{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "c2fee74f",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"import torch\n",
"import torchtext as text\n",
"import sys\n",
"import tqdm\n",
"from datasets import load_dataset\n",
"import pandas as pd\n",
"import string\n",
"from torchtext.data import get_tokenizer\n",
"from vocab import *\n",
"from utils import *\n",
"from constants import *\n",
"import pickle\n",
"\n",
"from tqdm.notebook import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "e34185c3",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Reusing dataset squad (/tmp/xdg-cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e5702ad773134cf983bba5331063de68",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dataset = load_dataset(\"squad\")\n",
"\n",
"train_dic = {'passage': [], 'question': [], 'answer': []}\n",
"for i in range(len(dataset['train'])):\n",
" datum = dataset['train'][i]\n",
" for j in range(len(datum['answers']['text'])):\n",
" train_dic['passage'].append(datum['context'])\n",
" train_dic['question'].append(datum['question'])\n",
" train_dic['answer'].append(datum['answers']['text'][j])\n",
"\n",
"train = pd.DataFrame(train_dic)\n",
"\n",
"val_dic = {'passage': [], 'question': [], 'answer': []}\n",
"for datum in dataset['validation']:\n",
" for elem in datum['answers']['text']:\n",
" ans_id = 0\n",
" val_dic['passage'].append(datum['context'])\n",
" val_dic['question'].append(datum['question'])\n",
" val_dic['answer'].append(elem)\n",
"\n",
"val = pd.DataFrame(val_dic)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "70fa0515",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "980e0c9145a4401eaa5cc5a82a9f721a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/87599 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "26efe8fa4acd4aac98addfa80543a6ec",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/34726 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for i in tqdm(range(len(train))):\n",
" row = train.iloc[i]\n",
" passage = clean_text(row['passage'].lower())\n",
" question = clean_text(row['question'].lower())\n",
" answer = clean_text(row['answer'].lower())\n",
" \n",
" train.iloc[i]['passage'] = passage\n",
" train.iloc[i]['question'] = question\n",
" train.iloc[i]['answer'] = answer\n",
" \n",
"for i in tqdm(range(len(val))):\n",
" row = val.iloc[i]\n",
" passage = clean_text(row['passage'].lower())\n",
" question = clean_text(row['question'].lower())\n",
" answer = clean_text(row['answer'].lower())\n",
" \n",
" val.iloc[i]['passage'] = passage\n",
" val.iloc[i]['question'] = question\n",
" val.iloc[i]['answer'] = answer"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "63809566",
"metadata": {},
"outputs": [],
"source": [
"#Split Val into val and test\n",
"\n",
"val = val.sample(frac=1).reset_index(drop=True)\n",
"test = val[:10000]\n",
"val = val[10000:]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "b5059d5c",
"metadata": {},
"outputs": [],
"source": [
"train.to_csv('./data/train.csv', index=False)\n",
"val.to_csv('./data/val.csv', index=False)\n",
"test.to_csv('./data/test.csv', index=False)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "b61acf74",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Building Vocabulary\n",
"Saved the vocab.\n"
]
}
],
"source": [
"vocab = build_vocab()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "e27ec51d",
"metadata": {},
"outputs": [],
"source": [
"def get_processed_data(df, tokenizer):\n",
" data = []\n",
" for idx in tqdm(range(len(df))):\n",
" pass_tokens = ['<start>'] + tokenizer(df.iloc[idx][\"passage\"]) + ['<end>']\n",
" ans_tokens = ['<start>'] + tokenizer(df.iloc[idx][\"answer\"]) + ['<end>']\n",
" q_tokens = ['<start>'] + tokenizer(df.iloc[idx][\"question\"]) + ['<end>']\n",
" # pass_tokens = ['<start>'] + list(map(tokenizer, df.iloc[idx][\"passage\"])) + ['<end>']\n",
" # ans_tokens = ['<start>'] + list(map(tokenizer, df.iloc[idx][\"answer\"])) + ['<end>']\n",
" # q_tokens = ['<start>'] + list(map(tokenizer, df.iloc[idx][\"question\"])) + ['<end>']\n",
"\n",
" pass_len = MAX_PASSAGE_LEN + 2 # +2 for start and end tokens\n",
" ans_len = MAX_ANSWER_LEN + 2\n",
" q_len = MAX_QUESTION_LEN + 2\n",
"\n",
" passage = [vocab(word) for word in pass_tokens]\n",
" answer = [vocab(word) for word in ans_tokens]\n",
" question = [vocab(word) for word in q_tokens]\n",
"\n",
" # padding to same length\n",
" pass_idxs = torch.zeros(pass_len)\n",
" ans_idxs = torch.zeros(ans_len)\n",
" q_idxs = torch.zeros(q_len)\n",
"\n",
" pass_idxs[:len(passage)] = torch.FloatTensor(passage)\n",
" ans_idxs[:len(answer)] = torch.FloatTensor(answer)\n",
" q_idxs[:len(question)] = torch.FloatTensor(question)\n",
"\n",
" data.append((pass_idxs, ans_idxs, q_idxs))\n",
" return data"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "2d03b997",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c050703069b4455da7807680794e088f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/87599 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6f1aeb41cd514bb7a3c5b7b661e0a43d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/24726 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "42cdac0229814c5591ab4516602c8b9b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/10000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"tokenizer = get_tokenizer(\"basic_english\")\n",
"train_processed = get_processed_data(train, tokenizer)\n",
"val_processed = get_processed_data(val, tokenizer)\n",
"test_processed = get_processed_data(test, tokenizer)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "d771f36c",
"metadata": {},
"outputs": [],
"source": [
"with open('./data/train_processed.pickle', 'wb') as train_file:\n",
" pickle.dump(train_processed, train_file)\n",
"\n",
"with open('./data/val_processed.pickle', 'wb') as val_file:\n",
" pickle.dump(val_processed, val_file)\n",
"\n",
"with open('./data/test_processed.pickle', 'wb') as test_file:\n",
" pickle.dump(test_processed, test_file)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "a8f3781c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([ 1., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 6., 14., 15.,\n",
" 16., 17., 18., 19., 20., 9., 21., 22., 23., 6., 24., 25., 12., 26.,\n",
" 27., 28., 23., 6., 14., 15., 29., 30., 31., 5., 20., 9., 32., 22.,\n",
" 23., 33., 34., 35., 36., 34., 6., 37., 38., 39., 40., 41., 12., 42.,\n",
" 43., 6., 14., 15., 20., 6., 44., 23., 6., 45., 46., 12., 26., 47.,\n",
" 6., 44., 20., 6., 48., 5., 9., 49., 50., 23., 51., 29., 52., 12.,\n",
" 31., 20., 9., 53., 23., 6., 48., 54., 55., 5., 56., 57., 6., 24.,\n",
" 25., 58., 59., 43., 60., 61., 62., 27., 63., 12., 54., 6., 64., 23.,\n",
" 6., 14., 65., 66., 29., 27., 9., 67., 68., 69., 70., 71., 72., 73.,\n",
" 29., 6., 18., 19., 74., 5., 20., 9., 75., 5., 76., 77., 22., 23.,\n",
" 25., 12., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),\n",
" tensor([ 1., 60., 61., 62., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0.]),\n",
" tensor([ 1., 43., 78., 79., 6., 24., 25., 80., 81., 27., 63., 27., 55., 56.,\n",
" 82., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0.]))"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pickle.load('./data/train_processed')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5eb5fea5",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"interpreter": {
"hash": "147ef1215b7e2b4bf3a64983c233460acb54149cdc3836f93d9a84ff7ba2f913"
},
"kernelspec": {
"display_name": "Python 3 (clean)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}