framework-chl-temp / tester.py
tester.py
Raw
import torch
import rasterio
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import netCDF4 as nc
class DeepNN12(nn.Module):
        def __init__(self, input_dim):
            super(DeepNN12, self).__init__()
            self.layer1 = nn.Linear(input_dim, 512)
            self.layer2 = nn.Linear(512, 512 // 2)
            self.layer3 = nn.Linear(512 // 2, 512 // 4)
            self.output = nn.Linear(512 // 4, 1)
            self.dropout = nn.Dropout(0.2)

        def forward(self, x):
            x = F.leaky_relu(self.layer1(x))
            x = self.dropout(x)
            x = F.leaky_relu(self.layer2(x))
            x = self.dropout(x)
            x = F.leaky_relu(self.layer3(x))
            x = self.output(x)
            return x
        
class DeepNN3(nn.Module):
    def __init__(self, input_dim):
        super(DeepNN3, self).__init__()
        self.layer1 = nn.Linear(input_dim, 512)  # Primera capa con 256 neuronas
        self.layer2 = nn.Linear(512, 256) # Segunda capa con 128 neuronas
        self.layer3 = nn.Linear(256, 128)  # Tercera capa con 64 neuronas
        self.layer4 = nn.Linear(128, 64)  # Cuarta capa nueva con 32 neuronas
        self.layer5 = nn.Linear(64, 32)
        self.output = nn.Linear(32, 1)  # Capa de salida
        self.dropout = nn.Dropout(0.3)  # Dropout para regularización
    def forward(self, x):
        x = F.leaky_relu(self.layer1(x))  # LeakyReLU en lugar de ReLU
        x = self.dropout(x)
        x = F.leaky_relu(self.layer2(x))
        x = self.dropout(x)
        x = F.leaky_relu(self.layer3(x))
        x = self.dropout(x)
        x = F.leaky_relu(self.layer4(x))  # LeakyReLU en la nueva capa
        x = self.dropout(x)
        x = F.leaky_relu(self.layer5(x))
        x = self.output(x)
        return x

# Crear instancias de los modelos con el tamaño de entrada correcto
model_3_features = DeepNN3(input_dim=3)
model_2_features = DeepNN12(input_dim=2)
model_1_feature = DeepNN12(input_dim=1)

# Cargar los modelos entrenados
model_3_features.load_state_dict(torch.load('C:/Users/Personal/Documents/Datasets/DeepModels/DataSetDef3_deepnn_model.pth'))
model_2_features.load_state_dict(torch.load('C:/Users/Personal/Documents/Datasets/DeepModels/DataSetDef2_deepnn_model.pth'))
model_1_feature.load_state_dict(torch.load('C:/Users/Personal/Documents/Datasets/DeepModels/DataSetDef1_deepnn_model.pth'))

# Asegúrate de que los modelos estén en modo de evaluación
model_3_features.eval()
model_2_features.eval()
model_1_feature.eval()

# Cargar una imagen de referencia
# Cargar el archivo NetCDF
nc_file = 'C:/Users/Personal/Documents/Periodo5/Images2020WCO/ImPack4/Corregistradas/04_mapped_JPSS1_VIIRS.20200801T175400.L2.OC.x.nc'
ds = nc.Dataset(nc_file)
#Extraer banda clorofila-a
chlorophyll_a = ds.variables['chlor_a'][:]
crs1 = ds.variables['crs']

latitude = ds.variables['latitude'][:]
longitude = ds.variables['longitude'][:]
print(latitude.min(),latitude.max(),longitude.min(),longitude.max())

# Crear una transformada afín a partir de las coordenadas
transform = rasterio.transform.from_bounds(
    west=longitude.min(),
    south=latitude.min(),
    east=longitude.max(),
    north=latitude.max(),
    width=chlorophyll_a.shape[1],
    height=chlorophyll_a.shape[0]
)

# Definir el CRS (sistema de referencia de coordenadas)
crs = crs1.wkt
def load_image(file_path):
    with rasterio.open(file_path) as src:
        image = src.read(1)  # Lee la primera banda (asumiendo que es una imagen en escala de grises)
        profile = src.profile
    return image, profile

# Cargar las imágenes de mayo, junio, julio y la imagen objetivo de agosto
image_august, profile_august = load_image('C:/Users/Personal/Documents/Periodo5/Images2020WCO/ImPack4/Corregistradas/modis_clo/clora_04_mapped_JPSS1_VIIRS.20200801T175400.L2.OC.x.nc.tif')
image_july, profile_july = load_image('C:/Users/Personal/Documents/Periodo5/Images2020WCO/ImPack4/Corregistradas/modis_clo/clora_03_mapped_JPSS1_VIIRS.20200701T191800.L2.OC.x.nc.tif')
image_june, profile_june = load_image('C:/Users/Personal/Documents/Periodo5/Images2020WCO/ImPack4/Corregistradas/modis_clo/clora_02_mapped_JPSS1_VIIRS.20200601T184201.L2.OC.x.nc.tif')
image_may, profile_may = load_image('C:/Users/Personal/Documents/Periodo5/Images2020WCO/ImPack4/Corregistradas/modis_clo/clora_01_mapped_JPSS1_VIIRS.20200501T181801.L2.OC.x.nc.tif')

def predict_chlorophyll(image_target, image_july, image_june, image_may):
    # Asumiendo que las imágenes son numpy arrays de la misma dimensión
    predicted_image = np.zeros_like(image_target)
    rows, cols = image_target.shape

    for i in range(rows):
        for j in range(cols):
            if image_target[i, j] == 0.0:  # Si el pixel en agosto está nublado
                july_val = image_july[i, j]
                june_val = image_june[i, j]
                may_val = image_may[i, j]

                # Caso 1: Todos los tres meses son válidos
                if july_val != 0.0 and june_val != 0.0 and may_val != 0.0:
                    features = torch.tensor([july_val, june_val, may_val], dtype=torch.float32).unsqueeze(0)
                    predicted_image[i, j] = model_3_features(features).item()
                
                # Caso 2: Solo julio y junio son válidos
                elif july_val != 0.0 and june_val != 0.0:
                    features = torch.tensor([july_val, june_val], dtype=torch.float32).unsqueeze(0)
                    predicted_image[i, j] = model_2_features(features).item()
                
                # Caso 3: Solo julio es válido
                elif july_val != 0.0:
                    features = torch.tensor([july_val], dtype=torch.float32).unsqueeze(0)
                    predicted_image[i, j] = model_1_feature(features).item()
                
                # Si no se cumplen las condiciones, dejar el valor como 0.0
                else:
                    predicted_image[i, j] = 0.0
            else:
                predicted_image[i, j] = image_target[i, j]  # Si el pixel no está nublado, dejar el valor original

    return predicted_image

def save_image(image, file_path, profile):
    with rasterio.open(file_path, 'w', **profile) as dst:
        dst.write(image, 1)

# Aplicar la predicción
predicted_august = predict_chlorophyll(image_august, image_july, image_june, image_may)
print('height',predicted_august.shape[0], 'width',predicted_august.shape[1])

profile = {
    'driver': 'GTiff',
    'height': predicted_august.shape[0],
    'width': predicted_august.shape[1],
    'count': 1,
    'crs': crs,
    'transform': transform
}
print(profile)

# Ejemplo de una ruta completa
output_path = 'C:/Users/Personal/Documents/Datasets/predicted_august2020def.tiff'
save_image(predicted_august, output_path, profile)