import argparse from load_data import load_val_dataset, load_test_dataset from model import ResNet18, ResNet50, VGG11 from misc.temperature_scaling import calibrate from collections import Counter from utils import gather_outputs, get_threshold, get_im_estimate, get_temp_dir from label_shift_utils import get_dirichlet_marginal, get_resampled_indices import os import json import random import numpy as np import torch import torch.nn as nn from torch_datasets.configs import ( get_n_classes, get_expected_label_distribution, sample_label_dist, sample_val_label_dist ) from tqdm import tqdm import time import math import ot os.environ["CUDA_VISIBLE_DEVICES"] = "0" def main(): # generic configs parser = argparse.ArgumentParser(description='Estimate target domain performance.') parser.add_argument('--arch', default='resnet18', type=str) parser.add_argument('--metric', default='EMD', type=str) parser.add_argument('--dataset', default='cifar-10', type=str) parser.add_argument('--batch_size', default=128, type=int) parser.add_argument('--n_val_samples', default=10000, type=int) parser.add_argument('--n_test_samples', default=-1, type=int) parser.add_argument('--dataset_seed', default=1, type=int) parser.add_argument('--pretrained', action='store_true', default=False) parser.add_argument('--model_seed', default=1, type=int) parser.add_argument('--ckpt_epoch', default=20, type=int) # synthetic shifts configs parser.add_argument('--data_path', default='./data/CIFAR-10', type=str) parser.add_argument('--subpopulation', default='same', type=str) parser.add_argument('--corruption_path', default='./data/CIFAR-10-C/', type=str) parser.add_argument('--corruption', default='clean', type=str) parser.add_argument('--severity', default=0, type=int) args = parser.parse_args() print(vars(args)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") metric = args.metric pretrained = args.pretrained model_seed = args.model_seed model_epoch = args.ckpt_epoch n_test_sample = args.n_test_samples dsname = args.dataset corruption = args.corruption severity = args.severity n_class = get_n_classes(args.dataset) use_py = False # load in iid data for calibration val_set = load_val_dataset(dsname=dsname, iid_path=args.data_path, pretrained=pretrained, n_val_samples=args.n_val_samples, seed=args.dataset_seed) val_iid_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=4) # load in ood test data valset_ood = load_test_dataset(dsname=dsname, subpopulation=args.subpopulation, iid_path=args.data_path, corr_path=args.corruption_path, corr=args.corruption, corr_sev=args.severity, pretrained=pretrained, n_test_sample=n_test_sample) val_ood_loader = torch.utils.data.DataLoader(valset_ood, batch_size=args.batch_size, shuffle=False, num_workers=4) n_test_sample = len(valset_ood) opt_bias = False cal_str = 'bcts' if opt_bias else 'ts' if pretrained: cache_dir = f"./cache/{dsname}/{args.arch}_{model_seed}-{model_epoch}/pretrained_{cal_str}" else: cache_dir = f"./cache/{dsname}/{args.arch}_{model_seed}-{model_epoch}/scratch_{cal_str}" os.makedirs(cache_dir, exist_ok=True) cache_id_dir = f"{cache_dir}/id_m{model_seed}-{model_epoch}_d{args.dataset_seed}.pkl" cache_od_dir = f"{cache_dir}/od_p{args.subpopulation}_m{model_seed}-{model_epoch}_c{corruption}-{severity}_n{n_test_sample}.pkl" if pretrained: save_dir_path = f"./checkpoints/{dsname}/{args.arch}/pretrained" else: save_dir_path = f"./checkpoints/{dsname}/{args.arch}/scratch" ckpt = torch.load(f"{save_dir_path}/base_model_{args.model_seed}-{model_epoch}.pt", map_location=device) model = ckpt['model'].module model.eval() # use temperature scaling to calibrate model print('calibrating models...') temp_dir = get_temp_dir(cache_dir, model_seed, model_epoch, opt_bias=opt_bias) model = calibrate(model, n_class, opt_bias, val_iid_loader, temp_dir) print('calibration done.') iid_acts, iid_preds, iid_tars = gather_outputs(model, val_iid_loader, device, cache_id_dir) ood_acts, ood_preds, ood_tars = gather_outputs(model, val_ood_loader, device, cache_od_dir) da = 50 src_label_dist = np.array(get_expected_label_distribution(dsname)) target_label_dist = get_dirichlet_marginal( da * get_n_classes(dsname) * src_label_dist, 1 ) if da != 'none': target_test_idx = get_resampled_indices( ood_tars.cpu().numpy(), n_class, target_label_dist, seed=1, ) else: target_test_idx = torch.arange(len(ood_acts)) ood_acts = ood_acts[target_test_idx] ood_preds = ood_preds[target_test_idx] ood_tars = ood_tars[target_test_idx] n_all_test_sample = n_test_sample n_test_sample = len(ood_tars) # print('shifted target marginal:', target_label_dist.tolist()) act_fn = nn.Softmax(dim=1) iid_acts = act_fn(iid_acts).cpu() ood_acts = act_fn(ood_acts).cpu() iid_acc = ( (iid_preds == iid_tars).sum() / len(iid_tars) ).item() ood_acc = ( (ood_preds == ood_tars).sum() / len(ood_tars) ).item() conf = iid_acts.amax(1).mean().item() print('n ood test sample:', n_test_sample) print('------------------') print('validation acc:', iid_acc) print('validation confidence:', conf) print('confidence gap:', conf - iid_acc) print('------------------') print() ood_preds_count = Counter(ood_preds.tolist()) ood_tars_count = Counter(ood_tars.tolist()) iid_preds_count = Counter(iid_tars.tolist()) iid_tars_dist = get_expected_label_distribution(args.dataset) ood_tars_dist = [ood_tars_count[i] / len(ood_acts) for i in range(n_class)] ood_preds_dist = [ood_preds_count[i] / len(ood_acts) for i in range(n_class)] iid_preds_dist = [iid_preds_count[i] / len(iid_acts) for i in range(n_class)] print('------------------') print("ood real label tv:", sum(abs(np.array(ood_tars_dist) - np.array(iid_tars_dist))) / 2 ) print("ood pseudo label tv:", sum(abs(np.array(ood_preds_dist) - np.array(iid_preds_dist))) / 2 ) print("ood pseudo-real label tv:", sum(abs(np.array(ood_preds_dist) - np.array(ood_tars_dist))) / 2 ) print('------------------') print() start = time.time() if metric == 'AC': max_confidence = torch.max(ood_acts, dim=-1)[0] est = 1 - torch.mean(max_confidence).item() elif metric == 'DoC': source_prob = iid_acts.max(1)[0] target_prob = ood_acts.max(1)[0] source_err = (iid_preds != iid_tars).sum().item() / len(iid_tars) est = source_err + torch.mean(source_prob).item() - torch.mean(target_prob).item() elif metric == 'IM': source_prob = iid_acts.max(1)[0] target_prob = ood_acts.max(1)[0] est = get_im_estimate(source_prob, target_prob, (iid_preds == iid_tars).cpu()).item() elif metric == 'GDE': seeds = [0, 1, 10] seed_ind = seeds.index(model_seed) alt_model_seed = seeds[ (seed_ind + 1) % len(seeds) ] alt_ckpt = torch.load(f"{save_dir_path}/base_model_{alt_model_seed}-{model_epoch}.pt", map_location=device) alt_model = alt_ckpt['model'].module alt_model.eval() if pretrained: alt_cache_dir = f"./cache/{dsname}/{args.arch}_{alt_model_seed}-{model_epoch}/pretrained_{cal_str}" else: alt_cache_dir = f"./cache/{dsname}/{args.arch}_{alt_model_seed}-{model_epoch}/scratch_{cal_str}" os.makedirs(alt_cache_dir, exist_ok=True) alt_temp_dir = get_temp_dir(alt_cache_dir, alt_model_seed, model_epoch, opt_bias=opt_bias) alt_model = calibrate(alt_model, n_class, opt_bias, val_iid_loader, alt_temp_dir) alt_cache_od_dir = f"{alt_cache_dir}/od_p{args.subpopulation}_m{alt_model_seed}-{model_epoch}_c{corruption}-{severity}_n{n_all_test_sample}.pkl" _, alt_ood_preds, _ = gather_outputs(alt_model, val_ood_loader, device, alt_cache_od_dir) alt_ood_preds = alt_ood_preds[target_test_idx] est = alt_ood_preds.ne(ood_preds).sum().item() / len(alt_ood_preds) elif metric == 'ATC-MC': threshold = get_threshold(model, val_iid_loader, n_class, args) mc = ood_acts.max(1)[0] est = (mc < threshold).sum().item() / len(ood_acts) cost_dist = torch.sort(mc)[0].tolist() elif metric == 'ATC-NE': threshold = get_threshold(model, val_iid_loader, n_class, args) ne = torch.sum(ood_acts * torch.log2(ood_acts), dim=1) est = (ne < threshold).sum().item() / len(ood_acts) cost_dist = torch.sort(ne)[0].tolist() elif metric == 'Pseudo': max_confidence = torch.max(ood_acts, dim=-1)[0] ac_target_error = 1 - torch.mean(max_confidence).item() est = min( sum(abs(np.array(ood_preds_dist) - np.array(iid_tars_dist))) / 2 + ac_target_error, 1 ) elif metric == 'Pseudo-val': max_confidence = torch.max(ood_acts, dim=-1)[0] ac_target_error = 1 - torch.mean(max_confidence).item() est = min( sum(abs(np.array(ood_preds_dist) - np.array(iid_preds_dist))) / 2 + ac_target_error, 1 ) elif metric == 'ProjNorm': est = min( sum(abs(np.array(ood_preds_dist) - np.array(iid_preds_dist))) / 2 + ( 1 - iid_acc ), 1 ) elif metric == 'COT-val': batch_size = min(10000, n_test_sample) n_batch = math.ceil( n_test_sample / batch_size ) print( f'total of {n_test_sample} test samples, running {n_batch} batches.' ) if n_batch > 1: est = 0 random.seed(0) for _ in range(n_batch): rand_inds = torch.as_tensor( random.choices( list(range(n_test_sample)), k=batch_size ) ) iid_acts_batch = nn.functional.one_hot( sample_val_label_dist(iid_preds_dist, n_class, batch_size) ) ood_acts_batch = ood_acts[rand_inds] M = torch.cdist(iid_acts_batch.float(), ood_acts_batch, p=1) weights = torch.as_tensor([]) est += ( ot.emd2(weights, weights, M, numItermax=1e8, numThreads=8) / 2 ).item() est = est / n_batch else: torch.manual_seed(0) exp_labels = sample_val_label_dist(iid_preds_dist, n_class, len(ood_acts)) iid_acts = nn.functional.one_hot(exp_labels) M = torch.cdist(iid_acts.float(), ood_acts, p=1) weights = torch.as_tensor([]) est = ( ot.emd2(weights, weights, M, numItermax=1e8, numThreads=8) / 2 ).item() elif metric == 'COT': batch_size = min(10000, n_test_sample) n_batch = math.ceil( n_test_sample / batch_size) print( f'total of {n_test_sample} test samples, running {n_batch} batches.' ) if n_batch > 1: est = 0 random.seed(10) for _ in range(n_batch): rand_inds = torch.as_tensor( random.choices( list(range(n_test_sample)), k=batch_size ) ) if not use_py: iid_acts_batch = nn.functional.one_hot( sample_label_dist(dsname, n_class, batch_size) ) else: iid_acts_batch = nn.functional.one_hot( ood_tars[rand_inds].cpu() ) ood_acts_batch = ood_acts[rand_inds] M = torch.cdist(iid_acts_batch.float(), ood_acts_batch, p=1) weights = torch.as_tensor([]) est += ( ot.emd2(weights, weights, M, numItermax=1e8, numThreads=8) / 2 ).item() est = est / n_batch else: torch.manual_seed(0) exp_labels = sample_label_dist(dsname, n_class, len(ood_acts)) if not use_py: iid_acts = nn.functional.one_hot(exp_labels) else: iid_acts = nn.functional.one_hot(ood_tars).cpu() M = torch.cdist(iid_acts.float(), ood_acts, p=1) weights = torch.as_tensor([]) est = ( ot.emd2(weights, weights, M, numItermax=1e8, numThreads=8) / 2 ).item() elif metric in ['COTT-MC', 'COTT-NE', 'COTT-val-MC']: threshold = get_threshold(model, val_iid_loader, n_class, args) batch_size = min(10000, n_test_sample) n_batch = math.ceil( n_test_sample / batch_size ) print( f'total of {n_test_sample} test samples, running {n_batch} batches.' ) if n_batch > 1: est = 0 random.seed(10) cost_dist = [] for _ in range(n_batch): rand_inds = torch.as_tensor( random.choices( list(range(n_test_sample)), k=batch_size ) ) ood_acts_batch = ood_acts[rand_inds] if metric == 'COTT-val-MC': exp_labels_batch = sample_val_label_dist(iid_preds_dist, n_class, batch_size) else: exp_labels_batch = sample_label_dist(dsname, n_class, batch_size) if not use_py: iid_acts_batch = nn.functional.one_hot(exp_labels_batch) else: iid_acts_batch = nn.functional.one_hot(ood_tars[rand_inds].cpu()) M = torch.cdist(iid_acts_batch.float(), ood_acts_batch, p=1) weights = torch.as_tensor([]) Pi = ot.emd(weights, weights, M, numItermax=1e8) if metric in ['COTT-MC', 'COTT-val-MC']: costs = ( Pi * M.shape[0] * M ).sum(1) * -1 elif metric == 'COTT-NE': matched_ood_acts_batch = ood_acts_batch[torch.argmax(Pi, dim=1)] matched_acts = (matched_ood_acts_batch + iid_acts_batch) / 2 costs = ( matched_acts * torch.log2( matched_acts ) ).sum(1) est = est + (costs < threshold).sum().item() / batch_size cost_dist.append(costs) est = est / n_batch cost_dist = torch.sort(torch.cat(cost_dist, dim=0))[0].tolist() else: torch.manual_seed(10) if metric == 'COTT-val-MC': exp_labels = sample_val_label_dist(iid_preds_dist, n_class, n_test_sample) else: exp_labels = sample_label_dist(dsname, n_class, n_test_sample) if not use_py: iid_acts = nn.functional.one_hot(exp_labels) else: iid_acts = nn.functional.one_hot(ood_tars.cpu()) M = torch.cdist(iid_acts.float(), ood_acts, p=1) weights = torch.as_tensor([]) Pi = ot.emd(weights, weights, M, numItermax=1e8) if metric in ['COTT-MC', 'COTT-val-MC']: costs = ( Pi * M.shape[0] * M ).sum(1) * -1 elif metric == 'COTT-NE': matched_ood_acts = ood_acts[torch.argmax(Pi, dim=1)] matched_acts = (matched_ood_acts + iid_acts) / 2 costs = ( matched_acts * torch.log2( matched_acts ) ).sum(1) est = (costs < threshold).sum().item() / batch_size cost_dist = torch.sort(costs)[0].tolist() elif metric == 'DCOT': if args.pretrained: cache_dir = f"cache/{dsname}/{args.arch}_{model_seed}-{model_epoch}/pretrained_dcot_base.json" else: cache_dir = f"cache/{dsname}/{args.arch}_{model_seed}-{model_epoch}/scratch_dcot_base.json" if not os.path.exists(cache_dir): torch.manual_seed(0) exp_labels = sample_label_dist(dsname, n_class, len(ood_acts)) label_acts = nn.functional.one_hot(exp_labels) M = torch.max( torch.abs( iid_acts.unsqueeze(1) - label_acts.unsqueeze(0) ), dim=-1)[0] weights = torch.as_tensor([]) base_est = ( ot.emd2( weights, weights, M, numItermax=1e8, numThreads=8 ) ).item() with open(cache_dir, 'w') as f: json.dump({'base': base_est}, f) else: with open(cache_dir, 'r') as f: base_est = json.load(f)['base'] M2 = torch.max( torch.abs( ood_acts.unsqueeze(1) - iid_acts.unsqueeze(0) ), dim=-1)[0] weights = torch.as_tensor([]) add_est = ( ot.emd2( weights, weights, M2, numItermax=1e8, numThreads=8 ) ).item() print('base est:', base_est) print('add est:', add_est) print('iid error:', 1 - iid_acc) est = base_est + add_est elif metric == 'SCOTT': t = get_threshold(model, val_iid_loader, n_class, args) torch.manual_seed(10) slices = torch.randn(8, n_class) slices = torch.stack([slice / torch.sqrt( torch.sum( slice ** 2 ) ) for slice in slices], dim=0) exp_labels = sample_label_dist(dsname, n_class, len(ood_acts)) iid_acts = nn.functional.one_hot(exp_labels) iid_act_scores = iid_acts.float() @ slices.T ood_act_scores = ood_acts.float() @ slices.T scores = torch.abs( torch.sort(ood_act_scores, dim=0)[0] - torch.sort(iid_act_scores, dim=0)[0] ) est = ( scores > t ).sum().item() / len(ood_acts) / len(slices) print('------------------') print('True OOD error:', 1 - ood_acc) print(f'{metric} predicted OOD error:', est) print(f'MAE: {abs(1 - ood_acc - est)}') print(f'Time: {time.time() - start}') print('------------------') print() if use_py: metric = metric + '-Py' n_test_str = args.n_test_samples if pretrained: result_dir = f"results/{dsname}/pretrained_{cal_str}/da_{da}/{args.arch}_{model_seed}-{model_epoch}/{metric}_{n_test_str}/{corruption}.json" else: result_dir = f"results/{dsname}/scratch_{cal_str}/da_{da}/{args.arch}_{model_seed}-{model_epoch}/{metric}_{n_test_str}/{corruption}.json" print(result_dir) os.makedirs(os.path.dirname(result_dir), exist_ok=True) if not os.path.exists(result_dir): with open(result_dir, 'w') as f: json.dump([], f) with open(result_dir, 'r') as f: data = json.load(f) data.append({ 'corruption': corruption, 'corruption level': severity, 'metric': float(est), 'ref': metric, 'acc': float(ood_acc), 'error': 1 - ood_acc, 'subpopulation': args.subpopulation, 'pretrained': pretrained }) with open(result_dir, 'w') as f: json.dump(data, f) if metric in ['ATC-MC', 'ATC-NE', 'COTT-MC', 'COTT-NE']: if args.pretrained: cost_dist_dir = f"cache/{dsname}/{args.arch}_{model_seed}-{model_epoch}/pretrained_{cal_str}/da_{da}/{metric}_costs/{corruption}.json" else: cost_dist_dir = f"cache/{dsname}/{args.arch}_{model_seed}-{model_epoch}/scratch_{cal_str}/da_{da}/{metric}_costs/{corruption}.json" os.makedirs(os.path.dirname(cost_dist_dir), exist_ok=True) if not os.path.exists(cost_dist_dir): with open(cost_dist_dir, 'w') as f: json.dump([], f) with open(cost_dist_dir, 'r') as f: saved_costs = json.load(f) saved_costs.append({ 'costs': cost_dist, 't': threshold, 'ood error': 1 - ood_acc, 'iid error': 1 - iid_acc, 'pred': ood_preds.tolist(), 'target': ood_tars.tolist(), 'pseudo-source shift': sum(abs(np.array(ood_preds_dist) - np.array(iid_tars_dist))) / 2, 'pseudo-target shift': sum(abs(np.array(ood_preds_dist) - np.array(ood_tars_dist))) / 2 }) with open(cost_dist_dir, 'w') as f: json.dump(saved_costs, f) if __name__ == "__main__": main()