mini-llama / src / classifier.py
classifier.py
Raw
import torch
import torch.nn.functional as F

# change it with respect to the original model
from config import LlamaConfig
from llama import load_pretrained
from tokenizer import Tokenizer


class LlamaZeroShotClassifier(torch.nn.Module):
    def __init__(
        self, config: LlamaConfig, tokenizer: Tokenizer, label_names: list[str]
    ):
        super(LlamaZeroShotClassifier, self).__init__()
        self.num_labels = config.num_labels
        self.llama = load_pretrained(config.pretrained_model_path)
        # Zero-shot classification does not require updating llama paramters.
        for param in self.llama.parameters():
            param.requires_grad = False
        assert len(label_names) == self.num_labels
        self.tokenizer = tokenizer
        self.label_name_ids = [
            tokenizer.encode(label, bos=False, eos=False) for label in label_names
        ]

    def forward(self, input_ids):
        # compute the completion probability of each label string
        logits, _ = self.llama(input_ids)
        log_probabilities = F.log_softmax(logits, dim=-1)
        label_probabilities = torch.zeros(
            (log_probabilities.shape[0], self.num_labels),
            device=log_probabilities.device,
        )
        for i, label_token_ids in enumerate(self.label_name_ids):
            total_log_prob = torch.sum(
                log_probabilities[:, :, label_token_ids], axis=-1
            )
            label_probabilities[:, i] = total_log_prob[:, 0]
        return label_probabilities


class LlamaEmbeddingClassifier(torch.nn.Module):
    def __init__(self, config):
        super(LlamaEmbeddingClassifier, self).__init__()
        self.num_labels = config.num_labels
        self.llama = load_pretrained(config.pretrained_model_path)
        # If we use pretrain mode, we freeze Llama parameters.
        for param in self.llama.parameters():
            if config.option == "pretrain":
                param.requires_grad = False
            elif config.option == "finetune":
                param.requires_grad = True

        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.classifier_head = torch.nn.Linear(self.llama.config.dim, self.num_labels)

    def forward(self, input_ids: torch.Tensor):
        """
        1) Find the hidden state after the final token of the input sequence
        2) Apply dropout (self.dropout) to the hidden state at training time to mitigate
           overfitting.
        2) Pass this through the classifier head (self.classifier_head), which will return
           logits (unnormalized probabilities) over all classes.
        3) Take the log-softmax of the logits and return log-probabilities over all classes.
        """

        # input_ids has shape (bs, seqlen)
        _, h = self.llama(input_ids)  # h has shape (bs, seqlen, dim)
        h = h[:, -1, :]  # shape (bs, dim)

        h = self.dropout(h)
        logits = self.classifier_head(h)

        return F.log_softmax(logits, dim=-1)