AlertStar / alertstar_ablation.py
alertstar_ablation.py
Raw
# ============================================================================
# ALERTSTAR: COMPLETE NOTEBOOK — 10 MODELS + ALL ABLATIONS
#
# Models:
#   1.  StarE                — qualifier-enriched relation attention
#   2.  ShrinkE              — shrinking transform qualifier fusion
#   3.  TrueNBFNet           — Neural Bellman-Ford (qualifier-unaware)
#   4.  AlertStar            — gated StarE + path composition (ours)
#   5.  StarQE               — complex query answering (1p/2p/2i/2u)
#   6.  NBFNet+StarQE        — residual path-augmented complex queries
#   7.  HyNT                 — Transformer 2-task competitor
#   8.  MultiTask-AS         — Transformer 4-task (ours, main)
#   9.  HR-NBFNet            — Hyper-Relational Bellman-Ford
#                              qualifier-aware at every propagation step
#                              matches slide formulation exactly
#  10.  MultiTask_HR_NBFNet  — HR-NBFNet + 4-task multi-task training (NEW)
#                              combines graph propagation with auxiliary
#                              relation/qual-key/qual-value supervision
#
# Ablations:
#   A1 — AlertStar component ablation (NoQual / NoPath / NoGate / Full)
#   A2 — AlertStar gate value trajectory
#   A3 — MultiTask-AS auxiliary task contribution
#   A4 — Qualifier density sensitivity (Q33 / Q66 / Q100)
#        HR-NBFNet + MultiTask_HR_NBFNet included in A4
# ============================================================================


# ============================================================================
# 1 — Imports & Config
# ============================================================================

import os, json, warnings, copy
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
from collections import defaultdict
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch_scatter import scatter_add, scatter_mean

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

CONFIG = {
    # ── paths ─────────────────────────────────────────────────────────────
    'data_path':   '/.../inductive_q100_statements',
    'output_path': '/.../results_q3366100_ablation_10ci100/',

    # ── shared hyper-params ───────────────────────────────────────────────
    'embedding_dim': 200,
    'dropout':       0.2,
    'device':        device,

    # ── standard training (StarE / ShrinkE / AlertStar) ──────────────────
    'batch_size':    128,
    'learning_rate': 0.0005,
    'epochs':        20,

    # ── TrueNBFNet ────────────────────────────────────────────────────────
    'nbfnet_epochs':        20,
    'nbfnet_lr':            0.0005,
    'nbfnet_layers':        3,
    'nbfnet_chunk_size':    10000,
    'nbfnet_max_per_group': 8,

    # ── StarQE / NBFNet+StarQE ────────────────────────────────────────────
    'query_epochs': 20,
    'query_lr':     0.0005,

    # ── HyNT ──────────────────────────────────────────────────────────────
    'hynt_epochs':   20,
    'hynt_lr':       0.0005,
    'hynt_batch':    64,
    'hynt_n_heads':  4,
    'hynt_n_layers': 2,

    # ── MultiTask-AS ──────────────────────────────────────────────────────
    'mt_epochs': 20,
    'mt_lr':     0.0005,
    'mt_batch':  64,

    # ── HR-NBFNet ─────────────────────────────────────────────────────────
    'hr_nbfnet_epochs':     20,
    'hr_nbfnet_lr':         0.0005,
    'hr_nbfnet_layers':     3,
    'hr_nbfnet_chunk_size': 5000,
    'hr_nbfnet_max_quals':  8,

    # ── MultiTask_HR_NBFNet ───────────────────────────────────────────────
    'mt_hr_epochs':     20,
    'mt_hr_lr':         0.0005,
    'mt_hr_batch':      32,    # smaller batch — each sample triggers a BF pass

    # ── qualifier density paths for Ablation A4 ───────────────────────────
    'q33_path':  '/.../inductive_q33_statements',
    'q66_path': '/.../inductive_q66_statements',
}

os.makedirs(CONFIG['output_path'], exist_ok=True)
all_results      = {}
ablation_results = {}
gate_history     = []
print("Config ready")


# ============================================================================
# 2 — Data Loading
# ============================================================================

class DataPreprocessor:
    def __init__(self):
        self.entity2id          = {}
        self.relation2id        = {}
        self.qualifier_key2id   = {}
        self.qualifier_value2id = {}

    def _parse(self, line):
        parts = line.strip().split(',')
        if len(parts) < 3:
            return None
        head, relation, tail = parts[0], parts[1], parts[2]
        qualifiers = []
        if len(parts) > 3:
            for pair in parts[3].split('|'):
                pair = pair.strip()
                if ':' in pair:
                    key, value = pair.split(':', 1)
                    qualifiers.append((key.strip(), value.strip()))
        return {'head': head, 'relation': relation,
                'tail': tail, 'qualifiers': qualifiers}

    def _register(self, triple):
        for token, vocab in [(triple['head'],     self.entity2id),
                             (triple['tail'],     self.entity2id),
                             (triple['relation'], self.relation2id)]:
            if token not in vocab:
                vocab[token] = len(vocab)
        for qk, qv in triple['qualifiers']:
            if qk not in self.qualifier_key2id:
                self.qualifier_key2id[qk] = len(self.qualifier_key2id)
            if qv not in self.qualifier_value2id:
                self.qualifier_value2id[qv] = len(self.qualifier_value2id)

    def load(self, data_path):
        splits = {}
        for name, fname in [('train','train.txt'),
                             ('valid','validation.txt'),
                             ('test', 'test.txt')]:
            data = []
            with open(os.path.join(data_path, fname)) as f:
                for line in f:
                    t = self._parse(line)
                    if t:
                        self._register(t)
                        data.append(t)
            splits[name] = data
            print(f"  {name}: {len(data):,}")
        print(f"  NE={len(self.entity2id)}  NR={len(self.relation2id)}  "
              f"NQK={len(self.qualifier_key2id)}  NQV={len(self.qualifier_value2id)}")
        return splits['train'], splits['valid'], splits['test']


class HRDataset(Dataset):
    def __init__(self, data, preprocessor):
        self.data = data
        self.p    = preprocessor

    def __len__(self): return len(self.data)

    def __getitem__(self, idx):
        t = self.data[idx]
        return {
            'head':      self.p.entity2id[t['head']],
            'relation':  self.p.relation2id[t['relation']],
            'tail':      self.p.entity2id[t['tail']],
            'qualifiers': [(self.p.qualifier_key2id[qk],
                            self.p.qualifier_value2id[qv])
                           for qk, qv in t['qualifiers']],
        }


def collate(batch):
    return {
        'head':      torch.tensor([b['head']     for b in batch], dtype=torch.long),
        'relation':  torch.tensor([b['relation'] for b in batch], dtype=torch.long),
        'tail':      torch.tensor([b['tail']     for b in batch], dtype=torch.long),
        'qualifiers': [b['qualifiers'] for b in batch],
    }


def mt_collate(batch):
    by_task = defaultdict(list)
    for s in batch:
        by_task[s['task']].append(s)
    return by_task


print("Loading data...")
preprocessor = DataPreprocessor()
train_data, valid_data, test_data = preprocessor.load(CONFIG['data_path'])

train_ds = HRDataset(train_data, preprocessor)
valid_ds = HRDataset(valid_data, preprocessor)
test_ds  = HRDataset(test_data,  preprocessor)

NE  = len(preprocessor.entity2id)
NR  = len(preprocessor.relation2id)
NQK = len(preprocessor.qualifier_key2id)
NQV = len(preprocessor.qualifier_value2id)
DIM = CONFIG['embedding_dim']
print("Data ready")


# ============================================================================
# 3 — Shared Evaluation & Training Utilities
# ============================================================================

def evaluate_model(model, dataset, device, max_samples=500):
    model.eval()
    ranks = []
    with torch.no_grad():
        for i in tqdm(range(min(len(dataset), max_samples)),
                      desc="Evaluating", leave=False):
            s      = dataset[i]
            h      = torch.tensor([s['head']],     device=device)
            r      = torch.tensor([s['relation']], device=device)
            t      = s['tail']
            q      = [s['qualifiers']]
            scores = model(h, r, q).squeeze()
            rank   = (torch.argsort(scores, descending=True) == t
                      ).nonzero(as_tuple=True)[0].item() + 1
            ranks.append(rank)
    ranks = np.array(ranks)
    return {'mr':      float(np.mean(ranks)),
            'mrr':     float(np.mean(1.0 / ranks)),
            'hits@1':  float(np.mean(ranks <= 1)),
            'hits@3':  float(np.mean(ranks <= 3)),
            'hits@10': float(np.mean(ranks <= 10))}


def train_standard(model, model_name, train_ds_=None, valid_ds_=None,
                   ne_override=None, lr=None, epochs=None, gate_track=False):
    dev   = CONFIG['device']
    lr    = lr     or CONFIG['learning_rate']
    eps   = epochs or CONFIG['epochs']
    _ne   = ne_override or NE
    _tds  = train_ds_ or train_ds
    _vds  = valid_ds_ or valid_ds
    model.to(dev)
    opt   = torch.optim.Adam(model.parameters(), lr=lr)
    loader = DataLoader(_tds, batch_size=CONFIG['batch_size'],
                        shuffle=True, collate_fn=collate)
    print(f"\n{'='*60}\nTRAINING {model_name}  "
          f"({sum(p.numel() for p in model.parameters()):,} params)\n{'='*60}")
    best_mrr, history, epoch_gates = 0.0, [], []
    for epoch in range(eps):
        model.train(); total = 0.0
        for b in tqdm(loader, desc=f"Epoch {epoch+1}/{eps}", leave=False):
            h, r, t, q = (b['head'].to(dev), b['relation'].to(dev),
                          b['tail'].to(dev),  b['qualifiers'])
            pos  = model(h, r, q, t)
            neg  = model(h, r, q, torch.randint(0, _ne, (len(h),), device=dev))
            loss = F.margin_ranking_loss(pos, neg, torch.ones_like(pos), margin=1.0)
            opt.zero_grad(); loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step(); total += loss.item()
        if gate_track and hasattr(model, 'gate') and model.gate is not None:
            epoch_gates.append({'epoch': epoch+1,
                                 'gate': torch.sigmoid(model.gate).item()})
        if (epoch+1) % 5 == 0:
            m = evaluate_model(model, _vds, dev)
            print(f"  Epoch {epoch+1}: loss={total/len(loader):.4f}  "
                  f"MRR={m['mrr']:.4f}  MR={m['mr']:.1f}  "
                  f"H@1={m['hits@1']:.4f}  H@3={m['hits@3']:.4f}  "
                  f"H@10={m['hits@10']:.4f}"
                  + (f"  gate={epoch_gates[-1]['gate']:.4f}"
                     if gate_track and epoch_gates else ""))
            if m['mrr'] > best_mrr:
                best_mrr = m['mrr']
                torch.save(model.state_dict(),
                           f"{CONFIG['output_path']}{model_name}_best.pt")
                print("    → best saved")
            history.append(m)
    return model, history, epoch_gates


print("Utilities ready")


# ============================================================================
# 4 — Model 1: StarE
# ============================================================================

class StarEModel(nn.Module):
    def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1):
        super().__init__()
        self.num_entities = ne
        self.ent  = nn.Embedding(ne,  dim)
        self.rel  = nn.Embedding(nr,  dim)
        self.qk   = nn.Embedding(nqk, dim)
        self.qv   = nn.Embedding(nqv, dim)
        self.attn = nn.MultiheadAttention(dim, 4, dropout=dropout, batch_first=True)
        self.drop = nn.Dropout(dropout)
        for e in [self.ent, self.rel, self.qk, self.qv]:
            nn.init.xavier_uniform_(e.weight)

    def _enrich(self, r_emb, quals, dev):
        if not quals: return r_emb
        k  = self.qk(torch.tensor([q[0] for q in quals], device=dev))
        v  = self.qv(torch.tensor([q[1] for q in quals], device=dev))
        kv = (k + v).unsqueeze(0)
        out, _ = self.attn(r_emb.view(1,1,-1), kv, kv)
        return out.squeeze()

    def forward(self, head, relation, qualifiers, tail=None):
        dev = head.device
        h   = self.ent(head)
        r   = torch.stack([self._enrich(self.rel(relation[i:i+1]).squeeze(),
                           qualifiers[i], dev) for i in range(head.size(0))])
        x   = self.drop(h + r)
        if tail is not None: return (x * self.ent(tail)).sum(-1)
        return x @ self.ent.weight.t()

print("StarE defined")


# ============================================================================
# 5 — Model 2: ShrinkE
# ============================================================================

class ShrinkEModel(nn.Module):
    def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1):
        super().__init__()
        self.num_entities = ne
        self.ent    = nn.Embedding(ne,  dim)
        self.rel    = nn.Embedding(nr,  dim)
        self.qk     = nn.Embedding(nqk, dim)
        self.qv     = nn.Embedding(nqv, dim)
        self.shrink = nn.Sequential(nn.Linear(dim*2, dim), nn.Tanh(), nn.Dropout(dropout))
        self.proj   = nn.Linear(dim, dim)
        self.drop   = nn.Dropout(dropout)
        for e in [self.ent, self.rel, self.qk, self.qv]:
            nn.init.xavier_uniform_(e.weight)

    def _shrink(self, r_emb, quals, dev):
        if not quals: return r_emb
        k  = self.qk(torch.tensor([q[0] for q in quals], device=dev))
        v  = self.qv(torch.tensor([q[1] for q in quals], device=dev))
        qc = (k + v).mean(0, keepdim=True)
        return self.shrink(torch.cat([r_emb.unsqueeze(0), qc], -1)).squeeze()

    def forward(self, head, relation, qualifiers, tail=None):
        dev = head.device
        h   = self.proj(self.ent(head))
        r   = torch.stack([self._shrink(self.rel(relation[i:i+1]).squeeze(),
                           qualifiers[i], dev) for i in range(head.size(0))])
        x   = self.drop(h + self.proj(r))
        if tail is not None: return (x * self.ent(tail)).sum(-1)
        return x @ self.ent.weight.t()

