# ============================================================================
# ALERTSTAR: COMPLETE NOTEBOOK — 10 MODELS + ALL ABLATIONS
#
# Models:
# 1. StarE — qualifier-enriched relation attention
# 2. ShrinkE — shrinking transform qualifier fusion
# 3. TrueNBFNet — Neural Bellman-Ford (qualifier-unaware)
# 4. AlertStar — gated StarE + path composition (ours)
# 5. StarQE — complex query answering (1p/2p/2i/2u)
# 6. NBFNet+StarQE — residual path-augmented complex queries
# 7. HyNT — Transformer 2-task competitor
# 8. MultiTask-AS — Transformer 4-task (ours, main)
# 9. HR-NBFNet — Hyper-Relational Bellman-Ford
# qualifier-aware at every propagation step
# matches slide formulation exactly
# 10. MultiTask_HR_NBFNet — HR-NBFNet + 4-task multi-task training (NEW)
# combines graph propagation with auxiliary
# relation/qual-key/qual-value supervision
#
# Ablations:
# A1 — AlertStar component ablation (NoQual / NoPath / NoGate / Full)
# A2 — AlertStar gate value trajectory
# A3 — MultiTask-AS auxiliary task contribution
# A4 — Qualifier density sensitivity (Q33 / Q66 / Q100)
# HR-NBFNet + MultiTask_HR_NBFNet included in A4
# ============================================================================
# ============================================================================
# 1 — Imports & Config
# ============================================================================
import os, json, warnings, copy
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
from collections import defaultdict
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch_scatter import scatter_add, scatter_mean
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
CONFIG = {
# ── paths ─────────────────────────────────────────────────────────────
'data_path': '/.../inductive_q100_statements',
'output_path': '/.../results_q3366100_ablation_10ci100/',
# ── shared hyper-params ───────────────────────────────────────────────
'embedding_dim': 200,
'dropout': 0.2,
'device': device,
# ── standard training (StarE / ShrinkE / AlertStar) ──────────────────
'batch_size': 128,
'learning_rate': 0.0005,
'epochs': 20,
# ── TrueNBFNet ────────────────────────────────────────────────────────
'nbfnet_epochs': 20,
'nbfnet_lr': 0.0005,
'nbfnet_layers': 3,
'nbfnet_chunk_size': 10000,
'nbfnet_max_per_group': 8,
# ── StarQE / NBFNet+StarQE ────────────────────────────────────────────
'query_epochs': 20,
'query_lr': 0.0005,
# ── HyNT ──────────────────────────────────────────────────────────────
'hynt_epochs': 20,
'hynt_lr': 0.0005,
'hynt_batch': 64,
'hynt_n_heads': 4,
'hynt_n_layers': 2,
# ── MultiTask-AS ──────────────────────────────────────────────────────
'mt_epochs': 20,
'mt_lr': 0.0005,
'mt_batch': 64,
# ── HR-NBFNet ─────────────────────────────────────────────────────────
'hr_nbfnet_epochs': 20,
'hr_nbfnet_lr': 0.0005,
'hr_nbfnet_layers': 3,
'hr_nbfnet_chunk_size': 5000,
'hr_nbfnet_max_quals': 8,
# ── MultiTask_HR_NBFNet ───────────────────────────────────────────────
'mt_hr_epochs': 20,
'mt_hr_lr': 0.0005,
'mt_hr_batch': 32, # smaller batch — each sample triggers a BF pass
# ── qualifier density paths for Ablation A4 ───────────────────────────
'q33_path': '/.../inductive_q33_statements',
'q66_path': '/.../inductive_q66_statements',
}
os.makedirs(CONFIG['output_path'], exist_ok=True)
all_results = {}
ablation_results = {}
gate_history = []
print("Config ready")
# ============================================================================
# 2 — Data Loading
# ============================================================================
class DataPreprocessor:
def __init__(self):
self.entity2id = {}
self.relation2id = {}
self.qualifier_key2id = {}
self.qualifier_value2id = {}
def _parse(self, line):
parts = line.strip().split(',')
if len(parts) < 3:
return None
head, relation, tail = parts[0], parts[1], parts[2]
qualifiers = []
if len(parts) > 3:
for pair in parts[3].split('|'):
pair = pair.strip()
if ':' in pair:
key, value = pair.split(':', 1)
qualifiers.append((key.strip(), value.strip()))
return {'head': head, 'relation': relation,
'tail': tail, 'qualifiers': qualifiers}
def _register(self, triple):
for token, vocab in [(triple['head'], self.entity2id),
(triple['tail'], self.entity2id),
(triple['relation'], self.relation2id)]:
if token not in vocab:
vocab[token] = len(vocab)
for qk, qv in triple['qualifiers']:
if qk not in self.qualifier_key2id:
self.qualifier_key2id[qk] = len(self.qualifier_key2id)
if qv not in self.qualifier_value2id:
self.qualifier_value2id[qv] = len(self.qualifier_value2id)
def load(self, data_path):
splits = {}
for name, fname in [('train','train.txt'),
('valid','validation.txt'),
('test', 'test.txt')]:
data = []
with open(os.path.join(data_path, fname)) as f:
for line in f:
t = self._parse(line)
if t:
self._register(t)
data.append(t)
splits[name] = data
print(f" {name}: {len(data):,}")
print(f" NE={len(self.entity2id)} NR={len(self.relation2id)} "
f"NQK={len(self.qualifier_key2id)} NQV={len(self.qualifier_value2id)}")
return splits['train'], splits['valid'], splits['test']
class HRDataset(Dataset):
def __init__(self, data, preprocessor):
self.data = data
self.p = preprocessor
def __len__(self): return len(self.data)
def __getitem__(self, idx):
t = self.data[idx]
return {
'head': self.p.entity2id[t['head']],
'relation': self.p.relation2id[t['relation']],
'tail': self.p.entity2id[t['tail']],
'qualifiers': [(self.p.qualifier_key2id[qk],
self.p.qualifier_value2id[qv])
for qk, qv in t['qualifiers']],
}
def collate(batch):
return {
'head': torch.tensor([b['head'] for b in batch], dtype=torch.long),
'relation': torch.tensor([b['relation'] for b in batch], dtype=torch.long),
'tail': torch.tensor([b['tail'] for b in batch], dtype=torch.long),
'qualifiers': [b['qualifiers'] for b in batch],
}
def mt_collate(batch):
by_task = defaultdict(list)
for s in batch:
by_task[s['task']].append(s)
return by_task
print("Loading data...")
preprocessor = DataPreprocessor()
train_data, valid_data, test_data = preprocessor.load(CONFIG['data_path'])
train_ds = HRDataset(train_data, preprocessor)
valid_ds = HRDataset(valid_data, preprocessor)
test_ds = HRDataset(test_data, preprocessor)
NE = len(preprocessor.entity2id)
NR = len(preprocessor.relation2id)
NQK = len(preprocessor.qualifier_key2id)
NQV = len(preprocessor.qualifier_value2id)
DIM = CONFIG['embedding_dim']
print("Data ready")
# ============================================================================
# 3 — Shared Evaluation & Training Utilities
# ============================================================================
def evaluate_model(model, dataset, device, max_samples=500):
model.eval()
ranks = []
with torch.no_grad():
for i in tqdm(range(min(len(dataset), max_samples)),
desc="Evaluating", leave=False):
s = dataset[i]
h = torch.tensor([s['head']], device=device)
r = torch.tensor([s['relation']], device=device)
t = s['tail']
q = [s['qualifiers']]
scores = model(h, r, q).squeeze()
rank = (torch.argsort(scores, descending=True) == t
).nonzero(as_tuple=True)[0].item() + 1
ranks.append(rank)
ranks = np.array(ranks)
return {'mr': float(np.mean(ranks)),
'mrr': float(np.mean(1.0 / ranks)),
'hits@1': float(np.mean(ranks <= 1)),
'hits@3': float(np.mean(ranks <= 3)),
'hits@10': float(np.mean(ranks <= 10))}
def train_standard(model, model_name, train_ds_=None, valid_ds_=None,
ne_override=None, lr=None, epochs=None, gate_track=False):
dev = CONFIG['device']
lr = lr or CONFIG['learning_rate']
eps = epochs or CONFIG['epochs']
_ne = ne_override or NE
_tds = train_ds_ or train_ds
_vds = valid_ds_ or valid_ds
model.to(dev)
opt = torch.optim.Adam(model.parameters(), lr=lr)
loader = DataLoader(_tds, batch_size=CONFIG['batch_size'],
shuffle=True, collate_fn=collate)
print(f"\n{'='*60}\nTRAINING {model_name} "
f"({sum(p.numel() for p in model.parameters()):,} params)\n{'='*60}")
best_mrr, history, epoch_gates = 0.0, [], []
for epoch in range(eps):
model.train(); total = 0.0
for b in tqdm(loader, desc=f"Epoch {epoch+1}/{eps}", leave=False):
h, r, t, q = (b['head'].to(dev), b['relation'].to(dev),
b['tail'].to(dev), b['qualifiers'])
pos = model(h, r, q, t)
neg = model(h, r, q, torch.randint(0, _ne, (len(h),), device=dev))
loss = F.margin_ranking_loss(pos, neg, torch.ones_like(pos), margin=1.0)
opt.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step(); total += loss.item()
if gate_track and hasattr(model, 'gate') and model.gate is not None:
epoch_gates.append({'epoch': epoch+1,
'gate': torch.sigmoid(model.gate).item()})
if (epoch+1) % 5 == 0:
m = evaluate_model(model, _vds, dev)
print(f" Epoch {epoch+1}: loss={total/len(loader):.4f} "
f"MRR={m['mrr']:.4f} MR={m['mr']:.1f} "
f"H@1={m['hits@1']:.4f} H@3={m['hits@3']:.4f} "
f"H@10={m['hits@10']:.4f}"
+ (f" gate={epoch_gates[-1]['gate']:.4f}"
if gate_track and epoch_gates else ""))
if m['mrr'] > best_mrr:
best_mrr = m['mrr']
torch.save(model.state_dict(),
f"{CONFIG['output_path']}{model_name}_best.pt")
print(" → best saved")
history.append(m)
return model, history, epoch_gates
print("Utilities ready")
# ============================================================================
# 4 — Model 1: StarE
# ============================================================================
class StarEModel(nn.Module):
def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1):
super().__init__()
self.num_entities = ne
self.ent = nn.Embedding(ne, dim)
self.rel = nn.Embedding(nr, dim)
self.qk = nn.Embedding(nqk, dim)
self.qv = nn.Embedding(nqv, dim)
self.attn = nn.MultiheadAttention(dim, 4, dropout=dropout, batch_first=True)
self.drop = nn.Dropout(dropout)
for e in [self.ent, self.rel, self.qk, self.qv]:
nn.init.xavier_uniform_(e.weight)
def _enrich(self, r_emb, quals, dev):
if not quals: return r_emb
k = self.qk(torch.tensor([q[0] for q in quals], device=dev))
v = self.qv(torch.tensor([q[1] for q in quals], device=dev))
kv = (k + v).unsqueeze(0)
out, _ = self.attn(r_emb.view(1,1,-1), kv, kv)
return out.squeeze()
def forward(self, head, relation, qualifiers, tail=None):
dev = head.device
h = self.ent(head)
r = torch.stack([self._enrich(self.rel(relation[i:i+1]).squeeze(),
qualifiers[i], dev) for i in range(head.size(0))])
x = self.drop(h + r)
if tail is not None: return (x * self.ent(tail)).sum(-1)
return x @ self.ent.weight.t()
print("StarE defined")
# ============================================================================
# 5 — Model 2: ShrinkE
# ============================================================================
class ShrinkEModel(nn.Module):
def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1):
super().__init__()
self.num_entities = ne
self.ent = nn.Embedding(ne, dim)
self.rel = nn.Embedding(nr, dim)
self.qk = nn.Embedding(nqk, dim)
self.qv = nn.Embedding(nqv, dim)
self.shrink = nn.Sequential(nn.Linear(dim*2, dim), nn.Tanh(), nn.Dropout(dropout))
self.proj = nn.Linear(dim, dim)
self.drop = nn.Dropout(dropout)
for e in [self.ent, self.rel, self.qk, self.qv]:
nn.init.xavier_uniform_(e.weight)
def _shrink(self, r_emb, quals, dev):
if not quals: return r_emb
k = self.qk(torch.tensor([q[0] for q in quals], device=dev))
v = self.qv(torch.tensor([q[1] for q in quals], device=dev))
qc = (k + v).mean(0, keepdim=True)
return self.shrink(torch.cat([r_emb.unsqueeze(0), qc], -1)).squeeze()
def forward(self, head, relation, qualifiers, tail=None):
dev = head.device
h = self.proj(self.ent(head))
r = torch.stack([self._shrink(self.rel(relation[i:i+1]).squeeze(),
qualifiers[i], dev) for i in range(head.size(0))])
x = self.drop(h + self.proj(r))
if tail is not None: return (x * self.ent(tail)).sum(-1)
return x @ self.ent.weight.t()
print("✓ ShrinkE defined")
# ============================================================================
# 6 — Model 3: TrueNBFNet (qualifier-unaware Bellman-Ford)
# ============================================================================
class NBFConvLayer(nn.Module):
def __init__(self, dim, num_relation, chunk_size=10000, layer_norm=True):
super().__init__()
self.chunk_size = chunk_size
self.rel_emb = nn.Embedding(num_relation, dim)
self.linear = nn.Linear(dim*2, dim)
self.ln = nn.LayerNorm(dim) if layer_norm else None
self.act = nn.ReLU()
nn.init.xavier_uniform_(self.rel_emb.weight)
def forward(self, graph, node_feat):
edge_list = graph['edge_list']
N, B, D = node_feat.shape
dev = node_feat.device
agg = torch.zeros(N, B, D, device=dev)
for start in range(0, edge_list.size(0), self.chunk_size):
chunk = edge_list[start:start+self.chunk_size]
src, dst, rel = chunk[:,0], chunk[:,1], chunk[:,2]
msg = node_feat[src] * self.rel_emb(rel).unsqueeze(1)
agg.scatter_add_(0, dst.view(-1,1,1).expand_as(msg), msg)
del msg
out = self.linear(torch.cat([node_feat, agg], dim=-1))
if self.ln:
s = out.shape; out = self.ln(out.flatten(0,1)).view(s)
return self.act(out)
class TrueNBFNet(nn.Module):
def __init__(self, ne, nr, dim=200, num_layers=3,
short_cut=True, chunk_size=10000, dropout=0.1):
super().__init__()
self.num_entities = ne
self.dim = dim
self.short_cut = short_cut
nr2 = nr * 2
self.query_emb = nn.Embedding(nr2, dim)
self.layers = nn.ModuleList([
NBFConvLayer(dim, nr2, chunk_size=chunk_size) for _ in range(num_layers)])
self.mlp = nn.Sequential(
nn.Linear(dim*2, dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, 1))
nn.init.xavier_uniform_(self.query_emb.weight)
def _bellman_ford(self, graph, h, r, device):
N = graph['num_nodes']
q = self.query_emb(torch.tensor([r], device=device))
feat = torch.zeros(N, 1, self.dim, device=device)
feat[h, 0] = q[0]
for layer in self.layers:
h_new = layer(graph, feat)
if self.short_cut: h_new = h_new + feat
feat = h_new
node_q = q.unsqueeze(0).expand(N, -1, -1)
return torch.cat([feat, node_q], dim=-1).squeeze(1)
def forward(self, head, relation, qualifiers=None, tail=None, graph=None):
assert graph is not None
assert (head==head[0]).all() and (relation==relation[0]).all()
dev = head.device
feat = self._bellman_ford(graph, head[0].item(), relation[0].item(), dev)
if tail is not None: return self.mlp(feat[tail]).squeeze(-1)
B = head.size(0)
return self.mlp(feat).squeeze(-1).unsqueeze(0).expand(B, -1)
def build_nbfnet_graph(train_data, preprocessor, device):
nr, rows = len(preprocessor.relation2id), []
for t in train_data:
h = preprocessor.entity2id[t['head']]
r = preprocessor.relation2id[t['relation']]
tl = preprocessor.entity2id[t['tail']]
rows += [(h, tl, r), (tl, h, r+nr)]
return {'edge_list': torch.tensor(rows, dtype=torch.long, device=device),
'num_nodes': len(preprocessor.entity2id)}
def train_true_nbfnet(model, model_name="TrueNBFNet",
train_data_=None, valid_ds_=None, graph_=None):
dev = CONFIG['device']
_td = train_data_ or train_data
_vds = valid_ds_ or valid_ds
_g = graph_ or nbfnet_graph
model.to(dev)
graph = {k: v.to(dev) if torch.is_tensor(v) else v for k,v in _g.items()}
opt = torch.optim.Adam(model.parameters(), lr=CONFIG['nbfnet_lr'])
ne_ = graph['num_nodes']
mpg = CONFIG['nbfnet_max_per_group']
groups = defaultdict(list)
for i, s in enumerate(_td):
groups[(preprocessor.entity2id[s['head']],
preprocessor.relation2id[s['relation']])].append(i)
keys = list(groups.keys())
print(f"\n{'='*60}\nTRAINING {model_name}\n{'='*60}")
best_mrr, history = 0.0, []
for epoch in range(CONFIG['nbfnet_epochs']):
model.train(); np.random.shuffle(keys)
total, cnt = 0.0, 0
for (h, r) in tqdm(keys, desc=f"Epoch {epoch+1}", leave=False):
chosen = np.random.choice(groups[(h,r)],
min(len(groups[(h,r)]),mpg), replace=False)
t_pos = torch.tensor(
[preprocessor.entity2id[_td[i]['tail']] for i in chosen], device=dev)
B = len(t_pos)
heads = torch.full((B,), h, dtype=torch.long, device=dev)
rels = torch.full((B,), r, dtype=torch.long, device=dev)
pos = model(heads, rels, tail=t_pos, graph=graph)
neg = model(heads, rels, tail=torch.randint(0,ne_,(B,),device=dev), graph=graph)
loss = F.margin_ranking_loss(pos, neg, torch.ones(B,device=dev), margin=1.0)
opt.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step(); total += loss.item(); cnt += 1
if (epoch+1) % 5 == 0:
m = evaluate_nbfnet(model, _vds, graph, dev)
print(f" Epoch {epoch+1}: MRR={m['mrr']:.4f} MR={m['mr']:.1f} "
f"H@1={m['hits@1']:.4f} H@10={m['hits@10']:.4f}")
if m['mrr'] > best_mrr:
best_mrr = m['mrr']
torch.save(model.state_dict(),
f"{CONFIG['output_path']}{model_name}_best.pt")
history.append(m)
return model, history
def evaluate_nbfnet(model, dataset, graph, device, max_groups=300):
model.eval()
groups = defaultdict(list)
for i, s in enumerate(dataset.data):
groups[(preprocessor.entity2id[s['head']],
preprocessor.relation2id[s['relation']])].append(i)
ranks = []
with torch.no_grad():
for (h, r) in tqdm(list(groups.keys())[:max_groups], desc="Eval", leave=False):
idxs = groups[(h,r)]
tails = [preprocessor.entity2id[dataset.data[i]['tail']] for i in idxs]
B = len(tails)
heads = torch.full((B,), h, dtype=torch.long, device=device)
rels = torch.full((B,), r, dtype=torch.long, device=device)
try:
scores = model(heads, rels, graph=graph)
for i, tgt in enumerate(tails):
pos = (torch.argsort(scores[i],descending=True)==tgt
).nonzero(as_tuple=True)[0]
if len(pos): ranks.append(pos[0].item()+1)
except Exception: continue
if not ranks:
return {'mr':0.,'mrr':0.,'hits@1':0.,'hits@3':0.,'hits@10':0.}
ranks = np.array(ranks)
return {'mr':float(np.mean(ranks)), 'mrr':float(np.mean(1./ranks)),
'hits@1':float(np.mean(ranks<=1)), 'hits@3':float(np.mean(ranks<=3)),
'hits@10':float(np.mean(ranks<=10))}
print("TrueNBFNet defined")
# ============================================================================
# 7 — Model 4: AlertStar (with ablation flags)
# ============================================================================
class AlertStarModel(nn.Module):
"""
use_qual=False → AS-NoQual
use_path=False → AS-NoPath
fixed_gate=0.5 → AS-NoGate
default → AS-Full
"""
def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1,
use_qual=True, use_path=True, fixed_gate=None):
super().__init__()
self.num_entities = ne
self.use_qual, self.use_path, self.fixed_gate = use_qual, use_path, fixed_gate
self.ent = nn.Embedding(ne, dim)
self.rel = nn.Embedding(nr, dim)
self.qk = nn.Embedding(nqk, dim)
self.qv = nn.Embedding(nqv, dim)
self.attn = nn.MultiheadAttention(dim, 4, dropout=dropout, batch_first=True)
self.ln1 = nn.LayerNorm(dim)
self.path_net = nn.Sequential(
nn.Linear(dim*2, dim), nn.LayerNorm(dim), nn.ReLU(),
nn.Dropout(dropout), nn.Linear(dim, dim))
self.ln2 = nn.LayerNorm(dim)
self.gate = (nn.Parameter(torch.tensor(0.5))
if use_path and fixed_gate is None else None)
self.drop = nn.Dropout(dropout)
for e in [self.ent, self.rel, self.qk, self.qv]:
nn.init.xavier_uniform_(e.weight)
def _enrich(self, r_emb, quals, dev):
if not self.use_qual or not quals: return r_emb
k = self.qk(torch.tensor([q[0] for q in quals], device=dev))
v = self.qv(torch.tensor([q[1] for q in quals], device=dev))
kv = (k+v).unsqueeze(0)
out, _ = self.attn(r_emb.view(1,1,-1), kv, kv)
return out.squeeze()
def forward(self, head, relation, qualifiers, tail=None):
dev = head.device
h = self.ent(head)
r = torch.stack([self._enrich(self.rel(relation[i:i+1]).squeeze(),
qualifiers[i], dev) for i in range(head.size(0))])
stare = self.ln1(h + r)
if not self.use_path:
x = self.drop(stare)
else:
path = self.ln2(h + self.path_net(torch.cat([h,r],dim=-1)))
g = (torch.sigmoid(self.gate) if self.gate is not None
else torch.tensor(self.fixed_gate, device=dev))
x = self.drop(g*stare + (1-g)*path)
if tail is not None: return (x * self.ent(tail)).sum(-1)
return x @ self.ent.weight.t()
print("AlertStar defined")
# ============================================================================
# 8 — Model 5: StarQE (Complex Query Answering: 1p / 2p / 2i / 2u)
# ============================================================================
# StarQE extends link prediction to four first-order logic query types:
# 1p: ∃e: r(h,e) — direct 1-hop
# 2p: ∃e1: r1(h,e1) ∧ r2(e1,e) — 2-hop chain
# 2i: r1(h1,e) ∧ r2(h2,e) — 2-anchor intersection
# 2u: r1(h1,e) ∨ r2(h2,e) — 2-anchor union
# All types are trained with margin ranking loss and evaluated on 1p queries
# (tail prediction) for a fair comparison with other models.
# ============================================================================
class StarQEModel(nn.Module):
def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1):
super().__init__()
self.num_entities = ne
self.dim = dim
self.ent = nn.Embedding(ne, dim)
self.rel = nn.Embedding(nr, dim)
# qualifier embeddings — StarQE uses them to enrich relations like StarE
self.qk = nn.Embedding(nqk, dim)
self.qv = nn.Embedding(nqv, dim)
self.attn = nn.MultiheadAttention(dim, 4, dropout=dropout, batch_first=True)
# composition MLP: projects x^{l-1} + R[r] into next entity embedding
self.compose = nn.Sequential(
nn.Linear(dim, dim), nn.LayerNorm(dim), nn.ReLU(), nn.Dropout(dropout))
# intersection operator W_∩: fuses two anchor embeddings
self.intersect = nn.Linear(dim*2, dim)
self.drop = nn.Dropout(dropout)
for e in [self.ent, self.rel, self.qk, self.qv]:
nn.init.xavier_uniform_(e.weight)
def _enrich_rel(self, r_id, quals, dev):
"""Qualifier-enrich a relation embedding (same as StarE)."""
r_emb = self.rel(torch.tensor([r_id], device=dev)).squeeze()
if not quals:
return r_emb
k = self.qk(torch.tensor([q[0] for q in quals], device=dev))
v = self.qv(torch.tensor([q[1] for q in quals], device=dev))
kv = (k + v).unsqueeze(0)
out, _ = self.attn(r_emb.view(1,1,-1), kv, kv)
return out.squeeze()
def _compose_1p(self, h_id, r_id, quals, dev):
"""1p: x = ρ(E[h] + R*[r])"""
h_emb = self.ent(torch.tensor([h_id], device=dev)).squeeze()
r_emb = self._enrich_rel(r_id, quals, dev)
return self.compose(h_emb + r_emb)
def _compose_2p(self, h_id, r1_id, r2_id, quals, dev):
"""2p: x1 = ρ(E[h]+R[r1]), x = ρ(x1+R[r2])"""
h_emb = self.ent(torch.tensor([h_id], device=dev)).squeeze()
r1_emb = self._enrich_rel(r1_id, quals, dev)
r2_emb = self._enrich_rel(r2_id, [], dev)
x1 = self.compose(h_emb + r1_emb)
return self.compose(x1 + r2_emb)
def _compose_2i(self, h1, r1, h2, r2, quals, dev):
"""2i: intersection of two 1p queries"""
e1 = self.compose(
self.ent(torch.tensor([h1],device=dev)).squeeze() +
self._enrich_rel(r1, quals, dev))
e2 = self.compose(
self.ent(torch.tensor([h2],device=dev)).squeeze() +
self._enrich_rel(r2, [], dev))
return self.intersect(torch.cat([e1, e2], dim=-1))
def _compose_2u(self, h1, r1, h2, r2, quals, dev):
"""2u: union (mean) of two 1p queries"""
e1 = self.compose(
self.ent(torch.tensor([h1],device=dev)).squeeze() +
self._enrich_rel(r1, quals, dev))
e2 = self.compose(
self.ent(torch.tensor([h2],device=dev)).squeeze() +
self._enrich_rel(r2, [], dev))
return (e1 + e2) / 2
def forward(self, head, relation, qualifiers, tail=None):
"""Standard 1p interface — compatible with evaluate_model."""
dev = head.device
B = head.size(0)
outs = []
for i in range(B):
x = self._compose_1p(head[i].item(), relation[i].item(),
qualifiers[i], dev)
outs.append(x)
x = self.drop(torch.stack(outs)) # [B, dim]
if tail is not None:
return (x * self.ent(tail)).sum(-1)
return x @ self.ent.weight.t()
def score_query(self, query_vec, tail_id, dev):
"""Score a composed query vector against a specific tail."""
t_emb = self.ent(torch.tensor([tail_id], device=dev)).squeeze()
return (query_vec * t_emb).sum()
def train_starqe(model, model_name="StarQE",
train_data_=None, valid_ds_=None, ne_override=None):
"""
Trains on all four query types derived from 1-hop triples:
1p: direct link 2p: chain 2i: intersection 2u: union
Falls back to 1p only if not enough distinct (h,r) pairs for chains.
"""
dev = CONFIG['device']
_td = train_data_ or train_data
_vds = valid_ds_ or valid_ds
_ne = ne_override or NE
model.to(dev)
opt = torch.optim.Adam(model.parameters(), lr=CONFIG['query_lr'])
# Build query samples from triples
def build_queries(data):
# index (h,r) → list of tails
hr2tails = defaultdict(list)
triples = []
for s in data:
h = preprocessor.entity2id[s['head']]
r = preprocessor.relation2id[s['relation']]
t = preprocessor.entity2id[s['tail']]
qs = [(preprocessor.qualifier_key2id[qk],
preprocessor.qualifier_value2id[qv])
for qk, qv in s['qualifiers']]
hr2tails[(h,r)].append(t)
triples.append((h, r, t, qs))
return triples, hr2tails
triples, hr2tails = build_queries(_td)
hr_keys = [k for k, v in hr2tails.items() if len(v) >= 1]
print(f"\n{'='*60}\nTRAINING {model_name} "
f"({sum(p.numel() for p in model.parameters()):,} params)")
print(f" Triples: {len(triples):,} (h,r) pairs: {len(hr_keys):,}\n{'='*60}")
best_mrr, history = 0.0, []
for epoch in range(CONFIG['query_epochs']):
model.train(); total, cnt = 0.0, 0
np.random.shuffle(triples)
for h, r, t, qs in tqdm(triples, desc=f"Epoch {epoch+1}", leave=False):
losses = []
# ── 1p query ─────────────────────────────────────────────────
neg_t = np.random.randint(0, _ne)
h_t = torch.tensor([h], device=dev)
r_t = torch.tensor([r], device=dev)
pos_s = model(h_t, r_t, [qs], torch.tensor([t], device=dev))
neg_s = model(h_t, r_t, [qs], torch.tensor([neg_t],device=dev))
losses.append(F.margin_ranking_loss(
pos_s, neg_s, torch.ones(1,device=dev), margin=1.0))
# ── 2p query (chain through t as intermediate) ────────────────
# find another relation r2 where t is a head
r2_candidates = [rr for (hh,rr) in hr_keys if hh == t]
if r2_candidates:
r2 = np.random.choice(r2_candidates)
tails2 = hr2tails[(t, r2)]
if tails2:
t2 = np.random.choice(tails2)
neg2 = np.random.randint(0, _ne)
try:
q_pos = model._compose_2p(h, r, r2, qs, dev)
q_neg = q_pos # same query vec, different tail
t2_emb = model.ent(torch.tensor([t2], device=dev)).squeeze()
n2_emb = model.ent(torch.tensor([neg2],device=dev)).squeeze()
pos_2p = (model.drop(q_pos) * t2_emb).sum()
neg_2p = (model.drop(q_neg) * n2_emb).sum()
losses.append(F.margin_ranking_loss(
pos_2p.unsqueeze(0), neg_2p.unsqueeze(0),
torch.ones(1,device=dev), margin=1.0))
except Exception: pass
# ── 2i query (intersect two 1p queries sharing tail t) ────────
r2_for_t = [rr for (hh,rr) in hr_keys if t in hr2tails.get((hh,rr),[])]
r_for_t = [rr for (hh,rr) in hr_keys
if hh != h and t in hr2tails.get((hh,rr),[])]
if r_for_t:
h2 = np.random.choice([hh for (hh,rr) in hr_keys
if rr in r_for_t
and t in hr2tails.get((hh,rr),[])
and hh != h] or [h])
r2 = np.random.choice(r_for_t)
neg2 = np.random.randint(0, _ne)
try:
q2i = model._compose_2i(h, r, h2, r2, qs, dev)
t_e = model.ent(torch.tensor([t], device=dev)).squeeze()
n_e = model.ent(torch.tensor([neg2],device=dev)).squeeze()
pos_2i = (model.drop(q2i)*t_e).sum()
neg_2i = (model.drop(q2i)*n_e).sum()
losses.append(F.margin_ranking_loss(
pos_2i.unsqueeze(0), neg_2i.unsqueeze(0),
torch.ones(1,device=dev), margin=1.0))
except Exception: pass
if not losses: continue
loss = sum(losses) / len(losses)
opt.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step(); total += loss.item(); cnt += 1
if (epoch+1) % 5 == 0:
m = evaluate_model(model, _vds, dev)
print(f" Epoch {epoch+1}: loss={total/max(cnt,1):.4f} "
f"MRR={m['mrr']:.4f} MR={m['mr']:.1f} "
f"H@1={m['hits@1']:.4f} H@10={m['hits@10']:.4f}")
if m['mrr'] > best_mrr:
best_mrr = m['mrr']
torch.save(model.state_dict(),
f"{CONFIG['output_path']}{model_name}_best.pt")
print(" → best saved")
history.append(m)
return model, history
print("StarQE defined")
# ============================================================================
# 9 — Model 6: NBFNet+StarQE (residual path-augmented complex queries)
# ============================================================================
# Replaces StarQE's linear composition ρ(x+R[r]) with a residual MLP:
# x^l = x^0 + PathNet( Concat(x^{l-1}, R*[r_l]) )
# Intersection and union operators inherited unchanged from StarQE.
# ============================================================================
class NBFNetStarQEModel(nn.Module):
def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1):
super().__init__()
self.num_entities = ne
self.dim = dim
self.ent = nn.Embedding(ne, dim)
self.rel = nn.Embedding(nr, dim)
self.qk = nn.Embedding(nqk, dim)
self.qv = nn.Embedding(nqv, dim)
self.attn = nn.MultiheadAttention(dim, 4, dropout=dropout, batch_first=True)
# residual PathNet: R^{2d} -> R^d -> R^d
self.path_net = nn.Sequential(
nn.Linear(dim*2, dim), nn.LayerNorm(dim), nn.ReLU(),
nn.Dropout(dropout), nn.Linear(dim, dim))
# intersection operator
self.intersect = nn.Linear(dim*2, dim)
self.drop = nn.Dropout(dropout)
for e in [self.ent, self.rel, self.qk, self.qv]:
nn.init.xavier_uniform_(e.weight)
def _enrich_rel(self, r_id, quals, dev):
r_emb = self.rel(torch.tensor([r_id], device=dev)).squeeze()
if not quals: return r_emb
k = self.qk(torch.tensor([q[0] for q in quals], device=dev))
v = self.qv(torch.tensor([q[1] for q in quals], device=dev))
kv = (k+v).unsqueeze(0)
out, _ = self.attn(r_emb.view(1,1,-1), kv, kv)
return out.squeeze()
def _compose_step(self, x, x0, r_emb):
"""x^l = x^0 + PathNet(Concat(x^{l-1}, R[r]))"""
return x0 + self.path_net(torch.cat([x, r_emb], dim=-1))
def _compose_1p(self, h_id, r_id, quals, dev):
h = self.ent(torch.tensor([h_id], device=dev)).squeeze()
r_star = self._enrich_rel(r_id, quals, dev)
return self._compose_step(h, h, r_star)
def _compose_2p(self, h_id, r1_id, r2_id, quals, dev):
h = self.ent(torch.tensor([h_id], device=dev)).squeeze()
r1_emb = self._enrich_rel(r1_id, quals, dev)
r2_emb = self._enrich_rel(r2_id, [], dev)
x1 = self._compose_step(h, h, r1_emb)
return self._compose_step(x1, h, r2_emb)
def _compose_2i(self, h1, r1, h2, r2, quals, dev):
e1 = self._compose_1p(h1, r1, quals, dev)
e2 = self._compose_1p(h2, r2, [], dev)
return self.intersect(torch.cat([e1, e2], dim=-1))
def _compose_2u(self, h1, r1, h2, r2, quals, dev):
e1 = self._compose_1p(h1, r1, quals, dev)
e2 = self._compose_1p(h2, r2, [], dev)
return (e1 + e2) / 2
def forward(self, head, relation, qualifiers, tail=None):
dev = head.device
B = head.size(0)
outs = []
for i in range(B):
x = self._compose_1p(head[i].item(), relation[i].item(),
qualifiers[i], dev)
outs.append(x)
x = self.drop(torch.stack(outs))
if tail is not None: return (x * self.ent(tail)).sum(-1)
return x @ self.ent.weight.t()
# NBFNet+StarQE reuses train_starqe with the same interface
print("NBFNet+StarQE defined")
# ============================================================================
# 10 — Model 7: HyNT
# ============================================================================
class HyNTModel(nn.Module):
def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1,
n_heads=4, n_layers=2):
super().__init__()
self.num_entities = ne
self.dim = dim
self.ent = nn.Embedding(ne, dim)
self.rel = nn.Embedding(nr, dim)
self.qk = nn.Embedding(nqk, dim)
self.qv = nn.Embedding(nqv, dim)
self.qual_attn = nn.MultiheadAttention(dim, n_heads,
dropout=dropout, batch_first=True)
enc_layer = nn.TransformerEncoderLayer(
d_model=dim, nhead=n_heads,
dim_feedforward=dim*4, dropout=dropout, batch_first=True)
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
def _head(out):
return nn.Sequential(nn.Linear(dim,dim), nn.LayerNorm(dim),
nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim,out))
self.tail_head = _head(ne)
self.qv_head = _head(nqv)
self.qk_gate = nn.Sequential(nn.Linear(dim*2, dim), nn.Sigmoid())
for e in [self.ent, self.rel, self.qk, self.qv]:
nn.init.xavier_uniform_(e.weight)
def _aggregate_qualifiers(self, qs, dev):
if not qs: return torch.zeros(self.dim, device=dev)
k_embs = self.qk(torch.tensor([q[0] for q in qs], device=dev))
v_embs = self.qv(torch.tensor([q[1] for q in qs], device=dev))
kv = (k_embs + v_embs).unsqueeze(0)
out, _ = self.qual_attn(kv, kv, kv)
return out.mean(dim=1).squeeze(0)
def _encode(self, h, r, t, qs, dev, mask_tail=False):
h_e = self.ent(torch.tensor([h], device=dev))
r_e = self.rel(torch.tensor([r], device=dev))
t_e = (self.ent(torch.tensor([t], device=dev)) if not mask_tail
else torch.zeros(1, self.dim, device=dev))
q_c = self._aggregate_qualifiers(qs, dev).unsqueeze(0)
seq = torch.cat([h_e, r_e, t_e, q_c], dim=0).unsqueeze(0)
return self.encoder(seq)[0, 0]
def forward_tail(self, s, dev):
return self.tail_head(self._encode(s['h'],s['r'],0,s['qs'],dev,mask_tail=True))
def forward_qv(self, s, dev):
filtered = [(k,v) for k,v in s['qs'] if k != s['qk']]
ctx = self._encode(s['h'],s['r'],s['t'],filtered,dev)
qk_emb = self.qk(torch.tensor([s['qk']],device=dev)).squeeze()
gate = self.qk_gate(torch.cat([ctx, qk_emb], dim=-1))
return self.qv_head(gate * ctx)
def forward(self, head, relation, qualifiers, tail=None):
dev = head.device; B = head.size(0)
outs = [self.forward_tail({'h':head[i].item(),'r':relation[i].item(),
't':0,'qs':qualifiers[i]}, dev)
for i in range(B)]
scores = torch.stack(outs, dim=0)
if tail is not None: return scores[torch.arange(B), tail]
return scores
class HyNTDataset(Dataset):
def __init__(self, data, preprocessor):
self.samples = []
p = preprocessor
for triple in data:
h = p.entity2id[triple['head']]
r = p.relation2id[triple['relation']]
t = p.entity2id[triple['tail']]
qs = [(p.qualifier_key2id[qk], p.qualifier_value2id[qv])
for qk, qv in triple['qualifiers']]
self.samples.append({'task':'tail','h':h,'r':r,'t':t,'qs':qs})
for qk_id, qv_id in qs:
self.samples.append({'task':'qv','h':h,'r':r,'t':t,
'qs':qs,'qk':qk_id,'qv':qv_id})
def __len__(self): return len(self.samples)
def __getitem__(self, i): return self.samples[i]
def train_hynt(model, model_name="HyNT", train_data_=None, valid_ds_=None):
dev = CONFIG['device']
_td = train_data_ or train_data
_vds = valid_ds_ or valid_ds
model.to(dev)
opt = torch.optim.Adam(model.parameters(), lr=CONFIG['hynt_lr'])
loader = DataLoader(HyNTDataset(_td, preprocessor),
batch_size=CONFIG['hynt_batch'],
shuffle=True, collate_fn=mt_collate)
print(f"\n{'='*60}\nTRAINING {model_name} "
f"({sum(p.numel() for p in model.parameters()):,} params)\n{'='*60}")
best_mrr, history = 0.0, []
for epoch in range(CONFIG['hynt_epochs']):
model.train(); total, cnt = 0.0, 0
for by_task in tqdm(loader, desc=f"Epoch {epoch+1}", leave=False):
bl = torch.tensor(0., device=dev)
for task, samples in by_task.items():
for s in samples:
try:
if task == 'tail':
logits = model.forward_tail(s, dev)
tgt = torch.tensor(s['t'], device=dev)
bl = bl + F.cross_entropy(logits.unsqueeze(0), tgt.unsqueeze(0))
elif task == 'qv':
logits = model.forward_qv(s, dev)
tgt = torch.tensor(s['qv'], device=dev)
bl = bl + 0.8*F.cross_entropy(logits.unsqueeze(0), tgt.unsqueeze(0))
except Exception: continue
opt.zero_grad(); bl.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step(); total += bl.item(); cnt += 1
if (epoch+1) % 5 == 0:
m = evaluate_model(model, _vds, dev)
print(f" Epoch {epoch+1}: loss={total/max(cnt,1):.4f} "
f"MRR={m['mrr']:.4f} H@1={m['hits@1']:.4f} H@10={m['hits@10']:.4f}")
if m['mrr'] > best_mrr:
best_mrr = m['mrr']
torch.save(model.state_dict(),
f"{CONFIG['output_path']}{model_name}_best.pt")
history.append(m)
return model, history
print("HyNT defined")
# ============================================================================
# 11 — Model 8: MultiTask AlertStar
# ============================================================================
class MultiTaskDataset(Dataset):
def __init__(self, data, preprocessor, tasks=None, nqk_override=None):
self.samples = []
self.nqk = nqk_override or NQK
if tasks is None:
tasks = ['tail','relation','qual_key','qual_value']
p = preprocessor
for triple in data:
h = p.entity2id[triple['head']]
r = p.relation2id[triple['relation']]
t = p.entity2id[triple['tail']]
qs = [(p.qualifier_key2id[qk], p.qualifier_value2id[qv])
for qk, qv in triple['qualifiers']]
if 'tail' in tasks:
self.samples.append({'task':'tail','h':h,'r':r,'t':t,'qs':qs})
if 'relation' in tasks:
self.samples.append({'task':'relation','h':h,'r':r,'t':t,'qs':qs})
if qs:
if 'qual_key' in tasks:
self.samples.append({'task':'qual_key','h':h,'r':r,'t':t,
'qs':qs,'keys':[qk for qk,_ in qs]})
if 'qual_value' in tasks:
for qk, qv in qs:
self.samples.append({'task':'qual_value','h':h,'r':r,'t':t,
'qs':qs,'qk':qk,'qv':qv})
def __len__(self): return len(self.samples)
def __getitem__(self, i): return self.samples[i]
class MultiTaskAlertStar(nn.Module):
def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1,
n_heads=4, n_layers=3):
super().__init__()
self.num_entities = ne
self.ent = nn.Embedding(ne, dim)
self.rel = nn.Embedding(nr, dim)
self.qk = nn.Embedding(nqk, dim)
self.qv = nn.Embedding(nqv, dim)
enc_layer = nn.TransformerEncoderLayer(
d_model=dim, nhead=n_heads,
dim_feedforward=dim*4, dropout=dropout, batch_first=True)
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
def head(out):
return nn.Sequential(nn.Linear(dim,dim), nn.LayerNorm(dim),
nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim,out))
self.tail_head = head(ne); self.rel_head = head(nr)
self.qk_head = head(nqk); self.qv_head = head(nqv)
for e in [self.ent,self.rel,self.qk,self.qv]:
nn.init.xavier_uniform_(e.weight)
def _encode(self, h, r, t, qs, dev, mask_pos=None):
tokens = [self.ent(torch.tensor([h],device=dev)),
self.rel(torch.tensor([r],device=dev)),
self.ent(torch.tensor([t],device=dev))]
for qk_id, qv_id in qs:
tokens += [self.qk(torch.tensor([qk_id],device=dev)),
self.qv(torch.tensor([qv_id],device=dev))]
seq = torch.cat(tokens, dim=0).unsqueeze(0)
if mask_pos is not None and mask_pos < seq.size(1):
seq = seq.clone(); seq[0, mask_pos] = 0.0
return self.encoder(seq)[0, 0]
def forward_task(self, task, s, dev):
h, r, t, qs = s['h'], s['r'], s['t'], s['qs']
if task == 'tail':
return self.tail_head(self._encode(h,r,0,qs,dev,mask_pos=2))
elif task == 'relation':
return self.rel_head(self._encode(h,r,t,qs,dev,mask_pos=1))
elif task == 'qual_key':
return self.qk_head(self._encode(h,r,t,[],dev))
elif task == 'qual_value':
filtered = [(k,v) for k,v in qs if k != s['qk']]
return self.qv_head(self._encode(h,r,t,filtered,dev))
def forward(self, head, relation, qualifiers, tail=None):
dev = head.device; B = head.size(0)
outs = [self.forward_task('tail',{'h':head[i].item(),'r':relation[i].item(),
't':0,'qs':qualifiers[i]},dev)
for i in range(B)]
scores = torch.stack(outs, dim=0)
if tail is not None: return scores[torch.arange(B), tail]
return scores
def train_multitask(model, model_name="MultiTask_AlertStar",
active_tasks=None, train_data_=None, valid_ds_=None,
nqk_override=None):
if active_tasks is None:
active_tasks = ['tail','relation','qual_key','qual_value']
dev = CONFIG['device']
_td = train_data_ or train_data
_vds = valid_ds_ or valid_ds
_nqk = nqk_override or NQK
model.to(dev)
opt = torch.optim.Adam(model.parameters(), lr=CONFIG['mt_lr'])
loader = DataLoader(MultiTaskDataset(_td, preprocessor, tasks=active_tasks,
nqk_override=_nqk),
batch_size=CONFIG['mt_batch'],
shuffle=True, collate_fn=mt_collate)
weights = {'tail':1.0,'relation':1.0,'qual_key':0.5,'qual_value':0.8}
print(f"\n{'='*60}\nTRAINING {model_name} tasks={active_tasks}\n{'='*60}")
best_mrr, history = 0.0, []
for epoch in range(CONFIG['mt_epochs']):
model.train(); total, cnt = 0.0, 0
for by_task in tqdm(loader, desc=f"Epoch {epoch+1}", leave=False):
bl = torch.tensor(0., device=dev)
for task, samples in by_task.items():
for s in samples:
try:
logits = model.forward_task(task, s, dev)
w = weights.get(task, 1.0)
if task == 'tail':
tgt = torch.tensor(s['t'], device=dev)
loss = F.cross_entropy(logits.unsqueeze(0), tgt.unsqueeze(0))
elif task == 'relation':
tgt = torch.tensor(s['r'], device=dev)
loss = F.cross_entropy(logits.unsqueeze(0), tgt.unsqueeze(0))
elif task == 'qual_key':
tgt = torch.zeros(_nqk, device=dev)
tgt[s['keys']] = 1.0
loss = F.binary_cross_entropy_with_logits(logits, tgt)
elif task == 'qual_value':
tgt = torch.tensor(s['qv'], device=dev)
loss = F.cross_entropy(logits.unsqueeze(0), tgt.unsqueeze(0))
else: continue
bl = bl + w*loss
except Exception: continue
opt.zero_grad(); bl.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step(); total += bl.item(); cnt += 1
if (epoch+1) % 5 == 0:
m = evaluate_model(model, _vds, dev)
print(f" Epoch {epoch+1}: loss={total/max(cnt,1):.4f} "
f"MRR={m['mrr']:.4f} H@1={m['hits@1']:.4f} H@10={m['hits@10']:.4f}")
if m['mrr'] > best_mrr:
best_mrr = m['mrr']
torch.save(model.state_dict(),
f"{CONFIG['output_path']}{model_name}_best.pt")
history.append(m)
return model, history
print("MultiTask AlertStar defined")
# ============================================================================
# 12 — Model 9: HR-NBFNet (Hyper-Relational Bellman-Ford)
#
# Matches slide formulation exactly:
# h(0)_uvqq' <- INDICATOR(u, v, q', q)
# phi_q(h_qk, h_qv) = h_qk · h_qv [DisMult per qualifier pair]
# h_q = W_q · SUM phi_q [projected qualifier sum]
# w_q = sigma(qual_gate(r)) [per-relation scalar gate]
# MSG = src_feat * (rel_emb + w_q * h_q) [qualifier-gated message]
# h(t) = AGG(msgs) + h(0) [shortcut to h(0) not h(t-1)]
# ============================================================================
def build_hr_nbfnet_graph(train_data, preprocessor, device,
max_quals=8, p_override=None):
"""
Build qualifier-aware edge tensors.
Unlike TrueNBFNet's plain [E,3] edge_list, stores per-edge qualifier
pairs in a padded tensor [E*2, max_quals, 2] so each propagation
layer can apply DisMult qualifier composition per edge.
"""
p = p_override or preprocessor
nr = len(p.relation2id)
srcs, dsts, rels, qual_list, nquals = [], [], [], [], []
for t in train_data:
h = p.entity2id[t['head']]
r = p.relation2id[t['relation']]
tl = p.entity2id[t['tail']]
qs = [(p.qualifier_key2id[qk], p.qualifier_value2id[qv])
for qk, qv in t['qualifiers']]
# forward + inverse edges — both carry the same qualifiers
srcs.append(h); dsts.append(tl); rels.append(r); qual_list.append(qs); nquals.append(len(qs))
srcs.append(tl); dsts.append(h); rels.append(r+nr); qual_list.append(qs); nquals.append(len(qs))
E = len(srcs)
quals_tensor = torch.zeros(E, max_quals, 2, dtype=torch.long)
for i, qs in enumerate(qual_list):
for j, (qk, qv) in enumerate(qs[:max_quals]):
quals_tensor[i, j, 0] = qk
quals_tensor[i, j, 1] = qv
return {
'edge_src': torch.tensor(srcs, dtype=torch.long, device=device),
'edge_dst': torch.tensor(dsts, dtype=torch.long, device=device),
'edge_rel': torch.tensor(rels, dtype=torch.long, device=device),
'edge_quals': quals_tensor.to(device), # [E, max_quals, 2]
'edge_nquals': torch.tensor(nquals, dtype=torch.long, device=device),
'num_nodes': len(p.entity2id),
'nr': nr,
'max_quals': max_quals,
}
class HRNBFConvLayer(nn.Module):
def __init__(self, dim, num_relation, nqk, nqv,
chunk_size=5000, layer_norm=True, dropout=0.1):
super().__init__()
self.dim = dim
self.chunk_size = chunk_size
self.rel_emb = nn.Embedding(num_relation, dim)
self.qk_emb = nn.Embedding(nqk, dim)
self.qv_emb = nn.Embedding(nqv, dim)
self.W_q = nn.Linear(dim, dim, bias=False)
self.qual_gate = nn.Embedding(num_relation, 1)
nn.init.ones_(self.qual_gate.weight) # start fully open
self.linear = nn.Linear(dim*2, dim)
self.ln = nn.LayerNorm(dim) if layer_norm else None
self.act = nn.ReLU()
self.drop = nn.Dropout(dropout)
for e in [self.rel_emb, self.qk_emb, self.qv_emb]:
nn.init.xavier_uniform_(e.weight)
nn.init.xavier_uniform_(self.W_q.weight)
def _qualifier_embedding(self, edge_quals, edge_nquals, edge_rels):
"""h_q = w_q(r) * W_q · SUM (h_qk * h_qv)"""
E, max_q, _ = edge_quals.shape
h_qk = self.qk_emb(edge_quals[:,:,0]) # [E, max_q, dim]
h_qv = self.qv_emb(edge_quals[:,:,1]) # [E, max_q, dim]
phi_q = h_qk * h_qv # DisMult [E, max_q, dim]
idx = torch.arange(max_q, device=edge_quals.device).unsqueeze(0)
mask = (idx < edge_nquals.unsqueeze(1)).float().unsqueeze(2)
phi_q = phi_q * mask # zero padding
h_q = self.W_q(phi_q.sum(dim=1)) # [E, dim]
w_q = torch.sigmoid(self.qual_gate(edge_rels)) # [E, 1]
return w_q * h_q
def forward(self, graph, node_feat, h0):
src = graph['edge_src']; dst = graph['edge_dst']
rel = graph['edge_rel']; quals = graph['edge_quals']
nquals = graph['edge_nquals']
N = graph['num_nodes']; dev = node_feat.device
agg = torch.zeros(N, self.dim, device=dev)
for start in range(0, src.size(0), self.chunk_size):
end = start + self.chunk_size
s_ = src[start:end]; d_ = dst[start:end]
r_ = rel[start:end]; qs_ = quals[start:end]; nqs = nquals[start:end]
r_emb = self.rel_emb(r_)
h_q = self._qualifier_embedding(qs_, nqs, r_)
src_feat = node_feat[s_]
msg = src_feat * (r_emb + h_q) # qualifier-gated message
agg.scatter_add_(0, d_.unsqueeze(1).expand_as(msg), msg)
del r_emb, h_q, src_feat, msg
out = self.linear(torch.cat([node_feat, agg], dim=-1))
if self.ln: out = self.ln(out)
out = self.act(out); out = self.drop(out)
return out + h0 # shortcut to h(0)
class HRNBFNet(nn.Module):
def __init__(self, ne, nr, nqk, nqv, dim=200, num_layers=3,
chunk_size=5000, dropout=0.1, max_quals=8):
super().__init__()
self.num_entities = ne
self.dim = dim
self.max_quals = max_quals
nr2 = nr * 2
self.query_emb = nn.Embedding(nr2, dim)
self.query_qk_emb = nn.Embedding(nqk, dim)
self.query_qv_emb = nn.Embedding(nqv, dim)
self.query_qual_proj = nn.Linear(dim, dim, bias=False)
self.layers = nn.ModuleList([
HRNBFConvLayer(dim, nr2, nqk, nqv, chunk_size=chunk_size, dropout=dropout)
for _ in range(num_layers)])
self.mlp = nn.Sequential(
nn.Linear(dim*2, dim), nn.LayerNorm(dim), nn.ReLU(),
nn.Dropout(dropout), nn.Linear(dim, 1))
for e in [self.query_emb, self.query_qk_emb, self.query_qv_emb]:
nn.init.xavier_uniform_(e.weight)
nn.init.xavier_uniform_(self.query_qual_proj.weight)
def _indicator_init(self, graph, h_idx, r_idx, query_quals, device):
"""INDICATOR(u,v,q',q): source node gets rel + qualifier context."""
N = graph['num_nodes']
q_rel = self.query_emb(torch.tensor([r_idx], device=device))
if query_quals:
qk_ids = torch.tensor([qk for qk,_ in query_quals], device=device)
qv_ids = torch.tensor([qv for _,qv in query_quals], device=device)
phi = (self.query_qk_emb(qk_ids) *
self.query_qv_emb(qv_ids)).sum(0, keepdim=True)
q_qual = self.query_qual_proj(phi)
else:
q_qual = torch.zeros(1, self.dim, device=device)
feat = torch.zeros(N, self.dim, device=device)
feat[h_idx] = q_rel.squeeze() + q_qual.squeeze()
return feat
def _propagate(self, graph, h_idx, r_idx, query_quals, device):
feat = self._indicator_init(graph, h_idx, r_idx, query_quals, device)
h0 = feat.clone()
for layer in self.layers:
feat = layer(graph, feat, h0)
return feat
def forward(self, head, relation, qualifiers=None, tail=None, graph=None):
assert graph is not None
assert (head==head[0]).all() and (relation==relation[0]).all()
dev = head.device
query_quals = qualifiers[0] if qualifiers else []
feat = self._propagate(graph, head[0].item(),
relation[0].item(), query_quals, dev)
q_emb = self.query_emb(torch.tensor([relation[0].item()], device=dev))
score_in = torch.cat([feat, q_emb.expand(graph['num_nodes'],-1)], dim=-1)
all_scores = self.mlp(score_in).squeeze(-1)
if tail is not None: return all_scores[tail]
return all_scores.unsqueeze(0).expand(head.size(0), -1)
def train_hr_nbfnet(model, model_name="HR_NBFNet",
train_data_=None, valid_ds_=None, graph_=None,
p_override=None):
dev = CONFIG['device']
_td = train_data_ or train_data
_vds = valid_ds_ or valid_ds
_g = graph_ or hr_nbfnet_graph
_p = p_override or preprocessor
model.to(dev)
graph = {k: v.to(dev) if torch.is_tensor(v) else v for k,v in _g.items()}
ne_ = graph['num_nodes']
opt = torch.optim.Adam(model.parameters(), lr=CONFIG['hr_nbfnet_lr'])
mpg = CONFIG['nbfnet_max_per_group']
groups = defaultdict(list)
for i, s in enumerate(_td):
groups[(_p.entity2id[s['head']], _p.relation2id[s['relation']])].append(i)
keys = list(groups.keys())
print(f"\n{'='*60}\nTRAINING {model_name} "
f"({sum(p.numel() for p in model.parameters()):,} params)")
print(f" Graph: {graph['edge_src'].size(0):,} qualifier-aware edges\n{'='*60}")
best_mrr, history = 0.0, []
for epoch in range(CONFIG['hr_nbfnet_epochs']):
model.train(); np.random.shuffle(keys)
total, cnt = 0.0, 0
for (h, r) in tqdm(keys, desc=f"Epoch {epoch+1}", leave=False):
chosen = np.random.choice(groups[(h,r)],
min(len(groups[(h,r)]),mpg), replace=False)
t_pos = torch.tensor(
[_p.entity2id[_td[i]['tail']] for i in chosen], device=dev)
all_quals = [[(_p.qualifier_key2id[qk], _p.qualifier_value2id[qv])
for qk, qv in _td[i]['qualifiers']] for i in chosen]
rep_quals = max(all_quals, key=len) if all_quals else []
B = len(t_pos)
heads = torch.full((B,), h, dtype=torch.long, device=dev)
rels = torch.full((B,), r, dtype=torch.long, device=dev)
try:
pos = model(heads, rels, qualifiers=[rep_quals]*B,
tail=t_pos, graph=graph)
neg = model(heads, rels, qualifiers=[rep_quals]*B,
tail=torch.randint(0,ne_,(B,),device=dev), graph=graph)
loss = F.margin_ranking_loss(pos, neg,
torch.ones(B,device=dev), margin=1.0)
opt.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step(); total += loss.item(); cnt += 1
except Exception: continue
if (epoch+1) % 5 == 0:
m = evaluate_hr_nbfnet(model, _vds, graph, dev, p_override=_p)
print(f" Epoch {epoch+1}: loss={total/max(cnt,1):.4f} "
f"MRR={m['mrr']:.4f} MR={m['mr']:.1f} "
f"H@1={m['hits@1']:.4f} H@3={m['hits@3']:.4f} "
f"H@10={m['hits@10']:.4f}")
if m['mrr'] > best_mrr:
best_mrr = m['mrr']
torch.save(model.state_dict(),
f"{CONFIG['output_path']}{model_name}_best.pt")
print(" → best saved")
history.append(m)
return model, history
def evaluate_hr_nbfnet(model, dataset, graph, device,
max_groups=300, p_override=None):
_p = p_override or preprocessor
model.eval()
groups = defaultdict(list)
for i, s in enumerate(dataset.data):
groups[(_p.entity2id[s['head']], _p.relation2id[s['relation']])].append(i)
ranks = []
with torch.no_grad():
for (h,r) in tqdm(list(groups.keys())[:max_groups],
desc="Eval HR-NBFNet", leave=False):
idxs = groups[(h,r)]
tails = [_p.entity2id[dataset.data[i]['tail']] for i in idxs]
B = len(tails)
heads = torch.full((B,),h,dtype=torch.long,device=device)
rels = torch.full((B,),r,dtype=torch.long,device=device)
s0 = dataset.data[idxs[0]]
quals = [(_p.qualifier_key2id[qk], _p.qualifier_value2id[qv])
for qk,qv in s0['qualifiers']]
try:
scores = model(heads,rels,qualifiers=[quals]*B,graph=graph)
for i, tgt in enumerate(tails):
pos = (torch.argsort(scores[i],descending=True)==tgt
).nonzero(as_tuple=True)[0]
if len(pos): ranks.append(pos[0].item()+1)
except Exception: continue
if not ranks:
return {'mr':0.,'mrr':0.,'hits@1':0.,'hits@3':0.,'hits@10':0.}
ranks = np.array(ranks)
return {'mr':float(np.mean(ranks)),'mrr':float(np.mean(1./ranks)),
'hits@1':float(np.mean(ranks<=1)),'hits@3':float(np.mean(ranks<=3)),
'hits@10':float(np.mean(ranks<=10))}
print("HR-NBFNet defined")
# ============================================================================
# 13 — Model 10: MultiTask_HR_NBFNet (NEW)
#
# Combines HR-NBFNet's qualifier-aware Bellman-Ford propagation with the
# 4-task multi-task training strategy from MultiTask AlertStar.
#
# Architecture:
# - HR-NBFNet backbone: propagates query-conditioned features with
# per-edge DisMult qualifier embeddings → node feature f^L[v] ∈ R^d
# - 4 prediction heads sharing the backbone, each a 2-layer MLP:
# tail: MLP(cat(f^L[e'], q_emb)) → R^1 [margin ranking]
# relation: MLP(f^L[h]) → R^|R| [cross-entropy]
# qual_key: MLP(f^L[h]) → R^|QK| [BCE multi-label]
# qual_value: MLP(f^L[h] · qk_emb) → R^|QV| [cross-entropy]
#
# Key design choices vs MultiTask AlertStar (MT-AS):
# - MT-AS: Transformer over flat token sequence (local, no graph)
# - MT-HR: Bellman-Ford over HR graph (global, structure-aware)
# - Relation/qual-key/qual-value heads use head-node BF representation,
# so auxiliary tasks receive path-enriched graph signals
# - Qualifier value head gates on qk_emb (same as HyNT) for fine-grained
# attribute discrimination
# ============================================================================
class MultiTaskHRNBFNet(nn.Module):
def __init__(self, ne, nr, nqk, nqv, dim=200, num_layers=3,
chunk_size=5000, dropout=0.1, max_quals=8):
super().__init__()
self.num_entities = ne
self.nr = nr
self.nqk = nqk
self.nqv = nqv
self.dim = dim
# ── Shared HR-NBFNet backbone ─────────────────────────────────────
nr2 = nr * 2
self.query_emb = nn.Embedding(nr2, dim)
self.query_qk_emb = nn.Embedding(nqk, dim)
self.query_qv_emb = nn.Embedding(nqv, dim)
self.query_qual_proj = nn.Linear(dim, dim, bias=False)
self.layers = nn.ModuleList([
HRNBFConvLayer(dim, nr2, nqk, nqv,
chunk_size=chunk_size, dropout=dropout)
for _ in range(num_layers)])
# ── Prediction heads ──────────────────────────────────────────────
def _mlp_head(in_dim, out_dim):
return nn.Sequential(
nn.Linear(in_dim, dim), nn.LayerNorm(dim), nn.ReLU(),
nn.Dropout(dropout), nn.Linear(dim, out_dim))
# tail: concat(f^L[e'], q_emb) → scalar
self.tail_mlp = _mlp_head(dim*2, 1)
# relation: f^L[h] → |R| (use head node's BF representation)
self.rel_head = _mlp_head(dim, nr)
# qual_key: f^L[h] → |QK| (multi-label)
self.qk_head = _mlp_head(dim, nqk)
# qual_value: gate(f^L[h], qk_emb) → |QV|
self.qv_gate = nn.Sequential(nn.Linear(dim*2, dim), nn.Sigmoid())
self.qv_head = _mlp_head(dim, nqv)
# qualifier key embedding (for qv gating)
self.qk_emb_head = nn.Embedding(nqk, dim)
for e in [self.query_emb, self.query_qk_emb, self.query_qv_emb,
self.qk_emb_head]:
nn.init.xavier_uniform_(e.weight)
nn.init.xavier_uniform_(self.query_qual_proj.weight)
# ── backbone: shared with HRNBFNet ────────────────────────────────────
def _indicator_init(self, graph, h_idx, r_idx, query_quals, device):
N = graph['num_nodes']
q_rel = self.query_emb(torch.tensor([r_idx], device=device))
if query_quals:
qk_ids = torch.tensor([qk for qk,_ in query_quals], device=device)
qv_ids = torch.tensor([qv for _,qv in query_quals], device=device)
phi = (self.query_qk_emb(qk_ids) *
self.query_qv_emb(qv_ids)).sum(0, keepdim=True)
q_qual = self.query_qual_proj(phi)
else:
q_qual = torch.zeros(1, self.dim, device=device)
feat = torch.zeros(N, self.dim, device=device)
feat[h_idx] = q_rel.squeeze() + q_qual.squeeze()
return feat
def _propagate(self, graph, h_idx, r_idx, query_quals, device):
feat = self._indicator_init(graph, h_idx, r_idx, query_quals, device)
h0 = feat.clone()
for layer in self.layers:
feat = layer(graph, feat, h0)
return feat # [N, dim]
# ── task-specific forward passes ──────────────────────────────────────
def forward_tail_task(self, graph, h_idx, r_idx, query_quals, tail_ids, device):
"""Tail prediction: score MLP(cat(f^L[tail], q_emb))."""
feat = self._propagate(graph, h_idx, r_idx, query_quals, device)
q_emb = self.query_emb(torch.tensor([r_idx], device=device))
# score all entities
score_in = torch.cat([feat, q_emb.expand(graph['num_nodes'],-1)], dim=-1)
all_scores = self.tail_mlp(score_in).squeeze(-1) # [N]
if tail_ids is not None:
return all_scores[tail_ids]
return all_scores
def forward_rel_task(self, graph, h_idx, r_idx, query_quals, device):
"""Relation prediction: classify from head node's BF representation."""
feat = self._propagate(graph, h_idx, r_idx, query_quals, device)
return self.rel_head(feat[h_idx]) # [nr]
def forward_qk_task(self, graph, h_idx, r_idx, device):
"""Qualifier key prediction (multi-label) from head BF representation."""
# run with empty qualifiers so we don't leak qk info
feat = self._propagate(graph, h_idx, r_idx, [], device)
return self.qk_head(feat[h_idx]) # [nqk]
def forward_qv_task(self, graph, h_idx, r_idx, query_quals,
target_qk, device):
"""Qualifier value prediction gated on the target qualifier key."""
feat = self._propagate(graph, h_idx, r_idx, query_quals, device)
h_repr = feat[h_idx]
qk_emb = self.qk_emb_head(torch.tensor([target_qk], device=device)).squeeze()
gate = self.qv_gate(torch.cat([h_repr, qk_emb], dim=-1))
return self.qv_head(gate * h_repr) # [nqv]
# ── standard forward for evaluate_model compatibility ─────────────────
def forward(self, head, relation, qualifiers=None, tail=None, graph=None):
assert graph is not None
assert (head==head[0]).all() and (relation==relation[0]).all()
dev = head.device
query_quals = qualifiers[0] if qualifiers else []
all_scores = self.forward_tail_task(
graph, head[0].item(), relation[0].item(),
query_quals, None, dev)
if tail is not None: return all_scores[tail]
return all_scores.unsqueeze(0).expand(head.size(0), -1)
def train_mt_hr_nbfnet(model, model_name="MultiTask_HR_NBFNet",
train_data_=None, valid_ds_=None, graph_=None,
p_override=None, nqk_override=None, nqv_override=None):
"""
Multi-task training for HR-NBFNet backbone.
Each (h,r) group triggers one BF propagation pass; auxiliary tasks
(relation, qual_key, qual_value) add supervision on the head node's
final graph representation f^L[h].
Task weights: tail=1.0, relation=0.8, qual_key=0.5, qual_value=0.8
"""
dev = CONFIG['device']
_td = train_data_ or train_data
_vds = valid_ds_ or valid_ds
_g = graph_ or hr_nbfnet_graph
_p = p_override or preprocessor
_nqk = nqk_override or NQK
_nqv = nqv_override or NQV
model.to(dev)
graph = {k: v.to(dev) if torch.is_tensor(v) else v for k,v in _g.items()}
ne_ = graph['num_nodes']
opt = torch.optim.Adam(model.parameters(), lr=CONFIG['mt_hr_lr'])
mpg = CONFIG['nbfnet_max_per_group']
weights = {'tail':1.0, 'relation':0.8, 'qual_key':0.5, 'qual_value':0.8}
# Group training triples by (h, r)
groups = defaultdict(list)
for i, s in enumerate(_td):
groups[(_p.entity2id[s['head']], _p.relation2id[s['relation']])].append(i)
keys = list(groups.keys())
print(f"\n{'='*60}\nTRAINING {model_name} "
f"({sum(p.numel() for p in model.parameters()):,} params)")
print(f" Graph: {graph['edge_src'].size(0):,} qualifier-aware edges")
print(f" Tasks: tail / relation / qual_key / qual_value\n{'='*60}")
best_mrr, history = 0.0, []
for epoch in range(CONFIG['mt_hr_epochs']):
model.train(); np.random.shuffle(keys)
total, cnt = 0.0, 0
for (h, r) in tqdm(keys, desc=f"Epoch {epoch+1}", leave=False):
chosen = np.random.choice(groups[(h,r)],
min(len(groups[(h,r)]),mpg), replace=False)
samples = [_td[i] for i in chosen]
all_quals = [[(_p.qualifier_key2id[qk], _p.qualifier_value2id[qv])
for qk,qv in s['qualifiers']] for s in samples]
rep_quals = max(all_quals, key=len) if all_quals else []
t_ids = torch.tensor(
[_p.entity2id[s['tail']] for s in samples], device=dev)
neg_ids = torch.randint(0, ne_, (len(samples),), device=dev)
batch_loss = torch.tensor(0., device=dev)
try:
# ── Task 1: Tail prediction (margin ranking) ───────────────
pos_scores = model.forward_tail_task(
graph, h, r, rep_quals, t_ids, dev)
neg_scores = model.forward_tail_task(
graph, h, r, rep_quals, neg_ids, dev)
loss_tail = F.margin_ranking_loss(
pos_scores, neg_scores,
torch.ones(len(samples), device=dev), margin=1.0)
batch_loss = batch_loss + weights['tail'] * loss_tail
# ── Task 2: Relation prediction (cross-entropy) ────────────
rel_logits = model.forward_rel_task(graph, h, r, rep_quals, dev)
loss_rel = F.cross_entropy(
rel_logits.unsqueeze(0),
torch.tensor([r], device=dev))
batch_loss = batch_loss + weights['relation'] * loss_rel
# ── Task 3: Qualifier key prediction (BCE multi-label) ─────
if rep_quals:
qk_logits = model.forward_qk_task(graph, h, r, dev)
qk_target = torch.zeros(_nqk, device=dev)
qk_target[[qk for qk,_ in rep_quals[:_nqk]]] = 1.0
loss_qk = F.binary_cross_entropy_with_logits(
qk_logits, qk_target)
batch_loss = batch_loss + weights['qual_key'] * loss_qk
# ── Task 4: Qualifier value prediction (cross-entropy) ─────
if rep_quals:
tgt_qk, tgt_qv = rep_quals[0]
qv_logits = model.forward_qv_task(
graph, h, r, rep_quals, tgt_qk, dev)
loss_qv = F.cross_entropy(
qv_logits.unsqueeze(0),
torch.tensor([tgt_qv], device=dev))
batch_loss = batch_loss + weights['qual_value'] * loss_qv
opt.zero_grad(); batch_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step(); total += batch_loss.item(); cnt += 1
except Exception: continue
if (epoch+1) % 5 == 0:
m = evaluate_hr_nbfnet(model, _vds, graph, dev, p_override=_p)
print(f" Epoch {epoch+1}: loss={total/max(cnt,1):.4f} "
f"MRR={m['mrr']:.4f} MR={m['mr']:.1f} "
f"H@1={m['hits@1']:.4f} H@3={m['hits@3']:.4f} "
f"H@10={m['hits@10']:.4f}")
if m['mrr'] > best_mrr:
best_mrr = m['mrr']
torch.save(model.state_dict(),
f"{CONFIG['output_path']}{model_name}_best.pt")
print(" → best saved")
history.append(m)
return model, history
print("MultiTask_HR_NBFNet defined")
# ============================================================================
# 14 — TRAIN ALL 10 MAIN MODELS
# ============================================================================
print("\n" + "="*70)
print("PHASE 1: TRAINING ALL 10 MAIN MODELS")
print("="*70)
# ── 1. StarE ─────────────────────────────────────────────────────────────
stare_model = StarEModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout'])
stare_model, _, _ = train_standard(stare_model, "StarE")
all_results['StarE'] = evaluate_model(stare_model, test_ds, CONFIG['device'])
# ── 2. ShrinkE ───────────────────────────────────────────────────────────
shrinke_model = ShrinkEModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout'])
shrinke_model, _, _ = train_standard(shrinke_model, "ShrinkE")
all_results['ShrinkE'] = evaluate_model(shrinke_model, test_ds, CONFIG['device'])
# ── 3. TrueNBFNet ────────────────────────────────────────────────────────
nbfnet_graph = build_nbfnet_graph(train_data, preprocessor, CONFIG['device'])
nbfnet_model = TrueNBFNet(NE, NR, dim=DIM, num_layers=CONFIG['nbfnet_layers'],
chunk_size=CONFIG['nbfnet_chunk_size'],
dropout=CONFIG['dropout'], short_cut=True)
nbfnet_model, _ = train_true_nbfnet(nbfnet_model, "TrueNBFNet")
nbfnet_graph_dev = {k: v.to(CONFIG['device']) if torch.is_tensor(v) else v
for k,v in nbfnet_graph.items()}
all_results['TrueNBFNet'] = evaluate_nbfnet(
nbfnet_model, test_ds, nbfnet_graph_dev, CONFIG['device'])
# ── 4. AlertStar ─────────────────────────────────────────────────────────
alertstar_model = AlertStarModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout'])
alertstar_model, _, gate_history = train_standard(
alertstar_model, "AlertStar", gate_track=True)
all_results['AlertStar'] = evaluate_model(alertstar_model, test_ds, CONFIG['device'])
# ── 5. StarQE ────────────────────────────────────────────────────────────
starqe_model = StarQEModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout'])
starqe_model, _ = train_starqe(starqe_model, "StarQE")
all_results['StarQE'] = evaluate_model(starqe_model, test_ds, CONFIG['device'])
# ── 6. NBFNet+StarQE ──────────────────────────────────────────────────────
nbfstarqe_model = NBFNetStarQEModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout'])
nbfstarqe_model, _ = train_starqe(nbfstarqe_model, "NBFNet_StarQE")
all_results['NBFNet_StarQE'] = evaluate_model(nbfstarqe_model, test_ds, CONFIG['device'])
# ── 7. HyNT ──────────────────────────────────────────────────────────────
hynt_model = HyNTModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout'],
n_heads=CONFIG['hynt_n_heads'],
n_layers=CONFIG['hynt_n_layers'])
hynt_model, _ = train_hynt(hynt_model, "HyNT")
all_results['HyNT'] = evaluate_model(hynt_model, test_ds, CONFIG['device'])
# ── 8. MultiTask AlertStar ────────────────────────────────────────────────
mt_model = MultiTaskAlertStar(NE, NR, NQK, NQV, DIM, dropout=CONFIG['dropout'])
mt_model, _ = train_multitask(mt_model, "MultiTask_AlertStar")
all_results['MultiTask_AlertStar'] = evaluate_model(mt_model, test_ds, CONFIG['device'])
# ── 9. HR-NBFNet ─────────────────────────────────────────────────────────
print("\nBuilding HR-NBFNet qualifier-aware graph...")
hr_nbfnet_graph = build_hr_nbfnet_graph(
train_data, preprocessor, CONFIG['device'],
max_quals=CONFIG['hr_nbfnet_max_quals'])
print(f"HR graph: {hr_nbfnet_graph['edge_src'].size(0):,} edges")
hr_model = HRNBFNet(NE, NR, NQK, NQV, DIM,
num_layers=CONFIG['hr_nbfnet_layers'],
chunk_size=CONFIG['hr_nbfnet_chunk_size'],
dropout=CONFIG['dropout'],
max_quals=CONFIG['hr_nbfnet_max_quals'])
hr_model, _ = train_hr_nbfnet(hr_model, "HR_NBFNet")
hr_graph_dev = {k: v.to(CONFIG['device']) if torch.is_tensor(v) else v
for k,v in hr_nbfnet_graph.items()}
all_results['HR_NBFNet'] = evaluate_hr_nbfnet(
hr_model, test_ds, hr_graph_dev, CONFIG['device'])
# ── 10. MultiTask_HR_NBFNet (NEW) ─────────────────────────────────────────
# Reuses hr_nbfnet_graph — same qualifier-aware edge structure
mt_hr_model = MultiTaskHRNBFNet(NE, NR, NQK, NQV, DIM,
num_layers=CONFIG['hr_nbfnet_layers'],
chunk_size=CONFIG['hr_nbfnet_chunk_size'],
dropout=CONFIG['dropout'],
max_quals=CONFIG['hr_nbfnet_max_quals'])
mt_hr_model, _ = train_mt_hr_nbfnet(mt_hr_model, "MultiTask_HR_NBFNet")
all_results['MultiTask_HR_NBFNet'] = evaluate_hr_nbfnet(
mt_hr_model, test_ds, hr_graph_dev, CONFIG['device'])
# ── Summary ───────────────────────────────────────────────────────────────
print("\n All 10 models trained")
print("\nMAIN RESULTS:")
for m, r in all_results.items():
tag = " ← NEW" if m in ('HR_NBFNet','MultiTask_HR_NBFNet') else ""
print(f" {m:30s} MRR={r['mrr']:.4f} MR={r['mr']:7.1f} "
f"H@1={r['hits@1']:.4f} H@3={r['hits@3']:.4f} "
f"H@10={r['hits@10']:.4f}{tag}")
# ============================================================================
# 15 — ABLATION A1: AlertStar Component Ablation
# ============================================================================
print("\n" + "="*70)
print("ABLATION A1: AlertStar Component Ablation")
print("="*70)
ablation_configs = [
("AS-NoQual", dict(use_qual=False, use_path=True, fixed_gate=None)),
("AS-NoPath", dict(use_qual=True, use_path=False, fixed_gate=None)),
("AS-NoGate", dict(use_qual=True, use_path=True, fixed_gate=0.5)),
("AS-Full", dict(use_qual=True, use_path=True, fixed_gate=None)),
]
ablation_A1 = {}
for name, cfg in ablation_configs:
m = AlertStarModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout'], **cfg)
m, _, _ = train_standard(m, f"Ablation_{name}")
ablation_A1[name] = evaluate_model(m, test_ds, CONFIG['device'])
r = ablation_A1[name]
print(f" {name}: MRR={r['mrr']:.4f} H@1={r['hits@1']:.4f} H@10={r['hits@10']:.4f}")
ablation_results['A1_AlertStar_Components'] = ablation_A1
print("Ablation A1 complete")
# ============================================================================
# 16 — ABLATION A2: Gate Value Trajectory
# ============================================================================
print("\n" + "="*70)
print("ABLATION A2: Gate Value Analysis")
print("="*70)
if gate_history:
gate_df = pd.DataFrame(gate_history)
print(gate_df.to_string(index=False))
final_gate = gate_history[-1]['gate']
print(f"\n Final gate g = {final_gate:.4f}")
if final_gate > 0.6: print(" → Attention stream dominates (qualifier-awareness)")
elif final_gate < 0.4: print(" → Path stream dominates (structural reasoning)")
else: print(" → Balanced — both streams equally useful")
ablation_results['A2_Gate_Values'] = gate_history
else:
print(" No gate history available")
print("Ablation A2 complete")
# ============================================================================
# 17 — ABLATION A3: MultiTask Auxiliary Task Ablation
# ============================================================================
print("\n" + "="*70)
print("ABLATION A3: MultiTask AlertStar — Task Contribution")
print("="*70)
mt_ablation_configs = [
("MT-TailOnly", ['tail']),
("MT-Tail+Rel", ['tail','relation']),
("MT-Tail+QualKey", ['tail','qual_key']),
("MT-Tail+QualVal", ['tail','qual_value']),
("MT-Full", ['tail','relation','qual_key','qual_value']),
]
ablation_A3 = {}
for name, tasks in mt_ablation_configs:
m = MultiTaskAlertStar(NE, NR, NQK, NQV, DIM, dropout=CONFIG['dropout'])
m, _ = train_multitask(m, f"Ablation_{name}", active_tasks=tasks)
ablation_A3[name] = evaluate_model(m, test_ds, CONFIG['device'])
r = ablation_A3[name]
print(f" {name}: MRR={r['mrr']:.4f} H@1={r['hits@1']:.4f} H@10={r['hits@10']:.4f}")
ablation_results['A3_MultiTask_Tasks'] = ablation_A3
print("Ablation A3 complete")
# ============================================================================
# 18 — ABLATION A4: Qualifier Density (Q33 / Q66 / Q100)
#
# Models compared: StarE, AlertStar, HyNT, MultiTask-AS,
# HR-NBFNet, MultiTask_HR_NBFNet
# For Q33: reuse main results. For Q66/Q100: retrain all 6.
# ============================================================================
print("\n" + "="*70)
print("ABLATION A4: Qualifier Density Sensitivity")
print("="*70)
DENSITY_PATHS = {
'Q100': CONFIG['data_path'],
'Q33': CONFIG['q33_path'],
'Q66': CONFIG['q66_path'],
}
DENSITY_MODELS = ['StarE','AlertStar','HyNT',
'MultiTask_AlertStar','HR_NBFNet','MultiTask_HR_NBFNet']
ablation_A4 = {}
for density_label, dpath in DENSITY_PATHS.items():
print(f"\n{'='*50} {density_label} {'='*50}")
if dpath == CONFIG['data_path']:
ablation_A4[density_label] = {
m: all_results[m] for m in DENSITY_MODELS if m in all_results}
print(f" Reusing trained models for {density_label}")
continue
# Load density-specific dataset
p2 = DataPreprocessor()
tr2, va2, te2 = p2.load(dpath)
ne2 = len(p2.entity2id); nr2 = len(p2.relation2id)
nqk2 = len(p2.qualifier_key2id); nqv2 = len(p2.qualifier_value2id)
tr_ds2 = HRDataset(tr2, p2); va_ds2 = HRDataset(va2, p2); te_ds2 = HRDataset(te2, p2)
density_res = {}
# StarE
m = StarEModel(ne2, nr2, nqk2, nqv2, DIM, CONFIG['dropout'])
m, _, _ = train_standard(m, f"A4_{density_label}_StarE",
train_ds_=tr_ds2, valid_ds_=va_ds2, ne_override=ne2)
density_res['StarE'] = evaluate_model(m, te_ds2, CONFIG['device'])
# AlertStar
m = AlertStarModel(ne2, nr2, nqk2, nqv2, DIM, CONFIG['dropout'])
m, _, _ = train_standard(m, f"A4_{density_label}_AlertStar",
train_ds_=tr_ds2, valid_ds_=va_ds2, ne_override=ne2)
density_res['AlertStar'] = evaluate_model(m, te_ds2, CONFIG['device'])
# HyNT — build its own HyNTDataset with p2
_ht2 = HyNTDataset(tr2, p2)
m = HyNTModel(ne2, nr2, nqk2, nqv2, DIM, CONFIG['dropout'],
n_heads=4, n_layers=2)
m.to(CONFIG['device'])
_opt2 = torch.optim.Adam(m.parameters(), lr=CONFIG['hynt_lr'])
_loader2 = DataLoader(_ht2, batch_size=CONFIG['hynt_batch'],
shuffle=True, collate_fn=mt_collate)
_best2 = 0.0
for _ep in range(CONFIG['hynt_epochs']):
m.train()
for _bt in tqdm(_loader2, desc=f"A4 HyNT {density_label} ep{_ep+1}", leave=False):
_bl = torch.tensor(0., device=CONFIG['device'])
for _task, _samps in _bt.items():
for _s in _samps:
try:
if _task == 'tail':
_lg = m.forward_tail(_s, CONFIG['device'])
_tg = torch.tensor(_s['t'], device=CONFIG['device'])
_bl = _bl + F.cross_entropy(_lg.unsqueeze(0), _tg.unsqueeze(0))
elif _task == 'qv':
_lg = m.forward_qv(_s, CONFIG['device'])
_tg = torch.tensor(_s['qv'], device=CONFIG['device'])
_bl = _bl + 0.8*F.cross_entropy(_lg.unsqueeze(0), _tg.unsqueeze(0))
except Exception: continue
_opt2.zero_grad(); _bl.backward()
torch.nn.utils.clip_grad_norm_(m.parameters(), 1.0); _opt2.step()
if (_ep+1) % 5 == 0:
_mv = evaluate_model(m, va_ds2, CONFIG['device'])
if _mv['mrr'] > _best2:
_best2 = _mv['mrr']
torch.save(m.state_dict(),
f"{CONFIG['output_path']}A4_{density_label}_HyNT_best.pt")
density_res['HyNT'] = evaluate_model(m, te_ds2, CONFIG['device'])
# MultiTask-AS
m = MultiTaskAlertStar(ne2, nr2, nqk2, nqv2, DIM, dropout=CONFIG['dropout'])
_mt2_ds = MultiTaskDataset(tr2, p2, nqk_override=nqk2)
_mt2_ldr = DataLoader(_mt2_ds, batch_size=CONFIG['mt_batch'],
shuffle=True, collate_fn=mt_collate)
_wts2 = {'tail':1.0,'relation':1.0,'qual_key':0.5,'qual_value':0.8}
m.to(CONFIG['device']); _mo2 = torch.optim.Adam(m.parameters(), lr=CONFIG['mt_lr'])
_bm2 = 0.0
for _ep in range(CONFIG['mt_epochs']):
m.train()
for _bt in tqdm(_mt2_ldr, desc=f"A4 MT {density_label} ep{_ep+1}", leave=False):
_bl = torch.tensor(0., device=CONFIG['device'])
for _task, _samps in _bt.items():
for _s in _samps:
try:
_lg = m.forward_task(_task, _s, CONFIG['device'])
_w = _wts2.get(_task, 1.0)
if _task == 'tail':
_tg = torch.tensor(_s['t'], device=CONFIG['device'])
_ls = F.cross_entropy(_lg.unsqueeze(0), _tg.unsqueeze(0))
elif _task == 'relation':
_tg = torch.tensor(_s['r'], device=CONFIG['device'])
_ls = F.cross_entropy(_lg.unsqueeze(0), _tg.unsqueeze(0))
elif _task == 'qual_key':
_tg = torch.zeros(nqk2, device=CONFIG['device'])
_tg[_s['keys']] = 1.0
_ls = F.binary_cross_entropy_with_logits(_lg, _tg)
elif _task == 'qual_value':
_tg = torch.tensor(_s['qv'], device=CONFIG['device'])
_ls = F.cross_entropy(_lg.unsqueeze(0), _tg.unsqueeze(0))
else: continue
_bl = _bl + _w*_ls
except Exception: continue
_mo2.zero_grad(); _bl.backward()
torch.nn.utils.clip_grad_norm_(m.parameters(), 1.0); _mo2.step()
if (_ep+1) % 5 == 0:
_mv = evaluate_model(m, va_ds2, CONFIG['device'])
if _mv['mrr'] > _bm2:
_bm2 = _mv['mrr']
torch.save(m.state_dict(),
f"{CONFIG['output_path']}A4_{density_label}_MT_best.pt")
density_res['MultiTask_AlertStar'] = evaluate_model(m, te_ds2, CONFIG['device'])
# HR-NBFNet for this density
_hr_g2 = build_hr_nbfnet_graph(tr2, p2, CONFIG['device'],
max_quals=CONFIG['hr_nbfnet_max_quals'],
p_override=p2)
m = HRNBFNet(ne2, nr2, nqk2, nqv2, DIM,
num_layers=CONFIG['hr_nbfnet_layers'],
chunk_size=CONFIG['hr_nbfnet_chunk_size'],
dropout=CONFIG['dropout'], max_quals=CONFIG['hr_nbfnet_max_quals'])
m, _ = train_hr_nbfnet(m, f"A4_{density_label}_HR_NBFNet",
train_data_=tr2, valid_ds_=va_ds2,
graph_=_hr_g2, p_override=p2)
_hr_g2_dev = {k: v.to(CONFIG['device']) if torch.is_tensor(v) else v
for k,v in _hr_g2.items()}
density_res['HR_NBFNet'] = evaluate_hr_nbfnet(
m, te_ds2, _hr_g2_dev, CONFIG['device'], p_override=p2)
# MultiTask_HR_NBFNet for this density
m = MultiTaskHRNBFNet(ne2, nr2, nqk2, nqv2, DIM,
num_layers=CONFIG['hr_nbfnet_layers'],
chunk_size=CONFIG['hr_nbfnet_chunk_size'],
dropout=CONFIG['dropout'],
max_quals=CONFIG['hr_nbfnet_max_quals'])
m, _ = train_mt_hr_nbfnet(m, f"A4_{density_label}_MultiTask_HR_NBFNet",
train_data_=tr2, valid_ds_=va_ds2,
graph_=_hr_g2, p_override=p2,
nqk_override=nqk2, nqv_override=nqv2)
density_res['MultiTask_HR_NBFNet'] = evaluate_hr_nbfnet(
m, te_ds2, _hr_g2_dev, CONFIG['device'], p_override=p2)
ablation_A4[density_label] = density_res
for mname, metrics in density_res.items():
print(f" {density_label} {mname:30s} MRR={metrics['mrr']:.4f} "
f"H@1={metrics['hits@1']:.4f} H@10={metrics['hits@10']:.4f}")
ablation_results['A4_Qualifier_Density'] = ablation_A4
print("\n Ablation A4 complete")
# ============================================================================
# 19 — COMPLETE RESULTS TABLES
# ============================================================================
print("\n" + "="*80)
print("TABLE 1: COMPLETE 10-MODEL COMPARISON")
print("="*80)
df_main = pd.DataFrame({
'Model': list(all_results.keys()),
'MR': [all_results[m]['mr'] for m in all_results],
'MRR': [all_results[m]['mrr'] for m in all_results],
'Hits@1': [all_results[m]['hits@1'] for m in all_results],
'Hits@3': [all_results[m]['hits@3'] for m in all_results],
'Hits@10': [all_results[m]['hits@10'] for m in all_results],
}).sort_values('MRR', ascending=False).reset_index(drop=True)
print(df_main.to_string(index=False))
print("\n" + "="*80)
print("TABLE 2: TrueNBFNet vs HR-NBFNet vs MultiTask_HR_NBFNet")
print("="*80)
for mname in ['TrueNBFNet','HR_NBFNet','MultiTask_HR_NBFNet']:
r = all_results.get(mname, {})
print(f" {mname:30s} MRR={r.get('mrr',0):.4f} MR={r.get('mr',0):7.1f} "
f"H@1={r.get('hits@1',0):.4f} H@3={r.get('hits@3',0):.4f} "
f"H@10={r.get('hits@10',0):.4f}")
print("\n" + "="*80)
print("TABLE 3: StarQE family comparison")
print("="*80)
for mname in ['StarQE','NBFNet_StarQE']:
r = all_results.get(mname, {})
print(f" {mname:30s} MRR={r.get('mrr',0):.4f} MR={r.get('mr',0):7.1f} "
f"H@1={r.get('hits@1',0):.4f} H@10={r.get('hits@10',0):.4f}")
df_A1 = pd.DataFrame({
'Variant': list(ablation_A1.keys()),
'Qual?': ['✗','✓','✓','✓'],
'Path?': ['✓','✗','✓','✓'],
'Gate?': ['learned','learned','fixed=0.5','learned'],
'MRR': [ablation_A1[v]['mrr'] for v in ablation_A1],
'Hits@1': [ablation_A1[v]['hits@1'] for v in ablation_A1],
'Hits@3': [ablation_A1[v]['hits@3'] for v in ablation_A1],
'Hits@10': [ablation_A1[v]['hits@10'] for v in ablation_A1],
'MR': [ablation_A1[v]['mr'] for v in ablation_A1],
})
print("\n" + "="*80)
print("TABLE 4: ABLATION A1 — AlertStar Components")
print("="*80)
print(df_A1.to_string(index=False))
print("\n" + "="*80)
print("TABLE 5: ABLATION A2 — Gate Trajectory")
print("="*80)
if gate_history: print(pd.DataFrame(gate_history).to_string(index=False))
df_A3 = pd.DataFrame({
'Variant': list(ablation_A3.keys()),
'Tasks': [str(t[1]) for t in mt_ablation_configs],
'MRR': [ablation_A3[v]['mrr'] for v in ablation_A3],
'Hits@1': [ablation_A3[v]['hits@1'] for v in ablation_A3],
'Hits@10': [ablation_A3[v]['hits@10'] for v in ablation_A3],
'MR': [ablation_A3[v]['mr'] for v in ablation_A3],
})
print("\n" + "="*80)
print("TABLE 6: ABLATION A3 — MT Task Contribution")
print("="*80)
print(df_A3.to_string(index=False))
print("\n" + "="*80)
print("TABLE 7: ABLATION A4 — Qualifier Density")
print("="*80)
for density, res in ablation_A4.items():
print(f"\n {density}:")
for mname, metrics in res.items():
print(f" {mname:30s} MRR={metrics['mrr']:.4f} "
f"H@1={metrics['hits@1']:.4f} H@10={metrics['hits@10']:.4f}")
# ============================================================================
# 20 — VISUALIZATIONS
# ============================================================================
n_models = len(df_main)
palette_main = sns.color_palette("tab10", n_models)
met_list = ['mrr','hits@1','hits@3','hits@10']
lbl_list = ['MRR','H@1','H@3','H@10']
fig = plt.figure(figsize=(30, 24))
fig.suptitle("AlertStar — 10-Model Complete Results & Ablation Studies",
fontsize=18, fontweight='bold', y=0.98)
gs = fig.add_gridspec(3, 3, hspace=0.5, wspace=0.35)
# ── P1: Main MRR ─────────────────────────────────────────────────────────
ax1 = fig.add_subplot(gs[0, 0])
bars = ax1.bar(df_main['Model'], df_main['MRR'], color=palette_main)
ax1.set_title("Main Results — MRR (all 10 models)", fontweight='bold')
ax1.set_ylabel("MRR")
ax1.tick_params(axis='x', rotation=55, labelsize=6)
for bar, v in zip(bars, df_main['MRR']):
ax1.text(bar.get_x()+bar.get_width()/2, v+0.002,
f'{v:.3f}', ha='center', fontsize=5, fontweight='bold')
highlight = {'HR_NBFNet':'goldenrod', 'MultiTask_HR_NBFNet':'crimson',
'HyNT':'red', 'StarQE':'steelblue', 'NBFNet_StarQE':'navy'}
model_list = list(df_main['Model'])
for i, bar in enumerate(bars):
mname = model_list[i]
if mname in highlight:
bar.set_edgecolor(highlight[mname]); bar.set_linewidth(2.5)
# ── P2: Hits@k ────────────────────────────────────────────────────────────
ax2 = fig.add_subplot(gs[0, 1])
x2 = np.arange(n_models); w = 0.25
ax2.bar(x2-w, df_main['Hits@1'], w, label='H@1', color='steelblue')
ax2.bar(x2, df_main['Hits@3'], w, label='H@3', color='orange')
ax2.bar(x2+w, df_main['Hits@10'], w, label='H@10', color='green')
ax2.set_xticks(x2)
ax2.set_xticklabels(df_main['Model'], rotation=55, ha='right', fontsize=6)
ax2.set_title("Main Results — Hits@k", fontweight='bold')
ax2.legend(fontsize=8)
# ── P3: NBFNet family (TrueNBFNet / HR-NBFNet / MT-HR-NBFNet) ────────────
ax3 = fig.add_subplot(gs[0, 2])
nbf_mods = ['TrueNBFNet','HR_NBFNet','MultiTask_HR_NBFNet']
nbf_res = {m: all_results[m] for m in nbf_mods if m in all_results}
x3 = np.arange(len(lbl_list)); w3 = 0.25
pal3 = ['steelblue','goldenrod','crimson']
for i, (mname, res) in enumerate(nbf_res.items()):
vals = [res.get(mk,0) for mk in met_list]
offset = (i - len(nbf_res)/2 + 0.5)*w3
ax3.bar(x3+offset, vals, w3, label=mname, color=pal3[i])
ax3.set_xticks(x3); ax3.set_xticklabels(lbl_list)
ax3.set_title("NBFNet Family Comparison", fontweight='bold')
ax3.legend(fontsize=8); ax3.set_ylabel("Score")
# ── P4: StarQE family ─────────────────────────────────────────────────────
ax4 = fig.add_subplot(gs[1, 0])
qe_mods = ['StarE','StarQE','NBFNet_StarQE','AlertStar']
qe_res = {m: all_results[m] for m in qe_mods if m in all_results}
x4 = np.arange(len(lbl_list)); w4 = 0.2
pal4 = sns.color_palette("Set2", len(qe_res))
for i, (mname, res) in enumerate(qe_res.items()):
vals = [res.get(mk,0) for mk in met_list]
offset = (i - len(qe_res)/2 + 0.5)*w4
ax4.bar(x4+offset, vals, w4, label=mname, color=pal4[i])
ax4.set_xticks(x4); ax4.set_xticklabels(lbl_list)
ax4.set_title("StarQE Family vs Baselines", fontweight='bold')
ax4.legend(fontsize=8)
# ── P5: Ablation A1 ───────────────────────────────────────────────────────
ax5 = fig.add_subplot(gs[1, 1])
x5 = np.arange(len(df_A1)); w5 = 0.25
ax5.bar(x5-w5, df_A1['MRR'], w5, label='MRR', color='steelblue')
ax5.bar(x5, df_A1['Hits@1'], w5, label='H@1', color='orange')
ax5.bar(x5+w5, df_A1['Hits@10'], w5, label='H@10', color='green')
ax5.set_xticks(x5)
ax5.set_xticklabels(df_A1['Variant'], rotation=20, ha='right', fontsize=9)
ax5.set_title("A1: AlertStar Components", fontweight='bold')
ax5.legend(fontsize=8)
full_mrr = df_A1[df_A1['Variant']=='AS-Full']['MRR'].values[0]
for xi, (_, row) in zip(x5, df_A1.iterrows()):
drop = full_mrr - row['MRR']
if abs(drop) > 0.001:
ax5.text(xi-w5, row['MRR']+0.003,
f'Δ{drop:+.3f}', ha='center', fontsize=7, color='red')
# ── P6: Gate Trajectory ───────────────────────────────────────────────────
ax6 = fig.add_subplot(gs[1, 2])
if gate_history:
epochs = [g['epoch'] for g in gate_history]
gates = [g['gate'] for g in gate_history]
ax6.plot(epochs, gates, 'o-', color='purple', lw=2, markersize=6)
ax6.axhline(0.5, color='gray', ls='--', alpha=0.7, label='g=0.5 (balanced)')
ax6.fill_between(epochs, gates, 0.5, alpha=0.12,
color='blue' if gates[-1]>0.5 else 'orange')
ax6.set_ylim(0,1.05)
ax6.set_xlabel("Epoch"); ax6.set_ylabel("g = σ(θ)")
ax6.set_title("A2: AlertStar Gate Trajectory", fontweight='bold')
ax6.legend(fontsize=8)
ax6.annotate(f"Final: {gates[-1]:.3f}",
xy=(epochs[-1],gates[-1]),
xytext=(epochs[-1]-3, gates[-1]+0.08),
arrowprops=dict(arrowstyle='->',color='purple'),
fontsize=9, color='purple')
else:
ax6.text(0.5,0.5,'No gate data',ha='center',va='center',transform=ax6.transAxes)
ax6.set_title("A2: Gate Trajectory",fontweight='bold')
# ── P7: Ablation A3 ───────────────────────────────────────────────────────
ax7 = fig.add_subplot(gs[2, 0])
x7 = np.arange(len(df_A3))
ax7.bar(x7-w5, df_A3['MRR'], w5, label='MRR', color='steelblue')
ax7.bar(x7, df_A3['Hits@1'], w5, label='H@1', color='orange')
ax7.bar(x7+w5, df_A3['Hits@10'], w5, label='H@10', color='green')
ax7.set_xticks(x7)
ax7.set_xticklabels(df_A3['Variant'], rotation=25, ha='right', fontsize=8)
ax7.set_title("A3: MT Task Contribution", fontweight='bold')
ax7.legend(fontsize=8)
# ── P8: Ablation A4 — Density (MRR) ──────────────────────────────────────
ax8 = fig.add_subplot(gs[2, 1:])
densities = list(ablation_A4.keys())
n_dm = len(DENSITY_MODELS)
x8 = np.arange(len(densities))
w8 = 0.12
pal8 = sns.color_palette("tab10", n_dm)
for mi, mname in enumerate(DENSITY_MODELS):
mrrs = [ablation_A4[d].get(mname,{}).get('mrr',0) for d in densities]
offset = (mi - n_dm/2 + 0.5)*w8
bars8 = ax8.bar(x8+offset, mrrs, w8, label=mname, color=pal8[mi])
# bold border for new methods
if mname in ('HR_NBFNet','MultiTask_HR_NBFNet'):
for b in bars8: b.set_edgecolor('black'); b.set_linewidth(1.5)
ax8.set_xticks(x8); ax8.set_xticklabels(densities)
ax8.set_xlabel("Qualifier Density"); ax8.set_ylabel("MRR")
ax8.set_title("A4: MRR vs Qualifier Density (all 6 models)", fontweight='bold')
ax8.legend(fontsize=7, loc='upper left', ncol=2)
plt.savefig(f"{CONFIG['output_path']}complete_10model_results.png",
dpi=300, bbox_inches='tight')
plt.show()
print("Visualization saved")
# ============================================================================
# 21 — SAVE ALL RESULTS
# ============================================================================
with open(f"{CONFIG['output_path']}all_results_10models.json", 'w') as f:
json.dump({'main_results': all_results,
'ablation_results': ablation_results,
'gate_history': gate_history}, f, indent=2)
print(f"\n Saved to {CONFIG['output_path']}")
print("\n" + "="*70)
print("COMPLETE — 10 models + 4 ablations ready")
print("="*70)
print(f"\nAll {len(all_results)} models ranked by MRR:")
for m, r in sorted(all_results.items(), key=lambda x: -x[1]['mrr']):
tag = " ← NEW" if m in ('HR_NBFNet','MultiTask_HR_NBFNet') else \
" ← restored" if m in ('StarQE','NBFNet_StarQE') else ""
print(f" {m:30s} MRR={r['mrr']:.4f} H@10={r['hits@10']:.4f}{tag}")