AlertStar / alertstar.py
alertstar.py
Raw
# ============================================================================
# ALERTSTAR: COMPLETE END-TO-END NOTEBOOK WITH ALL METRICS
# Hyper-Relational Knowledge Graph for Cybersecurity Alert Prediction
#
# UPDATED: Now includes MR, MRR, Hits@1, Hits@3, Hits@10 for ALL models
#          including MultiTask predictions
#
# Models:
#   1. StarE          — qualifier attention baseline
#   2. ShrinkE        — shrinking transform baseline
#   3. NBFNet     — actual Bellman-Ford (memory-safe)
#   4. AlertStar      — StarE + NBFNet hybrid (ours)
#   5. StarQE         — complex query answering
#   6. NBFNet+StarQE  — hybrid complex queries
#   7. MultiTask-AS   — Transformer multi-task (ours, main)
#
# Dataset: Alert-33%, 66% and 100% qualifiers-Cybersecurity dataset
# ============================================================================


# ============================================================================
# 1 — Installs (uncomment if needed)
# ============================================================================
# !pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
# !pip install tqdm pandas matplotlib seaborn scikit-learn


# ============================================================================
# 2 — Imports & Config
# ============================================================================

import os, json, warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch_scatter import scatter_add, scatter_mean

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

CONFIG = {
    # paths
    'data_path':   '/.../inductive_q100_statements',
    'output_path': '/.../results_inductive_q100_h1/',

    # shared
    'embedding_dim': 200,
    'dropout':       0.2,
    'device':        device,

    # standard training (StarE / ShrinkE / AlertStar)
    'batch_size':    128,
    'learning_rate': 0.0005,
    'epochs':        20,

    # NBFNet specific
    'nbfnet_epochs':        20,
    'nbfnet_lr':            0.0005,
    'nbfnet_layers':        3,
    'nbfnet_chunk_size':    10000,  # edges per chunk — controls GPU memory
    'nbfnet_max_per_group': 8,      # tails per (h,r) group per step

    # complex query / multi-task
    'query_epochs': 20,
    'query_lr':     0.0005,
    'mt_epochs':    20,
    'mt_lr':        0.0005,
    'mt_batch':     64,
}

os.makedirs(CONFIG['output_path'], exist_ok=True)
all_results = {}          # stores final metrics for all models
print("Config ready")


# ============================================================================
# 3 — Data Loading & Preprocessing
# ============================================================================

class DataPreprocessor:
    def __init__(self):
        self.entity2id          = {}
        self.relation2id        = {}
        self.qualifier_key2id   = {}
        self.qualifier_value2id = {}

    def _parse(self, line):
        parts = line.strip().split(',')  # Changed from ',' to '\t'
        if len(parts) < 3:
            return None
        head, relation, tail = parts[0], parts[1], parts[2]

        # Parse qualifiers if present (format: "key:value | key:value")
        qualifiers = []
        if len(parts) > 3:
            qual_str = parts[3]
            for pair in qual_str.split('|'):
                pair = pair.strip()
                if ':' in pair:
                    key, value = pair.split(':', 1)
                    qualifiers.append((key.strip(), value.strip()))

        return {'head': head, 'relation': relation,
                'tail': tail, 'qualifiers': qualifiers}

    def _register(self, triple):
        for token, vocab in [(triple['head'],     self.entity2id),
                             (triple['tail'],     self.entity2id),
                             (triple['relation'], self.relation2id)]:
            if token not in vocab:
                vocab[token] = len(vocab)
        for qk, qv in triple['qualifiers']:
            if qk not in self.qualifier_key2id:
                self.qualifier_key2id[qk] = len(self.qualifier_key2id)
            if qv not in self.qualifier_value2id:
                self.qualifier_value2id[qv] = len(self.qualifier_value2id)

    def load(self, data_path):
        files = [('train', 'train.txt'),
                 ('valid', 'validation.txt'),
                 ('test',  'test.txt')]
        splits = {}
        for name, fname in files:
            data = []
            with open(os.path.join(data_path, fname)) as f:
                for line in f:
                    t = self._parse(line)
                    if t:
                        self._register(t)
                        data.append(t)
            splits[name] = data
            print(f"  {name}: {len(data):,}")
        print(f"  NE={len(self.entity2id)}  NR={len(self.relation2id)}  "
              f"NQK={len(self.qualifier_key2id)}  NQV={len(self.qualifier_value2id)}")
        return splits['train'], splits['valid'], splits['test']


class HRDataset(Dataset):
    def __init__(self, data, preprocessor):
        self.data = data
        self.p    = preprocessor

    def __len__(self): return len(self.data)

    def __getitem__(self, idx):
        t = self.data[idx]
        return {
            'head':      self.p.entity2id[t['head']],
            'relation':  self.p.relation2id[t['relation']],
            'tail':      self.p.entity2id[t['tail']],
            'qualifiers': [(self.p.qualifier_key2id[qk],
                            self.p.qualifier_value2id[qv])
                           for qk, qv in t['qualifiers']],
        }


def collate(batch):
    return {
        'head':      torch.tensor([b['head']     for b in batch], dtype=torch.long),
        'relation':  torch.tensor([b['relation'] for b in batch], dtype=torch.long),
        'tail':      torch.tensor([b['tail']     for b in batch], dtype=torch.long),
        'qualifiers': [b['qualifiers'] for b in batch],
    }


print("Loading data...")
preprocessor = DataPreprocessor()
train_data, valid_data, test_data = preprocessor.load(CONFIG['data_path'])

train_ds = HRDataset(train_data, preprocessor)
valid_ds = HRDataset(valid_data, preprocessor)
test_ds  = HRDataset(test_data,  preprocessor)

NE  = len(preprocessor.entity2id)
NR  = len(preprocessor.relation2id)
NQK = len(preprocessor.qualifier_key2id)
NQV = len(preprocessor.qualifier_value2id)
DIM = CONFIG['embedding_dim']
print("Data ready")


# ============================================================================
# 4 — UPDATED: Complete Evaluation Utilities with All Metrics
# ============================================================================

def evaluate_model(model, dataset, device, max_samples=500):
    """
    Complete filtered MRR/MR/Hits evaluation for link prediction.
    Returns ALL metrics: MR, MRR, Hits@1, Hits@3, Hits@10
    """
    model.eval()
    ranks = []
    with torch.no_grad():
        for i in tqdm(range(min(len(dataset), max_samples)),
                      desc="Evaluating", leave=False):
            s = dataset[i]
            h = torch.tensor([s['head']],     device=device)
            r = torch.tensor([s['relation']], device=device)
            t = s['tail']
            q = [s['qualifiers']]

            scores = model(h, r, q).squeeze()           # [NE]
            rank   = (torch.argsort(scores, descending=True) == t
                      ).nonzero(as_tuple=True)[0].item() + 1
            ranks.append(rank)

    ranks = np.array(ranks)
    return {
        'mr':      float(np.mean(ranks)),
        'mrr':     float(np.mean(1.0 / ranks)),
        'hits@1':  float(np.mean(ranks <= 1)),
        'hits@3':  float(np.mean(ranks <= 3)),
        'hits@10': float(np.mean(ranks <= 10)),
    }