print("✓ ShrinkE defined")


# ============================================================================
# 6 — Model 3: TrueNBFNet (qualifier-unaware Bellman-Ford)
# ============================================================================

class NBFConvLayer(nn.Module):
    def __init__(self, dim, num_relation, chunk_size=10000, layer_norm=True):
        super().__init__()
        self.chunk_size = chunk_size
        self.rel_emb    = nn.Embedding(num_relation, dim)
        self.linear     = nn.Linear(dim*2, dim)
        self.ln         = nn.LayerNorm(dim) if layer_norm else None
        self.act        = nn.ReLU()
        nn.init.xavier_uniform_(self.rel_emb.weight)

    def forward(self, graph, node_feat):
        edge_list        = graph['edge_list']
        N, B, D          = node_feat.shape
        dev              = node_feat.device
        agg              = torch.zeros(N, B, D, device=dev)
        for start in range(0, edge_list.size(0), self.chunk_size):
            chunk        = edge_list[start:start+self.chunk_size]
            src, dst, rel = chunk[:,0], chunk[:,1], chunk[:,2]
            msg          = node_feat[src] * self.rel_emb(rel).unsqueeze(1)
            agg.scatter_add_(0, dst.view(-1,1,1).expand_as(msg), msg)
            del msg
        out = self.linear(torch.cat([node_feat, agg], dim=-1))
        if self.ln:
            s = out.shape; out = self.ln(out.flatten(0,1)).view(s)
        return self.act(out)


class TrueNBFNet(nn.Module):
    def __init__(self, ne, nr, dim=200, num_layers=3,
                 short_cut=True, chunk_size=10000, dropout=0.1):
        super().__init__()
        self.num_entities = ne
        self.dim       = dim
        self.short_cut = short_cut
        nr2            = nr * 2
        self.query_emb = nn.Embedding(nr2, dim)
        self.layers    = nn.ModuleList([
            NBFConvLayer(dim, nr2, chunk_size=chunk_size) for _ in range(num_layers)])
        self.mlp = nn.Sequential(
            nn.Linear(dim*2, dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, 1))
        nn.init.xavier_uniform_(self.query_emb.weight)

    def _bellman_ford(self, graph, h, r, device):
        N    = graph['num_nodes']
        q    = self.query_emb(torch.tensor([r], device=device))
        feat = torch.zeros(N, 1, self.dim, device=device)
        feat[h, 0] = q[0]
        for layer in self.layers:
            h_new = layer(graph, feat)
            if self.short_cut: h_new = h_new + feat
            feat = h_new
        node_q = q.unsqueeze(0).expand(N, -1, -1)
        return torch.cat([feat, node_q], dim=-1).squeeze(1)

    def forward(self, head, relation, qualifiers=None, tail=None, graph=None):
        assert graph is not None
        assert (head==head[0]).all() and (relation==relation[0]).all()
        dev  = head.device
        feat = self._bellman_ford(graph, head[0].item(), relation[0].item(), dev)
        if tail is not None: return self.mlp(feat[tail]).squeeze(-1)
        B = head.size(0)
        return self.mlp(feat).squeeze(-1).unsqueeze(0).expand(B, -1)


def build_nbfnet_graph(train_data, preprocessor, device):
    nr, rows = len(preprocessor.relation2id), []
    for t in train_data:
        h  = preprocessor.entity2id[t['head']]
        r  = preprocessor.relation2id[t['relation']]
        tl = preprocessor.entity2id[t['tail']]
        rows += [(h, tl, r), (tl, h, r+nr)]
    return {'edge_list': torch.tensor(rows, dtype=torch.long, device=device),
            'num_nodes': len(preprocessor.entity2id)}


def train_true_nbfnet(model, model_name="TrueNBFNet",
                      train_data_=None, valid_ds_=None, graph_=None):
    dev  = CONFIG['device']
    _td  = train_data_ or train_data
    _vds = valid_ds_   or valid_ds
    _g   = graph_      or nbfnet_graph
    model.to(dev)
    graph = {k: v.to(dev) if torch.is_tensor(v) else v for k,v in _g.items()}
    opt   = torch.optim.Adam(model.parameters(), lr=CONFIG['nbfnet_lr'])
    ne_   = graph['num_nodes']
    mpg   = CONFIG['nbfnet_max_per_group']
    groups = defaultdict(list)
    for i, s in enumerate(_td):
        groups[(preprocessor.entity2id[s['head']],
                preprocessor.relation2id[s['relation']])].append(i)
    keys = list(groups.keys())
    print(f"\n{'='*60}\nTRAINING {model_name}\n{'='*60}")
    best_mrr, history = 0.0, []
    for epoch in range(CONFIG['nbfnet_epochs']):
        model.train(); np.random.shuffle(keys)
        total, cnt = 0.0, 0
        for (h, r) in tqdm(keys, desc=f"Epoch {epoch+1}", leave=False):
            chosen = np.random.choice(groups[(h,r)],
                                      min(len(groups[(h,r)]),mpg), replace=False)
            t_pos = torch.tensor(
                [preprocessor.entity2id[_td[i]['tail']] for i in chosen], device=dev)
            B     = len(t_pos)
            heads = torch.full((B,), h, dtype=torch.long, device=dev)
            rels  = torch.full((B,), r, dtype=torch.long, device=dev)
            pos   = model(heads, rels, tail=t_pos, graph=graph)
            neg   = model(heads, rels, tail=torch.randint(0,ne_,(B,),device=dev), graph=graph)
            loss  = F.margin_ranking_loss(pos, neg, torch.ones(B,device=dev), margin=1.0)
            opt.zero_grad(); loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step(); total += loss.item(); cnt += 1
        if (epoch+1) % 5 == 0:
            m = evaluate_nbfnet(model, _vds, graph, dev)
            print(f"  Epoch {epoch+1}: MRR={m['mrr']:.4f}  MR={m['mr']:.1f}  "
                  f"H@1={m['hits@1']:.4f}  H@10={m['hits@10']:.4f}")
            if m['mrr'] > best_mrr:
                best_mrr = m['mrr']
                torch.save(model.state_dict(),
                           f"{CONFIG['output_path']}{model_name}_best.pt")
            history.append(m)
    return model, history


def evaluate_nbfnet(model, dataset, graph, device, max_groups=300):
    model.eval()
    groups = defaultdict(list)
    for i, s in enumerate(dataset.data):
        groups[(preprocessor.entity2id[s['head']],
                preprocessor.relation2id[s['relation']])].append(i)
    ranks = []
    with torch.no_grad():
        for (h, r) in tqdm(list(groups.keys())[:max_groups], desc="Eval", leave=False):
            idxs  = groups[(h,r)]
            tails = [preprocessor.entity2id[dataset.data[i]['tail']] for i in idxs]
            B     = len(tails)
            heads = torch.full((B,), h, dtype=torch.long, device=device)
            rels  = torch.full((B,), r, dtype=torch.long, device=device)
            try:
                scores = model(heads, rels, graph=graph)
                for i, tgt in enumerate(tails):
                    pos = (torch.argsort(scores[i],descending=True)==tgt
                           ).nonzero(as_tuple=True)[0]
                    if len(pos): ranks.append(pos[0].item()+1)
            except Exception: continue
    if not ranks:
        return {'mr':0.,'mrr':0.,'hits@1':0.,'hits@3':0.,'hits@10':0.}
    ranks = np.array(ranks)
    return {'mr':float(np.mean(ranks)), 'mrr':float(np.mean(1./ranks)),
            'hits@1':float(np.mean(ranks<=1)), 'hits@3':float(np.mean(ranks<=3)),
            'hits@10':float(np.mean(ranks<=10))}

print("TrueNBFNet defined")


# ============================================================================
# 7 — Model 4: AlertStar (with ablation flags)
# ============================================================================

class AlertStarModel(nn.Module):
    """
    use_qual=False  → AS-NoQual
    use_path=False  → AS-NoPath
    fixed_gate=0.5  → AS-NoGate
    default         → AS-Full
    """
    def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1,
                 use_qual=True, use_path=True, fixed_gate=None):
        super().__init__()
        self.num_entities = ne
        self.use_qual, self.use_path, self.fixed_gate = use_qual, use_path, fixed_gate
        self.ent  = nn.Embedding(ne,  dim)
        self.rel  = nn.Embedding(nr,  dim)
        self.qk   = nn.Embedding(nqk, dim)
        self.qv   = nn.Embedding(nqv, dim)
        self.attn = nn.MultiheadAttention(dim, 4, dropout=dropout, batch_first=True)
        self.ln1  = nn.LayerNorm(dim)
        self.path_net = nn.Sequential(
            nn.Linear(dim*2, dim), nn.LayerNorm(dim), nn.ReLU(),
            nn.Dropout(dropout), nn.Linear(dim, dim))
        self.ln2  = nn.LayerNorm(dim)
        self.gate = (nn.Parameter(torch.tensor(0.5))
                     if use_path and fixed_gate is None else None)
        self.drop = nn.Dropout(dropout)
        for e in [self.ent, self.rel, self.qk, self.qv]:
            nn.init.xavier_uniform_(e.weight)

    def _enrich(self, r_emb, quals, dev):
        if not self.use_qual or not quals: return r_emb
        k  = self.qk(torch.tensor([q[0] for q in quals], device=dev))
        v  = self.qv(torch.tensor([q[1] for q in quals], device=dev))
        kv = (k+v).unsqueeze(0)
        out, _ = self.attn(r_emb.view(1,1,-1), kv, kv)
        return out.squeeze()

    def forward(self, head, relation, qualifiers, tail=None):
        dev   = head.device
        h     = self.ent(head)
        r     = torch.stack([self._enrich(self.rel(relation[i:i+1]).squeeze(),
                             qualifiers[i], dev) for i in range(head.size(0))])
        stare = self.ln1(h + r)
        if not self.use_path:
            x = self.drop(stare)
        else:
            path = self.ln2(h + self.path_net(torch.cat([h,r],dim=-1)))
            g    = (torch.sigmoid(self.gate) if self.gate is not None
                    else torch.tensor(self.fixed_gate, device=dev))
            x    = self.drop(g*stare + (1-g)*path)
        if tail is not None: return (x * self.ent(tail)).sum(-1)
        return x @ self.ent.weight.t()

print("AlertStar defined")

# ============================================================================
# 8 — Model 5: StarQE (Complex Query Answering: 1p / 2p / 2i / 2u)
# ============================================================================
# StarQE extends link prediction to four first-order logic query types:
#   1p: ∃e: r(h,e)                       — direct 1-hop
#   2p: ∃e1: r1(h,e1) ∧ r2(e1,e)        — 2-hop chain
#   2i: r1(h1,e) ∧ r2(h2,e)             — 2-anchor intersection
#   2u: r1(h1,e) ∨ r2(h2,e)             — 2-anchor union
# All types are trained with margin ranking loss and evaluated on 1p queries
# (tail prediction) for a fair comparison with other models.
# ============================================================================

class StarQEModel(nn.Module):
    def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1):
        super().__init__()
        self.num_entities = ne
        self.dim  = dim
        self.ent  = nn.Embedding(ne, dim)
        self.rel  = nn.Embedding(nr, dim)
        # qualifier embeddings — StarQE uses them to enrich relations like StarE
        self.qk   = nn.Embedding(nqk, dim)
        self.qv   = nn.Embedding(nqv, dim)
        self.attn = nn.MultiheadAttention(dim, 4, dropout=dropout, batch_first=True)
        # composition MLP: projects x^{l-1} + R[r] into next entity embedding
        self.compose = nn.Sequential(
            nn.Linear(dim, dim), nn.LayerNorm(dim), nn.ReLU(), nn.Dropout(dropout))
        # intersection operator W_∩: fuses two anchor embeddings
        self.intersect = nn.Linear(dim*2, dim)
        self.drop = nn.Dropout(dropout)
        for e in [self.ent, self.rel, self.qk, self.qv]:
            nn.init.xavier_uniform_(e.weight)

    def _enrich_rel(self, r_id, quals, dev):
        """Qualifier-enrich a relation embedding (same as StarE)."""
        r_emb = self.rel(torch.tensor([r_id], device=dev)).squeeze()
        if not quals:
            return r_emb
        k  = self.qk(torch.tensor([q[0] for q in quals], device=dev))
        v  = self.qv(torch.tensor([q[1] for q in quals], device=dev))
        kv = (k + v).unsqueeze(0)
        out, _ = self.attn(r_emb.view(1,1,-1), kv, kv)
        return out.squeeze()

    def _compose_1p(self, h_id, r_id, quals, dev):
        """1p: x = ρ(E[h] + R*[r])"""
        h_emb = self.ent(torch.tensor([h_id], device=dev)).squeeze()
        r_emb = self._enrich_rel(r_id, quals, dev)
        return self.compose(h_emb + r_emb)

    def _compose_2p(self, h_id, r1_id, r2_id, quals, dev):
        """2p: x1 = ρ(E[h]+R[r1]),  x = ρ(x1+R[r2])"""
        h_emb  = self.ent(torch.tensor([h_id], device=dev)).squeeze()
        r1_emb = self._enrich_rel(r1_id, quals, dev)
        r2_emb = self._enrich_rel(r2_id, [], dev)
        x1     = self.compose(h_emb  + r1_emb)
        return   self.compose(x1     + r2_emb)

    def _compose_2i(self, h1, r1, h2, r2, quals, dev):
        """2i: intersection of two 1p queries"""
        e1 = self.compose(
            self.ent(torch.tensor([h1],device=dev)).squeeze() +
            self._enrich_rel(r1, quals, dev))
        e2 = self.compose(
            self.ent(torch.tensor([h2],device=dev)).squeeze() +
            self._enrich_rel(r2, [], dev))
        return self.intersect(torch.cat([e1, e2], dim=-1))

    def _compose_2u(self, h1, r1, h2, r2, quals, dev):
        """2u: union (mean) of two 1p queries"""
        e1 = self.compose(
            self.ent(torch.tensor([h1],device=dev)).squeeze() +
            self._enrich_rel(r1, quals, dev))
        e2 = self.compose(
            self.ent(torch.tensor([h2],device=dev)).squeeze() +
            self._enrich_rel(r2, [], dev))
        return (e1 + e2) / 2

    def forward(self, head, relation, qualifiers, tail=None):
        """Standard 1p interface — compatible with evaluate_model."""
        dev  = head.device
        B    = head.size(0)
        outs = []
        for i in range(B):
            x = self._compose_1p(head[i].item(), relation[i].item(),
                                  qualifiers[i], dev)
            outs.append(x)
        x = self.drop(torch.stack(outs))          # [B, dim]
        if tail is not None:
            return (x * self.ent(tail)).sum(-1)
        return x @ self.ent.weight.t()

    def score_query(self, query_vec, tail_id, dev):
        """Score a composed query vector against a specific tail."""
        t_emb = self.ent(torch.tensor([tail_id], device=dev)).squeeze()
        return (query_vec * t_emb).sum()


