import unittest import numpy as np from unittest.mock import patch import notscared.snr as snr import notscared.data.trace_handler as trace_handler class TestSNR(unittest.TestCase): def setUp(self): # Create a dummy TraceHandler object self.tracehandler = DummyTraceHandler() # Define the bytes to attack self.bytes = [0, 1, 2] # Create an instance of SNR self.snr = snr.SNR(Tracehandler=self.tracehandler, Bytes=self.bytes) def test_run_1x1(self): # Define a dummy batch of data batch = (np.random.rand(10, 256), np.random.randint(0, 256, size=(10,))) # Patch the grab method to return the dummy batch with patch.object(self.tracehandler, 'grab', return_value=batch): result = self.snr.run_1x1(0) # Assertions self.assertEqual(result.shape, (self.tracehandler.sample_length,)) self.assertTrue(np.all(result >= 0)) # Define a dummy TraceHandler class for testing class DummyTraceHandler(trace_handler.TraceHandler): def __init__(self): self.sample_length = 100 def grab(self): return (np.random.rand(10, 256), np.random.randint(0, 256, size=(10,))) if __name__ == '__main__': unittest.main()