def train_standard(model, model_name, lr=None, epochs=None):
    """Training loop for StarE / ShrinkE / AlertStar (standard dataset)."""
    dev  = CONFIG['device']
    lr   = lr    or CONFIG['learning_rate']
    eps  = epochs or CONFIG['epochs']
    model.to(dev)
    opt  = torch.optim.Adam(model.parameters(), lr=lr)
    loader = DataLoader(train_ds, batch_size=CONFIG['batch_size'],
                        shuffle=True, collate_fn=collate)

    print(f"\n{'='*60}")
    print(f"TRAINING {model_name}  "
          f"({sum(p.numel() for p in model.parameters()):,} params)")
    print(f"{'='*60}")

    best_mrr, history = 0.0, []
    for epoch in range(eps):
        model.train()
        total = 0.0
        pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{eps}")
        for b in pbar:
            h, r, t, q = (b['head'].to(dev), b['relation'].to(dev),
                          b['tail'].to(dev),  b['qualifiers'])
            pos = model(h, r, q, t)
            neg = model(h, r, q, torch.randint(0, NE, (len(h),), device=dev))
            loss = F.margin_ranking_loss(pos, neg,
                                         torch.ones_like(pos), margin=1.0)
            opt.zero_grad(); loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            total += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})

        if (epoch + 1) % 5 == 0:
            m = evaluate_model(model, valid_ds, dev)
            print(f"  Epoch {epoch+1}: loss={total/len(loader):.4f}  "
                  f"MRR={m['mrr']:.4f}  MR={m['mr']:.1f}  "
                  f"H@1={m['hits@1']:.4f}  H@3={m['hits@3']:.4f}  H@10={m['hits@10']:.4f}")
            if m['mrr'] > best_mrr:
                best_mrr = m['mrr']
                torch.save(model.state_dict(),
                           f"{CONFIG['output_path']}{model_name}_best.pt")
                print("    → best saved")
            history.append(m)

    return model, history


print("Utilities ready")


# ============================================================================
# 5 — Model 1: StarE
# ============================================================================

class StarEModel(nn.Module):
    def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1):
        super().__init__()
        self.num_entities = ne
        self.ent  = nn.Embedding(ne,  dim)
        self.rel  = nn.Embedding(nr,  dim)
        self.qk   = nn.Embedding(nqk, dim)
        self.qv   = nn.Embedding(nqv, dim)
        self.attn = nn.MultiheadAttention(dim, 4, dropout=dropout, batch_first=True)
        self.drop = nn.Dropout(dropout)
        for e in [self.ent, self.rel, self.qk, self.qv]:
            nn.init.xavier_uniform_(e.weight)

    def _enrich(self, r_emb, quals, dev):
        """Enrich relation with qualifier attention."""
        if not quals:
            return r_emb
        k = self.qk(torch.tensor([q[0] for q in quals], device=dev))
        v = self.qv(torch.tensor([q[1] for q in quals], device=dev))
        kv = (k + v).unsqueeze(0)                          # [1, nq, dim]
        out, _ = self.attn(r_emb.view(1,1,-1), kv, kv)
        return out.squeeze()

    def forward(self, head, relation, qualifiers, tail=None):
        dev = head.device
        h   = self.ent(head)                               # [B, dim]
        r   = torch.stack([
                self._enrich(self.rel(relation[i:i+1]).squeeze(),
                             qualifiers[i], dev)
                for i in range(head.size(0))])             # [B, dim]
        x   = self.drop(h + r)
        if tail is not None:
            return (x * self.ent(tail)).sum(-1)            # [B]
        return x @ self.ent.weight.t()                     # [B, NE]


print("StarE defined")


# ============================================================================
# 6 — Model 2: ShrinkE
# ============================================================================

class ShrinkEModel(nn.Module):
    def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1):
        super().__init__()
        self.num_entities = ne
        self.ent    = nn.Embedding(ne,  dim)
        self.rel    = nn.Embedding(nr,  dim)
        self.qk     = nn.Embedding(nqk, dim)
        self.qv     = nn.Embedding(nqv, dim)
        self.shrink = nn.Sequential(nn.Linear(dim*2, dim), nn.Tanh(),
                                    nn.Dropout(dropout))
        self.proj   = nn.Linear(dim, dim)
        self.drop   = nn.Dropout(dropout)
        for e in [self.ent, self.rel, self.qk, self.qv]:
            nn.init.xavier_uniform_(e.weight)

    def _shrink(self, r_emb, quals, dev):
        if not quals:
            return r_emb
        k = self.qk(torch.tensor([q[0] for q in quals], device=dev))
        v = self.qv(torch.tensor([q[1] for q in quals], device=dev))
        qc = (k + v).mean(0, keepdim=True)
        return self.shrink(torch.cat([r_emb.unsqueeze(0), qc], -1)).squeeze()

    def forward(self, head, relation, qualifiers, tail=None):
        dev = head.device
        h   = self.proj(self.ent(head))
        r   = torch.stack([
                self._shrink(self.rel(relation[i:i+1]).squeeze(),
                             qualifiers[i], dev)
                for i in range(head.size(0))])
        x   = self.drop(h + self.proj(r))
        if tail is not None:
            return (x * self.ent(tail)).sum(-1)
        return x @ self.ent.weight.t()


print("ShrinkE defined")


# ============================================================================
# 7 — Model 3: NBFNet (Bellman-Ford, memory-safe)
# ============================================================================

class NBFConvLayer(nn.Module):
    """One Bellman-Ford layer with chunked edge processing."""
    def __init__(self, dim, num_relation, chunk_size=10000,
                 message_func='distmult', layer_norm=True):
        super().__init__()
        self.chunk_size   = chunk_size
        self.message_func = message_func
        self.rel_emb  = nn.Embedding(num_relation, dim)
        self.linear   = nn.Linear(dim * 2, dim)
        self.ln       = nn.LayerNorm(dim) if layer_norm else None
        self.act      = nn.ReLU()
        nn.init.xavier_uniform_(self.rel_emb.weight)

    def forward(self, graph, node_feat):
        """node_feat : [N, 1, dim], returns : [N, 1, dim]"""
        edge_list = graph['edge_list']
        N, B, D   = node_feat.shape
        dev       = node_feat.device
        agg       = torch.zeros(N, B, D, device=dev)

        for start in range(0, edge_list.size(0), self.chunk_size):
            chunk = edge_list[start : start + self.chunk_size]
            src, dst, rel = chunk[:,0], chunk[:,1], chunk[:,2]

            sf  = node_feat[src]
            rf  = self.rel_emb(rel).unsqueeze(1)

            if self.message_func == 'distmult':
                msg = sf * rf
            elif self.message_func == 'transe':
                msg = sf + rf
            else:
                msg = sf * rf

            agg.scatter_add_(0, dst.view(-1,1,1).expand_as(msg), msg)
            del sf, rf, msg

        out = self.linear(torch.cat([node_feat, agg], dim=-1))
        if self.ln:
            s = out.shape
            out = self.ln(out.flatten(0,1)).view(s)
        return self.act(out)


class NBFNet(nn.Module):
    """Neural Bellman-Ford Network."""
    def __init__(self, ne, nr, dim=200, num_layers=3,
                 short_cut=True, chunk_size=10000, dropout=0.1):
        super().__init__()
        self.num_entities = ne
        self.dim          = dim
        self.short_cut    = short_cut
        nr2               = nr * 2

        self.query_emb = nn.Embedding(nr2, dim)
        self.layers    = nn.ModuleList([
            NBFConvLayer(dim, nr2, chunk_size=chunk_size)
            for _ in range(num_layers)
        ])
        self.mlp = nn.Sequential(
            nn.Linear(dim*2, dim), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(dim, 1)
        )
        nn.init.xavier_uniform_(self.query_emb.weight)

    def _bellman_ford(self, graph, h, r, device):
        """Single (h, r) BF propagation → [N, feature_dim]"""
        N   = graph['num_nodes']
        q   = self.query_emb(torch.tensor([r], device=device))

        feat = torch.zeros(N, 1, self.dim, device=device)
        feat[h, 0] = q[0]

        for layer in self.layers:
            h_new = layer(graph, feat)
            if self.short_cut:
                h_new = h_new + feat
            feat = h_new

        node_q = q.unsqueeze(0).expand(N, -1, -1)
        out    = torch.cat([feat, node_q], dim=-1).squeeze(1)
        return out

    def forward(self, head, relation, qualifiers=None, tail=None, graph=None):
        assert graph is not None, "NBFNet requires graph dict"
        assert (head == head[0]).all() and (relation == relation[0]).all(), \
               "All batch items must share the same (head, relation)"

        dev   = head.device
        feat  = self._bellman_ford(graph, head[0].item(),
                                   relation[0].item(), dev)

        if tail is not None:
            score = self.mlp(feat[tail]).squeeze(-1)
        else:
            B     = head.size(0)
            score = self.mlp(feat).squeeze(-1)
            score = score.unsqueeze(0).expand(B, -1)
        return score


