DalTag / backend / daltag / daltag.py
daltag.py
Raw
import random
import string
import requests
import json
import torch
import numpy as np
import transformers
import queue
from django.db.models import Prefetch, F, Count
from django.apps import apps
# from examples.models import Example
# from projects.models import Project
# from labels.models import Category
# from label_types.models import CategoryType
import daltag.text_classification.text_classification as tc
from torch.utils.data import DataLoader

from transformers import AutoTokenizer, AutoModelForSequenceClassification

pretrained_model_path = "./pre_trained_models"

def get_models():
    Example = apps.get_model('examples', 'Example')
    Category = apps.get_model('labels', 'Category')
    CategoryType = apps.get_model('label_types', 'CategoryType')
    Project = apps.get_model('project', 'Project')

    return Example, Category, CategoryType, Project

class ModelDoubleton:
    _inference_instance = None
    _train_instance = None
    _instance_count = 0

    def __init__(self, model_type, model, tokenizer):
        self.model_type = model_type
        self.model = model
        self.tokenizer = tokenizer

    @staticmethod
    def get_instance(model_type):
        if model_type == "inference":
            if ModelDoubleton._inference_instance is None:
                model, tokenizer = load_base_model()
                ModelDoubleton._inference_instance = ModelDoubleton("inference", model, tokenizer)
                ModelDoubleton._instance_count = 1
            return ModelDoubleton._inference_instance
        elif model_type == "training":
            if ModelDoubleton._train_instance is None:
                model, tokenizer = load_base_model()
                ModelDoubleton._train_instance = ModelDoubleton("training", model, tokenizer)
                ModelDoubleton._instance_count = 2
            return ModelDoubleton._train_instance
        else:
            raise Exception("model type not allowed")

    @staticmethod
    def reload_inference_model():
        ModelDoubleton._inference_instance = ModelDoubleton._train_instance
        ModelDoubleton._train_instance = None
        ModelDoubleton._instance_count = 1


def load_base_model(model_name="cardiffnlp/twitter-xlm-roberta-base-sentiment"):
    tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=pretrained_model_path)
    model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir=pretrained_model_path)
    return model, tokenizer


def handle_inference_request(q_text):
    model = ModelDoubleton.get_instance('inference')
    inputs = model.tokenizer(q_text, return_tensors="pt")
    outputs = model.model(**inputs)
    prediction = torch.argmax(outputs.logits, dim=-1)
    label = model.model.config.id2label[prediction.item()]
    accuracy = prediction.item()
    return {"label": label, "accuracy": accuracy}



