classificationRice / toolsSeed.py
toolsSeed.py
Raw
# -*- coding: utf-8 -*-
"""
Created on Mon Oct 17 10:51:31 2022

@author: VIDO6
"""

import os
import numpy as np
import matplotlib.pyplot as plt
from spectral import open_image, envi
from   skimage.measure import perimeter, label
from skimage.morphology import erosion, square
from math import pi, ceil
#import torch
#import torchvision
#from   torchvision import transforms
from torch.utils.data import Dataset
#from   torch.utils.data.dataloader import DataLoader
#from   torch.utils.data import random_split
#from   sklearn.metrics import confusion_matrix, accuracy_score

#C:/Users/VIDO6/Downloads/class3


class SeedDataset(Dataset):
    
    
    def __init__(self, img_dir, transform=None, tipo_semilla = "", 
                 bandRGB=(0,0,0), subClase = "None", ww = ""):
        
        #if WW is str:
            
        vmax = SeedDataset.contarhdr(img_dir)
        
        if not type(ww) is str:
        
            ww.bProgreso.configure(maximum=vmax)
            
            ww.bProgreso['value'] = 0
        
        
        
        self.classes = os.listdir(img_dir)
        
        self.listaIndCla = list(range(len(self.classes)))
        
        self.semillas = []
        
        subClases = {}
      
        
        for clase in self.classes:
            subClases[clase] = []
     
        
        listaCTN=[]
        
        num = 0
        
        
        for ind, carpeta in enumerate(os.listdir(img_dir)):
          

            dirCarpeta = os.path.join(img_dir, carpeta)
            
            archivos = os.listdir(dirCarpeta)
            
            
            for arc in archivos:
    
                if arc[-3:] != 'hdr':
                    
                    continue
            
                
                dirArhivo = os.path.join(dirCarpeta, arc)
      
                #self.direcciones.append(dirArhivo)
                
                #self.img_labels.append(ind)
                
                subClases[carpeta].append(arc.rsplit('_')[-2])
                
                #listaTiempo.append()
                
                listaCTN.append((arc.rsplit('_')[-3],arc.rsplit('_')[-2],arc.rsplit('_')[-1][:-4]))
                
                
                
                
                
                self.semillas.append( Semilla(dirArhivo, num, 
                                              tipo_semilla=tipo_semilla, 
                                              indiceClase=ind, 
                                              clase=self.classes[ind], 
                                              bandRGB=bandRGB, 
                                              subClase = arc.rsplit('_')[-2]   ) )
                
                if not type(ww) is str:
                    ww.bProgreso['value'] = num+1
                    
                    ww.estadoCarg.set( "Loading\n" + str(num+1) + " of " + str(vmax) + "\nimages" )
                    
                    ww.top.update()
                
                num+=1
                

        todasSubClases = []   
        
        for clase in self.classes:
            subClases[clase] = list(set(subClases[clase]))
            todasSubClases.extend(subClases[clase])
            
        
        self.img_dir = img_dir
        
        self.subClasses = subClases
        
        self.todasSubClases = list( set(todasSubClases))
        
        self.todasSubClases.sort()
        
        self.listaCTN = listaCTN

        self.transform = transform



    def __len__(self):
        
        return len(self.semillas)



    def __getitem__(self, idx):
        
        seed = self.semillas[idx]

        #img_path = self.direcciones[idx]
        
        #image = open_image(img_path).load()
        
        #label = self.img_labels[idx]
        
        if self.transform:
            
            seed = self.transform(seed)
        #if self.target_transform:
            #label = self.target_transform(label)
            
        return seed
    
    
    
    def extraerNumerosCT(self, clase, tiempo):
        listaNumeros = []
        for tupla in self.listaCTN:
            if tupla[0]==clase and tupla[1]==tiempo:
                listaNumeros.append(int(tupla[2]))
                
        listaNumeros.sort()
        
        for i, num in enumerate(listaNumeros):
            listaNumeros[i]=str(num)
        
        return listaNumeros




    def contarhdr(mainDir):
        
        vmax = 0
        
        carpetas = os.listdir(mainDir)
        
      
        for car in carpetas:
      
            dirCarpeta = os.path.join(mainDir, car)
            
            archivos = os.listdir(dirCarpeta)
            
            
            for arc in archivos:
        
                if arc[-3:] == 'hdr':
                    
                    vmax = vmax + 1
        
        return vmax    
    
    
    
    
    def mostrarEspectroColeccion(self, ancho = 5 , altura = 5.5, ymax = 1.65, filas=1):
    
        M = len(self.classes)
        
        columnas = ceil(M/filas)
      
        plt.figure(figsize = (ancho*columnas, altura*filas))

      
        for i, sma in enumerate(self):
            
              
            plt.subplot(filas, columnas, sma.c+1 )
            
            sma.mostrarEspectro(promedio = "Si", formato = 0)
          
            
        for i in range(M):
            
            plt.subplot(filas, columnas, i+1 )
        
            plt.grid()
              
            plt.xlabel("Bands")
            
            plt.ylabel("Intensity")
              
            plt.title( f"Spectrum of {self.classes[i]}" )    
              
            plt.ylim(0, ymax) 
     
    
    def mostrarEspectroClaseYSubclase(self, ancho = 5 , altura = 5.5, ymax = 1.65):
    
        columnas = len(self.classes)
        
        filas = len(self.todasSubClases)
      
        plt.figure(figsize = (ancho*columnas, altura*filas))

      
        for i, sma in enumerate(self):
            
            
            j = self.todasSubClases.index(sma.subClase)
              
            plt.subplot(filas, columnas, columnas*j + sma.c+1 )
            
            
            
            sma.mostrarEspectro(promedio = "Si", formato = 0)
          
        '''   
        for i in range(M):
            
            plt.subplot(filas, columnas, i+1 )
        
            plt.grid()
              
            plt.xlabel("Bands")
            
            plt.ylabel("Intensity")
              
            plt.title( f"Spectrum of {self.classes[i]}" )    
              
            plt.ylim(0, ymax) 
        '''



