WiscSort / EMS / ext_sort.cc
ext_sort.cc
Raw
#include "ext_sort.h"
#include <chrono>
#include <sys/types.h>
#include <sys/stat.h>
#include <assert.h>
#include <algorithm>
#include <sys/mman.h>
#include "timer.h"
#include "config.h"

#include <immintrin.h>

#include "ips4o/include/ips4o.hpp"

#define min(a, b) (((a) < (b)) ? (a) : (b))

Timer timer;
ExtSort::ExtSort(string mount_path)
{
    folder_path_ = mount_path + std::to_string(std::time(nullptr));
    mkdir(folder_path_.c_str(), 0777);
    // DEBUG:
    // folder_path_ = "/mnt/pmem/1647924000";
}

// return number of blocks processed
size_t ExtSort::InMemorySort(size_t blk_off, DataManager &data_manger, string output_file)
{
    vector<in_record_t> keys_idx;
    size_t read_buff_sz = conf.read_buff_blk_count * conf.block_size;
    record_t *read_buffer = (record_t *)malloc(read_buff_sz);
    assert(read_buffer);

    timer.start("load keys");
    timer.start("RUN read");
    size_t read_records = data_manger.RunRead(blk_off, conf.read_buff_blk_count, read_buffer);

    // tmp load exps
    // delete[] read_buffer;
    // return (read_records * conf.record_size) / conf.block_size;

    timer.end("RUN read");
#ifdef checkpoints
    timer.end("checkpoints");
    printf("\t \t RUN read: %f\n", timer.get_overall_time("checkpoints"));
    timer.start("checkpoints");
#endif
    if (!read_records)
        return 0;

    keys_idx.resize(read_records);
    // Move keys and index from read buffer to array in parallel
    std::vector<std::future<void>> kresults;
    size_t thrd_chunk_rec = read_records / conf.read_thrds;
    // clang-format off
    for (int thrd = 0; thrd < conf.read_thrds; thrd++)
    {
        kresults.emplace_back(
            data_manger.pool.enqueue([=, &keys_idx]()
            {
                size_t idx = 0;
                while(idx < thrd_chunk_rec)
                {
                    keys_idx[idx + thrd_chunk_rec * thrd] = 
                                    in_record_t(read_buffer[idx + thrd_chunk_rec * thrd].k,
                                                idx + thrd_chunk_rec * thrd);
                    idx++;
                }
            })
        );
    }
    //clang-foramt on
    for (auto &&result : kresults)
        result.get();

    timer.end("load keys");

    timer.start("SORT");
#ifdef cips4o
    if (conf.sort_thrds > 1) 
    {
#ifdef bandwidth
        uint64_t sort_start_time = rdtsc();
#endif
        // parallel sort
        ips4o::parallel::sort(keys_idx.begin(), keys_idx.end(), std::less<>{});
#ifdef bandwidth
        uint64_t sort_delta = rdtsc() - sort_start_time;
        timer.end("checkpoints");
        printf("%f,SORT,%f,%lu\n", timer.get_overall_time("checkpoints"), ((float)(sort_delta) / NANOSECONDS_IN_SECOND), keys_idx.size() * (KEY_SIZE + sizeof(size_t)));
        timer.start("checkpoints");
#endif
    }
    else
#endif
        std::sort(keys_idx.begin(), keys_idx.end());
    timer.end("SORT");
#ifdef checkpoints
    timer.end("checkpoints");
    printf("\t \t SORT: %f\n", timer.get_overall_time("checkpoints"));
    timer.start("checkpoints");
#endif
    // printf("Finished sorting %lu!\n", read_records);

    data_manger.OpenAndMapOutputFile(output_file, read_records * conf.record_size);
    record_t *write_buffer = (record_t *)malloc(conf.write_buff_blk_count * conf.block_size);
    assert(write_buffer);

    size_t write_records = 0;
    size_t key_off = 0;
    off_t output_off = 0;
    // Write all the read records.
    while (write_records < read_records)
    {
        // Move this to data manager /staged - the single thrd version can then keep the
        // Move records from read buffer to write buffer by fetching by value index.
        data_manger.LoadRecordToOutputBuffer(write_buffer, read_buffer, keys_idx, key_off);
        key_off += conf.write_buff_rec_count;

        timer.start("RUN write");
        write_records += data_manger.RunWrite(write_buffer, output_off);
        timer.end("RUN write");
#ifdef checkpoints
        timer.end("checkpoints");
        printf("\t \t RUN write: %f\n", timer.get_overall_time("checkpoints"));
        timer.start("checkpoints");
#endif
        output_off += (conf.write_buff_blk_count * conf.block_size);
        // FIXME: memset write_buffer 0 - if READ_BUFFER IS NOT A MULTIPLE OF INPUT_BUFFER
    }

    // TODO:
    // unmap buffers!!
    delete[] read_buffer;
    delete[] write_buffer;
    // munmap(input_mapped_buffer_, records_num * conf.record_size);
    // munmap(output_mapped_buffer_, records_num * conf.record_size);
    return (write_records * conf.record_size) / conf.block_size;
}

