notscared / notscared2-main / tests / snr_test.py
snr_test.py
Raw
import unittest
import numpy as np
from unittest.mock import patch

import notscared.snr as snr
import notscared.data.trace_handler as trace_handler


class TestSNR(unittest.TestCase):
    def setUp(self):
        # Create a dummy TraceHandler object
        self.tracehandler = DummyTraceHandler()
        # Define the bytes to attack
        self.bytes = [0, 1, 2]
        # Create an instance of SNR
        self.snr = snr.SNR(Tracehandler=self.tracehandler, Bytes=self.bytes)

    def test_run_1x1(self):
        # Define a dummy batch of data
        batch = (np.random.rand(10, 256), np.random.randint(0, 256, size=(10,)))

        # Patch the grab method to return the dummy batch
        with patch.object(self.tracehandler, 'grab', return_value=batch):
            result = self.snr.run_1x1(0)

        # Assertions
        self.assertEqual(result.shape, (self.tracehandler.sample_length,))
        self.assertTrue(np.all(result >= 0))

# Define a dummy TraceHandler class for testing
class DummyTraceHandler(trace_handler.TraceHandler):
    def __init__(self):
        self.sample_length = 100

    def grab(self):
        return (np.random.rand(10, 256), np.random.randint(0, 256, size=(10,)))

if __name__ == '__main__':
    unittest.main()