class Semilla():


  ##############        Constructor           ##################


    def __init__(self, direccion, numero, tipo_semilla = "", indiceClase = 0, clase = "Control", bandRGB=(0,0,0), subClase = "" ):

            
        self.direccion = direccion
        
        imagen = self.imagen()
        
        self.numero, self.c = numero, indiceClase
        
        self.clase = clase

        self.titulo = "Class: {}, image[{}]".format(self.clase, self.numero)
    
        self.mascara = Semilla.hacerMascara(imagen)
  
        self.m, self.n, self.b = imagen.shape
        
        self.tipo_semilla = tipo_semilla
        
        self.indiceRGB = bandRGB
        
        self.RGB = self( self.indiceRGB )
        
        self.subClase = subClase
        
        self.area = self.area()
        
        self.espectroMedio = self.espectro(promedio = "Si", imagen=imagen)
        
        del imagen



    def imagen(self):
        
        return open_image(self.direccion).load()
        



    def __call__(self, banda, imagen = ""):
        
        if type(imagen) is str:
            
            imagen = self.imagen()
            
        if type(banda) is int:
            
            return imagen[:,:,banda].reshape(self.m, self.n)
        
        else:
            
            return imagen[:,:,banda].reshape(self.m, self.n, len(banda))
    
          

    def __str__(self):
        
        return self.titulo



    def hacerMascara(imagen):

        mascara = np.max(imagen, axis=2)
    
        mascara = (mascara>0)
    
        return mascara



    def espectro(self, promedio = "No", imagen = ""):
  
        '''
        Calcula el espectro de cada pixel en la imagen 
        Si promedio == "No" entonces no calcula el promedio de todos los pixeles
        en caso contrario si los calcula.
        '''
        
        if type(imagen) is str:
            
            imagen = self.imagen()   
        
        
        if promedio == "No":
            
            R = self.mascara   
        
            lmax = self.area
            
            tab_spec = np.zeros((lmax, self.b))
            
            
        
            for j in range(self.b):
              
                S = imagen[:,:,j].reshape(self.m, self.n)                
              
                tab_spec[:,j] = S[R]
          
            return tab_spec
    
        else:
            
            mediaConZeros = np.mean(imagen, axis=(0,1))
    
            return mediaConZeros * self.m * self.n / self.area



  
    def mostrarImagen(self, mostrarTitulo=True):

        imp = self.RGB
        
        plt.axis('off')
    
        plt.imshow(imp / np.max(imp) )
        
        
        if mostrarTitulo:
    
            #plt.title( self.titulo + ", Imagen [{}]".format(self.numero))
            plt.title( self.titulo )
    
    
    def mostrarEspectroxy(self,x,y):
        
        x, y = y, x
        
        img = self.imagen()
        
        plt.plot(img[x,y,:].reshape(self.b))
        
        plt.grid()
    
        plt.xlabel("Bands")
        
        plt.ylabel("Intensity")
        
        x, y = y, x
        
        plt.title(f"Spectrum x = {x}, y = {y}")

    def mostrarEspectro(self, promedio = "No", formato = 1):
        
        if promedio == "No":
            
            plt.plot(np.transpose( self.espectro(promedio=promedio) ) )
            
        else:
            
            plt.plot(self.espectroMedio)
    
        if formato == 1:
    
            plt.grid()
        
            plt.xlabel("Bands")
            
            plt.ylabel("Intensity")
    
            plt.title( self.titulo + ", Spectrum" )
            
    '''        

    def mostrarImagenYEspectro(self, ymax = 2):

        plt.figure(figsize=(10,5))
    
        plt.subplot(1,2,1)
        
        self.mostrarImagen()
    
        plt.subplot(1,2,2)
        
        self.mostrarEspectro( promedio = "Si" )
        
        plt.title("Mean Spectrum")
        
        plt.ylim(0, ymax) 
        
        
    '''
    
    def mostrarImagenYEspectro(self, ymax = 2):

        plt.figure(figsize=(15,5))
    
        plt.subplot(1,3,1)
        self.mostrarImagen()
    
        plt.subplot(1,3,2)
        self.mostrarEspectro() 
        plt.ylim(0,ymax)   
    
        plt.subplot(1,3,3)
        self.mostrarEspectro( promedio = "Si" )
        plt.title("Mean Spectrum")
        plt.ylim(0, ymax) 


    '''
    def modificarImagen(self, tam = 60):

        img = self.imagen
    
        if img.ndim == 3:
            
            m, n, b = img.shape
      
            imgMod = np.zeros((tam, tam, b))
            
            for i in range( min(m,tam) ):
                for j in range( min(n,tam) ):
                    imgMod[i,j, :] = img[i, j, :]
        
        elif img.ndim == 2:
          
              m, n = img.shape
              imgMod = np.zeros((tam, tam))
              
              for i in range( min(m,tam) ):
                  for j in range( min(n,tam) ):
                      imgMod[i,j] = img[i, j]    
    
    
        return imgMod
    '''