vector<string> ExtSort::RunGeneration(vector<string> files)
{
    DataManager data_manager(files, &timer);
    uint64_t blk_off = 0;
    string file_name = folder_path_ + "/run_";
    size_t run_num = 0;
    vector<string> run_names;
    while (1)
    {
        uint64_t blk_processed = InMemorySort(blk_off, data_manager,
                                              file_name + std::to_string(run_num));
        if (!blk_processed)
            break;
        blk_off += blk_processed;
        run_names.push_back(file_name + std::to_string(run_num));
        run_num++;
    }

    // unmap and close file input file?
    munmap(data_manager.input_mapped_buffer_, data_manager.file_size_[0]);
    close(data_manager.file_ptr_[0]);
    
    return std::move(run_names);
}

void ExtSort::MergeRuns(vector<string> runs)
{

    // Game plan!
    // 1. Open all run files and create a new file of size sum of all runs.
    // 2. Create a read buffer that is evenly divisible by the number of files
    // 3. Create current pointer and end pointers (or atleast know end for each file's buffer)
    // 4. First mmap runs and fill the read buffer from their respective files.
    // 5. find min of keys pointed by current pointers (ideally a min on list of (key,file_num)
    // 6. Move the min to output buffer, if output buffer is full flush it and reset pointer.
    // 7. increment the pointer of the respective run_buffer
    // 8. if current pointer reach the end pointer load the new data to the part of that buffer.
    // 9. if the last block of a file is read and current pointer for that file buffer is at the
    // end then stop doing reads.
    // 10. If only one file buffer hasn't reached it's last block then just copy the rest of that
    // file to the output buffer.

    // [1]
    DataManager data_manager(runs, &timer);
    string output_file_name = folder_path_ + "/sorted";
    size_t output_file_size = 0;
    for (auto i : data_manager.file_size_)
        output_file_size += i;
    data_manager.OpenAndMapOutputFile(output_file_name, output_file_size);
    // Here the write_buffer size should evenly divide the output_file_size
    record_t *write_buffer = (record_t *)malloc(conf.write_buff_blk_count * conf.block_size);
    assert(write_buffer);

    // [2]
    size_t read_buff_sz = conf.read_buff_blk_count * conf.block_size;
    record_t *read_buffer = (record_t *)malloc(read_buff_sz);
    assert(read_buffer);

    // [3]
    vector<size_t> cur_rec_off(runs.size(), 0);   // Holds the rec offset of key that needs comparison per file
    vector<size_t> end_rec_off(runs.size(), 0);   // Holds the last rec offset of the space allocated per file
    vector<size_t> rfile_blk_off(runs.size(), 0); // Has the current blk offset of the file which is to be read
    vector<char *> mapped_run_files(runs.size()); // mmap all the run files

    size_t space_per_run = read_buff_sz / runs.size();

    // [4]
    size_t read_records = 0;
    size_t read_size_blk = space_per_run / conf.block_size;
    for (uint32_t file_num = 0; file_num < runs.size(); ++file_num)
    {
        // Initialize pointers
        cur_rec_off[file_num] = file_num * (space_per_run / conf.record_size);
        end_rec_off[file_num] = (file_num + 1) * (space_per_run / conf.record_size) - 1;

        // mmap files
        data_manager.MMapFile(data_manager.file_size_[file_num], 0,
                              data_manager.file_ptr_[file_num], mapped_run_files[file_num]);

        timer.start("MERGE read");
        read_records += data_manager.MergeRead(read_buffer, cur_rec_off[file_num],
                                               mapped_run_files[file_num], rfile_blk_off[file_num],
                                               read_size_blk);
        timer.end("MERGE read");
#ifdef checkpoints
        timer.end("checkpoints");
        printf("\t \t MERGE read: %f\n", timer.get_overall_time("checkpoints"));
        timer.start("checkpoints");
#endif
        rfile_blk_off[file_num] = space_per_run / conf.block_size;
    }

    // [5]
    vector<k_t> min_vec(runs.size());
    // Load the vec with first key from each run
    for (uint32_t file_num = 0; file_num < runs.size(); ++file_num)
    {
        min_vec[file_num] = read_buffer[cur_rec_off[file_num]].k;
    }

    size_t write_buff_rec_off = 0;
    uint32_t min_index = 0;
    size_t recs_written = 0;
    size_t run_file_sz = 0;
    k_t tmp_k;
    tmp_k = 0xff;
    // DEBUG: ////////////
    uint64_t check_time_1 = 0;
    uint64_t check_time_2 = 0;
    uint64_t merge_loop_time = 0;
    uint64_t small_branch_time = 0;
    uint64_t check1 = 0;
    uint64_t check2 = 0;
    uint64_t branch_t = 0;
    uint64_t count_check = 0;
    merge_loop_time = rdtsc();
    //////////////////////
    // Loop until all the recs are written
    while (recs_written < output_file_size / conf.record_size)
    {
        count_check++;
        // put if there back
        // [5]
        // timer.start("Find min");
        min_index = std::min_element(min_vec.begin(), min_vec.end()) - min_vec.begin();
        // k_t min_key = read_buffer[cur_rec_off[min_index]].k;
        // timer.end("Find min");
        //  [6]
        // timer.start("copy");
        check1 = 0;
        check1 = rdtsc();
            memcpy(&write_buffer[write_buff_rec_off], &read_buffer[cur_rec_off[min_index]], conf.record_size);
        write_buff_rec_off++;
        // timer.end("copy");
        check_time_1 += (rdtsc() - check1);
        // check if all RUN files reached the end
        // timer.start("count");
        check2 = 0;
        check2 = rdtsc();
        int count_k = count(min_vec.begin(), min_vec.end(), tmp_k) == (uint32_t)runs.size();
        // timer.end("count");
        if (count_k)
        {
            // What if write buffer is not completely full but all records are processed.
            // Here min_vec indicates all the records are processed for all runs
            // Dump remaining write_buffer and break
            timer.start("MERGE write");
            recs_written += data_manager.MergeWrite(write_buffer, write_buff_rec_off);
            timer.end("MERGE write");
#ifdef checkpoints
            timer.end("checkpoints");
            printf("\t \t MERGE write: %f\n", timer.get_overall_time("checkpoints"));
            timer.start("checkpoints");
#endif
            break;
        }
        // Write buffer is full so flush it.
        if (write_buff_rec_off >= conf.write_buff_rec_count)
        {
            // DEBUG: This is to verify if keys are sorted.
            // vector<in_record_t> keys_idx;
            // keys_idx.resize(write_buff_rec_off);
            // for (size_t idx = 0; idx < write_buff_rec_off; ++idx)
            // {
            //     keys_idx[idx] = in_record_t(write_buffer[idx].k, idx);
            // }
            // std::sort(keys_idx.begin(), keys_idx.end());
            timer.start("MERGE write");
            recs_written += data_manager.MergeWrite(write_buffer, write_buff_rec_off);
            timer.end("MERGE write");
#ifdef checkpoints
            timer.end("checkpoints");
            printf("\t \t MERGE write: %f\n", timer.get_overall_time("checkpoints"));
            timer.start("checkpoints");
#endif
            write_buff_rec_off = 0;
        }

        // [7]
        // Now replace min_vec with new value from respective buffer space.
        cur_rec_off[min_index]++;
        // timer.start("replace vec");
        min_vec[min_index] = read_buffer[cur_rec_off[min_index]].k;
        // timer.end("replace vec");


        check_time_2 += (rdtsc() - check2);

        branch_t = 0;
        branch_t = rdtsc();
        // Now check if corresponding chunk has reached the end
        if (cur_rec_off[min_index] > end_rec_off[min_index])
        {
            // [9]
            // Also see if you have hit the end of file for that run
            run_file_sz = data_manager.file_size_[min_index] / conf.block_size;
            if (rfile_blk_off[min_index] >= run_file_sz)
            {
                // [10] Uncomment this optimization after finishing MoveRemainingInputToOuputBuff
                // if (count(min_vec.begin(), min_vec.end(), tmp_k) == (uint32_t)runs.size() - 1)
                // {
                //     for (uint32_t file_id = 0; file_id < min_vec.size(); ++file_id)
                //     {
                //         if (!(min_vec[file_id] == tmp_k))
                //         {
                //             // First flush the write_buffer to file
                //             // Read data to read buffer and write it directly to outputfile
                //             recs_written += data_manager.MoveRemainingInputToOuputBuff();
                //             break;
                //         }
                //     }
                // }

                // if you reached the end of a file then set a max value for it in min_vec index.
                min_vec[min_index] = tmp_k; // 0xffffffffffffffffffff
            }
            else
            {
                // [8]
                // reset current rec offset for that RUN
                cur_rec_off[min_index] = min_index * (space_per_run / conf.record_size);
                // Read the next set of blocks from the file
                size_t read_size = min(read_size_blk, run_file_sz - rfile_blk_off[min_index]);
                timer.start("MERGE read");
                data_manager.MergeRead(read_buffer, cur_rec_off[min_index],
                                       mapped_run_files[min_index], rfile_blk_off[min_index],
                                       read_size);
                timer.end("MERGE read");
#ifdef checkpoints
                timer.end("checkpoints");
                printf("\t \t MERGE read: %f\n", timer.get_overall_time("checkpoints"));
                timer.start("checkpoints");
#endif
                // Update end record in the case where read_buffer is not completely full
                // because no more records to read from for that RUN file.
                end_rec_off[min_index] = (min_index + 1) * ((read_size * conf.block_size) / conf.record_size) - 1;
                rfile_blk_off[min_index] += read_size;
            }
        }
        small_branch_time += (rdtsc()- branch_t);
    }
    
    printf("check 1: %f, check: %f, small_branch: %f\n merge_loop: %f\n",
            ((float)(check_time_1) / NANOSECONDS_IN_SECOND),
            ((float)(check_time_2) / NANOSECONDS_IN_SECOND),
            ((float)(small_branch_time)/NANOSECONDS_IN_SECOND),
            ((float)(rdtsc() - merge_loop_time)/NANOSECONDS_IN_SECOND));
    printf("count: %lu\n", count_check);
}            

