KGTOSA / GNN-Methods / NodeClassifcation / SeHGNN / hgb / main.py
main.py
Raw
import os
import gc
import re
import time
import uuid
import argparse
import datetime
import numpy as np
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch_sparse import SparseTensor
from torch_sparse import remove_diag, set_diag

from model import *
from utils import *
from sparse_tools import SparseAdjList


def main(args):
    if args.seed > 0:
        set_random_seed(args.seed)

    g, adjs, init_labels, num_classes, dl, train_nid, val_nid, test_nid, test_nid_full \
        = load_dataset(args)

    if not args.neighbor_attention:
        for k in adjs.keys():
            adjs[k].storage._value = None
            adjs[k].storage._value = torch.ones(adjs[k].nnz()) / adjs[k].sum(dim=-1)[adjs[k].storage.row()]

    # =======
    # rearange node idx (for feats & labels)
    # =======
    train_node_nums = len(train_nid)
    valid_node_nums = len(val_nid)
    test_node_nums = len(test_nid)
    trainval_point = train_node_nums
    valtest_point = trainval_point + valid_node_nums
    total_num_nodes = train_node_nums + valid_node_nums + test_node_nums
    num_nodes = dl.nodes['count'][0]

    if total_num_nodes < num_nodes:
        flag = np.ones(num_nodes, dtype=bool)
        flag[train_nid] = 0
        flag[val_nid] = 0
        flag[test_nid] = 0
        extra_nid = np.where(flag)[0]
        print(f'Find {len(extra_nid)} extra nid for dataset {args.dataset}')
    else:
        extra_nid = np.array([])

    init2sort = torch.LongTensor(np.concatenate([train_nid, val_nid, test_nid, extra_nid]))
    sort2init = torch.argsort(init2sort)
    assert torch.all(init_labels[init2sort][sort2init] == init_labels)
    labels = init_labels[init2sort]

    # =======
    # neighbor aggregation
    # =======
    if args.dataset == 'DBLP':
        tgt_type = 'A'
        node_types = ['A', 'P', 'T', 'V']
        extra_metapath = []
    elif args.dataset == 'ACM':
        tgt_type = 'P'
        node_types = ['P', 'A', 'C']
        extra_metapath = []
    elif args.dataset == 'IMDB':
        tgt_type = 'M'
        node_types = ['M', 'A', 'D', 'K']
        extra_metapath = []
    elif args.dataset == 'Freebase':
        tgt_type = '0'
        node_types = [str(i) for i in range(8)]
        extra_metapath = []
    else:
        assert 0
    extra_metapath = [ele for ele in extra_metapath if len(ele) > args.num_hops + 1]

    print(f'Current num hops = {args.num_hops}')

    if args.dataset == 'Freebase':
        prop_device = 'cuda:{}'.format(args.gpu) if not args.cpu else 'cpu'
    else:
        prop_device = 'cpu'
    store_device = 'cpu'

    if args.dataset == 'Freebase':
        if not os.path.exists('./Freebase_adjs'):
            os.makedirs('./Freebase_adjs')
        num_tgt_nodes = dl.nodes['count'][0]

    # compute k-hop feature
    prop_tic = datetime.datetime.now()
    if args.dataset != 'Freebase':
        if len(extra_metapath):
            max_length = max(args.num_hops + 1, max([len(ele) for ele in extra_metapath]))
        else:
            max_length = args.num_hops + 1

        if args.neighbor_attention:
            meta_adjs = hg_propagate_sparse_pyg(adjs, tgt_type, args.num_hops, max_length, extra_metapath, prop_feats=True, echo=True, prop_device='cpu')
            assert tgt_type not in meta_adjs
            raw_feats = {k: g.nodes[k].data[k].clone() for k in g.ndata.keys()}
            print(f'For tgt {tgt_type}, Involved raw_feat keys {raw_feats.keys()}, feats keys {meta_adjs.keys()}')
        elif args.two_layer:
            assert node_types[0] == tgt_type
            meta_adjs = hg_propagate_sparse_pyg(adjs, node_types, args.num_hops, max_length, extra_metapath, prop_feats=True, echo=True, prop_device='cpu')
            for k in meta_adjs.keys(): assert len(k) > 1, k 
            raw_feats = {k: g.nodes[k].data[k].clone() for k in g.ndata.keys()}
            print(f'For tgt {tgt_type}, Involved raw_feat keys {raw_feats.keys()}, feats keys {meta_adjs.keys()}')
        else:
            g = hg_propagate_feat_dgl(g, tgt_type, args.num_hops, max_length, extra_metapath, echo=True)
            feats = {}
            keys = list(g.nodes[tgt_type].data.keys())
            print(f'For tgt {tgt_type}, feature keys {keys}')
            for k in keys:
                feats[k] = g.nodes[tgt_type].data.pop(k)
    else:
        if len(extra_metapath):
            max_length = max(args.num_hops + 1, max([len(ele) for ele in extra_metapath]))
        else:
            max_length = args.num_hops + 1

        if args.two_layer:
            meta_adjs = hg_propagate_sparse_pyg(adjs, node_types, args.num_hops, max_length, extra_metapath, prop_feats=True, echo=True, prop_device='cpu')
            for k in meta_adjs.keys(): assert len(k) > 1, k
        elif args.num_hops == 1:
            meta_adjs = {k: v.clone() for k, v in adjs.items() if k[0] == tgt_type}
        else:
            save_name = f'./Freebase_adjs/feat_seed{args.seed}_hop{args.num_hops}'
            if args.seed > 0 and os.path.exists(f'{save_name}_00_int64.npy'):
                # meta_adjs = torch.load(save_name)
                meta_adjs = {}
                for srcname in tqdm(dl.nodes['count'].keys()):
                    tmp = SparseAdjList(f'{save_name}_0{srcname}', None, None, num_tgt_nodes, dl.nodes['count'][srcname], with_values=True)
                    for k in tmp.keys:
                        assert k not in meta_adjs
                    meta_adjs.update(tmp.load_adjs(expand=True))
                    del tmp
            else:
                meta_adjs = hg_propagate_sparse_pyg(adjs, tgt_type, args.num_hops, max_length, extra_metapath, prop_feats=True, echo=True, prop_device=prop_device)

                meta_adj_list = []
                for srcname in dl.nodes['count'].keys():
                    keys = [k for k in meta_adjs.keys() if k[-1] == str(srcname)]
                    tmp = SparseAdjList(f'{save_name}_0{srcname}', keys, meta_adjs, num_tgt_nodes, dl.nodes['count'][srcname], with_values=True)
                    meta_adj_list.append(tmp)

                for srcname in dl.nodes['count'].keys():
                    tmp = SparseAdjList(f'{save_name}_0{srcname}', None, None, num_tgt_nodes, dl.nodes['count'][srcname], with_values=True)
                    tmp_adjs = tmp.load_adjs(expand=True)
                    print(srcname, tmp.keys)
                    for k in tmp.keys:
                        assert torch.all(meta_adjs[k].storage.rowptr() == tmp_adjs[k].storage.rowptr())
                        assert torch.all(meta_adjs[k].storage.col() == tmp_adjs[k].storage.col())
                        assert torch.all(meta_adjs[k].storage.value() == tmp_adjs[k].storage.value())
                    del tmp_adjs, tmp
                    gc.collect()

        feats = {k: v.clone() for k, v in meta_adjs.items() if len(k) <= args.num_hops + 1 or k in extra_metapath}
        if args.neighbor_attention:
            for k in feats.keys():
                feats[k].storage._value = None

        assert '0' not in feats
        if not args.neighbor_attention and not args.two_layer:
            feats['0'] = SparseTensor.eye(dl.nodes['count'][0])
        print(f'For tgt {tgt_type}, Involved keys {feats.keys()}')

    if args.dataset in ['DBLP', 'ACM', 'IMDB']:
        if args.neighbor_attention or args.two_layer:
            data_size = {k: g.ndata[k][k].size(-1) for k in g.ndata.keys()}
            raw_feats[tgt_type] = raw_feats[tgt_type][init2sort]

            feats = {}
            for k, v in tqdm(meta_adjs.items()):
                assert len(k) > 1
                if k[0] == tgt_type and k[-1] == tgt_type:
                    feats[k] = v[init2sort, init2sort]
                elif k[0] == tgt_type:
                    feats[k] = v[init2sort]
                else:
                    assert args.two_layer
                    if k[-1] == tgt_type:
                        feats[k] = v[:, init2sort]
                    else:
                        feats[k] = v
        else:
            data_size = {k: v.size(-1) for k, v in feats.items()}
            feats = {k: v[init2sort] for k, v in feats.items()}
    elif args.dataset == 'Freebase':
        data_size = dict(dl.nodes['count'])
        if args.neighbor_attention or args.two_layer:
            raw_feats = {}
            for k, count in data_size.items():
                raw_feats[k] = SparseTensor(row=torch.arange(count), col=torch.arange(count))

        for k, v in tqdm(feats.items()):
            if len(k) == 1:
                assert not args.neighbor_attention and not args.two_layer
                continue

            if k[0] == '0' and k[-1] == '0':
                # feats[k] = v[init2sort[:total_num_nodes], init2sort]
                # feats[k] = v[init2sort, init2sort]
                feats[k], _ = v.sample_adj(init2sort, -1, False) # faster, 50% time acceleration
            elif k[0] == '0':
                # feats[k] = v[init2sort[:total_num_nodes]]
                feats[k] = v[init2sort]
            else:
                assert args.two_layer, k
                if k[-1] == tgt_type:
                    feats[k] = v[:, init2sort]
                else:
                    feats[k] = v
    else:
        assert 0
    prop_toc = datetime.datetime.now()
    print(f'Time used for feat prop {prop_toc - prop_tic}')
    gc.collect()

    # =======
    checkpt_folder = f'./output/{args.dataset}/'
    if not os.path.exists(checkpt_folder):
        os.makedirs(checkpt_folder)
    checkpt_file = checkpt_folder + uuid.uuid4().hex
    print('checkpt_file', checkpt_file)

    if args.amp:
        scalar = torch.cuda.amp.GradScaler()
    else:
        scalar = None

    device = 'cuda:{}'.format(args.gpu) if not args.cpu else 'cpu'
    if args.dataset != 'IMDB':
        labels_cuda = labels.long().to(device)
    else:
        labels = labels.float()
        labels_cuda = labels.to(device)

    for stage in [0]:
        epochs = args.stage

        # =======
        # labels propagate alongside the metapath
        # =======
        label_feats = {}
        if args.label_feats:
            if args.dataset != 'IMDB':
                label_onehot = torch.zeros((num_nodes, num_classes))
                label_onehot[train_nid] = F.one_hot(init_labels[train_nid], num_classes).float()
            else:
                label_onehot = torch.zeros((num_nodes, num_classes))
                label_onehot[train_nid] = init_labels[train_nid].float()

            if args.dataset == 'DBLP':
                extra_metapath = []
            elif args.dataset == 'IMDB':
                extra_metapath = []
            elif args.dataset == 'ACM':
                extra_metapath = []
            elif args.dataset == 'Freebase':
                extra_metapath = []
            else:
                assert 0

            extra_metapath = [ele for ele in extra_metapath if len(ele) > args.num_label_hops + 1]
            if len(extra_metapath):
                max_length = max(args.num_label_hops + 1, max([len(ele) for ele in extra_metapath]))
            else:
                max_length = args.num_label_hops + 1

            print(f'Current label-prop num hops = {args.num_label_hops}')
            # compute k-hop feature
            prop_tic = datetime.datetime.now()
            if args.dataset == 'Freebase' and args.num_label_hops <= args.num_hops and len(extra_metapath) == 0:
                meta_adjs = {k: v for k, v in meta_adjs.items() if k[-1] == '0' and len(k) < max_length}
            else:
                if args.dataset == 'Freebase':
                    save_name = f'./Freebase_adjs/label_seed{args.seed}_hop{args.num_label_hops}'
                    if args.seed > 0 and os.path.exists(f'{save_name}_int64.npy'):
                        meta_adj_list = SparseAdjList(save_name, None, None, num_tgt_nodes, num_tgt_nodes, with_values=True)
                        meta_adjs = meta_adj_list.load_adjs(expand=True)
                    else:
                        meta_adjs = hg_propagate_sparse_pyg(
                            adjs, tgt_type, args.num_label_hops, max_length, extra_metapath, prop_feats=False, echo=True, prop_device=prop_device)
                        meta_adj_list = SparseAdjList(save_name, meta_adjs.keys(), meta_adjs, num_tgt_nodes, num_tgt_nodes, with_values=True)

                        tmp = SparseAdjList(save_name, None, None, num_tgt_nodes, num_tgt_nodes, with_values=True)
                        tmp_adjs = tmp.load_adjs(expand=True)
                        for k in tmp.keys:
                            assert torch.all(meta_adjs[k].storage.rowptr() == tmp_adjs[k].storage.rowptr())
                            assert torch.all(meta_adjs[k].storage.col() == tmp_adjs[k].storage.col())
                            assert torch.all(meta_adjs[k].storage.value() == tmp_adjs[k].storage.value())
                        del tmp_adjs, tmp
                        gc.collect()
                else:
                    meta_adjs = hg_propagate_sparse_pyg(
                        adjs, tgt_type, args.num_label_hops, max_length, extra_metapath, prop_feats=False, echo=True, prop_device=prop_device)

            if args.dataset == 'Freebase':
                if 0:
                    label_onehot_g = label_onehot.to(prop_device)
                    for k, v in tqdm(meta_adjs.items()):
                        if args.dataset != 'Freebase':
                            label_feats[k] = remove_diag(v) @ label_onehot
                        else:
                            label_feats[k] = (remove_diag(v).to(prop_device) @ label_onehot_g).to(store_device)

                    del label_onehot_g
                    torch.cuda.empty_cache()
                    gc.collect()

                    condition = lambda ra,rb,rc,k: rb > 0.2
                    check_acc(label_feats, condition, init_labels, train_nid, val_nid, test_nid, show_test=False)

                    left_keys = ['00', '000', '0000', '0010', '0030', '0040', '0050', '0060', '0070']
                    remove_keys = list(set(list(label_feats.keys())) - set(left_keys))
                    for k in remove_keys:
                        label_feats.pop(k)
                else:
                    left_keys = ['00', '000', '0000', '0010', '0030', '0040', '0050', '0060', '0070']
                    remove_keys = list(set(list(meta_adjs.keys())) - set(left_keys))
                    for k in remove_keys:
                        meta_adjs.pop(k)

                    label_onehot_g = label_onehot.to(prop_device)
                    for k, v in tqdm(meta_adjs.items()):
                        if args.dataset != 'Freebase':
                            label_feats[k] = remove_diag(v) @ label_onehot
                        else:
                            label_feats[k] = (remove_diag(v).to(prop_device) @ label_onehot_g).to(store_device)

                    del label_onehot_g
                    torch.cuda.empty_cache()
                    gc.collect()
            else:
                for k, v in tqdm(meta_adjs.items()):
                    if args.dataset != 'Freebase':
                        label_feats[k] = remove_diag(v) @ label_onehot
                    else:
                        label_feats[k] = (remove_diag(v).to(prop_device) @ label_onehot_g).to(store_device)
                gc.collect()

                if args.dataset == 'IMDB':
                    condition = lambda ra,rb,rc,k: True
                    check_acc(label_feats, condition, init_labels, train_nid, val_nid, test_nid, show_test=False, loss_type='bce')
                else:
                    condition = lambda ra,rb,rc,k: True
                    check_acc(label_feats, condition, init_labels, train_nid, val_nid, test_nid, show_test=True)
            print('Involved label keys', label_feats.keys())

            label_feats = {k: v[init2sort] for k,v in label_feats.items()}
            prop_toc = datetime.datetime.now()
            print(f'Time used for label prop {prop_toc - prop_tic}')

        # =======
        # Train & eval loaders
        # =======
        train_loader = torch.utils.data.DataLoader(
            torch.arange(train_node_nums), batch_size=args.batch_size, shuffle=True, drop_last=False)

        # =======
        # Mask & Smooth
        # =======
        with_mask = False
        # if args.dataset == 'Freebase':
        #     init_mask = {k: v.storage.rowcount() != 0 for k, v in feats.items()}
        #     with_mask = True
        # else:
        #     print(f'TODO: `with_mask` has not be implemented for {args.dataset}')

        # if with_mask:
        #     train_mask = {k: (v[:total_num_nodes] & (torch.randn(total_num_nodes) > 0)).float() for k, v in init_mask.items()}
        #     full_mask = {k: v.float() for k, v in init_mask.items()}
        # else:
        #     train_mask = full_mask = None

        # Freebase train/val/test/full_nodes: 1909/477/5568/40402
        # IMDB     train/val/test/full_nodes: 1097/274/3202/359
        eval_loader, full_loader = [], []
        batchsize = 2 * args.batch_size

        if args.two_layer:
            for batch_idx in range((total_num_nodes-1) // batchsize + 1):
                batch_start = batch_idx * batchsize
                batch_end = min(total_num_nodes, (batch_idx+1) * batchsize)
                batch = torch.arange(batch_start, batch_end)

                layer2_feats = {k: x[batch_start:batch_end] for k, x in feats.items() if k[0] == tgt_type}
                batch_labels_feats = {k: x[batch_start:batch_end] for k, x in label_feats.items()}

                involved_keys = {}
                for k, v in layer2_feats.items():
                    src = k[-1]
                    if src not in involved_keys:
                        involved_keys[src] = []
                    involved_keys[src].append(torch.unique(v.storage.col()))
                involved_keys = {k: torch.unique(torch.cat(v)) for k, v in involved_keys.items()}

                for k, v in layer2_feats.items():
                    src = k[-1]
                    old_nnz = v.nnz()
                    layer2_feats[k] = v[:, involved_keys[src]]
                    assert layer2_feats[k].nnz() == old_nnz

                layer1_feats = {k: v[involved_keys[k[0]]] for k, v in feats.items() if k[0] in involved_keys}

                eval_loader.append((involved_keys, layer1_feats, batch, layer2_feats, batch_labels_feats))

            for batch_idx in range((num_nodes-total_num_nodes-1) // batchsize + 1):
                batch_start = batch_idx * batchsize + total_num_nodes
                batch_end = min(num_nodes, (batch_idx+1) * batchsize + total_num_nodes)
                batch = torch.arange(batch_start, batch_end)

                layer2_feats = {k: x[batch_start:batch_end] for k, x in feats.items() if k[0] == tgt_type}
                batch_labels_feats = {k: x[batch_start:batch_end] for k, x in label_feats.items()}

                involved_keys = {}
                for k, v in layer2_feats.items():
                    src = k[-1]
                    if src not in involved_keys:
                        involved_keys[src] = []
                    involved_keys[src].append(torch.unique(v.storage.col()))
                involved_keys = {k: torch.unique(torch.cat(v)) for k, v in involved_keys.items()}

                for k, v in layer2_feats.items():
                    src = k[-1]
                    old_nnz = v.nnz()
                    layer2_feats[k] = v[:, involved_keys[src]]
                    assert layer2_feats[k].nnz() == old_nnz

                layer1_feats = {k: v[involved_keys[k[0]]] for k, v in feats.items() if k[0] in involved_keys}

                full_loader.append((involved_keys, layer1_feats, batch, layer2_feats, batch_labels_feats))
        else:
            for batch_idx in range((total_num_nodes-1) // batchsize + 1):
                batch_start = batch_idx * batchsize
                batch_end = min(total_num_nodes, (batch_idx+1) * batchsize)
                batch = torch.arange(batch_start, batch_end)

                batch_feats = {k: x[batch_start:batch_end] for k, x in feats.items()}
                batch_labels_feats = {k: x[batch_start:batch_end] for k, x in label_feats.items()}
                if with_mask:
                    batch_mask = {k: x[batch_start:batch_end] for k, x in full_mask.items()}
                else:
                    batch_mask = None
                eval_loader.append((batch, batch_feats, batch_labels_feats, batch_mask))

            for batch_idx in range((num_nodes-total_num_nodes-1) // batchsize + 1):
                batch_start = batch_idx * batchsize + total_num_nodes
                batch_end = min(num_nodes, (batch_idx+1) * batchsize + total_num_nodes)
                batch = torch.arange(batch_start, batch_end)

                batch_feats = {k: x[batch_start:batch_end] for k, x in feats.items()}
                batch_labels_feats = {k: x[batch_start:batch_end] for k, x in label_feats.items()}
                if with_mask:
                    batch_mask = {k: x[batch_start:batch_end] for k, x in full_mask.items()}
                else:
                    batch_mask = None
                full_loader.append((batch, batch_feats, batch_labels_feats, batch_mask))

        # =======
        # Construct network
        # =======
        torch.cuda.empty_cache()
        gc.collect()
        if args.neighbor_attention:
            model = SeHGNN_NA(args.embed_size, args.hidden, num_classes, feats.keys(), label_feats.keys(), tgt_type,
                args.dropout, args.input_drop, args.att_drop, args.label_drop,
                args.n_layers_1, args.n_layers_2, args.act, args.residual, bns=args.bns, data_size=data_size, num_heads=args.num_heads)
        elif args.two_layer:
            model = SeHGNN_2L(args.embed_size, args.hidden, num_classes,
                feats.keys(), [k for k in feats.keys() if k[0] == tgt_type], label_feats.keys(), node_types,
                args.dropout, args.input_drop, args.att_drop, args.label_drop,
                args.n_layers_1, args.n_layers_2, args.act, args.residual, bns=args.bns, data_size=data_size)
        else:
            model = SeHGNN(args.embed_size, args.hidden, num_classes, feats.keys(), label_feats.keys(), tgt_type,
                args.dropout, args.input_drop, args.att_drop, args.label_drop,
                args.n_layers_1, args.n_layers_2, args.act, args.residual, bns=args.bns, data_size=data_size,
                remove_transformer=args.remove_transformer, independent_attn=args.independent_attn)
        model = model.to(device)
        if args.seed == args.seeds[0]:
            print(model)
            print("# Params:", get_n_params(model))

        if args.dataset == 'IMDB':
            loss_fcn = nn.BCEWithLogitsLoss()
        else:
            loss_fcn = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
                                    weight_decay=args.weight_decay)

        best_epoch = -1
        best_val_loss = 1000000
        best_test_loss = 0
        best_val = (0,0)
        best_test = (0,0)
        val_loss_list, test_loss_list = [], []
        val_acc_list, test_acc_list = [], []
        actual_loss_list, actual_acc_list = [], []
        store_list = []
        best_pred = None
        count = 0

        train_times = []

        if args.neighbor_attention or args.two_layer:
            model.feats = {k: v.to(device) for k, v in raw_feats.items()}

        for epoch in tqdm(range(args.stage)):
            gc.collect()
            torch.cuda.synchronize()
            start = time.time()
            if args.two_layer:
                loss, acc = train_2l(model, feats, label_feats, labels_cuda, loss_fcn, optimizer, train_loader, evaluator, tgt_type, scalar=scalar)
            else:
                loss, acc = train(model, feats, label_feats, labels_cuda, loss_fcn, optimizer, train_loader, evaluator, scalar=scalar)
            torch.cuda.synchronize()
            end = time.time()

            log = "Epoch {}, training Time(s): {:.4f}, estimated train loss {:.4f}, acc {:.4f}, {:.4f}\n".format(epoch, end - start,loss, acc[0]*100, acc[1]*100)
            torch.cuda.empty_cache()
            train_times.append(end-start)

            start = time.time()
            with torch.no_grad():
                model.eval()
                raw_preds = []
                if args.two_layer:
                    for batch1, layer1_feats, batch2, layer2_feats, batch_labels_feats in eval_loader:
                        batch1 = {k: v.to(device) for k,v in batch1.items()}
                        layer1_feats = {k: v.to(device) for k,v in layer1_feats.items()}
                        batch2 = batch2.to(device)
                        layer2_feats = {k: v.to(device) for k,v in layer2_feats.items()}
                        batch_labels_feats = {k: x.to(device) for k, x in batch_labels_feats.items()}
                        raw_preds.append(model(layer1_feats, batch1, layer2_feats, batch2, batch_labels_feats).cpu())
                else:
                    for batch, batch_feats, batch_labels_feats, batch_mask in eval_loader:
                        batch = batch.to(device)
                        batch_feats = {k: x.to(device) for k, x in batch_feats.items()}
                        batch_labels_feats = {k: x.to(device) for k, x in batch_labels_feats.items()}
                        if with_mask:
                            batch_mask = {k: x.to(device) for k, x in batch_mask.items()}
                        else:
                            batch_mask = None
                        raw_preds.append(model(batch, batch_feats, batch_labels_feats, batch_mask).cpu())

                raw_preds = torch.cat(raw_preds, dim=0)
                loss_train = loss_fcn(raw_preds[:trainval_point], labels[:trainval_point]).item()
                loss_val = loss_fcn(raw_preds[trainval_point:valtest_point], labels[trainval_point:valtest_point]).item()
                loss_test = loss_fcn(raw_preds[valtest_point:total_num_nodes], labels[valtest_point:total_num_nodes]).item()

            if args.dataset != 'IMDB':
                preds = raw_preds.argmax(dim=-1)
            else:
                preds = (raw_preds > 0.).int()

            train_acc = evaluator(preds[:trainval_point], labels[:trainval_point])
            val_acc = evaluator(preds[trainval_point:valtest_point], labels[trainval_point:valtest_point])
            test_acc = evaluator(preds[valtest_point:total_num_nodes], labels[valtest_point:total_num_nodes])

            end = time.time()
            log += f'evaluation Time: {end-start}, Train loss: {loss_train}, Val loss: {loss_val}, Test loss: {loss_test}\n'
            log += 'Train acc: ({:.4f}, {:.4f}), Val acc: ({:.4f}, {:.4f}), Test acc: ({:.4f}, {:.4f}) ({})\n'.format(
                train_acc[0]*100, train_acc[1]*100, val_acc[0]*100, val_acc[1]*100, test_acc[0]*100, test_acc[1]*100, total_num_nodes-valtest_point)

            if (args.dataset != 'Freebase' and loss_val <= best_val_loss) or (args.dataset == 'Freebase' and sum(val_acc) >= sum(best_val)):
                best_epoch = epoch
                best_val_loss = loss_val
                best_test_loss = loss_test
                best_val = val_acc
                best_test = test_acc

                best_pred = raw_preds
                torch.save(model.state_dict(), f'{checkpt_file}.pkl')

                if epoch - best_epoch > args.patience: break

            if epoch > 0 and epoch % 10 == 0: 
                log = log + f'\tCurrent best at epoch {best_epoch} with Val loss {best_val_loss:.4f} ({best_val[0]*100:.4f}, {best_val[1]*100:.4f})' \
                    + f', Test loss {best_test_loss:.4f} ({best_test[0]*100:.4f}, {best_test[1]*100:.4f})'
            print(log)

        print('average train times', sum(train_times) / len(train_times))

        print(f'Best Epoch {best_epoch} at {checkpt_file.split("/")[-1]}\n\tFinal Val loss {best_val_loss:.4f} ({best_val[0]*100:.4f}, {best_val[1]*100:.4f})'
            + f', Test loss {best_test_loss:.4f} ({best_test[0]*100:.4f}, {best_test[1]*100:.4f})')

        if len(full_loader):
            model.load_state_dict(torch.load(f'{checkpt_file}.pkl', map_location='cpu'), strict=True)
            torch.cuda.empty_cache()
            with torch.no_grad():
                model.eval()
                raw_preds = []
                if args.two_layer:
                    for batch1, layer1_feats, batch2, layer2_feats, batch_labels_feats in full_loader:
                        batch1 = {k: v.to(device) for k,v in batch1.items()}
                        layer1_feats = {k: v.to(device) for k,v in layer1_feats.items()}
                        batch2 = batch2.to(device)
                        layer2_feats = {k: v.to(device) for k,v in layer2_feats.items()}
                        batch_labels_feats = {k: x.to(device) for k, x in batch_labels_feats.items()}
                        raw_preds.append(model(layer1_feats, batch1, layer2_feats, batch2, batch_labels_feats).cpu())
                else:
                    for batch, batch_feats, batch_labels_feats, batch_mask in full_loader:
                        batch = batch.to(device)
                        batch_feats = {k: x.to(device) for k, x in batch_feats.items()}
                        batch_labels_feats = {k: x.to(device) for k, x in batch_labels_feats.items()}
                        if with_mask:
                            batch_mask = {k: x.to(device) for k, x in batch_mask.items()}
                        else:
                            batch_mask = None
                        raw_preds.append(model(batch, batch_feats, batch_labels_feats, batch_mask).cpu())
                raw_preds = torch.cat(raw_preds, dim=0)
            best_pred = torch.cat((best_pred, raw_preds), dim=0)

        torch.save(best_pred, f'{checkpt_file}.pt')

        if args.dataset != 'IMDB':
            predict_prob = best_pred.softmax(dim=1)
        else:
            predict_prob = torch.sigmoid(best_pred)

        test_logits = predict_prob[sort2init][test_nid_full]
        if args.dataset != 'IMDB':
            pred = test_logits.cpu().numpy().argmax(axis=1)
            dl.gen_file_for_evaluate(test_idx=test_nid_full, label=pred, file_path=f"{args.dataset}_{args.seed}_{checkpt_file.split('/')[-1]}.txt")
        else:
            pred = (test_logits.cpu().numpy()>0.5).astype(int)
            dl.gen_file_for_evaluate(test_idx=test_nid_full, label=pred, file_path=f"{args.dataset}_{args.seed}_{checkpt_file.split('/')[-1]}.txt", mode='multi')

    if args.dataset != 'IMDB':
        preds = predict_prob.argmax(dim=1, keepdim=True)
    else:
        preds = (predict_prob > 0.5).int()
    train_acc = evaluator(labels[:trainval_point], preds[:trainval_point])
    val_acc = evaluator(labels[trainval_point:valtest_point], preds[trainval_point:valtest_point])
    test_acc = evaluator(labels[valtest_point:total_num_nodes], preds[valtest_point:total_num_nodes])

    print(f'train_acc ({train_acc[0]*100:.2f}, {train_acc[1]*100:.2f}) ' \
        + f'val_acc ({val_acc[0]*100:.2f}, {val_acc[1]*100:.2f}) ' \
        + f'test_acc ({test_acc[0]*100:.2f}, {test_acc[1]*100:.2f})')
    print(checkpt_file.split('/')[-1])


def parse_args(args=None):
    parser = argparse.ArgumentParser(description='SeHGNN')
    ## For environment costruction
    parser.add_argument("--seeds", nargs='+', type=int, default=[1],
                        help="the seed used in the training")
    parser.add_argument("--dataset", type=str, default="ogbn-mag")
    parser.add_argument("--gpu", type=int, default=0)
    parser.add_argument("--cpu", action='store_true', default=False)
    parser.add_argument("--root", type=str, default="../data/")
    parser.add_argument("--stage", type=int, default=200, help="The epoch setting for each stage.")
    parser.add_argument("--embed-size", type=int, default=256,
                        help="inital embedding size of nodes with no attributes")
    parser.add_argument("--num-hops", type=int, default=2,
                        help="number of hops for propagation of raw labels")
    parser.add_argument("--label-feats", action='store_true', default=False,
                        help="whether to use the label propagated features")
    parser.add_argument("--num-label-hops", type=int, default=2,
                        help="number of hops for propagation of raw features")
    ## For network structure
    parser.add_argument("--hidden", type=int, default=512)
    parser.add_argument("--dropout", type=float, default=0.5,
                        help="dropout on activation")
    parser.add_argument("--n-layers-1", type=int, default=2,
                        help="number of layers of feature projection")
    parser.add_argument("--n-layers-2", type=int, default=3,
                        help="number of layers of the downstream task")
    parser.add_argument("--input-drop", type=float, default=0.1,
                        help="input dropout of input features")
    parser.add_argument("--att-drop", type=float, default=0.,
                        help="attention dropout of model")
    parser.add_argument("--label-drop", type=float, default=0.,
                        help="label feature dropout of model")
    parser.add_argument("--residual", action='store_true', default=False,
                        help="whether to add residual branch the raw input features")
    parser.add_argument("--act", type=str, default='relu',
                        help="the activation function of the model")
    parser.add_argument("--bns", action='store_true', default=False,
                        help="whether to process the input features")
    parser.add_argument("--label-bns", action='store_true', default=False,
                        help="whether to process the input label features")
    ## for training
    parser.add_argument("--amp", action='store_true', default=False,
                        help="whether to amp to accelerate training with float16(half) calculation")
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--weight-decay", type=float, default=0)
    parser.add_argument("--eval-every", type=int, default=1)
    parser.add_argument("--batch-size", type=int, default=10000)
    parser.add_argument("--patience", type=int, default=100,
                        help="early stop of times of the experiment")
    parser.add_argument("--drop-metapath", type=float, default=0,
                        help="whether to process the input features")
    ## for ablation
    parser.add_argument("-na", "--neighbor-attention", action='store_true', default=False)
    parser.add_argument("--num-heads", type=int, default=1)
    parser.add_argument("--two-layer", action='store_true', default=False)
    parser.add_argument("--remove-transformer", action='store_true', default=False)
    parser.add_argument("--independent-attn", action='store_true', default=False)

    return parser.parse_args(args)

if __name__ == '__main__':
    args = parse_args()

    # args.bns = args.bns and args.dataset == 'Freebase' # remove bn for full-batch learning
    if args.dataset == 'ACM':
        args.ACM_keep_F = False
    assert args.neighbor_attention + args.two_layer + args.remove_transformer <= 1

    args.seed = args.seeds[0]
    print(args)

    for seed in args.seeds:
        args.seed = seed
        print('Restart with seed =', seed)
        main(args)