import numpy as np
import os
from multiprocessing.pool import Pool
from .util import resultsAnalyzer as resultsAnalyzer
from tqdm import tqdm
class CPA:
# Constructer Parameters:
# Tracehandler: Tracehandler object
# Bytes: List of desired bytes to attack
# file_name: File name for ResultsAnalyzer object
# bytes_to_graph: List of tuples containing bytes to graph in form of {byte_idx, graph_color}
def __init__(self, Tracehandler, Bytes, Bytes_To_Graph=[]) -> None:
# Tracehandler object that iterates file and grabs data
self.data = Tracehandler
# Length of observed samples used for array shapes
self.trace_length = self.data.sample_length
# Bytes to be computed
self.bytes = Bytes
# Tile coordinates to be computed
self.tiles = self.data.hasTiles
# AES-128 sbox used to compute model values
self.sbox = np.array([99,124,119,123,242,107,111,197,48,1,103,43,254,215,171,118,
202,130,201,125,250,89,71,240,173,212,162,175,156,164,114,192,
183,253,147,38,54,63,247,204,52,165,229,241,113,216,49,21,
4,199,35,195,24,150,5,154,7,18,128,226,235,39,178,117,
9,131,44,26,27,110,90,160,82,59,214,179,41,227,47,132,
83,209,0,237,32,252,177,91,106,203,190,57,74,76,88,207,
208,239,170,251,67,77,51,133,69,249,2,127,80,60,159,168,
81,163,64,143,146,157,56,245,188,182,218,33,16,255,243,210,
205,12,19,236,95,151,68,23,196,167,126,61,100,93,25,115,
96,129,79,220,34,42,144,136,70,238,184,20,222,94,11,219,
224,50,58,10,73,6,36,92,194,211,172,98,145,149,228,121,
231,200,55,109,141,213,78,169,108,86,244,234,101,122,174,8,
186,120,37,46,28,166,180,198,232,221,116,31,75,189,139,138,
112,62,181,102,72,3,246,14,97,53,87,185,134,193,29,158,
225,248,152,17,105,217,142,148,155,30,135,233,206,85,40,223,
140,161,137,13,191,230,66,104,65,153,45,15,176,84,187,22])
# The key hypotheseses which are the hex values 0-255
self.keys = np.arange(256)
# Hamming weights of the values 0-255 used for model values
self.weights = np.array([0,1,1,2,1,2,2,3,1,2,2,3,2,3,3,4,
1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,
1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,
2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,
2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,
1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,
2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,
2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,
3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,
4,5,5,6,5,6,6,7,5,6,6,7,6,7,7,8], np.float32)
# Number of rows computed so far
self.count = 0
# Sum of the samples observed
self.sample_sum = np.zeros((self.trace_length),dtype=np.float32)
# Sum of the samples observed squared
self.sample_sq_sum = np.zeros((self.trace_length),dtype=np.float32)
# Sum of the model so far
self.model_sum = np.zeros((256),dtype=np.float32)
# Sum of the model squared so far
self.model_sq_sum = np.zeros((256),dtype=np.float32)
# Sum of the product of the samples and the models
self.prod_sum = np.zeros((256, self.trace_length),dtype=np.float32)
# Analyzer for saving results
self.resultsAnalyzer = resultsAnalyzer.ResultsAnalyzer(self.data.file_name, 'cpa')
# Bytes for analyzer to graph
self.bytes_to_graph = Bytes_To_Graph
# Setting up the shape of results based on whether data is tiled or not
if self.tiles:
self.results = np.zeros((len(self.data.tiles_coordinates), len(self.bytes), 256, self.data.sample_length) ,dtype=np.float32)
else:
self.results = np.zeros((len(self.bytes),256, self.trace_length),dtype=np.float32)
print("Computing CPA: Tiles: ", self.tiles, " Bytes: ", self.bytes)
# Function to take accumulated sums and turn them into CPA
@staticmethod
def finalize(self):
# Sample mean computation
sample_mean = np.divide(self.sample_sum, self.count)
# Model mean computation
model_mean = np.divide(self.model_sum, self.count)
# Calculate correlation coefficient numerator
numerator = np.subtract(np.divide(self.prod_sum, self.count), model_mean[:,None]*sample_mean)
# Calculate correlation coeefficient denominator sample part
denom_sample = np.sqrt(np.subtract(np.divide(self.sample_sq_sum, self.count), np.square(sample_mean)))
# Calculate correlation coefficient denominator model part
denom_model = np.sqrt(np.subtract(np.divide(self.model_sq_sum, self.count), np.square(model_mean)))
# finalize and pass result back to run
return np.divide(numerator, denom_model[:,None]*denom_sample)
# Function to create model from given plaintext values
def get_model(self, plaintexts:np.ndarray):
# Calculate the inputs for each hypothesized key value
inputs = np.bitwise_xor(plaintexts[:,np.newaxis], self.keys)
# Calculate the outputs for each input
outputs = self.sbox[inputs]
# Calculate the modeled power draw and pass back to run
return self.weights[np.bitwise_xor(inputs, outputs)]
# Update the values using the current batch of trace samples and the computed model incrementing the vals into the accumulators
def update(self, traces:np.ndarray, model:np.ndarray):
# Update the number of rows processed
self.count += traces.shape[0]
# Update sample accumulator
self.sample_sum += np.sum(traces, axis=0)
# Update sample squared accumulator
self.sample_sq_sum += np.sum(np.square(traces), axis=0)
# Update model accumulator
self.model_sum += np.sum(model, axis=0)
# Update model squared accumulator
self.model_sq_sum += np.sum(np.square(model), axis=0)
# Update product accumulator
self.prod_sum += np.matmul(model.T, traces)
# Function for running tiled datasets single tile and byte combination
@staticmethod
def run_xy(self, byte, tile_x, tile_y):
# Make sure that trace handler is configured to given tile
self.data.configure_tile(tile_x, tile_y)
# Grab initial batch
batch = self.data.grab()
# While batches have length greater than 0
while len(batch) > 0:
# Plaintext batch is always first
plaintexts = batch[0]
# Samples are always the second batch
samples = batch[1].astype(np.float32)
# Calculate model from given plaintexts
model = self.get_model(plaintexts[:, byte])
# Update accumulators
self.update(samples, model)
# Grab new batch
batch = self.data.grab()
# Pass back to main run function
return (tile_x, tile_y, byte, self.finalize(self))
# Function for running single datasets single byte
@staticmethod
def run_1x1(self, byte):
# Grab initial batch
batch = self.data.grab()
# While batches have a length greater than 0
while len(batch) > 0:
# Plaintext batch is always first
plaintexts = batch[0]
# Samples are always the second batch
samples = batch[1].astype(np.float32)
# Calculate the model from the given plaintexts
model = self.get_model(plaintexts[:,byte])
# Update accumulators
self.update(samples, model)
# Grab new batch
batch = self.data.grab()
# Pass back to main run function
return self.finalize(self)
# Function to split workload using multiprocessing and differentiate between tiled and non-tiled datasets
def run(self):
# Bytes to process
self.bytes = np.array(self.bytes, dtype=int)
# If calculating on a tiled dataset cut workload into bytes X each tile coordinate
if self.tiles:
with Pool() as pool:
workload = []
# Create workload
for tile in self.data.tiles_coordinates:
(tile_x, tile_y) = tile
for byte_position in self.bytes:
workload.append((self, byte_position, tile_x, tile_y))
# Run workload
with tqdm(total=len(workload)) as pbar:
# For each result from the workload
for tile_x, tile_y, byte_position, byte_result in pool.starmap(self.run_xy, workload):
# Update progress bar
pbar.update()
# make sure results go into right index
byte_position = int(byte_position)
byte_indices = np.where(self.bytes == byte_position)
byte_index = byte_indices[0]
tile_index = self.data.tiles_coordinates.index((tile_x, tile_y))
# Establish results in results array
self.results[tile_index, byte_index, :] = byte_result
else:
with Pool() as pool:
workload = []
# Create workload
for x in self.bytes:
workload.append((self, x))
# Index for keeping track of location for byte results
index = 0
with tqdm(total=len(workload)) as pbar:
for byte_result in pool.starmap(self.run_1x1, workload):
# Update progress bar
pbar.update()
# Establish results in result array
self.results[index,:] = byte_result
# Increment index
index += 1
def publish(self):
if (self.tiles):
self.publish_xy()
else:
self.publish_1x1()
# Should not be called from outside this class - use publish()
def publish_1x1(self):
# Create a graph for keys
key_index = np.unravel_index(np.abs(self.results[0, :, :]).argmax(), self.results.shape[1:])
graph_results = [(self.results[0, :, :].T, 'gray', 'All Keys'),
(self.results[0, key_index[0], :], 'red', 'Highest Key')]
self.resultsAnalyzer.graph_1x1(graph_results, 'Samples', 'CPA')
# Save cpa data to csv file
for byte_index in range(len(self.bytes)):
data = []
fields = ['Trace']
for key_index in range(256):
data.append(self.results[byte_index, key_index, :])
fields.append(f'Key Hyp #{key_index}')
data = np.asarray(data)
data_file_name = os.path.splitext(os.path.basename(self.data.file_name))[0] # Full zarr file name minus zarr extension
csv_file_name = f'{data_file_name}_cpa_{byte_index}.csv'
self.resultsAnalyzer.save_results(data, csv_file_name, fields)
# Should not be called from outside this class - use publish()
def publish_xy(self):
def count_distinct_coordinates(coordinates):
x_coords = set()
y_coords = set()
for coord in coordinates:
x_coords.add(coord[0])
y_coords.add(coord[1])
return len(x_coords), len(y_coords)
shape = count_distinct_coordinates(self.data.tiles_coordinates)
# Create heat map for each byte pos
for byte_index in range(len(self.bytes)):
# Grab data from tiles at this byte pos
tile_results = []
for tile in self.data.tiles_coordinates:
tile_index = self.data.tiles_coordinates.index(tile)
tile_results.append((tile, self.results[tile_index, byte_index, :]))
# Print to heat map
self.resultsAnalyzer.heat_map(tile_results, shape, 'X Axis', 'Y Axis', self.bytes[byte_index], self.resultsAnalyzer.cpa_heat_map_max)
# Create csv file for each tile
for tile in self.data.tiles_coordinates:
x, y = tile
tile_index = self.data.tiles_coordinates.index(tile)
for byte_index in range(len(self.bytes)):
data = []
fields = ['Trace']
for key_hypothesis in range(256):
data.append(self.results[tile_index, byte_index, key_hypothesis, :])
fields.append(f'Key Hyp #{key_hypothesis}')
data = np.asarray(data)
data_file_name = os.path.splitext(os.path.basename(self.data.file_name))[0] # Full zarr file name minus zarr extension
csv_file_name = f'{data_file_name}_cpa_{x}_{y}_{byte_index}.csv'
self.resultsAnalyzer.save_results(data, csv_file_name, fields)