def build_nbfnet_graph(train_data, preprocessor, device):
    """Build [E,3] edge tensor with inverse edges."""
    nr   = len(preprocessor.relation2id)
    rows = []
    for t in train_data:
        h = preprocessor.entity2id[t['head']]
        r = preprocessor.relation2id[t['relation']]
        tl = preprocessor.entity2id[t['tail']]
        rows += [(h, tl, r), (tl, h, r + nr)]
    return {
        'edge_list': torch.tensor(rows, dtype=torch.long, device=device),
        'num_nodes': len(preprocessor.entity2id),
    }


def train_true_nbfnet(model, model_name="NBFNet"):
    """Training loop that groups by (h,r) so BF runs once per group."""
    dev  = CONFIG['device']
    lr   = CONFIG['nbfnet_lr']
    eps  = CONFIG['nbfnet_epochs']
    mpg  = CONFIG['nbfnet_max_per_group']

    model.to(dev)
    graph = {k: v.to(dev) if torch.is_tensor(v) else v
             for k, v in nbfnet_graph.items()}
    opt  = torch.optim.Adam(model.parameters(), lr=lr)

    groups = defaultdict(list)
    for i, s in enumerate(train_data):
        h = preprocessor.entity2id[s['head']]
        r = preprocessor.relation2id[s['relation']]
        groups[(h, r)].append(i)
    keys = list(groups.keys())

    print(f"\n{'='*60}\nTRAINING {model_name}  "
          f"({sum(p.numel() for p in model.parameters()):,} params)")
    print(f"Graph: {graph['edge_list'].size(0):,} edges, "
          f"{graph['num_nodes']} nodes  |  "
          f"{len(keys)} unique (h,r) pairs\n{'='*60}")

    best_mrr, history = 0.0, []

    for epoch in range(eps):
        model.train()
        np.random.shuffle(keys)
        total, cnt = 0.0, 0

        pbar = tqdm(keys, desc=f"Epoch {epoch+1}/{eps}")
        for (h, r) in pbar:
            chosen = np.random.choice(groups[(h,r)],
                                      min(len(groups[(h,r)]), mpg),
                                      replace=False)
            t_pos = torch.tensor(
                [preprocessor.entity2id[train_data[i]['tail']] for i in chosen],
                device=dev)
            B = len(t_pos)
            heads = torch.full((B,), h, dtype=torch.long, device=dev)
            rels  = torch.full((B,), r, dtype=torch.long, device=dev)

            pos = model(heads, rels, tail=t_pos, graph=graph)
            neg = model(heads, rels,
                        tail=torch.randint(0, NE, (B,), device=dev),
                        graph=graph)
            loss = F.margin_ranking_loss(pos, neg,
                                         torch.ones(B, device=dev), margin=1.0)
            opt.zero_grad(); loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            total += loss.item(); cnt += 1
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})

        if (epoch + 1) % 5 == 0:
            m = evaluate_nbfnet(model, valid_ds, graph, dev)
            print(f"  Epoch {epoch+1}: loss={total/cnt:.4f}  "
                  f"MRR={m['mrr']:.4f}  MR={m['mr']:.1f}  "
                  f"H@1={m['hits@1']:.4f}  H@3={m['hits@3']:.4f}  H@10={m['hits@10']:.4f}")
            if m['mrr'] > best_mrr:
                best_mrr = m['mrr']
                torch.save(model.state_dict(),
                           f"{CONFIG['output_path']}{model_name}_best.pt")
                print("    → best saved")
            history.append(m)

    return model, history


def evaluate_nbfnet(model, dataset, graph, device, max_groups=300):
    """Evaluation that also groups by (h,r) for BF with ALL metrics."""
    model.eval()
    groups = defaultdict(list)
    for i, s in enumerate(dataset.data):
        h = preprocessor.entity2id[s['head']]
        r = preprocessor.relation2id[s['relation']]
        groups[(h, r)].append(i)

    ranks = []
    with torch.no_grad():
        for (h, r) in tqdm(list(groups.keys())[:max_groups],
                           desc="Evaluating", leave=False):
            idxs  = groups[(h, r)]
            tails = [preprocessor.entity2id[dataset.data[i]['tail']]
                     for i in idxs]
            B     = len(tails)
            heads = torch.full((B,), h, dtype=torch.long, device=device)
            rels  = torch.full((B,), r, dtype=torch.long, device=device)
            try:
                scores = model(heads, rels, graph=graph)
                for i, tgt in enumerate(tails):
                    pos = (torch.argsort(scores[i], descending=True
                                        ) == tgt).nonzero(as_tuple=True)[0]
                    if len(pos):
                        ranks.append(pos[0].item() + 1)
            except Exception:
                continue

    if not ranks:
        return {'mr':0., 'mrr':0., 'hits@1':0., 'hits@3':0., 'hits@10':0.}
    ranks = np.array(ranks)
    return {
        'mr':      float(np.mean(ranks)),
        'mrr':     float(np.mean(1./ranks)),
        'hits@1':  float(np.mean(ranks<=1)),
        'hits@3':  float(np.mean(ranks<=3)),
        'hits@10': float(np.mean(ranks<=10)),
    }


print("Building graph for NBFNet...")
nbfnet_graph = build_nbfnet_graph(train_data, preprocessor, CONFIG['device'])
print(f" NBFNet ready  "
      f"(graph: {nbfnet_graph['edge_list'].size(0):,} edges, "
      f"{nbfnet_graph['num_nodes']} nodes)")


# ============================================================================
# 8 — Model 4: AlertStar
# ============================================================================

class AlertStarModel(nn.Module):
    """AlertStar: qualifier-aware attention (StarE) fused with path composition."""
    def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1):
        super().__init__()
        self.num_entities = ne

        self.ent  = nn.Embedding(ne,  dim)
        self.rel  = nn.Embedding(nr,  dim)
        self.qk   = nn.Embedding(nqk, dim)
        self.qv   = nn.Embedding(nqv, dim)

        self.attn = nn.MultiheadAttention(dim, 4, dropout=dropout, batch_first=True)
        self.ln1  = nn.LayerNorm(dim)

        self.path_net = nn.Sequential(
            nn.Linear(dim*2, dim), nn.LayerNorm(dim), nn.ReLU(),
            nn.Dropout(dropout), nn.Linear(dim, dim)
        )
        self.ln2  = nn.LayerNorm(dim)

        self.gate = nn.Parameter(torch.tensor(0.5))
        self.drop = nn.Dropout(dropout)

        for e in [self.ent, self.rel, self.qk, self.qv]:
            nn.init.xavier_uniform_(e.weight)

    def _enrich(self, r_emb, quals, dev):
        if not quals:
            return r_emb
        k  = self.qk(torch.tensor([q[0] for q in quals], device=dev))
        v  = self.qv(torch.tensor([q[1] for q in quals], device=dev))
        kv = (k + v).unsqueeze(0)
        out, _ = self.attn(r_emb.view(1,1,-1), kv, kv)
        return out.squeeze()

    def forward(self, head, relation, qualifiers, tail=None):
        dev = head.device
        h   = self.ent(head)
        r   = torch.stack([
                self._enrich(self.rel(relation[i:i+1]).squeeze(),
                             qualifiers[i], dev)
                for i in range(head.size(0))])

        stare = self.ln1(h + r)
        path  = self.path_net(torch.cat([h, r], dim=-1))
        path  = self.ln2(h + path)

        g   = torch.sigmoid(self.gate)
        x   = self.drop(g * stare + (1 - g) * path)

        if tail is not None:
            return (x * self.ent(tail)).sum(-1)
        return x @ self.ent.weight.t()