def train_starqe(model, model_name="StarQE",
                 train_data_=None, valid_ds_=None, ne_override=None):
    """
    Trains on all four query types derived from 1-hop triples:
      1p: direct link  2p: chain  2i: intersection  2u: union
    Falls back to 1p only if not enough distinct (h,r) pairs for chains.
    """
    dev  = CONFIG['device']
    _td  = train_data_ or train_data
    _vds = valid_ds_   or valid_ds
    _ne  = ne_override or NE
    model.to(dev)
    opt  = torch.optim.Adam(model.parameters(), lr=CONFIG['query_lr'])

    # Build query samples from triples
    def build_queries(data):
        # index (h,r) → list of tails
        hr2tails = defaultdict(list)
        triples   = []
        for s in data:
            h  = preprocessor.entity2id[s['head']]
            r  = preprocessor.relation2id[s['relation']]
            t  = preprocessor.entity2id[s['tail']]
            qs = [(preprocessor.qualifier_key2id[qk],
                   preprocessor.qualifier_value2id[qv])
                  for qk, qv in s['qualifiers']]
            hr2tails[(h,r)].append(t)
            triples.append((h, r, t, qs))
        return triples, hr2tails

    triples, hr2tails = build_queries(_td)
    hr_keys = [k for k, v in hr2tails.items() if len(v) >= 1]

    print(f"\n{'='*60}\nTRAINING {model_name}  "
          f"({sum(p.numel() for p in model.parameters()):,} params)")
    print(f"  Triples: {len(triples):,}  (h,r) pairs: {len(hr_keys):,}\n{'='*60}")

    best_mrr, history = 0.0, []
    for epoch in range(CONFIG['query_epochs']):
        model.train(); total, cnt = 0.0, 0
        np.random.shuffle(triples)
        for h, r, t, qs in tqdm(triples, desc=f"Epoch {epoch+1}", leave=False):
            losses = []
            # ── 1p query ─────────────────────────────────────────────────
            neg_t = np.random.randint(0, _ne)
            h_t   = torch.tensor([h], device=dev)
            r_t   = torch.tensor([r], device=dev)
            pos_s = model(h_t, r_t, [qs], torch.tensor([t],    device=dev))
            neg_s = model(h_t, r_t, [qs], torch.tensor([neg_t],device=dev))
            losses.append(F.margin_ranking_loss(
                pos_s, neg_s, torch.ones(1,device=dev), margin=1.0))

            # ── 2p query (chain through t as intermediate) ────────────────
            # find another relation r2 where t is a head
            r2_candidates = [rr for (hh,rr) in hr_keys if hh == t]
            if r2_candidates:
                r2    = np.random.choice(r2_candidates)
                tails2 = hr2tails[(t, r2)]
                if tails2:
                    t2    = np.random.choice(tails2)
                    neg2  = np.random.randint(0, _ne)
                    try:
                        q_pos = model._compose_2p(h, r, r2, qs, dev)
                        q_neg = q_pos  # same query vec, different tail
                        t2_emb  = model.ent(torch.tensor([t2], device=dev)).squeeze()
                        n2_emb  = model.ent(torch.tensor([neg2],device=dev)).squeeze()
                        pos_2p  = (model.drop(q_pos) * t2_emb).sum()
                        neg_2p  = (model.drop(q_neg) * n2_emb).sum()
                        losses.append(F.margin_ranking_loss(
                            pos_2p.unsqueeze(0), neg_2p.unsqueeze(0),
                            torch.ones(1,device=dev), margin=1.0))
                    except Exception: pass

            # ── 2i query (intersect two 1p queries sharing tail t) ────────
            r2_for_t = [rr for (hh,rr) in hr_keys if t in hr2tails.get((hh,rr),[])]
            r_for_t  = [rr for (hh,rr) in hr_keys
                        if hh != h and t in hr2tails.get((hh,rr),[])]
            if r_for_t:
                h2 = np.random.choice([hh for (hh,rr) in hr_keys
                                       if rr in r_for_t
                                       and t in hr2tails.get((hh,rr),[])
                                       and hh != h] or [h])
                r2 = np.random.choice(r_for_t)
                neg2 = np.random.randint(0, _ne)
                try:
                    q2i = model._compose_2i(h, r, h2, r2, qs, dev)
                    t_e  = model.ent(torch.tensor([t],   device=dev)).squeeze()
                    n_e  = model.ent(torch.tensor([neg2],device=dev)).squeeze()
                    pos_2i = (model.drop(q2i)*t_e).sum()
                    neg_2i = (model.drop(q2i)*n_e).sum()
                    losses.append(F.margin_ranking_loss(
                        pos_2i.unsqueeze(0), neg_2i.unsqueeze(0),
                        torch.ones(1,device=dev), margin=1.0))
                except Exception: pass

            if not losses: continue
            loss = sum(losses) / len(losses)
            opt.zero_grad(); loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step(); total += loss.item(); cnt += 1

        if (epoch+1) % 5 == 0:
            m = evaluate_model(model, _vds, dev)
            print(f"  Epoch {epoch+1}: loss={total/max(cnt,1):.4f}  "
                  f"MRR={m['mrr']:.4f}  MR={m['mr']:.1f}  "
                  f"H@1={m['hits@1']:.4f}  H@10={m['hits@10']:.4f}")
            if m['mrr'] > best_mrr:
                best_mrr = m['mrr']
                torch.save(model.state_dict(),
                           f"{CONFIG['output_path']}{model_name}_best.pt")
                print("    → best saved")
            history.append(m)
    return model, history

print("StarQE defined")


# ============================================================================
# 9 — Model 6: NBFNet+StarQE (residual path-augmented complex queries)
# ============================================================================
# Replaces StarQE's linear composition ρ(x+R[r]) with a residual MLP:
#   x^l = x^0 + PathNet( Concat(x^{l-1}, R*[r_l]) )
# Intersection and union operators inherited unchanged from StarQE.
# ============================================================================

class NBFNetStarQEModel(nn.Module):
    def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1):
        super().__init__()
        self.num_entities = ne
        self.dim  = dim
        self.ent  = nn.Embedding(ne, dim)
        self.rel  = nn.Embedding(nr, dim)
        self.qk   = nn.Embedding(nqk, dim)
        self.qv   = nn.Embedding(nqv, dim)
        self.attn = nn.MultiheadAttention(dim, 4, dropout=dropout, batch_first=True)
        # residual PathNet: R^{2d} -> R^d -> R^d
        self.path_net = nn.Sequential(
            nn.Linear(dim*2, dim), nn.LayerNorm(dim), nn.ReLU(),
            nn.Dropout(dropout), nn.Linear(dim, dim))
        # intersection operator
        self.intersect = nn.Linear(dim*2, dim)
        self.drop = nn.Dropout(dropout)
        for e in [self.ent, self.rel, self.qk, self.qv]:
            nn.init.xavier_uniform_(e.weight)

    def _enrich_rel(self, r_id, quals, dev):
        r_emb = self.rel(torch.tensor([r_id], device=dev)).squeeze()
        if not quals: return r_emb
        k  = self.qk(torch.tensor([q[0] for q in quals], device=dev))
        v  = self.qv(torch.tensor([q[1] for q in quals], device=dev))
        kv = (k+v).unsqueeze(0)
        out, _ = self.attn(r_emb.view(1,1,-1), kv, kv)
        return out.squeeze()

    def _compose_step(self, x, x0, r_emb):
        """x^l = x^0 + PathNet(Concat(x^{l-1}, R[r]))"""
        return x0 + self.path_net(torch.cat([x, r_emb], dim=-1))

    def _compose_1p(self, h_id, r_id, quals, dev):
        h      = self.ent(torch.tensor([h_id], device=dev)).squeeze()
        r_star = self._enrich_rel(r_id, quals, dev)
        return self._compose_step(h, h, r_star)

    def _compose_2p(self, h_id, r1_id, r2_id, quals, dev):
        h      = self.ent(torch.tensor([h_id], device=dev)).squeeze()
        r1_emb = self._enrich_rel(r1_id, quals, dev)
        r2_emb = self._enrich_rel(r2_id, [], dev)
        x1     = self._compose_step(h,  h,  r1_emb)
        return   self._compose_step(x1, h,  r2_emb)

    def _compose_2i(self, h1, r1, h2, r2, quals, dev):
        e1 = self._compose_1p(h1, r1, quals, dev)
        e2 = self._compose_1p(h2, r2, [],    dev)
        return self.intersect(torch.cat([e1, e2], dim=-1))

    def _compose_2u(self, h1, r1, h2, r2, quals, dev):
        e1 = self._compose_1p(h1, r1, quals, dev)
        e2 = self._compose_1p(h2, r2, [],    dev)
        return (e1 + e2) / 2

    def forward(self, head, relation, qualifiers, tail=None):
        dev  = head.device
        B    = head.size(0)
        outs = []
        for i in range(B):
            x = self._compose_1p(head[i].item(), relation[i].item(),
                                  qualifiers[i], dev)
            outs.append(x)
        x = self.drop(torch.stack(outs))
        if tail is not None: return (x * self.ent(tail)).sum(-1)
        return x @ self.ent.weight.t()

# NBFNet+StarQE reuses train_starqe with the same interface
print("NBFNet+StarQE defined")


# ============================================================================
# 10 — Model 7: HyNT
# ============================================================================

class HyNTModel(nn.Module):
    def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1,
                 n_heads=4, n_layers=2):
        super().__init__()
        self.num_entities = ne
        self.dim = dim
        self.ent = nn.Embedding(ne,  dim)
        self.rel = nn.Embedding(nr,  dim)
        self.qk  = nn.Embedding(nqk, dim)
        self.qv  = nn.Embedding(nqv, dim)
        self.qual_attn = nn.MultiheadAttention(dim, n_heads,
                                               dropout=dropout, batch_first=True)
        enc_layer = nn.TransformerEncoderLayer(
            d_model=dim, nhead=n_heads,
            dim_feedforward=dim*4, dropout=dropout, batch_first=True)
        self.encoder  = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
        def _head(out):
            return nn.Sequential(nn.Linear(dim,dim), nn.LayerNorm(dim),
                                 nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim,out))
        self.tail_head = _head(ne)
        self.qv_head   = _head(nqv)
        self.qk_gate   = nn.Sequential(nn.Linear(dim*2, dim), nn.Sigmoid())
        for e in [self.ent, self.rel, self.qk, self.qv]:
            nn.init.xavier_uniform_(e.weight)

    def _aggregate_qualifiers(self, qs, dev):
        if not qs: return torch.zeros(self.dim, device=dev)
        k_embs = self.qk(torch.tensor([q[0] for q in qs], device=dev))
        v_embs = self.qv(torch.tensor([q[1] for q in qs], device=dev))
        kv     = (k_embs + v_embs).unsqueeze(0)
        out, _ = self.qual_attn(kv, kv, kv)
        return out.mean(dim=1).squeeze(0)

    def _encode(self, h, r, t, qs, dev, mask_tail=False):
        h_e = self.ent(torch.tensor([h], device=dev))
        r_e = self.rel(torch.tensor([r], device=dev))
        t_e = (self.ent(torch.tensor([t], device=dev)) if not mask_tail
               else torch.zeros(1, self.dim, device=dev))
        q_c = self._aggregate_qualifiers(qs, dev).unsqueeze(0)
        seq = torch.cat([h_e, r_e, t_e, q_c], dim=0).unsqueeze(0)
        return self.encoder(seq)[0, 0]

    def forward_tail(self, s, dev):
        return self.tail_head(self._encode(s['h'],s['r'],0,s['qs'],dev,mask_tail=True))

    def forward_qv(self, s, dev):
        filtered = [(k,v) for k,v in s['qs'] if k != s['qk']]
        ctx      = self._encode(s['h'],s['r'],s['t'],filtered,dev)
        qk_emb   = self.qk(torch.tensor([s['qk']],device=dev)).squeeze()
        gate     = self.qk_gate(torch.cat([ctx, qk_emb], dim=-1))
        return self.qv_head(gate * ctx)

    def forward(self, head, relation, qualifiers, tail=None):
        dev  = head.device; B = head.size(0)
        outs = [self.forward_tail({'h':head[i].item(),'r':relation[i].item(),
                                    't':0,'qs':qualifiers[i]}, dev)
                for i in range(B)]
        scores = torch.stack(outs, dim=0)
        if tail is not None: return scores[torch.arange(B), tail]
        return scores


