DeepGPO / DeepGPO_multiple / multiple_code / utils.py
utils.py
Raw
import fastNLP
import wandb
# import torch
#保留rt,by,BY预测过程中不会改变的结构超参,需要优化的部分,写到模型里面,也可以叫args
maxlength=60
acid_size=22 #氨基酸种类数(包含unk和pad)
finetune_epoch=10
N_epochs=80
vocab_save=False
vocab=fastNLP.Vocabulary().load("AA_vocab")
BATCH_SIZE=256
# device="cuda:0" if torch.cuda.is_available () else "cpu"
# device="cpu"
num_col=36  #by离子数目
NH3_loss=True  #保留BY的NH3中性丢失
seed=101
dropout=0.2
embed_size=512
nhead=8
# num_layers=6

GNN_edge_ablation="GIN" #GIN
GNN_edge_hidden_dim=128 #64,128
GNN_edge_decoder_type="mlp" #hadamardlinear,hadamardmlp,mlp
GNN_edge_num_layers=7

GNN_global_ablation="GCN"  #GCN
GNN_global_hidden_dim=32 
GNN_global_num_layers=7  #7
#16,7,GCN #16,4,GIN #16,4,GCN
# ms2_method="cos_sqrt"#simlarcalc:cos,pcc,cos_sqrt
loc=True
rt_method="R2" #R2,cos,delta

class WandbCallback(fastNLP.Callback):
    r"""
    
    """
    def __init__(self,project,name,config:dict):
        r"""
        
        :param str name: project名字
        :param dict config: 模型超参
        :
        """
        super().__init__()
        self.project = project
        self.name=name
        self.config=config
    def on_train_begin(self):
        wandb.init(
      # Set entity to specify your username or team name
      # ex: entity="carey",
      # Set the project where this run will be logged
      project=self.project,
      name=self.name, 
      # Track hyperparameters and run metadata
      config=self.config)
    def on_train_end(self):
        wandb.finish()
    def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
        wandb.log(eval_result)
    def on_backward_begin(self,loss):
        wandb.log({"loss":loss})