CodeBERT-Attack / rnn / dataset.py
dataset.py
Raw
# -*- coding: utf-8 -*-
"""
Created on Mon Nov  9 15:31:11 2020

@author: DrLC
"""

from torch.utils.data import Dataset
import pickle
import json
import torch
import random

class ClassifierDataset(Dataset):
    
    def __init__(self, file_path, file_type='train', n_dev=1000, max_len=512):
        
        file_type = file_type.lower()
        assert file_type in ['train', 'test', 'dev']
        with open(file_path, "rb") as f:
            d = pickle.load(f)
            inputs = d['norm']
            labels = d['label']
            assert len(inputs) == len(labels)
            assert (file_type in ['train', 'dev'] and n_dev < len(inputs)) or file_type == "test"
        if file_type == 'dev':
            self.inputs, self.labels = [], []
            for i in range(n_dev):
                self.inputs.append(inputs[i])
                self.labels.append(labels[i])
        elif file_type == 'train':
            self.inputs, self.labels = [], []
            for i in range(n_dev, len(inputs)):
                self.inputs.append(inputs[i])
                self.labels.append(labels[i])
        elif file_type == 'test':
            self.inputs, self.labels = inputs, labels
        for i in range(len(self.inputs)):
            l = min(len(self.inputs[i]), max_len)
            _input = []
            for j in range(l):
                _input.append(self.inputs[i][j].strip())
            self.inputs[i] = _input
            
    def __getitem__(self, idx):
        
        return self.inputs[idx], self.labels[idx]
    
    def __len__(self):
        
        return len(self.inputs)

def classifier_collate_fn(batch):
    
    data = [" ".join(item[0]) for item in batch]
    target = [item[1] - 1 for item in batch]
    target = torch.tensor(target)
    return data, target

class CloneDataset(Dataset):
    
    def __init__(self, src_path, file_path, file_type='train', keep_ratio=1, max_len=512):
        
        file_type = file_type.lower()
        assert file_type in ['train', 'test', 'valid']
        with open(src_path, "r") as f:
            src = {}
            for l in f.readlines():
                r = json.loads(l)
                src[r['idx']] = r['func']
            self.src = src
        with open(file_path, "r") as f:
            d = f.readlines()
            d = [i.strip().split() for i in d]
            if keep_ratio >= 1:
                pass
            elif keep_ratio >= 0:
                rand_idx = random.sample(list(range(len(d))), int(len(d) * keep_ratio))
                d = [d[i] for i in rand_idx]
            else:
                assert False
            idxs1 = [i[0] for i in d]
            idxs2 = [i[1] for i in d]
            labels = [int(i[2]) for i in d]
            assert len(idxs1) == len(labels) and len(idxs1) == len(idxs2)
        self.idxs1, self.idxs2, self.labels = idxs1, idxs2, labels
        for k in self.src.keys():
            l = min(len(self.src[k]), max_len)
            _input = []
            for j in range(l):
                _input.append(self.src[k][j].strip())
            self.src[k] = _input
            
    def __getitem__(self, idx):
        
        return self.src[self.idxs1[idx]], self.src[self.idxs2[idx]], self.labels[idx]
    
    def __len__(self):
        
        return len(self.labels)

def clone_collate_fn(batch):
    
    data = []
    for item in batch:
        data.append(" ".join(item[0]))
        data.append(" ".join(item[1]))
    target = [item[2] for item in batch]
    target = torch.tensor(target)
    return data, target

if __name__ == "__main__":
    
    from torch.utils.data import DataLoader, RandomSampler
    import tqdm
    
    '''
    dataset = ClassifierDataset('../data/train.pkl', 'train', 1000)
    sampler = RandomSampler(dataset)
    
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=2,
                            drop_last=True, collate_fn=classifier_collate_fn)
    for inputs, labels in dataloader:
        for s in inputs:
            assert len(s.split()) <= 512, \
                "\n"+str(len(s))+"\n"+str(s)
    '''
    
    dataset = CloneDataset('../data/bcb_data.jsonl', '../data/bcb_train.txt', 'train')
    sampler = RandomSampler(dataset)
    
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=2,
                            drop_last=True, collate_fn=clone_collate_fn)
    print (len(dataloader))
    for inputs, labels in tqdm.tqdm(dataloader):
        assert len(inputs) == 2 * len(labels)
        for s in inputs:
            assert len(s.split()) <= 512, \
                "\n"+str(len(s.split()))+"\n"+str(s)
        for l in labels:
            assert l.item() in [0, 1]
                
    dataset = CloneDataset('../data/bcb_data.jsonl', '../data/bcb_test.txt', 'test')
    sampler = RandomSampler(dataset)
    
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=2,
                            drop_last=True, collate_fn=clone_collate_fn)
    print (len(dataloader))
    for inputs, labels in tqdm.tqdm(dataloader):
        assert len(inputs) == 2 * len(labels)
        for s in inputs:
            assert len(s.split()) <= 512, \
                "\n"+str(len(s.split()))+"\n"+str(s)
        for l in labels:
            assert l.item() in [0, 1]
    
    dataset = CloneDataset('../data/bcb_data.jsonl', '../data/bcb_valid.txt', 'valid', 0.1)
    sampler = RandomSampler(dataset)
    
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=2,
                            drop_last=True, collate_fn=clone_collate_fn)
    print (len(dataloader))
    for inputs, labels in tqdm.tqdm(dataloader):
        assert len(inputs) == 2 * len(labels)
        for s in inputs:
            assert len(s.split()) <= 512, \
                "\n"+str(len(s.split()))+"\n"+str(s)
        for l in labels:
            assert l.item() in [0, 1]