print("AlertStar defined")


# ============================================================================
# 9 — Model 5: StarQE (Complex Query Answering)
# ============================================================================

def generate_complex_queries(train_data, preprocessor, n=500):
    """Generate 1p / 2p / 2i / 2u query structures from training triples."""
    e2id = preprocessor.entity2id
    r2id = preprocessor.relation2id

    hr2t = defaultdict(list)
    for t in train_data:
        h = e2id[t['head']]; r = r2id[t['relation']]; tl = e2id[t['tail']]
        hr2t[(h, r)].append(tl)

    triples = [(e2id[t['head']], r2id[t['relation']], e2id[t['tail']])
               for t in train_data]
    queries = {'1p': [], '2p': [], '2i': [], '2u': []}

    np.random.shuffle(triples)
    for h, r, t in triples[:min(len(triples), n*4)]:
        if len(queries['1p']) < n:
            queries['1p'].append({'anchor': h, 'relations': [r], 'answer': t})

        for r2, t2_list in hr2t.items():
            if r2[0] == t and t2_list and len(queries['2p']) < n:
                queries['2p'].append({
                    'anchor': h, 'relations': [r, r2[1]], 'answer': t2_list[0]})
                break

        if len(queries['2i']) < n:
            candidates = [(h2,r2) for (h2,r2), tl in hr2t.items() if t in tl and h2!=h]
            if candidates:
                h2, r2 = candidates[0]
                queries['2i'].append({
                    'anchors': [h, h2], 'relations': [r, r2], 'answer': t})

        if len(queries['2u']) < n:
            candidates = [(h2,r2) for (h2,r2), tl in hr2t.items()
                          if tl and h2!=h][:1]
            if candidates:
                h2, r2 = candidates[0]
                queries['2u'].append({
                    'anchors': [h, h2], 'relations': [r, r2], 'answer': t})

        if all(len(v) >= n for v in queries.values()):
            break

    print("Complex query stats:")
    for k, v in queries.items():
        print(f"  {k}: {len(v)}")
    return queries


class StarQEModel(nn.Module):
    """StarQE: qualifier-aware complex query answering."""
    def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1):
        super().__init__()
        self.num_entities = ne
        self.ent  = nn.Embedding(ne,  dim)
        self.rel  = nn.Embedding(nr,  dim)
        self.qk   = nn.Embedding(nqk, dim)
        self.qv   = nn.Embedding(nqv, dim)
        self.attn = nn.MultiheadAttention(dim, 4, dropout=dropout, batch_first=True)
        self.proj  = nn.Sequential(nn.Linear(dim, dim), nn.ReLU())
        self.inter = nn.Sequential(nn.Linear(dim*2, dim), nn.ReLU())
        self.drop  = nn.Dropout(dropout)
        for e in [self.ent, self.rel, self.qk, self.qv]:
            nn.init.xavier_uniform_(e.weight)

    def _compose(self, anchor_emb, relations):
        """Compose anchor + chain of relations."""
        x = anchor_emb
        for r in relations:
            r_emb = self.rel(torch.tensor([r], device=x.device)).squeeze()
            x     = self.proj(x + r_emb)
        return x

    def answer_query(self, query_type, query):
        dev = CONFIG['device']
        if query_type == '1p':
            h  = self.ent(torch.tensor([query['anchor']], device=dev)).squeeze()
            return self._compose(h, query['relations'])

        elif query_type == '2p':
            h  = self.ent(torch.tensor([query['anchor']], device=dev)).squeeze()
            return self._compose(h, query['relations'])

        elif query_type == '2i':
            h1 = self.ent(torch.tensor([query['anchors'][0]], device=dev)).squeeze()
            h2 = self.ent(torch.tensor([query['anchors'][1]], device=dev)).squeeze()
            e1 = self._compose(h1, [query['relations'][0]])
            e2 = self._compose(h2, [query['relations'][1]])
            return self.inter(torch.cat([e1, e2], -1))

        elif query_type == '2u':
            h1 = self.ent(torch.tensor([query['anchors'][0]], device=dev)).squeeze()
            h2 = self.ent(torch.tensor([query['anchors'][1]], device=dev)).squeeze()
            e1 = self._compose(h1, [query['relations'][0]])
            e2 = self._compose(h2, [query['relations'][1]])
            return (e1 + e2) / 2

    def forward(self, head, relation, qualifiers, tail=None):
        """Standard link-pred interface (1p queries only)."""
        dev = head.device
        h   = self.ent(head)
        r   = self.rel(relation)
        x   = self.drop(self.proj(h + r))
        if tail is not None:
            return (x * self.ent(tail)).sum(-1)
        return x @ self.ent.weight.t()


def train_query_model(model, complex_queries, model_name):
    """Train on complex queries."""
    dev = CONFIG['device']
    model.to(dev)
    opt = torch.optim.Adam(model.parameters(), lr=CONFIG['query_lr'])

    print(f"\n{'='*60}\nTRAINING {model_name}  "
          f"({sum(p.numel() for p in model.parameters()):,} params)\n{'='*60}")

    best_mrr, history = 0.0, []
    all_queries = [(qt, q) for qt, qs in complex_queries.items() for q in qs]

    for epoch in range(CONFIG['query_epochs']):
        model.train()
        np.random.shuffle(all_queries)
        total = 0.0

        pbar = tqdm(all_queries, desc=f"Epoch {epoch+1}/{CONFIG['query_epochs']}")
        for qt, q in pbar:
            qvec   = model.answer_query(qt, q).unsqueeze(0)
            answer = torch.tensor([q['answer']], device=dev)
            pos    = (qvec * model.ent(answer)).sum(-1)
            neg_id = torch.randint(0, NE, (1,), device=dev)
            neg    = (qvec * model.ent(neg_id)).sum(-1)
            loss   = F.margin_ranking_loss(pos, neg,
                                           torch.ones(1, device=dev), margin=1.0)
            opt.zero_grad(); loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            total += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})

        if (epoch + 1) % 5 == 0:
            m  = evaluate_complex_queries(model, complex_queries)
            mr = np.mean([v['mrr'] for v in m.values()])
            print(f"  Epoch {epoch+1}: loss={total/len(all_queries):.4f}  "
                  f"Avg-MRR={mr:.4f}")
            if mr > best_mrr:
                best_mrr = mr
                torch.save(model.state_dict(),
                           f"{CONFIG['output_path']}{model_name}_best.pt")
                print("    → best saved")
            history.append({'epoch': epoch+1, **m})

    return model, history