class HyNTDataset(Dataset):
    def __init__(self, data, preprocessor):
        self.samples = []
        p = preprocessor
        for triple in data:
            h  = p.entity2id[triple['head']]
            r  = p.relation2id[triple['relation']]
            t  = p.entity2id[triple['tail']]
            qs = [(p.qualifier_key2id[qk], p.qualifier_value2id[qv])
                  for qk, qv in triple['qualifiers']]
            self.samples.append({'task':'tail','h':h,'r':r,'t':t,'qs':qs})
            for qk_id, qv_id in qs:
                self.samples.append({'task':'qv','h':h,'r':r,'t':t,
                                     'qs':qs,'qk':qk_id,'qv':qv_id})
    def __len__(self): return len(self.samples)
    def __getitem__(self, i): return self.samples[i]


def train_hynt(model, model_name="HyNT", train_data_=None, valid_ds_=None):
    dev  = CONFIG['device']
    _td  = train_data_ or train_data
    _vds = valid_ds_   or valid_ds
    model.to(dev)
    opt    = torch.optim.Adam(model.parameters(), lr=CONFIG['hynt_lr'])
    loader = DataLoader(HyNTDataset(_td, preprocessor),
                        batch_size=CONFIG['hynt_batch'],
                        shuffle=True, collate_fn=mt_collate)
    print(f"\n{'='*60}\nTRAINING {model_name}  "
          f"({sum(p.numel() for p in model.parameters()):,} params)\n{'='*60}")
    best_mrr, history = 0.0, []
    for epoch in range(CONFIG['hynt_epochs']):
        model.train(); total, cnt = 0.0, 0
        for by_task in tqdm(loader, desc=f"Epoch {epoch+1}", leave=False):
            bl = torch.tensor(0., device=dev)
            for task, samples in by_task.items():
                for s in samples:
                    try:
                        if task == 'tail':
                            logits = model.forward_tail(s, dev)
                            tgt    = torch.tensor(s['t'], device=dev)
                            bl     = bl + F.cross_entropy(logits.unsqueeze(0), tgt.unsqueeze(0))
                        elif task == 'qv':
                            logits = model.forward_qv(s, dev)
                            tgt    = torch.tensor(s['qv'], device=dev)
                            bl     = bl + 0.8*F.cross_entropy(logits.unsqueeze(0), tgt.unsqueeze(0))
                    except Exception: continue
            opt.zero_grad(); bl.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step(); total += bl.item(); cnt += 1
        if (epoch+1) % 5 == 0:
            m = evaluate_model(model, _vds, dev)
            print(f"  Epoch {epoch+1}: loss={total/max(cnt,1):.4f}  "
                  f"MRR={m['mrr']:.4f}  H@1={m['hits@1']:.4f}  H@10={m['hits@10']:.4f}")
            if m['mrr'] > best_mrr:
                best_mrr = m['mrr']
                torch.save(model.state_dict(),
                           f"{CONFIG['output_path']}{model_name}_best.pt")
            history.append(m)
    return model, history

print("HyNT defined")


# ============================================================================
# 11 — Model 8: MultiTask AlertStar
# ============================================================================

class MultiTaskDataset(Dataset):
    def __init__(self, data, preprocessor, tasks=None, nqk_override=None):
        self.samples = []
        self.nqk     = nqk_override or NQK
        if tasks is None:
            tasks = ['tail','relation','qual_key','qual_value']
        p = preprocessor
        for triple in data:
            h  = p.entity2id[triple['head']]
            r  = p.relation2id[triple['relation']]
            t  = p.entity2id[triple['tail']]
            qs = [(p.qualifier_key2id[qk], p.qualifier_value2id[qv])
                  for qk, qv in triple['qualifiers']]
            if 'tail'       in tasks:
                self.samples.append({'task':'tail','h':h,'r':r,'t':t,'qs':qs})
            if 'relation'   in tasks:
                self.samples.append({'task':'relation','h':h,'r':r,'t':t,'qs':qs})
            if qs:
                if 'qual_key' in tasks:
                    self.samples.append({'task':'qual_key','h':h,'r':r,'t':t,
                                         'qs':qs,'keys':[qk for qk,_ in qs]})
                if 'qual_value' in tasks:
                    for qk, qv in qs:
                        self.samples.append({'task':'qual_value','h':h,'r':r,'t':t,
                                             'qs':qs,'qk':qk,'qv':qv})
    def __len__(self): return len(self.samples)
    def __getitem__(self, i): return self.samples[i]


class MultiTaskAlertStar(nn.Module):
    def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1,
                 n_heads=4, n_layers=3):
        super().__init__()
        self.num_entities = ne
        self.ent  = nn.Embedding(ne,  dim)
        self.rel  = nn.Embedding(nr,  dim)
        self.qk   = nn.Embedding(nqk, dim)
        self.qv   = nn.Embedding(nqv, dim)
        enc_layer = nn.TransformerEncoderLayer(
            d_model=dim, nhead=n_heads,
            dim_feedforward=dim*4, dropout=dropout, batch_first=True)
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
        def head(out):
            return nn.Sequential(nn.Linear(dim,dim), nn.LayerNorm(dim),
                                 nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim,out))
        self.tail_head = head(ne); self.rel_head  = head(nr)
        self.qk_head   = head(nqk); self.qv_head  = head(nqv)
        for e in [self.ent,self.rel,self.qk,self.qv]:
            nn.init.xavier_uniform_(e.weight)

    def _encode(self, h, r, t, qs, dev, mask_pos=None):
        tokens = [self.ent(torch.tensor([h],device=dev)),
                  self.rel(torch.tensor([r],device=dev)),
                  self.ent(torch.tensor([t],device=dev))]
        for qk_id, qv_id in qs:
            tokens += [self.qk(torch.tensor([qk_id],device=dev)),
                       self.qv(torch.tensor([qv_id],device=dev))]
        seq = torch.cat(tokens, dim=0).unsqueeze(0)
        if mask_pos is not None and mask_pos < seq.size(1):
            seq = seq.clone(); seq[0, mask_pos] = 0.0
        return self.encoder(seq)[0, 0]

    def forward_task(self, task, s, dev):
        h, r, t, qs = s['h'], s['r'], s['t'], s['qs']
        if task == 'tail':
            return self.tail_head(self._encode(h,r,0,qs,dev,mask_pos=2))
        elif task == 'relation':
            return self.rel_head(self._encode(h,r,t,qs,dev,mask_pos=1))
        elif task == 'qual_key':
            return self.qk_head(self._encode(h,r,t,[],dev))
        elif task == 'qual_value':
            filtered = [(k,v) for k,v in qs if k != s['qk']]
            return self.qv_head(self._encode(h,r,t,filtered,dev))

    def forward(self, head, relation, qualifiers, tail=None):
        dev  = head.device; B = head.size(0)
        outs = [self.forward_task('tail',{'h':head[i].item(),'r':relation[i].item(),
                                           't':0,'qs':qualifiers[i]},dev)
                for i in range(B)]
        scores = torch.stack(outs, dim=0)
        if tail is not None: return scores[torch.arange(B), tail]
        return scores


def train_multitask(model, model_name="MultiTask_AlertStar",
                    active_tasks=None, train_data_=None, valid_ds_=None,
                    nqk_override=None):
    if active_tasks is None:
        active_tasks = ['tail','relation','qual_key','qual_value']
    dev  = CONFIG['device']
    _td  = train_data_ or train_data
    _vds = valid_ds_   or valid_ds
    _nqk = nqk_override or NQK
    model.to(dev)
    opt    = torch.optim.Adam(model.parameters(), lr=CONFIG['mt_lr'])
    loader = DataLoader(MultiTaskDataset(_td, preprocessor, tasks=active_tasks,
                                         nqk_override=_nqk),
                        batch_size=CONFIG['mt_batch'],
                        shuffle=True, collate_fn=mt_collate)
    weights = {'tail':1.0,'relation':1.0,'qual_key':0.5,'qual_value':0.8}
    print(f"\n{'='*60}\nTRAINING {model_name}  tasks={active_tasks}\n{'='*60}")
    best_mrr, history = 0.0, []
    for epoch in range(CONFIG['mt_epochs']):
        model.train(); total, cnt = 0.0, 0
        for by_task in tqdm(loader, desc=f"Epoch {epoch+1}", leave=False):
            bl = torch.tensor(0., device=dev)
            for task, samples in by_task.items():
                for s in samples:
                    try:
                        logits = model.forward_task(task, s, dev)
                        w      = weights.get(task, 1.0)
                        if task == 'tail':
                            tgt  = torch.tensor(s['t'], device=dev)
                            loss = F.cross_entropy(logits.unsqueeze(0), tgt.unsqueeze(0))
                        elif task == 'relation':
                            tgt  = torch.tensor(s['r'], device=dev)
                            loss = F.cross_entropy(logits.unsqueeze(0), tgt.unsqueeze(0))
                        elif task == 'qual_key':
                            tgt  = torch.zeros(_nqk, device=dev)
                            tgt[s['keys']] = 1.0
                            loss = F.binary_cross_entropy_with_logits(logits, tgt)
                        elif task == 'qual_value':
                            tgt  = torch.tensor(s['qv'], device=dev)
                            loss = F.cross_entropy(logits.unsqueeze(0), tgt.unsqueeze(0))
                        else: continue
                        bl = bl + w*loss
                    except Exception: continue
            opt.zero_grad(); bl.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step(); total += bl.item(); cnt += 1
        if (epoch+1) % 5 == 0:
            m = evaluate_model(model, _vds, dev)
            print(f"  Epoch {epoch+1}: loss={total/max(cnt,1):.4f}  "
                  f"MRR={m['mrr']:.4f}  H@1={m['hits@1']:.4f}  H@10={m['hits@10']:.4f}")
            if m['mrr'] > best_mrr:
                best_mrr = m['mrr']
                torch.save(model.state_dict(),
                           f"{CONFIG['output_path']}{model_name}_best.pt")
            history.append(m)
    return model, history

print("MultiTask AlertStar defined")

# ============================================================================
# 12 — Model 9: HR-NBFNet (Hyper-Relational Bellman-Ford)
#
# Matches slide formulation exactly:
#   h(0)_uvqq' <- INDICATOR(u, v, q', q)
#   phi_q(h_qk, h_qv) = h_qk · h_qv            [DisMult per qualifier pair]
#   h_q = W_q · SUM phi_q                        [projected qualifier sum]
#   w_q = sigma(qual_gate(r))                    [per-relation scalar gate]
#   MSG = src_feat * (rel_emb + w_q * h_q)       [qualifier-gated message]
#   h(t) = AGG(msgs) + h(0)                      [shortcut to h(0) not h(t-1)]
# ============================================================================

def build_hr_nbfnet_graph(train_data, preprocessor, device,
                           max_quals=8, p_override=None):
    """
    Build qualifier-aware edge tensors.
    Unlike TrueNBFNet's plain [E,3] edge_list, stores per-edge qualifier
    pairs in a padded tensor [E*2, max_quals, 2] so each propagation
    layer can apply DisMult qualifier composition per edge.
    """
    p   = p_override or preprocessor
    nr  = len(p.relation2id)
    srcs, dsts, rels, qual_list, nquals = [], [], [], [], []
    for t in train_data:
        h  = p.entity2id[t['head']]
        r  = p.relation2id[t['relation']]
        tl = p.entity2id[t['tail']]
        qs = [(p.qualifier_key2id[qk], p.qualifier_value2id[qv])
              for qk, qv in t['qualifiers']]
        # forward + inverse edges — both carry the same qualifiers
        srcs.append(h);  dsts.append(tl); rels.append(r);     qual_list.append(qs); nquals.append(len(qs))
        srcs.append(tl); dsts.append(h);  rels.append(r+nr);  qual_list.append(qs); nquals.append(len(qs))

    E            = len(srcs)
    quals_tensor = torch.zeros(E, max_quals, 2, dtype=torch.long)
    for i, qs in enumerate(qual_list):
        for j, (qk, qv) in enumerate(qs[:max_quals]):
            quals_tensor[i, j, 0] = qk
            quals_tensor[i, j, 1] = qv

    return {
        'edge_src':    torch.tensor(srcs,   dtype=torch.long, device=device),
        'edge_dst':    torch.tensor(dsts,   dtype=torch.long, device=device),
        'edge_rel':    torch.tensor(rels,   dtype=torch.long, device=device),
        'edge_quals':  quals_tensor.to(device),       # [E, max_quals, 2]
        'edge_nquals': torch.tensor(nquals, dtype=torch.long, device=device),
        'num_nodes':   len(p.entity2id),
        'nr':          nr,
        'max_quals':   max_quals,
    }


