import transformer import torch from datasets import load_dataset from transformers import BertTokenizerFast import channel def collate_batch(batch): ''' Function pads sequences to make them the same length. Useful when doing batch training. inputs: batch, should be a batch of samples. It will be a dictionary with fields: input_ids, token_type_ids, attention_mask ''' batch_inputs = [item['input_ids'] for item in batch] batch_token_types = [item['token_type_ids'] for item in batch] batch_attn_masks = [item['attention_mask'] for item in batch] batch_inputs_padded = torch.nn.utils.rnn.pad_sequence(batch_inputs, batch_first=True) batch_token_types_padded = torch.nn.utils.rnn.pad_sequence(batch_token_types, batch_first=True) batch_attn_masks_padded = torch.nn.utils.rnn.pad_sequence(batch_attn_masks, batch_first=True) return { 'input_ids': batch_inputs_padded, 'token_type_ids': batch_token_types_padded, 'attention_mask': batch_attn_masks_padded } def tokenize(dataset): ''' Tokenizes dataset and returns dictionary of tensors along with BERT. Uses pre-trained BERT model from hugging face. Pads to max_length = 128 ''' tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased') # pre-trained BERT tokenizer tokens = dataset.map(lambda examples: tokenizer(examples['en'], truncation=False, padding="max_length", max_length=128), batched=True, remove_columns=['de']) # tokenize # convert to tensors tokens.set_format(type="torch", columns=['input_ids', 'token_type_ids', 'attention_mask']) return tokens, tokenizer def create_dataloaders(tokens): ''' Creates dataloaders from tokens. Tokens should be in pytorch format and be a dictionary with fields train, test, validation ''' train_dl = torch.utils.data.DataLoader(tokens['train'], batch_size=32, shuffle=True, collate_fn=collate_batch) val_dl = torch.utils.data.DataLoader(tokens['validation'], batch_size=32, shuffle=True, collate_fn=collate_batch) test_dl = torch.utils.data.DataLoader(tokens['test'], batch_size=32, shuffle=False, collate_fn=collate_batch) return train_dl, val_dl, test_dl def get_attention_mask(padded_mask): ''' Function returns attention mask based on batched inputs that are padded. Input mask (padded_mask), needs to be of shape: [batch_size, seq_length]. ''' mask = padded_mask.float() mask = mask.masked_fill(mask == 0, -9e6) mask = mask.masked_fill(mask == 1, 0) return mask def main(): dataset = load_dataset("bentrevett/multi30k") # load Multi30k dataset tokens, tokenizer = tokenize(dataset) # tokenize data train_dl, val_dl, test_dl = create_dataloaders(tokens) # get pytorch dataloaders vocab_size = tokenizer.vocab_size embedding_dim = 768 # embedding dimension model = transformer.TransformerModel(ntoken=vocab_size, d_model=embedding_dim, nhead=2, num_layers=1, dropout=0.1) train_iter = iter(train_dl) batch = next(train_iter) encoding = model(batch['input_ids'], get_attention_mask(batch['attention_mask'])) corrupted_msg, erasures = channel.erasure_channel(encoding) print("done") if __name__ == "__main__": main() # optimizer = torch.optim.Adam(model.parameters(), lr=0.1) # criterion = torch.nn.MSELoss() # num_epochs = 10 # train_losses = [] # val_losses = [] # fig = plt.figure(figsize=(20, 5), dpi= 80, facecolor='w', edgecolor='k') # axes = fig.subplots(1,2) # for e in range(num_epochs): # train_iter = iter(train_dl) # model.train() # for i in range(len(train_dl)): # batch = next(train_iter) # attn_mask = expand_attention_mask_for_heads(batch['attention_mask'], 2) # pred = model(batch['input_ids'], attn_mask) # # calculate the loss and backward the gradient # loss = criterion(pred, batch['input_ids'].float()) # optimizer.zero_grad() # loss.backward() # optimizer.step() # train_losses.append(loss.item()) # if i%100==0: # axes[0].cla() # # plot the training error on a log plot # axes[0].plot(train_losses, label='loss') # axes[0].set_yscale('log') # axes[0].set_title('Training loss') # axes[0].set_xlabel('number of gradient iterations') # axes[0].legend() # # clear output window and diplay updated figure # display.clear_output(wait=True) # display.display(plt.gcf()) # print("Training epoch {}, iteration {} of {} ({} %), loss={:0.5f}".format(e, i, len(train_dl), 100*i//len(train_dl), train_losses[-1])) # val_iter = iter(val_dl) # model.eval() # for i in range(len(val_dl)): # batch = next(val_iter) # attn_mask = expand_attention_mask_for_heads(batch['attention_mask'], 12) # pred = model(batch['input_ids'], attn_mask) # # calculate the loss # with torch.no_grad(): # loss = criterion(pred, batch['input_ids']) # val_losses.append(loss.item()) # if i%10==0: # axes[1].cla() # # plot the validation error on a log plot # axes[1].plot(val_losses, label='loss') # axes[1].set_yscale('log') # axes[1].set_title('Validation loss') # axes[1].set_xlabel('number of gradient iterations') # axes[1].legend() # # clear output window and diplay updated figure # display.clear_output(wait=True) # display.display(plt.gcf()) # print("Validation epoch {}, iteration {} of {} ({} %), loss={:0.5f}".format(e, i, len(val_dl), 100*i//len(val_dl), val_losses[-1])) # plt.close('all')