def evaluate_complex_queries(model, complex_queries, max_per_type=200):
    """Evaluate all query types with ALL metrics."""
    model.eval()
    results = {}
    with torch.no_grad():
        for qt, queries in complex_queries.items():
            ranks = []
            for q in queries[:max_per_type]:
                try:
                    qvec   = model.answer_query(qt, q)
                    scores = qvec @ model.ent.weight.t()
                    rank   = (torch.argsort(scores, descending=True
                                           ) == q['answer']
                              ).nonzero(as_tuple=True)[0].item() + 1
                    ranks.append(rank)
                except Exception:
                    continue
            if ranks:
                ranks = np.array(ranks)
                results[qt] = {
                    'mr':     float(np.mean(ranks)),
                    'mrr':    float(np.mean(1./ranks)),
                    'hits@1': float(np.mean(ranks<=1)),
                    'hits@3': float(np.mean(ranks<=3)),
                    'hits@10':float(np.mean(ranks<=10)),
                }
            else:
                results[qt] = {'mr':0., 'mrr':0., 'hits@1':0., 'hits@3':0., 'hits@10':0.}
    return results


print("StarQE defined")


# ============================================================================
# 10 — Model 6: NBFNet+StarQE
# ============================================================================

class NBFNetStarQE(StarQEModel):
    """Extends StarQE with path-composition from NBFNet."""
    def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1):
        super().__init__(ne, nr, nqk, nqv, dim, dropout)
        self.path_net = nn.Sequential(
            nn.Linear(dim*2, dim), nn.LayerNorm(dim), nn.ReLU(),
            nn.Dropout(dropout), nn.Linear(dim, dim)
        )

    def _compose(self, anchor_emb, relations):
        """Path composition using NBFNet-style network."""
        x = anchor_emb
        for r in relations:
            r_emb = self.rel(torch.tensor([r], device=x.device)).squeeze()
            x     = anchor_emb + self.path_net(torch.cat([x, r_emb], -1))
        return x


print("NBFNet+StarQE defined")


# ============================================================================
# 11 — UPDATED: Model 7: MultiTask AlertStar with ALL Metrics
# ============================================================================

class MultiTaskDataset(Dataset):
    """Each hyper-relational triple generates 4 tasks."""
    def __init__(self, data, preprocessor):
        self.samples = []
        p = preprocessor
        for triple in data:
            h  = p.entity2id[triple['head']]
            r  = p.relation2id[triple['relation']]
            t  = p.entity2id[triple['tail']]
            qs = [(p.qualifier_key2id[qk], p.qualifier_value2id[qv])
                  for qk, qv in triple['qualifiers']]

            self.samples.append({'task':'tail',     'h':h,'r':r,'t':t,'qs':qs})
            self.samples.append({'task':'relation', 'h':h,'r':r,'t':t,'qs':qs})
            if qs:
                keys = [qk for qk, _ in qs]
                self.samples.append({'task':'qual_key','h':h,'r':r,'t':t,
                                     'qs':qs,'keys':keys})
                for qk, qv in qs:
                    self.samples.append({'task':'qual_value','h':h,'r':r,'t':t,
                                         'qs':qs,'qk':qk,'qv':qv})

    def __len__(self): return len(self.samples)
    def __getitem__(self, idx): return self.samples[idx]


def mt_collate(batch):
    """Group by task, pad as needed."""
    by_task = defaultdict(list)
    for s in batch:
        by_task[s['task']].append(s)
    return by_task


class MultiTaskAlertStar(nn.Module):
    """Transformer-based multi-task model."""
    def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1,
                 n_heads=4, n_layers=3):
        super().__init__()
        self.num_entities = ne

        self.ent  = nn.Embedding(ne,  dim)
        self.rel  = nn.Embedding(nr,  dim)
        self.qk   = nn.Embedding(nqk, dim)
        self.qv   = nn.Embedding(nqv, dim)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=dim, nhead=n_heads,
            dim_feedforward=dim*4, dropout=dropout, batch_first=True)
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)

        def head(out):
            return nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim),
                                 nn.ReLU(), nn.Dropout(dropout),
                                 nn.Linear(dim, out))

        self.tail_head  = head(ne)
        self.rel_head   = head(nr)
        self.qk_head    = head(nqk)
        self.qv_head    = head(nqv)

        for e in [self.ent, self.rel, self.qk, self.qv]:
            nn.init.xavier_uniform_(e.weight)

    def _encode(self, h, r, t, qs, dev, mask_pos=None):
        """Build token sequence and encode with Transformer."""
        tokens = [self.ent(torch.tensor([h], device=dev)),
                  self.rel(torch.tensor([r], device=dev)),
                  self.ent(torch.tensor([t], device=dev))]

        for qk_id, qv_id in qs:
            tokens.append(self.qk(torch.tensor([qk_id], device=dev)))
            tokens.append(self.qv(torch.tensor([qv_id], device=dev)))

        seq = torch.cat(tokens, dim=0).unsqueeze(0)

        if mask_pos is not None and mask_pos < seq.size(1):
            seq = seq.clone()
            seq[0, mask_pos] = 0.0

        out = self.encoder(seq)
        return out[0, 0]

    def forward_task(self, task, sample, dev):
        h, r, t, qs = sample['h'], sample['r'], sample['t'], sample['qs']

        if task == 'tail':
            ctx = self._encode(h, r, 0, qs, dev, mask_pos=2)
            return self.tail_head(ctx)

        elif task == 'relation':
            ctx = self._encode(h, r, t, qs, dev, mask_pos=1)
            return self.rel_head(ctx)

        elif task == 'qual_key':
            ctx = self._encode(h, r, t, [], dev)
            return self.qk_head(ctx)

        elif task == 'qual_value':
            qk = sample['qk']
            filtered = [(k, v) for k, v in qs if k != qk]
            ctx = self._encode(h, r, t, filtered, dev)
            return self.qv_head(ctx)


def train_multitask(model, model_name="MultiTask_AlertStar"):
    dev  = CONFIG['device']
    model.to(dev)
    opt  = torch.optim.Adam(model.parameters(), lr=CONFIG['mt_lr'])

    mt_train = MultiTaskDataset(train_data, preprocessor)
    mt_valid = MultiTaskDataset(valid_data, preprocessor)

    loader = DataLoader(mt_train, batch_size=CONFIG['mt_batch'],
                        shuffle=True, collate_fn=mt_collate)

    weights = {'tail': 1.0, 'relation': 1.0,
               'qual_key': 0.5, 'qual_value': 0.8}

    print(f"\n{'='*60}\nTRAINING {model_name}  "
          f"({sum(p.numel() for p in model.parameters()):,} params)\n{'='*60}")
    print(f"  Train samples: {len(mt_train):,}  "
          f"Valid samples: {len(mt_valid):,}")

    best_avg, history = 0.0, []

    for epoch in range(CONFIG['mt_epochs']):
        model.train()
        total_loss = 0.0; cnt = 0

        pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{CONFIG['mt_epochs']}")
        for by_task in pbar:
            batch_loss = torch.tensor(0., device=dev)

            for task, samples in by_task.items():
                for s in samples:
                    try:
                        logits = model.forward_task(task, s, dev)
                        w      = weights.get(task, 1.0)

                        if task == 'tail':
                            tgt  = torch.tensor(s['t'], device=dev)
                            loss = F.cross_entropy(logits.unsqueeze(0),
                                                   tgt.unsqueeze(0))
                        elif task == 'relation':
                            tgt  = torch.tensor(s['r'], device=dev)
                            loss = F.cross_entropy(logits.unsqueeze(0),
                                                   tgt.unsqueeze(0))
                        elif task == 'qual_key':
                            tgt  = torch.zeros(NQK, device=dev)
                            tgt[s['keys']] = 1.0
                            loss = F.binary_cross_entropy_with_logits(logits, tgt)
                        elif task == 'qual_value':
                            tgt  = torch.tensor(s['qv'], device=dev)
                            loss = F.cross_entropy(logits.unsqueeze(0),
                                                   tgt.unsqueeze(0))
                        else:
                            continue

                        batch_loss = batch_loss + w * loss
                    except Exception:
                        continue

            opt.zero_grad(); batch_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            total_loss += batch_loss.item(); cnt += 1
            pbar.set_postfix({'loss': f'{batch_loss.item():.4f}'})

        if (epoch + 1) % 5 == 0:
            m = evaluate_multitask(model, mt_valid, dev)
            avg = np.mean([m['tail']['mrr'],
                           m['relation']['mrr'],
                           m['qual_key']['f1'],
                           m['qual_value']['mrr']])  # Changed from accuracy to mrr
            print(f"  Epoch {epoch+1}: loss={total_loss/cnt:.4f}  avg={avg:.4f}")
            print(f"    tail: MR={m['tail']['mr']:.1f} MRR={m['tail']['mrr']:.4f} "
                  f"H@1={m['tail']['hits@1']:.4f} H@3={m['tail']['hits@3']:.4f} "
                  f"H@10={m['tail']['hits@10']:.4f}")
            print(f"    relation: MR={m['relation']['mr']:.1f} MRR={m['relation']['mrr']:.4f} "
                  f"Acc={m['relation']['accuracy']:.4f}")
            print(f"    qual-key: F1={m['qual_key']['f1']:.4f}")
            print(f"    qual-value: MR={m['qual_value']['mr']:.1f} MRR={m['qual_value']['mrr']:.4f} "
                  f"Acc={m['qual_value']['accuracy']:.4f}")
            if avg > best_avg:
                best_avg = avg
                torch.save(model.state_dict(),
                           f"{CONFIG['output_path']}{model_name}_best.pt")
                print("    → best saved")
            history.append({'epoch': epoch+1, **m})

    return model, history