class HRNBFConvLayer(nn.Module):
    def __init__(self, dim, num_relation, nqk, nqv,
                 chunk_size=5000, layer_norm=True, dropout=0.1):
        super().__init__()
        self.dim        = dim
        self.chunk_size = chunk_size
        self.rel_emb    = nn.Embedding(num_relation, dim)
        self.qk_emb     = nn.Embedding(nqk, dim)
        self.qv_emb     = nn.Embedding(nqv, dim)
        self.W_q        = nn.Linear(dim, dim, bias=False)
        self.qual_gate  = nn.Embedding(num_relation, 1)
        nn.init.ones_(self.qual_gate.weight)        # start fully open
        self.linear     = nn.Linear(dim*2, dim)
        self.ln         = nn.LayerNorm(dim) if layer_norm else None
        self.act        = nn.ReLU()
        self.drop       = nn.Dropout(dropout)
        for e in [self.rel_emb, self.qk_emb, self.qv_emb]:
            nn.init.xavier_uniform_(e.weight)
        nn.init.xavier_uniform_(self.W_q.weight)

    def _qualifier_embedding(self, edge_quals, edge_nquals, edge_rels):
        """h_q = w_q(r) * W_q · SUM (h_qk * h_qv)"""
        E, max_q, _ = edge_quals.shape
        h_qk  = self.qk_emb(edge_quals[:,:,0])         # [E, max_q, dim]
        h_qv  = self.qv_emb(edge_quals[:,:,1])         # [E, max_q, dim]
        phi_q = h_qk * h_qv                             # DisMult [E, max_q, dim]
        idx   = torch.arange(max_q, device=edge_quals.device).unsqueeze(0)
        mask  = (idx < edge_nquals.unsqueeze(1)).float().unsqueeze(2)
        phi_q = phi_q * mask                            # zero padding
        h_q   = self.W_q(phi_q.sum(dim=1))              # [E, dim]
        w_q   = torch.sigmoid(self.qual_gate(edge_rels)) # [E, 1]
        return w_q * h_q

    def forward(self, graph, node_feat, h0):
        src    = graph['edge_src'];  dst = graph['edge_dst']
        rel    = graph['edge_rel'];  quals = graph['edge_quals']
        nquals = graph['edge_nquals']
        N      = graph['num_nodes']; dev = node_feat.device
        agg    = torch.zeros(N, self.dim, device=dev)
        for start in range(0, src.size(0), self.chunk_size):
            end  = start + self.chunk_size
            s_   = src[start:end]; d_ = dst[start:end]
            r_   = rel[start:end]; qs_ = quals[start:end]; nqs = nquals[start:end]
            r_emb    = self.rel_emb(r_)
            h_q      = self._qualifier_embedding(qs_, nqs, r_)
            src_feat = node_feat[s_]
            msg      = src_feat * (r_emb + h_q)         # qualifier-gated message
            agg.scatter_add_(0, d_.unsqueeze(1).expand_as(msg), msg)
            del r_emb, h_q, src_feat, msg
        out = self.linear(torch.cat([node_feat, agg], dim=-1))
        if self.ln: out = self.ln(out)
        out = self.act(out); out = self.drop(out)
        return out + h0                                  # shortcut to h(0)


class HRNBFNet(nn.Module):
    def __init__(self, ne, nr, nqk, nqv, dim=200, num_layers=3,
                 chunk_size=5000, dropout=0.1, max_quals=8):
        super().__init__()
        self.num_entities = ne
        self.dim          = dim
        self.max_quals    = max_quals
        nr2               = nr * 2
        self.query_emb       = nn.Embedding(nr2, dim)
        self.query_qk_emb    = nn.Embedding(nqk, dim)
        self.query_qv_emb    = nn.Embedding(nqv, dim)
        self.query_qual_proj = nn.Linear(dim, dim, bias=False)
        self.layers = nn.ModuleList([
            HRNBFConvLayer(dim, nr2, nqk, nqv, chunk_size=chunk_size, dropout=dropout)
            for _ in range(num_layers)])
        self.mlp = nn.Sequential(
            nn.Linear(dim*2, dim), nn.LayerNorm(dim), nn.ReLU(),
            nn.Dropout(dropout), nn.Linear(dim, 1))
        for e in [self.query_emb, self.query_qk_emb, self.query_qv_emb]:
            nn.init.xavier_uniform_(e.weight)
        nn.init.xavier_uniform_(self.query_qual_proj.weight)

    def _indicator_init(self, graph, h_idx, r_idx, query_quals, device):
        """INDICATOR(u,v,q',q): source node gets rel + qualifier context."""
        N     = graph['num_nodes']
        q_rel = self.query_emb(torch.tensor([r_idx], device=device))
        if query_quals:
            qk_ids = torch.tensor([qk for qk,_ in query_quals], device=device)
            qv_ids = torch.tensor([qv for _,qv in query_quals], device=device)
            phi    = (self.query_qk_emb(qk_ids) *
                      self.query_qv_emb(qv_ids)).sum(0, keepdim=True)
            q_qual = self.query_qual_proj(phi)
        else:
            q_qual = torch.zeros(1, self.dim, device=device)
        feat        = torch.zeros(N, self.dim, device=device)
        feat[h_idx] = q_rel.squeeze() + q_qual.squeeze()
        return feat

    def _propagate(self, graph, h_idx, r_idx, query_quals, device):
        feat = self._indicator_init(graph, h_idx, r_idx, query_quals, device)
        h0   = feat.clone()
        for layer in self.layers:
            feat = layer(graph, feat, h0)
        return feat

    def forward(self, head, relation, qualifiers=None, tail=None, graph=None):
        assert graph is not None
        assert (head==head[0]).all() and (relation==relation[0]).all()
        dev         = head.device
        query_quals = qualifiers[0] if qualifiers else []
        feat        = self._propagate(graph, head[0].item(),
                                      relation[0].item(), query_quals, dev)
        q_emb       = self.query_emb(torch.tensor([relation[0].item()], device=dev))
        score_in    = torch.cat([feat, q_emb.expand(graph['num_nodes'],-1)], dim=-1)
        all_scores  = self.mlp(score_in).squeeze(-1)
        if tail is not None: return all_scores[tail]
        return all_scores.unsqueeze(0).expand(head.size(0), -1)


def train_hr_nbfnet(model, model_name="HR_NBFNet",
                    train_data_=None, valid_ds_=None, graph_=None,
                    p_override=None):
    dev    = CONFIG['device']
    _td    = train_data_ or train_data
    _vds   = valid_ds_   or valid_ds
    _g     = graph_      or hr_nbfnet_graph
    _p     = p_override  or preprocessor
    model.to(dev)
    graph  = {k: v.to(dev) if torch.is_tensor(v) else v for k,v in _g.items()}
    ne_    = graph['num_nodes']
    opt    = torch.optim.Adam(model.parameters(), lr=CONFIG['hr_nbfnet_lr'])
    mpg    = CONFIG['nbfnet_max_per_group']
    groups = defaultdict(list)
    for i, s in enumerate(_td):
        groups[(_p.entity2id[s['head']], _p.relation2id[s['relation']])].append(i)
    keys = list(groups.keys())
    print(f"\n{'='*60}\nTRAINING {model_name}  "
          f"({sum(p.numel() for p in model.parameters()):,} params)")
    print(f"  Graph: {graph['edge_src'].size(0):,} qualifier-aware edges\n{'='*60}")
    best_mrr, history = 0.0, []
    for epoch in range(CONFIG['hr_nbfnet_epochs']):
        model.train(); np.random.shuffle(keys)
        total, cnt = 0.0, 0
        for (h, r) in tqdm(keys, desc=f"Epoch {epoch+1}", leave=False):
            chosen    = np.random.choice(groups[(h,r)],
                                         min(len(groups[(h,r)]),mpg), replace=False)
            t_pos     = torch.tensor(
                [_p.entity2id[_td[i]['tail']] for i in chosen], device=dev)
            all_quals = [[(_p.qualifier_key2id[qk], _p.qualifier_value2id[qv])
                          for qk, qv in _td[i]['qualifiers']] for i in chosen]
            rep_quals = max(all_quals, key=len) if all_quals else []
            B         = len(t_pos)
            heads     = torch.full((B,), h, dtype=torch.long, device=dev)
            rels      = torch.full((B,), r, dtype=torch.long, device=dev)
            try:
                pos = model(heads, rels, qualifiers=[rep_quals]*B,
                            tail=t_pos, graph=graph)
                neg = model(heads, rels, qualifiers=[rep_quals]*B,
                            tail=torch.randint(0,ne_,(B,),device=dev), graph=graph)
                loss = F.margin_ranking_loss(pos, neg,
                                             torch.ones(B,device=dev), margin=1.0)
                opt.zero_grad(); loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                opt.step(); total += loss.item(); cnt += 1
            except Exception: continue
        if (epoch+1) % 5 == 0:
            m = evaluate_hr_nbfnet(model, _vds, graph, dev, p_override=_p)
            print(f"  Epoch {epoch+1}: loss={total/max(cnt,1):.4f}  "
                  f"MRR={m['mrr']:.4f}  MR={m['mr']:.1f}  "
                  f"H@1={m['hits@1']:.4f}  H@3={m['hits@3']:.4f}  "
                  f"H@10={m['hits@10']:.4f}")
            if m['mrr'] > best_mrr:
                best_mrr = m['mrr']
                torch.save(model.state_dict(),
                           f"{CONFIG['output_path']}{model_name}_best.pt")
                print("    → best saved")
            history.append(m)
    return model, history


def evaluate_hr_nbfnet(model, dataset, graph, device,
                        max_groups=300, p_override=None):
    _p = p_override or preprocessor
    model.eval()
    groups = defaultdict(list)
    for i, s in enumerate(dataset.data):
        groups[(_p.entity2id[s['head']], _p.relation2id[s['relation']])].append(i)
    ranks = []
    with torch.no_grad():
        for (h,r) in tqdm(list(groups.keys())[:max_groups],
                          desc="Eval HR-NBFNet", leave=False):
            idxs  = groups[(h,r)]
            tails = [_p.entity2id[dataset.data[i]['tail']] for i in idxs]
            B     = len(tails)
            heads = torch.full((B,),h,dtype=torch.long,device=device)
            rels  = torch.full((B,),r,dtype=torch.long,device=device)
            s0    = dataset.data[idxs[0]]
            quals = [(_p.qualifier_key2id[qk], _p.qualifier_value2id[qv])
                     for qk,qv in s0['qualifiers']]
            try:
                scores = model(heads,rels,qualifiers=[quals]*B,graph=graph)
                for i, tgt in enumerate(tails):
                    pos = (torch.argsort(scores[i],descending=True)==tgt
                           ).nonzero(as_tuple=True)[0]
                    if len(pos): ranks.append(pos[0].item()+1)
            except Exception: continue
    if not ranks:
        return {'mr':0.,'mrr':0.,'hits@1':0.,'hits@3':0.,'hits@10':0.}
    ranks = np.array(ranks)
    return {'mr':float(np.mean(ranks)),'mrr':float(np.mean(1./ranks)),
            'hits@1':float(np.mean(ranks<=1)),'hits@3':float(np.mean(ranks<=3)),
            'hits@10':float(np.mean(ranks<=10))}

print("HR-NBFNet defined")


# ============================================================================
# 13 — Model 10: MultiTask_HR_NBFNet (NEW)
#
# Combines HR-NBFNet's qualifier-aware Bellman-Ford propagation with the
# 4-task multi-task training strategy from MultiTask AlertStar.
#
# Architecture:
#   - HR-NBFNet backbone: propagates query-conditioned features with
#     per-edge DisMult qualifier embeddings → node feature f^L[v] ∈ R^d
#   - 4 prediction heads sharing the backbone, each a 2-layer MLP:
#       tail:       MLP(cat(f^L[e'], q_emb))  →  R^1   [margin ranking]
#       relation:   MLP(f^L[h])               →  R^|R| [cross-entropy]
#       qual_key:   MLP(f^L[h])               →  R^|QK| [BCE multi-label]
#       qual_value: MLP(f^L[h] · qk_emb)     →  R^|QV| [cross-entropy]
#
# Key design choices vs MultiTask AlertStar (MT-AS):
#   - MT-AS: Transformer over flat token sequence (local, no graph)
#   - MT-HR: Bellman-Ford over HR graph (global, structure-aware)
#   - Relation/qual-key/qual-value heads use head-node BF representation,
#     so auxiliary tasks receive path-enriched graph signals
#   - Qualifier value head gates on qk_emb (same as HyNT) for fine-grained
#     attribute discrimination
# ============================================================================