###############################################################################
############################  Metodos de clase   ##############################
###############################################################################



    def modificarImagen(img, tam = 60):
      
        posx = 0
        posy = 0
        
        if img.ndim == 3:
            
            m, n, b = img.shape
          
            if m < tam:
                posx = (tam - m) // 2
            
            if n < tam:
                posy = (tam - n) // 2
          
            imgMod = np.zeros((tam, tam, b))
            
            for i in range( min(m,tam) ):
                for j in range( min(n,tam) ):
                    imgMod[i + posx, j + posy, :] = img[i, j, :]
        
        elif img.ndim == 2:
          
            m, n = img.shape
          
            if m < tam:
                posx = (tam - m) // 2
            
            if n < tam:
                posy = (tam - n) // 2
          
            imgMod = np.zeros((tam, tam))
            
            for i in range( min(m,tam) ):
                for j in range( min(n,tam) ):
                    imgMod[i + posx, j + posy] = img[i, j]    
        
        
        return imgMod
      




    def minMaxLim(img):
        #img = erosion(img, square(3))
        (M,N)= img.shape
        imin = M
        imax = 0
        jmin = N
        jmax = 0
    
    
        for i in range(M):
            for j in range(N):
                if img[i,j]>0:
                    if i<imin:
                        imin = i
                    if i>imax:
                        imax = i
                    if j<jmin:
                        jmin = j
                    if j>jmax:
                        jmax = j
        
    
        if imin - 1< 0:
            imax = imax + 2
        elif imax + 1 > M:
            imin = imin - 2
        else:
            imin = imin - 1
            imax = imax + 1
        
    
        if jmin - 1 < 0:
            jmax = jmax + 2
        elif jmax + 1 > N:
            jmin = jmin - 2  
        else:
            jmin = jmin - 1
            jmax = jmax + 1  
    
    
        return (imin, imax, jmin, jmax)


    def listImageRec( img, th=0.25, ww = "" ):

        (m,n,b) = img.shape
        imgseg = img[:,:,170].reshape(m,n) > img[:,:,210].reshape(m,n)+th
        imgseg = erosion(imgseg, square(1))
        lista = []
       
        segLabel = label(imgseg)
        lmax = np.max(segLabel)
        
        vmax = lmax 
        
        if not type(ww) is str:
            ww.bProgreso2.configure(maximum=vmax)
            
            ww.bProgreso2['value'] = 0
        
        
        j = 0
       
    
        for i in range(1, lmax+1):
    
            R = (segLabel==i)
            
            if np.sum(R) < 500:
                
                if not type(ww) is str:
                    vmax -= 1
                    ww.bProgreso2.configure(maximum=vmax)
                
                continue
            
            (imin,imax,jmin,jmax)=Semilla.minMaxLim(R)
      
            imgR = Semilla.modificarImagen( img[imin:imax, jmin:jmax,:] )
            imgsegR = Semilla.modificarImagen( imgseg[imin:imax, jmin:jmax].reshape(imax-imin, jmax-jmin, 1) )
      
      
      
            lista.append(imgR*imgsegR)
            
            j += 1
            
            if not type(ww) is str:
            
                ww.bProgreso2['value'] = j
                
                ww.estadoCarg2.set( "Saving\n" + str(j) + " of " + str(vmax) + "\nimages" )
                
                ww.top.update()
            
            
          
    
        return lista
    
    
    def cutAndSegmentation(img, mainDir, clase, subclase="None", ww=""):
        
        listaImage = Semilla.listImageRec(img, ww = ww)
        
        clases = os.listdir(mainDir)
        
        mainDirCar = os.path.join(mainDir, clase)
        
        
        if not clase in clases:
            
            os.mkdir(mainDirCar)
    
        lgt = ceil(len(os.listdir(mainDirCar)) / 2)
    
        for i, imagenCut in enumerate(listaImage):
            
            nuevoNombre = "imagen_{}_{}_{}.hdr".format(clase, subclase, str(i+lgt)) 
            
            newFile = os.path.join(mainDirCar, nuevoNombre)
            
            envi.save_image(newFile, imagenCut)
   

             

  ################# Metodos para Colección   #####################
    '''
    def ordenarColeccion(coleccion):
    
        coleOrd = []
      
      
      
        for c in range(5):
            coleOrd.append([])
            for t in range(6):
                coleOrd[c].append([])
                if c == 0:
                    break
        
      
        for arroz in coleccion:
      
            c = arroz.c
            t = arroz.t
            coleOrd[c][t].append(arroz)
        
        for c in range(5):
          
            for t in range(6):
                coleOrd[c][t] = sorted(coleOrd[c][t], key=lambda grupoct : grupoct.numero)
            
                if c == 0:
                    break
        
        return coleOrd

#------------------------------------------------------------------------


    def hacerColeccion( mainDir =  "/content/drive/MyDrive/DataHyperspectral/rice/Procesado/class2" ):
    
        coleccion = []
      
        carpetas = os.listdir(mainDir)
      
        for car in carpetas:
      
            dirCarpeta = os.path.join(mainDir, car)
            
            archivos = os.listdir(dirCarpeta)
            
            for arc in archivos:
        
                if arc[-3:] != 'hdr':
                    continue
          
                dirArhivo = os.path.join(dirCarpeta, arc)
          
                objArroz = Semilla( dirArhivo )
          
                coleccion.append( objArroz )
      
        return Semilla.ordenarColeccion(coleccion)   
        #return coleccion
        '''
