import os import sys sys.path.append('../../') import time import argparse from sklearn.metrics import f1_score import torch import torch.nn.functional as F import numpy as np import random from utils.pytorchtools import EarlyStopping from utils.data import load_data from GNN import myGAT, myGAT2 import dgl import uuid def set_random_seed(seed=0): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) def sp_to_spt(mat): coo = mat.tocoo() values = coo.data indices = np.vstack((coo.row, coo.col)) i = torch.LongTensor(indices) v = torch.FloatTensor(values) shape = coo.shape return torch.sparse.FloatTensor(i, v, torch.Size(shape)) def mat2tensor(mat): if type(mat) is np.ndarray: return torch.from_numpy(mat).type(torch.FloatTensor) return sp_to_spt(mat) def run_model_DBLP(args): feats_type = args.feats_type features_list, adjM, labels, train_val_test_idx, dl = load_data(args.dataset, load_full_test=args.test_full) print(f'Show dataset {args.dataset}') print(dl.nodes['count']) print(dl.links) if args.dataset == 'Freebase': device = torch.device('cpu') else: device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') features_list = [mat2tensor(features).to(device) for features in features_list] if feats_type == 0: in_dims = [features.shape[1] for features in features_list] elif feats_type == 1 or feats_type == 5: save = 0 if feats_type == 1 else 2 in_dims = [] for i in range(0, len(features_list)): if i == save: in_dims.append(features_list[i].shape[1]) else: in_dims.append(10) features_list[i] = torch.zeros((features_list[i].shape[0], 10)).to(device) elif feats_type == 2 or feats_type == 4: save = feats_type - 2 in_dims = [features.shape[0] for features in features_list] for i in range(0, len(features_list)): if i == save: in_dims[i] = features_list[i].shape[1] continue dim = features_list[i].shape[0] indices = np.vstack((np.arange(dim), np.arange(dim))) indices = torch.LongTensor(indices) values = torch.FloatTensor(np.ones(dim)) features_list[i] = torch.sparse.FloatTensor(indices, values, torch.Size([dim, dim])).to(device) elif feats_type == 3: in_dims = [features.shape[0] for features in features_list] for i in range(len(features_list)): dim = features_list[i].shape[0] indices = np.vstack((np.arange(dim), np.arange(dim))) indices = torch.LongTensor(indices) values = torch.FloatTensor(np.ones(dim)) features_list[i] = torch.sparse.FloatTensor(indices, values, torch.Size([dim, dim])).to(device) train_idx = train_val_test_idx['train_idx'] train_idx = np.sort(train_idx) val_idx = train_val_test_idx['val_idx'] val_idx = np.sort(val_idx) test_idx = train_val_test_idx['test_idx'] test_idx = np.sort(test_idx) if args.test_full: if args.dataset == 'DBLP': labels[test_idx] = torch.load('../../data/DBLP_test_labels.pt', map_location='cpu') elif args.dataset == 'ACM': labels[test_idx] = torch.load('../../data/ACM_test_labels.pt', map_location='cpu') labels = torch.LongTensor(labels).to(device) g = dgl.DGLGraph(adjM+(adjM.T)) g = dgl.remove_self_loop(g) g = dgl.add_self_loop(g) g = g.to(device) if not os.path.exists(f'pre_data_{args.dataset}.pt'): edge2type = {} for k in dl.links['data']: for u,v in zip(*dl.links['data'][k].nonzero()): edge2type[(u,v)] = k for i in range(dl.nodes['total']): if (i,i) not in edge2type: edge2type[(i,i)] = len(dl.links['count']) for k in dl.links['data']: for u,v in zip(*dl.links['data'][k].nonzero()): if (v,u) not in edge2type: edge2type[(v,u)] = k+1+len(dl.links['count']) e_feat = [] for u, v in zip(*g.edges()): u = u.cpu().item() v = v.cpu().item() e_feat.append(edge2type[(u,v)]) e_feat = torch.tensor(e_feat, dtype=torch.long) torch.save(e_feat, f'pre_data_{args.dataset}.pt') else: e_feat = torch.load(f'pre_data_{args.dataset}.pt') e_feat = e_feat.to(device) num_relations = e_feat.max().item() + 1 for _ in range(args.repeat): num_classes = dl.labels_train['num_classes'] heads = [args.num_heads] * args.num_layers + [1] net = myGAT2(g, args.edge_feats, len(dl.links['count'])*2+1, in_dims, args.hidden_dim, num_classes, args.num_layers, heads, F.elu, args.dropout, args.dropout, args.slope, True, 0.05, num_relations=num_relations) net.to(device) optimizer = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay) net.train() save_path = 'checkpoint' if not os.path.exists(save_path): os.mkdir(save_path) save_path = f'{save_path}/checkpoint_{args.dataset}_{args.num_layers}_{uuid.uuid4().hex}.pkl' print('Model will save to', save_path) early_stopping = EarlyStopping(patience=args.patience, verbose=True, save_path=save_path) train_times = [] for epoch in range(args.epoch): torch.cuda.synchronize() t_start = time.time() net.train() if args.average_attention_values: logits = net(features_list, e_feat, average_na_layers=[0,1,2]) elif args.average_semantic_values: logits = net(features_list, e_feat, average_sa_layers=[0,1,2]) else: logits = net(features_list, e_feat) logp = F.log_softmax(logits, 1) train_loss = F.nll_loss(logp[train_idx], labels[train_idx]) optimizer.zero_grad() train_loss.backward() optimizer.step() torch.cuda.synchronize() t_end = time.time() if epoch > 5: train_times.append(t_end - t_start) train_acc = [ f1_score(labels[train_idx].cpu().squeeze(), logits[train_idx].cpu().argmax(dim=-1), average='micro'), f1_score(labels[train_idx].cpu().squeeze(), logits[train_idx].cpu().argmax(dim=-1), average='macro') ] print('train_acc', train_acc) print('Epoch {:05d} | Train_Loss: {:.4f} | Time: {:.4f}'.format(epoch, train_loss.item(), t_end-t_start)) t_start = time.time() net.eval() with torch.no_grad(): ##### if args.average_attention_values: logits = net(features_list, e_feat, average_na_layers=[0,1,2]) elif args.average_semantic_values: logits = net(features_list, e_feat, average_sa_layers=[0,1,2]) else: logits = net(features_list, e_feat) ##### logp = F.log_softmax(logits, 1) val_loss = F.nll_loss(logp[val_idx], labels[val_idx]) t_end = time.time() val_acc = [ f1_score(labels[val_idx].cpu().squeeze(), logits[val_idx].cpu().argmax(dim=-1), average='micro'), f1_score(labels[val_idx].cpu().squeeze(), logits[val_idx].cpu().argmax(dim=-1), average='macro') ] print('val_acc', val_acc) print('Epoch {:05d} | Val_Loss {:.4f} | Time(s) {:.4f}'.format( epoch, val_loss.item(), t_end - t_start)) early_stopping(val_loss, net) if early_stopping.early_stop: print('Early stopping!') break print('average train times', sum(train_times) / len(train_times)) net.load_state_dict(torch.load(save_path), strict=True) net.eval() if not args.average_attention_values and not args.average_semantic_values: print('\n\nThe result of original HGB is:') with torch.no_grad(): logits = net(features_list, e_feat, average_na_layers=[]) logp = F.log_softmax(logits, 1) val_loss = F.nll_loss(logp[val_idx], labels[val_idx]) val_acc = [ f1_score(labels[val_idx].cpu().squeeze(), logits[val_idx].cpu().argmax(dim=-1), average='micro'), f1_score(labels[val_idx].cpu().squeeze(), logits[val_idx].cpu().argmax(dim=-1), average='macro') ] print('\tval_acc', val_acc) test_acc = [ f1_score(labels[test_idx].cpu().squeeze(), logits[test_idx].cpu().argmax(dim=-1), average='micro'), f1_score(labels[test_idx].cpu().squeeze(), logits[test_idx].cpu().argmax(dim=-1), average='macro') ] print('\ttest_acc', test_acc) print('\nIf average attention values of neighbors of the same node type on a well-trained original HGB model:') with torch.no_grad(): logits = net(features_list, e_feat, average_na_layers=[0,1,2]) logp = F.log_softmax(logits, 1) val_loss = F.nll_loss(logp[val_idx], labels[val_idx]) val_acc = [ f1_score(labels[val_idx].cpu().squeeze(), logits[val_idx].cpu().argmax(dim=-1), average='micro'), f1_score(labels[val_idx].cpu().squeeze(), logits[val_idx].cpu().argmax(dim=-1), average='macro') ] print('val_acc', val_acc) test_acc = [ f1_score(labels[test_idx].cpu().squeeze(), logits[test_idx].cpu().argmax(dim=-1), average='micro'), f1_score(labels[test_idx].cpu().squeeze(), logits[test_idx].cpu().argmax(dim=-1), average='macro') ] print('test_acc', test_acc) if args.average_attention_values or args.average_semantic_values: if args.average_attention_values: print('\nThe result of HGB* is:') else: print('\nThe result of HGB† is:') with torch.no_grad(): if args.average_attention_values: logits = net(features_list, e_feat, average_na_layers=[0,1,2]) elif args.average_semantic_values: logits = net(features_list, e_feat, average_sa_layers=[0,1,2]) else: assert 0 logp = F.log_softmax(logits, 1) val_loss = F.nll_loss(logp[val_idx], labels[val_idx]) val_acc = [ f1_score(labels[val_idx].cpu().squeeze(), logits[val_idx].cpu().argmax(dim=-1), average='micro'), f1_score(labels[val_idx].cpu().squeeze(), logits[val_idx].cpu().argmax(dim=-1), average='macro') ] print('val_acc', val_acc) test_acc = [ f1_score(labels[test_idx].cpu().squeeze(), logits[test_idx].cpu().argmax(dim=-1), average='micro'), f1_score(labels[test_idx].cpu().squeeze(), logits[test_idx].cpu().argmax(dim=-1), average='macro') ] print('test_acc', test_acc) if __name__ == '__main__': ap = argparse.ArgumentParser(description='MRGNN testing for the DBLP dataset') ap.add_argument('--feats-type', type=int, default=3, help='Type of the node features used. ' + '0 - loaded features; ' + '1 - only target node features (zero vec for others); ' + '2 - only target node features (id vec for others); ' + '3 - all id vec. Default is 2;' + '4 - only term features (id vec for others);' + '5 - only term features (zero vec for others).') ap.add_argument('-s', '--seeds', nargs='+', type=int, default=[1], help='Random seeds') ap.add_argument('--hidden-dim', type=int, default=64, help='Dimension of the node hidden state. Default is 64.') ap.add_argument('--num-heads', type=int, default=8, help='Number of the attention heads. Default is 8.') ap.add_argument('--epoch', type=int, default=300, help='Number of epochs.') ap.add_argument('--patience', type=int, default=30, help='Patience.') ap.add_argument('--repeat', type=int, default=1, help='Repeat the training and testing for N times. Default is 1.') ap.add_argument('--num-layers', type=int, default=2) ap.add_argument('--lr', type=float, default=5e-4) ap.add_argument('--dropout', type=float, default=0.5) ap.add_argument('--weight-decay', type=float, default=1e-4) ap.add_argument('--slope', type=float, default=0.05) ap.add_argument('--dataset', type=str) ap.add_argument('--edge-feats', type=int, default=64) ap.add_argument('--run', type=int, default=1) ap.add_argument('--root', type=str, default='../../data') ap.add_argument('--average-attention-values', action='store_true', default=False) ap.add_argument('--average-semantic-values', action='store_true', default=False) ap.add_argument("--test-full", action='store_true', default=False) args = ap.parse_args() # args = ap.parse_args('--dataset DBLP --seed 1 --average-semantic-values'.split(' ')) # args = ap.parse_args('--dataset DBLP --average-attention-values --seed 1'.split(' ')) # run_model_DBLP(args) print(args) assert args.average_attention_values + args.average_semantic_values < 2, \ 'it cannot average neighbor attention and semantic attention simutaneouslty' for seed in args.seeds: if seed > 0: set_random_seed(seed) run_model_DBLP(args)