import unittest
import numpy as np
from unittest.mock import patch
import notscared.snr as snr
import notscared.data.trace_handler as trace_handler
class TestSNR(unittest.TestCase):
def setUp(self):
# Create a dummy TraceHandler object
self.tracehandler = DummyTraceHandler()
# Define the bytes to attack
self.bytes = [0, 1, 2]
# Create an instance of SNR
self.snr = snr.SNR(Tracehandler=self.tracehandler, Bytes=self.bytes)
def test_run_1x1(self):
# Define a dummy batch of data
batch = (np.random.rand(10, 256), np.random.randint(0, 256, size=(10,)))
# Patch the grab method to return the dummy batch
with patch.object(self.tracehandler, 'grab', return_value=batch):
result = self.snr.run_1x1(0)
# Assertions
self.assertEqual(result.shape, (self.tracehandler.sample_length,))
self.assertTrue(np.all(result >= 0))
# Define a dummy TraceHandler class for testing
class DummyTraceHandler(trace_handler.TraceHandler):
def __init__(self):
self.sample_length = 100
def grab(self):
return (np.random.rand(10, 256), np.random.randint(0, 256, size=(10,)))
if __name__ == '__main__':
unittest.main()