import numpy as np
import os
from multiprocessing.pool import Pool
from numba import njit
import matplotlib.pyplot as plt
from .util import resultsAnalyzer as resultsAnalyzer
from tqdm import tqdm
def run_xy_unpack(args):
return SNR.run_xy(*args)
class SNR:
# Constructer Parameters:
# Tracehandler: Tracehandler object
# Bytes: List of desired bytes to attack
# file_name: File name for ResultsAnalyzer object
# bytes_to_graph: List of tuples containing bytes to graph in form of {byte_idx, graph_color}
def __init__(self, Tracehandler, Bytes, bytes_to_graph=[]) -> None:
# Trace handler object that iterates file and grabs data
self.data = Tracehandler
# Bytes to compute
self.bytes = Bytes
# Bytes to show results of
self.bytes_to_graph = bytes_to_graph
# Boolean variable for whether or not we are computing a tiled dataset
self.tiles = self.data.hasTiles
# Number of rows computed per hex value
self.counts = np.zeros((256), dtype=np.float32)
# Mean value for each hex value and each sample point
self.means = np.zeros((256, self.data.sample_length), dtype=np.float32)
# Moment value for each hex value and each sample point
self.moments = np.zeros((256, self.data.sample_length), dtype=np.float32)
# Publisher object for saving results
self.resultsAnalyzer = resultsAnalyzer.ResultsAnalyzer(self.data.file_name, 'snr')
# Setting up the shape of results based on whether we are using tiled data or not
if self.tiles:
self.results = np.zeros((len(self.data.tiles_coordinates), len(self.bytes), self.data.sample_length))
else:
self.results = np.zeros((len(self.bytes), self.data.sample_length), dtype=np.float32)
print("Computing SNR: Tiles: ", self.tiles, " Bytes: ", self.bytes)
# Function to take means and vars and compute SNR
@staticmethod
def finalize(means, vars):
# Signal computation
signals = np.var(means, 0, where=means != 0)
# Noise computation
noises = np.apply_along_axis(np.mean, 0, vars)
# pass back SNR to main run function
return signals / noises
# Function for running single datasets by byte
@staticmethod
def run_1x1(self, byte):
# Grab initial batch
batch = self.data.grab()
# While length of the batch is greater than 0
while len(batch) > 0:
# Plaintext batch is always first
ptxt_batch = batch[0]
# Grab specific byte
ptxt_batch = ptxt_batch[:, byte]
# Update counts means and moments using welford's online algorithm
self.counts, self.means, self.moments = self.update(batch[1], ptxt_batch, self.counts, self.means, self.moments)
# grab next batch
batch = self.data.grab()
# Pass results back to main run function
return self.finalize(self.means, np.divide(self.moments, self.counts[:, np.newaxis], where=self.counts[:, np.newaxis]!=0))
# Function for running tiled datasets single tile and byte combination
@staticmethod
def run_xy(self, byte, tile_x, tile_y):
# Make sure the trace handler is configured to the given tile
self.data.configure_tile(tile_x, tile_y)
# Grab initial batch
batch = self.data.grab()
# While length of batch is greater than 0
while len(batch) > 0:
# Plaintext batch is always first
ptxt_batch = batch[0]
# Grab specific byte that we are observing
ptxt_batch = ptxt_batch[:, byte]
# Update counts means and moments using welford's onlie algorithm
self.counts, self.means, self.moments = self.update(batch[1], ptxt_batch, self.counts, self.means, self.moments)
# Grab next batch
batch = self.data.grab()
# Pass results back to main run function
return tile_x, tile_y, byte, self.finalize(self.means, np.divide(self.moments, self.counts[:, np.newaxis], where=self.counts[:, np.newaxis] != 0))
# Function to split workload using multiprocessing and differentiate between tiled and non-tiled datasets
def run(self):
# Bytes to process
self.bytes = np.array(self.bytes, dtype=int)
# If calculating on a tiled dataset cut workload into bytes X each tile coordinate
if self.tiles:
with Pool() as pool:
workload = []
# Create workload
for tile in self.data.tiles_coordinates:
(tile_x, tile_y) = tile
for byte_position in self.bytes:
workload.append((self, byte_position, tile_x, tile_y))
print('Processing...')
# Run workload
with tqdm(total=len(workload)) as pbar:
# For each result from the workload
for tile_x, tile_y, byte_position, byte_result in pool.imap_unordered(run_xy_unpack, workload):
# Update progress bar
pbar.update()
# make sure results go into right index
byte_position = int(byte_position)
byte_indices = np.where(self.bytes == byte_position)
byte_index = byte_indices[0]
tile_index = self.data.tiles_coordinates.index((tile_x, tile_y))
# Establish results in results array
self.results[tile_index, byte_index, :] = byte_result
else:
with Pool() as pool:
# Create workload
workload = [(self, x) for x in self.bytes]
print('Processing...')
# Index for keeping track of location for byte results
index = 0
# Run workload
with tqdm(total=len(workload)) as pbar:
# For each result from the workload
for byte_result in pool.starmap(self.run_1x1, workload):
# Update progress bar
pbar.update()
# Establish results in results array
self.results[index, :] = byte_result
# Increment index
index += 1
@staticmethod
@njit
def update(traces: np.ndarray, plaintext: np.ndarray, counts: np.ndarray, means: np.ndarray, moments: np.ndarray):
for i in range(traces.shape[0]):
counts[plaintext[i]] += 1
delta1 = traces[i] - means[plaintext[i]]
means[plaintext[i]] += delta1 / counts[plaintext[i]]
moments[plaintext[i]] += delta1*(traces[i] - means[plaintext[i]])
return counts, means, moments
def publish(self):
if (self.tiles):
self.publish_xy()
else:
self.publish_1x1()
# Should not be called from outside this class - use publish()
def publish_1x1(self):
# Create a graph for desired bytes
graph_results = []
for graph_byte in self.bytes_to_graph:
idx, color = graph_byte
graph_results.append((self.results[idx, :], color, f'Byte #{str(idx)}'))
if len(graph_results) > 0:
self.resultsAnalyzer.graph_1x1(graph_results, 'Samples', 'SNR')
# Save snr data to csv file
data = []
fields = ['Trace']
for i in range(len(self.bytes)):
data.append(self.results[i, :])
fields.append(f'Byte #{self.bytes[i]}')
data = np.asarray(data)
data_file_name = os.path.splitext(os.path.basename(self.data.file_name))[0] # Full zarr file name minus zarr extension
csv_file_name = f'{data_file_name}_snr'
for byte in self.bytes:
csv_file_name += f'_{str(byte)}'
csv_file_name += '.csv'
self.resultsAnalyzer.save_results(data, csv_file_name, fields)
# Should not be called from outside this class - use publish()
def publish_xy(self):
def count_distinct_coordinates(coordinates):
x_coords = set()
y_coords = set()
for coord in coordinates:
x_coords.add(coord[0])
y_coords.add(coord[1])
return len(x_coords), len(y_coords)
shape = count_distinct_coordinates(self.data.tiles_coordinates)
# Create heat map for each byte pos
for byte_index in range(len(self.bytes)):
# Grab data from tiles at this byte pos
tile_results = []
for tile in self.data.tiles_coordinates:
tile_index = self.data.tiles_coordinates.index(tile)
tile_results.append((tile, self.results[tile_index, byte_index, :]))
# Print to heat map
self.resultsAnalyzer.heat_map(tile_results, shape, 'X Axis', 'Y Axis', self.bytes[byte_index], self.resultsAnalyzer.snr_heat_map_max)
# Create csv file for each tile
for tile in self.data.tiles_coordinates:
x, y = tile
tile_index = self.data.tiles_coordinates.index(tile)
data = []
fields = ['Trace']
for byte_index in range(len(self.bytes)):
data.append(self.results[tile_index, byte_index, :])
fields.append(f'Byte #{self.bytes[byte_index]}')
data = np.asarray(data)
data_file_name = os.path.splitext(os.path.basename(self.data.file_name))[0] # Full zarr file name minus zarr extension
csv_file_name = f'{data_file_name}_snr_{x}_{y}.csv'
self.resultsAnalyzer.save_results(data, csv_file_name, fields)