import torch
import numpy as np
from sklearn.metrics import f1_score, precision_score, recall_score,roc_auc_score
class Client_GC():
def __init__(self, model, client_id, client_name, train_size, dataLoader, optimizer, args):
self.model = model.to(args.device)
self.id = client_id
self.name = client_name
self.train_size = train_size
self.dataLoader = dataLoader
self.optimizer = optimizer
self.args = args
self.W = {key: value for key, value in self.model.named_parameters()}
self.dW = {key: torch.zeros_like(value) for key, value in self.model.named_parameters()}
self.W_old = {key: value.data.clone() for key, value in self.model.named_parameters()}
self.gconvNames = None
self.train_stats = ([0], [0], [0], [0])
self.weightsNorm = 0.
self.gradsNorm = 0.
self.convGradsNorm = 0.
self.convWeightsNorm = 0.
self.convDWsNorm = 0.
def download_from_server(self, args, server):
self.gconvNames = server.W.keys()
if args.alg == 'fedstar':
for k in server.W:
if '_s' in k:
self.W[k].data = server.W[k].data.clone()
else:
for k in server.W:
self.W[k].data = server.W[k].data.clone()
def cache_weights(self):
for name in self.W.keys():
self.W_old[name].data = self.W[name].data.clone()
def reset(self):
copy(target=self.W, source=self.W_old, keys=self.gconvNames)
def local_train(self, local_epoch):
""" For self-train & FedAvg """
train_stats = train_gc(self.model, self.dataLoader, self.optimizer, local_epoch, self.args.device)
self.train_stats = train_stats
self.weightsNorm = torch.norm(flatten(self.W)).item()
weights_conv = {key: self.W[key] for key in self.gconvNames}
self.convWeightsNorm = torch.norm(flatten(weights_conv)).item()
grads = {key: value.grad for key, value in self.W.items()}
self.gradsNorm = torch.norm(flatten(grads)).item()
grads_conv = {key: self.W[key].grad for key in self.gconvNames}
self.convGradsNorm = torch.norm(flatten(grads_conv)).item()
def local_train_loss(self, local_epoch):
""" For self-train & FedAvg """
train_stats = train_gc_loss(self.model, self.dataLoader, self.optimizer, local_epoch, self.args.device)
self.train_stats = train_stats
self.weightsNorm = torch.norm(flatten(self.W)).item()
weights_conv = {key: self.W[key] for key in self.gconvNames}
self.convWeightsNorm = torch.norm(flatten(weights_conv)).item()
grads = {key: value.grad for key, value in self.W.items()}
self.gradsNorm = torch.norm(flatten(grads)).item()
grads_conv = {key: self.W[key].grad for key in self.gconvNames}
self.convGradsNorm = torch.norm(flatten(grads_conv)).item()
def compute_weight_update(self, local_epoch):
""" For GCFL """
copy(target=self.W_old, source=self.W, keys=self.gconvNames)
train_stats = train_gc(self.model, self.dataLoader, self.optimizer, local_epoch, self.args.device)
subtract_(target=self.dW, minuend=self.W, subtrahend=self.W_old)
self.train_stats = train_stats
self.weightsNorm = torch.norm(flatten(self.W)).item()
weights_conv = {key: self.W[key] for key in self.gconvNames}
self.convWeightsNorm = torch.norm(flatten(weights_conv)).item()
dWs_conv = {key: self.dW[key] for key in self.gconvNames}
self.convDWsNorm = torch.norm(flatten(dWs_conv)).item()
grads = {key: value.grad for key, value in self.W.items()}
self.gradsNorm = torch.norm(flatten(grads)).item()
grads_conv = {key: self.W[key].grad for key in self.gconvNames}
self.convGradsNorm = torch.norm(flatten(grads_conv)).item()
def evaluate(self):
return eval_gc(self.model, self.dataLoader['test'], self.args.device)
def evaluate_(self):
return eval_gc_(self.model, self.dataLoader['test'], self.args.device)
def evaluate_loss_(self):
return eval_gc_loss_(self.model, self.dataLoader['test'], self.args.device)
def local_train_prox(self, local_epoch, mu):
""" For FedProx """
train_stats = train_gc_prox(self.model, self.dataLoader, self.optimizer, local_epoch, self.args.device,
self.gconvNames, self.W, mu, self.W_old)
self.train_stats = train_stats
self.weightsNorm = torch.norm(flatten(self.W)).item()
weights_conv = {key: self.W[key] for key in self.gconvNames}
self.convWeightsNorm = torch.norm(flatten(weights_conv)).item()
grads = {key: value.grad for key, value in self.W.items()}
self.gradsNorm = torch.norm(flatten(grads)).item()
grads_conv = {key: self.W[key].grad for key in self.gconvNames}
self.convGradsNorm = torch.norm(flatten(grads_conv)).item()
def evaluate_prox(self, mu):
return eval_gc_prox(self.model, self.dataLoader['test'], self.args.device, self.gconvNames, mu, self.W_old)
def copy(target, source, keys):
for name in keys:
target[name].data = source[name].data.clone()
def subtract_(target, minuend, subtrahend):
for name in target:
target[name].data = minuend[name].data.clone() - subtrahend[name].data.clone()
def flatten(w):
return torch.cat([v.flatten() for v in w.values()])
def calc_gradsNorm(gconvNames, Ws):
grads_conv = {k: Ws[k].grad for k in gconvNames}
convGradsNorm = torch.norm(flatten(grads_conv)).item()
return convGradsNorm
def train_gc(model, dataloaders, optimizer, local_epoch, device):
losses_train, accs_train, losses_val, accs_val, losses_test, accs_test = [], [], [], [], [], []
train_loader, val_loader, test_loader = dataloaders['train'], dataloaders['val'], dataloaders['test']
for epoch in range(local_epoch):
model.train()
total_loss = 0.
ngraphs = 0
acc_sum = 0
for _, batch in enumerate(train_loader):
# print(batch.x)
# batch.x=batch.x[:,0:3]
# print(batch.x)
batch.to(device)
optimizer.zero_grad()
pred = model(batch)
label = batch.y
acc_sum += pred.max(dim=1)[1].eq(label).sum().item()
loss = model.loss(pred, label)
loss.backward()
optimizer.step()
total_loss += loss.item() * batch.num_graphs
ngraphs += batch.num_graphs
total_loss /= ngraphs
acc = acc_sum / ngraphs
loss_v, acc_v = eval_gc(model, val_loader, device)
loss_tt, acc_tt = eval_gc(model, test_loader, device)
losses_train.append(total_loss)
accs_train.append(acc)
losses_val.append(loss_v)
accs_val.append(acc_v)
losses_test.append(loss_tt)
accs_test.append(acc_tt)
return {'trainingLosses': losses_train, 'trainingAccs': accs_train, 'valLosses': losses_val, 'valAccs': accs_val,
'testLosses': losses_test, 'testAccs': accs_test}
def train_gc_loss(model, dataloaders, optimizer, local_epoch, device):
losses_train, accs_train, losses_val, accs_val, losses_test, accs_test = [], [], [], [], [], []
train_loader, val_loader, test_loader = dataloaders['train'], dataloaders['val'], dataloaders['test']
for epoch in range(local_epoch):
model.train()
total_loss = 0.
ngraphs = 0
acc_sum = 0
for _, batch in enumerate(train_loader):
# print(batch.x)
batch.x=batch.x[:,0:3]
# print(batch.x)
batch.to(device)
optimizer.zero_grad()
pred = model(batch)
label = batch.y
acc_sum += pred.max(dim=1)[1].eq(label).sum().item()
loss = model.loss(pred, label)
loss.backward()
optimizer.step()
total_loss += loss.item() * batch.num_graphs
ngraphs += batch.num_graphs
total_loss /= ngraphs
acc = acc_sum / ngraphs
loss_v, acc_v = eval_gc_loss(model, val_loader, device)
loss_tt, acc_tt = eval_gc_loss(model, test_loader, device)
losses_train.append(total_loss)
accs_train.append(acc)
losses_val.append(loss_v)
accs_val.append(acc_v)
losses_test.append(loss_tt)
accs_test.append(acc_tt)
return {'trainingLosses': losses_train, 'trainingAccs': accs_train, 'valLosses': losses_val, 'valAccs': accs_val,
'testLosses': losses_test, 'testAccs': accs_test}
def train_gc_pe(model, dataloaders, optimizer, local_epoch, device):
losses_train, accs_train, losses_val, accs_val, losses_test, accs_test = [], [], [], [], [], []
train_loader, val_loader, test_loader = dataloaders['train'], dataloaders['val'], dataloaders['test']
for epoch in range(local_epoch):
model.train()
total_loss = 0.
ngraphs = 0
acc_sum = 0
for _, batch in enumerate(train_loader):
batch.x = torch.ones_like(batch.x)
batch.to(device)
optimizer.zero_grad()
pred = model(batch)
label = batch.y
acc_sum += pred.max(dim=1)[1].eq(label).sum().item()
loss = model.loss(pred, label)
loss.backward()
optimizer.step()
total_loss += loss.item() * batch.num_graphs
ngraphs += batch.num_graphs
total_loss /= ngraphs
acc = acc_sum / ngraphs
loss_v, acc_v = eval_gc(model, val_loader, device)
loss_tt, acc_tt = eval_gc(model, test_loader, device)
losses_train.append(total_loss)
accs_train.append(acc)
losses_val.append(loss_v)
accs_val.append(acc_v)
losses_test.append(loss_tt)
accs_test.append(acc_tt)
return {'trainingLosses': losses_train, 'trainingAccs': accs_train, 'valLosses': losses_val, 'valAccs': accs_val,
'testLosses': losses_test, 'testAccs': accs_test}
def eval_gc(model, test_loader, device):
model.eval()
total_loss = 0.
acc_sum = 0.
ngraphs = 0
for batch in test_loader:
batch.to(device)
with torch.no_grad():
pred = model(batch)
label = batch.y
loss = model.loss(pred, label)
total_loss += loss.item() * batch.num_graphs
acc_sum += pred.max(dim=1)[1].eq(label).sum().item()
ngraphs += batch.num_graphs
return total_loss/ngraphs, acc_sum/ngraphs
def eval_gc_loss(model, test_loader, device):
model.eval()
total_loss = 0.
acc_sum = 0.
ngraphs = 0
for batch in test_loader:
# batch.x=batch.x[:,0:3]
batch.to(device)
with torch.no_grad():
pred = model(batch)
label = batch.y
loss = model.loss(pred, label)
total_loss += loss.item() * batch.num_graphs
acc_sum += pred.max(dim=1)[1].eq(label).sum().item()
ngraphs += batch.num_graphs
return total_loss/ngraphs, acc_sum/ngraphs
def eval_gc_(model, test_loader, device):
model.eval()
total_loss = 0.
acc_sum = 0.
ngraphs = 0
for batch in test_loader:
batch.to(device)
with torch.no_grad():
pred = model(batch)
label = batch.y
loss = model.loss(pred, label)
out=pred
pred = out.argmax(dim=1)
total_loss += loss.item() * batch.num_graphs
acc_sum += out.max(dim=1)[1].eq(label).sum().item()
ngraphs += batch.num_graphs
f1 = f1_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(), average='macro',
zero_division=1)
precision = precision_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(),
average='macro', zero_division=1)
recall = recall_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(), average='macro',
zero_division=1)
labels_list = label.detach().cpu().numpy()
auc = roc_auc_score(np.eye(max(labels_list) + 1)[labels_list], out.detach().cpu().numpy())
return total_loss/ngraphs, acc_sum/ngraphs,f1,precision,recall,auc
def eval_gc_loss_(model, test_loader, device):
model.eval()
total_loss = 0.
acc_sum = 0.
ngraphs = 0
for batch in test_loader:
batch.x=batch.x[:,0:3]
batch.to(device)
with torch.no_grad():
pred = model(batch)
label = batch.y
loss = model.loss(pred, label)
out=pred
pred = out.argmax(dim=1)
total_loss += loss.item() * batch.num_graphs
acc_sum += out.max(dim=1)[1].eq(label).sum().item()
ngraphs += batch.num_graphs
f1 = f1_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(), average='macro',
zero_division=1)
precision = precision_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(),
average='macro', zero_division=1)
recall = recall_score(y_true=label.detach().cpu().numpy(), y_pred=pred.detach().cpu().numpy(), average='macro',
zero_division=1)
labels_list = label.detach().cpu().numpy()
auc = roc_auc_score(np.eye(max(labels_list) + 1)[labels_list], out.detach().cpu().numpy())
return total_loss/ngraphs, acc_sum/ngraphs,f1,precision,recall,auc
def _prox_term(model, gconvNames, Wt):
prox = torch.tensor(0., requires_grad=True)
for name, param in model.named_parameters():
# only add the prox term for sharing layers (gConv)
if name in gconvNames:
prox = prox + torch.norm(param - Wt[name]).pow(2)
return prox
def train_gc_prox(model, dataloaders, optimizer, local_epoch, device, gconvNames, Ws, mu, Wt):
losses_train, accs_train, losses_val, accs_val, losses_test, accs_test = [], [], [], [], [], []
convGradsNorm = []
train_loader, val_loader, test_loader = dataloaders['train'], dataloaders['val'], dataloaders['test']
for epoch in range(local_epoch):
model.train()
total_loss = 0.
ngraphs = 0
acc_sum = 0
for _, batch in enumerate(train_loader):
batch.to(device)
optimizer.zero_grad()
pred = model(batch)
label = batch.y
acc_sum += pred.max(dim=1)[1].eq(label).sum().item()
loss = model.loss(pred, label) + mu / 2. * _prox_term(model, gconvNames, Wt)
loss.backward()
optimizer.step()
total_loss += loss.item() * batch.num_graphs
ngraphs += batch.num_graphs
total_loss /= ngraphs
acc = acc_sum / ngraphs
loss_v, acc_v = eval_gc(model, val_loader, device)
loss_tt, acc_tt = eval_gc(model, test_loader, device)
losses_train.append(total_loss)
accs_train.append(acc)
losses_val.append(loss_v)
accs_val.append(acc_v)
losses_test.append(loss_tt)
accs_test.append(acc_tt)
convGradsNorm.append(calc_gradsNorm(gconvNames, Ws))
return {'trainingLosses': losses_train, 'trainingAccs': accs_train, 'valLosses': losses_val, 'valAccs': accs_val,
'testLosses': losses_test, 'testAccs': accs_test, 'convGradsNorm': convGradsNorm}
def eval_gc_prox(model, test_loader, device, gconvNames, mu, Wt):
model.eval()
total_loss = 0.
acc_sum = 0.
ngraphs = 0
for batch in test_loader:
batch.to(device)
with torch.no_grad():
pred = model(batch)
label = batch.y
loss = model.loss(pred, label) + mu / 2. * _prox_term(model, gconvNames, Wt)
total_loss += loss.item() * batch.num_graphs
acc_sum += pred.max(dim=1)[1].eq(label).sum().item()
ngraphs += batch.num_graphs
return total_loss/ngraphs, acc_sum/ngraphs