import random import string import requests import json import torch import numpy as np import transformers import queue from django.db.models import Prefetch, F, Count from django.apps import apps # from examples.models import Example # from projects.models import Project # from labels.models import Category # from label_types.models import CategoryType import daltag.text_classification.text_classification as tc from torch.utils.data import DataLoader from transformers import AutoTokenizer, AutoModelForSequenceClassification pretrained_model_path = "./pre_trained_models" def get_models(): Example = apps.get_model('examples', 'Example') Category = apps.get_model('labels', 'Category') CategoryType = apps.get_model('label_types', 'CategoryType') Project = apps.get_model('project', 'Project') return Example, Category, CategoryType, Project class ModelDoubleton: _inference_instance = None _train_instance = None _instance_count = 0 def __init__(self, model_type, model, tokenizer): self.model_type = model_type self.model = model self.tokenizer = tokenizer @staticmethod def get_instance(model_type): if model_type == "inference": if ModelDoubleton._inference_instance is None: model, tokenizer = load_base_model() ModelDoubleton._inference_instance = ModelDoubleton("inference", model, tokenizer) ModelDoubleton._instance_count = 1 return ModelDoubleton._inference_instance elif model_type == "training": if ModelDoubleton._train_instance is None: model, tokenizer = load_base_model() ModelDoubleton._train_instance = ModelDoubleton("training", model, tokenizer) ModelDoubleton._instance_count = 2 return ModelDoubleton._train_instance else: raise Exception("model type not allowed") @staticmethod def reload_inference_model(): ModelDoubleton._inference_instance = ModelDoubleton._train_instance ModelDoubleton._train_instance = None ModelDoubleton._instance_count = 1 def load_base_model(model_name="cardiffnlp/twitter-xlm-roberta-base-sentiment"): tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=pretrained_model_path) model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir=pretrained_model_path) return model, tokenizer def handle_inference_request(q_text): model = ModelDoubleton.get_instance('inference') inputs = model.tokenizer(q_text, return_tensors="pt") outputs = model.model(**inputs) prediction = torch.argmax(outputs.logits, dim=-1) label = model.model.config.id2label[prediction.item()] accuracy = prediction.item() return {"label": label, "accuracy": accuracy} class DalTagEngine(object): """ DalTag Engine implements singleton pattern - handles requests for - inference - training - active learning strategies """ _singleton_instance = None _init_setup_complete = False def __new__(cls, *args, **kwargs): """ creates a singleton object, if it is not created, or else returns the previous singleton object""" if not cls._singleton_instance: cls._singleton_instance = super(DalTagEngine, cls).__new__(cls) return cls._singleton_instance def train_handler(self, project_id): project = self.project_mapping.get(project_id) if project is None: return None acl = project.get("acl") acl.train() # # model.train_model.delay() # celery way try if possible # pass def inference_handler(self, project_id, q_text): project = self.project_mapping.get(project_id,None) if project is None: return None return project.get("acl").infer(q_text) def acl_query_handler(self, project_id): """not sure there will be this method or not""" project = self.project_mapping.get(project_id) if project is None: return None return project.get("acl").query(n=-1) # handle n=-1 in query def get_strategy_instance(self, strategy_name): mapping = { "RandomSampling" : tc.RandomSampling, "MarginSampling" : tc.MarginSampling, "MNLPSampling" : tc.MNLPSampling, "BALDDropout" : tc.BALDDropout, "AdversarialDeepFool" : tc.RandomSampling, # "AdversarialDeepFool" : tc.AdversarialDeepFool, "LeastConfidence" : tc.LeastConfidence, "LeastConfidenceDropout" : tc.LeastConfidenceDropout, } StrategyClass = mapping.get(strategy_name) if StrategyClass is None: return tc.RandomSampling return StrategyClass def update_acl_strategy(self, project_id): """update acl strategy from daltag config""" DalTagConfig = apps.get_model('projects', 'DalTagConfig') conf = DalTagConfig.objects.filter(project_id=project_id).first() project = self.project_mapping.get(project_id) if project is None: Project = apps.get_model('projects', 'Project') project = Project.objects.filter(id=project_id).first() self.add_project(project) return StrategyClass = self.get_strategy_instance(conf.acl_strategy) ds = self.project_mapping[project_id]["data"] net = self.project_mapping[project_id]["net"] self.project_mapping[project_id]["acl"] = StrategyClass(ds,net) def update_projects(self): "logic for adding new projects in projects list" Project = apps.get_model('project', 'Project') self.projects = Project.objects.filter(project_type__in = ['DocumentClassification', 'SequenceLabeling']).all() for p in self.projects: if self.project_mapping.get(p.id,None) is None: self.add_project(p) pass def add_project(self, project): if project.project_type == 'DocumentClassification': DalTagConfig = apps.get_model('projects', 'DalTagConfig') conf = DalTagConfig.objects.filter(project_id=project.id).first() if conf is None: filter_conditions = {"project_id": project.id} updates = {"acl_strategy":"RandomSampling", "base_model": "bert-base-uncased"} daltag_obj, created = DalTagConfig.objects.update_or_create(defaults=updates, **filter_conditions) if created: resp = "created" else: resp = "updated" conf = DalTagConfig.objects.filter(project_id=project.id).first() if conf.base_model != "sentence-transformers/paraphrase-mpnet-base-v2": project_data = self.get_project_data(project.id) if len(project_data) == 0: return ds = tc.TextData({"data":project_data}) ds.pre_processor() train_df = ds.get_training_data() test_df = ds.get_testing_data() # text_classifier = tc.TextClassifier("bert-base-uncased", ds.get_num_of_labels()) text_classifier = tc.TextClassifierNN(conf.base_model, ds.get_labels()) text_classifier.load_models() tc_train_data = tc.TextClassificationData(train_df, text_classifier.get_tokenizer(), len(text_classifier.labels), labeled=True) tc_test_data = tc.TextClassificationData(test_df, text_classifier.get_tokenizer(), len(text_classifier.labels), labeled=True) StrategyClass = self.get_strategy_instance(conf.acl_strategy) self.project_mapping[project.id] = { "project_id": project.id, "acl": StrategyClass(ds, text_classifier), "net": text_classifier, "data": ds } else: project_data = self.get_project_data(project.id) if len(project_data) == 0: return ds = tc.TextDataSETFIT({"data":project_data}) ds.pre_processor() train_df = ds.get_training_data() test_df = ds.get_testing_data() # text_classifier = tc.TextClassifier("bert-base-uncased", ds.get_num_of_labels()) text_classifier = tc.TextClassifierSETFIT(conf.base_model, ds.get_labels()) text_classifier.load_models() # tc_train_data = tc.TextClassificationData(train_df, text_classifier.get_tokenizer(), len(text_classifier.labels), labeled=True) # tc_test_data = tc.TextClassificationData(test_df, text_classifier.get_tokenizer(), len(text_classifier.labels), labeled=True) StrategyClass = self.get_strategy_instance(conf.acl_strategy) self.project_mapping[project.id] = { "project_id": project.id, "acl": StrategyClass(ds, text_classifier, True), "net": text_classifier, "data": ds } def reload_project(self, project_id): print("reload project, ", project_id) DalTagConfig = apps.get_model('projects', 'DalTagConfig') conf = DalTagConfig.objects.filter(project_id=project_id).first() project = self.project_mapping.get(project_id) if project is None: Project = apps.get_model('projects', 'Project') project = Project.objects.filter(id=project_id).first() self.add_project(project) return if project.get("net").model_name != conf.base_model: if conf.base_model == "sentence-transformers/paraphrase-mpnet-base-v2": Project = apps.get_model('projects', 'Project') project_obj = Project.objects.filter(id=project_id).first() self.add_project(project_obj) return else: print("update in model") ds = project.get("data") text_classifier = tc.TextClassifierNN(conf.base_model, ds.get_labels()) text_classifier.load_models() train_df = ds.get_training_data() test_df = ds.get_testing_data() tc_train_data = tc.TextClassificationData(train_df, text_classifier.get_tokenizer(), len(text_classifier.labels), labeled=True) tc_test_data = tc.TextClassificationData(test_df, text_classifier.get_tokenizer(), len(text_classifier.labels), labeled=True) self.project_mapping[project_id]["net"] = text_classifier else: print("no update in model") StrategyClass = self.get_strategy_instance(conf.acl_strategy) ds = self.project_mapping[project_id]["data"] net = self.project_mapping[project_id]["net"] if conf.base_model != "sentence-transformers/paraphrase-mpnet-base-v2": self.project_mapping[project_id]["acl"] = StrategyClass(ds, net) return self.project_mapping[project_id]["acl"] = StrategyClass(ds, net, fewshot=True) def init_engine(self): """ create mapping between project type and set of class can have more than one project of same type and all will have different models instances add logic for inference into query strategy instance iterate over all the projects in db and start the query instances """ if not self._init_setup_complete: print("initialize engine") Project = apps.get_model('projects', 'Project') self.projects = Project.objects.filter(project_type__in = ['DocumentClassification', 'SequenceLabeling']).all() self.project_mapping = dict() for p in self.projects: self.add_project(p) self._init_setup_complete = True else: print("init already done... skipping") def get_project_data(self,project_id): Example = apps.get_model('examples', 'Example') Category = apps.get_model('labels', 'Category') CategoryType = apps.get_model('label_types', 'CategoryType') all_examples = Example.objects.filter(project_id=project_id).all() category_list = CategoryType.objects.filter(project_id=project_id) categories = Category.objects.filter(label_id__in=category_list).prefetch_related( Prefetch('example', queryset=Example.objects.only('text')) ).values('example__id', 'example__text', 'label__text').annotate(label_count=Count('label__text')) data = [] labeled_ids = dict() current_example_text = None labels = [] for category in categories: if category['example__text'] != current_example_text: if current_example_text is not None: labeled_ids[category['example__id']] = True data.append({ "id": category['example__id'], "text": current_example_text, "label": labels }) current_example_text = category['example__text'] labels = [] labels.append(category['label__text']) if current_example_text is not None: data.append({ "id": category['example__id'], "text": current_example_text, "label": labels }) for e in all_examples: if not labeled_ids.get(e.id,False): data.append({ "id": e.id, "text": e.text, "label": [] }) return data