def evaluate_multitask(model, mt_dataset, device, max_per_task=200):
    """UPDATED: Evaluate multitask with ALL metrics including MR, MRR, Hits for all tasks."""
    model.eval()
    preds = defaultdict(list)

    by_task = defaultdict(list)
    for s in mt_dataset.samples:
        by_task[s['task']].append(s)

    results = {}
    with torch.no_grad():

        # ── tail prediction (ALL METRICS) ───────────────────────────────
        ranks = []
        for s in by_task['tail'][:max_per_task]:
            try:
                logits = model.forward_task('tail', s, device)
                rank = (torch.argsort(logits, descending=True
                                     ) == s['t']).nonzero(as_tuple=True
                                                         )[0].item() + 1
                ranks.append(rank)
            except Exception: continue
        ranks = np.array(ranks) if ranks else np.array([NE])
        results['tail'] = {
            'mr':      float(np.mean(ranks)),
            'mrr':     float(np.mean(1./ranks)),
            'hits@1':  float(np.mean(ranks<=1)),
            'hits@3':  float(np.mean(ranks<=3)),
            'hits@10': float(np.mean(ranks<=10)),
        }

        # ── relation prediction (ALL METRICS) ───────────────────────────
        ranks = []
        correct = []
        for s in by_task['relation'][:max_per_task]:
            try:
                logits = model.forward_task('relation', s, device)
                rank   = (torch.argsort(logits, descending=True
                                       ) == s['r']).nonzero(as_tuple=True
                                                           )[0].item() + 1
                ranks.append(rank)
                pred   = logits.argmax().item()
                correct.append(int(pred == s['r']))
            except Exception: continue
        ranks = np.array(ranks) if ranks else np.array([NR])
        results['relation'] = {
            'mr':       float(np.mean(ranks)),
            'mrr':      float(np.mean(1./ranks)),
            'hits@1':   float(np.mean(ranks<=1)),
            'hits@3':   float(np.mean(ranks<=3)),
            'hits@10':  float(np.mean(ranks<=10)),
            'accuracy': float(np.mean(correct)) if correct else 0.0
        }

        # ── qualifier key prediction (multi-label) ──────────────────────
        prec_list, rec_list = [], []
        for s in by_task['qual_key'][:max_per_task]:
            try:
                logits = model.forward_task('qual_key', s, device)
                pred   = (torch.sigmoid(logits) > 0.5).cpu().numpy()
                true   = np.zeros(NQK); true[s['keys']] = 1
                tp = (pred * true).sum()
                prec_list.append(tp / max(pred.sum(), 1))
                rec_list.append(tp / max(true.sum(), 1))
            except Exception: continue
        p = float(np.mean(prec_list)) if prec_list else 0.
        r = float(np.mean(rec_list))  if rec_list  else 0.
        results['qual_key'] = {
            'precision': p, 'recall': r,
            'f1': 2*p*r/(p+r+1e-8)}

        # ── qualifier value prediction (ALL METRICS like HyNT) ─────────
        ranks = []
        correct = []
        for s in by_task['qual_value'][:max_per_task]:
            try:
                logits = model.forward_task('qual_value', s, device)
                # Ranking metrics (like HyNT)
                rank = (torch.argsort(logits, descending=True
                                     ) == s['qv']).nonzero(as_tuple=True
                                                          )[0].item() + 1
                ranks.append(rank)
                # Accuracy metric
                pred = logits.argmax().item()
                correct.append(int(pred == s['qv']))
            except Exception: continue
        ranks = np.array(ranks) if ranks else np.array([NQV])
        results['qual_value'] = {
            'mr':       float(np.mean(ranks)),
            'mrr':      float(np.mean(1./ranks)),
            'hits@1':   float(np.mean(ranks<=1)),
            'hits@3':   float(np.mean(ranks<=3)),
            'hits@10':  float(np.mean(ranks<=10)),
            'accuracy': float(np.mean(correct)) if correct else 0.0
        }

    return results


print("MultiTask AlertStar defined")


# ============================================================================
# CELL 12 — Train All Models
# ============================================================================

print("\n" + "="*70)
print("STARTING MODEL TRAINING")
print("="*70)

# ── StarE ──────────────────────────────────────────────────────────────────
stare_model = StarEModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout'])
stare_model, _ = train_standard(stare_model, "StarE")
all_results['StarE'] = evaluate_model(stare_model, test_ds, CONFIG['device'])

# ── ShrinkE ────────────────────────────────────────────────────────────────
shrinke_model = ShrinkEModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout'])
shrinke_model, _ = train_standard(shrinke_model, "ShrinkE")
all_results['ShrinkE'] = evaluate_model(shrinke_model, test_ds, CONFIG['device'])

# ── NBFNet ─────────────────────────────────────────────────────────────
nbfnet_model = NBFNet(
    NE, NR, dim=DIM,
    num_layers  = CONFIG['nbfnet_layers'],
    chunk_size  = CONFIG['nbfnet_chunk_size'],
    dropout     = CONFIG['dropout'],
    short_cut   = True,
)
nbfnet_model, _ = train_true_nbfnet(nbfnet_model, "NBFNet")
nbfnet_graph_dev = {k: v.to(CONFIG['device']) if torch.is_tensor(v) else v
                    for k, v in nbfnet_graph.items()}
all_results['NBFNet'] = evaluate_nbfnet(
    nbfnet_model, test_ds, nbfnet_graph_dev, CONFIG['device'])

# ── AlertStar ──────────────────────────────────────────────────────────────
alertstar_model = AlertStarModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout'])
alertstar_model, _ = train_standard(alertstar_model, "AlertStar")
all_results['AlertStar'] = evaluate_model(alertstar_model, test_ds, CONFIG['device'])

print("\n✓ Standard models trained")
for m, r in all_results.items():
    print(f"  {m:15s}  MR={r['mr']:7.1f}  MRR={r['mrr']:.4f}  "
          f"H@1={r['hits@1']:.4f}  H@3={r['hits@3']:.4f}  H@10={r['hits@10']:.4f}")

