import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch
import matplotlib.pyplot as plt
import os 
import argparse
import yaml
import random
from skimage.measure import label as label2
from skimage.measure import regionprops
from scipy.ndimage import label as label1
def set_reproducibility(seed):
    """
    Set the random seed for reproducibility in experiments.
    Parameters:
    seed (int): The seed value to set for reproducibility.
    Returns:
    None
    """
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.use_deterministic_algorithms(True, warn_only=True)
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ':4096:8' # or ':16:8' 16 or 4096 is mb of space for cublas
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
def str2bool(v):
    """
    Convert a string representation of a boolean to a boolean value.
    Parameters:
    v (str or bool): The input value to convert.
    Returns:
    bool: The corresponding boolean value.
    Raises:
    argparse.ArgumentTypeError: If the input is not a valid boolean string.
    """
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')
def parse_args():
    """
    Parse command-line arguments for training parameters.
    Returns:
    dict: A dictionary containing the configuration parameters from 
          the YAML file and command-line arguments.
    Description:
    - This function uses argparse to handle command-line input and 
      reads a YAML configuration file. 
    - Command-line arguments include learning rate, batch size, 
      weight decay, session name, number of epochs, device, stop 
      counter, and online logging options.
    """
    parser = argparse.ArgumentParser(description='Setting the training parameters')
    parser.add_argument('--yaml_file', type=str, help='Path to YAML file', default='wildfire.yaml')
    parser.add_argument('--LR', type=float, help='Learning Rate')
    parser.add_argument('--CHECKPOINT', type=str, help='Checkpoint')
    parser.add_argument('--BATCHSIZE', type=int, help='Batch Size')
    parser.add_argument('--WD', type=float, help='Weight decay')
    parser.add_argument('--SESSIONAME', type=str, help='Session Name')
    parser.add_argument('--EPOCHS', type=int, help='Number of Epochs')
    parser.add_argument('--DEVICE', type=str, help='Device "cpu" or "cuda"')
    parser.add_argument('--STOPCOUNTER', type=int, help='Stop counter')
    parser.add_argument('--ONLINELOG', type=str2bool, help='Online Log in weight and biases')
    parser.add_argument('--PRETRAINED', type=str, help='full path to pretrain weights')
    parser.add_argument('--OPTIM', type=str, help='Optimizer')
    parser.add_argument('--SCHED', type=str, help='Scheduler')
    
    args = parser.parse_args()
    args = {key: value for key, value in vars(args).items() if value is not None}
    with open('config/' + args['yaml_file'], 'r') as file:
        config = yaml.safe_load(file)
    for key, value in args.items():
        if key in config:
            config[key] = value
    return config
def plot_instance_masks(gt_instance_mask, pred_instance_mask, gt_class, pred_major_class, instance_id):
    """Helper function to plot the ground truth and predicted instance masks."""
    
    # Plot the images
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(gt_instance_mask.reshape(254, 254), cmap='gray')
    axes[0].set_title(f'Ground Truth Instance (Class {gt_class}) - ID {instance_id}')
    axes[0].axis('off')
    
    plt.show()
def process_mask(mask, k):
    mask = mask.detach().cpu().numpy()
    # Initialize the output mask
    output_mask = np.zeros_like(mask)
    
    # Iterate over classes 1 and 2
    for class_id in [1, 2]:
        # Create a binary mask for the current class
        class_mask = (mask == class_id).astype(np.uint8)
        
        # Label the connected components (instances)
        labeled_mask = label2(class_mask).reshape(mask.shape)
        # Get properties of labeled regions
        regions = regionprops(labeled_mask)
        
        for region in regions:
            if region.area >= k:
                # Keep the instance in the output mask
                output_mask[labeled_mask == region.label] = class_id
                
    return output_mask