class MultiTaskHRNBFNet(nn.Module):
    def __init__(self, ne, nr, nqk, nqv, dim=200, num_layers=3,
                 chunk_size=5000, dropout=0.1, max_quals=8):
        super().__init__()
        self.num_entities = ne
        self.nr           = nr
        self.nqk          = nqk
        self.nqv          = nqv
        self.dim          = dim
        # ── Shared HR-NBFNet backbone ─────────────────────────────────────
        nr2               = nr * 2
        self.query_emb       = nn.Embedding(nr2, dim)
        self.query_qk_emb    = nn.Embedding(nqk, dim)
        self.query_qv_emb    = nn.Embedding(nqv, dim)
        self.query_qual_proj = nn.Linear(dim, dim, bias=False)
        self.layers = nn.ModuleList([
            HRNBFConvLayer(dim, nr2, nqk, nqv,
                           chunk_size=chunk_size, dropout=dropout)
            for _ in range(num_layers)])
        # ── Prediction heads ──────────────────────────────────────────────
        def _mlp_head(in_dim, out_dim):
            return nn.Sequential(
                nn.Linear(in_dim, dim), nn.LayerNorm(dim), nn.ReLU(),
                nn.Dropout(dropout), nn.Linear(dim, out_dim))
        # tail: concat(f^L[e'], q_emb) → scalar
        self.tail_mlp    = _mlp_head(dim*2, 1)
        # relation: f^L[h] → |R|  (use head node's BF representation)
        self.rel_head    = _mlp_head(dim, nr)
        # qual_key: f^L[h] → |QK|  (multi-label)
        self.qk_head     = _mlp_head(dim, nqk)
        # qual_value: gate(f^L[h], qk_emb) → |QV|
        self.qv_gate     = nn.Sequential(nn.Linear(dim*2, dim), nn.Sigmoid())
        self.qv_head     = _mlp_head(dim, nqv)
        # qualifier key embedding (for qv gating)
        self.qk_emb_head = nn.Embedding(nqk, dim)
        for e in [self.query_emb, self.query_qk_emb, self.query_qv_emb,
                  self.qk_emb_head]:
            nn.init.xavier_uniform_(e.weight)
        nn.init.xavier_uniform_(self.query_qual_proj.weight)

    # ── backbone: shared with HRNBFNet ────────────────────────────────────
    def _indicator_init(self, graph, h_idx, r_idx, query_quals, device):
        N     = graph['num_nodes']
        q_rel = self.query_emb(torch.tensor([r_idx], device=device))
        if query_quals:
            qk_ids = torch.tensor([qk for qk,_ in query_quals], device=device)
            qv_ids = torch.tensor([qv for _,qv in query_quals], device=device)
            phi    = (self.query_qk_emb(qk_ids) *
                      self.query_qv_emb(qv_ids)).sum(0, keepdim=True)
            q_qual = self.query_qual_proj(phi)
        else:
            q_qual = torch.zeros(1, self.dim, device=device)
        feat        = torch.zeros(N, self.dim, device=device)
        feat[h_idx] = q_rel.squeeze() + q_qual.squeeze()
        return feat

    def _propagate(self, graph, h_idx, r_idx, query_quals, device):
        feat = self._indicator_init(graph, h_idx, r_idx, query_quals, device)
        h0   = feat.clone()
        for layer in self.layers:
            feat = layer(graph, feat, h0)
        return feat                                  # [N, dim]

    # ── task-specific forward passes ──────────────────────────────────────
    def forward_tail_task(self, graph, h_idx, r_idx, query_quals, tail_ids, device):
        """Tail prediction: score MLP(cat(f^L[tail], q_emb))."""
        feat  = self._propagate(graph, h_idx, r_idx, query_quals, device)
        q_emb = self.query_emb(torch.tensor([r_idx], device=device))
        # score all entities
        score_in    = torch.cat([feat, q_emb.expand(graph['num_nodes'],-1)], dim=-1)
        all_scores  = self.tail_mlp(score_in).squeeze(-1)   # [N]
        if tail_ids is not None:
            return all_scores[tail_ids]
        return all_scores

    def forward_rel_task(self, graph, h_idx, r_idx, query_quals, device):
        """Relation prediction: classify from head node's BF representation."""
        feat = self._propagate(graph, h_idx, r_idx, query_quals, device)
        return self.rel_head(feat[h_idx])            # [nr]

    def forward_qk_task(self, graph, h_idx, r_idx, device):
        """Qualifier key prediction (multi-label) from head BF representation."""
        # run with empty qualifiers so we don't leak qk info
        feat = self._propagate(graph, h_idx, r_idx, [], device)
        return self.qk_head(feat[h_idx])             # [nqk]

    def forward_qv_task(self, graph, h_idx, r_idx, query_quals,
                         target_qk, device):
        """Qualifier value prediction gated on the target qualifier key."""
        feat      = self._propagate(graph, h_idx, r_idx, query_quals, device)
        h_repr    = feat[h_idx]
        qk_emb    = self.qk_emb_head(torch.tensor([target_qk], device=device)).squeeze()
        gate      = self.qv_gate(torch.cat([h_repr, qk_emb], dim=-1))
        return self.qv_head(gate * h_repr)           # [nqv]

    # ── standard forward for evaluate_model compatibility ─────────────────
    def forward(self, head, relation, qualifiers=None, tail=None, graph=None):
        assert graph is not None
        assert (head==head[0]).all() and (relation==relation[0]).all()
        dev         = head.device
        query_quals = qualifiers[0] if qualifiers else []
        all_scores  = self.forward_tail_task(
            graph, head[0].item(), relation[0].item(),
            query_quals, None, dev)
        if tail is not None: return all_scores[tail]
        return all_scores.unsqueeze(0).expand(head.size(0), -1)


def train_mt_hr_nbfnet(model, model_name="MultiTask_HR_NBFNet",
                        train_data_=None, valid_ds_=None, graph_=None,
                        p_override=None, nqk_override=None, nqv_override=None):
    """
    Multi-task training for HR-NBFNet backbone.
    Each (h,r) group triggers one BF propagation pass; auxiliary tasks
    (relation, qual_key, qual_value) add supervision on the head node's
    final graph representation f^L[h].
    Task weights: tail=1.0, relation=0.8, qual_key=0.5, qual_value=0.8
    """
    dev    = CONFIG['device']
    _td    = train_data_ or train_data
    _vds   = valid_ds_   or valid_ds
    _g     = graph_      or hr_nbfnet_graph
    _p     = p_override  or preprocessor
    _nqk   = nqk_override or NQK
    _nqv   = nqv_override or NQV
    model.to(dev)
    graph  = {k: v.to(dev) if torch.is_tensor(v) else v for k,v in _g.items()}
    ne_    = graph['num_nodes']
    opt    = torch.optim.Adam(model.parameters(), lr=CONFIG['mt_hr_lr'])
    mpg    = CONFIG['nbfnet_max_per_group']
    weights = {'tail':1.0, 'relation':0.8, 'qual_key':0.5, 'qual_value':0.8}
    # Group training triples by (h, r)
    groups = defaultdict(list)
    for i, s in enumerate(_td):
        groups[(_p.entity2id[s['head']], _p.relation2id[s['relation']])].append(i)
    keys = list(groups.keys())
    print(f"\n{'='*60}\nTRAINING {model_name}  "
          f"({sum(p.numel() for p in model.parameters()):,} params)")
    print(f"  Graph: {graph['edge_src'].size(0):,} qualifier-aware edges")
    print(f"  Tasks: tail / relation / qual_key / qual_value\n{'='*60}")
    best_mrr, history = 0.0, []
    for epoch in range(CONFIG['mt_hr_epochs']):
        model.train(); np.random.shuffle(keys)
        total, cnt = 0.0, 0
        for (h, r) in tqdm(keys, desc=f"Epoch {epoch+1}", leave=False):
            chosen    = np.random.choice(groups[(h,r)],
                                         min(len(groups[(h,r)]),mpg), replace=False)
            samples   = [_td[i] for i in chosen]
            all_quals = [[(_p.qualifier_key2id[qk], _p.qualifier_value2id[qv])
                          for qk,qv in s['qualifiers']] for s in samples]
            rep_quals = max(all_quals, key=len) if all_quals else []
            t_ids     = torch.tensor(
                [_p.entity2id[s['tail']] for s in samples], device=dev)
            neg_ids   = torch.randint(0, ne_, (len(samples),), device=dev)
            batch_loss = torch.tensor(0., device=dev)
            try:
                # ── Task 1: Tail prediction (margin ranking) ───────────────
                pos_scores = model.forward_tail_task(
                    graph, h, r, rep_quals, t_ids, dev)
                neg_scores = model.forward_tail_task(
                    graph, h, r, rep_quals, neg_ids, dev)
                loss_tail  = F.margin_ranking_loss(
                    pos_scores, neg_scores,
                    torch.ones(len(samples), device=dev), margin=1.0)
                batch_loss = batch_loss + weights['tail'] * loss_tail

                # ── Task 2: Relation prediction (cross-entropy) ────────────
                rel_logits = model.forward_rel_task(graph, h, r, rep_quals, dev)
                loss_rel   = F.cross_entropy(
                    rel_logits.unsqueeze(0),
                    torch.tensor([r], device=dev))
                batch_loss = batch_loss + weights['relation'] * loss_rel

                # ── Task 3: Qualifier key prediction (BCE multi-label) ─────
                if rep_quals:
                    qk_logits = model.forward_qk_task(graph, h, r, dev)
                    qk_target = torch.zeros(_nqk, device=dev)
                    qk_target[[qk for qk,_ in rep_quals[:_nqk]]] = 1.0
                    loss_qk   = F.binary_cross_entropy_with_logits(
                        qk_logits, qk_target)
                    batch_loss = batch_loss + weights['qual_key'] * loss_qk

                # ── Task 4: Qualifier value prediction (cross-entropy) ─────
                if rep_quals:
                    tgt_qk, tgt_qv = rep_quals[0]
                    qv_logits = model.forward_qv_task(
                        graph, h, r, rep_quals, tgt_qk, dev)
                    loss_qv   = F.cross_entropy(
                        qv_logits.unsqueeze(0),
                        torch.tensor([tgt_qv], device=dev))
                    batch_loss = batch_loss + weights['qual_value'] * loss_qv

                opt.zero_grad(); batch_loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                opt.step(); total += batch_loss.item(); cnt += 1
            except Exception: continue

        if (epoch+1) % 5 == 0:
            m = evaluate_hr_nbfnet(model, _vds, graph, dev, p_override=_p)
            print(f"  Epoch {epoch+1}: loss={total/max(cnt,1):.4f}  "
                  f"MRR={m['mrr']:.4f}  MR={m['mr']:.1f}  "
                  f"H@1={m['hits@1']:.4f}  H@3={m['hits@3']:.4f}  "
                  f"H@10={m['hits@10']:.4f}")
            if m['mrr'] > best_mrr:
                best_mrr = m['mrr']
                torch.save(model.state_dict(),
                           f"{CONFIG['output_path']}{model_name}_best.pt")
                print("    → best saved")
            history.append(m)
    return model, history

print("MultiTask_HR_NBFNet defined")


# ============================================================================
# 14 — TRAIN ALL 10 MAIN MODELS
# ============================================================================

print("\n" + "="*70)
print("PHASE 1: TRAINING ALL 10 MAIN MODELS")
print("="*70)

# ── 1. StarE ─────────────────────────────────────────────────────────────
stare_model = StarEModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout'])
stare_model, _, _ = train_standard(stare_model, "StarE")
all_results['StarE'] = evaluate_model(stare_model, test_ds, CONFIG['device'])

# ── 2. ShrinkE ───────────────────────────────────────────────────────────
shrinke_model = ShrinkEModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout'])
shrinke_model, _, _ = train_standard(shrinke_model, "ShrinkE")
all_results['ShrinkE'] = evaluate_model(shrinke_model, test_ds, CONFIG['device'])

# ── 3. TrueNBFNet ────────────────────────────────────────────────────────
nbfnet_graph = build_nbfnet_graph(train_data, preprocessor, CONFIG['device'])
nbfnet_model = TrueNBFNet(NE, NR, dim=DIM, num_layers=CONFIG['nbfnet_layers'],
                           chunk_size=CONFIG['nbfnet_chunk_size'],
                           dropout=CONFIG['dropout'], short_cut=True)
nbfnet_model, _ = train_true_nbfnet(nbfnet_model, "TrueNBFNet")
nbfnet_graph_dev = {k: v.to(CONFIG['device']) if torch.is_tensor(v) else v
                    for k,v in nbfnet_graph.items()}
all_results['TrueNBFNet'] = evaluate_nbfnet(
    nbfnet_model, test_ds, nbfnet_graph_dev, CONFIG['device'])

# ── 4. AlertStar ─────────────────────────────────────────────────────────
alertstar_model = AlertStarModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout'])
alertstar_model, _, gate_history = train_standard(
    alertstar_model, "AlertStar", gate_track=True)
all_results['AlertStar'] = evaluate_model(alertstar_model, test_ds, CONFIG['device'])

# ── 5. StarQE ────────────────────────────────────────────────────────────
starqe_model = StarQEModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout'])
starqe_model, _ = train_starqe(starqe_model, "StarQE")
all_results['StarQE'] = evaluate_model(starqe_model, test_ds, CONFIG['device'])

# ── 6. NBFNet+StarQE ──────────────────────────────────────────────────────
nbfstarqe_model = NBFNetStarQEModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout'])
nbfstarqe_model, _ = train_starqe(nbfstarqe_model, "NBFNet_StarQE")
all_results['NBFNet_StarQE'] = evaluate_model(nbfstarqe_model, test_ds, CONFIG['device'])

# ── 7. HyNT ──────────────────────────────────────────────────────────────
hynt_model = HyNTModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout'],
                        n_heads=CONFIG['hynt_n_heads'],
                        n_layers=CONFIG['hynt_n_layers'])
hynt_model, _ = train_hynt(hynt_model, "HyNT")
all_results['HyNT'] = evaluate_model(hynt_model, test_ds, CONFIG['device'])

# ── 8. MultiTask AlertStar ────────────────────────────────────────────────
mt_model = MultiTaskAlertStar(NE, NR, NQK, NQV, DIM, dropout=CONFIG['dropout'])
mt_model, _ = train_multitask(mt_model, "MultiTask_AlertStar")
all_results['MultiTask_AlertStar'] = evaluate_model(mt_model, test_ds, CONFIG['device'])

# ── 9. HR-NBFNet ─────────────────────────────────────────────────────────
print("\nBuilding HR-NBFNet qualifier-aware graph...")
hr_nbfnet_graph = build_hr_nbfnet_graph(
    train_data, preprocessor, CONFIG['device'],
    max_quals=CONFIG['hr_nbfnet_max_quals'])
print(f"HR graph: {hr_nbfnet_graph['edge_src'].size(0):,} edges")

hr_model = HRNBFNet(NE, NR, NQK, NQV, DIM,
                     num_layers=CONFIG['hr_nbfnet_layers'],
                     chunk_size=CONFIG['hr_nbfnet_chunk_size'],
                     dropout=CONFIG['dropout'],
                     max_quals=CONFIG['hr_nbfnet_max_quals'])
hr_model, _ = train_hr_nbfnet(hr_model, "HR_NBFNet")
hr_graph_dev = {k: v.to(CONFIG['device']) if torch.is_tensor(v) else v
                for k,v in hr_nbfnet_graph.items()}
all_results['HR_NBFNet'] = evaluate_hr_nbfnet(
    hr_model, test_ds, hr_graph_dev, CONFIG['device'])

