FPGA-RISC-V-CPU / hardware / src / riscv_core / cpu.v
cpu.v
Raw
`include "opcode.vh"

module cpu #(
    parameter CPU_CLOCK_FREQ = 50_000_000,
    parameter RESET_PC = 32'h4000_0000,
    parameter BAUD_RATE = 115200
) (
    input clk,
    input rst,
    input bp_enable,
    input serial_in,
    output serial_out
);
    // BIOS Memory
    // Synchronous read: read takes one cycle
    // Synchronous write: write takes one cycle
    wire [11:0] bios_addra, bios_addrb;
    wire [31:0] bios_douta, bios_doutb;
    wire bios_ena, bios_enb;
    bios_mem bios_mem (
      .clk(clk),
      .ena(bios_ena),
      .addra(bios_addra),
      .douta(bios_douta),
      .enb(bios_enb),
      .addrb(bios_addrb),
      .doutb(bios_doutb)
    );

    // Data Memory
    // Synchronous read: read takes one cycle
    // Synchronous write: write takes one cycle
    // Write-byte-enable: select which of the four bytes to write
    wire [13:0] dmem_addr;
    wire [31:0] dmem_din, dmem_dout;
    wire [3:0] dmem_we;
    wire dmem_en;
    dmem dmem (
      .clk(clk),
      .en(dmem_en),
      .we(dmem_we),
      .addr(dmem_addr),
      .din(dmem_din),
      .dout(dmem_dout)
    );

    // Instruction Memory
    // Synchronous read: read takes one cycle
    // Synchronous write: write takes one cycle
    // Write-byte-enable: select which of the four bytes to write
    wire [31:0] imem_dina, imem_doutb;
    wire [13:0] imem_addra, imem_addrb;
    wire [3:0] imem_wea;
    wire imem_ena;
    imem imem (
      .clk(clk),
      .ena(imem_ena),
      .wea(imem_wea),
      .addra(imem_addra),
      .dina(imem_dina),
      .addrb(imem_addrb),
      .doutb(imem_doutb)
    );

    // Register file
    // Asynchronous read: read data is available in the same cycle
    // Synchronous write: write takes one cycle
    wire we;
    wire [4:0] ra1, ra2, wa;
    wire [31:0] wd;
    wire [31:0] rd1, rd2;
    reg_file rf (
        .clk(clk),
        .we(we),
        .ra1(ra1), .ra2(ra2), .wa(wa),
        .wd(wd),
        .rd1(rd1), .rd2(rd2)
    );

    // On-chip UART
    //// UART Receiver
    wire [7:0] uart_rx_data_out;
    wire uart_rx_data_out_valid;
    wire uart_rx_data_out_ready;
    //// UART Transmitter
    wire [7:0] uart_tx_data_in;
    wire uart_tx_data_in_valid;
    wire uart_tx_data_in_ready;
    uart #(
        .CLOCK_FREQ(CPU_CLOCK_FREQ),
        .BAUD_RATE(BAUD_RATE)
    ) on_chip_uart (
        .clk(clk),
        .reset(rst),

        .serial_in(serial_in),
        .data_out(uart_rx_data_out),
        .data_out_valid(uart_rx_data_out_valid),
        .data_out_ready(uart_rx_data_out_ready),

        .serial_out(serial_out),
        .data_in(uart_tx_data_in),
        .data_in_valid(uart_tx_data_in_valid),
        .data_in_ready(uart_tx_data_in_ready)
    );

    reg [31:0] tohost_csr = 0;

    // TODO: Your code to implement a fully functioning RISC-V core
    // Add as many modules as you want
    // Feel free to move the memory modules around

    // Note: wiring naming conventions
    // name = module_name + i/o name

    // Counters
    reg [31:0] cycle_cnt;
    reg [31:0] insts_cnt;
    reg [31:0] br_inst_cnt;
    reg [31:0] br_corr_cnt;
    wire cnt_reset;

    // Shared values
    reg [31:0] inst, f_ex_inst, ex_wb_inst;
    reg [31:0] pc, f_ex_pc, ex_wb_pc;
    reg [31:0] ex_alu, wb_alu;
    reg [31:0] wb_mux;
    wire stall, br, jal, jalr, mispred;
    reg ex_br_taken, wb_br_taken;
    reg [31:0] f_ex_rd1, f_ex_rd2;
    reg [31:0] f_ex_imm;

    // Control Signals
    wire BrEq;
    wire BrLt;
    wire [2:0] ImmSel;
    wire BrUn;
    wire ASel;
    wire BSel;
    wire [3:0] ALUSel;
    wire [3:0] MemRW;
    wire RegWEn;
    wire [1:0] WBSel;
    wire CSRSel;
    wire CSRWEn;

    // Flush Signals
    wire br_pred_taken;
    reg ex_flush, wb_flush;
    always @(*) begin
      if (bp_enable) begin
        ex_flush = jal | jalr | ((wb_br_taken != br) && (ex_wb_inst[6:0] == `OPC_BRANCH)) | br_pred_taken | rst;
        wb_flush = jal | jalr | ((wb_br_taken != br) && (ex_wb_inst[6:0] == `OPC_BRANCH)) | rst;
      end else begin
        ex_flush = jal | jalr | br | rst;
        wb_flush = jal | jalr | br | rst;
      end
    end

     // ======== Branch Predictor =========
    
    branch_predictor br_pred (
      .clk(clk),
      .reset(rst),
      .pc_guess(f_ex_pc),
      .is_br_guess(f_ex_inst[6:0] == `OPC_BRANCH),
      .pc_check(ex_wb_pc),
      .is_br_check(ex_wb_inst[6:0] == `OPC_BRANCH),
      .br_taken_check(br),
      .br_pred_taken(br_pred_taken)
    );
    //-------- Fetch/Decode Stage -------
    // Imm gen
    wire [31:0] imm;
    immediate_gen imm_gen(
      .inst(inst),
      .imm(imm)
    );

    // Program Counter
    reg [31:0] pc_next, inst_addr;
    always @(*) begin
      if (rst) begin
        pc_next = RESET_PC;
      end else if (stall) begin
        pc_next = pc;
      end else if (jal | jalr) begin
        pc_next = wb_alu;
      end else if (bp_enable) begin
        if (br_pred_taken) begin
          pc_next = f_ex_pc + f_ex_imm;
        end else if (br && !wb_br_taken && (ex_wb_inst[6:0] == `OPC_BRANCH)) begin
          pc_next = wb_alu; // Branch not taken; need to take
        end else if (!br && wb_br_taken && (ex_wb_inst[6:0] == `OPC_BRANCH)) begin
          pc_next = ex_wb_pc + 32'd4; // Branch taken; need to revert
        end else begin 
          pc_next = pc + 32'd4;
        end
      end else if (br) begin
        pc_next = wb_alu;
      end else begin 
        pc_next = pc + 32'd4;
      end
      inst_addr = pc_next;
    end
    always @(posedge clk) begin
      pc <= pc_next;
    end
    always @(*) begin
      case (inst_addr[31:28])
        4'b0100: inst = bios_douta;
        4'b0001: inst = imem_doutb;
        default: inst = 32'd0;
      endcase
    end
    assign bios_addra = inst_addr[13:2];
    assign imem_addrb = inst_addr[15:2];

    // Regfile
    assign ra1 = inst[19:15];
    assign ra2 = inst[24:20];

    

    // Pipeline
    always @(posedge clk) begin
      if (ex_flush) begin
        f_ex_pc   <= 32'd0;
        f_ex_inst <= 32'd0;
        f_ex_rd1  <= 32'd0;
        f_ex_rd2  <= 32'd0;
        f_ex_imm  <= 32'd0;
        ex_br_taken <= 1'd0;
      end else begin
        f_ex_pc   <= pc;
        f_ex_inst <= inst;
        f_ex_rd1  <= rd1;
        f_ex_rd2  <= rd2;
        f_ex_imm  <= imm;
      end
    end
  
    //-------- Execute Stage --------
    exec_mem_ctrl ex_mem_ctrl (
      .inst(f_ex_inst),
      .BrUn(BrUn),
      .ImmSel(ImmSel),
      .ASel(ASel),
      .BSel(BSel),
      .ALUSel(ALUSel),
      .MemRW(MemRW)
    );

    // ALU
    reg [31:0] pprev_inst, pprev_data;
    wire [31:0] fwd_a, fwd_b;
    always @(posedge clk) begin
      if (rst) begin
        pprev_inst <= 32'd0;
        pprev_data <= 32'd0;
      end else begin
        pprev_inst <= ex_wb_inst;
        pprev_data <= wb_mux;
      end 
    end
    forwarding fwd (
      .inst(f_ex_inst),
      .prev_inst(ex_wb_inst),
      .pprev_inst(pprev_inst),
      .rd1(f_ex_rd1), .rd2(f_ex_rd2),
      .alu(wb_alu), 
      .prev(pprev_data),
      .mem(wb_mux),
      .outa(fwd_a),
      .outb(fwd_b)
    );
    wire [31:0] a_mux, b_mux;
    assign a_mux = (ASel ? f_ex_pc : fwd_a);
    assign b_mux = (BSel ? f_ex_imm : fwd_b);
    alu alu (
      .a(a_mux),
      .b(b_mux),
      .ALUSel(ALUSel),
      .out(ex_alu)
    );

    // Branching
    branch_comp comp (
      .a(fwd_a),
      .b(fwd_b),
      .BrUn(BrUn),
      .BrEq(BrEq),
      .BrLt(BrLt)
    );
    

    // Memory & UART
    wire [31:0] str_data;
    store_data str (
      .inst(f_ex_inst),
      .addr(ex_alu),
      .din(fwd_b),
      .dout(str_data)
    );
    
    assign bios_addrb = ex_alu[13:2];

    assign dmem_addr = ex_alu[15:2];
    assign dmem_din = str_data;
    assign dmem_we = (((~ex_alu[31])&(~ex_alu[30])&ex_alu[28])==1'b1) ? (MemRW << ex_alu[1:0]) : 4'd0;
    
    assign imem_addra = ex_alu[15:2];
    assign imem_dina = str_data;
    assign imem_wea = (((~ex_alu[31])&(~ex_alu[30])&ex_alu[29])==1'b1) ? (MemRW << ex_alu[1:0]) : 4'd0;

    assign uart_tx_data_in = str_data[7:0];
    assign uart_tx_data_in_valid = ((f_ex_inst[6:2] == `OPC_STORE_5) && (ex_alu == 32'h8000_0008));

    assign cnt_reset = ((f_ex_inst[6:2] == `OPC_STORE_5) && (ex_alu == 32'h8000_0018));

    // CSR
    csr_ctrl csr_ctrl (
      .inst(f_ex_inst),
      .CSRSel(CSRSel),
      .CSRWEn(CSRWEn)
    );
    always @(*) begin
      if (CSRWEn) begin
        tohost_csr = (CSRSel ? f_ex_inst[19:15] : fwd_a);
      end
    end

    // Pipeline
    reg wb_BrEq, wb_BrLt;
    always @(posedge clk) begin
      if (rst | wb_flush) begin
        ex_wb_pc <= 32'd0;
        ex_wb_inst <= 32'd0;
        wb_alu <= 32'd0;
        wb_BrEq <= 1'd0;
        wb_BrLt <= 1'd0;
        wb_br_taken <= 1'd0;
      end else begin
        ex_wb_pc <= f_ex_pc;
        ex_wb_inst <= f_ex_inst;
        wb_alu <= ex_alu;
        wb_BrEq <= BrEq;
        wb_BrLt <= BrLt;
        wb_br_taken <= br_pred_taken;
      end
    end

    assign dmem_en = 1'b1;
    assign bios_ena = 1'b1;
    assign bios_enb = 1'b1;
    assign imem_ena = f_ex_pc[30];

    //-------- Writeback Stage --------
    writeback_ctrl wb_ctrl (
      .inst(ex_wb_inst),
      .WBSel(WBSel),
      .RegWEn(RegWEn)
    );

    reg [31:0] mem_wb_mux;
    always @(*) begin
      if      ((~wb_alu[31])&(~wb_alu[30])&(wb_alu[28]))               mem_wb_mux = dmem_dout;
      else if ((~wb_alu[31])&(wb_alu[30])&(~wb_alu[29])&(~wb_alu[28])) mem_wb_mux = bios_doutb;
      else if ((wb_alu[31])&(~wb_alu[30])&(~wb_alu[29])&(~wb_alu[28])) begin
        case (wb_alu[31:0])
          32'h8000_0000: mem_wb_mux = {30'd0, uart_rx_data_out_valid, uart_tx_data_in_ready};
          32'h8000_0004: mem_wb_mux = {24'd0, uart_rx_data_out};
          32'h8000_0010: mem_wb_mux = cycle_cnt;
          32'h8000_0014: mem_wb_mux = insts_cnt;
          32'h8000_001c: mem_wb_mux = br_inst_cnt;
          32'h8000_0020: mem_wb_mux = br_corr_cnt;
          default: mem_wb_mux = 32'd0;
        endcase
      end
      else mem_wb_mux = 32'd0;
    end 
    assign uart_rx_data_out_ready = ((ex_wb_inst[6:2] == `OPC_LOAD_5) && wb_alu == 32'h8000_0004);
    wire [31:0] ld_data;
    load_data ld (
      .inst(ex_wb_inst),
      .addr(wb_alu),
      .din(mem_wb_mux),
      .dout(ld_data)
    );
    always @(*) begin
      case (WBSel) // 0:Mem, 1:ALU, 2:PC+4
        2'b00:   wb_mux = ld_data;
        2'b01:   wb_mux = wb_alu;
        2'b10:   wb_mux = ex_wb_pc + 32'd4;
        default: wb_mux = 32'd0;
      endcase
    end
    assign we = RegWEn;
    assign wa = ex_wb_inst[11:7];
    assign wd = wb_mux;
    
    // Jal/Jalr
    assign jal = (ex_wb_inst[6:2] == `OPC_JAL_5);
    assign jalr = (ex_wb_inst[6:2] == `OPC_JALR_5);

    branch_ctrl br_ctrl(
      .inst(ex_wb_inst),
      .BrEq(wb_BrEq),
      .BrLt(wb_BrLt),
      .branch(br)
    );

    // ======== Counters ========
    always @(posedge clk) begin
      if (cnt_reset) cycle_cnt <= 32'd0;
      else cycle_cnt <= cycle_cnt + 32'd1;
    end
    always @(posedge clk) begin
      if (cnt_reset) insts_cnt <= 32'd0;
      else if ((ex_wb_inst != 32'd0)) insts_cnt <= insts_cnt + 32'd1;
    end
    always @(posedge clk) begin
      if (cnt_reset) br_inst_cnt <= 32'd0;
      else if (inst[6:0] == `OPC_BRANCH) br_inst_cnt <= br_inst_cnt + 32'd1;
    end
    always @(posedge clk) begin
      if (cnt_reset) br_corr_cnt <= 32'd0;
      else if ((br == wb_br_taken) && bp_enable) br_corr_cnt <= br_corr_cnt + 32'd1;
    end

endmodule