from fastNLP import Tester from preprocess import * # import fastNLP from transformers import BertConfig,RobertaConfig from model_gly import * from Bertmodel import ModelbyBYms2_bert from utils import * import os from pathlib import Path set_seed(seed) import argparse import json import numpy as np # ----------------------- args ------------------------------# import argparse def parsering(): parser = argparse.ArgumentParser() parser.add_argument('--datafold', type=str, default="/remote-home1/yxwang/test/zzb/DeepGlyco/DeepSweet_v1/data/mouse/Five_tissues/", help='datafold ') parser.add_argument('--trainpathcsv', type=str, default="/remote-home1/yxwang/test/zzb/DeepGlyco/DeepSweet_v1/data/mouse/PXD005413/PXD005413_MouseHeart_data_1st.csv", help='the train csv for test') parser.add_argument('--bestmodelpath', type=str, default="/remote-home1/yxwang/test/zzb/DeepGlyco/DeepSweet_v1/data/mouse/Five_tissues/Mouse_test_PXD005413_data_1st_combine_byBYprocessed/checkpoints/2023-04-13-00-45-52-259709/epoch-75_step-19800_mediancos-0.912608.pt", help='the best model path') parser.add_argument('--device', type=int, default=0, help='cudadevice') parser.add_argument('--postprocessing', type=str, default="off", help='on/off') parser.add_argument('--savename', type=str, default="test_byBY_PXD005413_model_0.912608", help='the output file name') parser.add_argument('--ms2_method', type=str, default="cos_sqrt", help='metric') args = parser.parse_args() return args args=parsering() savefold= os.path.join(args.datafold, "test_replace_predict") if os.path.exists(savefold): pass else: os.mkdir(savefold) savename=args.datafold+"test_replace_predict/"+args.savename device = torch.device('cuda', args.device) if torch.cuda.is_available() else torch.device('cpu') trainpathcsv=args.trainpathcsv traindatajson=trainpathcsv[:-4]+"_byBYprocessed.json" traindatajson_path = Path(traindatajson) print(f"{traindatajson} does not exist. Begin matrixwithdict to produce result...") os.system("python matrixwithdict.py \ --do_byBY \ --DDAfile {} \ --outputfile {} \ --split {}".format(trainpathcsv,traindatajson,"False")) # -----------------------------------------------------------# fpath=traindatajson databundle = PPeptidePipebyBY(vocab=vocab).process_from_file(paths=fpath) totaldata=databundle.get_dataset("train") print("totaldata",totaldata) def encode_dataset(obj): if isinstance(obj, np.ndarray) or isinstance(obj,torch.Tensor): return obj.tolist() else: print(type(obj)) ipdb.set_trace() return str(obj) def save_dataset_as_json(dataset, file_path): data_dict = {} dataset_field=['GlySpec', "peptide","charge", 'ions_by', 'ions_BY', 'iden_pep', "ions_BY_p",'_id'] for field_name in dataset_field: data_dict[field_name] = list(dataset.get_field(field_name)) encoded_dict = json.loads(json.dumps(data_dict, default=encode_dataset)) with open(file_path, 'w') as f: json.dump(encoded_dict, f, indent=4) save_dataset_as_json(totaldata, args.datafold+"test_replace_predict/"+args.savename+"_byBY_testdata.json") ###########model config=BertConfig.from_pretrained("bert-base-uncased") bestmodelpath=args.bestmodelpath deepms2=ModelbyBYms2_bert(config) bestmodel=torch.load(bestmodelpath).state_dict() deepms2.load_state_dict(bestmodel,strict=False) model_sign=bestmodelpath.split("/")[-1] from torchinfo import summary summary(deepms2) from fastNLP import Const metrics=Metric_byBY_outputmsms(savename=savename,pred=Const.OUTPUT,target=Const.TARGET, pred_by="pred_by",pred_BY="pred_BY",target_by="target_by",target_BY="target_BY",seq_len='seq_len', num_col=num_col,sequence='sequence',charge="charge",decoration="decoration", peptide="peptide",PlausibleStruct="PlausibleStruct", args=args) from MSELoss_for_byBY import MSELoss_byBY loss=MSELoss_byBY(pred_by="pred_by",pred_BY="pred_BY",target_by="target_by",target_BY="target_BY") ############tester pptester=Tester(model=deepms2,device=device,data=totaldata, loss=loss,metrics=metrics, batch_size=BATCH_SIZE) from timeit import default_timer as timer train_time_start = timer() pptester.test() train_time_end = timer() def print_train_time(start: float, end: float, device: torch.device = None): total_time = end - start print(f"Train time on {device}: {total_time:.3f} seconds") return total_time total_train_time_model_2 = print_train_time(start=train_time_start, end=train_time_end, device=device) postprocessing=args.postprocessing if postprocessing=="on": import postprocessing_spectra print("savename for postprocessing",savename) postprocessing_spectra.postprocessing(savename) print("end")