from dataclasses import dataclass
import re
from torch import dtype
from config import LlamaConfig
from utils import *
class LlamaPreTrainedModel(nn.Module):
config_class = LlamaConfig
base_model_prefix = "llama"
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.n_layers = config.n_layers
def init_weights(self):
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
@property
def dtype(self) -> dtype:
return get_parameter_dtype(self)