# ── 10. MultiTask_HR_NBFNet (NEW) ─────────────────────────────────────────
# Reuses hr_nbfnet_graph — same qualifier-aware edge structure
mt_hr_model = MultiTaskHRNBFNet(NE, NR, NQK, NQV, DIM,
                                 num_layers=CONFIG['hr_nbfnet_layers'],
                                 chunk_size=CONFIG['hr_nbfnet_chunk_size'],
                                 dropout=CONFIG['dropout'],
                                 max_quals=CONFIG['hr_nbfnet_max_quals'])
mt_hr_model, _ = train_mt_hr_nbfnet(mt_hr_model, "MultiTask_HR_NBFNet")
all_results['MultiTask_HR_NBFNet'] = evaluate_hr_nbfnet(
    mt_hr_model, test_ds, hr_graph_dev, CONFIG['device'])

# ── Summary ───────────────────────────────────────────────────────────────
print("\n All 10 models trained")
print("\nMAIN RESULTS:")
for m, r in all_results.items():
    tag = " ← NEW" if m in ('HR_NBFNet','MultiTask_HR_NBFNet') else ""
    print(f"  {m:30s}  MRR={r['mrr']:.4f}  MR={r['mr']:7.1f}  "
          f"H@1={r['hits@1']:.4f}  H@3={r['hits@3']:.4f}  "
          f"H@10={r['hits@10']:.4f}{tag}")


# ============================================================================
# 15 — ABLATION A1: AlertStar Component Ablation
# ============================================================================

print("\n" + "="*70)
print("ABLATION A1: AlertStar Component Ablation")
print("="*70)

ablation_configs = [
    ("AS-NoQual", dict(use_qual=False, use_path=True,  fixed_gate=None)),
    ("AS-NoPath", dict(use_qual=True,  use_path=False, fixed_gate=None)),
    ("AS-NoGate", dict(use_qual=True,  use_path=True,  fixed_gate=0.5)),
    ("AS-Full",   dict(use_qual=True,  use_path=True,  fixed_gate=None)),
]
ablation_A1 = {}
for name, cfg in ablation_configs:
    m = AlertStarModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout'], **cfg)
    m, _, _ = train_standard(m, f"Ablation_{name}")
    ablation_A1[name] = evaluate_model(m, test_ds, CONFIG['device'])
    r = ablation_A1[name]
    print(f"  {name}: MRR={r['mrr']:.4f}  H@1={r['hits@1']:.4f}  H@10={r['hits@10']:.4f}")

ablation_results['A1_AlertStar_Components'] = ablation_A1
print("Ablation A1 complete")


# ============================================================================
# 16 — ABLATION A2: Gate Value Trajectory
# ============================================================================

print("\n" + "="*70)
print("ABLATION A2: Gate Value Analysis")
print("="*70)

if gate_history:
    gate_df    = pd.DataFrame(gate_history)
    print(gate_df.to_string(index=False))
    final_gate = gate_history[-1]['gate']
    print(f"\n  Final gate g = {final_gate:.4f}")
    if   final_gate > 0.6: print("  → Attention stream dominates (qualifier-awareness)")
    elif final_gate < 0.4: print("  → Path stream dominates (structural reasoning)")
    else:                   print("  → Balanced — both streams equally useful")
    ablation_results['A2_Gate_Values'] = gate_history
else:
    print("  No gate history available")

print("Ablation A2 complete")


# ============================================================================
# 17 — ABLATION A3: MultiTask Auxiliary Task Ablation
# ============================================================================

print("\n" + "="*70)
print("ABLATION A3: MultiTask AlertStar — Task Contribution")
print("="*70)

mt_ablation_configs = [
    ("MT-TailOnly",     ['tail']),
    ("MT-Tail+Rel",     ['tail','relation']),
    ("MT-Tail+QualKey", ['tail','qual_key']),
    ("MT-Tail+QualVal", ['tail','qual_value']),
    ("MT-Full",         ['tail','relation','qual_key','qual_value']),
]
ablation_A3 = {}
for name, tasks in mt_ablation_configs:
    m = MultiTaskAlertStar(NE, NR, NQK, NQV, DIM, dropout=CONFIG['dropout'])
    m, _ = train_multitask(m, f"Ablation_{name}", active_tasks=tasks)
    ablation_A3[name] = evaluate_model(m, test_ds, CONFIG['device'])
    r = ablation_A3[name]
    print(f"  {name}: MRR={r['mrr']:.4f}  H@1={r['hits@1']:.4f}  H@10={r['hits@10']:.4f}")

ablation_results['A3_MultiTask_Tasks'] = ablation_A3
print("Ablation A3 complete")


# ============================================================================
# 18 — ABLATION A4: Qualifier Density (Q33 / Q66 / Q100)
#
# Models compared: StarE, AlertStar, HyNT, MultiTask-AS,
#                  HR-NBFNet, MultiTask_HR_NBFNet
# For Q33: reuse main results.  For Q66/Q100: retrain all 6.
# ============================================================================

print("\n" + "="*70)
print("ABLATION A4: Qualifier Density Sensitivity")
print("="*70)

DENSITY_PATHS = {
    'Q100':  CONFIG['data_path'],
    'Q33':  CONFIG['q33_path'],
    'Q66': CONFIG['q66_path'],
}
DENSITY_MODELS = ['StarE','AlertStar','HyNT',
                  'MultiTask_AlertStar','HR_NBFNet','MultiTask_HR_NBFNet']

ablation_A4 = {}

for density_label, dpath in DENSITY_PATHS.items():
    print(f"\n{'='*50}  {density_label}  {'='*50}")

    if dpath == CONFIG['data_path']:
        ablation_A4[density_label] = {
            m: all_results[m] for m in DENSITY_MODELS if m in all_results}
        print(f"  Reusing trained models for {density_label}")
        continue

    # Load density-specific dataset
    p2       = DataPreprocessor()
    tr2, va2, te2 = p2.load(dpath)
    ne2  = len(p2.entity2id);   nr2  = len(p2.relation2id)
    nqk2 = len(p2.qualifier_key2id); nqv2 = len(p2.qualifier_value2id)
    tr_ds2 = HRDataset(tr2, p2);  va_ds2 = HRDataset(va2, p2);  te_ds2 = HRDataset(te2, p2)
    density_res = {}

    # StarE
    m = StarEModel(ne2, nr2, nqk2, nqv2, DIM, CONFIG['dropout'])
    m, _, _ = train_standard(m, f"A4_{density_label}_StarE",
                              train_ds_=tr_ds2, valid_ds_=va_ds2, ne_override=ne2)
    density_res['StarE'] = evaluate_model(m, te_ds2, CONFIG['device'])

    # AlertStar
    m = AlertStarModel(ne2, nr2, nqk2, nqv2, DIM, CONFIG['dropout'])
    m, _, _ = train_standard(m, f"A4_{density_label}_AlertStar",
                              train_ds_=tr_ds2, valid_ds_=va_ds2, ne_override=ne2)
    density_res['AlertStar'] = evaluate_model(m, te_ds2, CONFIG['device'])

    # HyNT — build its own HyNTDataset with p2
    _ht2 = HyNTDataset(tr2, p2)
    m    = HyNTModel(ne2, nr2, nqk2, nqv2, DIM, CONFIG['dropout'],
                     n_heads=4, n_layers=2)
    m.to(CONFIG['device'])
    _opt2   = torch.optim.Adam(m.parameters(), lr=CONFIG['hynt_lr'])
    _loader2 = DataLoader(_ht2, batch_size=CONFIG['hynt_batch'],
                           shuffle=True, collate_fn=mt_collate)
    _best2 = 0.0
    for _ep in range(CONFIG['hynt_epochs']):
        m.train()
        for _bt in tqdm(_loader2, desc=f"A4 HyNT {density_label} ep{_ep+1}", leave=False):
            _bl = torch.tensor(0., device=CONFIG['device'])
            for _task, _samps in _bt.items():
                for _s in _samps:
                    try:
                        if _task == 'tail':
                            _lg  = m.forward_tail(_s, CONFIG['device'])
                            _tg  = torch.tensor(_s['t'], device=CONFIG['device'])
                            _bl  = _bl + F.cross_entropy(_lg.unsqueeze(0), _tg.unsqueeze(0))
                        elif _task == 'qv':
                            _lg  = m.forward_qv(_s, CONFIG['device'])
                            _tg  = torch.tensor(_s['qv'], device=CONFIG['device'])
                            _bl  = _bl + 0.8*F.cross_entropy(_lg.unsqueeze(0), _tg.unsqueeze(0))
                    except Exception: continue
            _opt2.zero_grad(); _bl.backward()
            torch.nn.utils.clip_grad_norm_(m.parameters(), 1.0); _opt2.step()
        if (_ep+1) % 5 == 0:
            _mv = evaluate_model(m, va_ds2, CONFIG['device'])
            if _mv['mrr'] > _best2:
                _best2 = _mv['mrr']
                torch.save(m.state_dict(),
                           f"{CONFIG['output_path']}A4_{density_label}_HyNT_best.pt")
    density_res['HyNT'] = evaluate_model(m, te_ds2, CONFIG['device'])

    # MultiTask-AS
    m = MultiTaskAlertStar(ne2, nr2, nqk2, nqv2, DIM, dropout=CONFIG['dropout'])
    _mt2_ds  = MultiTaskDataset(tr2, p2, nqk_override=nqk2)
    _mt2_ldr = DataLoader(_mt2_ds, batch_size=CONFIG['mt_batch'],
                           shuffle=True, collate_fn=mt_collate)
    _wts2 = {'tail':1.0,'relation':1.0,'qual_key':0.5,'qual_value':0.8}
    m.to(CONFIG['device']); _mo2 = torch.optim.Adam(m.parameters(), lr=CONFIG['mt_lr'])
    _bm2 = 0.0
    for _ep in range(CONFIG['mt_epochs']):
        m.train()
        for _bt in tqdm(_mt2_ldr, desc=f"A4 MT {density_label} ep{_ep+1}", leave=False):
            _bl = torch.tensor(0., device=CONFIG['device'])
            for _task, _samps in _bt.items():
                for _s in _samps:
                    try:
                        _lg = m.forward_task(_task, _s, CONFIG['device'])
                        _w  = _wts2.get(_task, 1.0)
                        if _task == 'tail':
                            _tg = torch.tensor(_s['t'], device=CONFIG['device'])
                            _ls = F.cross_entropy(_lg.unsqueeze(0), _tg.unsqueeze(0))
                        elif _task == 'relation':
                            _tg = torch.tensor(_s['r'], device=CONFIG['device'])
                            _ls = F.cross_entropy(_lg.unsqueeze(0), _tg.unsqueeze(0))
                        elif _task == 'qual_key':
                            _tg = torch.zeros(nqk2, device=CONFIG['device'])
                            _tg[_s['keys']] = 1.0
                            _ls = F.binary_cross_entropy_with_logits(_lg, _tg)
                        elif _task == 'qual_value':
                            _tg = torch.tensor(_s['qv'], device=CONFIG['device'])
                            _ls = F.cross_entropy(_lg.unsqueeze(0), _tg.unsqueeze(0))
                        else: continue
                        _bl = _bl + _w*_ls
                    except Exception: continue
            _mo2.zero_grad(); _bl.backward()
            torch.nn.utils.clip_grad_norm_(m.parameters(), 1.0); _mo2.step()
        if (_ep+1) % 5 == 0:
            _mv = evaluate_model(m, va_ds2, CONFIG['device'])
            if _mv['mrr'] > _bm2:
                _bm2 = _mv['mrr']
                torch.save(m.state_dict(),
                           f"{CONFIG['output_path']}A4_{density_label}_MT_best.pt")
    density_res['MultiTask_AlertStar'] = evaluate_model(m, te_ds2, CONFIG['device'])

    # HR-NBFNet for this density
    _hr_g2 = build_hr_nbfnet_graph(tr2, p2, CONFIG['device'],
                                    max_quals=CONFIG['hr_nbfnet_max_quals'],
                                    p_override=p2)
    m = HRNBFNet(ne2, nr2, nqk2, nqv2, DIM,
                 num_layers=CONFIG['hr_nbfnet_layers'],
                 chunk_size=CONFIG['hr_nbfnet_chunk_size'],
                 dropout=CONFIG['dropout'], max_quals=CONFIG['hr_nbfnet_max_quals'])
    m, _ = train_hr_nbfnet(m, f"A4_{density_label}_HR_NBFNet",
                            train_data_=tr2, valid_ds_=va_ds2,
                            graph_=_hr_g2, p_override=p2)
    _hr_g2_dev = {k: v.to(CONFIG['device']) if torch.is_tensor(v) else v
                  for k,v in _hr_g2.items()}
    density_res['HR_NBFNet'] = evaluate_hr_nbfnet(
        m, te_ds2, _hr_g2_dev, CONFIG['device'], p_override=p2)

    # MultiTask_HR_NBFNet for this density
    m = MultiTaskHRNBFNet(ne2, nr2, nqk2, nqv2, DIM,
                          num_layers=CONFIG['hr_nbfnet_layers'],
                          chunk_size=CONFIG['hr_nbfnet_chunk_size'],
                          dropout=CONFIG['dropout'],
                          max_quals=CONFIG['hr_nbfnet_max_quals'])
    m, _ = train_mt_hr_nbfnet(m, f"A4_{density_label}_MultiTask_HR_NBFNet",
                               train_data_=tr2, valid_ds_=va_ds2,
                               graph_=_hr_g2, p_override=p2,
                               nqk_override=nqk2, nqv_override=nqv2)
    density_res['MultiTask_HR_NBFNet'] = evaluate_hr_nbfnet(
        m, te_ds2, _hr_g2_dev, CONFIG['device'], p_override=p2)

    ablation_A4[density_label] = density_res
    for mname, metrics in density_res.items():
        print(f"  {density_label}  {mname:30s}  MRR={metrics['mrr']:.4f}  "
              f"H@1={metrics['hits@1']:.4f}  H@10={metrics['hits@10']:.4f}")

