import torch import torch.nn as nn import math class EncoderBlock(nn.Module): def __init__(self, embed_dim=768, num_heads=12, dim_feedforward=100, dropout=0.0, kdim=None, vdim=None): super().__init__() self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, kdim=kdim, vdim=vdim, batch_first=True) self.linear = nn.Sequential( nn.Linear(embed_dim, dim_feedforward), nn.Dropout(dropout), nn.ReLU(), nn.Linear(dim_feedforward, embed_dim) ) self.layer_norm1 = nn.LayerNorm(embed_dim) self.layer_norm2 = nn.LayerNorm(embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x, mask=None): attn_output, _ = self.attention(x, x, x, key_padding_mask=mask) x = x + self.dropout(attn_output) x = self.layer_norm1(x) x = x + self.dropout(self.linear(x)) x = self.layer_norm2(x) return x class TransformerEncoder(nn.Module): def __init__(self, nlayers, **kwargs): super().__init__() self.layers = nn.ModuleList([EncoderBlock(**kwargs) for _ in range(nlayers)]) def forward(self, x, mask=None): for layer in self.layers: x = layer(x, mask) return x class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer('pe', pe, persistent=False) def forward(self, x): x = x + self.pe[:, :x.size(1)] return x class TransformerModel(nn.Module): def __init__(self, ntoken, d_model, nhead, d_hid, nlayers, dropout=0.5): super().__init__() self.model_type = 'Transformer' self.pos_encoder = PositionalEncoding(d_model) self.transformer_encoder = TransformerEncoder(nlayers, embed_dim=d_model, num_heads=nhead, dim_feedforward=d_hid, dropout=dropout) self.embedding = nn.Embedding(ntoken, d_model) self.d_model = d_model self.linear = nn.Linear(d_model, 1) self.init_weights() def init_weights(self): initrange = 0.1 self.embedding.weight.data.uniform_(-initrange, initrange) self.linear.bias.data.zero_() self.linear.weight.data.uniform_(-initrange, initrange) def forward(self, src, src_mask=None): src = self.embedding(src) * math.sqrt(self.d_model) src = self.pos_encoder(src) output = self.transformer_encoder(src, src_mask) output = self.linear(output) return output.squeeze(-1)