# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from __future__ import absolute_import, division, print_function import argparse import glob import logging import os import pickle import random import re import gc import shutil import json import numpy as np import torch from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset from torch.utils.data.distributed import DistributedSampler try: from torch.utils.tensorboard import SummaryWriter except: from tensorboardX import SummaryWriter from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, BertConfig, BertForMaskedLM, BertTokenizer, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, RobertaConfig, RobertaForMaskedLM, RobertaTokenizer, DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer) def mask_tokens(inputs, tokenizer, mlm_probability): """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ labels = inputs.clone() # We sample a few tokens in each sequence for masked-LM training (with probability mlm_probability defaults to 0.15 in Bert/RoBERTa) probability_matrix = torch.full(labels.shape, mlm_probability).to(inputs.device) special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()] probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool).to(inputs.device), value=0.0) if tokenizer._pad_token is not None: padding_mask = labels.eq(tokenizer.pad_token_id) probability_matrix.masked_fill_(padding_mask, value=0.0) masked_indices = torch.bernoulli(probability_matrix).bool() labels[~masked_indices] = -100 # We only compute loss on masked tokens # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool().to(inputs.device) & masked_indices inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) # 10% of the time, we replace masked input tokens with random word indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool().to(inputs.device) & masked_indices & ~indices_replaced random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long).to(inputs.device) inputs[indices_random] = random_words[indices_random] # The rest of the time (10% of the time) we keep the masked input tokens unchanged return inputs, labels class TextDataset(Dataset): def __init__(self, tokenizer, args, logger, file_type='train', block_size=512): if args.local_rank==-1: local_rank=0 world_size=1 else: local_rank=args.local_rank world_size=torch.distributed.get_world_size() if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) cached_file = os.path.join(args.output_dir, file_type+"_langs_%s"%(args.langs)+"_blocksize_%d"%(block_size)+"_wordsize_%d"%(world_size)+"_rank_%d"%(local_rank)) if os.path.exists(cached_file) and not args.overwrite_cache: if file_type == 'train': logger.warning("Loading features from cached file %s", cached_file) with open(cached_file, 'rb') as handle: self.inputs = pickle.load(handle) else: self.inputs = [] datafile = os.path.join(args.data_dir, "norm.pkl") if file_type == 'train': logger.warning("Creating features from dataset file at %s", datafile) datas = pickle.load(open(datafile, "rb"))["norm"] if file_type == "train": datas = datas[:-1000] else: datas = datas[-1000:] length = len(datas) for idx, data in enumerate(datas): if idx % world_size == local_rank: code = " ".join(data) code_tokens = tokenizer.tokenize(code)[:block_size-2] code_tokens = [tokenizer.cls_token] + code_tokens + [tokenizer.sep_token] code_ids = tokenizer.convert_tokens_to_ids(code_tokens) padding_length = block_size - len(code_ids) code_ids += [tokenizer.pad_token_id] * padding_length self.inputs.append(code_ids) if idx % (length//10) == 0: percent = idx / (length//10) * 10 logger.warning("Rank %d, load %d"%(local_rank, percent)) if file_type == 'train': logger.warning("Rank %d Training %d samples"%(local_rank, len(self.inputs))) logger.warning("Saving features into cached file %s", cached_file) with open(cached_file, 'wb') as handle: pickle.dump(self.inputs, handle, protocol=pickle.HIGHEST_PROTOCOL) def __len__(self): return len(self.inputs) def __getitem__(self, item): return torch.tensor(self.inputs[item]) class ClassifierDataset(Dataset): def __init__(self, tokenizer, args, logger, file_type='train', block_size=512): if args.local_rank==-1: local_rank=0 world_size=1 else: local_rank=args.local_rank world_size=torch.distributed.get_world_size() if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) cached_file = os.path.join(args.output_dir, file_type+"_langs_%s"%(args.langs)+"_blocksize_%d"%(block_size)+"_wordsize_%d"%(world_size)+"_rank_%d"%(local_rank)) if os.path.exists(cached_file) and not args.overwrite_cache: if file_type == 'train': logger.warning("Loading features from cached file %s", cached_file) with open(cached_file, 'rb') as handle: datas = pickle.load(handle) self.inputs, self.labels = datas["inputs"], datas["labels"] else: self.inputs = [] self.labels = [] if file_type != "test": datafile = os.path.join(args.data_dir, "train.pkl") else: datafile = os.path.join(args.data_dir, "test.pkl") if file_type == 'train': logger.warning("Creating features from dataset file at %s", datafile) datas = pickle.load(open(datafile, "rb")) labels = datas["label"] inputs = datas["norm"] if file_type == "train": inputs, labels = inputs[:-1000], labels[:-1000] elif file_type == "dev": inputs, labels = inputs[-1000:], labels[-1000:] length = len(inputs) for idx, (data, label) in enumerate(zip(inputs, labels)): if idx % world_size == local_rank: code = " ".join(data) code_tokens = tokenizer.tokenize(code)[:block_size-2] code_tokens = [tokenizer.cls_token] + code_tokens + [tokenizer.sep_token] code_ids = tokenizer.convert_tokens_to_ids(code_tokens) padding_length = block_size - len(code_ids) code_ids += [tokenizer.pad_token_id] * padding_length self.inputs.append(code_ids) self.labels.append(label-1) if idx % (length//10) == 0: percent = idx / (length//10) * 10 logger.warning("Rank %d, load %d"%(local_rank, percent)) if file_type == 'train': logger.warning("Rank %d Training %d samples"%(local_rank, len(self.inputs))) logger.warning("Saving features into cached file %s", cached_file) with open(cached_file, 'wb') as handle: pickle.dump({"inputs": self.inputs, "labels": self.labels}, handle, protocol=pickle.HIGHEST_PROTOCOL) def __len__(self): return len(self.inputs) def __getitem__(self, item): return torch.tensor(self.inputs[item]), torch.tensor(self.labels[item]) class EvalDataset(Dataset): def __init__(self, tokenizer, args, logger, file_type='train', block_size=1024): if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) cached_file = os.path.join(args.output_dir, file_type+"_blocksize_%d"%(block_size)) if os.path.exists(cached_file) and not args.overwrite_cache: with open(cached_file, 'rb') as handle: self.inputs = pickle.load(handle) else: self.inputs = [] datafile = os.path.join(args.data_dir, f"{file_type}.txt") with open(datafile) as f: data = f.readlines() length = len(data) logger.info("Data size: %d"%(length)) input_ids = [] for idx,x in enumerate(data): x = x.strip() if x.startswith("") and x.endswith(""): pass else: x = " " + x + " " try: input_ids.extend(tokenizer.encode(x)) except Exception: pass if idx % (length//10) == 0: percent = idx / (length//10) * 10 logger.warning("load %d"%(percent)) del data gc.collect() logger.info(f"tokens: {len(input_ids)}") self.split(input_ids, tokenizer, logger, block_size=block_size) del input_ids gc.collect() with open(cached_file, 'wb') as handle: pickle.dump(self.inputs, handle, protocol=pickle.HIGHEST_PROTOCOL) def split(self, input_ids, tokenizer, logger, block_size=1024): sample = [] i = 0 while i < len(input_ids): sample = input_ids[i: i+block_size] if len(sample) == block_size: for j in range(block_size): if tokenizer.convert_ids_to_tokens(sample[block_size-1-j])[0] == '\u0120': break if sample[block_size-1-j] in [tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.sep_token_id]: if sample[block_size-1-j] != tokenizer.bos_token_id: j -= 1 break if j == block_size-1: print(tokenizer.decode(sample)) exit() sample = sample[: block_size-1-j] # print(len(sample)) i += len(sample) pad_len = block_size-len(sample) sample += [tokenizer.pad_token_id]*pad_len self.inputs.append(sample) if len(self.inputs) % 10000 == 0: logger.info(f"{len(self.inputs)} samples") def __len__(self): return len(self.inputs) def __getitem__(self, item): return torch.tensor(self.inputs[item]) class lineDataset(Dataset): def __init__(self, tokenizer, args, logger, file_type='test', block_size=924): datafile = os.path.join(args.data_dir, f"{file_type}.json") with open(datafile) as f: datas = f.readlines() length = len(datas) logger.info("Data size: %d"%(length)) self.inputs = [] self.gts = [] for data in datas: data = json.loads(data.strip()) self.inputs.append(tokenizer.encode(data["input"])[-block_size:]) self.gts.append(data["gt"]) def __len__(self): return len(self.inputs) def __getitem__(self, item): return torch.tensor(self.inputs[item]), self.gts[item]