# -*- coding: utf-8 -*- """ Created on Mon Nov 9 15:31:11 2020 @author: DrLC """ from torch.utils.data import Dataset import pickle import json import torch import random class ClassifierDataset(Dataset): def __init__(self, file_path, file_type='train', n_dev=1000, max_len=512): file_type = file_type.lower() assert file_type in ['train', 'test', 'dev'] with open(file_path, "rb") as f: d = pickle.load(f) inputs = d['norm'] labels = d['label'] assert len(inputs) == len(labels) assert (file_type in ['train', 'dev'] and n_dev < len(inputs)) or file_type == "test" if file_type == 'dev': self.inputs, self.labels = [], [] for i in range(n_dev): self.inputs.append(inputs[i]) self.labels.append(labels[i]) elif file_type == 'train': self.inputs, self.labels = [], [] for i in range(n_dev, len(inputs)): self.inputs.append(inputs[i]) self.labels.append(labels[i]) elif file_type == 'test': self.inputs, self.labels = inputs, labels for i in range(len(self.inputs)): l = min(len(self.inputs[i]), max_len) _input = [] for j in range(l): _input.append(self.inputs[i][j].strip()) self.inputs[i] = _input def __getitem__(self, idx): return self.inputs[idx], self.labels[idx] def __len__(self): return len(self.inputs) def classifier_collate_fn(batch): data = [" ".join(item[0]) for item in batch] target = [item[1] - 1 for item in batch] target = torch.tensor(target) return data, target class CloneDataset(Dataset): def __init__(self, src_path, file_path, file_type='train', keep_ratio=1, max_len=512): file_type = file_type.lower() assert file_type in ['train', 'test', 'valid'] with open(src_path, "r") as f: src = {} for l in f.readlines(): r = json.loads(l) src[r['idx']] = r['func'] self.src = src with open(file_path, "r") as f: d = f.readlines() d = [i.strip().split() for i in d] if keep_ratio >= 1: pass elif keep_ratio >= 0: rand_idx = random.sample(list(range(len(d))), int(len(d) * keep_ratio)) d = [d[i] for i in rand_idx] else: assert False idxs1 = [i[0] for i in d] idxs2 = [i[1] for i in d] labels = [int(i[2]) for i in d] assert len(idxs1) == len(labels) and len(idxs1) == len(idxs2) self.idxs1, self.idxs2, self.labels = idxs1, idxs2, labels for k in self.src.keys(): l = min(len(self.src[k]), max_len) _input = [] for j in range(l): _input.append(self.src[k][j].strip()) self.src[k] = _input def __getitem__(self, idx): return self.src[self.idxs1[idx]], self.src[self.idxs2[idx]], self.labels[idx] def __len__(self): return len(self.labels) def clone_collate_fn(batch): data = [] for item in batch: data.append(" ".join(item[0])) data.append(" ".join(item[1])) target = [item[2] for item in batch] target = torch.tensor(target) return data, target if __name__ == "__main__": from torch.utils.data import DataLoader, RandomSampler import tqdm ''' dataset = ClassifierDataset('../data/train.pkl', 'train', 1000) sampler = RandomSampler(dataset) dataloader = DataLoader(dataset, sampler=sampler, batch_size=2, drop_last=True, collate_fn=classifier_collate_fn) for inputs, labels in dataloader: for s in inputs: assert len(s.split()) <= 512, \ "\n"+str(len(s))+"\n"+str(s) ''' dataset = CloneDataset('../data/bcb_data.jsonl', '../data/bcb_train.txt', 'train') sampler = RandomSampler(dataset) dataloader = DataLoader(dataset, sampler=sampler, batch_size=2, drop_last=True, collate_fn=clone_collate_fn) print (len(dataloader)) for inputs, labels in tqdm.tqdm(dataloader): assert len(inputs) == 2 * len(labels) for s in inputs: assert len(s.split()) <= 512, \ "\n"+str(len(s.split()))+"\n"+str(s) for l in labels: assert l.item() in [0, 1] dataset = CloneDataset('../data/bcb_data.jsonl', '../data/bcb_test.txt', 'test') sampler = RandomSampler(dataset) dataloader = DataLoader(dataset, sampler=sampler, batch_size=2, drop_last=True, collate_fn=clone_collate_fn) print (len(dataloader)) for inputs, labels in tqdm.tqdm(dataloader): assert len(inputs) == 2 * len(labels) for s in inputs: assert len(s.split()) <= 512, \ "\n"+str(len(s.split()))+"\n"+str(s) for l in labels: assert l.item() in [0, 1] dataset = CloneDataset('../data/bcb_data.jsonl', '../data/bcb_valid.txt', 'valid', 0.1) sampler = RandomSampler(dataset) dataloader = DataLoader(dataset, sampler=sampler, batch_size=2, drop_last=True, collate_fn=clone_collate_fn) print (len(dataloader)) for inputs, labels in tqdm.tqdm(dataloader): assert len(inputs) == 2 * len(labels) for s in inputs: assert len(s.split()) <= 512, \ "\n"+str(len(s.split()))+"\n"+str(s) for l in labels: assert l.item() in [0, 1]