#include "main.h"
#include "buffer.h"
#include <algorithm>
#include <stdlib.h>
#include <string>
#include <unistd.h> 
#include <fcntl.h> 
#include <cstdio>
#include <string>
#include <sys/stat.h>
#include <iostream>
#include <filesystem>
#include <cmath>
#include <cassert>
#include <fstream>
#include "xxhash32.h"
#include <limits.h>
#include <cstring>
#define NO_KEY INT_MIN
#define TOMBSTONE INT_MAX
#define PAGE_SIZE 4096
#define B 512 // Can be changed, but should be 2^n for some n such that B DB entries will fit within one page.
#define BIN_SEARCH 0
#define BTREE_SEARCH 1
#ifndef O_DIRECT // On machines that have O_DIRECT, our code is designed to use it.
#define O_DIRECT O_RDONLY // On machines where O_DIRECT is disallowed, use O_RDONLY instead.
#endif
#define DEBUG 0 // Set to 0 for experimental use / real-world , 1 for testing.
DB::DB(std::string name): name(name), mem_table(nullptr) {
    buffer_args[0] = 0;
    buffer_args[1] = NO_KEY;
    buffer_args[2] = NO_KEY;
}
int sizeToDepth(int size) {
    if (size < 0) return 0; // edge case, log only defined on >= 0
    return std::ceil(std::log2(size));
}
// Sums the array from start to size
int array_sum(int* arr, int start, int size) {
    int sum = 0;
    for (int i = start; i < size; i++) {
        sum += arr[i];
    }
    return sum;
}
// Public Methods
DB* Open(std::string name, int max_size, int initial_directory_size, int max_directory_size, int search_alg, int bloomfilter_bits) {
    DB* db = new DB(name);
    db->mem_table = new AVLTree(max_size, name);
    db->buffer = new BufferPool(sizeToDepth(initial_directory_size), sizeToDepth(max_directory_size));
    db->sst_counter = 0; // Update this to be persistent
    db->max_size = max_size;    
    db->buffer->no_buffer = false;
    db->bloomfilter_num_bits = bloomfilter_bits;
    if(max_directory_size == 0) db->buffer->no_buffer = true;
    // Moved from AVLTree Class
    struct stat sb;
    if (stat(name.c_str(), &sb) == 0){
        for (const auto & entry : std::filesystem::directory_iterator(name)){
            if ((strcmp(entry.path().c_str(), (name + "/sizes.bin").c_str()) != 0)) {
                std::string filename = entry.path().c_str();
                size_t back = filename.find_last_of(".");
                std::string no_extension = filename.substr(0, back);
                size_t front = no_extension.find_last_not_of("0123456789");
                std::string result = no_extension.substr(front + 1);
                db->sst_counter = std::max(db->sst_counter, stoi(result)+1);   
            }
        }
    }
    db->search_alg = search_alg;
    return db;
}
int DB::Close() {
    int new_size = transform_to_sst(); // Flush the current memtable to disk
    if (sst_counter == 0) sst_counter++;
    sst_counter = std::max(sst_counter, new_size + 1);  
    delete(DB::mem_table);
    DB:buffer->destroyBuffer(); // Currently is freeing nothing, and failing a malloc.
    delete this;  // suicide, be careful
    return 0;
}
// Checks for invalid keys / values, then inserts the pair.
int DB::put(int key, int value) {
    if(key == NO_KEY || key == TOMBSTONE || value == NO_KEY || value == TOMBSTONE){
        std::cerr << "Entry (" << key << ", " << value << ") not inserted. \nThe values " << NO_KEY << " and " << TOMBSTONE << " are not allowed as keys or values in this DB.\n";
        return -1;
    }
    return insert(key, value);
}
// Returns the value matching key, or NO_KEY if the key is not in the database or deleted.
int DB::get(int key) {
    // first searches the tree itself, and then searches the SSTs in-order if the result is not in the tree.
    int tree_search_result = mem_table->get(key);
   
    if(tree_search_result == TOMBSTONE){ // key has been deleted.
        std::cerr << "Key " << key << " not found.\n";
        return NO_KEY; 
    }
    if (tree_search_result != NO_KEY) return tree_search_result;
    
    // Iterate over SSTs
    bool buf = true;
    if ((DB::buffer_args[0] == 1) && (DB::buffer_args[1] == key)) buf = false;
    
    for(int sst_num = 0; sst_num < sst_counter; sst_num++){
        std::string sst_name = DB::name + "/sst" + std::to_string(sst_num) + ".bin";
        if (std::filesystem::exists(sst_name)){
            if (check_bloom_filters(sst_num, key)) {
                int sst_search_result;
                if (search_alg == BIN_SEARCH) {
                    sst_search_result = get_from_SST_Binary(sst_num, key, buf);
                } else { // btree search
                    sst_search_result = get_from_SST_BTree(sst_num, key, buf);
                }
                if ((sst_search_result != NO_KEY) && (sst_search_result != TOMBSTONE)){
                    DB::buffer_args[0] = 1;
                    DB::buffer_args[1] = key;
                    DB::buffer_args[2] = NO_KEY;
                    return sst_search_result;
                }
            }
        }
    }
    DB::buffer_args[0] = 1;
    DB::buffer_args[1] = key;
    DB::buffer_args[2] = NO_KEY;
    
    // Key not found
    std::cerr << "Key " << key << " not found.\n";
    return NO_KEY; 
}
// An update is identical to a put in our current DB, so update() simply calls put().
int DB::update(int key, int value){
    return put(key, value);
}
// To remove a key, we insert a tombstone. We can check for users trying to delete invalid keys, still.
int DB::remove(int key){
    if(key == NO_KEY || key == TOMBSTONE){
        std::cerr << "Entry with key " << key << " is not permitted in the DB, and can not be deleted. \nThe values " << NO_KEY << " and " << TOMBSTONE << " are not allowed as keys or values in this DB.\n";
        return -1;
    }
    return insert(key, TOMBSTONE);
}
// Retrieves all KV-pairs in a key range in key order.
struct kv_pairs* DB::scan(int key1, int key2) {  
    kv_pairs *scanned_kv = mem_table->scan(key1, key2); // Scan the memtable
    kv_pairs *SST_kv = new kv_pairs; // temporary struct for each SST to use
    kv_pairs *merged_kv = new kv_pairs; // temporary struct to hold the merged kv pairs before it is re-assigned to scanned_kv
    
    SST_kv->kv = (int**) malloc(sizeof(int*) * mem_table->max_size * sst_counter * 2); // todo: find a better bound
    bool buf = true;
    if ((DB::buffer_args[0] == 2) && (DB::buffer_args[1] == key1) && (DB::buffer_args[2] == key2)) buf = false;
    
    // Iterate over SSTs
    for (int sst_num = 0; sst_num < sst_counter; sst_num++) {
        // some layers of the sst may not exist due to lsm structure
        if (std::filesystem::exists((name + "/sst" + std::to_string(sst_num) + ".bin").c_str()) == 0) continue;
        SST_kv->kv_size = 0;
        if (search_alg == BIN_SEARCH) {
            scan_SST_Binary(sst_num, key1, key2, SST_kv, buf);
        } else {
            scan_SST_BTree(sst_num, key1, key2, SST_kv, buf);
        }
        merged_kv->kv = (int**) malloc(sizeof(int*) * (scanned_kv->kv_size + SST_kv->kv_size)); // allocate enough space for all pairs in both
        merge_kv_pairs(scanned_kv, SST_kv, merged_kv);
        free(scanned_kv->kv); // free the old scanned kv so it can be replaced by the new merged one
        scanned_kv->kv = merged_kv->kv;
        scanned_kv->kv_size = merged_kv->kv_size;
    }
    free(SST_kv->kv);
    delete(SST_kv); // no longer needed
    delete(merged_kv);
    remove_scanned_tombstones(scanned_kv); // clean scanned kv of tombstones for returning
    
