# ============================================================================ # ALERTSTAR: COMPLETE NOTEBOOK — 10 MODELS + ALL ABLATIONS # # Models: # 1. StarE — qualifier-enriched relation attention # 2. ShrinkE — shrinking transform qualifier fusion # 3. TrueNBFNet — Neural Bellman-Ford (qualifier-unaware) # 4. AlertStar — gated StarE + path composition (ours) # 5. StarQE — complex query answering (1p/2p/2i/2u) # 6. NBFNet+StarQE — residual path-augmented complex queries # 7. HyNT — Transformer 2-task competitor # 8. MultiTask-AS — Transformer 4-task (ours, main) # 9. HR-NBFNet — Hyper-Relational Bellman-Ford # qualifier-aware at every propagation step # matches slide formulation exactly # 10. MultiTask_HR_NBFNet — HR-NBFNet + 4-task multi-task training (NEW) # combines graph propagation with auxiliary # relation/qual-key/qual-value supervision # # Ablations: # A1 — AlertStar component ablation (NoQual / NoPath / NoGate / Full) # A2 — AlertStar gate value trajectory # A3 — MultiTask-AS auxiliary task contribution # A4 — Qualifier density sensitivity (Q33 / Q66 / Q100) # HR-NBFNet + MultiTask_HR_NBFNet included in A4 # ============================================================================ # ============================================================================ # 1 — Imports & Config # ============================================================================ import os, json, warnings, copy warnings.filterwarnings('ignore') import numpy as np import pandas as pd import matplotlib.pyplot as plt import matplotlib.patches as mpatches import seaborn as sns from collections import defaultdict from tqdm import tqdm import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torch_scatter import scatter_add, scatter_mean device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") CONFIG = { # ── paths ───────────────────────────────────────────────────────────── 'data_path': '/.../inductive_q100_statements', 'output_path': '/.../results_q3366100_ablation_10ci100/', # ── shared hyper-params ─────────────────────────────────────────────── 'embedding_dim': 200, 'dropout': 0.2, 'device': device, # ── standard training (StarE / ShrinkE / AlertStar) ────────────────── 'batch_size': 128, 'learning_rate': 0.0005, 'epochs': 20, # ── TrueNBFNet ──────────────────────────────────────────────────────── 'nbfnet_epochs': 20, 'nbfnet_lr': 0.0005, 'nbfnet_layers': 3, 'nbfnet_chunk_size': 10000, 'nbfnet_max_per_group': 8, # ── StarQE / NBFNet+StarQE ──────────────────────────────────────────── 'query_epochs': 20, 'query_lr': 0.0005, # ── HyNT ────────────────────────────────────────────────────────────── 'hynt_epochs': 20, 'hynt_lr': 0.0005, 'hynt_batch': 64, 'hynt_n_heads': 4, 'hynt_n_layers': 2, # ── MultiTask-AS ────────────────────────────────────────────────────── 'mt_epochs': 20, 'mt_lr': 0.0005, 'mt_batch': 64, # ── HR-NBFNet ───────────────────────────────────────────────────────── 'hr_nbfnet_epochs': 20, 'hr_nbfnet_lr': 0.0005, 'hr_nbfnet_layers': 3, 'hr_nbfnet_chunk_size': 5000, 'hr_nbfnet_max_quals': 8, # ── MultiTask_HR_NBFNet ─────────────────────────────────────────────── 'mt_hr_epochs': 20, 'mt_hr_lr': 0.0005, 'mt_hr_batch': 32, # smaller batch — each sample triggers a BF pass # ── qualifier density paths for Ablation A4 ─────────────────────────── 'q33_path': '/.../inductive_q33_statements', 'q66_path': '/.../inductive_q66_statements', } os.makedirs(CONFIG['output_path'], exist_ok=True) all_results = {} ablation_results = {} gate_history = [] print("Config ready") # ============================================================================ # 2 — Data Loading # ============================================================================ class DataPreprocessor: def __init__(self): self.entity2id = {} self.relation2id = {} self.qualifier_key2id = {} self.qualifier_value2id = {} def _parse(self, line): parts = line.strip().split(',') if len(parts) < 3: return None head, relation, tail = parts[0], parts[1], parts[2] qualifiers = [] if len(parts) > 3: for pair in parts[3].split('|'): pair = pair.strip() if ':' in pair: key, value = pair.split(':', 1) qualifiers.append((key.strip(), value.strip())) return {'head': head, 'relation': relation, 'tail': tail, 'qualifiers': qualifiers} def _register(self, triple): for token, vocab in [(triple['head'], self.entity2id), (triple['tail'], self.entity2id), (triple['relation'], self.relation2id)]: if token not in vocab: vocab[token] = len(vocab) for qk, qv in triple['qualifiers']: if qk not in self.qualifier_key2id: self.qualifier_key2id[qk] = len(self.qualifier_key2id) if qv not in self.qualifier_value2id: self.qualifier_value2id[qv] = len(self.qualifier_value2id) def load(self, data_path): splits = {} for name, fname in [('train','train.txt'), ('valid','validation.txt'), ('test', 'test.txt')]: data = [] with open(os.path.join(data_path, fname)) as f: for line in f: t = self._parse(line) if t: self._register(t) data.append(t) splits[name] = data print(f" {name}: {len(data):,}") print(f" NE={len(self.entity2id)} NR={len(self.relation2id)} " f"NQK={len(self.qualifier_key2id)} NQV={len(self.qualifier_value2id)}") return splits['train'], splits['valid'], splits['test'] class HRDataset(Dataset): def __init__(self, data, preprocessor): self.data = data self.p = preprocessor def __len__(self): return len(self.data) def __getitem__(self, idx): t = self.data[idx] return { 'head': self.p.entity2id[t['head']], 'relation': self.p.relation2id[t['relation']], 'tail': self.p.entity2id[t['tail']], 'qualifiers': [(self.p.qualifier_key2id[qk], self.p.qualifier_value2id[qv]) for qk, qv in t['qualifiers']], } def collate(batch): return { 'head': torch.tensor([b['head'] for b in batch], dtype=torch.long), 'relation': torch.tensor([b['relation'] for b in batch], dtype=torch.long), 'tail': torch.tensor([b['tail'] for b in batch], dtype=torch.long), 'qualifiers': [b['qualifiers'] for b in batch], } def mt_collate(batch): by_task = defaultdict(list) for s in batch: by_task[s['task']].append(s) return by_task print("Loading data...") preprocessor = DataPreprocessor() train_data, valid_data, test_data = preprocessor.load(CONFIG['data_path']) train_ds = HRDataset(train_data, preprocessor) valid_ds = HRDataset(valid_data, preprocessor) test_ds = HRDataset(test_data, preprocessor) NE = len(preprocessor.entity2id) NR = len(preprocessor.relation2id) NQK = len(preprocessor.qualifier_key2id) NQV = len(preprocessor.qualifier_value2id) DIM = CONFIG['embedding_dim'] print("Data ready") # ============================================================================ # 3 — Shared Evaluation & Training Utilities # ============================================================================ def evaluate_model(model, dataset, device, max_samples=500): model.eval() ranks = [] with torch.no_grad(): for i in tqdm(range(min(len(dataset), max_samples)), desc="Evaluating", leave=False): s = dataset[i] h = torch.tensor([s['head']], device=device) r = torch.tensor([s['relation']], device=device) t = s['tail'] q = [s['qualifiers']] scores = model(h, r, q).squeeze() rank = (torch.argsort(scores, descending=True) == t ).nonzero(as_tuple=True)[0].item() + 1 ranks.append(rank) ranks = np.array(ranks) return {'mr': float(np.mean(ranks)), 'mrr': float(np.mean(1.0 / ranks)), 'hits@1': float(np.mean(ranks <= 1)), 'hits@3': float(np.mean(ranks <= 3)), 'hits@10': float(np.mean(ranks <= 10))} def train_standard(model, model_name, train_ds_=None, valid_ds_=None, ne_override=None, lr=None, epochs=None, gate_track=False): dev = CONFIG['device'] lr = lr or CONFIG['learning_rate'] eps = epochs or CONFIG['epochs'] _ne = ne_override or NE _tds = train_ds_ or train_ds _vds = valid_ds_ or valid_ds model.to(dev) opt = torch.optim.Adam(model.parameters(), lr=lr) loader = DataLoader(_tds, batch_size=CONFIG['batch_size'], shuffle=True, collate_fn=collate) print(f"\n{'='*60}\nTRAINING {model_name} " f"({sum(p.numel() for p in model.parameters()):,} params)\n{'='*60}") best_mrr, history, epoch_gates = 0.0, [], [] for epoch in range(eps): model.train(); total = 0.0 for b in tqdm(loader, desc=f"Epoch {epoch+1}/{eps}", leave=False): h, r, t, q = (b['head'].to(dev), b['relation'].to(dev), b['tail'].to(dev), b['qualifiers']) pos = model(h, r, q, t) neg = model(h, r, q, torch.randint(0, _ne, (len(h),), device=dev)) loss = F.margin_ranking_loss(pos, neg, torch.ones_like(pos), margin=1.0) opt.zero_grad(); loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step(); total += loss.item() if gate_track and hasattr(model, 'gate') and model.gate is not None: epoch_gates.append({'epoch': epoch+1, 'gate': torch.sigmoid(model.gate).item()}) if (epoch+1) % 5 == 0: m = evaluate_model(model, _vds, dev) print(f" Epoch {epoch+1}: loss={total/len(loader):.4f} " f"MRR={m['mrr']:.4f} MR={m['mr']:.1f} " f"H@1={m['hits@1']:.4f} H@3={m['hits@3']:.4f} " f"H@10={m['hits@10']:.4f}" + (f" gate={epoch_gates[-1]['gate']:.4f}" if gate_track and epoch_gates else "")) if m['mrr'] > best_mrr: best_mrr = m['mrr'] torch.save(model.state_dict(), f"{CONFIG['output_path']}{model_name}_best.pt") print(" → best saved") history.append(m) return model, history, epoch_gates print("Utilities ready") # ============================================================================ # 4 — Model 1: StarE # ============================================================================ class StarEModel(nn.Module): def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1): super().__init__() self.num_entities = ne self.ent = nn.Embedding(ne, dim) self.rel = nn.Embedding(nr, dim) self.qk = nn.Embedding(nqk, dim) self.qv = nn.Embedding(nqv, dim) self.attn = nn.MultiheadAttention(dim, 4, dropout=dropout, batch_first=True) self.drop = nn.Dropout(dropout) for e in [self.ent, self.rel, self.qk, self.qv]: nn.init.xavier_uniform_(e.weight) def _enrich(self, r_emb, quals, dev): if not quals: return r_emb k = self.qk(torch.tensor([q[0] for q in quals], device=dev)) v = self.qv(torch.tensor([q[1] for q in quals], device=dev)) kv = (k + v).unsqueeze(0) out, _ = self.attn(r_emb.view(1,1,-1), kv, kv) return out.squeeze() def forward(self, head, relation, qualifiers, tail=None): dev = head.device h = self.ent(head) r = torch.stack([self._enrich(self.rel(relation[i:i+1]).squeeze(), qualifiers[i], dev) for i in range(head.size(0))]) x = self.drop(h + r) if tail is not None: return (x * self.ent(tail)).sum(-1) return x @ self.ent.weight.t() print("StarE defined") # ============================================================================ # 5 — Model 2: ShrinkE # ============================================================================ class ShrinkEModel(nn.Module): def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1): super().__init__() self.num_entities = ne self.ent = nn.Embedding(ne, dim) self.rel = nn.Embedding(nr, dim) self.qk = nn.Embedding(nqk, dim) self.qv = nn.Embedding(nqv, dim) self.shrink = nn.Sequential(nn.Linear(dim*2, dim), nn.Tanh(), nn.Dropout(dropout)) self.proj = nn.Linear(dim, dim) self.drop = nn.Dropout(dropout) for e in [self.ent, self.rel, self.qk, self.qv]: nn.init.xavier_uniform_(e.weight) def _shrink(self, r_emb, quals, dev): if not quals: return r_emb k = self.qk(torch.tensor([q[0] for q in quals], device=dev)) v = self.qv(torch.tensor([q[1] for q in quals], device=dev)) qc = (k + v).mean(0, keepdim=True) return self.shrink(torch.cat([r_emb.unsqueeze(0), qc], -1)).squeeze() def forward(self, head, relation, qualifiers, tail=None): dev = head.device h = self.proj(self.ent(head)) r = torch.stack([self._shrink(self.rel(relation[i:i+1]).squeeze(), qualifiers[i], dev) for i in range(head.size(0))]) x = self.drop(h + self.proj(r)) if tail is not None: return (x * self.ent(tail)).sum(-1) return x @ self.ent.weight.t() print("✓ ShrinkE defined") # ============================================================================ # 6 — Model 3: TrueNBFNet (qualifier-unaware Bellman-Ford) # ============================================================================ class NBFConvLayer(nn.Module): def __init__(self, dim, num_relation, chunk_size=10000, layer_norm=True): super().__init__() self.chunk_size = chunk_size self.rel_emb = nn.Embedding(num_relation, dim) self.linear = nn.Linear(dim*2, dim) self.ln = nn.LayerNorm(dim) if layer_norm else None self.act = nn.ReLU() nn.init.xavier_uniform_(self.rel_emb.weight) def forward(self, graph, node_feat): edge_list = graph['edge_list'] N, B, D = node_feat.shape dev = node_feat.device agg = torch.zeros(N, B, D, device=dev) for start in range(0, edge_list.size(0), self.chunk_size): chunk = edge_list[start:start+self.chunk_size] src, dst, rel = chunk[:,0], chunk[:,1], chunk[:,2] msg = node_feat[src] * self.rel_emb(rel).unsqueeze(1) agg.scatter_add_(0, dst.view(-1,1,1).expand_as(msg), msg) del msg out = self.linear(torch.cat([node_feat, agg], dim=-1)) if self.ln: s = out.shape; out = self.ln(out.flatten(0,1)).view(s) return self.act(out) class TrueNBFNet(nn.Module): def __init__(self, ne, nr, dim=200, num_layers=3, short_cut=True, chunk_size=10000, dropout=0.1): super().__init__() self.num_entities = ne self.dim = dim self.short_cut = short_cut nr2 = nr * 2 self.query_emb = nn.Embedding(nr2, dim) self.layers = nn.ModuleList([ NBFConvLayer(dim, nr2, chunk_size=chunk_size) for _ in range(num_layers)]) self.mlp = nn.Sequential( nn.Linear(dim*2, dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, 1)) nn.init.xavier_uniform_(self.query_emb.weight) def _bellman_ford(self, graph, h, r, device): N = graph['num_nodes'] q = self.query_emb(torch.tensor([r], device=device)) feat = torch.zeros(N, 1, self.dim, device=device) feat[h, 0] = q[0] for layer in self.layers: h_new = layer(graph, feat) if self.short_cut: h_new = h_new + feat feat = h_new node_q = q.unsqueeze(0).expand(N, -1, -1) return torch.cat([feat, node_q], dim=-1).squeeze(1) def forward(self, head, relation, qualifiers=None, tail=None, graph=None): assert graph is not None assert (head==head[0]).all() and (relation==relation[0]).all() dev = head.device feat = self._bellman_ford(graph, head[0].item(), relation[0].item(), dev) if tail is not None: return self.mlp(feat[tail]).squeeze(-1) B = head.size(0) return self.mlp(feat).squeeze(-1).unsqueeze(0).expand(B, -1) def build_nbfnet_graph(train_data, preprocessor, device): nr, rows = len(preprocessor.relation2id), [] for t in train_data: h = preprocessor.entity2id[t['head']] r = preprocessor.relation2id[t['relation']] tl = preprocessor.entity2id[t['tail']] rows += [(h, tl, r), (tl, h, r+nr)] return {'edge_list': torch.tensor(rows, dtype=torch.long, device=device), 'num_nodes': len(preprocessor.entity2id)} def train_true_nbfnet(model, model_name="TrueNBFNet", train_data_=None, valid_ds_=None, graph_=None): dev = CONFIG['device'] _td = train_data_ or train_data _vds = valid_ds_ or valid_ds _g = graph_ or nbfnet_graph model.to(dev) graph = {k: v.to(dev) if torch.is_tensor(v) else v for k,v in _g.items()} opt = torch.optim.Adam(model.parameters(), lr=CONFIG['nbfnet_lr']) ne_ = graph['num_nodes'] mpg = CONFIG['nbfnet_max_per_group'] groups = defaultdict(list) for i, s in enumerate(_td): groups[(preprocessor.entity2id[s['head']], preprocessor.relation2id[s['relation']])].append(i) keys = list(groups.keys()) print(f"\n{'='*60}\nTRAINING {model_name}\n{'='*60}") best_mrr, history = 0.0, [] for epoch in range(CONFIG['nbfnet_epochs']): model.train(); np.random.shuffle(keys) total, cnt = 0.0, 0 for (h, r) in tqdm(keys, desc=f"Epoch {epoch+1}", leave=False): chosen = np.random.choice(groups[(h,r)], min(len(groups[(h,r)]),mpg), replace=False) t_pos = torch.tensor( [preprocessor.entity2id[_td[i]['tail']] for i in chosen], device=dev) B = len(t_pos) heads = torch.full((B,), h, dtype=torch.long, device=dev) rels = torch.full((B,), r, dtype=torch.long, device=dev) pos = model(heads, rels, tail=t_pos, graph=graph) neg = model(heads, rels, tail=torch.randint(0,ne_,(B,),device=dev), graph=graph) loss = F.margin_ranking_loss(pos, neg, torch.ones(B,device=dev), margin=1.0) opt.zero_grad(); loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step(); total += loss.item(); cnt += 1 if (epoch+1) % 5 == 0: m = evaluate_nbfnet(model, _vds, graph, dev) print(f" Epoch {epoch+1}: MRR={m['mrr']:.4f} MR={m['mr']:.1f} " f"H@1={m['hits@1']:.4f} H@10={m['hits@10']:.4f}") if m['mrr'] > best_mrr: best_mrr = m['mrr'] torch.save(model.state_dict(), f"{CONFIG['output_path']}{model_name}_best.pt") history.append(m) return model, history def evaluate_nbfnet(model, dataset, graph, device, max_groups=300): model.eval() groups = defaultdict(list) for i, s in enumerate(dataset.data): groups[(preprocessor.entity2id[s['head']], preprocessor.relation2id[s['relation']])].append(i) ranks = [] with torch.no_grad(): for (h, r) in tqdm(list(groups.keys())[:max_groups], desc="Eval", leave=False): idxs = groups[(h,r)] tails = [preprocessor.entity2id[dataset.data[i]['tail']] for i in idxs] B = len(tails) heads = torch.full((B,), h, dtype=torch.long, device=device) rels = torch.full((B,), r, dtype=torch.long, device=device) try: scores = model(heads, rels, graph=graph) for i, tgt in enumerate(tails): pos = (torch.argsort(scores[i],descending=True)==tgt ).nonzero(as_tuple=True)[0] if len(pos): ranks.append(pos[0].item()+1) except Exception: continue if not ranks: return {'mr':0.,'mrr':0.,'hits@1':0.,'hits@3':0.,'hits@10':0.} ranks = np.array(ranks) return {'mr':float(np.mean(ranks)), 'mrr':float(np.mean(1./ranks)), 'hits@1':float(np.mean(ranks<=1)), 'hits@3':float(np.mean(ranks<=3)), 'hits@10':float(np.mean(ranks<=10))} print("TrueNBFNet defined") # ============================================================================ # 7 — Model 4: AlertStar (with ablation flags) # ============================================================================ class AlertStarModel(nn.Module): """ use_qual=False → AS-NoQual use_path=False → AS-NoPath fixed_gate=0.5 → AS-NoGate default → AS-Full """ def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1, use_qual=True, use_path=True, fixed_gate=None): super().__init__() self.num_entities = ne self.use_qual, self.use_path, self.fixed_gate = use_qual, use_path, fixed_gate self.ent = nn.Embedding(ne, dim) self.rel = nn.Embedding(nr, dim) self.qk = nn.Embedding(nqk, dim) self.qv = nn.Embedding(nqv, dim) self.attn = nn.MultiheadAttention(dim, 4, dropout=dropout, batch_first=True) self.ln1 = nn.LayerNorm(dim) self.path_net = nn.Sequential( nn.Linear(dim*2, dim), nn.LayerNorm(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim)) self.ln2 = nn.LayerNorm(dim) self.gate = (nn.Parameter(torch.tensor(0.5)) if use_path and fixed_gate is None else None) self.drop = nn.Dropout(dropout) for e in [self.ent, self.rel, self.qk, self.qv]: nn.init.xavier_uniform_(e.weight) def _enrich(self, r_emb, quals, dev): if not self.use_qual or not quals: return r_emb k = self.qk(torch.tensor([q[0] for q in quals], device=dev)) v = self.qv(torch.tensor([q[1] for q in quals], device=dev)) kv = (k+v).unsqueeze(0) out, _ = self.attn(r_emb.view(1,1,-1), kv, kv) return out.squeeze() def forward(self, head, relation, qualifiers, tail=None): dev = head.device h = self.ent(head) r = torch.stack([self._enrich(self.rel(relation[i:i+1]).squeeze(), qualifiers[i], dev) for i in range(head.size(0))]) stare = self.ln1(h + r) if not self.use_path: x = self.drop(stare) else: path = self.ln2(h + self.path_net(torch.cat([h,r],dim=-1))) g = (torch.sigmoid(self.gate) if self.gate is not None else torch.tensor(self.fixed_gate, device=dev)) x = self.drop(g*stare + (1-g)*path) if tail is not None: return (x * self.ent(tail)).sum(-1) return x @ self.ent.weight.t() print("AlertStar defined") # ============================================================================ # 8 — Model 5: StarQE (Complex Query Answering: 1p / 2p / 2i / 2u) # ============================================================================ # StarQE extends link prediction to four first-order logic query types: # 1p: ∃e: r(h,e) — direct 1-hop # 2p: ∃e1: r1(h,e1) ∧ r2(e1,e) — 2-hop chain # 2i: r1(h1,e) ∧ r2(h2,e) — 2-anchor intersection # 2u: r1(h1,e) ∨ r2(h2,e) — 2-anchor union # All types are trained with margin ranking loss and evaluated on 1p queries # (tail prediction) for a fair comparison with other models. # ============================================================================ class StarQEModel(nn.Module): def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1): super().__init__() self.num_entities = ne self.dim = dim self.ent = nn.Embedding(ne, dim) self.rel = nn.Embedding(nr, dim) # qualifier embeddings — StarQE uses them to enrich relations like StarE self.qk = nn.Embedding(nqk, dim) self.qv = nn.Embedding(nqv, dim) self.attn = nn.MultiheadAttention(dim, 4, dropout=dropout, batch_first=True) # composition MLP: projects x^{l-1} + R[r] into next entity embedding self.compose = nn.Sequential( nn.Linear(dim, dim), nn.LayerNorm(dim), nn.ReLU(), nn.Dropout(dropout)) # intersection operator W_∩: fuses two anchor embeddings self.intersect = nn.Linear(dim*2, dim) self.drop = nn.Dropout(dropout) for e in [self.ent, self.rel, self.qk, self.qv]: nn.init.xavier_uniform_(e.weight) def _enrich_rel(self, r_id, quals, dev): """Qualifier-enrich a relation embedding (same as StarE).""" r_emb = self.rel(torch.tensor([r_id], device=dev)).squeeze() if not quals: return r_emb k = self.qk(torch.tensor([q[0] for q in quals], device=dev)) v = self.qv(torch.tensor([q[1] for q in quals], device=dev)) kv = (k + v).unsqueeze(0) out, _ = self.attn(r_emb.view(1,1,-1), kv, kv) return out.squeeze() def _compose_1p(self, h_id, r_id, quals, dev): """1p: x = ρ(E[h] + R*[r])""" h_emb = self.ent(torch.tensor([h_id], device=dev)).squeeze() r_emb = self._enrich_rel(r_id, quals, dev) return self.compose(h_emb + r_emb) def _compose_2p(self, h_id, r1_id, r2_id, quals, dev): """2p: x1 = ρ(E[h]+R[r1]), x = ρ(x1+R[r2])""" h_emb = self.ent(torch.tensor([h_id], device=dev)).squeeze() r1_emb = self._enrich_rel(r1_id, quals, dev) r2_emb = self._enrich_rel(r2_id, [], dev) x1 = self.compose(h_emb + r1_emb) return self.compose(x1 + r2_emb) def _compose_2i(self, h1, r1, h2, r2, quals, dev): """2i: intersection of two 1p queries""" e1 = self.compose( self.ent(torch.tensor([h1],device=dev)).squeeze() + self._enrich_rel(r1, quals, dev)) e2 = self.compose( self.ent(torch.tensor([h2],device=dev)).squeeze() + self._enrich_rel(r2, [], dev)) return self.intersect(torch.cat([e1, e2], dim=-1)) def _compose_2u(self, h1, r1, h2, r2, quals, dev): """2u: union (mean) of two 1p queries""" e1 = self.compose( self.ent(torch.tensor([h1],device=dev)).squeeze() + self._enrich_rel(r1, quals, dev)) e2 = self.compose( self.ent(torch.tensor([h2],device=dev)).squeeze() + self._enrich_rel(r2, [], dev)) return (e1 + e2) / 2 def forward(self, head, relation, qualifiers, tail=None): """Standard 1p interface — compatible with evaluate_model.""" dev = head.device B = head.size(0) outs = [] for i in range(B): x = self._compose_1p(head[i].item(), relation[i].item(), qualifiers[i], dev) outs.append(x) x = self.drop(torch.stack(outs)) # [B, dim] if tail is not None: return (x * self.ent(tail)).sum(-1) return x @ self.ent.weight.t() def score_query(self, query_vec, tail_id, dev): """Score a composed query vector against a specific tail.""" t_emb = self.ent(torch.tensor([tail_id], device=dev)).squeeze() return (query_vec * t_emb).sum() def train_starqe(model, model_name="StarQE", train_data_=None, valid_ds_=None, ne_override=None): """ Trains on all four query types derived from 1-hop triples: 1p: direct link 2p: chain 2i: intersection 2u: union Falls back to 1p only if not enough distinct (h,r) pairs for chains. """ dev = CONFIG['device'] _td = train_data_ or train_data _vds = valid_ds_ or valid_ds _ne = ne_override or NE model.to(dev) opt = torch.optim.Adam(model.parameters(), lr=CONFIG['query_lr']) # Build query samples from triples def build_queries(data): # index (h,r) → list of tails hr2tails = defaultdict(list) triples = [] for s in data: h = preprocessor.entity2id[s['head']] r = preprocessor.relation2id[s['relation']] t = preprocessor.entity2id[s['tail']] qs = [(preprocessor.qualifier_key2id[qk], preprocessor.qualifier_value2id[qv]) for qk, qv in s['qualifiers']] hr2tails[(h,r)].append(t) triples.append((h, r, t, qs)) return triples, hr2tails triples, hr2tails = build_queries(_td) hr_keys = [k for k, v in hr2tails.items() if len(v) >= 1] print(f"\n{'='*60}\nTRAINING {model_name} " f"({sum(p.numel() for p in model.parameters()):,} params)") print(f" Triples: {len(triples):,} (h,r) pairs: {len(hr_keys):,}\n{'='*60}") best_mrr, history = 0.0, [] for epoch in range(CONFIG['query_epochs']): model.train(); total, cnt = 0.0, 0 np.random.shuffle(triples) for h, r, t, qs in tqdm(triples, desc=f"Epoch {epoch+1}", leave=False): losses = [] # ── 1p query ───────────────────────────────────────────────── neg_t = np.random.randint(0, _ne) h_t = torch.tensor([h], device=dev) r_t = torch.tensor([r], device=dev) pos_s = model(h_t, r_t, [qs], torch.tensor([t], device=dev)) neg_s = model(h_t, r_t, [qs], torch.tensor([neg_t],device=dev)) losses.append(F.margin_ranking_loss( pos_s, neg_s, torch.ones(1,device=dev), margin=1.0)) # ── 2p query (chain through t as intermediate) ──────────────── # find another relation r2 where t is a head r2_candidates = [rr for (hh,rr) in hr_keys if hh == t] if r2_candidates: r2 = np.random.choice(r2_candidates) tails2 = hr2tails[(t, r2)] if tails2: t2 = np.random.choice(tails2) neg2 = np.random.randint(0, _ne) try: q_pos = model._compose_2p(h, r, r2, qs, dev) q_neg = q_pos # same query vec, different tail t2_emb = model.ent(torch.tensor([t2], device=dev)).squeeze() n2_emb = model.ent(torch.tensor([neg2],device=dev)).squeeze() pos_2p = (model.drop(q_pos) * t2_emb).sum() neg_2p = (model.drop(q_neg) * n2_emb).sum() losses.append(F.margin_ranking_loss( pos_2p.unsqueeze(0), neg_2p.unsqueeze(0), torch.ones(1,device=dev), margin=1.0)) except Exception: pass # ── 2i query (intersect two 1p queries sharing tail t) ──────── r2_for_t = [rr for (hh,rr) in hr_keys if t in hr2tails.get((hh,rr),[])] r_for_t = [rr for (hh,rr) in hr_keys if hh != h and t in hr2tails.get((hh,rr),[])] if r_for_t: h2 = np.random.choice([hh for (hh,rr) in hr_keys if rr in r_for_t and t in hr2tails.get((hh,rr),[]) and hh != h] or [h]) r2 = np.random.choice(r_for_t) neg2 = np.random.randint(0, _ne) try: q2i = model._compose_2i(h, r, h2, r2, qs, dev) t_e = model.ent(torch.tensor([t], device=dev)).squeeze() n_e = model.ent(torch.tensor([neg2],device=dev)).squeeze() pos_2i = (model.drop(q2i)*t_e).sum() neg_2i = (model.drop(q2i)*n_e).sum() losses.append(F.margin_ranking_loss( pos_2i.unsqueeze(0), neg_2i.unsqueeze(0), torch.ones(1,device=dev), margin=1.0)) except Exception: pass if not losses: continue loss = sum(losses) / len(losses) opt.zero_grad(); loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step(); total += loss.item(); cnt += 1 if (epoch+1) % 5 == 0: m = evaluate_model(model, _vds, dev) print(f" Epoch {epoch+1}: loss={total/max(cnt,1):.4f} " f"MRR={m['mrr']:.4f} MR={m['mr']:.1f} " f"H@1={m['hits@1']:.4f} H@10={m['hits@10']:.4f}") if m['mrr'] > best_mrr: best_mrr = m['mrr'] torch.save(model.state_dict(), f"{CONFIG['output_path']}{model_name}_best.pt") print(" → best saved") history.append(m) return model, history print("StarQE defined") # ============================================================================ # 9 — Model 6: NBFNet+StarQE (residual path-augmented complex queries) # ============================================================================ # Replaces StarQE's linear composition ρ(x+R[r]) with a residual MLP: # x^l = x^0 + PathNet( Concat(x^{l-1}, R*[r_l]) ) # Intersection and union operators inherited unchanged from StarQE. # ============================================================================ class NBFNetStarQEModel(nn.Module): def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1): super().__init__() self.num_entities = ne self.dim = dim self.ent = nn.Embedding(ne, dim) self.rel = nn.Embedding(nr, dim) self.qk = nn.Embedding(nqk, dim) self.qv = nn.Embedding(nqv, dim) self.attn = nn.MultiheadAttention(dim, 4, dropout=dropout, batch_first=True) # residual PathNet: R^{2d} -> R^d -> R^d self.path_net = nn.Sequential( nn.Linear(dim*2, dim), nn.LayerNorm(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim)) # intersection operator self.intersect = nn.Linear(dim*2, dim) self.drop = nn.Dropout(dropout) for e in [self.ent, self.rel, self.qk, self.qv]: nn.init.xavier_uniform_(e.weight) def _enrich_rel(self, r_id, quals, dev): r_emb = self.rel(torch.tensor([r_id], device=dev)).squeeze() if not quals: return r_emb k = self.qk(torch.tensor([q[0] for q in quals], device=dev)) v = self.qv(torch.tensor([q[1] for q in quals], device=dev)) kv = (k+v).unsqueeze(0) out, _ = self.attn(r_emb.view(1,1,-1), kv, kv) return out.squeeze() def _compose_step(self, x, x0, r_emb): """x^l = x^0 + PathNet(Concat(x^{l-1}, R[r]))""" return x0 + self.path_net(torch.cat([x, r_emb], dim=-1)) def _compose_1p(self, h_id, r_id, quals, dev): h = self.ent(torch.tensor([h_id], device=dev)).squeeze() r_star = self._enrich_rel(r_id, quals, dev) return self._compose_step(h, h, r_star) def _compose_2p(self, h_id, r1_id, r2_id, quals, dev): h = self.ent(torch.tensor([h_id], device=dev)).squeeze() r1_emb = self._enrich_rel(r1_id, quals, dev) r2_emb = self._enrich_rel(r2_id, [], dev) x1 = self._compose_step(h, h, r1_emb) return self._compose_step(x1, h, r2_emb) def _compose_2i(self, h1, r1, h2, r2, quals, dev): e1 = self._compose_1p(h1, r1, quals, dev) e2 = self._compose_1p(h2, r2, [], dev) return self.intersect(torch.cat([e1, e2], dim=-1)) def _compose_2u(self, h1, r1, h2, r2, quals, dev): e1 = self._compose_1p(h1, r1, quals, dev) e2 = self._compose_1p(h2, r2, [], dev) return (e1 + e2) / 2 def forward(self, head, relation, qualifiers, tail=None): dev = head.device B = head.size(0) outs = [] for i in range(B): x = self._compose_1p(head[i].item(), relation[i].item(), qualifiers[i], dev) outs.append(x) x = self.drop(torch.stack(outs)) if tail is not None: return (x * self.ent(tail)).sum(-1) return x @ self.ent.weight.t() # NBFNet+StarQE reuses train_starqe with the same interface print("NBFNet+StarQE defined") # ============================================================================ # 10 — Model 7: HyNT # ============================================================================ class HyNTModel(nn.Module): def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1, n_heads=4, n_layers=2): super().__init__() self.num_entities = ne self.dim = dim self.ent = nn.Embedding(ne, dim) self.rel = nn.Embedding(nr, dim) self.qk = nn.Embedding(nqk, dim) self.qv = nn.Embedding(nqv, dim) self.qual_attn = nn.MultiheadAttention(dim, n_heads, dropout=dropout, batch_first=True) enc_layer = nn.TransformerEncoderLayer( d_model=dim, nhead=n_heads, dim_feedforward=dim*4, dropout=dropout, batch_first=True) self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers) def _head(out): return nn.Sequential(nn.Linear(dim,dim), nn.LayerNorm(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim,out)) self.tail_head = _head(ne) self.qv_head = _head(nqv) self.qk_gate = nn.Sequential(nn.Linear(dim*2, dim), nn.Sigmoid()) for e in [self.ent, self.rel, self.qk, self.qv]: nn.init.xavier_uniform_(e.weight) def _aggregate_qualifiers(self, qs, dev): if not qs: return torch.zeros(self.dim, device=dev) k_embs = self.qk(torch.tensor([q[0] for q in qs], device=dev)) v_embs = self.qv(torch.tensor([q[1] for q in qs], device=dev)) kv = (k_embs + v_embs).unsqueeze(0) out, _ = self.qual_attn(kv, kv, kv) return out.mean(dim=1).squeeze(0) def _encode(self, h, r, t, qs, dev, mask_tail=False): h_e = self.ent(torch.tensor([h], device=dev)) r_e = self.rel(torch.tensor([r], device=dev)) t_e = (self.ent(torch.tensor([t], device=dev)) if not mask_tail else torch.zeros(1, self.dim, device=dev)) q_c = self._aggregate_qualifiers(qs, dev).unsqueeze(0) seq = torch.cat([h_e, r_e, t_e, q_c], dim=0).unsqueeze(0) return self.encoder(seq)[0, 0] def forward_tail(self, s, dev): return self.tail_head(self._encode(s['h'],s['r'],0,s['qs'],dev,mask_tail=True)) def forward_qv(self, s, dev): filtered = [(k,v) for k,v in s['qs'] if k != s['qk']] ctx = self._encode(s['h'],s['r'],s['t'],filtered,dev) qk_emb = self.qk(torch.tensor([s['qk']],device=dev)).squeeze() gate = self.qk_gate(torch.cat([ctx, qk_emb], dim=-1)) return self.qv_head(gate * ctx) def forward(self, head, relation, qualifiers, tail=None): dev = head.device; B = head.size(0) outs = [self.forward_tail({'h':head[i].item(),'r':relation[i].item(), 't':0,'qs':qualifiers[i]}, dev) for i in range(B)] scores = torch.stack(outs, dim=0) if tail is not None: return scores[torch.arange(B), tail] return scores class HyNTDataset(Dataset): def __init__(self, data, preprocessor): self.samples = [] p = preprocessor for triple in data: h = p.entity2id[triple['head']] r = p.relation2id[triple['relation']] t = p.entity2id[triple['tail']] qs = [(p.qualifier_key2id[qk], p.qualifier_value2id[qv]) for qk, qv in triple['qualifiers']] self.samples.append({'task':'tail','h':h,'r':r,'t':t,'qs':qs}) for qk_id, qv_id in qs: self.samples.append({'task':'qv','h':h,'r':r,'t':t, 'qs':qs,'qk':qk_id,'qv':qv_id}) def __len__(self): return len(self.samples) def __getitem__(self, i): return self.samples[i] def train_hynt(model, model_name="HyNT", train_data_=None, valid_ds_=None): dev = CONFIG['device'] _td = train_data_ or train_data _vds = valid_ds_ or valid_ds model.to(dev) opt = torch.optim.Adam(model.parameters(), lr=CONFIG['hynt_lr']) loader = DataLoader(HyNTDataset(_td, preprocessor), batch_size=CONFIG['hynt_batch'], shuffle=True, collate_fn=mt_collate) print(f"\n{'='*60}\nTRAINING {model_name} " f"({sum(p.numel() for p in model.parameters()):,} params)\n{'='*60}") best_mrr, history = 0.0, [] for epoch in range(CONFIG['hynt_epochs']): model.train(); total, cnt = 0.0, 0 for by_task in tqdm(loader, desc=f"Epoch {epoch+1}", leave=False): bl = torch.tensor(0., device=dev) for task, samples in by_task.items(): for s in samples: try: if task == 'tail': logits = model.forward_tail(s, dev) tgt = torch.tensor(s['t'], device=dev) bl = bl + F.cross_entropy(logits.unsqueeze(0), tgt.unsqueeze(0)) elif task == 'qv': logits = model.forward_qv(s, dev) tgt = torch.tensor(s['qv'], device=dev) bl = bl + 0.8*F.cross_entropy(logits.unsqueeze(0), tgt.unsqueeze(0)) except Exception: continue opt.zero_grad(); bl.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step(); total += bl.item(); cnt += 1 if (epoch+1) % 5 == 0: m = evaluate_model(model, _vds, dev) print(f" Epoch {epoch+1}: loss={total/max(cnt,1):.4f} " f"MRR={m['mrr']:.4f} H@1={m['hits@1']:.4f} H@10={m['hits@10']:.4f}") if m['mrr'] > best_mrr: best_mrr = m['mrr'] torch.save(model.state_dict(), f"{CONFIG['output_path']}{model_name}_best.pt") history.append(m) return model, history print("HyNT defined") # ============================================================================ # 11 — Model 8: MultiTask AlertStar # ============================================================================ class MultiTaskDataset(Dataset): def __init__(self, data, preprocessor, tasks=None, nqk_override=None): self.samples = [] self.nqk = nqk_override or NQK if tasks is None: tasks = ['tail','relation','qual_key','qual_value'] p = preprocessor for triple in data: h = p.entity2id[triple['head']] r = p.relation2id[triple['relation']] t = p.entity2id[triple['tail']] qs = [(p.qualifier_key2id[qk], p.qualifier_value2id[qv]) for qk, qv in triple['qualifiers']] if 'tail' in tasks: self.samples.append({'task':'tail','h':h,'r':r,'t':t,'qs':qs}) if 'relation' in tasks: self.samples.append({'task':'relation','h':h,'r':r,'t':t,'qs':qs}) if qs: if 'qual_key' in tasks: self.samples.append({'task':'qual_key','h':h,'r':r,'t':t, 'qs':qs,'keys':[qk for qk,_ in qs]}) if 'qual_value' in tasks: for qk, qv in qs: self.samples.append({'task':'qual_value','h':h,'r':r,'t':t, 'qs':qs,'qk':qk,'qv':qv}) def __len__(self): return len(self.samples) def __getitem__(self, i): return self.samples[i] class MultiTaskAlertStar(nn.Module): def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1, n_heads=4, n_layers=3): super().__init__() self.num_entities = ne self.ent = nn.Embedding(ne, dim) self.rel = nn.Embedding(nr, dim) self.qk = nn.Embedding(nqk, dim) self.qv = nn.Embedding(nqv, dim) enc_layer = nn.TransformerEncoderLayer( d_model=dim, nhead=n_heads, dim_feedforward=dim*4, dropout=dropout, batch_first=True) self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers) def head(out): return nn.Sequential(nn.Linear(dim,dim), nn.LayerNorm(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim,out)) self.tail_head = head(ne); self.rel_head = head(nr) self.qk_head = head(nqk); self.qv_head = head(nqv) for e in [self.ent,self.rel,self.qk,self.qv]: nn.init.xavier_uniform_(e.weight) def _encode(self, h, r, t, qs, dev, mask_pos=None): tokens = [self.ent(torch.tensor([h],device=dev)), self.rel(torch.tensor([r],device=dev)), self.ent(torch.tensor([t],device=dev))] for qk_id, qv_id in qs: tokens += [self.qk(torch.tensor([qk_id],device=dev)), self.qv(torch.tensor([qv_id],device=dev))] seq = torch.cat(tokens, dim=0).unsqueeze(0) if mask_pos is not None and mask_pos < seq.size(1): seq = seq.clone(); seq[0, mask_pos] = 0.0 return self.encoder(seq)[0, 0] def forward_task(self, task, s, dev): h, r, t, qs = s['h'], s['r'], s['t'], s['qs'] if task == 'tail': return self.tail_head(self._encode(h,r,0,qs,dev,mask_pos=2)) elif task == 'relation': return self.rel_head(self._encode(h,r,t,qs,dev,mask_pos=1)) elif task == 'qual_key': return self.qk_head(self._encode(h,r,t,[],dev)) elif task == 'qual_value': filtered = [(k,v) for k,v in qs if k != s['qk']] return self.qv_head(self._encode(h,r,t,filtered,dev)) def forward(self, head, relation, qualifiers, tail=None): dev = head.device; B = head.size(0) outs = [self.forward_task('tail',{'h':head[i].item(),'r':relation[i].item(), 't':0,'qs':qualifiers[i]},dev) for i in range(B)] scores = torch.stack(outs, dim=0) if tail is not None: return scores[torch.arange(B), tail] return scores def train_multitask(model, model_name="MultiTask_AlertStar", active_tasks=None, train_data_=None, valid_ds_=None, nqk_override=None): if active_tasks is None: active_tasks = ['tail','relation','qual_key','qual_value'] dev = CONFIG['device'] _td = train_data_ or train_data _vds = valid_ds_ or valid_ds _nqk = nqk_override or NQK model.to(dev) opt = torch.optim.Adam(model.parameters(), lr=CONFIG['mt_lr']) loader = DataLoader(MultiTaskDataset(_td, preprocessor, tasks=active_tasks, nqk_override=_nqk), batch_size=CONFIG['mt_batch'], shuffle=True, collate_fn=mt_collate) weights = {'tail':1.0,'relation':1.0,'qual_key':0.5,'qual_value':0.8} print(f"\n{'='*60}\nTRAINING {model_name} tasks={active_tasks}\n{'='*60}") best_mrr, history = 0.0, [] for epoch in range(CONFIG['mt_epochs']): model.train(); total, cnt = 0.0, 0 for by_task in tqdm(loader, desc=f"Epoch {epoch+1}", leave=False): bl = torch.tensor(0., device=dev) for task, samples in by_task.items(): for s in samples: try: logits = model.forward_task(task, s, dev) w = weights.get(task, 1.0) if task == 'tail': tgt = torch.tensor(s['t'], device=dev) loss = F.cross_entropy(logits.unsqueeze(0), tgt.unsqueeze(0)) elif task == 'relation': tgt = torch.tensor(s['r'], device=dev) loss = F.cross_entropy(logits.unsqueeze(0), tgt.unsqueeze(0)) elif task == 'qual_key': tgt = torch.zeros(_nqk, device=dev) tgt[s['keys']] = 1.0 loss = F.binary_cross_entropy_with_logits(logits, tgt) elif task == 'qual_value': tgt = torch.tensor(s['qv'], device=dev) loss = F.cross_entropy(logits.unsqueeze(0), tgt.unsqueeze(0)) else: continue bl = bl + w*loss except Exception: continue opt.zero_grad(); bl.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step(); total += bl.item(); cnt += 1 if (epoch+1) % 5 == 0: m = evaluate_model(model, _vds, dev) print(f" Epoch {epoch+1}: loss={total/max(cnt,1):.4f} " f"MRR={m['mrr']:.4f} H@1={m['hits@1']:.4f} H@10={m['hits@10']:.4f}") if m['mrr'] > best_mrr: best_mrr = m['mrr'] torch.save(model.state_dict(), f"{CONFIG['output_path']}{model_name}_best.pt") history.append(m) return model, history print("MultiTask AlertStar defined") # ============================================================================ # 12 — Model 9: HR-NBFNet (Hyper-Relational Bellman-Ford) # # Matches slide formulation exactly: # h(0)_uvqq' <- INDICATOR(u, v, q', q) # phi_q(h_qk, h_qv) = h_qk · h_qv [DisMult per qualifier pair] # h_q = W_q · SUM phi_q [projected qualifier sum] # w_q = sigma(qual_gate(r)) [per-relation scalar gate] # MSG = src_feat * (rel_emb + w_q * h_q) [qualifier-gated message] # h(t) = AGG(msgs) + h(0) [shortcut to h(0) not h(t-1)] # ============================================================================ def build_hr_nbfnet_graph(train_data, preprocessor, device, max_quals=8, p_override=None): """ Build qualifier-aware edge tensors. Unlike TrueNBFNet's plain [E,3] edge_list, stores per-edge qualifier pairs in a padded tensor [E*2, max_quals, 2] so each propagation layer can apply DisMult qualifier composition per edge. """ p = p_override or preprocessor nr = len(p.relation2id) srcs, dsts, rels, qual_list, nquals = [], [], [], [], [] for t in train_data: h = p.entity2id[t['head']] r = p.relation2id[t['relation']] tl = p.entity2id[t['tail']] qs = [(p.qualifier_key2id[qk], p.qualifier_value2id[qv]) for qk, qv in t['qualifiers']] # forward + inverse edges — both carry the same qualifiers srcs.append(h); dsts.append(tl); rels.append(r); qual_list.append(qs); nquals.append(len(qs)) srcs.append(tl); dsts.append(h); rels.append(r+nr); qual_list.append(qs); nquals.append(len(qs)) E = len(srcs) quals_tensor = torch.zeros(E, max_quals, 2, dtype=torch.long) for i, qs in enumerate(qual_list): for j, (qk, qv) in enumerate(qs[:max_quals]): quals_tensor[i, j, 0] = qk quals_tensor[i, j, 1] = qv return { 'edge_src': torch.tensor(srcs, dtype=torch.long, device=device), 'edge_dst': torch.tensor(dsts, dtype=torch.long, device=device), 'edge_rel': torch.tensor(rels, dtype=torch.long, device=device), 'edge_quals': quals_tensor.to(device), # [E, max_quals, 2] 'edge_nquals': torch.tensor(nquals, dtype=torch.long, device=device), 'num_nodes': len(p.entity2id), 'nr': nr, 'max_quals': max_quals, } class HRNBFConvLayer(nn.Module): def __init__(self, dim, num_relation, nqk, nqv, chunk_size=5000, layer_norm=True, dropout=0.1): super().__init__() self.dim = dim self.chunk_size = chunk_size self.rel_emb = nn.Embedding(num_relation, dim) self.qk_emb = nn.Embedding(nqk, dim) self.qv_emb = nn.Embedding(nqv, dim) self.W_q = nn.Linear(dim, dim, bias=False) self.qual_gate = nn.Embedding(num_relation, 1) nn.init.ones_(self.qual_gate.weight) # start fully open self.linear = nn.Linear(dim*2, dim) self.ln = nn.LayerNorm(dim) if layer_norm else None self.act = nn.ReLU() self.drop = nn.Dropout(dropout) for e in [self.rel_emb, self.qk_emb, self.qv_emb]: nn.init.xavier_uniform_(e.weight) nn.init.xavier_uniform_(self.W_q.weight) def _qualifier_embedding(self, edge_quals, edge_nquals, edge_rels): """h_q = w_q(r) * W_q · SUM (h_qk * h_qv)""" E, max_q, _ = edge_quals.shape h_qk = self.qk_emb(edge_quals[:,:,0]) # [E, max_q, dim] h_qv = self.qv_emb(edge_quals[:,:,1]) # [E, max_q, dim] phi_q = h_qk * h_qv # DisMult [E, max_q, dim] idx = torch.arange(max_q, device=edge_quals.device).unsqueeze(0) mask = (idx < edge_nquals.unsqueeze(1)).float().unsqueeze(2) phi_q = phi_q * mask # zero padding h_q = self.W_q(phi_q.sum(dim=1)) # [E, dim] w_q = torch.sigmoid(self.qual_gate(edge_rels)) # [E, 1] return w_q * h_q def forward(self, graph, node_feat, h0): src = graph['edge_src']; dst = graph['edge_dst'] rel = graph['edge_rel']; quals = graph['edge_quals'] nquals = graph['edge_nquals'] N = graph['num_nodes']; dev = node_feat.device agg = torch.zeros(N, self.dim, device=dev) for start in range(0, src.size(0), self.chunk_size): end = start + self.chunk_size s_ = src[start:end]; d_ = dst[start:end] r_ = rel[start:end]; qs_ = quals[start:end]; nqs = nquals[start:end] r_emb = self.rel_emb(r_) h_q = self._qualifier_embedding(qs_, nqs, r_) src_feat = node_feat[s_] msg = src_feat * (r_emb + h_q) # qualifier-gated message agg.scatter_add_(0, d_.unsqueeze(1).expand_as(msg), msg) del r_emb, h_q, src_feat, msg out = self.linear(torch.cat([node_feat, agg], dim=-1)) if self.ln: out = self.ln(out) out = self.act(out); out = self.drop(out) return out + h0 # shortcut to h(0) class HRNBFNet(nn.Module): def __init__(self, ne, nr, nqk, nqv, dim=200, num_layers=3, chunk_size=5000, dropout=0.1, max_quals=8): super().__init__() self.num_entities = ne self.dim = dim self.max_quals = max_quals nr2 = nr * 2 self.query_emb = nn.Embedding(nr2, dim) self.query_qk_emb = nn.Embedding(nqk, dim) self.query_qv_emb = nn.Embedding(nqv, dim) self.query_qual_proj = nn.Linear(dim, dim, bias=False) self.layers = nn.ModuleList([ HRNBFConvLayer(dim, nr2, nqk, nqv, chunk_size=chunk_size, dropout=dropout) for _ in range(num_layers)]) self.mlp = nn.Sequential( nn.Linear(dim*2, dim), nn.LayerNorm(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, 1)) for e in [self.query_emb, self.query_qk_emb, self.query_qv_emb]: nn.init.xavier_uniform_(e.weight) nn.init.xavier_uniform_(self.query_qual_proj.weight) def _indicator_init(self, graph, h_idx, r_idx, query_quals, device): """INDICATOR(u,v,q',q): source node gets rel + qualifier context.""" N = graph['num_nodes'] q_rel = self.query_emb(torch.tensor([r_idx], device=device)) if query_quals: qk_ids = torch.tensor([qk for qk,_ in query_quals], device=device) qv_ids = torch.tensor([qv for _,qv in query_quals], device=device) phi = (self.query_qk_emb(qk_ids) * self.query_qv_emb(qv_ids)).sum(0, keepdim=True) q_qual = self.query_qual_proj(phi) else: q_qual = torch.zeros(1, self.dim, device=device) feat = torch.zeros(N, self.dim, device=device) feat[h_idx] = q_rel.squeeze() + q_qual.squeeze() return feat def _propagate(self, graph, h_idx, r_idx, query_quals, device): feat = self._indicator_init(graph, h_idx, r_idx, query_quals, device) h0 = feat.clone() for layer in self.layers: feat = layer(graph, feat, h0) return feat def forward(self, head, relation, qualifiers=None, tail=None, graph=None): assert graph is not None assert (head==head[0]).all() and (relation==relation[0]).all() dev = head.device query_quals = qualifiers[0] if qualifiers else [] feat = self._propagate(graph, head[0].item(), relation[0].item(), query_quals, dev) q_emb = self.query_emb(torch.tensor([relation[0].item()], device=dev)) score_in = torch.cat([feat, q_emb.expand(graph['num_nodes'],-1)], dim=-1) all_scores = self.mlp(score_in).squeeze(-1) if tail is not None: return all_scores[tail] return all_scores.unsqueeze(0).expand(head.size(0), -1) def train_hr_nbfnet(model, model_name="HR_NBFNet", train_data_=None, valid_ds_=None, graph_=None, p_override=None): dev = CONFIG['device'] _td = train_data_ or train_data _vds = valid_ds_ or valid_ds _g = graph_ or hr_nbfnet_graph _p = p_override or preprocessor model.to(dev) graph = {k: v.to(dev) if torch.is_tensor(v) else v for k,v in _g.items()} ne_ = graph['num_nodes'] opt = torch.optim.Adam(model.parameters(), lr=CONFIG['hr_nbfnet_lr']) mpg = CONFIG['nbfnet_max_per_group'] groups = defaultdict(list) for i, s in enumerate(_td): groups[(_p.entity2id[s['head']], _p.relation2id[s['relation']])].append(i) keys = list(groups.keys()) print(f"\n{'='*60}\nTRAINING {model_name} " f"({sum(p.numel() for p in model.parameters()):,} params)") print(f" Graph: {graph['edge_src'].size(0):,} qualifier-aware edges\n{'='*60}") best_mrr, history = 0.0, [] for epoch in range(CONFIG['hr_nbfnet_epochs']): model.train(); np.random.shuffle(keys) total, cnt = 0.0, 0 for (h, r) in tqdm(keys, desc=f"Epoch {epoch+1}", leave=False): chosen = np.random.choice(groups[(h,r)], min(len(groups[(h,r)]),mpg), replace=False) t_pos = torch.tensor( [_p.entity2id[_td[i]['tail']] for i in chosen], device=dev) all_quals = [[(_p.qualifier_key2id[qk], _p.qualifier_value2id[qv]) for qk, qv in _td[i]['qualifiers']] for i in chosen] rep_quals = max(all_quals, key=len) if all_quals else [] B = len(t_pos) heads = torch.full((B,), h, dtype=torch.long, device=dev) rels = torch.full((B,), r, dtype=torch.long, device=dev) try: pos = model(heads, rels, qualifiers=[rep_quals]*B, tail=t_pos, graph=graph) neg = model(heads, rels, qualifiers=[rep_quals]*B, tail=torch.randint(0,ne_,(B,),device=dev), graph=graph) loss = F.margin_ranking_loss(pos, neg, torch.ones(B,device=dev), margin=1.0) opt.zero_grad(); loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step(); total += loss.item(); cnt += 1 except Exception: continue if (epoch+1) % 5 == 0: m = evaluate_hr_nbfnet(model, _vds, graph, dev, p_override=_p) print(f" Epoch {epoch+1}: loss={total/max(cnt,1):.4f} " f"MRR={m['mrr']:.4f} MR={m['mr']:.1f} " f"H@1={m['hits@1']:.4f} H@3={m['hits@3']:.4f} " f"H@10={m['hits@10']:.4f}") if m['mrr'] > best_mrr: best_mrr = m['mrr'] torch.save(model.state_dict(), f"{CONFIG['output_path']}{model_name}_best.pt") print(" → best saved") history.append(m) return model, history def evaluate_hr_nbfnet(model, dataset, graph, device, max_groups=300, p_override=None): _p = p_override or preprocessor model.eval() groups = defaultdict(list) for i, s in enumerate(dataset.data): groups[(_p.entity2id[s['head']], _p.relation2id[s['relation']])].append(i) ranks = [] with torch.no_grad(): for (h,r) in tqdm(list(groups.keys())[:max_groups], desc="Eval HR-NBFNet", leave=False): idxs = groups[(h,r)] tails = [_p.entity2id[dataset.data[i]['tail']] for i in idxs] B = len(tails) heads = torch.full((B,),h,dtype=torch.long,device=device) rels = torch.full((B,),r,dtype=torch.long,device=device) s0 = dataset.data[idxs[0]] quals = [(_p.qualifier_key2id[qk], _p.qualifier_value2id[qv]) for qk,qv in s0['qualifiers']] try: scores = model(heads,rels,qualifiers=[quals]*B,graph=graph) for i, tgt in enumerate(tails): pos = (torch.argsort(scores[i],descending=True)==tgt ).nonzero(as_tuple=True)[0] if len(pos): ranks.append(pos[0].item()+1) except Exception: continue if not ranks: return {'mr':0.,'mrr':0.,'hits@1':0.,'hits@3':0.,'hits@10':0.} ranks = np.array(ranks) return {'mr':float(np.mean(ranks)),'mrr':float(np.mean(1./ranks)), 'hits@1':float(np.mean(ranks<=1)),'hits@3':float(np.mean(ranks<=3)), 'hits@10':float(np.mean(ranks<=10))} print("HR-NBFNet defined") # ============================================================================ # 13 — Model 10: MultiTask_HR_NBFNet (NEW) # # Combines HR-NBFNet's qualifier-aware Bellman-Ford propagation with the # 4-task multi-task training strategy from MultiTask AlertStar. # # Architecture: # - HR-NBFNet backbone: propagates query-conditioned features with # per-edge DisMult qualifier embeddings → node feature f^L[v] ∈ R^d # - 4 prediction heads sharing the backbone, each a 2-layer MLP: # tail: MLP(cat(f^L[e'], q_emb)) → R^1 [margin ranking] # relation: MLP(f^L[h]) → R^|R| [cross-entropy] # qual_key: MLP(f^L[h]) → R^|QK| [BCE multi-label] # qual_value: MLP(f^L[h] · qk_emb) → R^|QV| [cross-entropy] # # Key design choices vs MultiTask AlertStar (MT-AS): # - MT-AS: Transformer over flat token sequence (local, no graph) # - MT-HR: Bellman-Ford over HR graph (global, structure-aware) # - Relation/qual-key/qual-value heads use head-node BF representation, # so auxiliary tasks receive path-enriched graph signals # - Qualifier value head gates on qk_emb (same as HyNT) for fine-grained # attribute discrimination # ============================================================================ class MultiTaskHRNBFNet(nn.Module): def __init__(self, ne, nr, nqk, nqv, dim=200, num_layers=3, chunk_size=5000, dropout=0.1, max_quals=8): super().__init__() self.num_entities = ne self.nr = nr self.nqk = nqk self.nqv = nqv self.dim = dim # ── Shared HR-NBFNet backbone ───────────────────────────────────── nr2 = nr * 2 self.query_emb = nn.Embedding(nr2, dim) self.query_qk_emb = nn.Embedding(nqk, dim) self.query_qv_emb = nn.Embedding(nqv, dim) self.query_qual_proj = nn.Linear(dim, dim, bias=False) self.layers = nn.ModuleList([ HRNBFConvLayer(dim, nr2, nqk, nqv, chunk_size=chunk_size, dropout=dropout) for _ in range(num_layers)]) # ── Prediction heads ────────────────────────────────────────────── def _mlp_head(in_dim, out_dim): return nn.Sequential( nn.Linear(in_dim, dim), nn.LayerNorm(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, out_dim)) # tail: concat(f^L[e'], q_emb) → scalar self.tail_mlp = _mlp_head(dim*2, 1) # relation: f^L[h] → |R| (use head node's BF representation) self.rel_head = _mlp_head(dim, nr) # qual_key: f^L[h] → |QK| (multi-label) self.qk_head = _mlp_head(dim, nqk) # qual_value: gate(f^L[h], qk_emb) → |QV| self.qv_gate = nn.Sequential(nn.Linear(dim*2, dim), nn.Sigmoid()) self.qv_head = _mlp_head(dim, nqv) # qualifier key embedding (for qv gating) self.qk_emb_head = nn.Embedding(nqk, dim) for e in [self.query_emb, self.query_qk_emb, self.query_qv_emb, self.qk_emb_head]: nn.init.xavier_uniform_(e.weight) nn.init.xavier_uniform_(self.query_qual_proj.weight) # ── backbone: shared with HRNBFNet ──────────────────────────────────── def _indicator_init(self, graph, h_idx, r_idx, query_quals, device): N = graph['num_nodes'] q_rel = self.query_emb(torch.tensor([r_idx], device=device)) if query_quals: qk_ids = torch.tensor([qk for qk,_ in query_quals], device=device) qv_ids = torch.tensor([qv for _,qv in query_quals], device=device) phi = (self.query_qk_emb(qk_ids) * self.query_qv_emb(qv_ids)).sum(0, keepdim=True) q_qual = self.query_qual_proj(phi) else: q_qual = torch.zeros(1, self.dim, device=device) feat = torch.zeros(N, self.dim, device=device) feat[h_idx] = q_rel.squeeze() + q_qual.squeeze() return feat def _propagate(self, graph, h_idx, r_idx, query_quals, device): feat = self._indicator_init(graph, h_idx, r_idx, query_quals, device) h0 = feat.clone() for layer in self.layers: feat = layer(graph, feat, h0) return feat # [N, dim] # ── task-specific forward passes ────────────────────────────────────── def forward_tail_task(self, graph, h_idx, r_idx, query_quals, tail_ids, device): """Tail prediction: score MLP(cat(f^L[tail], q_emb)).""" feat = self._propagate(graph, h_idx, r_idx, query_quals, device) q_emb = self.query_emb(torch.tensor([r_idx], device=device)) # score all entities score_in = torch.cat([feat, q_emb.expand(graph['num_nodes'],-1)], dim=-1) all_scores = self.tail_mlp(score_in).squeeze(-1) # [N] if tail_ids is not None: return all_scores[tail_ids] return all_scores def forward_rel_task(self, graph, h_idx, r_idx, query_quals, device): """Relation prediction: classify from head node's BF representation.""" feat = self._propagate(graph, h_idx, r_idx, query_quals, device) return self.rel_head(feat[h_idx]) # [nr] def forward_qk_task(self, graph, h_idx, r_idx, device): """Qualifier key prediction (multi-label) from head BF representation.""" # run with empty qualifiers so we don't leak qk info feat = self._propagate(graph, h_idx, r_idx, [], device) return self.qk_head(feat[h_idx]) # [nqk] def forward_qv_task(self, graph, h_idx, r_idx, query_quals, target_qk, device): """Qualifier value prediction gated on the target qualifier key.""" feat = self._propagate(graph, h_idx, r_idx, query_quals, device) h_repr = feat[h_idx] qk_emb = self.qk_emb_head(torch.tensor([target_qk], device=device)).squeeze() gate = self.qv_gate(torch.cat([h_repr, qk_emb], dim=-1)) return self.qv_head(gate * h_repr) # [nqv] # ── standard forward for evaluate_model compatibility ───────────────── def forward(self, head, relation, qualifiers=None, tail=None, graph=None): assert graph is not None assert (head==head[0]).all() and (relation==relation[0]).all() dev = head.device query_quals = qualifiers[0] if qualifiers else [] all_scores = self.forward_tail_task( graph, head[0].item(), relation[0].item(), query_quals, None, dev) if tail is not None: return all_scores[tail] return all_scores.unsqueeze(0).expand(head.size(0), -1) def train_mt_hr_nbfnet(model, model_name="MultiTask_HR_NBFNet", train_data_=None, valid_ds_=None, graph_=None, p_override=None, nqk_override=None, nqv_override=None): """ Multi-task training for HR-NBFNet backbone. Each (h,r) group triggers one BF propagation pass; auxiliary tasks (relation, qual_key, qual_value) add supervision on the head node's final graph representation f^L[h]. Task weights: tail=1.0, relation=0.8, qual_key=0.5, qual_value=0.8 """ dev = CONFIG['device'] _td = train_data_ or train_data _vds = valid_ds_ or valid_ds _g = graph_ or hr_nbfnet_graph _p = p_override or preprocessor _nqk = nqk_override or NQK _nqv = nqv_override or NQV model.to(dev) graph = {k: v.to(dev) if torch.is_tensor(v) else v for k,v in _g.items()} ne_ = graph['num_nodes'] opt = torch.optim.Adam(model.parameters(), lr=CONFIG['mt_hr_lr']) mpg = CONFIG['nbfnet_max_per_group'] weights = {'tail':1.0, 'relation':0.8, 'qual_key':0.5, 'qual_value':0.8} # Group training triples by (h, r) groups = defaultdict(list) for i, s in enumerate(_td): groups[(_p.entity2id[s['head']], _p.relation2id[s['relation']])].append(i) keys = list(groups.keys()) print(f"\n{'='*60}\nTRAINING {model_name} " f"({sum(p.numel() for p in model.parameters()):,} params)") print(f" Graph: {graph['edge_src'].size(0):,} qualifier-aware edges") print(f" Tasks: tail / relation / qual_key / qual_value\n{'='*60}") best_mrr, history = 0.0, [] for epoch in range(CONFIG['mt_hr_epochs']): model.train(); np.random.shuffle(keys) total, cnt = 0.0, 0 for (h, r) in tqdm(keys, desc=f"Epoch {epoch+1}", leave=False): chosen = np.random.choice(groups[(h,r)], min(len(groups[(h,r)]),mpg), replace=False) samples = [_td[i] for i in chosen] all_quals = [[(_p.qualifier_key2id[qk], _p.qualifier_value2id[qv]) for qk,qv in s['qualifiers']] for s in samples] rep_quals = max(all_quals, key=len) if all_quals else [] t_ids = torch.tensor( [_p.entity2id[s['tail']] for s in samples], device=dev) neg_ids = torch.randint(0, ne_, (len(samples),), device=dev) batch_loss = torch.tensor(0., device=dev) try: # ── Task 1: Tail prediction (margin ranking) ─────────────── pos_scores = model.forward_tail_task( graph, h, r, rep_quals, t_ids, dev) neg_scores = model.forward_tail_task( graph, h, r, rep_quals, neg_ids, dev) loss_tail = F.margin_ranking_loss( pos_scores, neg_scores, torch.ones(len(samples), device=dev), margin=1.0) batch_loss = batch_loss + weights['tail'] * loss_tail # ── Task 2: Relation prediction (cross-entropy) ──────────── rel_logits = model.forward_rel_task(graph, h, r, rep_quals, dev) loss_rel = F.cross_entropy( rel_logits.unsqueeze(0), torch.tensor([r], device=dev)) batch_loss = batch_loss + weights['relation'] * loss_rel # ── Task 3: Qualifier key prediction (BCE multi-label) ───── if rep_quals: qk_logits = model.forward_qk_task(graph, h, r, dev) qk_target = torch.zeros(_nqk, device=dev) qk_target[[qk for qk,_ in rep_quals[:_nqk]]] = 1.0 loss_qk = F.binary_cross_entropy_with_logits( qk_logits, qk_target) batch_loss = batch_loss + weights['qual_key'] * loss_qk # ── Task 4: Qualifier value prediction (cross-entropy) ───── if rep_quals: tgt_qk, tgt_qv = rep_quals[0] qv_logits = model.forward_qv_task( graph, h, r, rep_quals, tgt_qk, dev) loss_qv = F.cross_entropy( qv_logits.unsqueeze(0), torch.tensor([tgt_qv], device=dev)) batch_loss = batch_loss + weights['qual_value'] * loss_qv opt.zero_grad(); batch_loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step(); total += batch_loss.item(); cnt += 1 except Exception: continue if (epoch+1) % 5 == 0: m = evaluate_hr_nbfnet(model, _vds, graph, dev, p_override=_p) print(f" Epoch {epoch+1}: loss={total/max(cnt,1):.4f} " f"MRR={m['mrr']:.4f} MR={m['mr']:.1f} " f"H@1={m['hits@1']:.4f} H@3={m['hits@3']:.4f} " f"H@10={m['hits@10']:.4f}") if m['mrr'] > best_mrr: best_mrr = m['mrr'] torch.save(model.state_dict(), f"{CONFIG['output_path']}{model_name}_best.pt") print(" → best saved") history.append(m) return model, history print("MultiTask_HR_NBFNet defined") # ============================================================================ # 14 — TRAIN ALL 10 MAIN MODELS # ============================================================================ print("\n" + "="*70) print("PHASE 1: TRAINING ALL 10 MAIN MODELS") print("="*70) # ── 1. StarE ───────────────────────────────────────────────────────────── stare_model = StarEModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout']) stare_model, _, _ = train_standard(stare_model, "StarE") all_results['StarE'] = evaluate_model(stare_model, test_ds, CONFIG['device']) # ── 2. ShrinkE ─────────────────────────────────────────────────────────── shrinke_model = ShrinkEModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout']) shrinke_model, _, _ = train_standard(shrinke_model, "ShrinkE") all_results['ShrinkE'] = evaluate_model(shrinke_model, test_ds, CONFIG['device']) # ── 3. TrueNBFNet ──────────────────────────────────────────────────────── nbfnet_graph = build_nbfnet_graph(train_data, preprocessor, CONFIG['device']) nbfnet_model = TrueNBFNet(NE, NR, dim=DIM, num_layers=CONFIG['nbfnet_layers'], chunk_size=CONFIG['nbfnet_chunk_size'], dropout=CONFIG['dropout'], short_cut=True) nbfnet_model, _ = train_true_nbfnet(nbfnet_model, "TrueNBFNet") nbfnet_graph_dev = {k: v.to(CONFIG['device']) if torch.is_tensor(v) else v for k,v in nbfnet_graph.items()} all_results['TrueNBFNet'] = evaluate_nbfnet( nbfnet_model, test_ds, nbfnet_graph_dev, CONFIG['device']) # ── 4. AlertStar ───────────────────────────────────────────────────────── alertstar_model = AlertStarModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout']) alertstar_model, _, gate_history = train_standard( alertstar_model, "AlertStar", gate_track=True) all_results['AlertStar'] = evaluate_model(alertstar_model, test_ds, CONFIG['device']) # ── 5. StarQE ──────────────────────────────────────────────────────────── starqe_model = StarQEModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout']) starqe_model, _ = train_starqe(starqe_model, "StarQE") all_results['StarQE'] = evaluate_model(starqe_model, test_ds, CONFIG['device']) # ── 6. NBFNet+StarQE ────────────────────────────────────────────────────── nbfstarqe_model = NBFNetStarQEModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout']) nbfstarqe_model, _ = train_starqe(nbfstarqe_model, "NBFNet_StarQE") all_results['NBFNet_StarQE'] = evaluate_model(nbfstarqe_model, test_ds, CONFIG['device']) # ── 7. HyNT ────────────────────────────────────────────────────────────── hynt_model = HyNTModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout'], n_heads=CONFIG['hynt_n_heads'], n_layers=CONFIG['hynt_n_layers']) hynt_model, _ = train_hynt(hynt_model, "HyNT") all_results['HyNT'] = evaluate_model(hynt_model, test_ds, CONFIG['device']) # ── 8. MultiTask AlertStar ──────────────────────────────────────────────── mt_model = MultiTaskAlertStar(NE, NR, NQK, NQV, DIM, dropout=CONFIG['dropout']) mt_model, _ = train_multitask(mt_model, "MultiTask_AlertStar") all_results['MultiTask_AlertStar'] = evaluate_model(mt_model, test_ds, CONFIG['device']) # ── 9. HR-NBFNet ───────────────────────────────────────────────────────── print("\nBuilding HR-NBFNet qualifier-aware graph...") hr_nbfnet_graph = build_hr_nbfnet_graph( train_data, preprocessor, CONFIG['device'], max_quals=CONFIG['hr_nbfnet_max_quals']) print(f"HR graph: {hr_nbfnet_graph['edge_src'].size(0):,} edges") hr_model = HRNBFNet(NE, NR, NQK, NQV, DIM, num_layers=CONFIG['hr_nbfnet_layers'], chunk_size=CONFIG['hr_nbfnet_chunk_size'], dropout=CONFIG['dropout'], max_quals=CONFIG['hr_nbfnet_max_quals']) hr_model, _ = train_hr_nbfnet(hr_model, "HR_NBFNet") hr_graph_dev = {k: v.to(CONFIG['device']) if torch.is_tensor(v) else v for k,v in hr_nbfnet_graph.items()} all_results['HR_NBFNet'] = evaluate_hr_nbfnet( hr_model, test_ds, hr_graph_dev, CONFIG['device']) # ── 10. MultiTask_HR_NBFNet (NEW) ───────────────────────────────────────── # Reuses hr_nbfnet_graph — same qualifier-aware edge structure mt_hr_model = MultiTaskHRNBFNet(NE, NR, NQK, NQV, DIM, num_layers=CONFIG['hr_nbfnet_layers'], chunk_size=CONFIG['hr_nbfnet_chunk_size'], dropout=CONFIG['dropout'], max_quals=CONFIG['hr_nbfnet_max_quals']) mt_hr_model, _ = train_mt_hr_nbfnet(mt_hr_model, "MultiTask_HR_NBFNet") all_results['MultiTask_HR_NBFNet'] = evaluate_hr_nbfnet( mt_hr_model, test_ds, hr_graph_dev, CONFIG['device']) # ── Summary ─────────────────────────────────────────────────────────────── print("\n All 10 models trained") print("\nMAIN RESULTS:") for m, r in all_results.items(): tag = " ← NEW" if m in ('HR_NBFNet','MultiTask_HR_NBFNet') else "" print(f" {m:30s} MRR={r['mrr']:.4f} MR={r['mr']:7.1f} " f"H@1={r['hits@1']:.4f} H@3={r['hits@3']:.4f} " f"H@10={r['hits@10']:.4f}{tag}") # ============================================================================ # 15 — ABLATION A1: AlertStar Component Ablation # ============================================================================ print("\n" + "="*70) print("ABLATION A1: AlertStar Component Ablation") print("="*70) ablation_configs = [ ("AS-NoQual", dict(use_qual=False, use_path=True, fixed_gate=None)), ("AS-NoPath", dict(use_qual=True, use_path=False, fixed_gate=None)), ("AS-NoGate", dict(use_qual=True, use_path=True, fixed_gate=0.5)), ("AS-Full", dict(use_qual=True, use_path=True, fixed_gate=None)), ] ablation_A1 = {} for name, cfg in ablation_configs: m = AlertStarModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout'], **cfg) m, _, _ = train_standard(m, f"Ablation_{name}") ablation_A1[name] = evaluate_model(m, test_ds, CONFIG['device']) r = ablation_A1[name] print(f" {name}: MRR={r['mrr']:.4f} H@1={r['hits@1']:.4f} H@10={r['hits@10']:.4f}") ablation_results['A1_AlertStar_Components'] = ablation_A1 print("Ablation A1 complete") # ============================================================================ # 16 — ABLATION A2: Gate Value Trajectory # ============================================================================ print("\n" + "="*70) print("ABLATION A2: Gate Value Analysis") print("="*70) if gate_history: gate_df = pd.DataFrame(gate_history) print(gate_df.to_string(index=False)) final_gate = gate_history[-1]['gate'] print(f"\n Final gate g = {final_gate:.4f}") if final_gate > 0.6: print(" → Attention stream dominates (qualifier-awareness)") elif final_gate < 0.4: print(" → Path stream dominates (structural reasoning)") else: print(" → Balanced — both streams equally useful") ablation_results['A2_Gate_Values'] = gate_history else: print(" No gate history available") print("Ablation A2 complete") # ============================================================================ # 17 — ABLATION A3: MultiTask Auxiliary Task Ablation # ============================================================================ print("\n" + "="*70) print("ABLATION A3: MultiTask AlertStar — Task Contribution") print("="*70) mt_ablation_configs = [ ("MT-TailOnly", ['tail']), ("MT-Tail+Rel", ['tail','relation']), ("MT-Tail+QualKey", ['tail','qual_key']), ("MT-Tail+QualVal", ['tail','qual_value']), ("MT-Full", ['tail','relation','qual_key','qual_value']), ] ablation_A3 = {} for name, tasks in mt_ablation_configs: m = MultiTaskAlertStar(NE, NR, NQK, NQV, DIM, dropout=CONFIG['dropout']) m, _ = train_multitask(m, f"Ablation_{name}", active_tasks=tasks) ablation_A3[name] = evaluate_model(m, test_ds, CONFIG['device']) r = ablation_A3[name] print(f" {name}: MRR={r['mrr']:.4f} H@1={r['hits@1']:.4f} H@10={r['hits@10']:.4f}") ablation_results['A3_MultiTask_Tasks'] = ablation_A3 print("Ablation A3 complete") # ============================================================================ # 18 — ABLATION A4: Qualifier Density (Q33 / Q66 / Q100) # # Models compared: StarE, AlertStar, HyNT, MultiTask-AS, # HR-NBFNet, MultiTask_HR_NBFNet # For Q33: reuse main results. For Q66/Q100: retrain all 6. # ============================================================================ print("\n" + "="*70) print("ABLATION A4: Qualifier Density Sensitivity") print("="*70) DENSITY_PATHS = { 'Q100': CONFIG['data_path'], 'Q33': CONFIG['q33_path'], 'Q66': CONFIG['q66_path'], } DENSITY_MODELS = ['StarE','AlertStar','HyNT', 'MultiTask_AlertStar','HR_NBFNet','MultiTask_HR_NBFNet'] ablation_A4 = {} for density_label, dpath in DENSITY_PATHS.items(): print(f"\n{'='*50} {density_label} {'='*50}") if dpath == CONFIG['data_path']: ablation_A4[density_label] = { m: all_results[m] for m in DENSITY_MODELS if m in all_results} print(f" Reusing trained models for {density_label}") continue # Load density-specific dataset p2 = DataPreprocessor() tr2, va2, te2 = p2.load(dpath) ne2 = len(p2.entity2id); nr2 = len(p2.relation2id) nqk2 = len(p2.qualifier_key2id); nqv2 = len(p2.qualifier_value2id) tr_ds2 = HRDataset(tr2, p2); va_ds2 = HRDataset(va2, p2); te_ds2 = HRDataset(te2, p2) density_res = {} # StarE m = StarEModel(ne2, nr2, nqk2, nqv2, DIM, CONFIG['dropout']) m, _, _ = train_standard(m, f"A4_{density_label}_StarE", train_ds_=tr_ds2, valid_ds_=va_ds2, ne_override=ne2) density_res['StarE'] = evaluate_model(m, te_ds2, CONFIG['device']) # AlertStar m = AlertStarModel(ne2, nr2, nqk2, nqv2, DIM, CONFIG['dropout']) m, _, _ = train_standard(m, f"A4_{density_label}_AlertStar", train_ds_=tr_ds2, valid_ds_=va_ds2, ne_override=ne2) density_res['AlertStar'] = evaluate_model(m, te_ds2, CONFIG['device']) # HyNT — build its own HyNTDataset with p2 _ht2 = HyNTDataset(tr2, p2) m = HyNTModel(ne2, nr2, nqk2, nqv2, DIM, CONFIG['dropout'], n_heads=4, n_layers=2) m.to(CONFIG['device']) _opt2 = torch.optim.Adam(m.parameters(), lr=CONFIG['hynt_lr']) _loader2 = DataLoader(_ht2, batch_size=CONFIG['hynt_batch'], shuffle=True, collate_fn=mt_collate) _best2 = 0.0 for _ep in range(CONFIG['hynt_epochs']): m.train() for _bt in tqdm(_loader2, desc=f"A4 HyNT {density_label} ep{_ep+1}", leave=False): _bl = torch.tensor(0., device=CONFIG['device']) for _task, _samps in _bt.items(): for _s in _samps: try: if _task == 'tail': _lg = m.forward_tail(_s, CONFIG['device']) _tg = torch.tensor(_s['t'], device=CONFIG['device']) _bl = _bl + F.cross_entropy(_lg.unsqueeze(0), _tg.unsqueeze(0)) elif _task == 'qv': _lg = m.forward_qv(_s, CONFIG['device']) _tg = torch.tensor(_s['qv'], device=CONFIG['device']) _bl = _bl + 0.8*F.cross_entropy(_lg.unsqueeze(0), _tg.unsqueeze(0)) except Exception: continue _opt2.zero_grad(); _bl.backward() torch.nn.utils.clip_grad_norm_(m.parameters(), 1.0); _opt2.step() if (_ep+1) % 5 == 0: _mv = evaluate_model(m, va_ds2, CONFIG['device']) if _mv['mrr'] > _best2: _best2 = _mv['mrr'] torch.save(m.state_dict(), f"{CONFIG['output_path']}A4_{density_label}_HyNT_best.pt") density_res['HyNT'] = evaluate_model(m, te_ds2, CONFIG['device']) # MultiTask-AS m = MultiTaskAlertStar(ne2, nr2, nqk2, nqv2, DIM, dropout=CONFIG['dropout']) _mt2_ds = MultiTaskDataset(tr2, p2, nqk_override=nqk2) _mt2_ldr = DataLoader(_mt2_ds, batch_size=CONFIG['mt_batch'], shuffle=True, collate_fn=mt_collate) _wts2 = {'tail':1.0,'relation':1.0,'qual_key':0.5,'qual_value':0.8} m.to(CONFIG['device']); _mo2 = torch.optim.Adam(m.parameters(), lr=CONFIG['mt_lr']) _bm2 = 0.0 for _ep in range(CONFIG['mt_epochs']): m.train() for _bt in tqdm(_mt2_ldr, desc=f"A4 MT {density_label} ep{_ep+1}", leave=False): _bl = torch.tensor(0., device=CONFIG['device']) for _task, _samps in _bt.items(): for _s in _samps: try: _lg = m.forward_task(_task, _s, CONFIG['device']) _w = _wts2.get(_task, 1.0) if _task == 'tail': _tg = torch.tensor(_s['t'], device=CONFIG['device']) _ls = F.cross_entropy(_lg.unsqueeze(0), _tg.unsqueeze(0)) elif _task == 'relation': _tg = torch.tensor(_s['r'], device=CONFIG['device']) _ls = F.cross_entropy(_lg.unsqueeze(0), _tg.unsqueeze(0)) elif _task == 'qual_key': _tg = torch.zeros(nqk2, device=CONFIG['device']) _tg[_s['keys']] = 1.0 _ls = F.binary_cross_entropy_with_logits(_lg, _tg) elif _task == 'qual_value': _tg = torch.tensor(_s['qv'], device=CONFIG['device']) _ls = F.cross_entropy(_lg.unsqueeze(0), _tg.unsqueeze(0)) else: continue _bl = _bl + _w*_ls except Exception: continue _mo2.zero_grad(); _bl.backward() torch.nn.utils.clip_grad_norm_(m.parameters(), 1.0); _mo2.step() if (_ep+1) % 5 == 0: _mv = evaluate_model(m, va_ds2, CONFIG['device']) if _mv['mrr'] > _bm2: _bm2 = _mv['mrr'] torch.save(m.state_dict(), f"{CONFIG['output_path']}A4_{density_label}_MT_best.pt") density_res['MultiTask_AlertStar'] = evaluate_model(m, te_ds2, CONFIG['device']) # HR-NBFNet for this density _hr_g2 = build_hr_nbfnet_graph(tr2, p2, CONFIG['device'], max_quals=CONFIG['hr_nbfnet_max_quals'], p_override=p2) m = HRNBFNet(ne2, nr2, nqk2, nqv2, DIM, num_layers=CONFIG['hr_nbfnet_layers'], chunk_size=CONFIG['hr_nbfnet_chunk_size'], dropout=CONFIG['dropout'], max_quals=CONFIG['hr_nbfnet_max_quals']) m, _ = train_hr_nbfnet(m, f"A4_{density_label}_HR_NBFNet", train_data_=tr2, valid_ds_=va_ds2, graph_=_hr_g2, p_override=p2) _hr_g2_dev = {k: v.to(CONFIG['device']) if torch.is_tensor(v) else v for k,v in _hr_g2.items()} density_res['HR_NBFNet'] = evaluate_hr_nbfnet( m, te_ds2, _hr_g2_dev, CONFIG['device'], p_override=p2) # MultiTask_HR_NBFNet for this density m = MultiTaskHRNBFNet(ne2, nr2, nqk2, nqv2, DIM, num_layers=CONFIG['hr_nbfnet_layers'], chunk_size=CONFIG['hr_nbfnet_chunk_size'], dropout=CONFIG['dropout'], max_quals=CONFIG['hr_nbfnet_max_quals']) m, _ = train_mt_hr_nbfnet(m, f"A4_{density_label}_MultiTask_HR_NBFNet", train_data_=tr2, valid_ds_=va_ds2, graph_=_hr_g2, p_override=p2, nqk_override=nqk2, nqv_override=nqv2) density_res['MultiTask_HR_NBFNet'] = evaluate_hr_nbfnet( m, te_ds2, _hr_g2_dev, CONFIG['device'], p_override=p2) ablation_A4[density_label] = density_res for mname, metrics in density_res.items(): print(f" {density_label} {mname:30s} MRR={metrics['mrr']:.4f} " f"H@1={metrics['hits@1']:.4f} H@10={metrics['hits@10']:.4f}") ablation_results['A4_Qualifier_Density'] = ablation_A4 print("\n Ablation A4 complete") # ============================================================================ # 19 — COMPLETE RESULTS TABLES # ============================================================================ print("\n" + "="*80) print("TABLE 1: COMPLETE 10-MODEL COMPARISON") print("="*80) df_main = pd.DataFrame({ 'Model': list(all_results.keys()), 'MR': [all_results[m]['mr'] for m in all_results], 'MRR': [all_results[m]['mrr'] for m in all_results], 'Hits@1': [all_results[m]['hits@1'] for m in all_results], 'Hits@3': [all_results[m]['hits@3'] for m in all_results], 'Hits@10': [all_results[m]['hits@10'] for m in all_results], }).sort_values('MRR', ascending=False).reset_index(drop=True) print(df_main.to_string(index=False)) print("\n" + "="*80) print("TABLE 2: TrueNBFNet vs HR-NBFNet vs MultiTask_HR_NBFNet") print("="*80) for mname in ['TrueNBFNet','HR_NBFNet','MultiTask_HR_NBFNet']: r = all_results.get(mname, {}) print(f" {mname:30s} MRR={r.get('mrr',0):.4f} MR={r.get('mr',0):7.1f} " f"H@1={r.get('hits@1',0):.4f} H@3={r.get('hits@3',0):.4f} " f"H@10={r.get('hits@10',0):.4f}") print("\n" + "="*80) print("TABLE 3: StarQE family comparison") print("="*80) for mname in ['StarQE','NBFNet_StarQE']: r = all_results.get(mname, {}) print(f" {mname:30s} MRR={r.get('mrr',0):.4f} MR={r.get('mr',0):7.1f} " f"H@1={r.get('hits@1',0):.4f} H@10={r.get('hits@10',0):.4f}") df_A1 = pd.DataFrame({ 'Variant': list(ablation_A1.keys()), 'Qual?': ['✗','✓','✓','✓'], 'Path?': ['✓','✗','✓','✓'], 'Gate?': ['learned','learned','fixed=0.5','learned'], 'MRR': [ablation_A1[v]['mrr'] for v in ablation_A1], 'Hits@1': [ablation_A1[v]['hits@1'] for v in ablation_A1], 'Hits@3': [ablation_A1[v]['hits@3'] for v in ablation_A1], 'Hits@10': [ablation_A1[v]['hits@10'] for v in ablation_A1], 'MR': [ablation_A1[v]['mr'] for v in ablation_A1], }) print("\n" + "="*80) print("TABLE 4: ABLATION A1 — AlertStar Components") print("="*80) print(df_A1.to_string(index=False)) print("\n" + "="*80) print("TABLE 5: ABLATION A2 — Gate Trajectory") print("="*80) if gate_history: print(pd.DataFrame(gate_history).to_string(index=False)) df_A3 = pd.DataFrame({ 'Variant': list(ablation_A3.keys()), 'Tasks': [str(t[1]) for t in mt_ablation_configs], 'MRR': [ablation_A3[v]['mrr'] for v in ablation_A3], 'Hits@1': [ablation_A3[v]['hits@1'] for v in ablation_A3], 'Hits@10': [ablation_A3[v]['hits@10'] for v in ablation_A3], 'MR': [ablation_A3[v]['mr'] for v in ablation_A3], }) print("\n" + "="*80) print("TABLE 6: ABLATION A3 — MT Task Contribution") print("="*80) print(df_A3.to_string(index=False)) print("\n" + "="*80) print("TABLE 7: ABLATION A4 — Qualifier Density") print("="*80) for density, res in ablation_A4.items(): print(f"\n {density}:") for mname, metrics in res.items(): print(f" {mname:30s} MRR={metrics['mrr']:.4f} " f"H@1={metrics['hits@1']:.4f} H@10={metrics['hits@10']:.4f}") # ============================================================================ # 20 — VISUALIZATIONS # ============================================================================ n_models = len(df_main) palette_main = sns.color_palette("tab10", n_models) met_list = ['mrr','hits@1','hits@3','hits@10'] lbl_list = ['MRR','H@1','H@3','H@10'] fig = plt.figure(figsize=(30, 24)) fig.suptitle("AlertStar — 10-Model Complete Results & Ablation Studies", fontsize=18, fontweight='bold', y=0.98) gs = fig.add_gridspec(3, 3, hspace=0.5, wspace=0.35) # ── P1: Main MRR ───────────────────────────────────────────────────────── ax1 = fig.add_subplot(gs[0, 0]) bars = ax1.bar(df_main['Model'], df_main['MRR'], color=palette_main) ax1.set_title("Main Results — MRR (all 10 models)", fontweight='bold') ax1.set_ylabel("MRR") ax1.tick_params(axis='x', rotation=55, labelsize=6) for bar, v in zip(bars, df_main['MRR']): ax1.text(bar.get_x()+bar.get_width()/2, v+0.002, f'{v:.3f}', ha='center', fontsize=5, fontweight='bold') highlight = {'HR_NBFNet':'goldenrod', 'MultiTask_HR_NBFNet':'crimson', 'HyNT':'red', 'StarQE':'steelblue', 'NBFNet_StarQE':'navy'} model_list = list(df_main['Model']) for i, bar in enumerate(bars): mname = model_list[i] if mname in highlight: bar.set_edgecolor(highlight[mname]); bar.set_linewidth(2.5) # ── P2: Hits@k ──────────────────────────────────────────────────────────── ax2 = fig.add_subplot(gs[0, 1]) x2 = np.arange(n_models); w = 0.25 ax2.bar(x2-w, df_main['Hits@1'], w, label='H@1', color='steelblue') ax2.bar(x2, df_main['Hits@3'], w, label='H@3', color='orange') ax2.bar(x2+w, df_main['Hits@10'], w, label='H@10', color='green') ax2.set_xticks(x2) ax2.set_xticklabels(df_main['Model'], rotation=55, ha='right', fontsize=6) ax2.set_title("Main Results — Hits@k", fontweight='bold') ax2.legend(fontsize=8) # ── P3: NBFNet family (TrueNBFNet / HR-NBFNet / MT-HR-NBFNet) ──────────── ax3 = fig.add_subplot(gs[0, 2]) nbf_mods = ['TrueNBFNet','HR_NBFNet','MultiTask_HR_NBFNet'] nbf_res = {m: all_results[m] for m in nbf_mods if m in all_results} x3 = np.arange(len(lbl_list)); w3 = 0.25 pal3 = ['steelblue','goldenrod','crimson'] for i, (mname, res) in enumerate(nbf_res.items()): vals = [res.get(mk,0) for mk in met_list] offset = (i - len(nbf_res)/2 + 0.5)*w3 ax3.bar(x3+offset, vals, w3, label=mname, color=pal3[i]) ax3.set_xticks(x3); ax3.set_xticklabels(lbl_list) ax3.set_title("NBFNet Family Comparison", fontweight='bold') ax3.legend(fontsize=8); ax3.set_ylabel("Score") # ── P4: StarQE family ───────────────────────────────────────────────────── ax4 = fig.add_subplot(gs[1, 0]) qe_mods = ['StarE','StarQE','NBFNet_StarQE','AlertStar'] qe_res = {m: all_results[m] for m in qe_mods if m in all_results} x4 = np.arange(len(lbl_list)); w4 = 0.2 pal4 = sns.color_palette("Set2", len(qe_res)) for i, (mname, res) in enumerate(qe_res.items()): vals = [res.get(mk,0) for mk in met_list] offset = (i - len(qe_res)/2 + 0.5)*w4 ax4.bar(x4+offset, vals, w4, label=mname, color=pal4[i]) ax4.set_xticks(x4); ax4.set_xticklabels(lbl_list) ax4.set_title("StarQE Family vs Baselines", fontweight='bold') ax4.legend(fontsize=8) # ── P5: Ablation A1 ─────────────────────────────────────────────────────── ax5 = fig.add_subplot(gs[1, 1]) x5 = np.arange(len(df_A1)); w5 = 0.25 ax5.bar(x5-w5, df_A1['MRR'], w5, label='MRR', color='steelblue') ax5.bar(x5, df_A1['Hits@1'], w5, label='H@1', color='orange') ax5.bar(x5+w5, df_A1['Hits@10'], w5, label='H@10', color='green') ax5.set_xticks(x5) ax5.set_xticklabels(df_A1['Variant'], rotation=20, ha='right', fontsize=9) ax5.set_title("A1: AlertStar Components", fontweight='bold') ax5.legend(fontsize=8) full_mrr = df_A1[df_A1['Variant']=='AS-Full']['MRR'].values[0] for xi, (_, row) in zip(x5, df_A1.iterrows()): drop = full_mrr - row['MRR'] if abs(drop) > 0.001: ax5.text(xi-w5, row['MRR']+0.003, f'Δ{drop:+.3f}', ha='center', fontsize=7, color='red') # ── P6: Gate Trajectory ─────────────────────────────────────────────────── ax6 = fig.add_subplot(gs[1, 2]) if gate_history: epochs = [g['epoch'] for g in gate_history] gates = [g['gate'] for g in gate_history] ax6.plot(epochs, gates, 'o-', color='purple', lw=2, markersize=6) ax6.axhline(0.5, color='gray', ls='--', alpha=0.7, label='g=0.5 (balanced)') ax6.fill_between(epochs, gates, 0.5, alpha=0.12, color='blue' if gates[-1]>0.5 else 'orange') ax6.set_ylim(0,1.05) ax6.set_xlabel("Epoch"); ax6.set_ylabel("g = σ(θ)") ax6.set_title("A2: AlertStar Gate Trajectory", fontweight='bold') ax6.legend(fontsize=8) ax6.annotate(f"Final: {gates[-1]:.3f}", xy=(epochs[-1],gates[-1]), xytext=(epochs[-1]-3, gates[-1]+0.08), arrowprops=dict(arrowstyle='->',color='purple'), fontsize=9, color='purple') else: ax6.text(0.5,0.5,'No gate data',ha='center',va='center',transform=ax6.transAxes) ax6.set_title("A2: Gate Trajectory",fontweight='bold') # ── P7: Ablation A3 ─────────────────────────────────────────────────────── ax7 = fig.add_subplot(gs[2, 0]) x7 = np.arange(len(df_A3)) ax7.bar(x7-w5, df_A3['MRR'], w5, label='MRR', color='steelblue') ax7.bar(x7, df_A3['Hits@1'], w5, label='H@1', color='orange') ax7.bar(x7+w5, df_A3['Hits@10'], w5, label='H@10', color='green') ax7.set_xticks(x7) ax7.set_xticklabels(df_A3['Variant'], rotation=25, ha='right', fontsize=8) ax7.set_title("A3: MT Task Contribution", fontweight='bold') ax7.legend(fontsize=8) # ── P8: Ablation A4 — Density (MRR) ────────────────────────────────────── ax8 = fig.add_subplot(gs[2, 1:]) densities = list(ablation_A4.keys()) n_dm = len(DENSITY_MODELS) x8 = np.arange(len(densities)) w8 = 0.12 pal8 = sns.color_palette("tab10", n_dm) for mi, mname in enumerate(DENSITY_MODELS): mrrs = [ablation_A4[d].get(mname,{}).get('mrr',0) for d in densities] offset = (mi - n_dm/2 + 0.5)*w8 bars8 = ax8.bar(x8+offset, mrrs, w8, label=mname, color=pal8[mi]) # bold border for new methods if mname in ('HR_NBFNet','MultiTask_HR_NBFNet'): for b in bars8: b.set_edgecolor('black'); b.set_linewidth(1.5) ax8.set_xticks(x8); ax8.set_xticklabels(densities) ax8.set_xlabel("Qualifier Density"); ax8.set_ylabel("MRR") ax8.set_title("A4: MRR vs Qualifier Density (all 6 models)", fontweight='bold') ax8.legend(fontsize=7, loc='upper left', ncol=2) plt.savefig(f"{CONFIG['output_path']}complete_10model_results.png", dpi=300, bbox_inches='tight') plt.show() print("Visualization saved") # ============================================================================ # 21 — SAVE ALL RESULTS # ============================================================================ with open(f"{CONFIG['output_path']}all_results_10models.json", 'w') as f: json.dump({'main_results': all_results, 'ablation_results': ablation_results, 'gate_history': gate_history}, f, indent=2) print(f"\n Saved to {CONFIG['output_path']}") print("\n" + "="*70) print("COMPLETE — 10 models + 4 ablations ready") print("="*70) print(f"\nAll {len(all_results)} models ranked by MRR:") for m, r in sorted(all_results.items(), key=lambda x: -x[1]['mrr']): tag = " ← NEW" if m in ('HR_NBFNet','MultiTask_HR_NBFNet') else \ " ← restored" if m in ('StarQE','NBFNet_StarQE') else "" print(f" {m:30s} MRR={r['mrr']:.4f} H@10={r['hits@10']:.4f}{tag}")