import fastNLP import wandb # import torch #保留rt,by,BY预测过程中不会改变的结构超参,需要优化的部分,写到模型里面,也可以叫args maxlength=60 acid_size=22 #氨基酸种类数(包含unk和pad) finetune_epoch=10 N_epochs=80 vocab_save=False vocab=fastNLP.Vocabulary().load("AA_vocab") BATCH_SIZE=256 # device="cuda:0" if torch.cuda.is_available () else "cpu" # device="cpu" num_col=36 #by离子数目 NH3_loss=True #保留BY的NH3中性丢失 seed=101 dropout=0.2 embed_size=512 nhead=8 # num_layers=6 GNN_edge_ablation="GIN" #GIN GNN_edge_hidden_dim=128 #64,128 GNN_edge_decoder_type="mlp" #hadamardlinear,hadamardmlp,mlp GNN_edge_num_layers=7 GNN_global_ablation="GCN" #GCN GNN_global_hidden_dim=32 GNN_global_num_layers=7 #7 #16,7,GCN #16,4,GIN #16,4,GCN # ms2_method="cos_sqrt"#simlarcalc:cos,pcc,cos_sqrt loc=True rt_method="R2" #R2,cos,delta class WandbCallback(fastNLP.Callback): r""" """ def __init__(self,project,name,config:dict): r""" :param str name: project名字 :param dict config: 模型超参 : """ super().__init__() self.project = project self.name=name self.config=config def on_train_begin(self): wandb.init( # Set entity to specify your username or team name # ex: entity="carey", # Set the project where this run will be logged project=self.project, name=self.name, # Track hyperparameters and run metadata config=self.config) def on_train_end(self): wandb.finish() def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): wandb.log(eval_result) def on_backward_begin(self,loss): wandb.log({"loss":loss})