class DalTagEngine(object):
    """
    DalTag Engine implements singleton pattern
    - handles requests for 
        - inference
        - training
        - active learning strategies
    """
    _singleton_instance = None
    _init_setup_complete = False

    def __new__(cls, *args, **kwargs):
        """ creates a singleton object, if it is not created,
            or else returns the previous singleton object"""
        if not cls._singleton_instance:
            cls._singleton_instance = super(DalTagEngine, cls).__new__(cls)
        return cls._singleton_instance


    def train_handler(self, project_id):
        project = self.project_mapping.get(project_id)
        if project is None:
            return None
        acl = project.get("acl")
        acl.train()
        # 
        # model.train_model.delay() # celery way try if possible
        # pass

    def inference_handler(self, project_id, q_text):
        project = self.project_mapping.get(project_id,None)
        if project is None:
            return None
        return project.get("acl").infer(q_text)

    def acl_query_handler(self, project_id):
        """not sure there will be this method or not"""
        project = self.project_mapping.get(project_id)
        if project is None:
            return None
        return project.get("acl").query(n=-1) # handle n=-1 in query
    
    def get_strategy_instance(self, strategy_name):
        mapping = {
            "RandomSampling" : tc.RandomSampling,
            "MarginSampling" : tc.MarginSampling,
            "MNLPSampling" : tc.MNLPSampling,
            "BALDDropout" : tc.BALDDropout,
            "AdversarialDeepFool" : tc.RandomSampling,
            # "AdversarialDeepFool" : tc.AdversarialDeepFool,
            "LeastConfidence" : tc.LeastConfidence,
            "LeastConfidenceDropout" : tc.LeastConfidenceDropout,
        }
        StrategyClass = mapping.get(strategy_name)
        if StrategyClass is None:
            return tc.RandomSampling
        return StrategyClass

    def update_acl_strategy(self, project_id):
        """update acl strategy from daltag config"""
        DalTagConfig = apps.get_model('projects', 'DalTagConfig')
        conf = DalTagConfig.objects.filter(project_id=project_id).first()
        project = self.project_mapping.get(project_id)
        if project is None:
            Project = apps.get_model('projects', 'Project')
            project = Project.objects.filter(id=project_id).first()
            self.add_project(project)
            return
        
        StrategyClass = self.get_strategy_instance(conf.acl_strategy)
        ds = self.project_mapping[project_id]["data"]
        net = self.project_mapping[project_id]["net"]
        self.project_mapping[project_id]["acl"] = StrategyClass(ds,net)
        
        
    def update_projects(self):
        "logic for adding new projects in projects list"
        
        Project = apps.get_model('project', 'Project')
        self.projects = Project.objects.filter(project_type__in = ['DocumentClassification', 'SequenceLabeling']).all()

        for p in self.projects:
            if self.project_mapping.get(p.id,None) is None:
                self.add_project(p)
                
        pass

    def add_project(self, project):
        if project.project_type == 'DocumentClassification':
            DalTagConfig = apps.get_model('projects', 'DalTagConfig')
            conf = DalTagConfig.objects.filter(project_id=project.id).first()

            if conf is None:
                filter_conditions = {"project_id": project.id}
                updates = {"acl_strategy":"RandomSampling", "base_model": "bert-base-uncased"}
                daltag_obj, created = DalTagConfig.objects.update_or_create(defaults=updates, **filter_conditions)

                if created:
                    resp = "created"
                else:
                    resp = "updated"
                
            conf = DalTagConfig.objects.filter(project_id=project.id).first()
            if conf.base_model != "sentence-transformers/paraphrase-mpnet-base-v2":                    
                project_data = self.get_project_data(project.id)
                if len(project_data) == 0:
                    return
                ds = tc.TextData({"data":project_data})
                ds.pre_processor()
                train_df = ds.get_training_data()
                test_df = ds.get_testing_data()


                # text_classifier = tc.TextClassifier("bert-base-uncased", ds.get_num_of_labels())
                text_classifier = tc.TextClassifierNN(conf.base_model, ds.get_labels())
                text_classifier.load_models()

                tc_train_data = tc.TextClassificationData(train_df, text_classifier.get_tokenizer(), len(text_classifier.labels), labeled=True)
                tc_test_data = tc.TextClassificationData(test_df, text_classifier.get_tokenizer(), len(text_classifier.labels), labeled=True)

                StrategyClass = self.get_strategy_instance(conf.acl_strategy)
                
                self.project_mapping[project.id] = {
                    "project_id": project.id,
                    "acl": StrategyClass(ds, text_classifier),
                    "net": text_classifier,
                    "data": ds
                }

            else:
                project_data = self.get_project_data(project.id)
                if len(project_data) == 0:
                    return
                ds = tc.TextDataSETFIT({"data":project_data})
                ds.pre_processor()
                train_df = ds.get_training_data()
                test_df = ds.get_testing_data()


                # text_classifier = tc.TextClassifier("bert-base-uncased", ds.get_num_of_labels())
                text_classifier = tc.TextClassifierSETFIT(conf.base_model, ds.get_labels())
                text_classifier.load_models()

                # tc_train_data = tc.TextClassificationData(train_df, text_classifier.get_tokenizer(), len(text_classifier.labels), labeled=True)
                # tc_test_data = tc.TextClassificationData(test_df, text_classifier.get_tokenizer(), len(text_classifier.labels), labeled=True)

                StrategyClass = self.get_strategy_instance(conf.acl_strategy)
                
                self.project_mapping[project.id] = {
                    "project_id": project.id,
                    "acl": StrategyClass(ds, text_classifier, True),
                    "net": text_classifier,
                    "data": ds
                }


    def reload_project(self, project_id):
        print("reload project, ", project_id)
        DalTagConfig = apps.get_model('projects', 'DalTagConfig')
        conf = DalTagConfig.objects.filter(project_id=project_id).first()
        project = self.project_mapping.get(project_id)
        if project is None:
            Project = apps.get_model('projects', 'Project')
            project = Project.objects.filter(id=project_id).first()
            self.add_project(project)
            return

        if project.get("net").model_name != conf.base_model:
            if conf.base_model == "sentence-transformers/paraphrase-mpnet-base-v2":
                Project = apps.get_model('projects', 'Project')
                project_obj = Project.objects.filter(id=project_id).first()
                self.add_project(project_obj)
                return
                
            else:
                print("update in model")
                ds = project.get("data")
                text_classifier = tc.TextClassifierNN(conf.base_model, ds.get_labels())
                text_classifier.load_models()
                
                train_df = ds.get_training_data()
                test_df = ds.get_testing_data()

                tc_train_data = tc.TextClassificationData(train_df, text_classifier.get_tokenizer(), len(text_classifier.labels), labeled=True)
                tc_test_data = tc.TextClassificationData(test_df, text_classifier.get_tokenizer(), len(text_classifier.labels), labeled=True)

                self.project_mapping[project_id]["net"] =  text_classifier
        
        else:
            print("no update in model")
        
        StrategyClass = self.get_strategy_instance(conf.acl_strategy)

        ds = self.project_mapping[project_id]["data"]
        net = self.project_mapping[project_id]["net"]
        if conf.base_model != "sentence-transformers/paraphrase-mpnet-base-v2":
            self.project_mapping[project_id]["acl"] = StrategyClass(ds, net)
            return
        self.project_mapping[project_id]["acl"] = StrategyClass(ds, net, fewshot=True)

        

    def init_engine(self):
        """
        create mapping between project type and set of class 
        can have more than one project of same type and all will have different models instances
        add logic for inference into query strategy instance 

        iterate over all the projects in db and start the query instances
        """
        if not self._init_setup_complete:
            print("initialize engine")

            Project = apps.get_model('projects', 'Project')
            self.projects = Project.objects.filter(project_type__in = ['DocumentClassification', 'SequenceLabeling']).all()
            self.project_mapping = dict()
            for p in self.projects:
                self.add_project(p)
            
            self._init_setup_complete = True
        else:
            print("init already done... skipping")
            
    def get_project_data(self,project_id):
        Example = apps.get_model('examples', 'Example')
        Category = apps.get_model('labels', 'Category')
        CategoryType = apps.get_model('label_types', 'CategoryType')

        all_examples = Example.objects.filter(project_id=project_id).all()

        category_list = CategoryType.objects.filter(project_id=project_id)
        categories = Category.objects.filter(label_id__in=category_list).prefetch_related(
            Prefetch('example', queryset=Example.objects.only('text'))
        ).values('example__id', 'example__text', 'label__text').annotate(label_count=Count('label__text'))

        data = []
        labeled_ids = dict()
        current_example_text = None
        labels = []
        for category in categories:
            if category['example__text'] != current_example_text:
                if current_example_text is not None:
                    labeled_ids[category['example__id']] = True
                    data.append({
                        "id": category['example__id'],
                        "text": current_example_text,
                        "label": labels
                    })
                current_example_text = category['example__text']
                labels = []
            labels.append(category['label__text'])

        if current_example_text is not None:
            data.append({
                "id": category['example__id'],
                "text": current_example_text,
                "label": labels
            })
        
        for e in all_examples:
            if not labeled_ids.get(e.id,False):
                data.append({
                    "id": e.id,
                    "text": e.text,
                    "label": []
                })

        return data