    DB::buffer_args[0] = 2;
    DB::buffer_args[1] = key1;
    DB::buffer_args[2] = key2;
    return scanned_kv;
}
// Helper function for end-user. Frees a KV pairs object once the user no longer requires it.
void DB::freeKVpairs(kv_pairs* p) {
    mem_table->freeKVpairs(p);
}
// Setting the maximum directory size for the buffer pool
void DB::set_max_dir_size(int size) {
    int depth = sizeToDepth(size);
    DB::buffer->setMaxDirectoryDepth(depth);
}
// DB private methods
// Inserts a KV pair into the DB, without validating the key or value. Used by both put() and delete().
int DB::insert(int key, int value){
    // If full, move to SST before inserting
    if (mem_table->cur_size == mem_table->max_size){
        int new_size = transform_to_sst();
        if (sst_counter == 0) sst_counter++;
        sst_counter = std::max(sst_counter, new_size + 1); 
        mem_table->root = nullptr; // reset root
    }
    mem_table->put(key, value);
    return 0;
} 
// Because we implement Monkey, the number of bits assigned to a bloom filter depends on the level of the LSM tree it corresponds to.
int DB::get_bloom_size(int lvl) {
    int num_entry = pow(2, lvl) * (DB::mem_table->max_size);
    if (lvl >= DB::sst_counter - 1) return num_entry * (DB::bloomfilter_num_bits - 1); // lowest level of LSM tree
    return num_entry * (DB::bloomfilter_num_bits + (DB::sst_counter - 1 - lvl)); // other levels
}
// For a given key and LSM tree level, return the ln(2) * m hashes of our key.
std::vector<uint32_t> DB::bloom_hash(int lvl, int key){
    int m;
    if (lvl >= DB::sst_counter - 1){
        m = (DB::bloomfilter_num_bits - 1);
    } else {
        m = (DB::bloomfilter_num_bits + (DB::sst_counter - 1 - lvl));
    }
    int num_hash = round(m * log(2));
    int k = key;
    std::vector<uint32_t> hashes;
    for (int i = 0; i < num_hash; i++){
        uint32_t keyHash = (XXHash32::hash(&k, sizeof(int), i)) % (get_bloom_size(lvl));
        hashes.push_back(keyHash);
    }
    return hashes;
}
// Creates a bloom filter from a set of KV pairs. We assume this KV pairs corresponds to all of the pairs from level 'lvl' of the LSM tree.
std::vector<bool> DB::create_bloom_filter(int lvl, kv_pairs* scanned_kv){
    int size = DB::get_bloom_size(lvl);
    std::vector<bool> filter(size, false);
    for (int i = 0; i < scanned_kv->kv_size; i++) {
        std::vector<uint32_t> hashes = bloom_hash(lvl, scanned_kv->kv[i][0]);
        for (uint32_t h : hashes){
            filter[h] = true;
        }
    }
    return filter;
}
int DB::get_offset(int sst_num, int index, int num_leaf_pages, int sst_size, int leaf_page) {
    int offset;
    if (num_leaf_pages <= 1) {
        offset = leaf_page * PAGE_SIZE + floor(index/B) * PAGE_SIZE + (index % B) * sizeof(int) * 2;
    }
    // The last two pages needed to be split so that they'd have between B/2 and B leaves each
    // Check if the element on index is at one of the last two pages
    else if (index >= num_leaf_pages - 2 && sst_size % B != 0 && max_size % B != 0 && max_size % B < ceil(float(B)/float(2)))  { 
        if (index < (num_leaf_pages - 2) * B + ceil(float(B)/float(2))) { // on the first half --> second last page
            offset = leaf_page * PAGE_SIZE + (num_leaf_pages - 2) * PAGE_SIZE + (index % B) * sizeof(int) * 2;
        } else { // on last page
            offset = leaf_page * PAGE_SIZE + (num_leaf_pages - 1) * PAGE_SIZE + (index - (sst_size - (num_leaf_pages-2) * B  - ceil(float(B)/float(2)))) * sizeof(int) * 2;
        }
    }
    else {  
        offset = leaf_page * PAGE_SIZE + floor(index/B) * PAGE_SIZE + (index % B) * sizeof(int) * 2;
    }
    return offset;
}
// Returns the size of the <sst_num>th SST
int DB::get_SST_size(int sst_num) {
    int fd = open((name + "/sizes.bin").c_str(), O_RDONLY);
    int sst_size;
    pread(fd, &sst_size, sizeof(int), sizeof(int) * sst_num); // read sst size
    close(fd);
    return sst_size;
}
// Gets the offset in the SST of the key at index in cur_page of the SST
int DB::get_internal_node_offset(int cur_page, int index) {
    return cur_page * PAGE_SIZE + index * sizeof(int) * 2;
}
// Scans the SST with sst_num via B-Tree search for keys in range key1 to key2, and adds valid pairs into scanned_kv.
void DB::scan_SST_BTree(int sst_num, int key1, int key2, struct kv_pairs *scanned_kv, bool args) {
    // Initializing variables that will help with sizing of the internal nodes
    int sst_size = get_SST_size(sst_num);
    int num_layers = ceil(log(sst_size) / log(B));
    int cur_layer = 1;
    int node_size = 0;
    int cur_page = 0;
    // Calculate the number of leaf pages
    int num_leaf_pages = ceil((float) sst_size / (float) B);
    // Calculate the page number of the leaf page
    int num_pages [num_layers];               // Store the number of pages needed in each layer, starting with leaves and going up
    num_pages[0] = num_leaf_pages;
    int leaf_page = 0;
    // Count the number of pages that we need on each layer
    for (int i = 1; i < num_layers; i++) {
        // Check if we have already found a layer with the root node
        if (num_pages[i-1] == 1) {
            break;
        }
        num_pages[i] = ceil(float(num_pages[i-1]) / float(B));
        leaf_page += num_pages[i];
    }
    // Set up binary search
    int left = 1;   // left boundary of binary search
    int right = node_size;  // right boundary of binary search
    int middle; // current index to query in binary search
    int middle_value; // buffer to hold value of middle element. This will be the key.
    int idx; // will hold the least index in the binary search that is greater than key1
    int offset; // offset to search at
    std::string sst_name = (name + "/sst"  + std::to_string(sst_num) + ".bin").c_str();
    // Assertion: sst_size >= 1 (an empty SST is illogical). DOUBLE CHECK THIS
    if (sst_size < 1) return;
    // While cur_layer is not num_layers (we have not reached the leaf nodes yet),
    // keep searching the internal nodes with binary search
    while (cur_layer != num_layers) {
        if (cur_page >= leaf_page) break;
        // Binary search to find the first element in the range
        node_size = read_SST(sst_name, sst_num, cur_page * PAGE_SIZE, args);
        left = 1;   // 1-indexed, as node size takes up the first sizeof(int) of the SST
        right = node_size;  
        // If the rightmost key is less than key1, then it is on the last offset
        if (key1 > read_SST(sst_name, sst_num, get_internal_node_offset(cur_page, right), args)) {
            cur_layer++;
            // Assign cur_page to the offset suggested by the BTree (the offset after the key at index <right>)
            cur_page = read_SST(sst_name, sst_num, get_internal_node_offset(cur_page, right) + sizeof(int), args);
            continue;
        } 
        // Otherwise, binary search the internal nodes to find the last node that is greater than or equal to key
        // and navigate to the the internal node inferred
        while (left <= right) {
            middle = (left + right) / 2;
            middle_value = read_SST(sst_name, sst_num, get_internal_node_offset(cur_page, middle), args);
            // If key1 is greater than middle, then move left side to mid + 1
            if (middle_value < key1) {
                left = middle + 1;
            } else {
                idx  = middle;
                right = middle - 1;
            }
        }
        // Set cur page to the value indicated by the offset for idx and then increment cur_layer
        cur_page = read_SST(sst_name, sst_num, get_internal_node_offset(cur_page, idx) - sizeof(int), args); // offset should be stored at index before the found key
        cur_layer++;
    }
    // We should now have the page that our leaf node is in - binary search again here to find it
    // First we need to find the size of this page
    int page_size;
    
    if (num_leaf_pages <= 1) { // if there's only one leaf page, must contain all the leaves
        page_size = sst_size;
    } 
    else if (cur_page < num_leaf_pages + leaf_page - 2 || sst_size % B == 0) { 
        // not on last two pages, or the sst size is a multiple of B, so page size must be B
        page_size = B;
    } 
    else if (sst_size % B < ceil(float(B)/float(2))) { // the last two pages needed to be balanced
        if (cur_page == num_leaf_pages + leaf_page - 1) { // last page
            page_size = B - ceil(float(B)/float(2));
        } else { // second last page
            page_size = ceil(float(B)/float(2));
        }
    } 
    else if (cur_page == num_leaf_pages + leaf_page - 1) {   
        // on last page, did not need to be balanced
        page_size = B;
    } 
    else {
        // on second last page, did not need to be balanced
        page_size = num_leaf_pages % B;
    }
    // Binary search the leaves
    left = 0;
    right = page_size;
    while(left < right){
        middle = (left + right) / 2;
        middle_value = read_SST(sst_name, sst_num, cur_page * PAGE_SIZE + middle * sizeof(int) * 2, args);
        if (middle_value < key1){left = middle + 1;} // if value is less than key, nothing before it is in-range.
        else {right = middle;} // otherwise, the value could possibly be the smallest in-range element. 
    }
    // Scan to get all the elements
    int kvpair_buffer[2]; // buffer to hold current KV pair being read in
    offset = cur_page * PAGE_SIZE + left * sizeof(int) * 2; // Get the offset
    
    while (offset < leaf_page * PAGE_SIZE + num_leaf_pages * PAGE_SIZE) {
        kvpair_buffer[0] = read_SST(sst_name, sst_num, offset, args);
        if(kvpair_buffer[0] > key2) break; // stop reading from this SST once we have surpassed range
        // move to next page
        if (kvpair_buffer[0] == NO_KEY) {
            cur_page += 1;
            offset = cur_page * PAGE_SIZE;
            continue;
        }
        kvpair_buffer[1] = read_SST(sst_name, sst_num, offset + sizeof(int), args);
        int * pair;
        pair = new int [2];
        pair[0] = kvpair_buffer[0];
        pair[1] = kvpair_buffer[1];
        scanned_kv->kv[scanned_kv->kv_size] = pair;
        scanned_kv->kv_size += 1;
        offset += sizeof(int) * 2;
    }
}
// Scans the SST with sst_num via Binary search for keys in range key1 to key2, and adds valid pairs into scanned_kv.
void DB::scan_SST_Binary(int sst_num, int key1, int key2, struct kv_pairs *scanned_kv, bool args){
    int sst_size = get_SST_size(sst_num);
    if (sst_size < 1) return; // confirm that SST is non-empty.
    // we are going to binary search to find the smallest element that is in-range, and then 
    // iterate through each element to add to our list of kv-pairs until we reach one out of range.
    // set up binary search
    int left = 0; // left boundry of binary search
    int right = sst_size - 1; // right boundry of binary search
    int middle; // current index to query in binary search
    int middle_value; // buffer to hold value of middle element. This will be the key.
    int offset; // offset in SST to read at.
    std::string sst_name = (name + "/sst"  + std::to_string(sst_num) + ".bin").c_str();
    int num_leaf_pages = ceil((float) sst_size / (float) B);
    // Calculate the page number of the leaf page
    int num_layers = ceil(log(sst_size) / log(B));  // O(log_B(max_size)) internal node layers
    int num_pages [num_layers];               // Store the number of pages needed in each layer, starting with leaves and going up
    num_pages[0] = num_leaf_pages;
    int leaf_page = 0;
    // Count the number of pages that we need on each layer
    for (int i = 1; i < num_layers; i++) {
        // Check if we have already found a layer with the root node
        if (num_pages[i-1] == 1) break;
        num_pages[i] = ceil(float(num_pages[i-1]) / float(B));
        leaf_page += num_pages[i];
    }
    // Once we exit the loop, we will have left = right = smallest element in-range.
    while(left < right){
        middle = (left + right) / 2;
        // Get the offset of the middle index
        offset = get_offset(sst_num, middle, num_leaf_pages, sst_size, leaf_page);
        middle_value = read_SST(sst_name, sst_num, offset, args);
        if (middle_value < key1){left = middle + 1;} // if value is less than key, nothing before it is in-range.
        else {right = middle;} // otherwise, the value could possibly be the smallest in-range element. 
    }
    int kvpair_buffer[2]; // buffer to hold current KV pair being read in
    // Get the page that left is on
    int cur_page = floor(float(get_offset(sst_num, left, num_leaf_pages, sst_size, leaf_page)) / PAGE_SIZE);
    
    for (int i = left; i < sst_size; i++) {
        offset = get_offset(sst_num, i, num_leaf_pages, sst_size, leaf_page);
        kvpair_buffer[0] = read_SST(sst_name, sst_num, offset, args);
        if(kvpair_buffer[0] > key2) break; // stop reading from this SST once we have surpassed range
        // move to next page
        if (kvpair_buffer[0] == INT_MIN) {
            cur_page += 1;
            offset = cur_page * PAGE_SIZE;
            continue;
        }
        kvpair_buffer[1] = read_SST(sst_name, sst_num, offset + sizeof(int), args);
        int * pair;
        pair = new int [2];
        pair[0] = kvpair_buffer[0];
        pair[1] = kvpair_buffer[1];
        scanned_kv->kv[scanned_kv->kv_size] = pair;
        scanned_kv->kv_size += 1;
    }
}
// Merges kv_1 and kv_2 together, storing them in merged_kv.
// Assumes kv_1 is newer than kv_2, and thus always prefers kv_1 over kv_2 when keys match.
void DB::merge_kv_pairs(struct kv_pairs *kv_1, struct kv_pairs *kv_2, struct kv_pairs *merged_kv){
    int kv_1_count = 0;
    int kv_2_count = 0;
    int merged_count = 0; // Keep track of number of discarded elements
    // 
    while ((kv_1_count < kv_1->kv_size) && (kv_2_count < kv_2->kv_size)){ 
        if (kv_1->kv[kv_1_count][0] < kv_2->kv[kv_2_count][0]){
            merged_kv->kv[merged_count] = kv_1->kv[kv_1_count];
            kv_1_count++;
        }
        else if (kv_1->kv[kv_1_count][0] > kv_2->kv[kv_2_count][0]){
            merged_kv->kv[merged_count] = kv_2->kv[kv_2_count];
            kv_2_count++;
        }
        else{ // matching keys, discard older version
            merged_kv->kv[merged_count] = kv_1->kv[kv_1_count];
            kv_1_count++;
            kv_2_count++;
        }
        merged_count++;
    }
    // insert all remaining elements from kv_1
    while(kv_1_count < kv_1->kv_size){
        merged_kv->kv[merged_count] = kv_1->kv[kv_1_count];
        kv_1_count++;
        merged_count++;
    }
    // insert all remaining elements from kv_2. Only one of these two while loops will execute.
    while(kv_2_count < kv_2->kv_size){
        merged_kv->kv[merged_count] = kv_2->kv[kv_2_count];
        kv_2_count++;
        merged_count++;
    }
    merged_kv->kv_size = merged_count;    
}
// Takes final kv pairs struct from scan, and removes any tombstones present.
void DB::remove_scanned_tombstones(struct kv_pairs *scanned_kv){
    // Create new array of kv-pairs
    int** clean_kv = (int**) malloc(sizeof(int*) * (scanned_kv->kv_size));
    int clean_index = 0;
    for(int i = 0; i < scanned_kv->kv_size; i++){
        if(scanned_kv->kv[i][1] != TOMBSTONE){
            clean_kv[clean_index] = scanned_kv->kv[i];
            clean_index++;
        }
    }
    free(scanned_kv->kv); // remove old kv array
    scanned_kv->kv = clean_kv;
    scanned_kv->kv_size = clean_index;
}
// gets the element matching key from SST sst_num, using B-Tree search.
int DB::get_from_SST_BTree(int sst_num, int key, bool args) {
    // Initializing variables that will help with sizing of the internal nodes
    int sst_size = get_SST_size(sst_num);
    int num_layers = ceil(log(sst_size) / log(B));
    int cur_layer = 1;
    int node_size = 0;
    int cur_page = 0;
    // Calculate the number of leaf pages
    int num_leaf_pages = ceil((float) sst_size / (float) B);
    // Calculate the page number of the leaf page
    int num_pages [num_layers]; // Store the number of pages needed in each layer, starting with leaves and going up
    num_pages[0] = num_leaf_pages;
    int leaf_page = 0;
    // Count the number of pages that we need on each layer
    for (int i = 1; i < num_layers; i++) {
        // Check if we have already found a layer with the root node
        if (num_pages[i-1] == 1) break;
        num_pages[i] = ceil(float(num_pages[i-1]) / float(B));
        leaf_page += num_pages[i];
    }
    // Set up binary search
    int left = 1;   // left boundary of binary search
    int right = node_size;  // right boundary of binary search
    int middle; // current index to query in binary search
    int middle_value; // buffer to hold value of middle element. This will be the key.
    int idx; // will hold the least index in the binary search that is greater than key1
    int offset; // offset to search at
    std::string sst_name = (name + "/sst"  + std::to_string(sst_num) + ".bin").c_str();
    // Assertion: sst_size >= 1 (an empty SST is illogical). 
    if (sst_size < 1) return NO_KEY;
    
    // While cur_layer is not num_layers (we have not reached the leaf nodes yet),
    // keep searching the internal nodes with binary search
    while (cur_layer != num_layers) {
        if (cur_page >= leaf_page) break;
        // Binary search to find the first element in the range
        node_size = read_SST(sst_name, sst_num, cur_page * PAGE_SIZE, args);
        left = 1;   // 1-indexed, as node size takes up the first sizeof(int) of the SST
        right = node_size;  
        // If the rightmost key is less than key1, then it is on the last offset
        if (key > read_SST(sst_name, sst_num, get_internal_node_offset(cur_page, right), args)) {
            cur_layer++;
            // Assign cur_page to the offset suggested by the BTree (the offset after the key at index <right>)
            cur_page = read_SST(sst_name, sst_num, get_internal_node_offset(cur_page, right) + sizeof(int), args);
            continue;
        } 
        // Otherwise, binary search the internal nodes to find the last node that is greater than or equal to key
        // and navigate to the the internal node inferred
        while (left <= right) {
            middle = (left + right) / 2;
            middle_value = read_SST(sst_name, sst_num, get_internal_node_offset(cur_page, middle), args);
            // If key1 is greater than middle, then move left side to mid + 1
            if (middle_value < key) {
                left = middle + 1;
            } else {
                idx  = middle;
                right = middle - 1;
            }
        }
        // Set cur page to the value indicated by the offset for idx and then increment cur_layer
        cur_page = read_SST(sst_name, sst_num, get_internal_node_offset(cur_page, idx) - sizeof(int), args); // offset should be stored at index before the found key
        cur_layer++;
    }
    // We should now have the page that our leaf node is in - binary search again here to find it
    // First we need to find the size of this page
    int page_size;
    
    if (num_leaf_pages <= 1) { // if there's only one leaf page, must contain all the leaves
        page_size = sst_size;
    } 
    else if (cur_page < num_leaf_pages + leaf_page - 2 || sst_size % B == 0) { 
        // not on last two pages, or the sst size is a multiple of B, so page size must be B
        page_size = B;
    } 
    else if (sst_size % B < ceil(float(B)/float(2))) { // the last two pages needed to be balanced
        if (cur_page == num_leaf_pages + leaf_page - 1) { // last page
            page_size = B - ceil(float(B)/float(2));
        } else { // second last page
            page_size = ceil(float(B)/float(2));
        }
    } 
    else if (cur_page == num_leaf_pages + leaf_page - 1) {   
        // on last page, did not need to be balanced
        page_size = B;
    } 
    else {
        // on second last page, did not need to be balanced
        page_size = num_leaf_pages % B;
    }
     // Binary search the leaves
    left = 0;
    right = page_size;
    while(left <= right){
        middle = (left + right) / 2;
        middle_value = read_SST(sst_name, sst_num, cur_page * PAGE_SIZE + middle * sizeof(int) * 2, args);
        if (middle_value < key){left = middle + 1;}
        if (middle_value > key){right = middle - 1;}
        if (middle_value == key) return read_SST(sst_name, sst_num, cur_page * PAGE_SIZE + middle * sizeof(int) * 2 + sizeof(int), args);
    }
    return NO_KEY;
}
// gets the element matching key from SST sst_num, using Binary search.
int DB::get_from_SST_Binary(int sst_num, int key, bool args){
    int sst_size = get_SST_size(sst_num);
    int num_leaf_pages = ceil((float) sst_size / (float) B); // Calculate the number of leaf pages
    // Calculate the page number of the leaf page
    int num_layers = ceil(log(sst_size) / log(B));  // O(log_B(max_size)) internal node layers
    int num_pages [num_layers];               // Store the number of pages needed in each layer, starting with leaves and going up
    num_pages[0] = num_leaf_pages;
    int leaf_page = 0;
    std::string sst_name = (name + "/sst"  + std::to_string(sst_num) + ".bin").c_str();
    // Count the number of pages that we need on each layer
    for (int i = 1; i < num_layers; i++) {
        // Check if we have already found a layer with the root node
        if (num_pages[i-1] == 1) break;
        num_pages[i] = ceil(float(num_pages[i-1]) / float(B));
        leaf_page += num_pages[i];
    }
    // set up binary search
    int left = 0; // left boundry of binary search
    int right = sst_size - 1; // right boundry of binary search
    int middle; // current index to query in binary search
    int middle_value; // buffer to hold value of middle element. This will be the key, but if the key matches we can write the value to here instead.
    int offset; // offset in SST to read middle at.
    while(left <= right){
        middle = (left + right) / 2;
        offset = get_offset(sst_num, middle, num_leaf_pages, sst_size, leaf_page);
        middle_value = read_SST(sst_name, sst_num, offset, args);
        if (middle_value < key){ left = middle + 1; }
        if (middle_value > key){ right = middle - 1; }
        if (middle_value == key) { return read_SST(sst_name, sst_num, offset + sizeof(int), args); }
    }
    return NO_KEY;
}
// Reads the value at offset in SST number sst_num. First checks the buffer, and only reads from file if it is not in the buffer.
// Note that the buffer page is defined by memtable_maxsize * 2^(sst_num) + (offset / PAGE_SIZE).
int DB::read_SST(std::string sst_name, int sst_num, int offset, bool args){
    int page_num = (((1 << sst_num) - 1) * mem_table->max_size) + (offset / PAGE_SIZE); // page we are reading from
    
    if(DEBUG) assert(page_num < (((1 << (sst_num + 1)) - 1) *  mem_table->max_size) ); // assert all reads from SST N do not bleed into SST N + 1 (page numbering is valid)
    int page_offset = offset % PAGE_SIZE; // offset within page
    // Case where there is no buffer, so we read directly from file. DOES NOT USE DIRECT I/O.
    if(buffer->no_buffer){
        int read_val;
        int fd = open(sst_name.c_str(), O_RDONLY);
        pread(fd, &read_val, sizeof(int), offset);
        close(fd);
        return read_val;
    }
    // Case where we have a buffer. USES DIRECT I/O
    char* page;
    page = buffer->getPage(page_num);
    if (page == nullptr){ // if page is not in buffer
        posix_memalign(reinterpret_cast<void**>(&page), PAGE_SIZE, PAGE_SIZE); // Allocate 4KB of page-alligned memory for new page
        // Read in page
        int fd = open((name + "/sst"  + std::to_string(sst_num) + ".bin").c_str(), O_DIRECT);
        pread(fd, page, PAGE_SIZE, (offset / PAGE_SIZE) * PAGE_SIZE);
        close(fd);
        if (args) buffer->insertPage(page_num, page);
    }
    // We now have the needed page.
    int* value = (int*) (page + page_offset); // is this legal? "I will make it so"
    return *value;
}
void AVLTree::update_size(int sst_num, int sst_size) {
    int fd = open((db_name + "/sizes.bin").c_str(), O_RDWR | O_CREAT, 0777);
    pwrite(fd, &sst_size, sizeof(int), sst_num * sizeof(int));
    close(fd);
}
// Get the page where the leaf nodes start on the SST with sst_num
int AVLTree::get_leaf_page(int sst_size) {
    // Calculate the number of leaf pages
    int num_leaf_pages = ceil((float) sst_size / (float) B);
    // Calculate the page number of the leaf page
    int num_layers = ceil(log(sst_size) / log(B));  // O(log_B(max_size)) internal node layers
    int num_pages [num_layers];               // Store the number of pages needed in each layer, starting with leaves and going up
    num_pages[0] = num_leaf_pages;
    int leaf_page = 0;
    // Count the number of pages that we need on each layer
    for (int i = 1; i < num_layers; i++) {
        // Check if we have already found a layer with the root node
        if (num_pages[i-1] == 1) break;
        num_pages[i] = ceil(float(num_pages[i-1]) / float(B));
        leaf_page += num_pages[i];
    }
    return leaf_page;
}
int AVLTree::count_num_repeats(int size1, int size2, int fd1, int fd2, int in1, int in2) {
    int num_read1 = 0;
    int num_read2 = 0;
    int page1 = get_leaf_page(size1);
    int page2 = get_leaf_page(size2); 
    int offset1 = PAGE_SIZE * page1;
    int offset2 = PAGE_SIZE * page2;
    int num_repeats = 0;
    while (num_read1 < size1 || num_read2 < size2) {
        // Reads the next key into the buffer if we are out of data in the page.
        // INT_MIN means nothing is there, need to move to the next page to find more data.
        if (in1 == INT_MIN && num_read1 < size1) {
            page1 += 1;
            offset1 = page1 * PAGE_SIZE;
            pread(fd1, &in1, sizeof(int), offset1);
        } 
        if (in2 == INT_MIN && num_read2 < size2) {
            page2 += 1;
            offset2 = page2 * PAGE_SIZE;
            pread(fd2, &in2, sizeof(int), offset2);
        }
        // Compare the keys and put the lesser one into the output buffer. If they are equal, we ignore sst2 (as it is older)
        // Then, read the next page.
        if ((in1 <= in2 && num_read1 < size1) || num_read2 >= size2) { // Read in1 if in1 <= in2 or sst2 has already been fully read
            if (in1 == in2 && num_read2 < size2) { // If equal, we need to skip over input2 and read the next input from SST2
                num_read2 += 1;
                num_repeats++;
                offset2 += sizeof(int) * 2;
                pread(fd2, &in2, sizeof(int), offset2);
            }
            pread(fd1, &in1, sizeof(int), offset1 + sizeof(int));
            num_read1 += 1;
            offset1 += sizeof(int) * 2;
            if (num_read1 < size1) { pread(fd1, &in1, sizeof(int), offset1); }
        } 
        else if ((in1 > in2 && num_read2 < size2) || num_read1 >= size1) {   // Read in2 if in2 < in1 or sst1 has already been fully read
            pread(fd2, &in2, sizeof(int), offset2 + sizeof(int));
            num_read2 += 1;
            offset2 += sizeof(int) * 2;
            if (num_read2 < size2) { pread(fd2, &in2, sizeof(int), offset2); } 
        } 
    }
    return num_repeats;
}
// Compacts the SSTs with IDs sst_num1 and sst_num2 and returns the sst_num of the highest compacted SST
// sst_num1 is on level i, sst_num2 is on level i + 1, they get merged to level i + 2
// NOTE: always pass the younger SST (the one on the lower level) as sst_num1
// If temp is > 0, then we have recursively called compact while merging from a lower layer
// If temp = 0, then we compact sst_num1 and sst_num2
// If temp = 1 then the first SST to merge is temp1.bin
// If temp = 2 then the first SST to merge is is temp2.bin
// The need for two temps is because we may need to continuously merge upwards
int DB::compact_sst(int sst_num1, int sst_num2, int has_temp, int temp_size) {
    // Variables needed for compaction
    int size1;
    if (has_temp == 0) { size1 = get_SST_size(sst_num1); } else {size1 = temp_size; }
    int size2 = get_SST_size(sst_num2);
    int new_size = size1 + size2;
    int page1 = DB::mem_table->get_leaf_page(size1);
    int page2 = DB::mem_table->get_leaf_page(size2); 
    int offset1 = PAGE_SIZE * page1;
    int offset2 = PAGE_SIZE * page2;
    int num_read1 = 0, num_read2 = 0;  // The number of elements read from each SST
    bool final_out = false;  // Is this the final call that outputs sst file
    int final_level = -1;
    std::vector<bool> Bfilter;
    // Names of the SST files
    std::string sst1_name;
    std::string sst2_name = (DB::name + "/sst" + std::to_string(sst_num2) + ".bin").c_str();
    std::string sst_out_name;
    if (has_temp == 0) {
        sst1_name = (DB::name + "/sst" + std::to_string(sst_num1) + ".bin").c_str();
        DB::mem_table->update_size(sst_num1, 0); // we are merging this one upwards, so it is now 0 sized
    } else if (has_temp == 1) {
        sst1_name = (DB::name + "/temp1.bin").c_str();
    } else {
        sst1_name = (DB::name + "/temp2.bin").c_str();
    }
    // If there is already an SST on the level we want to sort-merge to, we need to create a temp
    // file to hold the sort-merged SST before sort-merging it on the next level
    int temp_file_num = 0;
    bool need_temp = false;
    int fd_new;
    if (std::filesystem::exists((DB::name + "/sst" + std::to_string(sst_num2 + 1) + ".bin").c_str()) != 0) {
        need_temp = true;
        if (has_temp == 1) {
            temp_file_num = 2;
        } else {
            temp_file_num = 1;
        }
        sst_out_name = (DB::name + "/temp" + std::to_string(temp_file_num) + ".bin").c_str();
    } 
    else {
        final_out = true;
        final_level = sst_num2 + 1;
        Bfilter.assign(get_bloom_size(final_level), false);
        sst_out_name = (DB::name + "/sst" + std::to_string(sst_num2 + 1) + ".bin").c_str();
    }
    fd_new = open(sst_out_name.c_str(), O_RDWR | O_CREAT | O_TRUNC, 0777); 
    int fd1 = open((sst1_name).c_str(), O_RDONLY);
    int fd2 = open((sst2_name).c_str(), O_RDONLY);
    // Buffer related variables
    int kvpair_buffer[2]; // buffer to hold current KV pair being read in
    int in1; int in2;
    pread(fd1, &in1, sizeof(int), offset1);
    pread(fd2, &in2, sizeof(int), offset2);
    int out [B * 2];   // Output buffer size B as there are B leaves per page
    int curr_out = 0; // The current size of the output buffer
    int num_repeats = DB::mem_table->count_num_repeats(size1, size2, fd1, fd2, in1, in2);
    new_size -= num_repeats;
    // Variables needed for constructing the BTree in the merged SST
    int out_leaf_page = DB::mem_table->get_leaf_page(new_size); // The location where we start writing the leaf pages in the new SST
    int cur_out_page = 0;
    int num_leaf_pages = ceil((float) (new_size) / (float)B);
    int max_keys [num_leaf_pages + 1];  // the max value of each leaf page, so that we can assign internal nodes later
    int max_offsets [num_leaf_pages + 1]; // the page where each value in max_keys is contained (then do offset * PAGE_SIZE on retrieval)
    int cur_idx = 0; // index for max keys and max offsets
    int num_layers = ceil(log(new_size) / log(B));  // O(log_B(max_size)) internal node layers
    int num_pages [num_layers];               // Store the number of pages needed in each layer, starting with leaves and going up
    num_pages[0] = num_leaf_pages;
    int no_idx = INT_MIN;
    // Count the number of pages that we need on each layer
    for (int i = 1; i < num_layers; i++) {
        // Check if we have already found a layer with the root node
        if (num_pages[i-1] == 1) break;
        num_pages[i] = ceil(float(num_pages[i-1]) / float(B));
    }
    int nodes_per_page = B;
    while (num_read1 < size1 || num_read2 < size2) {
        // Reads the next key into the buffer if we are out of data in the page.
        // INT_MIN means nothing is there, need to move to the next page to find more data.
        if (in1 == INT_MIN && num_read1 < size1) {
            page1 += 1;
            offset1 = page1 * PAGE_SIZE;
            pread(fd1, &in1, sizeof(int), offset1);
        } 
        if (in2 == INT_MIN && num_read2 < size2) {
            page2 += 1;
            offset2 = page2 * PAGE_SIZE;
            pread(fd2, &in2, sizeof(int), offset2);
        }
        // Compare the keys and put the lesser one into the output buffer. If they are equal, we ignore sst2 (as it is older)
        // Then, read the next page.
        if ((in1 <= in2 && num_read1 < size1) || num_read2 >= size2) { // Read in1 if in1 <= in2 or sst2 has already been fully read
            if (in1 == in2 && num_read2 < size2) { // If equal, we need to skip over input2 and read the next input from SST2
                num_read2 += 1;
                offset2 += sizeof(int) * 2;
                pread(fd2, &in2, sizeof(int), offset2);
            }
            out[curr_out * 2] = in1; // Key
            if (final_out){
                std::vector<uint32_t> hashes = bloom_hash(final_level, in1);
                for (uint32_t h : hashes){
                    Bfilter[h] = true;
                }
                // hash in1
            }
            pread(fd1, &in1, sizeof(int), offset1 + sizeof(int));
            out[curr_out * 2 + 1] = in1; // Value
            curr_out += 1;
            num_read1 += 1;
            offset1 += sizeof(int) * 2;
            if (num_read1 < size1) { pread(fd1, &in1, sizeof(int), offset1); }
        } 
        else if ((in1 > in2 && num_read2 < size2) || num_read1 >= size1) {   // Read in2 if in2 < in1 or sst1 has already been fully read
            out[curr_out * 2] = in2;
            if (final_out){
                std::vector<uint32_t> hashes = bloom_hash(final_level, in2);
                for (uint32_t h : hashes){
                    Bfilter[h] = true;
                }
            }
            pread(fd2, &in2, sizeof(int), offset2 + sizeof(int));
            out[curr_out * 2 + 1] = in2;
            curr_out += 1;
            num_read2 += 1;
            offset2 += sizeof(int) * 2;
            if (num_read2 < size2) { pread(fd2, &in2, sizeof(int), offset2); } 
        } 
        // Flush the buffer if it is full (also need to be aware of the splitting at the end)
        if ((curr_out >= B) || (cur_out_page == num_leaf_pages - 2 && new_size % B != 0 && new_size % B < ceil(float(B) / float(2)) && curr_out >= ceil(float(B) / float(2)))) {
            // Store max keys and max offsets
            max_keys[cur_idx] = out[curr_out * 2 - 2];
            max_offsets[cur_idx] = out_leaf_page + cur_out_page;
            cur_idx += 1;
            // Flush
            pwrite(fd_new, out, sizeof(int) * curr_out * 2, (out_leaf_page + cur_out_page) * PAGE_SIZE);
            // Need to fill the next slot with INT_MIN
            pwrite(fd_new, &no_idx, sizeof(int), (out_leaf_page + cur_out_page) * PAGE_SIZE + (curr_out * 2) * sizeof(int));
            cur_out_page += 1;
            curr_out = 0;
        }
    } 
    // If there is still data in the buffer, flush it
    if (curr_out != 0) {
        // Store max keys and max offsets
        max_keys[cur_idx] = out[curr_out * 2 - 2];
        max_offsets[cur_idx] = out_leaf_page + cur_out_page;
        cur_idx += 1;
        // Flush
        pwrite(fd_new, out, sizeof(int) * curr_out * 2, (out_leaf_page + cur_out_page) * PAGE_SIZE);
        pwrite(fd_new, &no_idx, sizeof(int), (out_leaf_page + cur_out_page) * PAGE_SIZE + (curr_out * 2) * sizeof(int));
        cur_out_page += 1;
        curr_out = 0;
    }
    // Remake the internal nodes
    // Loop through the remaining layers to write their internal nodes
    int cur_psize;
    int cur_page;
    int offset;
    for (int i = 1; i < num_layers; i++) {
        if (num_pages[i] == 0) break;   // no elements in this layer
        // Get the layer offset
        if (num_pages[i] == 1) {
            offset = 0;
        } else {
            offset = array_sum(num_pages, i+1, num_layers);
        }
        cur_psize = 0;
        cur_page = 0;
        nodes_per_page = B-1;    
        // Write to the SST
        for (int j=0; j < num_pages[i-1]; j++) {
            // Check if we're on the second last page and if we need to split it to avoid inbalance
            if (cur_page == num_pages[i] - 2 && (num_pages[i-1] % (B) != 0) && ((num_pages[i-1] % (B)) < (B)/2)) {
                nodes_per_page = ceil(float(B) / float(2));
            }
            // If we have filled up this page, move to write to the next page
            if (cur_psize >= nodes_per_page || j == num_pages[i-1] - 1 || j == num_pages[i-1]-1) {
                // Write the offset as we need it to access this last node
                pwrite(fd_new, &max_offsets[j], sizeof(int), (offset * PAGE_SIZE) + (PAGE_SIZE * cur_page) + (cur_psize * sizeof(int) * 2) + sizeof(int));
                pwrite(fd_new, &cur_psize, sizeof(int), (offset * PAGE_SIZE) + (PAGE_SIZE * cur_page));
                // Also set the max value of this internal node to its rightmost max value
                max_keys[cur_page] = max_keys[j];
                max_offsets[cur_page] = cur_page + offset;
                cur_page ++;
                cur_psize = 0;
            }
            else {
                // Write to (Start of internal node pages + page offset + element offset)
                // Write alternating offsets and keys
                pwrite(fd_new, &max_offsets[j], sizeof(int), (offset * PAGE_SIZE) + (PAGE_SIZE * cur_page) + (cur_psize * sizeof(int) * 2) + sizeof(int));
                pwrite(fd_new, &max_keys[j], sizeof(int), (offset * PAGE_SIZE) + (PAGE_SIZE * cur_page) + (cur_psize * sizeof(int) * 2) + sizeof(int) * 2);
                cur_psize++;
            }
        }
    }
    // If we have a temp file, we need to merge upwards again with tempfile and (sst2 + 1).bin
    int rv;
    if (need_temp) {
        rv = compact_sst(-1, sst_num2 + 1, temp_file_num, new_size); // merge temp sst and sst_num2+1 sst
    } else {
        rv = sst_num2 + 1;
    }
    if (final_out) write_bloom_filters(Bfilter, final_level);
    close(fd1);
    close(fd2);
    close(fd_new);
    std::string fp_bf1 = DB::name + "/bloomFilter" + std::to_string(sst_num1) + ".bin";
    std::string fp_bf2 = DB::name + "/bloomFilter" + std::to_string(sst_num2) + ".bin";
    if (std::filesystem::exists(fp_bf1)) std::remove(fp_bf1.c_str()); // Delete bloom filter file for sst1
    if (std::filesystem::exists(fp_bf2)) std::remove(fp_bf2.c_str()); // Delete bloom filter file for sst2
    std::remove(sst1_name.c_str()); // Delete sst1 file
    std::remove(sst2_name.c_str()); // Delete sst2 file
    DB::mem_table->update_size(sst_num2, 0); // we merged this one upwards as well, so nothing is left here
    DB::mem_table->update_size(sst_num2 + 1, new_size); // the sort-merged file on the new level that we created
    buffer->emptyBuffer(); // clear the buffer, to ensure that there are no stale pages within it
    return rv;
}
void DB::write_bloom_filters(const std::vector<bool>& data, int lvl) {
    std::string filename = name + "/bloomFilter" + std::to_string(lvl) + ".bin";
    std::ofstream outFile(filename, std::ios::binary);
    if (!outFile.is_open()) {
        std::cerr << "Error opening file for writing." << std::endl;
        return;
    }
    // Calculate the number of bytes needed (ceil)
    std::size_t byteCount = (data.size() + 7) / 8;
    // Write the data to the file, packing bits into bytes
    for (std::size_t i = 0; i < byteCount; ++i) {
        char byte = 0;
        for (std::size_t j = 0; j < 8 && i * 8 + j < data.size(); ++j) {
            if (data[i * 8 + j]) byte |= (1 << j);
        }
        outFile.put(byte);
    }
    outFile.close();
}
int DB::bloom_filter_id(int num1, int num2){
    return - (pow(2, num1) * ceil((DB::mem_table->max_size) * 0.001) + num2);  // magic, empirically impossible to overflow
}
bool DB::check_bloom_filters(int sst_num, int key) {
    std::vector<uint32_t> hashes = bloom_hash(sst_num, key);
    for (uint32_t h : hashes) {
        int page = h / (PAGE_SIZE*8);
        int page_id = bloom_filter_id(sst_num, page);
        int page_offset = h % (PAGE_SIZE*8);
        char* data;
        data = buffer->getPage(page_id);
        if (data == nullptr){
            posix_memalign(reinterpret_cast<void**>(&data), PAGE_SIZE, PAGE_SIZE); // Allocate 4KB of page-alligned memory for new page
            // Read in page
            int fd = open((DB::name + "/bloomFilter" + std::to_string(sst_num) + ".bin").c_str(), O_DIRECT);
            pread(fd, data, PAGE_SIZE, (h / (PAGE_SIZE*8)) * PAGE_SIZE);
            close(fd);
            buffer->insertPage(page_id, data);
        }
        // read the actual bit
        std::size_t byteIndex = page_offset / 8;
        std::size_t bitIndex = page_offset % 8;
        // Extract the bit using bitwise operations
        char byte = data[byteIndex];
        if ((byte & (1 << bitIndex)) == 0) return false;
    }
    return true;
}
// Flushes the memtable to disk by converting it to a SST.
// Return value is the sst_num of the SST created by flushing to disk, after compaction (if needed) has been performed.
int DB::transform_to_sst() {
    if(mem_table->cur_size == 0) return 0; // If memtable is empty, there is nothing to flush.
    // Create struct perfectly size-alligned to memtable
    kv_pairs *scanned_kv = new kv_pairs;
    scanned_kv->kv = (int**) malloc(sizeof(int*) * DB::mem_table->cur_size);
    scanned_kv->kv_size = 0;
    
    DB::mem_table->scan_node(DB::mem_table->root, 0, 0, scanned_kv, true); // Scan the full memtable
    // The max size of the directory when we transform the SST
    max_size = scanned_kv->kv_size;
    // Write the memtable to a file
    mkdir(DB::name.c_str(), 0777);
    
    int fd;
    std::string sst0 = (DB::name + "/sst0.bin").c_str();
    std::string sst_name;
    bool to_compact = false;
    // If sst0.bin does not exist, increment sst_counter and put this into sst0
    if (std::filesystem::exists(sst0.c_str()) == 0) {
        std::vector<bool> filter = create_bloom_filter(0, scanned_kv);
        write_bloom_filters(filter, 0);
        sst_name = (DB::name + "/sst0.bin").c_str();
        // Store the number of leaves that are being transformed to SST
        fd = open((DB::name + "/sizes.bin").c_str(), O_RDWR | O_CREAT, 0777);
        pwrite(fd, &(scanned_kv->kv_size), sizeof(int), 0); // because this is SST 0, we start at offset 0
        close(fd);
    } else {
        // If sst0.bin already exists, we need to create a temp file and then merge them together to new file sst1.bin
        sst_name = (DB::name + "/temp1.bin");
        to_compact = true;
    }
    fd = open(sst_name.c_str(), O_RDWR | O_CREAT | O_TRUNC, 0777); // open binary file sst<sst_counter>.bin
    // Converts to a B-Tree format before storing it into the STT
    
    // Counting the number of elements per layer
    int num_layers = ceil(log(max_size) / log(B));  // O(log_B(max_size)) internal node layers
    int num_pages [num_layers];               // Store the number of pages needed in each layer, starting with leaves and going up
    int offset = 0;
    
    int num_leaf_pages = ceil((float)max_size / (float)B);
    num_pages[0] = num_leaf_pages;
    int max_keys [num_leaf_pages + 1];  // the max value of each leaf page, so that we can assign internal nodes later
    int max_offsets [num_leaf_pages + 1]; // the page where each value in max_keys is contained (then do offset * PAGE_SIZE on retrieval)
    // Count the number of pages that we need on each layer
    for (int i = 1; i < num_layers; i++) {
        // Check if we have already found a layer with the root node
        if (num_pages[i-1] == 1) break;
        num_pages[i] = ceil(float(num_pages[i-1]) / float(B));
        offset += num_pages[i];
    }
    int cur_psize = 0;                      // current size of the page (number of kv pairs currently in the page)
    int cur_page = 0;                       // the leaf page we are currently writing to
    int nodes_per_page = B;                 // the maximum number of nodes we can write to the page
    
    // Iterate through leaves and write them to their proper page with the proper offset
    for (int i = 0; i < scanned_kv->kv_size; i++) {
        // If the last page needs to be split (to avoid being unbalanced), then we only write B/2 elements on the second last page
        if (cur_page == num_leaf_pages - 2 && (max_size % B != 0) && ((max_size % B) < ceil(float(B)/float(2)))) {
            nodes_per_page = ceil(float(B) / float(2));
        } 
        // Write to (Start of leaf pages + page offset + element offset)
        pwrite(fd, scanned_kv->kv[i], sizeof(int) * 2, (offset * PAGE_SIZE) + (PAGE_SIZE * cur_page) + (cur_psize * sizeof(int) * 2));
        cur_psize++;
        // // If we have filled up this page, move to write to the next page and store the max value and its offset to use in internal nodes
        if (cur_psize >= nodes_per_page) {
            max_keys[cur_page] = scanned_kv->kv[i][0];
            max_offsets[cur_page] = offset + cur_page;
            // Write INT_MIN to indicate that we have filled up the page (if we have not entered the next page yet)
            if ((cur_psize * sizeof(int) * 2) % PAGE_SIZE != 0) {
                int no_key = NO_KEY;
                pwrite(fd, &(no_key), sizeof(int), offset * PAGE_SIZE + cur_page * PAGE_SIZE + (cur_psize - 1) * sizeof(int) * 2 + sizeof(int) * 2);
            }
            cur_page++;
            cur_psize = 0;
        }
        else if (i == scanned_kv->kv_size -1) { // last element, also add it to the max_keys and offset
            max_keys[cur_page] = scanned_kv->kv[i][0];
            max_offsets[cur_page] = offset + cur_page;
            if ((cur_psize * sizeof(int) * 2) % PAGE_SIZE != 0) {
                int no_key = NO_KEY;
                pwrite(fd, &(no_key), sizeof(int), offset * PAGE_SIZE + cur_page * PAGE_SIZE + (cur_psize - 1) * sizeof(int) * 2 + sizeof(int) * 2);
            }
        }
        delete[] scanned_kv->kv[i];
    }
    // Loop through the remaining layers to write their internal nodes
    for (int i = 1; i < num_layers; i++) {
        if (num_pages[i] == 0) break;  // no elements in this layer
        // Get the layer offset
        if (num_pages[i] == 1) {
            offset = 0;
        } else {
            offset = array_sum(num_pages, i+1, num_layers);
        }
        cur_psize = 0;
        cur_page = 0;
        nodes_per_page = B-1;    
    
        // Write to the SST
        for (int j=0; j < num_pages[i-1]; j++) {
            // Check if we're on the second last page and if we need to split it to avoid inbalance
            if (cur_page == num_pages[i] - 2 && (num_pages[i-1] % (B) != 0) && ((num_pages[i-1] % (B)) < (B)/2)) {
                nodes_per_page = ceil(float(B) / float(2));
            }
            // If we have filled up this page, move to write to the next page
            if (cur_psize >= nodes_per_page || j == num_pages[i-1] - 1 || j == num_pages[i-1]-1) {
                // Write the offset as we need it to access this last node
                pwrite(fd, &max_offsets[j], sizeof(int), (offset * PAGE_SIZE) + (PAGE_SIZE * cur_page) + (cur_psize * sizeof(int) * 2) + sizeof(int));
                pwrite(fd, &cur_psize, sizeof(int), (offset * PAGE_SIZE) + (PAGE_SIZE * cur_page));
                // Also set the max value of this internal node to its rightmost max value
                max_keys[cur_page] = max_keys[j];
                max_offsets[cur_page] = cur_page + offset;
                cur_page ++;
                cur_psize = 0;
            }
            else {
                // Write to (Start of internal node pages + page offset + element offset)
                // Write alternating offsets and keys
                pwrite(fd, &max_offsets[j], sizeof(int), (offset * PAGE_SIZE) + (PAGE_SIZE * cur_page) + (cur_psize * sizeof(int) * 2) + sizeof(int));
                pwrite(fd, &max_keys[j], sizeof(int), (offset * PAGE_SIZE) + (PAGE_SIZE * cur_page) + (cur_psize * sizeof(int) * 2) + sizeof(int) * 2);
                cur_psize++;
            }
        }
    }
    int rv = 0;
    // If sst0.bin exists, we need to do a compaction
    if (to_compact) rv = compact_sst(-1, 0, 1, scanned_kv->kv_size); 
    close(fd);
    DB::mem_table->cur_size = 0;
    free(scanned_kv->kv);
    delete(scanned_kv);
    DB::mem_table->deleteTree(DB::mem_table->root);
    return rv;
}
TreeNode::TreeNode(int key, int val) : key(key), value(val), height(1), left(nullptr), right(nullptr) {}
AVLTree::AVLTree(int max_size, std::string db_name): root(nullptr), max_size(max_size), cur_size(0), db_name(db_name) {}
// public methods
int AVLTree::put(int key, int value) {
    root = insert(root, key, value);
    return 0;
}
int AVLTree::get(int key) {
    return get_node(root, key);
}
struct kv_pairs* AVLTree::scan(int key1, int key2){
    kv_pairs *scanned_kv = new kv_pairs;
    scanned_kv->kv = (int**) malloc(sizeof(int*) * max_size); // initially sized to fit max possible number of pairs in the memtable
    scanned_kv->kv_size = 0;
    scan_node(root, key1, key2, scanned_kv, false);    // Recursive scan of tree using our helper function
    return scanned_kv;
}
// private methods
void AVLTree::scan_node(TreeNode* node, int key1, int key2, struct kv_pairs *scanned_kv, bool fullscan) {
    if (node == nullptr) return;
    // Scan left side
    if ((key1 < node->key) || fullscan) scan_node(node->left, key1, key2, scanned_kv, fullscan);
    // If we are not doing a full scan of the memtable, and we are in range, add to kv pairs and increment size
    // If we are doing a fullscan, we always add the kv pair regardless of the range
    if ((!fullscan && node->key >= key1 && node->key <= key2) || fullscan) {
        int * pair;
        pair = new int [2];
        pair[0] = node->key;
        pair[1] = node->value;
        scanned_kv->kv[scanned_kv->kv_size] = pair;
        scanned_kv->kv_size += 1;
    }
    // Scan right side
    if ((key2 > node->key) || fullscan) scan_node(node->right, key1, key2, scanned_kv, fullscan);
}
int AVLTree::getHeight(TreeNode* node) {
    if (node == nullptr) return 0;
    return node->height;
}
int AVLTree::getBalanceFactor(TreeNode* node) {
    if (node == nullptr) return 0;
    return getHeight(node->left) - getHeight(node->right);
}
TreeNode* AVLTree::rotateRight(TreeNode* node) {
    /*
        node
        /
        x
        / \
        z   y
    */
    TreeNode* x = node->left;
    TreeNode* y = x->right;
    
    x->right = node;
    node->left = y;
    node->height = std::max(getHeight(node->right), getHeight(node->left)) + 1;
    x->height = std::max(getHeight(x->right), getHeight(x->left)) + 1;
    return x;
}
TreeNode* AVLTree::rotateLeft(TreeNode* node) {
    /*
    node
        \
        x
        / \
        z   y
    */
    TreeNode* x = node->right;
    TreeNode* z = x->left;
    
    x->left = node;
    node->right = z;
    node->height = std::max(getHeight(node->right), getHeight(node->left)) + 1;
    x->height = std::max(getHeight(x->right), getHeight(x->left)) + 1;
    return x;
}
// insert into the tree recursively
TreeNode* AVLTree::insert(TreeNode* node, int key, int val) {
    if (node == nullptr) {
        cur_size ++;
        return new TreeNode(key, val);
    }
    if (key < node->key) {
        node->left = insert(node->left, key, val);
    }
    else if (key > node->key) {
        node->right = insert(node->right, key, val);
    }
    else{
        node->value = val;  // value updated for the key
        return node;
    }
    // new node is created
    node->height = std::max(getHeight(node->right), getHeight(node->left)) + 1;
    // check current balance
    int balanceFactor = getBalanceFactor(node);
    if (balanceFactor < -1 && getBalanceFactor(node->right) <= 0) return rotateLeft(node);
    if (balanceFactor > 1 && getBalanceFactor(node->left) >= 0) return rotateRight(node);
    if (balanceFactor < -1 && getBalanceFactor(node->right) > 0) {
        node->right = rotateRight(node->right);
        return rotateLeft(node);
    }
    if (balanceFactor > 1 && getBalanceFactor(node->left) < 0) {
        node->left = rotateLeft(node->left);
        return rotateRight(node);
    }
    return node;
}
// get from tree
int AVLTree::get_node(TreeNode* node, int k) {
    if (node == nullptr) return NO_KEY; 
    if (k < node->key) return get_node(node->left, k);
    if (k > node->key) return get_node(node->right, k);
    return node->value;
}
void AVLTree::deleteTree(TreeNode* node){
    if (node == nullptr) return;
    deleteTree(node->left);
    deleteTree(node->right);
    delete(node);
    return;
}
void AVLTree::freeKVpairs(kv_pairs* p) {
    for(int i = 0; i < p->kv_size; i++){
        delete[] p->kv[i];
    }
    free(p->kv);
    delete(p);
}