neural-question-generator / create_datasets.ipynb
create_datasets.ipynb
Raw
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c2fee74f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import torch\n",
    "import torchtext as text\n",
    "import sys\n",
    "import tqdm\n",
    "from datasets import load_dataset\n",
    "import pandas as pd\n",
    "import string\n",
    "from torchtext.data import get_tokenizer\n",
    "from vocab import *\n",
    "from utils import *\n",
    "from constants import *\n",
    "import pickle\n",
    "\n",
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e34185c3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Reusing dataset squad (/tmp/xdg-cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e5702ad773134cf983bba5331063de68",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dataset = load_dataset(\"squad\")\n",
    "\n",
    "train_dic = {'passage': [], 'question': [], 'answer': []}\n",
    "for i in range(len(dataset['train'])):\n",
    "    datum = dataset['train'][i]\n",
    "    for j in range(len(datum['answers']['text'])):\n",
    "        train_dic['passage'].append(datum['context'])\n",
    "        train_dic['question'].append(datum['question'])\n",
    "        train_dic['answer'].append(datum['answers']['text'][j])\n",
    "\n",
    "train = pd.DataFrame(train_dic)\n",
    "\n",
    "val_dic = {'passage': [], 'question': [], 'answer': []}\n",
    "for datum in dataset['validation']:\n",
    "    for elem in datum['answers']['text']:\n",
    "        ans_id = 0\n",
    "        val_dic['passage'].append(datum['context'])\n",
    "        val_dic['question'].append(datum['question'])\n",
    "        val_dic['answer'].append(elem)\n",
    "\n",
    "val = pd.DataFrame(val_dic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "70fa0515",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "980e0c9145a4401eaa5cc5a82a9f721a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/87599 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "26efe8fa4acd4aac98addfa80543a6ec",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/34726 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for i in tqdm(range(len(train))):\n",
    "    row = train.iloc[i]\n",
    "    passage = clean_text(row['passage'].lower())\n",
    "    question = clean_text(row['question'].lower())\n",
    "    answer = clean_text(row['answer'].lower())\n",
    "    \n",
    "    train.iloc[i]['passage'] = passage\n",
    "    train.iloc[i]['question'] = question\n",
    "    train.iloc[i]['answer'] = answer\n",
    "    \n",
    "for i in tqdm(range(len(val))):\n",
    "    row = val.iloc[i]\n",
    "    passage = clean_text(row['passage'].lower())\n",
    "    question = clean_text(row['question'].lower())\n",
    "    answer = clean_text(row['answer'].lower())\n",
    "    \n",
    "    val.iloc[i]['passage'] = passage\n",
    "    val.iloc[i]['question'] = question\n",
    "    val.iloc[i]['answer'] = answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "63809566",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Split Val into val and test\n",
    "\n",
    "val = val.sample(frac=1).reset_index(drop=True)\n",
    "test = val[:10000]\n",
    "val = val[10000:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b5059d5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "train.to_csv('./data/train.csv', index=False)\n",
    "val.to_csv('./data/val.csv', index=False)\n",
    "test.to_csv('./data/test.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b61acf74",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Building Vocabulary\n",
      "Saved the vocab.\n"
     ]
    }
   ],
   "source": [
    "vocab = build_vocab()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "e27ec51d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_processed_data(df, tokenizer):\n",
    "        data = []\n",
    "        for idx in tqdm(range(len(df))):\n",
    "                pass_tokens = ['<start>'] + tokenizer(df.iloc[idx][\"passage\"]) + ['<end>']\n",
    "                ans_tokens = ['<start>'] + tokenizer(df.iloc[idx][\"answer\"]) + ['<end>']\n",
    "                q_tokens = ['<start>'] + tokenizer(df.iloc[idx][\"question\"]) + ['<end>']\n",
    "                # pass_tokens = ['<start>'] + list(map(tokenizer, df.iloc[idx][\"passage\"])) + ['<end>']\n",
    "                # ans_tokens = ['<start>'] + list(map(tokenizer, df.iloc[idx][\"answer\"])) + ['<end>']\n",
    "                # q_tokens = ['<start>'] + list(map(tokenizer, df.iloc[idx][\"question\"])) + ['<end>']\n",
    "\n",
    "                pass_len = MAX_PASSAGE_LEN + 2 # +2 for start and end tokens\n",
    "                ans_len = MAX_ANSWER_LEN + 2\n",
    "                q_len = MAX_QUESTION_LEN + 2\n",
    "\n",
    "                passage = [vocab(word) for word in pass_tokens]\n",
    "                answer = [vocab(word) for word in ans_tokens]\n",
    "                question = [vocab(word) for word in q_tokens]\n",
    "\n",
    "                # padding to same length\n",
    "                pass_idxs = torch.zeros(pass_len)\n",
    "                ans_idxs = torch.zeros(ans_len)\n",
    "                q_idxs = torch.zeros(q_len)\n",
    "\n",
    "                pass_idxs[:len(passage)] = torch.FloatTensor(passage)\n",
    "                ans_idxs[:len(answer)] = torch.FloatTensor(answer)\n",
    "                q_idxs[:len(question)] = torch.FloatTensor(question)\n",
    "\n",
    "                data.append((pass_idxs, ans_idxs, q_idxs))\n",
    "        return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "2d03b997",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c050703069b4455da7807680794e088f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/87599 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6f1aeb41cd514bb7a3c5b7b661e0a43d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/24726 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "42cdac0229814c5591ab4516602c8b9b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tokenizer = get_tokenizer(\"basic_english\")\n",
    "train_processed = get_processed_data(train, tokenizer)\n",
    "val_processed = get_processed_data(val, tokenizer)\n",
    "test_processed = get_processed_data(test, tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "d771f36c",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('./data/train_processed.pickle', 'wb') as train_file:\n",
    "    pickle.dump(train_processed, train_file)\n",
    "\n",
    "with open('./data/val_processed.pickle', 'wb') as val_file:\n",
    "    pickle.dump(val_processed, val_file)\n",
    "\n",
    "with open('./data/test_processed.pickle', 'wb') as test_file:\n",
    "    pickle.dump(test_processed, test_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "a8f3781c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([ 1.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,  6., 14., 15.,\n",
       "         16., 17., 18., 19., 20.,  9., 21., 22., 23.,  6., 24., 25., 12., 26.,\n",
       "         27., 28., 23.,  6., 14., 15., 29., 30., 31.,  5., 20.,  9., 32., 22.,\n",
       "         23., 33., 34., 35., 36., 34.,  6., 37., 38., 39., 40., 41., 12., 42.,\n",
       "         43.,  6., 14., 15., 20.,  6., 44., 23.,  6., 45., 46., 12., 26., 47.,\n",
       "          6., 44., 20.,  6., 48.,  5.,  9., 49., 50., 23., 51., 29., 52., 12.,\n",
       "         31., 20.,  9., 53., 23.,  6., 48., 54., 55.,  5., 56., 57.,  6., 24.,\n",
       "         25., 58., 59., 43., 60., 61., 62., 27., 63., 12., 54.,  6., 64., 23.,\n",
       "          6., 14., 65., 66., 29., 27.,  9., 67., 68., 69., 70., 71., 72., 73.,\n",
       "         29.,  6., 18., 19., 74.,  5., 20.,  9., 75.,  5., 76., 77., 22., 23.,\n",
       "         25., 12.,  2.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]),\n",
       " tensor([ 1., 60., 61., 62.,  2.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.]),\n",
       " tensor([ 1., 43., 78., 79.,  6., 24., 25., 80., 81., 27., 63., 27., 55., 56.,\n",
       "         82.,  2.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "          0.,  0.,  0.,  0.,  0.,  0.]))"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pickle.load('./data/train_processed')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5eb5fea5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "147ef1215b7e2b4bf3a64983c233460acb54149cdc3836f93d9a84ff7ba2f913"
  },
  "kernelspec": {
   "display_name": "Python 3 (clean)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}