notscared / notscared2-main / notscared / cpa.py
cpa.py
Raw
import numpy as np
import os
from multiprocessing.pool import Pool
from .util import resultsAnalyzer as resultsAnalyzer
from tqdm import tqdm


class CPA:

    # Constructer Parameters:
    #   Tracehandler: Tracehandler object 
    #   Bytes: List of desired bytes to attack
    #   file_name: File name for ResultsAnalyzer object
    #   bytes_to_graph: List of tuples containing bytes to graph in form of {byte_idx, graph_color}
    def __init__(self, Tracehandler, Bytes, Bytes_To_Graph=[]) -> None:
        # Tracehandler object that iterates file and grabs data
        self.data = Tracehandler

        # Length of observed samples used for array shapes
        self.trace_length = self.data.sample_length

        # Bytes to be computed
        self.bytes = Bytes
        # Tile coordinates to be computed
        self.tiles = self.data.hasTiles

        # AES-128 sbox used to compute model values
        self.sbox = np.array([99,124,119,123,242,107,111,197,48,1,103,43,254,215,171,118,
                              202,130,201,125,250,89,71,240,173,212,162,175,156,164,114,192,
                              183,253,147,38,54,63,247,204,52,165,229,241,113,216,49,21,
                              4,199,35,195,24,150,5,154,7,18,128,226,235,39,178,117,
                              9,131,44,26,27,110,90,160,82,59,214,179,41,227,47,132,
                              83,209,0,237,32,252,177,91,106,203,190,57,74,76,88,207,
                              208,239,170,251,67,77,51,133,69,249,2,127,80,60,159,168,
                              81,163,64,143,146,157,56,245,188,182,218,33,16,255,243,210,
                              205,12,19,236,95,151,68,23,196,167,126,61,100,93,25,115,
                              96,129,79,220,34,42,144,136,70,238,184,20,222,94,11,219,
                              224,50,58,10,73,6,36,92,194,211,172,98,145,149,228,121,
                              231,200,55,109,141,213,78,169,108,86,244,234,101,122,174,8,
                              186,120,37,46,28,166,180,198,232,221,116,31,75,189,139,138,
                              112,62,181,102,72,3,246,14,97,53,87,185,134,193,29,158,
                              225,248,152,17,105,217,142,148,155,30,135,233,206,85,40,223,
                              140,161,137,13,191,230,66,104,65,153,45,15,176,84,187,22])
        
        # The key hypotheseses which are the hex values 0-255
        self.keys = np.arange(256)

        # Hamming weights of the values 0-255 used for model values
        self.weights = np.array([0,1,1,2,1,2,2,3,1,2,2,3,2,3,3,4,
                                 1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,
                                 1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,
                                 2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
                                 1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,
                                 2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
                                 2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
                                 3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,
                                 1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,
                                 2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
                                 2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
                                 3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,
                                 2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
                                 3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,
                                 3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,
                                 4,5,5,6,5,6,6,7,5,6,6,7,6,7,7,8], np.float32)
        # Number of rows computed so far
        self.count = 0
        # Sum of the samples observed
        self.sample_sum = np.zeros((self.trace_length),dtype=np.float32)
        # Sum of the samples observed squared
        self.sample_sq_sum = np.zeros((self.trace_length),dtype=np.float32)
        # Sum of the model so far
        self.model_sum = np.zeros((256),dtype=np.float32)
        # Sum of the model squared so far
        self.model_sq_sum = np.zeros((256),dtype=np.float32)
        # Sum of the product of the samples and the models
        self.prod_sum = np.zeros((256, self.trace_length),dtype=np.float32)
        # Analyzer for saving results
        self.resultsAnalyzer = resultsAnalyzer.ResultsAnalyzer(self.data.file_name, 'cpa')
        # Bytes for analyzer to graph
        self.bytes_to_graph = Bytes_To_Graph

        # Setting up the shape of results based on whether data is tiled or not
        if self.tiles:
            self.results = np.zeros((len(self.data.tiles_coordinates), len(self.bytes), 256, self.data.sample_length) ,dtype=np.float32)
        else:
            self.results = np.zeros((len(self.bytes),256, self.trace_length),dtype=np.float32)

        print("Computing CPA: Tiles: ", self.tiles, " Bytes: ", self.bytes)

    # Function to take accumulated sums and turn them into CPA
    @staticmethod
    def finalize(self):
        # Sample mean computation
        sample_mean = np.divide(self.sample_sum, self.count)
        # Model mean computation
        model_mean = np.divide(self.model_sum, self.count)
        # Calculate correlation coefficient numerator
        numerator = np.subtract(np.divide(self.prod_sum, self.count), model_mean[:,None]*sample_mean)
        # Calculate correlation coeefficient denominator sample part
        denom_sample = np.sqrt(np.subtract(np.divide(self.sample_sq_sum, self.count), np.square(sample_mean)))
        # Calculate correlation coefficient denominator model part
        denom_model = np.sqrt(np.subtract(np.divide(self.model_sq_sum, self.count), np.square(model_mean)))

        # finalize and pass result back to run
        return np.divide(numerator, denom_model[:,None]*denom_sample)

    # Function to create model from given plaintext values
    def get_model(self, plaintexts:np.ndarray):
        
        # Calculate the inputs for each hypothesized key value
        inputs = np.bitwise_xor(plaintexts[:,np.newaxis], self.keys)
        # Calculate the outputs for each input
        outputs = self.sbox[inputs]
        # Calculate the modeled power draw and pass back to run
        return self.weights[np.bitwise_xor(inputs, outputs)]

    # Update the values using the current batch of trace samples and the computed model incrementing the vals into the accumulators
    def update(self, traces:np.ndarray, model:np.ndarray):
        # Update the number of rows processed
        self.count += traces.shape[0]
        # Update sample accumulator
        self.sample_sum += np.sum(traces, axis=0)
        # Update sample squared accumulator
        self.sample_sq_sum += np.sum(np.square(traces), axis=0)
        # Update model accumulator
        self.model_sum += np.sum(model, axis=0)
        # Update model squared accumulator
        self.model_sq_sum += np.sum(np.square(model), axis=0)
        # Update product accumulator
        self.prod_sum += np.matmul(model.T, traces)


    # Function for running tiled datasets single tile and byte combination
    @staticmethod
    def run_xy(self, byte, tile_x, tile_y):
        # Make sure that trace handler is configured to given tile
        self.data.configure_tile(tile_x, tile_y)
        # Grab initial batch
        batch = self.data.grab()

        # While batches have length greater than 0
        while len(batch) > 0:
            # Plaintext batch is always first
            plaintexts = batch[0]
            # Samples are always the second batch
            samples = batch[1].astype(np.float32)

            # Calculate model from given plaintexts
            model = self.get_model(plaintexts[:, byte])
            # Update accumulators
            self.update(samples, model)
            # Grab new batch
            batch = self.data.grab()

        # Pass back to main run function
        return (tile_x, tile_y, byte, self.finalize(self))

    # Function for running single datasets single byte
    @staticmethod
    def run_1x1(self, byte):

        # Grab initial batch
        batch = self.data.grab()

        # While batches have a length greater than 0
        while len(batch) > 0: 
            # Plaintext batch is always first
            plaintexts = batch[0]
            # Samples are always the second batch
            samples = batch[1].astype(np.float32)

            # Calculate the model from the given plaintexts
            model = self.get_model(plaintexts[:,byte])
            # Update accumulators
            self.update(samples, model)
            # Grab new batch
            batch = self.data.grab()

        # Pass back to main run function
        return self.finalize(self)

    # Function to split workload using multiprocessing and differentiate between tiled and non-tiled datasets
    def run(self):
        # Bytes to process
        self.bytes = np.array(self.bytes, dtype=int)

        # If calculating on a tiled dataset cut workload into bytes X each tile coordinate
        if self.tiles:
            with Pool() as pool:
                workload = []
                # Create workload
                for tile in self.data.tiles_coordinates:
                    (tile_x, tile_y) = tile
                    for byte_position in self.bytes:
                        workload.append((self, byte_position, tile_x, tile_y))

                # Run workload
                with tqdm(total=len(workload)) as pbar:
                    # For each result from the workload
                    for tile_x, tile_y, byte_position, byte_result in pool.starmap(self.run_xy, workload):
                        # Update progress bar
                        pbar.update()
                        # make sure results go into right index
                        byte_position = int(byte_position)
                        byte_indices = np.where(self.bytes == byte_position)
                        
                        byte_index = byte_indices[0]
                        tile_index = self.data.tiles_coordinates.index((tile_x, tile_y))
                        # Establish results in results array
                        self.results[tile_index, byte_index, :] = byte_result

        else:
            with Pool() as pool:
                workload = []
                # Create workload
                for x in self.bytes:
                    workload.append((self, x))

                # Index for keeping track of location for byte results 
                index = 0

                with tqdm(total=len(workload)) as pbar:
                    for byte_result in pool.starmap(self.run_1x1, workload):
                        # Update progress bar
                        pbar.update()
                        # Establish results in result array
                        self.results[index,:] = byte_result
                        # Increment index
                        index += 1 

    def publish(self):
        if (self.tiles):
            self.publish_xy()
        else:
            self.publish_1x1()

    # Should not be called from outside this class - use publish()
    def publish_1x1(self):

        # Create a graph for keys
        key_index = np.unravel_index(np.abs(self.results[0, :, :]).argmax(), self.results.shape[1:])
        graph_results = [(self.results[0, :, :].T, 'gray', 'All Keys'),
                         (self.results[0, key_index[0], :], 'red', 'Highest Key')]
        self.resultsAnalyzer.graph_1x1(graph_results, 'Samples', 'CPA')

        # Save cpa data to csv file 
        for byte_index in range(len(self.bytes)):
            data = []
            fields = ['Trace']
            for key_index in range(256):
                data.append(self.results[byte_index, key_index, :])
                fields.append(f'Key Hyp #{key_index}')
            data = np.asarray(data)
            data_file_name = os.path.splitext(os.path.basename(self.data.file_name))[0]  # Full zarr file name minus zarr extension
            csv_file_name = f'{data_file_name}_cpa_{byte_index}.csv'
            self.resultsAnalyzer.save_results(data, csv_file_name, fields)

    # Should not be called from outside this class - use publish()
    def publish_xy(self):

        def count_distinct_coordinates(coordinates):
            x_coords = set()
            y_coords = set()

            for coord in coordinates:
                x_coords.add(coord[0])
                y_coords.add(coord[1])

            return len(x_coords), len(y_coords)
        

        shape = count_distinct_coordinates(self.data.tiles_coordinates)

        # Create heat map for each byte pos
        for byte_index in range(len(self.bytes)):

            # Grab data from tiles at this byte pos
            tile_results = []
            for tile in self.data.tiles_coordinates:
                tile_index = self.data.tiles_coordinates.index(tile)
                tile_results.append((tile, self.results[tile_index, byte_index, :]))

            # Print to heat map
            self.resultsAnalyzer.heat_map(tile_results, shape, 'X Axis', 'Y Axis', self.bytes[byte_index], self.resultsAnalyzer.cpa_heat_map_max)

        # Create csv file for each tile
        for tile in self.data.tiles_coordinates:
            x, y = tile
            tile_index = self.data.tiles_coordinates.index(tile)
            for byte_index in range(len(self.bytes)):
                data = []
                fields = ['Trace']
                for key_hypothesis in range(256):
                    data.append(self.results[tile_index, byte_index, key_hypothesis, :])
                    fields.append(f'Key Hyp #{key_hypothesis}')
                data = np.asarray(data)
                data_file_name = os.path.splitext(os.path.basename(self.data.file_name))[0]  # Full zarr file name minus zarr extension
                csv_file_name = f'{data_file_name}_cpa_{x}_{y}_{byte_index}.csv'
                self.resultsAnalyzer.save_results(data, csv_file_name, fields)