notscared / notscared2-main / notscared / converters.py
converters.py
Raw
import sys
import os
import h5py
import sqlite3
import numpy as np
import zarr
from tqdm import tqdm


class Converter:

    def __init__(self) -> None:
        pass

    @staticmethod
    def get_directory_size(directory):
        total_size = 0
        for path, dirs, files in os.walk(directory):
            for f in files:
                fp = os.path.join(path, f)
                total_size += os.path.getsize(fp)
        return total_size / (pow(1024, 3))  # Return dir size in gb

    def copy_h5_to_zarr(self, h5_group, zarr_group, chunk_size):

        for key, value in h5_group.items():
            if isinstance(value, h5py.Group):
                zarr_subgroup = zarr_group.create_group(key)
                self.copy_h5_to_zarr(value, zarr_subgroup)
            elif isinstance(value, h5py.Dataset):
                print("copying ", key, " to zarr")
                compressor = zarr.Blosc(cname='zstd', clevel=3, shuffle=zarr.Blosc.SHUFFLE)
                zarr_group.array(key, data=np.array(value), chunks=value.chunks, compressor=compressor)

    def h5_to_zarr(self, h5file, chunk_size):

        if os.path.exists(h5file.replace(".h5", ".zarr")):
            print("Deleting existing Zarr file...")
            zarr.DirectoryStore(h5file.replace(".h5", ".zarr")).rmdir()

        h5_file = h5py.File(h5file, 'r')
        zarr_dir = zarr.DirectoryStore(h5file.replace(".h5", ".zarr"))
        zarr_group = zarr.hierarchy.group(store=zarr_dir)
        self.copy_h5_to_zarr(h5_file, zarr_group, chunk_size)

        h5_file.close()
        zarr_dir.close()

    def db_to_zarr(self, db_path, table_name, output_columns, chunk_size=10000, compressor=zarr.Blosc(cname='zstd', clevel=3, shuffle=zarr.Blosc.SHUFFLE)):

        zarr_file = db_path.replace(".db", ".zarr")

        if os.path.exists(zarr_file):
            print("Deleting existing Zarr file...")
            zarr.DirectoryStore(zarr_file).rmdir()

        zarr_dir = zarr.DirectoryStore(zarr_file)
        zarr_group = zarr.hierarchy.group(store=zarr_dir)

        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()
        num_rows = cursor.execute(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0]
        
        # Prepare SQL query
        columns_str = ",".join(output_columns)
        query = f"SELECT tile_x, tile_y, {columns_str} FROM {table_name}"

        for offset in tqdm(range(0, num_rows, chunk_size), desc='Converting database to Zarr'):
            cursor.execute(f"{query} LIMIT {chunk_size} OFFSET {offset}")
            rows = cursor.fetchall()

            blobs_for_columns = [[] for _ in output_columns]

            for tile_x, tile_y, *blobs in rows:
                tile_x, tile_y = int(tile_x), int(tile_y)

                for idx, blob in enumerate(blobs):
                    blob_array = np.frombuffer(np.array(blob), dtype=np.uint8)
                    blobs_for_columns[idx].append(blob_array)

            zarr_group_tile = zarr_group.require_group(f"{tile_x}/{tile_y}")

            for blobs, column_name in zip(blobs_for_columns, output_columns):
                blobs_array = np.array(blobs)

                if column_name in zarr_group_tile.array_keys():
                    zarr_array = zarr_group_tile[column_name]
                    zarr_array.append(blobs_array)
                else:
                    zarr_group_tile.zeros(
                        name=column_name,
                        shape=(0,) + blobs_array.shape[1:],
                        chunks=(chunk_size,) + blobs_array.shape[1:],
                        dtype=blobs_array.dtype,
                        compressor=compressor
                    ).append(blobs_array)
        conn.close()

    def db_to_h5(self, db_path, table_name, output_columns, chunk_size=10000):

        hdf5_file = db_path.replace(".db", ".h5")
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()
        num_rows = cursor.execute(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0]

        with h5py.File(hdf5_file, "w") as f:
            for col_name in output_columns:
                print("Converting column ", col_name)

                dataset = None
                for offset in range(0, num_rows, chunk_size):
                    query = f"SELECT {col_name} FROM {table_name} LIMIT {chunk_size} OFFSET {offset}"
                    cursor.execute(query)
                    rows = np.array(cursor.fetchall())

                    if len(rows) == 0:
                        print(f"No data found for column {col_name}")
                        break

                    if dataset is None:
                        print("Creating dataset for column: ", col_name)
                        dataset = f.create_dataset(
                            col_name,
                            (0, 1),
                            maxshape=(None, 1),
                            dtype=rows.dtype,
                            chunks=True)

                    dataset.resize((dataset.shape[0] + len(rows), 1))
                    dataset[-len(rows):] = rows

        conn.close()


if __name__ == '__main__':

    if len(sys.argv) != 2:
        print("Usage: python convert.py <db_file>")
        sys.exit(1)

    converter = Converter()
    db_path = sys.argv[1]
    h5_path = db_path.replace(".db", ".h5")
    table_name = 'traces'
    columns = ['samples', 'ptxt', 'ctxt']

    converter.db_to_zarr(db_path, table_name, columns)
    # converter.h5_to_zarr(h5_path)