def get_confusion_matrix_instancewise(seg_gt, output, num_class, k=0.5, ignore=255):
    """
    Compute the confusion matrix for a segmentation task based on instance-level accuracy.
    Parameters:
    seg_gt (torch.Tensor): Ground truth segmentation map with shape (N, H, W).
    output (torch.Tensor): Model output predictions with shape (N, num_class, H, W).
    num_class (int): The number of classes in the segmentation task.
    k (float): Threshold for considering an instance as True Positive (percentage overlap).
    ignore (int, optional): Class index to ignore in the computation. Defaults to 255.
    Returns:
    numpy.ndarray: Confusion matrix of shape (num_class, num_class).
    """
    # Convert tensors to long and find the prediction
    seg_gt = seg_gt.to(torch.long)
    seg_pred = output.argmax(dim=1).to(torch.long)
    # Create confusion matrix
    confusion_matrix = np.zeros((num_class, num_class), dtype=np.int64)
    img_shape = (seg_gt.shape[-2], seg_gt.shape[-1])
    for i in range(seg_gt.shape[0]):  # Iterate over batch
        gt_mask = seg_gt[i]
        pred_mask = seg_pred[i]
        # Ignore specified class (if necessary)
        valid_mask = gt_mask != ignore
        gt_mask = gt_mask[valid_mask].reshape(img_shape).cpu()
        pred_mask = pred_mask[valid_mask].reshape(img_shape).cpu()
        labeled_gt, num_gt_instances = label1(gt_mask)  # Get instances in GT
        for gt_instance_id in range(0, num_gt_instances + 1):  # Skip background (ID 0)
            # Get binary mask for this ground truth instance
            gt_instance_mask = (labeled_gt == gt_instance_id).reshape(img_shape)
            gt_class = np.unique(gt_mask[gt_instance_mask])  # Should be one class per instance
            gt_class = gt_class[0]  # Since it's one instance, there should be a single class
            # Get the predicted labels in the same area
            pred_instance_mask = pred_mask[gt_instance_mask]
            # Find the most common class in the predicted instance mask
            pred_class, counts = np.unique(pred_instance_mask, return_counts=True)
            pred_major_class = pred_class[np.argmax(counts)]
            # Calculate overlap ratio for the predicted class vs ground truth
            intersection = (pred_instance_mask == gt_class).sum()
            union = gt_instance_mask.sum()
            overlap_ratio = intersection / float(union)
            # plot_instance_masks(gt_instance_mask, pred_instance_mask, gt_class, pred_major_class, gt_instance_id)
            # Decide if it's a True Positive, False Positive, or False Negative
            if overlap_ratio >= k:
                confusion_matrix[gt_class, gt_class] += 1  # True Positive
            else:
                confusion_matrix[gt_class, pred_major_class] += 1  # False Positive or FN
                
    return confusion_matrix
def get_confusion_matrix(seg_gt, output, num_class, ignore=255):
    """
    Compute the confusion matrix for a segmentation task.
    Parameters:
    seg_gt (torch.Tensor): Ground truth segmentation map with shape (N, H, W).
    output (torch.Tensor): Model output predictions with shape (N, num_class, H, W).
    num_class (int): The number of classes in the segmentation task.
    ignore (int, optional): Class index to ignore in the computation. Defaults to 255.
    Returns:
    numpy.ndarray: Confusion matrix of shape (num_class, num_class).
    """
    seg_gt = seg_gt.to(torch.long)
    seg_pred = output.argmax(dim=1).to(torch.long)
    valid_mask = seg_gt != ignore
    seg_gt = seg_gt[valid_mask]
    seg_pred = seg_pred[valid_mask]
    index = seg_gt * num_class + seg_pred
    confusion_matrix = torch.bincount(index, minlength=num_class**2, weights=None).reshape(num_class, num_class).detach().cpu().numpy()
    return confusion_matrix
def calculate_metrics(confusion_matrix, runloss, cls_names = None, cls_weights=None, val=False):
    """
    Calculate various evaluation metrics from the confusion matrix.
    Parameters:
    confusion_matrix (numpy.ndarray): The confusion matrix of shape (num_classes, num_classes).
    runloss (float): The loss value for the current run.
    cls_names (list, optional): List of class names corresponding to the classes. Defaults to None.
    cls_weights (numpy.ndarray, optional): Weights for each class used in weighted metrics. Defaults to None.
    val (bool, optional): Indicator for validation metrics. Defaults to False.
    Returns:
    dict: A dictionary containing precision, recall, F1 score, accuracy, IoU, and other metrics.
    """
    num_classes = confusion_matrix.shape[0]
    # Initialize arrays to hold per class metrics
    precision = np.zeros(num_classes)
    recall = np.zeros(num_classes)
    f1_score = np.zeros(num_classes)
    accuracy = np.zeros(num_classes)
    iou = np.zeros(num_classes)
    # Total number of samples
    total_samples = np.sum(confusion_matrix)
    total_TN = 0
    # Loop over each class to compute precision, recall, f1-score, and accuracy
    for i in range(num_classes):
        TP = confusion_matrix[i, i]
        FP = np.sum(confusion_matrix[:, i]) - TP
        FN = np.sum(confusion_matrix[i, :]) - TP
        TN = total_samples - (TP + FP + FN)
        total_TN += TN
        
        precision[i] = TP / (TP + FP) if (TP + FP) != 0 else 1
        recall[i] = TP / (TP + FN) if (TP + FN) != 0 else 0
        f1_score[i] = 2 * precision[i] * recall[i] / (precision[i] + recall[i]) if (precision[i] + recall[i]) != 0 else 0
        accuracy[i] = (TP + TN) / total_samples if total_samples != 0 else 1
        iou[i] = TP / (TP + FN + FP) if (TP + FN + FP) != 0 else 0
    # Calculate average metrics
    avg_precision = np.mean(precision)
    avg_recall = np.mean(recall)
    avg_f1 = np.mean(f1_score)
    avg_acc = np.mean(accuracy)
    total_acc = np.sum(np.diag(confusion_matrix)) / total_samples
    miou = np.mean(iou)
    metrics = {
        f"{'val ' if val else ''}avg_f1": avg_f1,
        f"{'val ' if val else ''}avg_acc": avg_acc,
        f"{'val ' if val else ''}total_acc": total_acc,
        f"{'val ' if val else ''}avg_precision": avg_precision,
        f"{'val ' if val else ''}avg_recall": avg_recall,
        f"{'val ' if val else ''}miou": miou,
        f"{'val ' if val else ''}avgloss": runloss}
    names = np.arange(num_classes) if (cls_names == None) else cls_names
    cls_weights = np.ones(num_classes) if (cls_weights == None) else cls_weights / np.sum(cls_weights)
    metrics[f'{"val " if val else ""}weighted_f1'] = 0
    for i in range(num_classes):
        metrics[f'{"val " if val else ""}precision {names[i]}'] = precision[i]
        metrics[f'{"val " if val else ""}recall {names[i]}'] = recall[i]
        metrics[f'{"val " if val else ""}f1_score {names[i]}'] = f1_score[i]
        metrics[f'{"val " if val else ""}iou {names[i]}'] = iou[i]
        metrics[f'{"val " if val else ""}accuracy {names[i]}'] = accuracy[i]
        metrics[f'{"val " if val else ""}weighted_f1'] += cls_weights[i] * f1_score[i]
    return metrics, iou
