KGTOSA / GNN-Methods / NodeClassifcation / SeHGNN / motivation / attn_HAN / main.py
main.py
Raw
import torch
from sklearn.metrics import f1_score
import dgl
from utils import load_data, EarlyStopping, set_random_seed, setup_log_dir
import torch.nn.functional as F
from model_hetero import HAN, HAN_freebase
import argparse
import datetime


def score(logits, labels):
    _, indices = torch.max(logits, dim=1)
    prediction = indices.long().cpu().numpy()
    labels = labels.cpu().numpy()

    accuracy = (prediction == labels).sum() / len(prediction)
    micro_f1 = f1_score(labels, prediction, average='micro')
    macro_f1 = f1_score(labels, prediction, average='macro')

    # return accuracy, micro_f1, macro_f1
    return (micro_f1+macro_f1)/2, micro_f1, macro_f1


def evaluate(model, g, features, labels, mask, loss_func):
    model.eval()
    with torch.no_grad():
        logits = model(g, features)
    loss = loss_func(logits[mask], labels[mask])
    accuracy, micro_f1, macro_f1 = score(logits[mask], labels[mask])

    return loss, accuracy, micro_f1, macro_f1


def main(args):
    g, features, labels, num_classes, train_idx, val_idx, test_idx, train_mask, \
    val_mask, test_mask, meta_paths = load_data(args.dataset, feat_type=0)

    if hasattr(torch, 'BoolTensor'):
        train_mask = train_mask.bool()
        val_mask = val_mask.bool()
        test_mask = test_mask.bool()

    # if args.dataset == 'DBLP':
    #     labels[test_idx] = torch.load('../../data/DBLP_test_labels.pt', map_location='cpu')
    #     new_labels = labels.to(args.device)
    # elif args.dataset == 'ACM_HGB':
    #     labels[test_idx] = torch.load('../../data/ACM_test_labels.pt', map_location='cpu')
    #     new_labels = labels.to(args.device)
    # else:
    new_labels = labels.to(args.device)

    features = features.to(args.device)
    labels = labels.to(args.device)
    train_mask = train_mask.to(args.device)
    val_mask = val_mask.to(args.device)
    test_mask = test_mask.to(args.device)
    g = g.to(args.device)

    if args.dataset == 'Freebase':
        model = HAN_freebase(
            meta_paths=meta_paths,
            in_size=features.shape[1],
            hidden_size=args.hidden_units,
            out_size=num_classes,
            num_heads=args.num_heads,
            dropout=args.dropout).to(args.device)
    else:
        if args.use_gat:
            model = HAN(
                meta_paths=meta_paths,
                in_size=features.shape[1],
                hidden_size=args.hidden_units,
                out_size=num_classes,
                num_heads=args.num_heads,
                dropout=args.dropout,
                remove_sa=args.remove_sa_attn,
                use_gat=True, attn_dropout=args.attn_dropout).to(args.device)
        else:
            model = HAN(
                meta_paths=meta_paths,
                in_size=features.shape[1],
                hidden_size=args.hidden_units * args.num_heads[0],
                out_size=num_classes,
                num_heads=[1],
                dropout=args.dropout,
                remove_sa=args.remove_sa_attn,
                use_gat=False, attn_dropout=args.attn_dropout).to(args.device)

    stopper = EarlyStopping(patience=args.patience)
    loss_fcn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
                                 weight_decay=args.weight_decay)

    train_times = []
    for epoch in range(args.num_epochs):
        if epoch > 5:
            torch.cuda.synchronize()
            tic = datetime.datetime.now()
        model.train()
        logits = model(g, features)
        loss = loss_fcn(logits[train_mask], labels[train_mask])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if epoch > 5:
            torch.cuda.synchronize()
            toc = datetime.datetime.now()
            train_times.append(toc-tic)

        train_acc, train_micro_f1, train_macro_f1 = score(logits[train_mask], labels[train_mask])
        val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(model, g, features, labels, val_mask, loss_fcn)
        early_stop = stopper.step(val_loss.data.item(), val_acc, model)

        print('Epoch {:d} | Train Loss {:.4f} | Train Micro f1 {:.4f} | Train Macro f1 {:.4f} | '
              'Val Loss {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}'.format(
            epoch + 1, loss.item(), train_micro_f1, train_macro_f1, val_loss.item(), val_micro_f1, val_macro_f1))

        if early_stop:
            break

    total_times = datetime.timedelta(0)
    for ele in train_times:
        total_times += ele
    total_times = total_times / len(train_times)
    print('average train times', total_times)

    stopper.load_checkpoint(model)

    val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(model, g, features, new_labels, val_mask, loss_fcn)
    print('Val loss {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}'.format(
        val_loss.item(), val_micro_f1, val_macro_f1))
    test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate(model, g, features, new_labels, test_mask, loss_fcn)
    print('Test loss {:.4f} | Test Micro f1 {:.4f} | Test Macro f1 {:.4f}'.format(
        test_loss.item(), test_micro_f1, test_macro_f1))

if __name__ == '__main__':
    parser = argparse.ArgumentParser('HAN')
    parser.add_argument('-s', '--seeds',  nargs='+', type=int, default=[1],
                        help='Random seeds')
    parser.add_argument('-ld', '--log-dir', type=str, default='results',
                        help='Dir for saving training results')
    parser.add_argument('--hetero', action='store_true',
                        help='Use metapath coalescing with DGL\'s own dataset')
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--dataset', type=str, default='DBLP',
                        choices=['DBLP', 'ACM', 'ACM_HGB', 'Freebase'])
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--num_epochs', type=int, default=200)
    parser.add_argument('--patience', type=int, default=100)
    parser.add_argument("--lr", type=float, default=0.005)
    parser.add_argument("--num-heads", nargs='+',type=int, default=[8])
    parser.add_argument("--hidden-units", type=int, default=8)
    parser.add_argument("--dropout", type=float, default=0.6)
    parser.add_argument("--attn-dropout", type=float, default=0.6)
    parser.add_argument("--weight-decay", type=float, default=0.001)

    parser.add_argument("--use-gat", action='store_true', default=False)
    parser.add_argument("--remove-sa-attn", action='store_true', default=False)

    args = parser.parse_args()
    print(args)
    for seed in args.seeds:
        if seed > 0:
            set_random_seed(seed)
        args.log_dir = 'results'
        args.log_dir = setup_log_dir(args)
        main(args)