In [1]:
import os
import json
import torch
import torchtext as text
import sys
import tqdm
from datasets import load_dataset
import pandas as pd
import string
from torchtext.data import get_tokenizer
from vocab import *
from utils import *
from constants import *
import pickle

from tqdm.notebook import tqdm

In [2]:
dataset = load_dataset("squad")

train_dic = {'passage': [], 'question': [], 'answer': []}
for i in range(len(dataset['train'])):
 datum = dataset['train'][i]
 for j in range(len(datum['answers']['text'])):
 train_dic['passage'].append(datum['context'])
 train_dic['question'].append(datum['question'])
 train_dic['answer'].append(datum['answers']['text'][j])

train = pd.DataFrame(train_dic)

val_dic = {'passage': [], 'question': [], 'answer': []}
for datum in dataset['validation']:
 for elem in datum['answers']['text']:
 ans_id = 0
 val_dic['passage'].append(datum['context'])
 val_dic['question'].append(datum['question'])
 val_dic['answer'].append(elem)

val = pd.DataFrame(val_dic)

Reusing dataset squad (/tmp/xdg-cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)


 0%| | 0/2 [00:00<?, ?it/s]

In [3]:
for i in tqdm(range(len(train))):
 row = train.iloc[i]
 passage = clean_text(row['passage'].lower())
 question = clean_text(row['question'].lower())
 answer = clean_text(row['answer'].lower())
 
 train.iloc[i]['passage'] = passage
 train.iloc[i]['question'] = question
 train.iloc[i]['answer'] = answer
 
for i in tqdm(range(len(val))):
 row = val.iloc[i]
 passage = clean_text(row['passage'].lower())
 question = clean_text(row['question'].lower())
 answer = clean_text(row['answer'].lower())
 
 val.iloc[i]['passage'] = passage
 val.iloc[i]['question'] = question
 val.iloc[i]['answer'] = answer

 0%| | 0/87599 [00:00<?, ?it/s]

 0%| | 0/34726 [00:00<?, ?it/s]

In [4]:
#Split Val into val and test

val = val.sample(frac=1).reset_index(drop=True)
test = val[:10000]
val = val[10000:]

In [5]:
train.to_csv('./data/train.csv', index=False)
val.to_csv('./data/val.csv', index=False)
test.to_csv('./data/test.csv', index=False)

In [6]:
vocab = build_vocab()

Building Vocabulary
Saved the vocab.


In [11]:
def get_processed_data(df, tokenizer):
 data = []
 for idx in tqdm(range(len(df))):
 pass_tokens = ['<start>'] + tokenizer(df.iloc[idx]["passage"]) + ['<end>']
 ans_tokens = ['<start>'] + tokenizer(df.iloc[idx]["answer"]) + ['<end>']
 q_tokens = ['<start>'] + tokenizer(df.iloc[idx]["question"]) + ['<end>']
 # pass_tokens = ['<start>'] + list(map(tokenizer, df.iloc[idx]["passage"])) + ['<end>']
 # ans_tokens = ['<start>'] + list(map(tokenizer, df.iloc[idx]["answer"])) + ['<end>']
 # q_tokens = ['<start>'] + list(map(tokenizer, df.iloc[idx]["question"])) + ['<end>']

 pass_len = MAX_PASSAGE_LEN + 2 # +2 for start and end tokens
 ans_len = MAX_ANSWER_LEN + 2
 q_len = MAX_QUESTION_LEN + 2

 passage = [vocab(word) for word in pass_tokens]
 answer = [vocab(word) for word in ans_tokens]
 question = [vocab(word) for word in q_tokens]

 # padding to same length
 pass_idxs = torch.zeros(pass_len)
 ans_idxs = torch.zeros(ans_len)
 q_idxs = torch.zeros(q_len)

 pass_idxs[:len(passage)] = torch.FloatTensor(passage)
 ans_idxs[:len(answer)] = torch.FloatTensor(answer)
 q_idxs[:len(question)] = torch.FloatTensor(question)

 data.append((pass_idxs, ans_idxs, q_idxs))
 return data

In [12]:
tokenizer = get_tokenizer("basic_english")
train_processed = get_processed_data(train, tokenizer)
val_processed = get_processed_data(val, tokenizer)
test_processed = get_processed_data(test, tokenizer)

 0%| | 0/87599 [00:00<?, ?it/s]

 0%| | 0/24726 [00:00<?, ?it/s]

 0%| | 0/10000 [00:00<?, ?it/s]

In [16]:
with open('./data/train_processed.pickle', 'wb') as train_file:
 pickle.dump(train_processed, train_file)

with open('./data/val_processed.pickle', 'wb') as val_file:
 pickle.dump(val_processed, val_file)

with open('./data/test_processed.pickle', 'wb') as test_file:
 pickle.dump(test_processed, test_file)

In [15]:
pickle.load('./data/train_processed')

(tensor([ 1., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 6., 14., 15.,
 16., 17., 18., 19., 20., 9., 21., 22., 23., 6., 24., 25., 12., 26.,
 27., 28., 23., 6., 14., 15., 29., 30., 31., 5., 20., 9., 32., 22.,
 23., 33., 34., 35., 36., 34., 6., 37., 38., 39., 40., 41., 12., 42.,
 43., 6., 14., 15., 20., 6., 44., 23., 6., 45., 46., 12., 26., 47.,
 6., 44., 20., 6., 48., 5., 9., 49., 50., 23., 51., 29., 52., 12.,
 31., 20., 9., 53., 23., 6., 48., 54., 55., 5., 56., 57., 6., 24.,
 25., 58., 59., 43., 60., 61., 62., 27., 63., 12., 54., 6., 64., 23.,
 6., 14., 65., 66., 29., 27., 9., 67., 68., 69., 70., 71., 72., 73.,
 29., 6., 18., 19., 74., 5., 20., 9., 75., 5., 76., 77., 22., 23.,
 25., 12., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
 0., 0., 0., 0., 0., 0., 0