{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "c2fee74f", "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "import torch\n", "import torchtext as text\n", "import sys\n", "import tqdm\n", "from datasets import load_dataset\n", "import pandas as pd\n", "import string\n", "from torchtext.data import get_tokenizer\n", "from vocab import *\n", "from utils import *\n", "from constants import *\n", "import pickle\n", "\n", "from tqdm.notebook import tqdm" ] }, { "cell_type": "code", "execution_count": 2, "id": "e34185c3", "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Reusing dataset squad (/tmp/xdg-cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e5702ad773134cf983bba5331063de68", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/2 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dataset = load_dataset(\"squad\")\n", "\n", "train_dic = {'passage': [], 'question': [], 'answer': []}\n", "for i in range(len(dataset['train'])):\n", " datum = dataset['train'][i]\n", " for j in range(len(datum['answers']['text'])):\n", " train_dic['passage'].append(datum['context'])\n", " train_dic['question'].append(datum['question'])\n", " train_dic['answer'].append(datum['answers']['text'][j])\n", "\n", "train = pd.DataFrame(train_dic)\n", "\n", "val_dic = {'passage': [], 'question': [], 'answer': []}\n", "for datum in dataset['validation']:\n", " for elem in datum['answers']['text']:\n", " ans_id = 0\n", " val_dic['passage'].append(datum['context'])\n", " val_dic['question'].append(datum['question'])\n", " val_dic['answer'].append(elem)\n", "\n", "val = pd.DataFrame(val_dic)" ] }, { "cell_type": "code", "execution_count": 3, "id": "70fa0515", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "980e0c9145a4401eaa5cc5a82a9f721a", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/87599 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "26efe8fa4acd4aac98addfa80543a6ec", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/34726 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "for i in tqdm(range(len(train))):\n", " row = train.iloc[i]\n", " passage = clean_text(row['passage'].lower())\n", " question = clean_text(row['question'].lower())\n", " answer = clean_text(row['answer'].lower())\n", " \n", " train.iloc[i]['passage'] = passage\n", " train.iloc[i]['question'] = question\n", " train.iloc[i]['answer'] = answer\n", " \n", "for i in tqdm(range(len(val))):\n", " row = val.iloc[i]\n", " passage = clean_text(row['passage'].lower())\n", " question = clean_text(row['question'].lower())\n", " answer = clean_text(row['answer'].lower())\n", " \n", " val.iloc[i]['passage'] = passage\n", " val.iloc[i]['question'] = question\n", " val.iloc[i]['answer'] = answer" ] }, { "cell_type": "code", "execution_count": 4, "id": "63809566", "metadata": {}, "outputs": [], "source": [ "#Split Val into val and test\n", "\n", "val = val.sample(frac=1).reset_index(drop=True)\n", "test = val[:10000]\n", "val = val[10000:]" ] }, { "cell_type": "code", "execution_count": 5, "id": "b5059d5c", "metadata": {}, "outputs": [], "source": [ "train.to_csv('./data/train.csv', index=False)\n", "val.to_csv('./data/val.csv', index=False)\n", "test.to_csv('./data/test.csv', index=False)" ] }, { "cell_type": "code", "execution_count": 6, "id": "b61acf74", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Building Vocabulary\n", "Saved the vocab.\n" ] } ], "source": [ "vocab = build_vocab()" ] }, { "cell_type": "code", "execution_count": 11, "id": "e27ec51d", "metadata": {}, "outputs": [], "source": [ "def get_processed_data(df, tokenizer):\n", " data = []\n", " for idx in tqdm(range(len(df))):\n", " pass_tokens = ['<start>'] + tokenizer(df.iloc[idx][\"passage\"]) + ['<end>']\n", " ans_tokens = ['<start>'] + tokenizer(df.iloc[idx][\"answer\"]) + ['<end>']\n", " q_tokens = ['<start>'] + tokenizer(df.iloc[idx][\"question\"]) + ['<end>']\n", " # pass_tokens = ['<start>'] + list(map(tokenizer, df.iloc[idx][\"passage\"])) + ['<end>']\n", " # ans_tokens = ['<start>'] + list(map(tokenizer, df.iloc[idx][\"answer\"])) + ['<end>']\n", " # q_tokens = ['<start>'] + list(map(tokenizer, df.iloc[idx][\"question\"])) + ['<end>']\n", "\n", " pass_len = MAX_PASSAGE_LEN + 2 # +2 for start and end tokens\n", " ans_len = MAX_ANSWER_LEN + 2\n", " q_len = MAX_QUESTION_LEN + 2\n", "\n", " passage = [vocab(word) for word in pass_tokens]\n", " answer = [vocab(word) for word in ans_tokens]\n", " question = [vocab(word) for word in q_tokens]\n", "\n", " # padding to same length\n", " pass_idxs = torch.zeros(pass_len)\n", " ans_idxs = torch.zeros(ans_len)\n", " q_idxs = torch.zeros(q_len)\n", "\n", " pass_idxs[:len(passage)] = torch.FloatTensor(passage)\n", " ans_idxs[:len(answer)] = torch.FloatTensor(answer)\n", " q_idxs[:len(question)] = torch.FloatTensor(question)\n", "\n", " data.append((pass_idxs, ans_idxs, q_idxs))\n", " return data" ] }, { "cell_type": "code", "execution_count": 12, "id": "2d03b997", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c050703069b4455da7807680794e088f", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/87599 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6f1aeb41cd514bb7a3c5b7b661e0a43d", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/24726 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "42cdac0229814c5591ab4516602c8b9b", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10000 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "tokenizer = get_tokenizer(\"basic_english\")\n", "train_processed = get_processed_data(train, tokenizer)\n", "val_processed = get_processed_data(val, tokenizer)\n", "test_processed = get_processed_data(test, tokenizer)" ] }, { "cell_type": "code", "execution_count": 16, "id": "d771f36c", "metadata": {}, "outputs": [], "source": [ "with open('./data/train_processed.pickle', 'wb') as train_file:\n", " pickle.dump(train_processed, train_file)\n", "\n", "with open('./data/val_processed.pickle', 'wb') as val_file:\n", " pickle.dump(val_processed, val_file)\n", "\n", "with open('./data/test_processed.pickle', 'wb') as test_file:\n", " pickle.dump(test_processed, test_file)" ] }, { "cell_type": "code", "execution_count": 15, "id": "a8f3781c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([ 1., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 6., 14., 15.,\n", " 16., 17., 18., 19., 20., 9., 21., 22., 23., 6., 24., 25., 12., 26.,\n", " 27., 28., 23., 6., 14., 15., 29., 30., 31., 5., 20., 9., 32., 22.,\n", " 23., 33., 34., 35., 36., 34., 6., 37., 38., 39., 40., 41., 12., 42.,\n", " 43., 6., 14., 15., 20., 6., 44., 23., 6., 45., 46., 12., 26., 47.,\n", " 6., 44., 20., 6., 48., 5., 9., 49., 50., 23., 51., 29., 52., 12.,\n", " 31., 20., 9., 53., 23., 6., 48., 54., 55., 5., 56., 57., 6., 24.,\n", " 25., 58., 59., 43., 60., 61., 62., 27., 63., 12., 54., 6., 64., 23.,\n", " 6., 14., 65., 66., 29., 27., 9., 67., 68., 69., 70., 71., 72., 73.,\n", " 29., 6., 18., 19., 74., 5., 20., 9., 75., 5., 76., 77., 22., 23.,\n", " 25., 12., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),\n", " tensor([ 1., 60., 61., 62., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0.]),\n", " tensor([ 1., 43., 78., 79., 6., 24., 25., 80., 81., 27., 63., 27., 55., 56.,\n", " 82., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0.]))" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pickle.load('./data/train_processed')" ] }, { "cell_type": "code", "execution_count": null, "id": "5eb5fea5", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "147ef1215b7e2b4bf3a64983c233460acb54149cdc3836f93d9a84ff7ba2f913" }, "kernelspec": { "display_name": "Python 3 (clean)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5" } }, "nbformat": 4, "nbformat_minor": 5 }