notscared / notscared2-main / notscared / snr.py
snr.py
Raw
import numpy as np
import os
from multiprocessing.pool import Pool
from numba import njit
import matplotlib.pyplot as plt
from .util import resultsAnalyzer as resultsAnalyzer

from tqdm import tqdm


def run_xy_unpack(args):
    return SNR.run_xy(*args)


class SNR:

    # Constructer Parameters:
    #   Tracehandler: Tracehandler object 
    #   Bytes: List of desired bytes to attack
    #   file_name: File name for ResultsAnalyzer object
    #   bytes_to_graph: List of tuples containing bytes to graph in form of {byte_idx, graph_color}
    def __init__(self, Tracehandler, Bytes, bytes_to_graph=[]) -> None:

        # Trace handler object that iterates file and grabs data
        self.data = Tracehandler
        # Bytes to compute
        self.bytes = Bytes
        # Bytes to show results of
        self.bytes_to_graph = bytes_to_graph
        # Boolean variable for whether or not we are computing a tiled dataset
        self.tiles = self.data.hasTiles
        # Number of rows computed per hex value
        self.counts = np.zeros((256), dtype=np.float32)
        # Mean value for each hex value and each sample point
        self.means = np.zeros((256, self.data.sample_length), dtype=np.float32)
        # Moment value for each hex value and each sample point
        self.moments = np.zeros((256, self.data.sample_length), dtype=np.float32)
        # Publisher object for saving results
        self.resultsAnalyzer = resultsAnalyzer.ResultsAnalyzer(self.data.file_name, 'snr')

        # Setting up the shape of results based on whether we are using tiled data or not
        if self.tiles:
            self.results = np.zeros((len(self.data.tiles_coordinates), len(self.bytes), self.data.sample_length))

        else:
            self.results = np.zeros((len(self.bytes), self.data.sample_length), dtype=np.float32)

        print("Computing SNR: Tiles: ", self.tiles, " Bytes: ", self.bytes)

    # Function to take means and vars and compute SNR
    @staticmethod
    def finalize(means, vars):
        
        # Signal computation
        signals = np.var(means, 0, where=means != 0)
        # Noise computation
        noises = np.apply_along_axis(np.mean, 0, vars)

        # pass back SNR to main run function
        return signals / noises

    # Function for running single datasets by byte
    @staticmethod
    def run_1x1(self, byte):

        # Grab initial batch
        batch = self.data.grab()

        # While length of the batch is greater than 0
        while len(batch) > 0:
            # Plaintext batch is always first
            ptxt_batch = batch[0]
            # Grab specific byte
            ptxt_batch = ptxt_batch[:, byte]
            # Update counts means and moments using welford's online algorithm
            self.counts, self.means, self.moments = self.update(batch[1], ptxt_batch, self.counts, self.means, self.moments)
            # grab next batch
            batch = self.data.grab()

        # Pass results back to main run function
        return self.finalize(self.means, np.divide(self.moments, self.counts[:, np.newaxis], where=self.counts[:, np.newaxis]!=0))

    # Function for running tiled datasets single tile and byte combination
    @staticmethod
    def run_xy(self, byte, tile_x, tile_y):
        # Make sure the trace handler is configured to the given tile
        self.data.configure_tile(tile_x, tile_y)

        # Grab initial batch
        batch = self.data.grab()

        # While length of batch is greater than 0
        while len(batch) > 0:

            # Plaintext batch is always first
            ptxt_batch = batch[0]
            # Grab specific byte that we are observing
            ptxt_batch = ptxt_batch[:, byte]
            # Update counts means and moments using welford's onlie algorithm
            self.counts, self.means, self.moments = self.update(batch[1], ptxt_batch, self.counts, self.means, self.moments)
            # Grab next batch
            batch = self.data.grab()

        # Pass results back to main run function
        return tile_x, tile_y, byte, self.finalize(self.means, np.divide(self.moments, self.counts[:, np.newaxis], where=self.counts[:, np.newaxis] != 0))

    # Function to split workload using multiprocessing and differentiate between tiled and non-tiled datasets
    def run(self):
        # Bytes to process
        self.bytes = np.array(self.bytes, dtype=int)

        # If calculating on a tiled dataset cut workload into bytes X each tile coordinate
        if self.tiles:
            with Pool() as pool:
                workload = []
                # Create workload
                for tile in self.data.tiles_coordinates:
                    (tile_x, tile_y) = tile
                    for byte_position in self.bytes:
                        workload.append((self, byte_position, tile_x, tile_y))
                print('Processing...')
                # Run workload
                with tqdm(total=len(workload)) as pbar:
                    # For each result from the workload
                    for tile_x, tile_y, byte_position, byte_result in pool.imap_unordered(run_xy_unpack, workload):
                        # Update progress bar
                        pbar.update()
                        # make sure results go into right index
                        byte_position = int(byte_position)
                        byte_indices = np.where(self.bytes == byte_position)
                        byte_index = byte_indices[0]
                        tile_index = self.data.tiles_coordinates.index((tile_x, tile_y))
                        # Establish results in results array
                        self.results[tile_index, byte_index, :] = byte_result
        else:
            with Pool() as pool:
                # Create workload
                workload = [(self, x) for x in self.bytes]
                print('Processing...')
                # Index for keeping track of location for byte results 
                index = 0 
                # Run workload
                with tqdm(total=len(workload)) as pbar:
                    # For each result from the workload
                    for byte_result in pool.starmap(self.run_1x1, workload):
                        # Update progress bar
                        pbar.update()
                        # Establish results in results array
                        self.results[index, :] = byte_result
                        # Increment index
                        index += 1

    @staticmethod
    @njit
    def update(traces: np.ndarray, plaintext: np.ndarray, counts: np.ndarray, means: np.ndarray, moments: np.ndarray):

        for i in range(traces.shape[0]):

            counts[plaintext[i]] += 1
            delta1 = traces[i] - means[plaintext[i]]
            means[plaintext[i]] += delta1 / counts[plaintext[i]]
            moments[plaintext[i]] += delta1*(traces[i] - means[plaintext[i]])

        return counts, means, moments

    def publish(self):
        if (self.tiles):
            self.publish_xy()
        else:
            self.publish_1x1()

    # Should not be called from outside this class - use publish()
    def publish_1x1(self):

        # Create a graph for desired bytes
        graph_results = []
        for graph_byte in self.bytes_to_graph:
            idx, color = graph_byte
            graph_results.append((self.results[idx, :], color, f'Byte #{str(idx)}'))
        if len(graph_results) > 0:
            self.resultsAnalyzer.graph_1x1(graph_results, 'Samples', 'SNR')

        # Save snr data to csv file
        data = []
        fields = ['Trace']
        for i in range(len(self.bytes)):
            data.append(self.results[i, :])
            fields.append(f'Byte #{self.bytes[i]}')
        data = np.asarray(data)
        data_file_name = os.path.splitext(os.path.basename(self.data.file_name))[0]  # Full zarr file name minus zarr extension
        csv_file_name = f'{data_file_name}_snr'
        for byte in self.bytes:
            csv_file_name += f'_{str(byte)}'
        csv_file_name += '.csv'
        self.resultsAnalyzer.save_results(data, csv_file_name, fields)

    # Should not be called from outside this class - use publish()
    def publish_xy(self):

        def count_distinct_coordinates(coordinates):
            x_coords = set()
            y_coords = set()

            for coord in coordinates:
                x_coords.add(coord[0])
                y_coords.add(coord[1])

            return len(x_coords), len(y_coords)

        shape = count_distinct_coordinates(self.data.tiles_coordinates)

        # Create heat map for each byte pos
        for byte_index in range(len(self.bytes)):

            # Grab data from tiles at this byte pos
            tile_results = []
            for tile in self.data.tiles_coordinates:
                tile_index = self.data.tiles_coordinates.index(tile)
                tile_results.append((tile, self.results[tile_index, byte_index, :]))

            # Print to heat map
            self.resultsAnalyzer.heat_map(tile_results, shape, 'X Axis', 'Y Axis', self.bytes[byte_index], self.resultsAnalyzer.snr_heat_map_max)

        # Create csv file for each tile
        for tile in self.data.tiles_coordinates:
            x, y = tile
            tile_index = self.data.tiles_coordinates.index(tile)
            data = []
            fields = ['Trace']
            for byte_index in range(len(self.bytes)):
                data.append(self.results[tile_index, byte_index, :])
                fields.append(f'Byte #{self.bytes[byte_index]}')
            data = np.asarray(data)
            data_file_name = os.path.splitext(os.path.basename(self.data.file_name))[0]  # Full zarr file name minus zarr extension
            csv_file_name = f'{data_file_name}_snr_{x}_{y}.csv'
            self.resultsAnalyzer.save_results(data, csv_file_name, fields)