mini-llama / src / base_llama.py
base_llama.py
Raw
from dataclasses import dataclass

import re
from torch import dtype
from config import LlamaConfig
from utils import *

class LlamaPreTrainedModel(nn.Module):
  config_class = LlamaConfig
  base_model_prefix = "llama"

  def __init__(self, config: LlamaConfig):
      super().__init__()
      self.config = config
      self.vocab_size = config.vocab_size
      self.n_layers = config.n_layers

  def init_weights(self):
    # Initialize weights
    self.apply(self._init_weights)

  def _init_weights(self, module):
    """ Initialize the weights """
    if isinstance(module, nn.Linear):
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        if module.bias is not None:
            torch.nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

  @property
  def dtype(self) -> dtype:
    return get_parameter_dtype(self)