JSCC / Transformer / jscc.py
jscc.py
Raw
import transformer
import torch
from datasets import load_dataset
from transformers import BertTokenizerFast
import channel

def collate_batch(batch):
    '''
    Function pads sequences to make them the same length. Useful when doing batch training.
    inputs: batch, should be a batch of samples. It will be a dictionary with fields: input_ids, token_type_ids, attention_mask
    '''
    batch_inputs = [item['input_ids'] for item in batch]
    batch_token_types = [item['token_type_ids'] for item in batch]
    batch_attn_masks = [item['attention_mask'] for item in batch]

    batch_inputs_padded = torch.nn.utils.rnn.pad_sequence(batch_inputs, batch_first=True)
    batch_token_types_padded = torch.nn.utils.rnn.pad_sequence(batch_token_types, batch_first=True)
    batch_attn_masks_padded = torch.nn.utils.rnn.pad_sequence(batch_attn_masks, batch_first=True)

    return {
        'input_ids': batch_inputs_padded,
        'token_type_ids': batch_token_types_padded,
        'attention_mask': batch_attn_masks_padded
    }

def tokenize(dataset):
    '''
    Tokenizes dataset and returns dictionary of tensors along with BERT. 
    Uses pre-trained BERT model from hugging face.
    Pads to max_length = 128

    '''
    tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased') # pre-trained BERT tokenizer
    tokens = dataset.map(lambda examples: tokenizer(examples['en'], truncation=False, padding="max_length", max_length=128), 
                         batched=True, remove_columns=['de']) # tokenize
    # convert to tensors

    tokens.set_format(type="torch", columns=['input_ids', 'token_type_ids', 'attention_mask'])
    return tokens, tokenizer

def create_dataloaders(tokens):
    '''
    Creates dataloaders from tokens. Tokens should be in pytorch format and be a dictionary with fields train, test, validation
    '''
    train_dl = torch.utils.data.DataLoader(tokens['train'], batch_size=32, shuffle=True, collate_fn=collate_batch)
    val_dl = torch.utils.data.DataLoader(tokens['validation'], batch_size=32, shuffle=True, collate_fn=collate_batch)
    test_dl = torch.utils.data.DataLoader(tokens['test'], batch_size=32, shuffle=False, collate_fn=collate_batch)
    return train_dl, val_dl, test_dl

def get_attention_mask(padded_mask):
    '''
    Function returns attention mask based on batched inputs that are padded. 
    Input mask (padded_mask), needs to be of shape: [batch_size, seq_length]. 
    '''

    mask = padded_mask.float()
    mask = mask.masked_fill(mask == 0, -9e6)
    mask = mask.masked_fill(mask == 1, 0)
    return mask

def main():
    dataset = load_dataset("bentrevett/multi30k") # load Multi30k dataset
    tokens, tokenizer = tokenize(dataset) # tokenize data
    train_dl, val_dl, test_dl = create_dataloaders(tokens) # get pytorch dataloaders

    vocab_size = tokenizer.vocab_size
    embedding_dim = 768  # embedding dimension

    model = transformer.TransformerModel(ntoken=vocab_size, d_model=embedding_dim, nhead=2, num_layers=1, dropout=0.1)
    train_iter = iter(train_dl)
    batch = next(train_iter)

    encoding = model(batch['input_ids'], get_attention_mask(batch['attention_mask']))

    corrupted_msg, erasures = channel.erasure_channel(encoding)
    
    print("done")

if __name__ == "__main__":
    main()
    # optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
    # criterion = torch.nn.MSELoss()
    # num_epochs = 10
    # train_losses = []
    # val_losses = []

    # fig = plt.figure(figsize=(20, 5), dpi= 80, facecolor='w', edgecolor='k')
    # axes = fig.subplots(1,2)

    # for e in range(num_epochs):
    #     train_iter = iter(train_dl)
    #     model.train()
    #     for i in range(len(train_dl)):
    #         batch = next(train_iter)
    #         attn_mask = expand_attention_mask_for_heads(batch['attention_mask'], 2)
    #         pred = model(batch['input_ids'], attn_mask)

    #         # calculate the loss and backward the gradient
    #         loss = criterion(pred, batch['input_ids'].float())
    #         optimizer.zero_grad()
    #         loss.backward()
    #         optimizer.step()
    #         train_losses.append(loss.item())

    #         if i%100==0:
    #             axes[0].cla()

    #             # plot the training error on a log plot
    #             axes[0].plot(train_losses, label='loss')
    #             axes[0].set_yscale('log')
    #             axes[0].set_title('Training loss')
    #             axes[0].set_xlabel('number of gradient iterations')
    #             axes[0].legend()

    #             # clear output window and diplay updated figure
    #             display.clear_output(wait=True)
    #             display.display(plt.gcf())
    #             print("Training epoch {}, iteration {} of {} ({} %), loss={:0.5f}".format(e, i, len(train_dl), 100*i//len(train_dl), train_losses[-1]))

    #     val_iter = iter(val_dl)
    #     model.eval()
    #     for i in range(len(val_dl)):
    #         batch = next(val_iter)
    #         attn_mask = expand_attention_mask_for_heads(batch['attention_mask'], 12)
    #         pred = model(batch['input_ids'], attn_mask)

    #         # calculate the loss
    #         with torch.no_grad():
    #             loss = criterion(pred, batch['input_ids'])
    #             val_losses.append(loss.item())

    #         if i%10==0:
    #             axes[1].cla()

    #             # plot the validation error on a log plot
    #             axes[1].plot(val_losses, label='loss')
    #             axes[1].set_yscale('log')
    #             axes[1].set_title('Validation loss')
    #             axes[1].set_xlabel('number of gradient iterations')
    #             axes[1].legend()

    #             # clear output window and diplay updated figure
    #             display.clear_output(wait=True)
    #             display.display(plt.gcf())
    #             print("Validation epoch {}, iteration {} of {} ({} %), loss={:0.5f}".format(e, i, len(val_dl), 100*i//len(val_dl), val_losses[-1]))
    # plt.close('all')