1from channels.generic.websocket import AsyncWebsocketConsumer
2import json
3
4from celery.result import AsyncResult
5from celery_progress.backend import Progress
6
7
8class ProgressConsumer(AsyncWebsocketConsumer):
9    async def connect(self):
10        self.task_id = self.scope['url_route']['kwargs']['task_id']
11
12        await self.channel_layer.group_add(
13            self.task_id,
14            self.channel_name
15        )
16
17        await self.accept()
18
19    async def disconnect(self, close_code):
20        await self.channel_layer.group_discard(
21            self.task_id,
22            self.channel_name
23        )
24
25    async def receive(self, text_data):
26        text_data_json = json.loads(text_data)
27        task_type = text_data_json['type']
28
29        if task_type == 'check_task_completion':
30            await self.channel_layer.group_send(
31                self.task_id,
32                {
33                    'type': 'update_task_progress',
34                    'data': Progress(AsyncResult(self.task_id)).get_info()
35                }
36            )
37
38    async def update_task_progress(self, event):
39        data = event['data']
40
41        await self.send(text_data=json.dumps(data))
42