production-taskbar / backend / informing / tasks.py
tasks.py
Raw
import json
import logging
from asyncio import get_event_loop
from datetime import datetime, timedelta
from typing import Any, List

from celery import shared_task
from celery.exceptions import Ignore, InvalidTaskError
from channels.layers import get_channel_layer
from django.conf import settings
from django.utils.timezone import make_aware, utc
from django_celery_beat.models import PeriodicTask
from rest_framework.utils.serializer_helpers import ReturnDict

from .models import Notification, OrganizationalUnit
from .serializers import TaskShiftSerializer

is_debug = settings.DEBUG
delayed_tasks: List[int] = []


def get_now() -> datetime:
    return make_aware(datetime.now())


def update_last_run_at(notification: Notification, task_name: str) -> Any:
    task = None
    # queryset used for update [last_run_at] without historyrecord hit
    queryset = Notification.objects.filter(id=notification.pk)
    if notification.one_off == True:
        queryset.update(is_active=False)

    if task_name:
        task = PeriodicTask.objects.filter(name__contains=task_name)

    if task:
        now = get_now()
        queryset.update(last_run_at=now)
        task.update(last_run_at=now)
    else:
        raise InvalidTaskError(
            f'Error on {notification} last_run_at update. No celery task {task_name} exists.'
        )


def send(notification: Notification) -> None:
    recipients = notification.recipients.all()
    if not recipients.exists():
        recipients = OrganizationalUnit.objects.filter(
            location=notification.location)
    channel_layer = get_channel_layer()

    # use loop instead of async_to_sync because 'Event loop is closed' error
    loop = get_event_loop()

    try:
        for recipient in recipients:
            group_name = f'informing_{recipient.pk}'
            loop.run_until_complete(
                channel_layer.group_send(
                    group_name,
                    {
                        "type": "informing.notify",
                        "text": json.dumps({
                            "notification_id": notification.pk,
                            "content": notification.content,
                            "need_confirmation": notification.
                            need_confirmation,
                            "close_delay": notification.close_delay,
                            "is_overlay": notification.is_overlay,
                            "timestamp": datetime.now().timestamp()
                        }),
                    },
                ))
    except Exception as e:
        print(f'CATCH: {e}')


def send_and_update(notification: Notification, task_name: str) -> None:
    send(notification)
    update_last_run_at(notification, task_name)


def get_delayed_datetime(shifts: ReturnDict,
                         notification: Notification) -> Any:
    now = datetime.now()
    weekday = now.weekday()
    next_shift_datetime = None
    for shift in shifts:
        if weekday in shift['weekdays']:
            time_start = datetime.strptime(shift['start_time'],
                                           "%H:%M:%S").time()
            time_end = datetime.strptime(shift['end_time'], "%H:%M:%S").time()
            datetime_start_origin = datetime.combine(now, time_start)
            # datetime_end
            datetime_end = datetime.combine(now, time_end)
            # datetime_start
            if time_start > time_end:
                datetime_start = datetime_start_origin - timedelta(days=1)
            else:
                datetime_start = datetime_start_origin
            if notification.show_on_shift_start:
                delta = timedelta(minutes=notification.show_on_shift_start)
                datetime_start += delta
                datetime_start_origin += delta
            if notification.show_on_shift_end:
                datetime_end -= timedelta(
                    minutes=notification.show_on_shift_end)
            if datetime_start <= now <= datetime_end:
                return None
            elif datetime_start >= now:
                next_shift_datetime = datetime_start
            elif datetime_start_origin >= now:
                next_shift_datetime = datetime_start_origin

    return next_shift_datetime


@shared_task(bind=True, ignore_result=False,
             options={"expires": 120})    # type: ignore
def broadcast_notification(self: Any, *args: Any, **kwargs: Any) -> str:
    n_id = kwargs.get('notification_id', None)
    is_delayed = kwargs.get('is_delayed', False)
    task_name = kwargs.get('task_name', None)
    logging.info(f'Starting broadcast {task_name}, n_id: {n_id}')
    if not is_delayed and n_id in delayed_tasks:
        meta = {'info': 'this notification task is delayed'}
        self.update_state(state='REVOKED', meta=meta)
        raise Ignore()
    if is_delayed and n_id in delayed_tasks:
        delayed_tasks.remove(n_id)

    notification = Notification.objects.filter(pk=n_id).first()
    if notification:
        shifts = TaskShiftSerializer(
            notification.shifts.filter(is_active=True), many=True).data
        if shifts:
            delayed_datetime = get_delayed_datetime(shifts, notification)
            if delayed_datetime:
                kwargs['is_delayed'] = True
                utc_time = delayed_datetime.astimezone(utc)
                self.apply_async(eta=utc_time,
                                 args=args,
                                 kwargs=kwargs,
                                 ignore_result=False)
                self.update_state(state='REVOKED',
                                  meta={'delayed_until': delayed_datetime})
                delayed_tasks.append(n_id)
                raise Ignore()
        send_and_update(notification, task_name)
        return 'success'

    raise InvalidTaskError(f'Notification object with id {n_id} not found')