# ============================================================================ # ALERTSTAR: COMPLETE END-TO-END NOTEBOOK WITH ALL METRICS # Hyper-Relational Knowledge Graph for Cybersecurity Alert Prediction # # UPDATED: Now includes MR, MRR, Hits@1, Hits@3, Hits@10 for ALL models # including MultiTask predictions # # Models: # 1. StarE — qualifier attention baseline # 2. ShrinkE — shrinking transform baseline # 3. NBFNet — actual Bellman-Ford (memory-safe) # 4. AlertStar — StarE + NBFNet hybrid (ours) # 5. StarQE — complex query answering # 6. NBFNet+StarQE — hybrid complex queries # 7. MultiTask-AS — Transformer multi-task (ours, main) # # Dataset: Alert-33%, 66% and 100% qualifiers-Cybersecurity dataset # ============================================================================ # ============================================================================ # 1 — Installs (uncomment if needed) # ============================================================================ # !pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.0+cu118.html # !pip install tqdm pandas matplotlib seaborn scikit-learn # ============================================================================ # 2 — Imports & Config # ============================================================================ import os, json, warnings warnings.filterwarnings('ignore') import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from collections import defaultdict from tqdm import tqdm import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torch_scatter import scatter_add, scatter_mean device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") CONFIG = { # paths 'data_path': '/.../inductive_q100_statements', 'output_path': '/.../results_inductive_q100_h1/', # shared 'embedding_dim': 200, 'dropout': 0.2, 'device': device, # standard training (StarE / ShrinkE / AlertStar) 'batch_size': 128, 'learning_rate': 0.0005, 'epochs': 20, # NBFNet specific 'nbfnet_epochs': 20, 'nbfnet_lr': 0.0005, 'nbfnet_layers': 3, 'nbfnet_chunk_size': 10000, # edges per chunk — controls GPU memory 'nbfnet_max_per_group': 8, # tails per (h,r) group per step # complex query / multi-task 'query_epochs': 20, 'query_lr': 0.0005, 'mt_epochs': 20, 'mt_lr': 0.0005, 'mt_batch': 64, } os.makedirs(CONFIG['output_path'], exist_ok=True) all_results = {} # stores final metrics for all models print("Config ready") # ============================================================================ # 3 — Data Loading & Preprocessing # ============================================================================ class DataPreprocessor: def __init__(self): self.entity2id = {} self.relation2id = {} self.qualifier_key2id = {} self.qualifier_value2id = {} def _parse(self, line): parts = line.strip().split(',') # Changed from ',' to '\t' if len(parts) < 3: return None head, relation, tail = parts[0], parts[1], parts[2] # Parse qualifiers if present (format: "key:value | key:value") qualifiers = [] if len(parts) > 3: qual_str = parts[3] for pair in qual_str.split('|'): pair = pair.strip() if ':' in pair: key, value = pair.split(':', 1) qualifiers.append((key.strip(), value.strip())) return {'head': head, 'relation': relation, 'tail': tail, 'qualifiers': qualifiers} def _register(self, triple): for token, vocab in [(triple['head'], self.entity2id), (triple['tail'], self.entity2id), (triple['relation'], self.relation2id)]: if token not in vocab: vocab[token] = len(vocab) for qk, qv in triple['qualifiers']: if qk not in self.qualifier_key2id: self.qualifier_key2id[qk] = len(self.qualifier_key2id) if qv not in self.qualifier_value2id: self.qualifier_value2id[qv] = len(self.qualifier_value2id) def load(self, data_path): files = [('train', 'train.txt'), ('valid', 'validation.txt'), ('test', 'test.txt')] splits = {} for name, fname in files: data = [] with open(os.path.join(data_path, fname)) as f: for line in f: t = self._parse(line) if t: self._register(t) data.append(t) splits[name] = data print(f" {name}: {len(data):,}") print(f" NE={len(self.entity2id)} NR={len(self.relation2id)} " f"NQK={len(self.qualifier_key2id)} NQV={len(self.qualifier_value2id)}") return splits['train'], splits['valid'], splits['test'] class HRDataset(Dataset): def __init__(self, data, preprocessor): self.data = data self.p = preprocessor def __len__(self): return len(self.data) def __getitem__(self, idx): t = self.data[idx] return { 'head': self.p.entity2id[t['head']], 'relation': self.p.relation2id[t['relation']], 'tail': self.p.entity2id[t['tail']], 'qualifiers': [(self.p.qualifier_key2id[qk], self.p.qualifier_value2id[qv]) for qk, qv in t['qualifiers']], } def collate(batch): return { 'head': torch.tensor([b['head'] for b in batch], dtype=torch.long), 'relation': torch.tensor([b['relation'] for b in batch], dtype=torch.long), 'tail': torch.tensor([b['tail'] for b in batch], dtype=torch.long), 'qualifiers': [b['qualifiers'] for b in batch], } print("Loading data...") preprocessor = DataPreprocessor() train_data, valid_data, test_data = preprocessor.load(CONFIG['data_path']) train_ds = HRDataset(train_data, preprocessor) valid_ds = HRDataset(valid_data, preprocessor) test_ds = HRDataset(test_data, preprocessor) NE = len(preprocessor.entity2id) NR = len(preprocessor.relation2id) NQK = len(preprocessor.qualifier_key2id) NQV = len(preprocessor.qualifier_value2id) DIM = CONFIG['embedding_dim'] print("Data ready") # ============================================================================ # 4 — UPDATED: Complete Evaluation Utilities with All Metrics # ============================================================================ def evaluate_model(model, dataset, device, max_samples=500): """ Complete filtered MRR/MR/Hits evaluation for link prediction. Returns ALL metrics: MR, MRR, Hits@1, Hits@3, Hits@10 """ model.eval() ranks = [] with torch.no_grad(): for i in tqdm(range(min(len(dataset), max_samples)), desc="Evaluating", leave=False): s = dataset[i] h = torch.tensor([s['head']], device=device) r = torch.tensor([s['relation']], device=device) t = s['tail'] q = [s['qualifiers']] scores = model(h, r, q).squeeze() # [NE] rank = (torch.argsort(scores, descending=True) == t ).nonzero(as_tuple=True)[0].item() + 1 ranks.append(rank) ranks = np.array(ranks) return { 'mr': float(np.mean(ranks)), 'mrr': float(np.mean(1.0 / ranks)), 'hits@1': float(np.mean(ranks <= 1)), 'hits@3': float(np.mean(ranks <= 3)), 'hits@10': float(np.mean(ranks <= 10)), } def train_standard(model, model_name, lr=None, epochs=None): """Training loop for StarE / ShrinkE / AlertStar (standard dataset).""" dev = CONFIG['device'] lr = lr or CONFIG['learning_rate'] eps = epochs or CONFIG['epochs'] model.to(dev) opt = torch.optim.Adam(model.parameters(), lr=lr) loader = DataLoader(train_ds, batch_size=CONFIG['batch_size'], shuffle=True, collate_fn=collate) print(f"\n{'='*60}") print(f"TRAINING {model_name} " f"({sum(p.numel() for p in model.parameters()):,} params)") print(f"{'='*60}") best_mrr, history = 0.0, [] for epoch in range(eps): model.train() total = 0.0 pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{eps}") for b in pbar: h, r, t, q = (b['head'].to(dev), b['relation'].to(dev), b['tail'].to(dev), b['qualifiers']) pos = model(h, r, q, t) neg = model(h, r, q, torch.randint(0, NE, (len(h),), device=dev)) loss = F.margin_ranking_loss(pos, neg, torch.ones_like(pos), margin=1.0) opt.zero_grad(); loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() total += loss.item() pbar.set_postfix({'loss': f'{loss.item():.4f}'}) if (epoch + 1) % 5 == 0: m = evaluate_model(model, valid_ds, dev) print(f" Epoch {epoch+1}: loss={total/len(loader):.4f} " f"MRR={m['mrr']:.4f} MR={m['mr']:.1f} " f"H@1={m['hits@1']:.4f} H@3={m['hits@3']:.4f} H@10={m['hits@10']:.4f}") if m['mrr'] > best_mrr: best_mrr = m['mrr'] torch.save(model.state_dict(), f"{CONFIG['output_path']}{model_name}_best.pt") print(" → best saved") history.append(m) return model, history print("Utilities ready") # ============================================================================ # 5 — Model 1: StarE # ============================================================================ class StarEModel(nn.Module): def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1): super().__init__() self.num_entities = ne self.ent = nn.Embedding(ne, dim) self.rel = nn.Embedding(nr, dim) self.qk = nn.Embedding(nqk, dim) self.qv = nn.Embedding(nqv, dim) self.attn = nn.MultiheadAttention(dim, 4, dropout=dropout, batch_first=True) self.drop = nn.Dropout(dropout) for e in [self.ent, self.rel, self.qk, self.qv]: nn.init.xavier_uniform_(e.weight) def _enrich(self, r_emb, quals, dev): """Enrich relation with qualifier attention.""" if not quals: return r_emb k = self.qk(torch.tensor([q[0] for q in quals], device=dev)) v = self.qv(torch.tensor([q[1] for q in quals], device=dev)) kv = (k + v).unsqueeze(0) # [1, nq, dim] out, _ = self.attn(r_emb.view(1,1,-1), kv, kv) return out.squeeze() def forward(self, head, relation, qualifiers, tail=None): dev = head.device h = self.ent(head) # [B, dim] r = torch.stack([ self._enrich(self.rel(relation[i:i+1]).squeeze(), qualifiers[i], dev) for i in range(head.size(0))]) # [B, dim] x = self.drop(h + r) if tail is not None: return (x * self.ent(tail)).sum(-1) # [B] return x @ self.ent.weight.t() # [B, NE] print("StarE defined") # ============================================================================ # 6 — Model 2: ShrinkE # ============================================================================ class ShrinkEModel(nn.Module): def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1): super().__init__() self.num_entities = ne self.ent = nn.Embedding(ne, dim) self.rel = nn.Embedding(nr, dim) self.qk = nn.Embedding(nqk, dim) self.qv = nn.Embedding(nqv, dim) self.shrink = nn.Sequential(nn.Linear(dim*2, dim), nn.Tanh(), nn.Dropout(dropout)) self.proj = nn.Linear(dim, dim) self.drop = nn.Dropout(dropout) for e in [self.ent, self.rel, self.qk, self.qv]: nn.init.xavier_uniform_(e.weight) def _shrink(self, r_emb, quals, dev): if not quals: return r_emb k = self.qk(torch.tensor([q[0] for q in quals], device=dev)) v = self.qv(torch.tensor([q[1] for q in quals], device=dev)) qc = (k + v).mean(0, keepdim=True) return self.shrink(torch.cat([r_emb.unsqueeze(0), qc], -1)).squeeze() def forward(self, head, relation, qualifiers, tail=None): dev = head.device h = self.proj(self.ent(head)) r = torch.stack([ self._shrink(self.rel(relation[i:i+1]).squeeze(), qualifiers[i], dev) for i in range(head.size(0))]) x = self.drop(h + self.proj(r)) if tail is not None: return (x * self.ent(tail)).sum(-1) return x @ self.ent.weight.t() print("ShrinkE defined") # ============================================================================ # 7 — Model 3: NBFNet (Bellman-Ford, memory-safe) # ============================================================================ class NBFConvLayer(nn.Module): """One Bellman-Ford layer with chunked edge processing.""" def __init__(self, dim, num_relation, chunk_size=10000, message_func='distmult', layer_norm=True): super().__init__() self.chunk_size = chunk_size self.message_func = message_func self.rel_emb = nn.Embedding(num_relation, dim) self.linear = nn.Linear(dim * 2, dim) self.ln = nn.LayerNorm(dim) if layer_norm else None self.act = nn.ReLU() nn.init.xavier_uniform_(self.rel_emb.weight) def forward(self, graph, node_feat): """node_feat : [N, 1, dim], returns : [N, 1, dim]""" edge_list = graph['edge_list'] N, B, D = node_feat.shape dev = node_feat.device agg = torch.zeros(N, B, D, device=dev) for start in range(0, edge_list.size(0), self.chunk_size): chunk = edge_list[start : start + self.chunk_size] src, dst, rel = chunk[:,0], chunk[:,1], chunk[:,2] sf = node_feat[src] rf = self.rel_emb(rel).unsqueeze(1) if self.message_func == 'distmult': msg = sf * rf elif self.message_func == 'transe': msg = sf + rf else: msg = sf * rf agg.scatter_add_(0, dst.view(-1,1,1).expand_as(msg), msg) del sf, rf, msg out = self.linear(torch.cat([node_feat, agg], dim=-1)) if self.ln: s = out.shape out = self.ln(out.flatten(0,1)).view(s) return self.act(out) class NBFNet(nn.Module): """Neural Bellman-Ford Network.""" def __init__(self, ne, nr, dim=200, num_layers=3, short_cut=True, chunk_size=10000, dropout=0.1): super().__init__() self.num_entities = ne self.dim = dim self.short_cut = short_cut nr2 = nr * 2 self.query_emb = nn.Embedding(nr2, dim) self.layers = nn.ModuleList([ NBFConvLayer(dim, nr2, chunk_size=chunk_size) for _ in range(num_layers) ]) self.mlp = nn.Sequential( nn.Linear(dim*2, dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, 1) ) nn.init.xavier_uniform_(self.query_emb.weight) def _bellman_ford(self, graph, h, r, device): """Single (h, r) BF propagation → [N, feature_dim]""" N = graph['num_nodes'] q = self.query_emb(torch.tensor([r], device=device)) feat = torch.zeros(N, 1, self.dim, device=device) feat[h, 0] = q[0] for layer in self.layers: h_new = layer(graph, feat) if self.short_cut: h_new = h_new + feat feat = h_new node_q = q.unsqueeze(0).expand(N, -1, -1) out = torch.cat([feat, node_q], dim=-1).squeeze(1) return out def forward(self, head, relation, qualifiers=None, tail=None, graph=None): assert graph is not None, "NBFNet requires graph dict" assert (head == head[0]).all() and (relation == relation[0]).all(), \ "All batch items must share the same (head, relation)" dev = head.device feat = self._bellman_ford(graph, head[0].item(), relation[0].item(), dev) if tail is not None: score = self.mlp(feat[tail]).squeeze(-1) else: B = head.size(0) score = self.mlp(feat).squeeze(-1) score = score.unsqueeze(0).expand(B, -1) return score def build_nbfnet_graph(train_data, preprocessor, device): """Build [E,3] edge tensor with inverse edges.""" nr = len(preprocessor.relation2id) rows = [] for t in train_data: h = preprocessor.entity2id[t['head']] r = preprocessor.relation2id[t['relation']] tl = preprocessor.entity2id[t['tail']] rows += [(h, tl, r), (tl, h, r + nr)] return { 'edge_list': torch.tensor(rows, dtype=torch.long, device=device), 'num_nodes': len(preprocessor.entity2id), } def train_true_nbfnet(model, model_name="NBFNet"): """Training loop that groups by (h,r) so BF runs once per group.""" dev = CONFIG['device'] lr = CONFIG['nbfnet_lr'] eps = CONFIG['nbfnet_epochs'] mpg = CONFIG['nbfnet_max_per_group'] model.to(dev) graph = {k: v.to(dev) if torch.is_tensor(v) else v for k, v in nbfnet_graph.items()} opt = torch.optim.Adam(model.parameters(), lr=lr) groups = defaultdict(list) for i, s in enumerate(train_data): h = preprocessor.entity2id[s['head']] r = preprocessor.relation2id[s['relation']] groups[(h, r)].append(i) keys = list(groups.keys()) print(f"\n{'='*60}\nTRAINING {model_name} " f"({sum(p.numel() for p in model.parameters()):,} params)") print(f"Graph: {graph['edge_list'].size(0):,} edges, " f"{graph['num_nodes']} nodes | " f"{len(keys)} unique (h,r) pairs\n{'='*60}") best_mrr, history = 0.0, [] for epoch in range(eps): model.train() np.random.shuffle(keys) total, cnt = 0.0, 0 pbar = tqdm(keys, desc=f"Epoch {epoch+1}/{eps}") for (h, r) in pbar: chosen = np.random.choice(groups[(h,r)], min(len(groups[(h,r)]), mpg), replace=False) t_pos = torch.tensor( [preprocessor.entity2id[train_data[i]['tail']] for i in chosen], device=dev) B = len(t_pos) heads = torch.full((B,), h, dtype=torch.long, device=dev) rels = torch.full((B,), r, dtype=torch.long, device=dev) pos = model(heads, rels, tail=t_pos, graph=graph) neg = model(heads, rels, tail=torch.randint(0, NE, (B,), device=dev), graph=graph) loss = F.margin_ranking_loss(pos, neg, torch.ones(B, device=dev), margin=1.0) opt.zero_grad(); loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() total += loss.item(); cnt += 1 pbar.set_postfix({'loss': f'{loss.item():.4f}'}) if (epoch + 1) % 5 == 0: m = evaluate_nbfnet(model, valid_ds, graph, dev) print(f" Epoch {epoch+1}: loss={total/cnt:.4f} " f"MRR={m['mrr']:.4f} MR={m['mr']:.1f} " f"H@1={m['hits@1']:.4f} H@3={m['hits@3']:.4f} H@10={m['hits@10']:.4f}") if m['mrr'] > best_mrr: best_mrr = m['mrr'] torch.save(model.state_dict(), f"{CONFIG['output_path']}{model_name}_best.pt") print(" → best saved") history.append(m) return model, history def evaluate_nbfnet(model, dataset, graph, device, max_groups=300): """Evaluation that also groups by (h,r) for BF with ALL metrics.""" model.eval() groups = defaultdict(list) for i, s in enumerate(dataset.data): h = preprocessor.entity2id[s['head']] r = preprocessor.relation2id[s['relation']] groups[(h, r)].append(i) ranks = [] with torch.no_grad(): for (h, r) in tqdm(list(groups.keys())[:max_groups], desc="Evaluating", leave=False): idxs = groups[(h, r)] tails = [preprocessor.entity2id[dataset.data[i]['tail']] for i in idxs] B = len(tails) heads = torch.full((B,), h, dtype=torch.long, device=device) rels = torch.full((B,), r, dtype=torch.long, device=device) try: scores = model(heads, rels, graph=graph) for i, tgt in enumerate(tails): pos = (torch.argsort(scores[i], descending=True ) == tgt).nonzero(as_tuple=True)[0] if len(pos): ranks.append(pos[0].item() + 1) except Exception: continue if not ranks: return {'mr':0., 'mrr':0., 'hits@1':0., 'hits@3':0., 'hits@10':0.} ranks = np.array(ranks) return { 'mr': float(np.mean(ranks)), 'mrr': float(np.mean(1./ranks)), 'hits@1': float(np.mean(ranks<=1)), 'hits@3': float(np.mean(ranks<=3)), 'hits@10': float(np.mean(ranks<=10)), } print("Building graph for NBFNet...") nbfnet_graph = build_nbfnet_graph(train_data, preprocessor, CONFIG['device']) print(f" NBFNet ready " f"(graph: {nbfnet_graph['edge_list'].size(0):,} edges, " f"{nbfnet_graph['num_nodes']} nodes)") # ============================================================================ # 8 — Model 4: AlertStar # ============================================================================ class AlertStarModel(nn.Module): """AlertStar: qualifier-aware attention (StarE) fused with path composition.""" def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1): super().__init__() self.num_entities = ne self.ent = nn.Embedding(ne, dim) self.rel = nn.Embedding(nr, dim) self.qk = nn.Embedding(nqk, dim) self.qv = nn.Embedding(nqv, dim) self.attn = nn.MultiheadAttention(dim, 4, dropout=dropout, batch_first=True) self.ln1 = nn.LayerNorm(dim) self.path_net = nn.Sequential( nn.Linear(dim*2, dim), nn.LayerNorm(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim) ) self.ln2 = nn.LayerNorm(dim) self.gate = nn.Parameter(torch.tensor(0.5)) self.drop = nn.Dropout(dropout) for e in [self.ent, self.rel, self.qk, self.qv]: nn.init.xavier_uniform_(e.weight) def _enrich(self, r_emb, quals, dev): if not quals: return r_emb k = self.qk(torch.tensor([q[0] for q in quals], device=dev)) v = self.qv(torch.tensor([q[1] for q in quals], device=dev)) kv = (k + v).unsqueeze(0) out, _ = self.attn(r_emb.view(1,1,-1), kv, kv) return out.squeeze() def forward(self, head, relation, qualifiers, tail=None): dev = head.device h = self.ent(head) r = torch.stack([ self._enrich(self.rel(relation[i:i+1]).squeeze(), qualifiers[i], dev) for i in range(head.size(0))]) stare = self.ln1(h + r) path = self.path_net(torch.cat([h, r], dim=-1)) path = self.ln2(h + path) g = torch.sigmoid(self.gate) x = self.drop(g * stare + (1 - g) * path) if tail is not None: return (x * self.ent(tail)).sum(-1) return x @ self.ent.weight.t() print("AlertStar defined") # ============================================================================ # 9 — Model 5: StarQE (Complex Query Answering) # ============================================================================ def generate_complex_queries(train_data, preprocessor, n=500): """Generate 1p / 2p / 2i / 2u query structures from training triples.""" e2id = preprocessor.entity2id r2id = preprocessor.relation2id hr2t = defaultdict(list) for t in train_data: h = e2id[t['head']]; r = r2id[t['relation']]; tl = e2id[t['tail']] hr2t[(h, r)].append(tl) triples = [(e2id[t['head']], r2id[t['relation']], e2id[t['tail']]) for t in train_data] queries = {'1p': [], '2p': [], '2i': [], '2u': []} np.random.shuffle(triples) for h, r, t in triples[:min(len(triples), n*4)]: if len(queries['1p']) < n: queries['1p'].append({'anchor': h, 'relations': [r], 'answer': t}) for r2, t2_list in hr2t.items(): if r2[0] == t and t2_list and len(queries['2p']) < n: queries['2p'].append({ 'anchor': h, 'relations': [r, r2[1]], 'answer': t2_list[0]}) break if len(queries['2i']) < n: candidates = [(h2,r2) for (h2,r2), tl in hr2t.items() if t in tl and h2!=h] if candidates: h2, r2 = candidates[0] queries['2i'].append({ 'anchors': [h, h2], 'relations': [r, r2], 'answer': t}) if len(queries['2u']) < n: candidates = [(h2,r2) for (h2,r2), tl in hr2t.items() if tl and h2!=h][:1] if candidates: h2, r2 = candidates[0] queries['2u'].append({ 'anchors': [h, h2], 'relations': [r, r2], 'answer': t}) if all(len(v) >= n for v in queries.values()): break print("Complex query stats:") for k, v in queries.items(): print(f" {k}: {len(v)}") return queries class StarQEModel(nn.Module): """StarQE: qualifier-aware complex query answering.""" def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1): super().__init__() self.num_entities = ne self.ent = nn.Embedding(ne, dim) self.rel = nn.Embedding(nr, dim) self.qk = nn.Embedding(nqk, dim) self.qv = nn.Embedding(nqv, dim) self.attn = nn.MultiheadAttention(dim, 4, dropout=dropout, batch_first=True) self.proj = nn.Sequential(nn.Linear(dim, dim), nn.ReLU()) self.inter = nn.Sequential(nn.Linear(dim*2, dim), nn.ReLU()) self.drop = nn.Dropout(dropout) for e in [self.ent, self.rel, self.qk, self.qv]: nn.init.xavier_uniform_(e.weight) def _compose(self, anchor_emb, relations): """Compose anchor + chain of relations.""" x = anchor_emb for r in relations: r_emb = self.rel(torch.tensor([r], device=x.device)).squeeze() x = self.proj(x + r_emb) return x def answer_query(self, query_type, query): dev = CONFIG['device'] if query_type == '1p': h = self.ent(torch.tensor([query['anchor']], device=dev)).squeeze() return self._compose(h, query['relations']) elif query_type == '2p': h = self.ent(torch.tensor([query['anchor']], device=dev)).squeeze() return self._compose(h, query['relations']) elif query_type == '2i': h1 = self.ent(torch.tensor([query['anchors'][0]], device=dev)).squeeze() h2 = self.ent(torch.tensor([query['anchors'][1]], device=dev)).squeeze() e1 = self._compose(h1, [query['relations'][0]]) e2 = self._compose(h2, [query['relations'][1]]) return self.inter(torch.cat([e1, e2], -1)) elif query_type == '2u': h1 = self.ent(torch.tensor([query['anchors'][0]], device=dev)).squeeze() h2 = self.ent(torch.tensor([query['anchors'][1]], device=dev)).squeeze() e1 = self._compose(h1, [query['relations'][0]]) e2 = self._compose(h2, [query['relations'][1]]) return (e1 + e2) / 2 def forward(self, head, relation, qualifiers, tail=None): """Standard link-pred interface (1p queries only).""" dev = head.device h = self.ent(head) r = self.rel(relation) x = self.drop(self.proj(h + r)) if tail is not None: return (x * self.ent(tail)).sum(-1) return x @ self.ent.weight.t() def train_query_model(model, complex_queries, model_name): """Train on complex queries.""" dev = CONFIG['device'] model.to(dev) opt = torch.optim.Adam(model.parameters(), lr=CONFIG['query_lr']) print(f"\n{'='*60}\nTRAINING {model_name} " f"({sum(p.numel() for p in model.parameters()):,} params)\n{'='*60}") best_mrr, history = 0.0, [] all_queries = [(qt, q) for qt, qs in complex_queries.items() for q in qs] for epoch in range(CONFIG['query_epochs']): model.train() np.random.shuffle(all_queries) total = 0.0 pbar = tqdm(all_queries, desc=f"Epoch {epoch+1}/{CONFIG['query_epochs']}") for qt, q in pbar: qvec = model.answer_query(qt, q).unsqueeze(0) answer = torch.tensor([q['answer']], device=dev) pos = (qvec * model.ent(answer)).sum(-1) neg_id = torch.randint(0, NE, (1,), device=dev) neg = (qvec * model.ent(neg_id)).sum(-1) loss = F.margin_ranking_loss(pos, neg, torch.ones(1, device=dev), margin=1.0) opt.zero_grad(); loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() total += loss.item() pbar.set_postfix({'loss': f'{loss.item():.4f}'}) if (epoch + 1) % 5 == 0: m = evaluate_complex_queries(model, complex_queries) mr = np.mean([v['mrr'] for v in m.values()]) print(f" Epoch {epoch+1}: loss={total/len(all_queries):.4f} " f"Avg-MRR={mr:.4f}") if mr > best_mrr: best_mrr = mr torch.save(model.state_dict(), f"{CONFIG['output_path']}{model_name}_best.pt") print(" → best saved") history.append({'epoch': epoch+1, **m}) return model, history def evaluate_complex_queries(model, complex_queries, max_per_type=200): """Evaluate all query types with ALL metrics.""" model.eval() results = {} with torch.no_grad(): for qt, queries in complex_queries.items(): ranks = [] for q in queries[:max_per_type]: try: qvec = model.answer_query(qt, q) scores = qvec @ model.ent.weight.t() rank = (torch.argsort(scores, descending=True ) == q['answer'] ).nonzero(as_tuple=True)[0].item() + 1 ranks.append(rank) except Exception: continue if ranks: ranks = np.array(ranks) results[qt] = { 'mr': float(np.mean(ranks)), 'mrr': float(np.mean(1./ranks)), 'hits@1': float(np.mean(ranks<=1)), 'hits@3': float(np.mean(ranks<=3)), 'hits@10':float(np.mean(ranks<=10)), } else: results[qt] = {'mr':0., 'mrr':0., 'hits@1':0., 'hits@3':0., 'hits@10':0.} return results print("StarQE defined") # ============================================================================ # 10 — Model 6: NBFNet+StarQE # ============================================================================ class NBFNetStarQE(StarQEModel): """Extends StarQE with path-composition from NBFNet.""" def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1): super().__init__(ne, nr, nqk, nqv, dim, dropout) self.path_net = nn.Sequential( nn.Linear(dim*2, dim), nn.LayerNorm(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim) ) def _compose(self, anchor_emb, relations): """Path composition using NBFNet-style network.""" x = anchor_emb for r in relations: r_emb = self.rel(torch.tensor([r], device=x.device)).squeeze() x = anchor_emb + self.path_net(torch.cat([x, r_emb], -1)) return x print("NBFNet+StarQE defined") # ============================================================================ # 11 — UPDATED: Model 7: MultiTask AlertStar with ALL Metrics # ============================================================================ class MultiTaskDataset(Dataset): """Each hyper-relational triple generates 4 tasks.""" def __init__(self, data, preprocessor): self.samples = [] p = preprocessor for triple in data: h = p.entity2id[triple['head']] r = p.relation2id[triple['relation']] t = p.entity2id[triple['tail']] qs = [(p.qualifier_key2id[qk], p.qualifier_value2id[qv]) for qk, qv in triple['qualifiers']] self.samples.append({'task':'tail', 'h':h,'r':r,'t':t,'qs':qs}) self.samples.append({'task':'relation', 'h':h,'r':r,'t':t,'qs':qs}) if qs: keys = [qk for qk, _ in qs] self.samples.append({'task':'qual_key','h':h,'r':r,'t':t, 'qs':qs,'keys':keys}) for qk, qv in qs: self.samples.append({'task':'qual_value','h':h,'r':r,'t':t, 'qs':qs,'qk':qk,'qv':qv}) def __len__(self): return len(self.samples) def __getitem__(self, idx): return self.samples[idx] def mt_collate(batch): """Group by task, pad as needed.""" by_task = defaultdict(list) for s in batch: by_task[s['task']].append(s) return by_task class MultiTaskAlertStar(nn.Module): """Transformer-based multi-task model.""" def __init__(self, ne, nr, nqk, nqv, dim=200, dropout=0.1, n_heads=4, n_layers=3): super().__init__() self.num_entities = ne self.ent = nn.Embedding(ne, dim) self.rel = nn.Embedding(nr, dim) self.qk = nn.Embedding(nqk, dim) self.qv = nn.Embedding(nqv, dim) enc_layer = nn.TransformerEncoderLayer( d_model=dim, nhead=n_heads, dim_feedforward=dim*4, dropout=dropout, batch_first=True) self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers) def head(out): return nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, out)) self.tail_head = head(ne) self.rel_head = head(nr) self.qk_head = head(nqk) self.qv_head = head(nqv) for e in [self.ent, self.rel, self.qk, self.qv]: nn.init.xavier_uniform_(e.weight) def _encode(self, h, r, t, qs, dev, mask_pos=None): """Build token sequence and encode with Transformer.""" tokens = [self.ent(torch.tensor([h], device=dev)), self.rel(torch.tensor([r], device=dev)), self.ent(torch.tensor([t], device=dev))] for qk_id, qv_id in qs: tokens.append(self.qk(torch.tensor([qk_id], device=dev))) tokens.append(self.qv(torch.tensor([qv_id], device=dev))) seq = torch.cat(tokens, dim=0).unsqueeze(0) if mask_pos is not None and mask_pos < seq.size(1): seq = seq.clone() seq[0, mask_pos] = 0.0 out = self.encoder(seq) return out[0, 0] def forward_task(self, task, sample, dev): h, r, t, qs = sample['h'], sample['r'], sample['t'], sample['qs'] if task == 'tail': ctx = self._encode(h, r, 0, qs, dev, mask_pos=2) return self.tail_head(ctx) elif task == 'relation': ctx = self._encode(h, r, t, qs, dev, mask_pos=1) return self.rel_head(ctx) elif task == 'qual_key': ctx = self._encode(h, r, t, [], dev) return self.qk_head(ctx) elif task == 'qual_value': qk = sample['qk'] filtered = [(k, v) for k, v in qs if k != qk] ctx = self._encode(h, r, t, filtered, dev) return self.qv_head(ctx) def train_multitask(model, model_name="MultiTask_AlertStar"): dev = CONFIG['device'] model.to(dev) opt = torch.optim.Adam(model.parameters(), lr=CONFIG['mt_lr']) mt_train = MultiTaskDataset(train_data, preprocessor) mt_valid = MultiTaskDataset(valid_data, preprocessor) loader = DataLoader(mt_train, batch_size=CONFIG['mt_batch'], shuffle=True, collate_fn=mt_collate) weights = {'tail': 1.0, 'relation': 1.0, 'qual_key': 0.5, 'qual_value': 0.8} print(f"\n{'='*60}\nTRAINING {model_name} " f"({sum(p.numel() for p in model.parameters()):,} params)\n{'='*60}") print(f" Train samples: {len(mt_train):,} " f"Valid samples: {len(mt_valid):,}") best_avg, history = 0.0, [] for epoch in range(CONFIG['mt_epochs']): model.train() total_loss = 0.0; cnt = 0 pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{CONFIG['mt_epochs']}") for by_task in pbar: batch_loss = torch.tensor(0., device=dev) for task, samples in by_task.items(): for s in samples: try: logits = model.forward_task(task, s, dev) w = weights.get(task, 1.0) if task == 'tail': tgt = torch.tensor(s['t'], device=dev) loss = F.cross_entropy(logits.unsqueeze(0), tgt.unsqueeze(0)) elif task == 'relation': tgt = torch.tensor(s['r'], device=dev) loss = F.cross_entropy(logits.unsqueeze(0), tgt.unsqueeze(0)) elif task == 'qual_key': tgt = torch.zeros(NQK, device=dev) tgt[s['keys']] = 1.0 loss = F.binary_cross_entropy_with_logits(logits, tgt) elif task == 'qual_value': tgt = torch.tensor(s['qv'], device=dev) loss = F.cross_entropy(logits.unsqueeze(0), tgt.unsqueeze(0)) else: continue batch_loss = batch_loss + w * loss except Exception: continue opt.zero_grad(); batch_loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() total_loss += batch_loss.item(); cnt += 1 pbar.set_postfix({'loss': f'{batch_loss.item():.4f}'}) if (epoch + 1) % 5 == 0: m = evaluate_multitask(model, mt_valid, dev) avg = np.mean([m['tail']['mrr'], m['relation']['mrr'], m['qual_key']['f1'], m['qual_value']['mrr']]) # Changed from accuracy to mrr print(f" Epoch {epoch+1}: loss={total_loss/cnt:.4f} avg={avg:.4f}") print(f" tail: MR={m['tail']['mr']:.1f} MRR={m['tail']['mrr']:.4f} " f"H@1={m['tail']['hits@1']:.4f} H@3={m['tail']['hits@3']:.4f} " f"H@10={m['tail']['hits@10']:.4f}") print(f" relation: MR={m['relation']['mr']:.1f} MRR={m['relation']['mrr']:.4f} " f"Acc={m['relation']['accuracy']:.4f}") print(f" qual-key: F1={m['qual_key']['f1']:.4f}") print(f" qual-value: MR={m['qual_value']['mr']:.1f} MRR={m['qual_value']['mrr']:.4f} " f"Acc={m['qual_value']['accuracy']:.4f}") if avg > best_avg: best_avg = avg torch.save(model.state_dict(), f"{CONFIG['output_path']}{model_name}_best.pt") print(" → best saved") history.append({'epoch': epoch+1, **m}) return model, history def evaluate_multitask(model, mt_dataset, device, max_per_task=200): """UPDATED: Evaluate multitask with ALL metrics including MR, MRR, Hits for all tasks.""" model.eval() preds = defaultdict(list) by_task = defaultdict(list) for s in mt_dataset.samples: by_task[s['task']].append(s) results = {} with torch.no_grad(): # ── tail prediction (ALL METRICS) ─────────────────────────────── ranks = [] for s in by_task['tail'][:max_per_task]: try: logits = model.forward_task('tail', s, device) rank = (torch.argsort(logits, descending=True ) == s['t']).nonzero(as_tuple=True )[0].item() + 1 ranks.append(rank) except Exception: continue ranks = np.array(ranks) if ranks else np.array([NE]) results['tail'] = { 'mr': float(np.mean(ranks)), 'mrr': float(np.mean(1./ranks)), 'hits@1': float(np.mean(ranks<=1)), 'hits@3': float(np.mean(ranks<=3)), 'hits@10': float(np.mean(ranks<=10)), } # ── relation prediction (ALL METRICS) ─────────────────────────── ranks = [] correct = [] for s in by_task['relation'][:max_per_task]: try: logits = model.forward_task('relation', s, device) rank = (torch.argsort(logits, descending=True ) == s['r']).nonzero(as_tuple=True )[0].item() + 1 ranks.append(rank) pred = logits.argmax().item() correct.append(int(pred == s['r'])) except Exception: continue ranks = np.array(ranks) if ranks else np.array([NR]) results['relation'] = { 'mr': float(np.mean(ranks)), 'mrr': float(np.mean(1./ranks)), 'hits@1': float(np.mean(ranks<=1)), 'hits@3': float(np.mean(ranks<=3)), 'hits@10': float(np.mean(ranks<=10)), 'accuracy': float(np.mean(correct)) if correct else 0.0 } # ── qualifier key prediction (multi-label) ────────────────────── prec_list, rec_list = [], [] for s in by_task['qual_key'][:max_per_task]: try: logits = model.forward_task('qual_key', s, device) pred = (torch.sigmoid(logits) > 0.5).cpu().numpy() true = np.zeros(NQK); true[s['keys']] = 1 tp = (pred * true).sum() prec_list.append(tp / max(pred.sum(), 1)) rec_list.append(tp / max(true.sum(), 1)) except Exception: continue p = float(np.mean(prec_list)) if prec_list else 0. r = float(np.mean(rec_list)) if rec_list else 0. results['qual_key'] = { 'precision': p, 'recall': r, 'f1': 2*p*r/(p+r+1e-8)} # ── qualifier value prediction (ALL METRICS like HyNT) ───────── ranks = [] correct = [] for s in by_task['qual_value'][:max_per_task]: try: logits = model.forward_task('qual_value', s, device) # Ranking metrics (like HyNT) rank = (torch.argsort(logits, descending=True ) == s['qv']).nonzero(as_tuple=True )[0].item() + 1 ranks.append(rank) # Accuracy metric pred = logits.argmax().item() correct.append(int(pred == s['qv'])) except Exception: continue ranks = np.array(ranks) if ranks else np.array([NQV]) results['qual_value'] = { 'mr': float(np.mean(ranks)), 'mrr': float(np.mean(1./ranks)), 'hits@1': float(np.mean(ranks<=1)), 'hits@3': float(np.mean(ranks<=3)), 'hits@10': float(np.mean(ranks<=10)), 'accuracy': float(np.mean(correct)) if correct else 0.0 } return results print("MultiTask AlertStar defined") # ============================================================================ # CELL 12 — Train All Models # ============================================================================ print("\n" + "="*70) print("STARTING MODEL TRAINING") print("="*70) # ── StarE ────────────────────────────────────────────────────────────────── stare_model = StarEModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout']) stare_model, _ = train_standard(stare_model, "StarE") all_results['StarE'] = evaluate_model(stare_model, test_ds, CONFIG['device']) # ── ShrinkE ──────────────────────────────────────────────────────────────── shrinke_model = ShrinkEModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout']) shrinke_model, _ = train_standard(shrinke_model, "ShrinkE") all_results['ShrinkE'] = evaluate_model(shrinke_model, test_ds, CONFIG['device']) # ── NBFNet ───────────────────────────────────────────────────────────── nbfnet_model = NBFNet( NE, NR, dim=DIM, num_layers = CONFIG['nbfnet_layers'], chunk_size = CONFIG['nbfnet_chunk_size'], dropout = CONFIG['dropout'], short_cut = True, ) nbfnet_model, _ = train_true_nbfnet(nbfnet_model, "NBFNet") nbfnet_graph_dev = {k: v.to(CONFIG['device']) if torch.is_tensor(v) else v for k, v in nbfnet_graph.items()} all_results['NBFNet'] = evaluate_nbfnet( nbfnet_model, test_ds, nbfnet_graph_dev, CONFIG['device']) # ── AlertStar ────────────────────────────────────────────────────────────── alertstar_model = AlertStarModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout']) alertstar_model, _ = train_standard(alertstar_model, "AlertStar") all_results['AlertStar'] = evaluate_model(alertstar_model, test_ds, CONFIG['device']) print("\n✓ Standard models trained") for m, r in all_results.items(): print(f" {m:15s} MR={r['mr']:7.1f} MRR={r['mrr']:.4f} " f"H@1={r['hits@1']:.4f} H@3={r['hits@3']:.4f} H@10={r['hits@10']:.4f}") # ── Complex Query Models ─────────────────────────────────────────────────── print("\nGenerating complex queries...") complex_queries = generate_complex_queries(train_data, preprocessor, n=500) # StarQE starqe_model = StarQEModel(NE, NR, NQK, NQV, DIM, CONFIG['dropout']) starqe_model, _ = train_query_model(starqe_model, complex_queries, "StarQE") starqe_lp = evaluate_model(starqe_model, test_ds, CONFIG['device']) starqe_cq = evaluate_complex_queries(starqe_model, complex_queries) all_results['StarQE'] = starqe_lp # NBFNet+StarQE nbfqe_model = NBFNetStarQE(NE, NR, NQK, NQV, DIM, CONFIG['dropout']) nbfqe_model, _ = train_query_model(nbfqe_model, complex_queries, "NBFNet+StarQE") nbfqe_lp = evaluate_model(nbfqe_model, test_ds, CONFIG['device']) nbfqe_cq = evaluate_complex_queries(nbfqe_model, complex_queries) all_results['NBFNet+StarQE'] = nbfqe_lp print("\n Complex query models trained") for name, cq in [("StarQE", starqe_cq), ("NBFNet+StarQE", nbfqe_cq)]: avg = np.mean([v['mrr'] for v in cq.values()]) print(f" {name}: Avg-MRR={avg:.4f}") for k, v in cq.items(): print(f" {k}: MR={v['mr']:6.1f} MRR={v['mrr']:.4f} " f"H@1={v['hits@1']:.4f} H@3={v['hits@3']:.4f} H@10={v['hits@10']:.4f}") # ── MultiTask AlertStar ──────────────────────────────────────────────────── mt_model = MultiTaskAlertStar(NE, NR, NQK, NQV, DIM, dropout=CONFIG['dropout']) mt_model, mt_history = train_multitask(mt_model) mt_test_ds = MultiTaskDataset(test_data, preprocessor) mt_results = evaluate_multitask(mt_model, mt_test_ds, CONFIG['device']) # Store tail prediction metrics all_results['MultiTask_AlertStar'] = mt_results['tail'] print("\n MultiTask AlertStar evaluation complete") print(f"\n TAIL PREDICTION:") print(f" MR = {mt_results['tail']['mr']:.1f}") print(f" MRR = {mt_results['tail']['mrr']:.4f}") print(f" Hits@1 = {mt_results['tail']['hits@1']:.4f}") print(f" Hits@3 = {mt_results['tail']['hits@3']:.4f}") print(f" Hits@10= {mt_results['tail']['hits@10']:.4f}") print(f"\n RELATION PREDICTION:") print(f" MR = {mt_results['relation']['mr']:.1f}") print(f" MRR = {mt_results['relation']['mrr']:.4f}") print(f" Hits@1 = {mt_results['relation']['hits@1']:.4f}") print(f" Hits@3 = {mt_results['relation']['hits@3']:.4f}") print(f" Hits@10= {mt_results['relation']['hits@10']:.4f}") print(f" Acc = {mt_results['relation']['accuracy']:.4f}") print(f"\n QUALIFIER KEY:") print(f" Precision = {mt_results['qual_key']['precision']:.4f}") print(f" Recall = {mt_results['qual_key']['recall']:.4f}") print(f" F1 = {mt_results['qual_key']['f1']:.4f}") print(f"\n QUALIFIER VALUE:") print(f" MR = {mt_results['qual_value']['mr']:.1f}") print(f" MRR = {mt_results['qual_value']['mrr']:.4f}") print(f" Hits@1 = {mt_results['qual_value']['hits@1']:.4f}") print(f" Hits@3 = {mt_results['qual_value']['hits@3']:.4f}") print(f" Hits@10= {mt_results['qual_value']['hits@10']:.4f}") print(f" Acc = {mt_results['qual_value']['accuracy']:.4f}") # ============================================================================ # 13 — UPDATED: Complete Results Table with All Metrics # ============================================================================ print("\n" + "="*80) print("COMPLETE MODEL COMPARISON — ALL METRICS") print("="*80) df = pd.DataFrame({ 'Model': list(all_results.keys()), 'MR': [all_results[m]['mr'] for m in all_results], 'MRR': [all_results[m]['mrr'] for m in all_results], 'Hits@1': [all_results[m]['hits@1'] for m in all_results], 'Hits@3': [all_results[m]['hits@3'] for m in all_results], 'Hits@10': [all_results[m]['hits@10'] for m in all_results], }) df = df.sort_values('MRR', ascending=False).reset_index(drop=True) print(df.to_string(index=False)) best = df.iloc[0] print(f"\n Best Model: {best['Model']}") print(f" MR={best['MR']:.1f} MRR={best['MRR']:.4f} " f"H@1={best['Hits@1']:.4f} H@3={best['Hits@3']:.4f} H@10={best['Hits@10']:.4f}") # Multi-task detailed table print("\n" + "="*80) print("MULTI-TASK ALERTSTAR — DETAILED METRICS (HyNT-comparable)") print("="*80) mt_df = pd.DataFrame({ 'Task': ['Tail Pred', 'Relation Pred', 'Qual Key', 'Qual Value'], 'Primary Metric': [ f"MRR={mt_results['tail']['mrr']:.4f}", f"MRR={mt_results['relation']['mrr']:.4f}", f"F1={mt_results['qual_key']['f1']:.4f}", f"MRR={mt_results['qual_value']['mrr']:.4f}" ], 'MR': [ f"{mt_results['tail']['mr']:.1f}", f"{mt_results['relation']['mr']:.1f}", '-', f"{mt_results['qual_value']['mr']:.1f}" ], 'Hits@1': [ f"{mt_results['tail']['hits@1']:.4f}", f"{mt_results['relation']['hits@1']:.4f}", '-', f"{mt_results['qual_value']['hits@1']:.4f}" ], 'Hits@3': [ f"{mt_results['tail']['hits@3']:.4f}", f"{mt_results['relation']['hits@3']:.4f}", '-', f"{mt_results['qual_value']['hits@3']:.4f}" ], 'Hits@10': [ f"{mt_results['tail']['hits@10']:.4f}", f"{mt_results['relation']['hits@10']:.4f}", '-', f"{mt_results['qual_value']['hits@10']:.4f}" ], 'Accuracy/F1': [ '-', f"{mt_results['relation']['accuracy']:.4f}", f"{mt_results['qual_key']['f1']:.4f}", f"{mt_results['qual_value']['accuracy']:.4f}" ] }) print(mt_df.to_string(index=False)) # Complex query detailed table print("\n" + "="*80) print("COMPLEX QUERY ANSWERING — DETAILED METRICS") print("="*80) for model_name, cq_results in [("StarQE", starqe_cq), ("NBFNet+StarQE", nbfqe_cq)]: print(f"\n{model_name}:") cq_df = pd.DataFrame({ 'Query Type': list(cq_results.keys()), 'MR': [cq_results[qt]['mr'] for qt in cq_results], 'MRR': [cq_results[qt]['mrr'] for qt in cq_results], 'Hits@1': [cq_results[qt]['hits@1'] for qt in cq_results], 'Hits@3': [cq_results[qt]['hits@3'] for qt in cq_results], 'Hits@10': [cq_results[qt]['hits@10'] for qt in cq_results], }) print(cq_df.to_string(index=False)) # Save all results with open(f"{CONFIG['output_path']}all_results.json", 'w') as f: json.dump({ 'link_prediction': all_results, 'complex_queries': { 'StarQE': starqe_cq, 'NBFNet+StarQE':nbfqe_cq, }, 'multitask': mt_results, }, f, indent=2) print(f"\n Results saved to {CONFIG['output_path']}all_results.json") # ============================================================================ # 14 — UPDATED: Visualization with All Metrics # ============================================================================ fig, axes = plt.subplots(3, 3, figsize=(20, 15)) fig.suptitle("AlertStar: Complete Model Comparison (All Metrics)", fontsize=16, fontweight='bold') palette = sns.color_palette("Set2", len(df)) # Row 1: Link Prediction - MR, MRR, Hits@1 axes[0,0].bar(df['Model'], df['MR'], color=palette) axes[0,0].set_title("Link Prediction — MR (lower is better)", fontweight='bold') axes[0,0].set_ylabel("Mean Rank") axes[0,0].tick_params(axis='x', rotation=45, labelsize=9) axes[0,1].bar(df['Model'], df['MRR'], color=palette) axes[0,1].set_title("Link Prediction — MRR", fontweight='bold') axes[0,1].set_ylabel("MRR") axes[0,1].tick_params(axis='x', rotation=45, labelsize=9) for i, v in enumerate(df['MRR']): axes[0,1].text(i, v+0.005, f'{v:.3f}', ha='center', fontsize=8) axes[0,2].bar(df['Model'], df['Hits@1'], color=palette) axes[0,2].set_title("Link Prediction — Hits@1", fontweight='bold') axes[0,2].set_ylabel("Hits@1") axes[0,2].tick_params(axis='x', rotation=45, labelsize=9) # Row 2: Link Prediction - Hits@3, Hits@10, Combined axes[1,0].bar(df['Model'], df['Hits@3'], color=palette) axes[1,0].set_title("Link Prediction — Hits@3", fontweight='bold') axes[1,0].set_ylabel("Hits@3") axes[1,0].tick_params(axis='x', rotation=45, labelsize=9) axes[1,1].bar(df['Model'], df['Hits@10'], color=palette) axes[1,1].set_title("Link Prediction — Hits@10", fontweight='bold') axes[1,1].set_ylabel("Hits@10") axes[1,1].tick_params(axis='x', rotation=45, labelsize=9) # Combined metric visualization x = np.arange(len(df)) width = 0.15 axes[1,2].bar(x-2*width, df['Hits@1'], width, label='H@1', color='steelblue') axes[1,2].bar(x-width, df['Hits@3'], width, label='H@3', color='orange') axes[1,2].bar(x, df['Hits@10'], width, label='H@10', color='green') axes[1,2].bar(x+width, df['MRR'], width, label='MRR', color='red') axes[1,2].set_xticks(x) axes[1,2].set_xticklabels(df['Model'], rotation=45, ha='right', fontsize=8) axes[1,2].set_title("All Metrics Combined", fontweight='bold') axes[1,2].legend(fontsize=8) # Row 3: Complex Queries, Multi-task, Capability Matrix query_types = list(starqe_cq.keys()) x_cq = np.arange(len(query_types)) w = 0.35 axes[2,0].bar(x_cq-w/2, [starqe_cq[qt]['mrr'] for qt in query_types], w, label='StarQE', color='steelblue') axes[2,0].bar(x_cq+w/2, [nbfqe_cq[qt]['mrr'] for qt in query_types], w, label='NBFNet+StarQE', color='coral') axes[2,0].set_xticks(x_cq) axes[2,0].set_xticklabels(query_types) axes[2,0].set_title("Complex Query MRR by Type", fontweight='bold') axes[2,0].set_ylabel("MRR") axes[2,0].legend() # Multi-task results tasks = ['Tail\n(MRR)', 'Relation\n(MRR)', 'Qual-Key\n(F1)', 'Qual-Val\n(Acc)'] scores = [mt_results['tail']['mrr'], mt_results['relation']['mrr'], mt_results['qual_key']['f1'], mt_results['qual_value']['accuracy']] bars = axes[2,1].bar(tasks, scores, color=sns.color_palette("Set3", 4)) axes[2,1].set_title("MultiTask AlertStar — All Tasks", fontweight='bold') axes[2,1].set_ylabel("Score") axes[2,1].set_ylim(0, 1.1) for bar, v in zip(bars, scores): axes[2,1].text(bar.get_x()+bar.get_width()/2, v+0.02, f'{v:.3f}', ha='center', fontsize=10, fontweight='bold') # Capability matrix model_names = list(all_results.keys()) caps = {m: {'Link\nPred': 1, 'Relation\nPred': 1 if m == 'MultiTask_AlertStar' else 0, 'Qual\nPred': 1 if m == 'MultiTask_AlertStar' else 0, 'Complex\nQuery': 1 if m in ['StarQE','NBFNet+StarQE'] else 0} for m in model_names} cap_matrix = np.array([[caps[m][c] for c in ['Link\nPred','Relation\nPred', 'Qual\nPred','Complex\nQuery']] for m in model_names]) im = axes[2,2].imshow(cap_matrix, cmap='YlGn', aspect='auto', vmin=0, vmax=1) axes[2,2].set_xticks(range(4)) axes[2,2].set_xticklabels(['Link\nPred','Relation\nPred', 'Qual\nPred','Complex\nQuery'], fontsize=9) axes[2,2].set_yticks(range(len(model_names))) axes[2,2].set_yticklabels(model_names, fontsize=9) axes[2,2].set_title("Capability Matrix", fontweight='bold') for i in range(len(model_names)): for j in range(4): axes[2,2].text(j, i, '✓' if cap_matrix[i,j] else '✗', ha='center', va='center', fontsize=12, color='darkgreen' if cap_matrix[i,j] else 'gray') plt.tight_layout() plt.savefig(f"{CONFIG['output_path']}complete_comparison_all_metrics.png", dpi=300, bbox_inches='tight') plt.show() print(f"✓ Visualization saved") # ============================================================================ # 15- Final Summary # ============================================================================ def get_name(vocab_dict, idx): """Reverse lookup from id → name.""" rev = {v: k for k, v in vocab_dict.items()} return rev.get(idx, f"<{idx}>") dev = CONFIG['device'] print("\n" + "="*70) print("EXPERIMENT COMPLETE!") print("="*70) print(f"\n TRAINED {len(all_results)} MODELS:") for m in all_results: print(f" ✓ {m}") print(f"\n BEST LINK PREDICTION MODEL: {best['Model']}") print(f" MR = {best['MR']:.1f}") print(f" MRR = {best['MRR']:.4f}") print(f" Hits@1 = {best['Hits@1']:.4f}") print(f" Hits@3 = {best['Hits@3']:.4f}") print(f" Hits@10= {best['Hits@10']:.4f}") print(f"\n All outputs saved to: {CONFIG['output_path']}") print(" - all_results.json (complete metrics)") print(" - complete_comparison_all_metrics.png (visualization)") print(" - Model weights (.pt files)") print("\n" + "="*70)