# ── Complex Query Models ───────────────────────────────────────────────────
print("\nGenerating complex queries...")
complex_queries = generate_complex_queries(train_data, preprocessor, n=500)

# StarQE
starqe_model = StarQEModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout'])
starqe_model, _ = train_query_model(starqe_model, complex_queries, "StarQE")
starqe_lp  = evaluate_model(starqe_model, test_ds, CONFIG['device'])
starqe_cq  = evaluate_complex_queries(starqe_model, complex_queries)
all_results['StarQE'] = starqe_lp

# NBFNet+StarQE
nbfqe_model = NBFNetStarQE(NE, NR, NQK, NQV, DIM, CONFIG['dropout'])
nbfqe_model, _ = train_query_model(nbfqe_model, complex_queries, "NBFNet+StarQE")
nbfqe_lp  = evaluate_model(nbfqe_model, test_ds, CONFIG['device'])
nbfqe_cq  = evaluate_complex_queries(nbfqe_model, complex_queries)
all_results['NBFNet+StarQE'] = nbfqe_lp

print("\n Complex query models trained")
for name, cq in [("StarQE", starqe_cq), ("NBFNet+StarQE", nbfqe_cq)]:
    avg = np.mean([v['mrr'] for v in cq.values()])
    print(f"  {name}: Avg-MRR={avg:.4f}")
    for k, v in cq.items():
        print(f"    {k}: MR={v['mr']:6.1f} MRR={v['mrr']:.4f} "
              f"H@1={v['hits@1']:.4f} H@3={v['hits@3']:.4f} H@10={v['hits@10']:.4f}")

# ── MultiTask AlertStar ────────────────────────────────────────────────────
mt_model = MultiTaskAlertStar(NE, NR, NQK, NQV, DIM,
                               dropout=CONFIG['dropout'])
mt_model, mt_history = train_multitask(mt_model)

mt_test_ds   = MultiTaskDataset(test_data, preprocessor)
mt_results   = evaluate_multitask(mt_model, mt_test_ds, CONFIG['device'])

# Store tail prediction metrics
all_results['MultiTask_AlertStar'] = mt_results['tail']

print("\n MultiTask AlertStar evaluation complete")
print(f"\n  TAIL PREDICTION:")
print(f"    MR     = {mt_results['tail']['mr']:.1f}")
print(f"    MRR    = {mt_results['tail']['mrr']:.4f}")
print(f"    Hits@1 = {mt_results['tail']['hits@1']:.4f}")
print(f"    Hits@3 = {mt_results['tail']['hits@3']:.4f}")
print(f"    Hits@10= {mt_results['tail']['hits@10']:.4f}")
print(f"\n  RELATION PREDICTION:")
print(f"    MR     = {mt_results['relation']['mr']:.1f}")
print(f"    MRR    = {mt_results['relation']['mrr']:.4f}")
print(f"    Hits@1 = {mt_results['relation']['hits@1']:.4f}")
print(f"    Hits@3 = {mt_results['relation']['hits@3']:.4f}")
print(f"    Hits@10= {mt_results['relation']['hits@10']:.4f}")
print(f"    Acc    = {mt_results['relation']['accuracy']:.4f}")
print(f"\n  QUALIFIER KEY:")
print(f"    Precision = {mt_results['qual_key']['precision']:.4f}")
print(f"    Recall    = {mt_results['qual_key']['recall']:.4f}")
print(f"    F1        = {mt_results['qual_key']['f1']:.4f}")
print(f"\n  QUALIFIER VALUE:")
print(f"    MR     = {mt_results['qual_value']['mr']:.1f}")
print(f"    MRR    = {mt_results['qual_value']['mrr']:.4f}")
print(f"    Hits@1 = {mt_results['qual_value']['hits@1']:.4f}")
print(f"    Hits@3 = {mt_results['qual_value']['hits@3']:.4f}")
print(f"    Hits@10= {mt_results['qual_value']['hits@10']:.4f}")
print(f"    Acc    = {mt_results['qual_value']['accuracy']:.4f}")


# ============================================================================
# 13 — UPDATED: Complete Results Table with All Metrics
# ============================================================================

print("\n" + "="*80)
print("COMPLETE MODEL COMPARISON — ALL METRICS")
print("="*80)

df = pd.DataFrame({
    'Model':   list(all_results.keys()),
    'MR':      [all_results[m]['mr']      for m in all_results],
    'MRR':     [all_results[m]['mrr']     for m in all_results],
    'Hits@1':  [all_results[m]['hits@1']  for m in all_results],
    'Hits@3':  [all_results[m]['hits@3']  for m in all_results],
    'Hits@10': [all_results[m]['hits@10'] for m in all_results],
})
df = df.sort_values('MRR', ascending=False).reset_index(drop=True)
print(df.to_string(index=False))

best = df.iloc[0]
print(f"\n Best Model: {best['Model']}")
print(f"   MR={best['MR']:.1f}  MRR={best['MRR']:.4f}  "
      f"H@1={best['Hits@1']:.4f}  H@3={best['Hits@3']:.4f}  H@10={best['Hits@10']:.4f}")

# Multi-task detailed table
print("\n" + "="*80)
print("MULTI-TASK ALERTSTAR — DETAILED METRICS (HyNT-comparable)")
print("="*80)

mt_df = pd.DataFrame({
    'Task': ['Tail Pred', 'Relation Pred', 'Qual Key', 'Qual Value'],
    'Primary Metric': [
        f"MRR={mt_results['tail']['mrr']:.4f}",
        f"MRR={mt_results['relation']['mrr']:.4f}",
        f"F1={mt_results['qual_key']['f1']:.4f}",
        f"MRR={mt_results['qual_value']['mrr']:.4f}"
    ],
    'MR': [
        f"{mt_results['tail']['mr']:.1f}",
        f"{mt_results['relation']['mr']:.1f}",
        '-',
        f"{mt_results['qual_value']['mr']:.1f}"
    ],
    'Hits@1': [
        f"{mt_results['tail']['hits@1']:.4f}",
        f"{mt_results['relation']['hits@1']:.4f}",
        '-',
        f"{mt_results['qual_value']['hits@1']:.4f}"
    ],
    'Hits@3': [
        f"{mt_results['tail']['hits@3']:.4f}",
        f"{mt_results['relation']['hits@3']:.4f}",
        '-',
        f"{mt_results['qual_value']['hits@3']:.4f}"
    ],
    'Hits@10': [
        f"{mt_results['tail']['hits@10']:.4f}",
        f"{mt_results['relation']['hits@10']:.4f}",
        '-',
        f"{mt_results['qual_value']['hits@10']:.4f}"
    ],
    'Accuracy/F1': [
        '-',
        f"{mt_results['relation']['accuracy']:.4f}",
        f"{mt_results['qual_key']['f1']:.4f}",
        f"{mt_results['qual_value']['accuracy']:.4f}"
    ]
})
print(mt_df.to_string(index=False))

# Complex query detailed table
print("\n" + "="*80)
print("COMPLEX QUERY ANSWERING — DETAILED METRICS")
print("="*80)

for model_name, cq_results in [("StarQE", starqe_cq), ("NBFNet+StarQE", nbfqe_cq)]:
    print(f"\n{model_name}:")
    cq_df = pd.DataFrame({
        'Query Type': list(cq_results.keys()),
        'MR':      [cq_results[qt]['mr']      for qt in cq_results],
        'MRR':     [cq_results[qt]['mrr']     for qt in cq_results],
        'Hits@1':  [cq_results[qt]['hits@1']  for qt in cq_results],
        'Hits@3':  [cq_results[qt]['hits@3']  for qt in cq_results],
        'Hits@10': [cq_results[qt]['hits@10'] for qt in cq_results],
    })
    print(cq_df.to_string(index=False))