#----------------------------------------------------------------

    '''
    def mostrarEspectroct(coleccion, c, t):
        for arrozk in coleccion[c][t]:
            arrozk.mostrarEspectro(promedio = "Si", formato = 0)
      
        plt.grid()
      
        plt.xlabel("Bands")
        plt.ylabel("Intensity")
      
        plt.title( arrozk.titulo )    
    


#----------------------------------------------------------------------

  
    def mostrarEspectroColeccion(coleccion, indc = [0, 1, 2, 3, 4], indt = [0, 1, 2, 3, 4, 5], ancho = 5 , altura = 5.5, ymax = 1.65):
    
        M = len(indc)
        N = len(indt)
        i=0
      
        plt.figure(figsize = (ancho*N, altura*M))
      
        for i, c in enumerate(indc):
      
            for j, t in enumerate(indt):
            
                plt.subplot(M, N, N*i+j+1)
          
                Semilla.mostrarEspectroct(coleccion, c, t)
          
                plt.ylim(0, ymax) 
          
                if c == 0:
                    break


#--------------------------------------------------------------------


    def extraerEspectro(coleccion, c, t, promedio = "Si"):
      
        listaArroz = coleccion[c][t]
      
        lmax = len(listaArroz)
        b = listaArroz[0].b
      
        tab_spec = np.zeros((lmax, b))
      
        for i in range(lmax):
          
            tab_spec[i,:] = listaArroz[i].espectroMedio
      
        
        if promedio == "Si":
          
            return np.mean(tab_spec, axis=0)
      
        else:
      
            return tab_spec

      
#----------------------------------------------------------------------

    def mostrarEspectroClaseOtiempo(coleccion, c = -1, t=-1):
      
        if c != -1:
      
            for listaTiempo in coleccion[c]:
              
                arroz = listaTiempo[0]
          
                esp = Semilla.extraerEspectro(coleccion, arroz.c, arroz.t )
          
                plt.plot(esp, label = arroz.tiempo)
        
        
            plt.legend()
        
            plt.title(arroz.clase)
        
            plt.grid()
            plt.xlabel("Bands")
            plt.ylabel("Intensity")
      
        
        elif t != -1:
      
            cmax = len(coleccion)
        
            for cc in range(cmax):
        
                if cc == 0:
          
                    esp = Semilla.extraerEspectro(coleccion, cc, 0 )
          
                else:
          
                    esp = Semilla.extraerEspectro(coleccion, cc, t )
          
          
                plt.plot(esp, label = coleccion[cc][0][0].clase)
        
            plt.legend()
        
            plt.title('Time = {} h'.format(coleccion[1][t][0].tiempo) )
        
            plt.grid()
            plt.xlabel("Bands")
            plt.ylabel("Intensity")
        

    '''
    


#------------------------------------------------------------------





  #########   Metodos para calcular propiedades geometricas ###########

    def perimetro(self):
        return perimeter(self.mascara)
    
    
    def area(self):
    
        return np.sum(self.mascara)
    
    
    def compactness(self):
    
        return self.perimetro()**2 / self.area
    
    
    def roundness(self):
    
        return 4*pi*self.area / self.perimetro()**2