source file: /opt/devel/celery/testproj/../celery/tests/test_task.py
file stats: 143 lines, 143 executed: 100.0% covered
1. import unittest 2. import uuid 3. import logging 4. from StringIO import StringIO 5. 6. from celery import task 7. from celery import registry 8. from celery.log import setup_logger 9. from celery import messaging 10. 11. 12. # Task run functions can't be closures/lambdas, as they're pickled. 13. def return_True(self, **kwargs): 14. return True 15. 16. 17. def raise_exception(self, **kwargs): 18. raise Exception("%s error" % self.__class__) 19. 20. 21. class IncrementCounterTask(task.Task): 22. name = "c.unittest.increment_counter_task" 23. count = 0 24. 25. def run(self, increment_by, **kwargs): 26. increment_by = increment_by or 1 27. self.__class__.count += increment_by 28. 29. 30. class TestCeleryTasks(unittest.TestCase): 31. 32. def createTaskCls(self, cls_name, task_name=None): 33. attrs = {} 34. if task_name: 35. attrs["name"] = task_name 36. cls = type(cls_name, (task.Task, ), attrs) 37. cls.run = return_True 38. return cls 39. 40. def assertNextTaskDataEquals(self, consumer, task_id, task_name, 41. **kwargs): 42. next_task = consumer.fetch() 43. task_data = consumer.decoder(next_task.body) 44. self.assertEquals(task_data["celeryID"], task_id) 45. self.assertEquals(task_data["celeryTASK"], task_name) 46. for arg_name, arg_value in kwargs.items(): 47. self.assertEquals(task_data.get(arg_name), arg_value) 48. 49. def test_raising_task(self): 50. rtask = self.createTaskCls("RaisingTask", "c.unittest.t.rtask") 51. rtask.run = raise_exception 52. sio = StringIO() 53. 54. taskinstance = rtask() 55. taskinstance(loglevel=logging.INFO, logfile=sio) 56. self.assertTrue(sio.getvalue().find("Task got exception") != -1) 57. 58. def test_incomplete_task_cls(self): 59. class IncompleteTask(task.Task): 60. name = "c.unittest.t.itask" 61. 62. self.assertRaises(NotImplementedError, IncompleteTask().run) 63. 64. def test_regular_task(self): 65. T1 = self.createTaskCls("T1", "c.unittest.t.t1") 66. self.assertTrue(isinstance(T1(), T1)) 67. self.assertTrue(T1().run()) 68. self.assertTrue(callable(T1()), 69. "Task class is callable()") 70. self.assertTrue(T1()(), 71. "Task class runs run() when called") 72. 73. # task without name raises NotImplementedError 74. T2 = self.createTaskCls("T2") 75. self.assertRaises(NotImplementedError, T2) 76. 77. registry.tasks.register(T1) 78. t1 = T1() 79. consumer = t1.get_consumer() 80. self.assertRaises(NotImplementedError, consumer.receive, "foo", "foo") 81. consumer.discard_all() 82. self.assertTrue(consumer.fetch() is None) 83. 84. # Without arguments. 85. tid = t1.delay() 86. self.assertNextTaskDataEquals(consumer, tid, t1.name) 87. 88. # With arguments. 89. tid2 = task.delay_task(t1.name, name="George Constanza") 90. self.assertNextTaskDataEquals(consumer, tid2, t1.name, 91. name="George Constanza") 92. 93. self.assertRaises(registry.tasks.NotRegistered, task.delay_task, 94. "some.task.that.should.never.exist.X.X.X.X.X") 95. 96. # Discarding all tasks. 97. task.discard_all() 98. tid3 = task.delay_task(t1.name) 99. self.assertEquals(task.discard_all(), 1) 100. self.assertTrue(consumer.fetch() is None) 101. 102. self.assertFalse(task.is_done(tid)) 103. task.mark_as_done(tid, result=None) 104. self.assertTrue(task.is_done(tid)) 105. 106. 107. publisher = t1.get_publisher() 108. self.assertTrue(isinstance(publisher, messaging.TaskPublisher)) 109. 110. def test_taskmeta_cache(self): 111. # TODO Needs to test task meta without TASK_META_USE_DB. 112. tid = str(uuid.uuid4()) 113. ckey = task.gen_task_done_cache_key(tid) 114. self.assertTrue(ckey.rfind(tid) != -1) 115. 116. 117. class TestTaskSet(unittest.TestCase): 118. 119. def test_counter_taskset(self): 120. ts = task.TaskSet(IncrementCounterTask, [ 121. {}, 122. {"increment_by": 2}, 123. {"increment_by": 3}, 124. {"increment_by": 4}, 125. {"increment_by": 5}, 126. {"increment_by": 6}, 127. {"increment_by": 7}, 128. {"increment_by": 8}, 129. {"increment_by": 9}, 130. ]) 131. self.assertEquals(ts.task_name, IncrementCounterTask.name) 132. self.assertEquals(ts.total, 9) 133. 134. taskset_id, subtask_ids = ts.run() 135. 136. consumer = IncrementCounterTask().get_consumer() 137. for subtask_id in subtask_ids: 138. m = consumer.decoder(consumer.fetch().body) 139. self.assertEquals(m.get("celeryTASKSET"), taskset_id) 140. self.assertEquals(m.get("celeryTASK"), IncrementCounterTask.name) 141. self.assertEquals(m.get("celeryID"), subtask_id) 142. IncrementCounterTask().run(increment_by=m.get("increment_by")) 143. self.assertEquals(IncrementCounterTask.count, sum(xrange(1, 10)))