void ExtSort::Sort(vector<string> &files)
{
#ifdef checkpoints
    timer.start("checkpoints");
#elif bandwidth
    timer.start("checkpoints");
#endif

    // DEBUG:
    // sleep(5);

    timer.start("RUN");
    vector<string> runs = RunGeneration(files);
    timer.end("RUN");

    // Assuming only one merge phase!
    // DEBUG:
    // vector<string> runs{"/mnt/pmem/1647924000/run_0", "/mnt/pmem/1647924000/run_1"};
    timer.start("MERGE");
    MergeRuns(runs);
    timer.end("MERGE");

#ifdef checkpoints
    timer.end("checkpoints");
#elif bandwidth
    timer.end("checkpoints");
#endif

    printf("====================================\n");
    printf("\t RUN read: %f\n", timer.get_overall_time("RUN read"));
    printf("\t RUN sort: %f\n", timer.get_overall_time("SORT"));
    printf("\t RUN write: %f\n", timer.get_overall_time("RUN write"));
    printf("Total RUN: %f\n", timer.get_time("RUN"));
    printf("\t MERGE read: %f\n", timer.get_overall_time("MERGE read"));
    printf("\t MERGE write: %f\n", timer.get_overall_time("MERGE write"));
    printf("Total MERGE: %f\n", timer.get_time("MERGE"));
    printf("Total: %f\n", timer.get_time("RUN") + timer.get_time("MERGE"));
    // printf("Find min: %f\n", timer.get_overall_time("Find min"));
    // printf("Count: %f\n", timer.get_overall_time("count"));
    // printf("Copy: %f\n", timer.get_overall_time("copy"));
    // printf("replace vec: %f\n", timer.get_overall_time("replace vec"));
    // printf("load keys: %f\n", timer.get_overall_time("load keys"));   
}