neural-question-generator / dataset_factory.py
dataset_factory.py
Raw
################################################################################
# CSE 253: Programming Assignment 4
# Code snippet by Ajit Kumar, Savyasachi
# Fall 2020
################################################################################

import csv, os, sys, pickle
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch
import pandas as pd
from utils import load_vocab

import torchtext as text
from torchtext.data.utils import get_tokenizer
import constants

class SQUAD(Dataset):

    def __init__(self, file):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
        """
        with open(file, 'rb') as pickle_file:
            self.data = pickle.load(pickle_file)
        self.tokenizer = get_tokenizer("basic_english")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx][0], self.data[idx][1], self.data[idx][2]


# Builds your datasets here based on the configuration.
# You are not required to modify this code but you are allowed to.
def get_datasets(config_data):
    train_file_path = os.path.join(sys.path[0], config_data['dataset']['training_file_path'])
    val_file_path = os.path.join(sys.path[0], config_data['dataset']['validation_file_path'])
    test_file_path = os.path.join(sys.path[0], config_data['dataset']['test_file_path'])

    vocabulary = load_vocab()

    train_dataset = SQUAD(train_file_path)
    train_data_loader = DataLoader(dataset=train_dataset,
                                    batch_size=config_data['dataset']['batch_size'],
                                    shuffle=True,
                                    num_workers=config_data['dataset']['num_workers'],
                                    pin_memory=True)

    val_dataset = SQUAD(val_file_path)
    val_data_loader = DataLoader(dataset=val_dataset,
                                    batch_size=config_data['dataset']['batch_size'],
                                    shuffle=False,
                                    num_workers=config_data['dataset']['num_workers'],
                                    pin_memory=True)

    test_dataset = SQUAD(test_file_path)
    test_data_loader = DataLoader(dataset=test_dataset,
                                    batch_size=config_data['dataset']['batch_size'],
                                    shuffle=False,
                                    num_workers=config_data['dataset']['num_workers'],
                                    pin_memory=True)

    return vocabulary, train_data_loader, val_data_loader, test_data_loader