KGTOSA / GNN-Methods / NodeClassifcation / IBS / train / trainer.py
trainer.py
Raw
import logging
import os
import time
from collections import defaultdict
from math import ceil

import numpy as np
import torch
from sklearn.metrics import f1_score
from torch_sparse import SparseTensor

from dataloaders.BaseLoader import BaseLoader
from .prefetch_generators import BackgroundGenerator
from .train_utils import run_batch,run_rgcn_batch
from data.data_utils import MyGraph
from tqdm import tqdm
from models.RGCN import rgcn_test
# from torch.utils.tensorboard import SummaryWriter

device_type='cpu'
class Trainer:
    def __init__(self,
                 mode: str,
                 num_batches: int,
                 micro_batch: int = 1,
                 batch_size: int = 1,
                 epoch_max: int = 800,
                 epoch_min: int = 300,
                 patience: int = 100,
                 device: str = device_type):

        super().__init__()

        self.mode = mode
        self.device = device
        self.num_batches = num_batches
        self.batch_size = batch_size
        self.micro_batch = micro_batch
        self.epoch_max = epoch_max
        self.epoch_min = epoch_min
        self.patience = patience

        self.database = defaultdict(list)

    def get_loss_scaling(self, len_loader: int):
        if type(len_loader)==int:
            micro_batch = int(min(self.micro_batch, len_loader))
            loss_scaling_lst = [micro_batch] * (len_loader // micro_batch) + [len_loader % micro_batch]
        else:
            micro_batch = int(min(self.micro_batch, len_loader()))
            loss_scaling_lst = [micro_batch] * (len_loader() // micro_batch) + [len_loader() % micro_batch]
        return loss_scaling_lst, micro_batch

    def train(self,
              train_loader,
              self_val_loader,
              ppr_val_loader,
              batch_val_loader,
              model,
              lr,
              reg,
              comment='',
              run_no='',org_graph=None,x_feat=None):

        #         writer = SummaryWriter('./runs')
        patience_count = 0
        best_accs = {'train': 0., 'self': 0., 'part': 0., 'ppr': 0.}
        best_val_acc = -1.

        if not os.path.isdir('./saved_models'):
            os.mkdir('./saved_models')
        model_dir = os.path.join('./saved_models', comment)
        if not os.path.isdir(model_dir):
            os.mkdir(model_dir)
        model_path = os.path.join(model_dir, f'model_{run_no}.pt')

        # start training
        training_curve = defaultdict(list)

        # opt = torch.optim.Adam(model.p_list, lr=lr, weight_decay=reg)
        opt = torch.optim.Adam(model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.33, patience=30,
                                                               cooldown=10, min_lr=1e-4)
       ###################
        max_subgraphs=50
        ###############
        next_loader = BackgroundGenerator(train_loader)
        for epoch in range(self.epoch_max):
            logging.info(f"Epoch {epoch}")
            data_dic = {'self': {'loss': 0., 'acc': 0., 'num': 0},
                        'part': {'loss': 0., 'acc': 0., 'num': 0},
                        'train': {'loss': 0., 'acc': 0., 'num': 0},
                        'ppr': {'loss': 0., 'acc': 0., 'num': 0}, }

            update_count = 0

            # train
            model.train()
            loss_scaling_lst, cur_micro_batch = self.get_loss_scaling(train_loader.loader_len)
            loader, next_loader = next_loader, None

            start_time = time.time()
            # print ("len(loader)=",loader.len())
            pbar = tqdm(total=max_subgraphs)
            pbar.set_description(f"Train Loader Progress")
            subg_idx=0
            while True:
                data = loader.next()
                if subg_idx == max_subgraphs:
                    if batch_val_loader is not None:
                        next_loader = BackgroundGenerator(batch_val_loader)
                    elif ppr_val_loader is not None:
                        next_loader = BackgroundGenerator(ppr_val_loader)
                    else:
                        next_loader = BackgroundGenerator(self_val_loader)
                    break
                if data is None:
                    loader = None
                    loss = None
                    corrects = None
                    break
                else:
                    if data[1]:  # stop signal
                        if batch_val_loader is not None:
                            next_loader = BackgroundGenerator(batch_val_loader)
                        elif ppr_val_loader is not None:
                            next_loader = BackgroundGenerator(ppr_val_loader)
                        else:
                            next_loader = BackgroundGenerator(self_val_loader)
                if str(type(model)).split('.')[1]=='RGCN':
                    opt.zero_grad()
                    loss, corrects, num_nodes,pred,y,acc = run_rgcn_batch(model, data[0], loss_scaling_lst[0],org_graph,x=x_feat)
                    # print('Train Acc=', acc)
                    opt.step()
                    pbar.update(1)
                else:
                    loss, corrects, num_nodes, _, _ = run_batch(model, data[0], loss_scaling_lst[0])
                    pbar.update(1)
                data_dic['train']['loss'] += loss
                data_dic['train']['acc'] += corrects
                data_dic['train']['num'] += num_nodes
                update_count += 1
                # print("num_nodes=",num_nodes,"corrects=",corrects)

                if update_count >= cur_micro_batch:
                    loss_scaling_lst.pop(0)
                    opt.step()
                    opt.zero_grad()
                    update_count = 0
                subg_idx+=1
            pbar.close()

            # remainder
            if update_count:
                opt.step()
                opt.zero_grad()

            train_time = time.time() - start_time

            logging.info('After train loader -- '
                         f'Allocated: {torch.cuda.memory_allocated()}, '
                         f'Max allocated: {torch.cuda.max_memory_allocated()}, '
                         f'Reserved: {torch.cuda.memory_reserved()}')

            model.eval()

            # part val first, for fairness of all methods
            start_time = time.time()
            if batch_val_loader is not None:
                loader, next_loader = next_loader, None
                subg_idx = 0
                pbar = tqdm(total=max_subgraphs)
                pbar.set_description(f"Valid Loader Progress")
                while True:
                    data = loader.next()
                    if subg_idx == max_subgraphs:
                        if ppr_val_loader is not None:
                            next_loader = BackgroundGenerator(ppr_val_loader)
                        else:
                            next_loader = BackgroundGenerator(self_val_loader)
                        break
                    if data is None:
                        loader = None
                        loss = None
                        corrects = None
                        break
                    else:
                        if data[1]:  # stop signal
                            if ppr_val_loader is not None:
                                next_loader = BackgroundGenerator(ppr_val_loader)
                            else:
                                next_loader = BackgroundGenerator(self_val_loader)

                    with torch.no_grad():
                        if str(type(model)).split('.')[1] == 'RGCN':
                            loss, corrects, num_nodes, pred,y,acc = run_rgcn_batch(model, data[0], loss_scaling_lst[0],org_graph,x=x_feat)
                            # print('Valid Acc=',acc)
                            pbar.update(1)
                        else:
                            loss, corrects, num_nodes, _, _ = run_batch(model, data[0])
                            pbar.update(1)
                        data_dic['part']['loss'] += loss
                        data_dic['part']['acc'] += corrects
                        data_dic['part']['num'] += num_nodes

                    subg_idx += 1
                pbar.close()

            part_val_time = time.time() - start_time

            # ppr val
            start_time = time.time()
            if ppr_val_loader is not None:
                loader, next_loader = next_loader, None
                subg_idx = 0
                pbar = tqdm(total=max_subgraphs)
                pbar.set_description(f"ppr val Loader Progress")
                while True:
                    data = loader.next()
                    if subg_idx == max_subgraphs:
                        next_loader = BackgroundGenerator(self_val_loader)
                        break
                    if data is None:
                        loader = None
                        loss = None
                        corrects = None
                        break
                    else:
                        if data[1]:  # stop signal
                            next_loader = BackgroundGenerator(self_val_loader)

                    with torch.no_grad():
                        if str(type(model)).split('.')[1] == 'RGCN':
                            loss, corrects, num_nodes, pred,y,acc = run_rgcn_batch(model, data[0], loss_scaling_lst[0],org_graph,x=x_feat)
                            # print('ppr Acc=', acc)
                            pbar.update(1)
                        else:
                            loss, corrects, num_nodes, _, _ = run_batch(model, data[0])
                            pbar.update(1)
                        data_dic['ppr']['loss'] += loss
                        data_dic['ppr']['acc'] += corrects
                        data_dic['ppr']['num'] += num_nodes
                    subg_idx += 1
                pbar.close()

            ppr_val_time = time.time() - start_time

            # original val
            loader, next_loader = next_loader, None
            start_time = time.time()

            subg_idx = 0
            pbar = tqdm(total=max_subgraphs)
            pbar.set_description(f"original val Loader Progress")
            while True:
                data = loader.next()
                if subg_idx==max_subgraphs:
                    if epoch < self.epoch_max - 1:
                        next_loader = BackgroundGenerator(train_loader)
                    else:
                        next_loader = None
                    break
                if data is None:
                    loader = None
                    loss = None
                    corrects = None
                    break
                else:
                    if data[1]:  # stop signal
                        if epoch < self.epoch_max - 1:
                            next_loader = BackgroundGenerator(train_loader)
                        else:
                            next_loader = None

                with torch.no_grad():
                    if str(type(model)).split('.')[1] == 'RGCN':
                        loss, corrects, num_nodes, pred,y,acc = run_rgcn_batch(model, data[0], loss_scaling_lst[0], org_graph,x=x_feat)
                        # print('original val  Acc=', acc)
                        pbar.update(1)
                    else:
                        loss, corrects, num_nodes, _, _ = run_batch(model, data[0])
                        pbar.update(1)
                    data_dic['self']['loss'] += loss
                    data_dic['self']['acc'] += corrects
                    data_dic['self']['num'] += num_nodes
                subg_idx+=1
            pbar.close()

            self_val_time = time.time() - start_time

            # update training info
            for cat in ['train', 'self', 'part', 'ppr']:
                if data_dic[cat]['num'] > 0:
                    data_dic[cat]['loss'] = (data_dic[cat]['loss'] / data_dic[cat]['num']).item()
                    data_dic[cat]['acc'] = (data_dic[cat]['acc'] / data_dic[cat]['num']).item()
                best_accs[cat] = max(best_accs[cat], data_dic[cat]['acc'])

            # lr scheduler
            criterion_val_loss = data_dic['part']['loss'] if data_dic['part']['loss'] != 0 else data_dic['self']['loss']
            if scheduler is not None:
                scheduler.step(criterion_val_loss)

            # early stop
            criterion_val_acc = data_dic['part']['acc'] if data_dic['part']['acc'] != 0 else data_dic['self']['acc']
            if criterion_val_acc > best_val_acc:
                best_val_acc = criterion_val_acc
                patience_count = 0
                torch.save(model.state_dict(), model_path)
            else:
                patience_count += 1
                if epoch > self.epoch_min and patience_count > self.patience:
                    scheduler = None
                    opt = None
                    assert loader is None

                    if next_loader is not None:
                        next_loader.stop_signal = True
                        while next_loader.is_alive():
                            batch = next_loader.next()
                        next_loader = None
                    torch.cuda.empty_cache()
                    break

            logging.info(f"train_acc: {data_dic['train']['acc']:.5f}, "
                         f"self_val_acc: {data_dic['self']['acc']:.5f}, "
                         f"part_val_acc: {data_dic['part']['acc']:.5f}, "
                         f"ppr_val_acc: {data_dic['ppr']['acc']:.5f}, "
                         f"lr: {opt.param_groups[0]['lr']}, "
                         f"patience: {patience_count} / {self.patience}\n")

            # maintain curves
            training_curve['per_train_time'].append(train_time)
            training_curve['per_self_val_time'].append(self_val_time)
            training_curve['per_part_val_time'].append(part_val_time)
            training_curve['per_ppr_val_time'].append(ppr_val_time)
            training_curve['train_loss'].append(data_dic['train']['loss'])
            training_curve['train_acc'].append(data_dic['train']['acc'])
            training_curve['self_val_loss'].append(data_dic['self']['loss'])
            training_curve['self_val_acc'].append(data_dic['self']['acc'])
            training_curve['ppr_val_loss'].append(data_dic['ppr']['loss'])
            training_curve['ppr_val_acc'].append(data_dic['ppr']['acc'])
            training_curve['part_val_loss'].append(data_dic['part']['loss'])
            training_curve['part_val_acc'].append(data_dic['part']['acc'])
            training_curve['lr'].append(opt.param_groups[0]['lr'])

        #             writer.add_scalar('train_loss', data_dic['train']['loss'], epoch)
        #             writer.add_scalar('train_acc', data_dic['train']['acc'], epoch)
        #             writer.add_scalar('self_val_loss', data_dic['self']['loss'], epoch)
        #             writer.add_scalar('self_val_acc', data_dic['self']['acc'], epoch)

        # ending
        self.database['best_train_accs'].append(best_accs['train'])
        self.database['training_curves'].append(training_curve)

        logging.info(f"best train_acc: {best_accs['train']}, "
                     f"best self val_acc: {best_accs['self']}, "
                     f"best part val_acc: {best_accs['part']}"
                     f"best ppr val_acc: {best_accs['ppr']}")

        torch.cuda.empty_cache()
        # assert next_loader is None and loader is None

    #         writer.flush()
        return model

    def train_single_batch(self,
                           dataset,
                           model,
                           lr,
                           reg,
                           val_per_epoch=5,
                           comment='',
                           run_no=''):
        raise NotImplementedError

    @torch.no_grad()
    def inference(self,
                  self_val_loader,
                  ppr_val_loader,
                  batch_val_loader,
                  self_test_loader,
                  ppr_test_loader,
                  batch_test_loader,
                  model,
                  record_numbatch=False,org_graph=None,x_feat=None):

        cat_dict = {('self', 'val',): [self.database['self_val_accs'], self.database['self_val_f1s']],
                    ('part', 'val',): [self.database['part_val_accs'], self.database['part_val_f1s']],
                    ('ppr', 'val',): [self.database['ppr_val_accs'], self.database['ppr_val_f1s']],
                    ('self', 'test',): [self.database['self_test_accs'], self.database['self_test_f1s']],
                    ('part', 'test',): [self.database['part_test_accs'], self.database['part_test_f1s']],
                    ('ppr', 'test',): [self.database['ppr_test_accs'], self.database['ppr_test_f1s']], }

        loader_dict = {'val': {'self': self_val_loader, 'part': batch_val_loader, 'ppr': ppr_val_loader},
                       'test': {'self': self_test_loader, 'part': batch_test_loader, 'ppr': ppr_test_loader}}

        time_dict = {'self': self.database['self_inference_time'],
                     'part': self.database['part_inference_time'],
                     'ppr': self.database['ppr_inference_time']}

        print("start inference")
        for cat in ['val', 'test']:
            for sample in ['self', 'part', 'ppr']:
                acc, f1 = 0., 0.
                num_batch = 0
                if device_type=='cuda':
                    torch.cuda.synchronize()
                start_time = time.time()
                if loader_dict[cat][sample] is not None:
                    loader = BackgroundGenerator(loader_dict[cat][sample])

                    pred_labels = []
                    true_labels = []


                    max_subgraphs=100
                    pbar = tqdm(total=max_subgraphs)
                    pbar.set_description(f"inference Progress "+cat+"-"+sample)
                    while True:
                        data = loader.next()
                        if data is None:
                            del loader
                            break
                        if str(type(model)).split('.')[1] == 'RGCN':
                            loss, corrects, num_nodes, pred_label_batch, true_label_batch,acc = run_rgcn_batch(model, data[0],org_graph,x=x_feat)
                            if sample=='self':
                                print("graph=",data[0])
                            pbar.update(1)
                            # print('inference Acc=', acc)
                        else:
                            _, _, _, pred_label_batch, true_label_batch = run_batch(model, data[0])
                        pred_labels.append(pred_label_batch.detach())
                        true_labels.append(true_label_batch.detach())
                        num_batch += 1
                    pbar.close()

                    pred_labels = torch.cat(pred_labels, dim=0).cpu().numpy()
                    true_labels = torch.cat(true_labels, dim=0).cpu().numpy()

                    acc = (pred_labels == true_labels).sum() / len(true_labels)
                    f1 = f1_score(true_labels, pred_labels, average='macro', zero_division=0)

                cat_dict[(sample, cat,)][0].append(acc)
                cat_dict[(sample, cat,)][1].append(f1)

                if record_numbatch:
                    self.database[f'numbatch_{sample}_{cat}'].append(num_batch)

                logging.info("{}_{}_acc: {:.3f}, {}_{}_f1: {:.3f}, ".format(sample, cat, acc, sample, cat, f1))
                if cat == 'test':
                    if device_type=='cuda':
                        torch.cuda.synchronize()
                    time_dict[sample].append(time.time() - start_time)

    @torch.no_grad()
    def full_graph_inference(self,
                             model,
                             graph,
                             train_nodes,
                             val_nodes,
                             test_nodes,x_feat=None,key2int=None,subject_node='Paper' ):

        if isinstance(val_nodes, torch.Tensor):
            val_nodes = val_nodes.numpy()
        if isinstance(test_nodes, torch.Tensor):
            test_nodes = test_nodes.numpy()

        start_time = time.time()

        mask = np.union1d(val_nodes, test_nodes)
        val_mask = np.in1d(mask, val_nodes)
        test_mask = np.in1d(mask, test_nodes)
        assert np.all(np.invert(val_mask) == test_mask)


        adj = SparseTensor.from_edge_index(graph.edge_index, sparse_sizes=(graph.num_nodes, graph.num_nodes))
        adj = BaseLoader.normalize_adjmat(adj, normalization='sym')

        if str(type(model)).split('.')[1] == 'RGCN':
            train_acc, valid_acc, test_acc=rgcn_test(model, x_feat, graph, key2int, train_nodes,val_nodes,test_nodes,subject_node)

            # nodes = val_nodes if cat == 'val' else test_nodes
            # _mask = val_mask if cat == 'val' else test_mask
            # f1 = f1_score(true, pred, average='macro', zero_division=0)
            self.database[f'full_RGCN_train_acc'].append(train_acc)
            self.database[f'full_RGCN_valid_acc'].append(valid_acc)
            self.database[f'full_RGCN_test_acc'].append(test_acc)
        else:
            outputs = model.chunked_pass(MyGraph(x=graph.x, adj=adj, idx=torch.from_numpy(mask)),
                                         self.num_batches // self.batch_size).detach().numpy()  # an estimate of #chunks
            for cat in ['val', 'test']:
                nodes = val_nodes if cat == 'val' else test_nodes
                _mask = val_mask if cat == 'val' else test_mask
                pred = np.argmax(outputs[_mask], axis=1)
                true = graph.y.numpy()[nodes]

                acc = (pred == true).sum() / len(true)
                f1 = f1_score(true, pred, average='macro', zero_division=0)

                self.database[f'full_{cat}_accs'].append(acc)
                self.database[f'full_{cat}_f1s'].append(f1)

                logging.info("full_{}_acc: {:.3f}, full_{}_f1: {:.3f}, ".format(cat, acc, cat, f1))

        self.database['full_inference_time'].append(time.time() - start_time)

    @torch.no_grad()
    def full_graph_inference_hetero(self,
                             model,
                             graph,
                             val_nodes,
                             test_nodes, ):

        if isinstance(val_nodes, torch.Tensor):
            val_nodes = val_nodes.numpy()
        if isinstance(test_nodes, torch.Tensor):
            test_nodes = test_nodes.numpy()

        start_time = time.time()

        mask = np.union1d(val_nodes, test_nodes)
        val_mask = np.in1d(mask, val_nodes)
        test_mask = np.in1d(mask, test_nodes)
        assert np.all(np.invert(val_mask) == test_mask)

        output_node=list(graph.y_dict.keys())[0]
        for (s, p, o) in graph.edge_index_dict.keys():
            if (s == output_node):
                adj = SparseTensor.from_edge_index(graph.edge_index_dict[(s, p, o)], sparse_sizes=(
                graph.edge_index_dict[(s, p, o)].max() + 1, graph.edge_index_dict[(s, p, o)].max() + 1))
                adj = BaseLoader.normalize_adjmat(adj, normalization='sym')
                break ## ToDos get outputs based on graph not single adj

        outputs = model.chunked_pass(MyGraph(x=graph.x_dict[output_node], adj=adj, idx=torch.from_numpy(mask)),
                                     self.num_batches // self.batch_size).detach().numpy()  # an estimate of #chunks

        for cat in ['val', 'test']:
            nodes = val_nodes if cat == 'val' else test_nodes
            _mask = val_mask if cat == 'val' else test_mask
            pred = np.argmax(outputs[_mask], axis=1)
            true = graph.y_dict[output_node].numpy()[nodes]

            acc = (pred == true).sum() / len(true)
            f1 = f1_score(true, pred, average='macro', zero_division=0)

            self.database[f'full_{cat}_accs'].append(acc)
            self.database[f'full_{cat}_f1s'].append(f1)

            logging.info("full_{}_acc: {:.3f}, full_{}_f1: {:.3f}, ".format(cat, acc, cat, f1))

        self.database['full_inference_time'].append(time.time() - start_time)