def qualitive_eval(inf_model, val_data, ex_path='./outputs', name='example.png'):
    """
    Perform qualitative evaluation of the segmentation model and save the results.
    Parameters:
    inf_model (torch.nn.Module): The model used for inference.
    val_data (Dataset): The validation dataset containing images and labels.
    ex_path (str, optional): The directory path where output images will be saved. Defaults to './outputs'.
    name (str, optional): The filename for the saved output image. Defaults to 'example.png'.
    Returns:
    None
    """
    valid_loader = iter(DataLoader(val_data, batch_size=1, shuffle=True))
    f, ax = plt.subplots(2, 5, figsize=(20, 5))
    for sampleid in range(10):
        batch = next(valid_loader)
        images = batch[0]
        label = batch[1][0].detach().cpu().numpy().astype('uint8')
        outputs = inf_model(images)
        images = (np.transpose(images[0, :3, :, :].detach().cpu().numpy(), (1, 2, 0)) * val_data.std_rgb) + val_data.mean_rgb
        images = (images * 255).astype('uint8')
        outputs = outputs.detach().cpu().numpy().astype('uint8')[0]
        non_bg_idxs = outputs!=0
        outputs = val_data.label2color(outputs)
        images[non_bg_idxs] = 0.199 * images[non_bg_idxs] + 0.799 * outputs[non_bg_idxs]
        ax[sampleid // 5][sampleid % 5].imshow(images, aspect='auto')
    os.makedirs(ex_path, exist_ok=True)
    plt.savefig(os.path.join(ex_path, name))
    plt.close(f)
def qualitive_test(inf_model, val_data, ex_path='./outputs'):
    """
    Perform qualitative evaluation of the segmentation model and save the results.
    Parameters:
    inf_model (torch.nn.Module): The model used for inference.
    val_data (Dataset): The validation dataset containing images and labels.
    ex_path (str, optional): The directory path where output images will be saved. Defaults to './outputs'.
    name (str, optional): The filename for the saved output image. Defaults to 'example.png'.
    Returns:
    None
    """
    valid_loader = iter(DataLoader(val_data, batch_size=1, shuffle=True))
    os.makedirs(ex_path, exist_ok=True)
    for sampleid in range(len(valid_loader)):
        batch = next(valid_loader)
        images, name = batch[0], batch[3]
        label = batch[1][0].detach().cpu().numpy().astype('uint8')
        outputs = inf_model(images)
        images = (np.transpose(images[0, :3, :, :].detach().cpu().numpy(), (1, 2, 0)) * val_data.std_rgb) + val_data.mean_rgb
        images = (images * 255).astype('uint8')
        outputs = outputs.detach().cpu().numpy().astype('uint8')[0]
        non_bg_idxs = outputs!=0
        fire_instances = outputs == 2
        outputs = val_data.label2color(outputs)
        if val_data.num_classes == 3:
            outputs[fire_instances, 1:] = 0         # this is for making fire red
        images[non_bg_idxs] = 0.199 * images[non_bg_idxs] + 0.799 * outputs[non_bg_idxs]
        plt.imsave(os.path.join(ex_path, name[0][0]), images)
        plt.close()