KGTOSA / GNN-Methods / NodeClassifcation / SeHGNN / motivation / multi_layer_HAN / main.py
main.py
Raw
import torch
from sklearn.metrics import f1_score
import dgl
from utils import load_data, EarlyStopping, set_random_seed, setup_log_dir
from utils import load_dblp, load_acm, load_acm_hgb
import torch.nn.functional as F
from model_hetero import HAN, HAN_freebase
import argparse
import datetime


def score(logits, labels, bce=False):
    if bce:
        prediction = (logits > 0.).data.cpu().long().numpy()
    else:
        _, indices = torch.max(logits, dim=1)
        prediction = indices.long().cpu().numpy()
    labels = labels.cpu().numpy()

    accuracy = (prediction == labels).sum() / len(prediction)
    micro_f1 = f1_score(labels, prediction, average='micro')
    macro_f1 = f1_score(labels, prediction, average='macro')

    return accuracy, micro_f1, macro_f1


def evaluate(model, g, features, labels, mask, loss_func, bce=False):
    model.eval()
    with torch.no_grad():
        logits = model(g, features)
    loss = loss_func(logits[mask], labels[mask])
    accuracy, micro_f1, macro_f1 = score(logits[mask], labels[mask], bce=bce)

    return loss, accuracy, micro_f1, macro_f1


def main(args):
    if args.dataset == 'DBLP':
        g, features, labels, num_classes, train_idx, val_idx, test_idx, train_mask, \
            val_mask, test_mask = load_dblp(args.tree_like_a_leaves, args.star_like_a_leaves)
        if args.tree_like_a_leaves or args.star_like_a_leaves:
            args.tgt_type = 'A'
        else:
            args.tgt_type = 'a'
    elif args.dataset == 'ACM':
        g, features, labels, num_classes, train_idx, val_idx, test_idx, train_mask, \
            val_mask, test_mask = load_acm()
        args.tgt_type = 'p'
    elif args.dataset == 'ACM_HGB':
        g, features, labels, num_classes, train_idx, val_idx, test_idx, train_mask, \
            val_mask, test_mask = load_acm_hgb(pp_symmetric=True)
        args.tgt_type = 'p'
    elif args.dataset == 'IMDB':
        g, features, labels, num_classes, train_idx, val_idx, test_idx, train_mask, \
            val_mask, test_mask = load_imdb()
        args.tgt_type = 'm'
    else:
        assert 0, f'Not ready for dataset {args.dataset}'
    print(g)

    if len(args.metapaths):
        meta_paths = []
        for ele in args.metapaths:
            if len(ele) == 1:
                meta_paths.append([ele])
            else:
                ele = ele[::-1]
                meta_paths.append([
                    ele[i:i+2] for i in range(len(ele)-1)])
    elif args.dataset == 'DBLP': # apa, aptpa, apcpa
        '''
            a(0*) -- p(1) -- c(2)
                     |
                     t(3)
            a(0): [4057, 334]
            p(1): [18385, 4231]
            c(2): [26108, 50]
            t(3): None or [20, 20]
        ''' # PS: edge ap means a->p; metapath [['ap', 'pc']] means  a->p->c
        meta_paths = [['ap', 'pa'], ['ap', 'pt', 'tp', 'pa'], ['ap', 'pc', 'cp', 'pa']]
    elif args.dataset == 'ACM': # pap, pfp
        '''
            a -- p* -- f
            p: [4025, 1903]
            a: [17351, 1903]
            f: [72, 1903]
        '''
        meta_paths = [['pa', 'ap'], ['pf', 'fp']]
    elif args.dataset == 'ACM_HGB': # ppsp, pap, psp, ptp
        '''
            a(0) -- p(1*) -- s(2)
                    |
                    t(3)
            p: [5959, 1902]
            a: [3025, 1902]
            s: [56, 1902]
            t: None or [1902, 1902]
        '''
        meta_paths = [['pp', 'ps', 'sp'], ['pa', 'ap'], ['ps', 'sp'], ['pt', 'tp']]
    else:
        assert False, f'Unrecognized dataset {args.dataset}'

    if args.use_gat: # GAT requires the src node type and the tgt node type must be same, while GCN not
        assert False, 'not implement'
        # for metapath in metapaths:
        #     assert metapath[0] == args.tgt_type and metapath[1] == args.tgt_type

    if hasattr(torch, 'BoolTensor'):
        train_mask = train_mask.bool()
        val_mask = val_mask.bool()
        test_mask = test_mask.bool()

    features = {k: v.to(args.device) for k, v in features.items()}
    labels = labels.to(args.device)
    train_mask = train_mask.to(args.device)
    val_mask = val_mask.to(args.device)
    test_mask = test_mask.to(args.device)
    g = g.to(args.device)

    if args.dataset == 'Freebase':
        model = HAN_freebase(
            meta_paths=meta_paths,
            in_size=features.shape[1],
            hidden_size=args.hidden_units,
            out_size=num_classes,
            num_heads=args.num_heads,
            dropout=args.dropout).to(args.device)
    else:
        model = HAN(
            meta_paths=meta_paths, tgt_type=args.tgt_type,
            in_sizes={k: v.shape[1] for k, v in features.items()},
            hidden_size=args.hidden_units,
            out_size=num_classes,
            num_heads=args.num_heads,
            dropout=args.dropout, use_gat=args.use_gat, attn_dropout=args.attn_dropout, residual=args.residual).to(args.device)
    print(model)

    stopper = EarlyStopping(patience=args.patience)
    if args.dataset == 'IMDB':
        loss_fcn = torch.nn.BCEWithLogitsLoss()
    else:
        loss_fcn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
                                 weight_decay=args.weight_decay)

    train_times = []
    for epoch in range(args.num_epochs):
        if epoch > 5:
            torch.cuda.synchronize()
            tic = datetime.datetime.now()
        model.train()
        logits = model(g, features)
        loss = loss_fcn(logits[train_mask], labels[train_mask])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if epoch > 5:
            torch.cuda.synchronize()
            toc = datetime.datetime.now()
            train_times.append(toc-tic)

        train_acc, train_micro_f1, train_macro_f1 = score(logits[train_mask], labels[train_mask], bce=args.dataset=='IMDB')
        val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(model, g, features, labels, val_mask, loss_fcn, bce=args.dataset=='IMDB')
        early_stop = stopper.step(val_loss.data.item(), val_acc, model)

        print('Epoch {:d} | Train Loss {:.4f} | Train Micro f1 {:.4f} | Train Macro f1 {:.4f} | '
              'Val Loss {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}'.format(
            epoch + 1, loss.item(), train_micro_f1, train_macro_f1, val_loss.item(), val_micro_f1, val_macro_f1))

        if early_stop:
            break

    total_times = datetime.timedelta(0)
    for ele in train_times:
        total_times += ele
    total_times = total_times / len(train_times)
    print('average train times', total_times)

    stopper.load_checkpoint(model)

    val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(model, g, features, labels, val_mask, loss_fcn, bce=args.dataset=='IMDB')
    print('Val loss {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}'.format(
        val_loss.item(), val_micro_f1, val_macro_f1))
    test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate(model, g, features, labels, test_mask, loss_fcn, bce=args.dataset=='IMDB')
    print('Test loss {:.4f} | Test Micro f1 {:.4f} | Test Macro f1 {:.4f}'.format(
        test_loss.item(), test_micro_f1, test_macro_f1))