ablation_results['A4_Qualifier_Density'] = ablation_A4
print("\n Ablation A4 complete")


# ============================================================================
# 19 — COMPLETE RESULTS TABLES
# ============================================================================

print("\n" + "="*80)
print("TABLE 1: COMPLETE 10-MODEL COMPARISON")
print("="*80)
df_main = pd.DataFrame({
    'Model':   list(all_results.keys()),
    'MR':      [all_results[m]['mr']      for m in all_results],
    'MRR':     [all_results[m]['mrr']     for m in all_results],
    'Hits@1':  [all_results[m]['hits@1']  for m in all_results],
    'Hits@3':  [all_results[m]['hits@3']  for m in all_results],
    'Hits@10': [all_results[m]['hits@10'] for m in all_results],
}).sort_values('MRR', ascending=False).reset_index(drop=True)
print(df_main.to_string(index=False))

print("\n" + "="*80)
print("TABLE 2: TrueNBFNet vs HR-NBFNet vs MultiTask_HR_NBFNet")
print("="*80)
for mname in ['TrueNBFNet','HR_NBFNet','MultiTask_HR_NBFNet']:
    r = all_results.get(mname, {})
    print(f"  {mname:30s}  MRR={r.get('mrr',0):.4f}  MR={r.get('mr',0):7.1f}  "
          f"H@1={r.get('hits@1',0):.4f}  H@3={r.get('hits@3',0):.4f}  "
          f"H@10={r.get('hits@10',0):.4f}")

print("\n" + "="*80)
print("TABLE 3: StarQE family comparison")
print("="*80)
for mname in ['StarQE','NBFNet_StarQE']:
    r = all_results.get(mname, {})
    print(f"  {mname:30s}  MRR={r.get('mrr',0):.4f}  MR={r.get('mr',0):7.1f}  "
          f"H@1={r.get('hits@1',0):.4f}  H@10={r.get('hits@10',0):.4f}")

df_A1 = pd.DataFrame({
    'Variant': list(ablation_A1.keys()),
    'Qual?':   ['','','',''],
    'Path?':   ['','','',''],
    'Gate?':   ['learned','learned','fixed=0.5','learned'],
    'MRR':     [ablation_A1[v]['mrr']     for v in ablation_A1],
    'Hits@1':  [ablation_A1[v]['hits@1']  for v in ablation_A1],
    'Hits@3':  [ablation_A1[v]['hits@3']  for v in ablation_A1],
    'Hits@10': [ablation_A1[v]['hits@10'] for v in ablation_A1],
    'MR':      [ablation_A1[v]['mr']      for v in ablation_A1],
})
print("\n" + "="*80)
print("TABLE 4: ABLATION A1 — AlertStar Components")
print("="*80)
print(df_A1.to_string(index=False))

print("\n" + "="*80)
print("TABLE 5: ABLATION A2 — Gate Trajectory")
print("="*80)
if gate_history: print(pd.DataFrame(gate_history).to_string(index=False))

df_A3 = pd.DataFrame({
    'Variant': list(ablation_A3.keys()),
    'Tasks':   [str(t[1]) for t in mt_ablation_configs],
    'MRR':     [ablation_A3[v]['mrr']     for v in ablation_A3],
    'Hits@1':  [ablation_A3[v]['hits@1']  for v in ablation_A3],
    'Hits@10': [ablation_A3[v]['hits@10'] for v in ablation_A3],
    'MR':      [ablation_A3[v]['mr']      for v in ablation_A3],
})
print("\n" + "="*80)
print("TABLE 6: ABLATION A3 — MT Task Contribution")
print("="*80)
print(df_A3.to_string(index=False))

print("\n" + "="*80)
print("TABLE 7: ABLATION A4 — Qualifier Density")
print("="*80)
for density, res in ablation_A4.items():
    print(f"\n  {density}:")
    for mname, metrics in res.items():
        print(f"    {mname:30s}  MRR={metrics['mrr']:.4f}  "
              f"H@1={metrics['hits@1']:.4f}  H@10={metrics['hits@10']:.4f}")


# ============================================================================
# 20 — VISUALIZATIONS
# ============================================================================

n_models     = len(df_main)
palette_main = sns.color_palette("tab10", n_models)
met_list     = ['mrr','hits@1','hits@3','hits@10']
lbl_list     = ['MRR','H@1','H@3','H@10']

fig = plt.figure(figsize=(30, 24))
fig.suptitle("AlertStar — 10-Model Complete Results & Ablation Studies",
             fontsize=18, fontweight='bold', y=0.98)
gs  = fig.add_gridspec(3, 3, hspace=0.5, wspace=0.35)

# ── P1: Main MRR ─────────────────────────────────────────────────────────
ax1  = fig.add_subplot(gs[0, 0])
bars = ax1.bar(df_main['Model'], df_main['MRR'], color=palette_main)
ax1.set_title("Main Results — MRR (all 10 models)", fontweight='bold')
ax1.set_ylabel("MRR")
ax1.tick_params(axis='x', rotation=55, labelsize=6)
for bar, v in zip(bars, df_main['MRR']):
    ax1.text(bar.get_x()+bar.get_width()/2, v+0.002,
             f'{v:.3f}', ha='center', fontsize=5, fontweight='bold')
highlight = {'HR_NBFNet':'goldenrod', 'MultiTask_HR_NBFNet':'crimson',
             'HyNT':'red', 'StarQE':'steelblue', 'NBFNet_StarQE':'navy'}
model_list = list(df_main['Model'])
for i, bar in enumerate(bars):
    mname = model_list[i]
    if mname in highlight:
        bar.set_edgecolor(highlight[mname]); bar.set_linewidth(2.5)

# ── P2: Hits@k ────────────────────────────────────────────────────────────
ax2 = fig.add_subplot(gs[0, 1])
x2  = np.arange(n_models); w = 0.25
ax2.bar(x2-w,  df_main['Hits@1'],  w, label='H@1',  color='steelblue')
ax2.bar(x2,    df_main['Hits@3'],  w, label='H@3',  color='orange')
ax2.bar(x2+w,  df_main['Hits@10'], w, label='H@10', color='green')
ax2.set_xticks(x2)
ax2.set_xticklabels(df_main['Model'], rotation=55, ha='right', fontsize=6)
ax2.set_title("Main Results — Hits@k", fontweight='bold')
ax2.legend(fontsize=8)

# ── P3: NBFNet family (TrueNBFNet / HR-NBFNet / MT-HR-NBFNet) ────────────
ax3      = fig.add_subplot(gs[0, 2])
nbf_mods = ['TrueNBFNet','HR_NBFNet','MultiTask_HR_NBFNet']
nbf_res  = {m: all_results[m] for m in nbf_mods if m in all_results}
x3   = np.arange(len(lbl_list)); w3 = 0.25
pal3 = ['steelblue','goldenrod','crimson']
for i, (mname, res) in enumerate(nbf_res.items()):
    vals   = [res.get(mk,0) for mk in met_list]
    offset = (i - len(nbf_res)/2 + 0.5)*w3
    ax3.bar(x3+offset, vals, w3, label=mname, color=pal3[i])
ax3.set_xticks(x3); ax3.set_xticklabels(lbl_list)
ax3.set_title("NBFNet Family Comparison", fontweight='bold')
ax3.legend(fontsize=8); ax3.set_ylabel("Score")

# ── P4: StarQE family ─────────────────────────────────────────────────────
ax4     = fig.add_subplot(gs[1, 0])
qe_mods = ['StarE','StarQE','NBFNet_StarQE','AlertStar']
qe_res  = {m: all_results[m] for m in qe_mods if m in all_results}
x4  = np.arange(len(lbl_list)); w4 = 0.2
pal4 = sns.color_palette("Set2", len(qe_res))
for i, (mname, res) in enumerate(qe_res.items()):
    vals   = [res.get(mk,0) for mk in met_list]
    offset = (i - len(qe_res)/2 + 0.5)*w4
    ax4.bar(x4+offset, vals, w4, label=mname, color=pal4[i])
ax4.set_xticks(x4); ax4.set_xticklabels(lbl_list)
ax4.set_title("StarQE Family vs Baselines", fontweight='bold')
ax4.legend(fontsize=8)

# ── P5: Ablation A1 ───────────────────────────────────────────────────────
ax5  = fig.add_subplot(gs[1, 1])
x5   = np.arange(len(df_A1)); w5 = 0.25
ax5.bar(x5-w5, df_A1['MRR'],     w5, label='MRR',  color='steelblue')
ax5.bar(x5,    df_A1['Hits@1'],  w5, label='H@1',  color='orange')
ax5.bar(x5+w5, df_A1['Hits@10'], w5, label='H@10', color='green')
ax5.set_xticks(x5)
ax5.set_xticklabels(df_A1['Variant'], rotation=20, ha='right', fontsize=9)
ax5.set_title("A1: AlertStar Components", fontweight='bold')
ax5.legend(fontsize=8)
full_mrr = df_A1[df_A1['Variant']=='AS-Full']['MRR'].values[0]
for xi, (_, row) in zip(x5, df_A1.iterrows()):
    drop = full_mrr - row['MRR']
    if abs(drop) > 0.001:
        ax5.text(xi-w5, row['MRR']+0.003,
                 f'Δ{drop:+.3f}', ha='center', fontsize=7, color='red')

# ── P6: Gate Trajectory ───────────────────────────────────────────────────
ax6 = fig.add_subplot(gs[1, 2])
if gate_history:
    epochs = [g['epoch'] for g in gate_history]
    gates  = [g['gate']  for g in gate_history]
    ax6.plot(epochs, gates, 'o-', color='purple', lw=2, markersize=6)
    ax6.axhline(0.5, color='gray', ls='--', alpha=0.7, label='g=0.5 (balanced)')
    ax6.fill_between(epochs, gates, 0.5, alpha=0.12,
                     color='blue' if gates[-1]>0.5 else 'orange')
    ax6.set_ylim(0,1.05)
    ax6.set_xlabel("Epoch"); ax6.set_ylabel("g = σ(θ)")
    ax6.set_title("A2: AlertStar Gate Trajectory", fontweight='bold')
    ax6.legend(fontsize=8)
    ax6.annotate(f"Final: {gates[-1]:.3f}",
                 xy=(epochs[-1],gates[-1]),
                 xytext=(epochs[-1]-3, gates[-1]+0.08),
                 arrowprops=dict(arrowstyle='->',color='purple'),
                 fontsize=9, color='purple')
else:
    ax6.text(0.5,0.5,'No gate data',ha='center',va='center',transform=ax6.transAxes)
    ax6.set_title("A2: Gate Trajectory",fontweight='bold')

# ── P7: Ablation A3 ───────────────────────────────────────────────────────
ax7  = fig.add_subplot(gs[2, 0])
x7   = np.arange(len(df_A3))
ax7.bar(x7-w5, df_A3['MRR'],     w5, label='MRR',  color='steelblue')
ax7.bar(x7,    df_A3['Hits@1'],  w5, label='H@1',  color='orange')
ax7.bar(x7+w5, df_A3['Hits@10'], w5, label='H@10', color='green')
ax7.set_xticks(x7)
ax7.set_xticklabels(df_A3['Variant'], rotation=25, ha='right', fontsize=8)
ax7.set_title("A3: MT Task Contribution", fontweight='bold')
ax7.legend(fontsize=8)

# ── P8: Ablation A4 — Density (MRR) ──────────────────────────────────────
ax8       = fig.add_subplot(gs[2, 1:])
densities = list(ablation_A4.keys())
n_dm      = len(DENSITY_MODELS)
x8        = np.arange(len(densities))
w8        = 0.12
pal8      = sns.color_palette("tab10", n_dm)
for mi, mname in enumerate(DENSITY_MODELS):
    mrrs   = [ablation_A4[d].get(mname,{}).get('mrr',0) for d in densities]
    offset = (mi - n_dm/2 + 0.5)*w8
    bars8  = ax8.bar(x8+offset, mrrs, w8, label=mname, color=pal8[mi])
    # bold border for new methods
    if mname in ('HR_NBFNet','MultiTask_HR_NBFNet'):
        for b in bars8: b.set_edgecolor('black'); b.set_linewidth(1.5)
ax8.set_xticks(x8); ax8.set_xticklabels(densities)
ax8.set_xlabel("Qualifier Density"); ax8.set_ylabel("MRR")
ax8.set_title("A4: MRR vs Qualifier Density (all 6 models)", fontweight='bold')
ax8.legend(fontsize=7, loc='upper left', ncol=2)

plt.savefig(f"{CONFIG['output_path']}complete_10model_results.png",
            dpi=300, bbox_inches='tight')
plt.show()
print("Visualization saved")


# ============================================================================
# 21 — SAVE ALL RESULTS
# ============================================================================

with open(f"{CONFIG['output_path']}all_results_10models.json", 'w') as f:
    json.dump({'main_results': all_results,
               'ablation_results': ablation_results,
               'gate_history': gate_history}, f, indent=2)

print(f"\n Saved to {CONFIG['output_path']}")
print("\n" + "="*70)
print("COMPLETE — 10 models + 4 ablations ready")
print("="*70)
print(f"\nAll {len(all_results)} models ranked by MRR:")
for m, r in sorted(all_results.items(), key=lambda x: -x[1]['mrr']):
    tag = " ← NEW" if m in ('HR_NBFNet','MultiTask_HR_NBFNet') else \
          " ← restored" if m in ('StarQE','NBFNet_StarQE') else ""
    print(f"  {m:30s}  MRR={r['mrr']:.4f}  H@10={r['hits@10']:.4f}{tag}")