# Save all results
with open(f"{CONFIG['output_path']}all_results.json", 'w') as f:
    json.dump({
        'link_prediction': all_results,
        'complex_queries': {
            'StarQE':       starqe_cq,
            'NBFNet+StarQE':nbfqe_cq,
        },
        'multitask': mt_results,
    }, f, indent=2)

print(f"\n Results saved to {CONFIG['output_path']}all_results.json")


# ============================================================================
# 14 — UPDATED: Visualization with All Metrics
# ============================================================================

fig, axes = plt.subplots(3, 3, figsize=(20, 15))
fig.suptitle("AlertStar: Complete Model Comparison (All Metrics)",
             fontsize=16, fontweight='bold')

palette = sns.color_palette("Set2", len(df))

# Row 1: Link Prediction - MR, MRR, Hits@1
axes[0,0].bar(df['Model'], df['MR'], color=palette)
axes[0,0].set_title("Link Prediction — MR (lower is better)", fontweight='bold')
axes[0,0].set_ylabel("Mean Rank")
axes[0,0].tick_params(axis='x', rotation=45, labelsize=9)

axes[0,1].bar(df['Model'], df['MRR'], color=palette)
axes[0,1].set_title("Link Prediction — MRR", fontweight='bold')
axes[0,1].set_ylabel("MRR")
axes[0,1].tick_params(axis='x', rotation=45, labelsize=9)
for i, v in enumerate(df['MRR']):
    axes[0,1].text(i, v+0.005, f'{v:.3f}', ha='center', fontsize=8)

axes[0,2].bar(df['Model'], df['Hits@1'], color=palette)
axes[0,2].set_title("Link Prediction — Hits@1", fontweight='bold')
axes[0,2].set_ylabel("Hits@1")
axes[0,2].tick_params(axis='x', rotation=45, labelsize=9)

# Row 2: Link Prediction - Hits@3, Hits@10, Combined
axes[1,0].bar(df['Model'], df['Hits@3'], color=palette)
axes[1,0].set_title("Link Prediction — Hits@3", fontweight='bold')
axes[1,0].set_ylabel("Hits@3")
axes[1,0].tick_params(axis='x', rotation=45, labelsize=9)

axes[1,1].bar(df['Model'], df['Hits@10'], color=palette)
axes[1,1].set_title("Link Prediction — Hits@10", fontweight='bold')
axes[1,1].set_ylabel("Hits@10")
axes[1,1].tick_params(axis='x', rotation=45, labelsize=9)

# Combined metric visualization
x = np.arange(len(df))
width = 0.15
axes[1,2].bar(x-2*width, df['Hits@1'], width, label='H@1', color='steelblue')
axes[1,2].bar(x-width, df['Hits@3'], width, label='H@3', color='orange')
axes[1,2].bar(x, df['Hits@10'], width, label='H@10', color='green')
axes[1,2].bar(x+width, df['MRR'], width, label='MRR', color='red')
axes[1,2].set_xticks(x)
axes[1,2].set_xticklabels(df['Model'], rotation=45, ha='right', fontsize=8)
axes[1,2].set_title("All Metrics Combined", fontweight='bold')
axes[1,2].legend(fontsize=8)

# Row 3: Complex Queries, Multi-task, Capability Matrix
query_types = list(starqe_cq.keys())
x_cq = np.arange(len(query_types))
w = 0.35
axes[2,0].bar(x_cq-w/2, [starqe_cq[qt]['mrr'] for qt in query_types],
              w, label='StarQE', color='steelblue')
axes[2,0].bar(x_cq+w/2, [nbfqe_cq[qt]['mrr'] for qt in query_types],
              w, label='NBFNet+StarQE', color='coral')
axes[2,0].set_xticks(x_cq)
axes[2,0].set_xticklabels(query_types)
axes[2,0].set_title("Complex Query MRR by Type", fontweight='bold')
axes[2,0].set_ylabel("MRR")
axes[2,0].legend()

# Multi-task results
tasks  = ['Tail\n(MRR)', 'Relation\n(MRR)', 'Qual-Key\n(F1)', 'Qual-Val\n(Acc)']
scores = [mt_results['tail']['mrr'],
          mt_results['relation']['mrr'],
          mt_results['qual_key']['f1'],
          mt_results['qual_value']['accuracy']]
bars = axes[2,1].bar(tasks, scores, color=sns.color_palette("Set3", 4))
axes[2,1].set_title("MultiTask AlertStar — All Tasks", fontweight='bold')
axes[2,1].set_ylabel("Score")
axes[2,1].set_ylim(0, 1.1)
for bar, v in zip(bars, scores):
    axes[2,1].text(bar.get_x()+bar.get_width()/2, v+0.02,
                   f'{v:.3f}', ha='center', fontsize=10, fontweight='bold')

# Capability matrix
model_names = list(all_results.keys())
caps = {m: {'Link\nPred': 1,
            'Relation\nPred': 1 if m == 'MultiTask_AlertStar' else 0,
            'Qual\nPred': 1 if m == 'MultiTask_AlertStar' else 0,
            'Complex\nQuery': 1 if m in ['StarQE','NBFNet+StarQE'] else 0}
        for m in model_names}
cap_matrix = np.array([[caps[m][c] for c in ['Link\nPred','Relation\nPred',
                                              'Qual\nPred','Complex\nQuery']]
                        for m in model_names])
im = axes[2,2].imshow(cap_matrix, cmap='YlGn', aspect='auto', vmin=0, vmax=1)
axes[2,2].set_xticks(range(4))
axes[2,2].set_xticklabels(['Link\nPred','Relation\nPred',
                            'Qual\nPred','Complex\nQuery'], fontsize=9)
axes[2,2].set_yticks(range(len(model_names)))
axes[2,2].set_yticklabels(model_names, fontsize=9)
axes[2,2].set_title("Capability Matrix", fontweight='bold')
for i in range(len(model_names)):
    for j in range(4):
        axes[2,2].text(j, i, '' if cap_matrix[i,j] else '',
                       ha='center', va='center', fontsize=12,
                       color='darkgreen' if cap_matrix[i,j] else 'gray')

plt.tight_layout()
plt.savefig(f"{CONFIG['output_path']}complete_comparison_all_metrics.png",
            dpi=300, bbox_inches='tight')
plt.show()
print(f"✓ Visualization saved")


# ============================================================================
# 15- Final Summary
# ============================================================================

def get_name(vocab_dict, idx):
    """Reverse lookup from id → name."""
    rev = {v: k for k, v in vocab_dict.items()}
    return rev.get(idx, f"<{idx}>")

dev = CONFIG['device']
print("\n" + "="*70)
print("EXPERIMENT COMPLETE!")
print("="*70)

print(f"\n TRAINED {len(all_results)} MODELS:")
for m in all_results:
    print(f"{m}")

print(f"\n BEST LINK PREDICTION MODEL: {best['Model']}")
print(f"   MR     = {best['MR']:.1f}")
print(f"   MRR    = {best['MRR']:.4f}")
print(f"   Hits@1 = {best['Hits@1']:.4f}")
print(f"   Hits@3 = {best['Hits@3']:.4f}")
print(f"   Hits@10= {best['Hits@10']:.4f}")

print(f"\n All outputs saved to: {CONFIG['output_path']}")
print("   - all_results.json (complete metrics)")
print("   - complete_comparison_all_metrics.png (visualization)")
print("   - Model weights (.pt files)")

print("\n" + "="*70)