# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import absolute_import, division, print_function
import argparse
import glob
import logging
import os
import pickle
import random
import re
import gc
import shutil
import json
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset
from torch.utils.data.distributed import DistributedSampler
try:
from torch.utils.tensorboard import SummaryWriter
except:
from tensorboardX import SummaryWriter
from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup,
BertConfig, BertForMaskedLM, BertTokenizer,
GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer,
DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer)
def mask_tokens(inputs, tokenizer, mlm_probability):
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability mlm_probability defaults to 0.15 in Bert/RoBERTa)
probability_matrix = torch.full(labels.shape, mlm_probability).to(inputs.device)
special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in
labels.tolist()]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool).to(inputs.device), value=0.0)
if tokenizer._pad_token is not None:
padding_mask = labels.eq(tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool().to(inputs.device) & masked_indices
inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool().to(inputs.device) & masked_indices & ~indices_replaced
random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long).to(inputs.device)
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels
class TextDataset(Dataset):
def __init__(self, tokenizer, args, logger, file_type='train', block_size=512):
if args.local_rank==-1:
local_rank=0
world_size=1
else:
local_rank=args.local_rank
world_size=torch.distributed.get_world_size()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
cached_file = os.path.join(args.output_dir, file_type+"_langs_%s"%(args.langs)+"_blocksize_%d"%(block_size)+"_wordsize_%d"%(world_size)+"_rank_%d"%(local_rank))
if os.path.exists(cached_file) and not args.overwrite_cache:
if file_type == 'train':
logger.warning("Loading features from cached file %s", cached_file)
with open(cached_file, 'rb') as handle:
self.inputs = pickle.load(handle)
else:
self.inputs = []
datafile = os.path.join(args.data_dir, "norm.pkl")
if file_type == 'train':
logger.warning("Creating features from dataset file at %s", datafile)
datas = pickle.load(open(datafile, "rb"))["norm"]
if file_type == "train":
datas = datas[:-1000]
else:
datas = datas[-1000:]
length = len(datas)
for idx, data in enumerate(datas):
if idx % world_size == local_rank:
code = " ".join(data)
code_tokens = tokenizer.tokenize(code)[:block_size-2]
code_tokens = [tokenizer.cls_token] + code_tokens + [tokenizer.sep_token]
code_ids = tokenizer.convert_tokens_to_ids(code_tokens)
padding_length = block_size - len(code_ids)
code_ids += [tokenizer.pad_token_id] * padding_length
self.inputs.append(code_ids)
if idx % (length//10) == 0:
percent = idx / (length//10) * 10
logger.warning("Rank %d, load %d"%(local_rank, percent))
if file_type == 'train':
logger.warning("Rank %d Training %d samples"%(local_rank, len(self.inputs)))
logger.warning("Saving features into cached file %s", cached_file)
with open(cached_file, 'wb') as handle:
pickle.dump(self.inputs, handle, protocol=pickle.HIGHEST_PROTOCOL)
def __len__(self):
return len(self.inputs)
def __getitem__(self, item):
return torch.tensor(self.inputs[item])
class ClassifierDataset(Dataset):
def __init__(self, tokenizer, args, logger, file_type='train', block_size=512):
if args.local_rank==-1:
local_rank=0
world_size=1
else:
local_rank=args.local_rank
world_size=torch.distributed.get_world_size()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
cached_file = os.path.join(args.output_dir, file_type+"_langs_%s"%(args.langs)+"_blocksize_%d"%(block_size)+"_wordsize_%d"%(world_size)+"_rank_%d"%(local_rank))
if os.path.exists(cached_file) and not args.overwrite_cache:
if file_type == 'train':
logger.warning("Loading features from cached file %s", cached_file)
with open(cached_file, 'rb') as handle:
datas = pickle.load(handle)
self.inputs, self.labels = datas["inputs"], datas["labels"]
else:
self.inputs = []
self.labels = []
if file_type != "test":
datafile = os.path.join(args.data_dir, "train.pkl")
else:
datafile = os.path.join(args.data_dir, "test.pkl")
if file_type == 'train':
logger.warning("Creating features from dataset file at %s", datafile)
datas = pickle.load(open(datafile, "rb"))
labels = datas["label"]
inputs = datas["norm"]
if file_type == "train":
inputs, labels = inputs[:-1000], labels[:-1000]
elif file_type == "dev":
inputs, labels = inputs[-1000:], labels[-1000:]
length = len(inputs)
for idx, (data, label) in enumerate(zip(inputs, labels)):
if idx % world_size == local_rank:
code = " ".join(data)
code_tokens = tokenizer.tokenize(code)[:block_size-2]
code_tokens = [tokenizer.cls_token] + code_tokens + [tokenizer.sep_token]
code_ids = tokenizer.convert_tokens_to_ids(code_tokens)
padding_length = block_size - len(code_ids)
code_ids += [tokenizer.pad_token_id] * padding_length
self.inputs.append(code_ids)
self.labels.append(label-1)
if idx % (length//10) == 0:
percent = idx / (length//10) * 10
logger.warning("Rank %d, load %d"%(local_rank, percent))
if file_type == 'train':
logger.warning("Rank %d Training %d samples"%(local_rank, len(self.inputs)))
logger.warning("Saving features into cached file %s", cached_file)
with open(cached_file, 'wb') as handle:
pickle.dump({"inputs": self.inputs, "labels": self.labels}, handle, protocol=pickle.HIGHEST_PROTOCOL)
def __len__(self):
return len(self.inputs)
def __getitem__(self, item):
return torch.tensor(self.inputs[item]), torch.tensor(self.labels[item])
class EvalDataset(Dataset):
def __init__(self, tokenizer, args, logger, file_type='train', block_size=1024):
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
cached_file = os.path.join(args.output_dir, file_type+"_blocksize_%d"%(block_size))
if os.path.exists(cached_file) and not args.overwrite_cache:
with open(cached_file, 'rb') as handle:
self.inputs = pickle.load(handle)
else:
self.inputs = []
datafile = os.path.join(args.data_dir, f"{file_type}.txt")
with open(datafile) as f:
data = f.readlines()
length = len(data)
logger.info("Data size: %d"%(length))
input_ids = []
for idx,x in enumerate(data):
x = x.strip()
if x.startswith("") and x.endswith(""):
pass
else:
x = " " + x + " "
try:
input_ids.extend(tokenizer.encode(x))
except Exception:
pass
if idx % (length//10) == 0:
percent = idx / (length//10) * 10
logger.warning("load %d"%(percent))
del data
gc.collect()
logger.info(f"tokens: {len(input_ids)}")
self.split(input_ids, tokenizer, logger, block_size=block_size)
del input_ids
gc.collect()
with open(cached_file, 'wb') as handle:
pickle.dump(self.inputs, handle, protocol=pickle.HIGHEST_PROTOCOL)
def split(self, input_ids, tokenizer, logger, block_size=1024):
sample = []
i = 0
while i < len(input_ids):
sample = input_ids[i: i+block_size]
if len(sample) == block_size:
for j in range(block_size):
if tokenizer.convert_ids_to_tokens(sample[block_size-1-j])[0] == '\u0120':
break
if sample[block_size-1-j] in [tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.sep_token_id]:
if sample[block_size-1-j] != tokenizer.bos_token_id:
j -= 1
break
if j == block_size-1:
print(tokenizer.decode(sample))
exit()
sample = sample[: block_size-1-j]
# print(len(sample))
i += len(sample)
pad_len = block_size-len(sample)
sample += [tokenizer.pad_token_id]*pad_len
self.inputs.append(sample)
if len(self.inputs) % 10000 == 0:
logger.info(f"{len(self.inputs)} samples")
def __len__(self):
return len(self.inputs)
def __getitem__(self, item):
return torch.tensor(self.inputs[item])
class lineDataset(Dataset):
def __init__(self, tokenizer, args, logger, file_type='test', block_size=924):
datafile = os.path.join(args.data_dir, f"{file_type}.json")
with open(datafile) as f:
datas = f.readlines()
length = len(datas)
logger.info("Data size: %d"%(length))
self.inputs = []
self.gts = []
for data in datas:
data = json.loads(data.strip())
self.inputs.append(tokenizer.encode(data["input"])[-block_size:])
self.gts.append(data["gt"])
def __len__(self):
return len(self.inputs)
def __getitem__(self, item):
return torch.tensor(self.inputs[item]), self.gts[item]