if __name__ == '__main__':
    parser = argparse.ArgumentParser('HAN')
    parser.add_argument('-s', '--seeds', nargs='+', type=int, default=[1],
                        help='Random seeds')
    parser.add_argument('-ld', '--log-dir', type=str, default='results',
                        help='Dir for saving training results')
    parser.add_argument('--hetero', action='store_true',
                        help='Use metapath coalescing with DGL\'s own dataset')
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--dataset', type=str, default='DBLP',
                        choices=['DBLP', 'ACM', 'ACM_HGB', 'Freebase', 'IMDB'])
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--num_epochs', type=int, default=200)
    parser.add_argument('--patience', type=int, default=100)
    parser.add_argument("--use-gat", action='store_true', default=False)
    parser.add_argument("--lr", type=float, default=0.005)
    parser.add_argument("--num-heads", nargs='+',type=int, default=[8])
    parser.add_argument("--hidden-units", type=int, default=8)
    parser.add_argument("--dropout", type=float, default=0.6)
    parser.add_argument("--attn-dropout", type=float, default=0.6)
    parser.add_argument("--weight-decay", type=float, default=0.001)

    parser.add_argument('--metapaths', nargs='+',type=str, default=[])
    parser.add_argument("--residual", action='store_true', default=False)

    parser.add_argument("--tree-like-a-leaves", action='store_true', default=False)
    parser.add_argument("--tree-like-a-roots", action='store_true', default=False)
    parser.add_argument("--star-like-a-leaves", action='store_true', default=False)
    parser.add_argument("--star-like-m-roots", action='store_true', default=False)

    args = parser.parse_args()

    # args = parser.parse_args('--dataset DBLP'.split(' '))
    # args = parser.parse_args('--dataset DBLP --metapaths a ap apc --num-heads 8'.split(' '))
    # args = parser.parse_args('--dataset IMDB --metapaths m ma md --num-heads 8'.split(' '))
    # args = parser.parse_args('--dataset DBLP --metapaths Ap Apa Apc --num-heads 8 --star-like-a-leaves'.split(' '))
    args = parser.parse_args('--dataset DBLP --metapaths AP APa APc APcp APcpa --num-heads 8 --tree-like-a-leaves'.split(' '))
    # set_random_seed(args.seed[0])

    print(args)
    args.log_dir = setup_log_dir(args)
    for seed in args.seeds:
        if seed > 0:
            set_random_seed(seed)
        main(args)