Tags: None.
Categories: None.

It’s time to implement the architecture mentioned in this thesis.First,i will implement the circut of this extension.according to the paper.This is not too hard:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
module BNRV_4Mux(
    input signed [7:0] x,
    input [1:0] w,
    output signed [9:0] o
);
    reg signed [9:0] t_o;
    wire signed [9:0] x_ext = { {2{x[7]}}, x };
    always @(*) begin
        case(w)
            2'b00: t_o = 10'sd0;
            2'b01: t_o = x_ext;          
            2'b10: t_o = -x_ext;          
            2'b11: t_o = -(x_ext << 1);  
            default: t_o = 10'sd0;
        endcase
    end
    assign o = t_o;
endmodule

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
module BNRV_SUM4 (
    input signed [7:0] x1, x2, x3, x4,
    input [1:0] w1, w2, w3, w4,
    output signed [31:0] o
    );
    wire signed [9:0] mo1, mo2, mo3, mo4;
    wire signed [11:0] res_temp;

    BNRV_4Mux Mux1 (.x(x1), .w(w1), .o(mo1));
    BNRV_4Mux Mux2 (.x(x2), .w(w2), .o(mo2));
    BNRV_4Mux Mux3 (.x(x3), .w(w3), .o(mo3));
    BNRV_4Mux Mux4 (.x(x4), .w(w4), .o(mo4));

    assign res_temp = ( { {2{mo1[9]}}, mo1 } + { {2{mo2[9]}}, mo2 } ) +
                      ( { {2{mo3[9]}}, mo3 } + { {2{mo4[9]}}, mo4 } );
    assign o = { {20{res_temp[11]}}, res_temp };
endmodule

To avoid bit truncation for a big matrix,whose bugs have confused me for a long time,i extent the bit weights for a or two bits.

We need to compile it with verilator.This will generate c++ module files in obj_dir:

1
verilator -cc -j 0 -build -Wall infiles...
Copy it to gem5/src/arch/riscv,renamed with BNRV.Caution:after copy,we also need to delete the lib files such as *.a,otherwise we would encounter a link multiple define error when compiling gem5 because gem5 will generate symbols after we edit the sconscript which could conflict with them

Edit the Sconscript to add some compile options.DO NOT add *__ALL.cpp files if you have added other files.Or you can ONLY add ALL.cpp,otherwise you would also encount a multiple define error when compiling.

I like to copy verilator.cpp and verilated_threads.cpp to BNRV which are also needed.

Then,let’s integrate the circut to gem5.the ISA files are all in src/arch/[architecture]/isa.For example,the subset of x86–x87 is in src/arch/x86/isa.For RISCV,its files about ISA is placed in src/arch/riscv/isa,of course.

Gem5 will decode according to decoder.isa.According to the manuel of RISCV,the lowest 2 bits is decoded to the length of code: Standard instruction-set extensions encoded with more than 32 bits have additional low-order bits set to 1, with the conventions for 48-bit and 64-bit lengths shown in Table 1.

And in gem5,the decoder is organized with a tree,then anylazed to a c case statment,meaning that the top is the lowest bits of the code.Our BNRV code is a 32-bits code,which means the source register and destination register is all 32 weights.So we should put the implement in 0x3 at last.I decided encode bnsum4 to 0001011 with func3 000.So put the implement in 0x3–0x02(with decode funt3 0x0) is better.The gem5 allows users to implement their own format using to custom their opeartion format.Such as,a FenceOp is:

1
2
3
4
5
6
7
8
9
10
11
def format FenceOp(code, imm_type='int64_t', *opt_flags) {{
regs = ['destRegIdx(0)','srcRegIdx(0)']
iop = InstObjParams(name, Name, 'ImmOp<%s>' % imm_type,
{'code': code, 'imm_code': 'imm = sext<12>(IMM12);',
'regs': ','.join(regs)}, opt_flags)
header_output = ImmDeclare.subst(iop)
decoder_output = ImmConstructor.subst(iop)
decode_block = BasicDecode.subst(iop)
exec_output = FenceExecute.subst(iop)
}};

We pass our implement to the paramter code and pass a opt_flags,the decoder will unfold our implement in code to C++ statments.Because our code is a classic R-Type code,we just use the default format of it.
1
2
3
4
5
6
7
8
// defined BNRV
def format BNOP(code, *opt_flags){{
iop = InstObjParams(name, Name, 'RegOp', code, opt_flags)
header_output = BasicDeclare.subst(iop)
decoder_output = BasicConstructor.subst(iop)
decode_block = BasicDecode.subst(iop)
exec_output = BasicExecute.subst(iop)
}};
To modularize our implement for debugging expediently,we can write a wrapper,then declare it in includes.isa.The same that we can also put any necessary include files in it.These files will be used by decoder.isa:
1
2
3
#include "arch/riscv/BNRV/VBNRV_SUM4.h"
#include "arch/riscv/BNRV/VBNRV_SUM4___024root.h"
#include "arch/riscv/BNRV/sum4_wrapper.h"

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#include "VBNRV_SUM4.h"
#include "VBNRV_SUM4___024root.h"
#include "verilated.h"
#include <stdint.h>
#include <memory>s
extern "C" int32_t invoke_bnsum4(uint32_t rs1, uint32_t rs2) {
    static VerilatedContext* contextp = new VerilatedContext;
    static VBNRV_SUM4* vmodel = nullptr;
    if (!vmodel) {
        vmodel = new VBNRV_SUM4{contextp};
    }
    vmodel->x1 = (uint8_t)(rs1 & 0xFF);
    vmodel->x2 = (uint8_t)((rs1 >> 8) & 0xFF);
    vmodel->x3 = (uint8_t)((rs1 >> 16) & 0xFF);
    vmodel->x4 = (uint8_t)((rs1 >> 24) & 0xFF);
    vmodel->w1 = (uint8_t)(rs2 & 0x03);
    vmodel->w2 = (uint8_t)((rs2 >> 2) & 0x03);
    vmodel->w3 = (uint8_t)((rs2 >> 4) & 0x03);
    vmodel->w4 = (uint8_t)((rs2 >> 6) & 0x03);
    vmodel->eval();
    return vmodel->o;
}

Then we just call invoke_bnsum4 in decoder:

1
2
3
4
5
6
7
8
9
10
11
12
13
0x3: decode OPCODE5 {
0x01: decode FUNCT3 {
...
}
0x02: decode FUNCT3 {
format BNOP {
0x0: bnsum4({{
auto raw_result = invoke_bnsum4(Rs1_uw, Rs2_uw);
Rd_sw = raw_result;
}},IsInteger)
}
}
}

We need to write a kernel implemented by BNRV extensions:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#define BITNET_STEP_HW(weights_byte, act_ptr, sum_accumulator) do { \
    uint32_t rs1_val = *((uint32_t*)(act_ptr)); \
    uint32_t rs2_val = (uint32_t)(weights_byte); \
    uint32_t rd_raw; \
    asm volatile ( \
        "mv x6, %1\n\t" \
        "mv x7, %2\n\t" \
        ".word 0x0073028b\n\t" \
        "mv %0, x5\n\t" \
        : "=r" (rd_raw) \
        : "r" (rs1_val), "r" (rs2_val) \
        : "x5", "x6", "x7", "memory" \
    ); \
    int32_t rd_sign_ext = (int32_t)(rd_raw << 21) >> 21; \
    sum_accumulator += rd_sign_ext; \
} while(0)

Do not use registers with predefined functional roles, such as the zero register or stack pointer.Or you will encounter a virtual address error because you have broken the context.

It’s time to write our kernel:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <math.h>
#include <float.h>
#include <gem5/m5ops.h>

#define N 8192

#define BITNET_STEP_HW(weights_byte, act_ptr, sum_accumulator) do { \
    uint32_t rs1_val = *((uint32_t*)(act_ptr)); \
    uint32_t rs2_val = (uint32_t)(weights_byte); \
    uint32_t rd_raw; \
    asm volatile ( \
        "mv x6, %1\n\t" \
        "mv x7, %2\n\t" \
        ".word 0x0073028b\n\t" \
        "mv %0, x5\n\t" \
        : "=r" (rd_raw) \
        : "r" (rs1_val), "r" (rs2_val) \
        : "x5", "x6", "x7", "memory" \
    ); \
    sum_accumulator += rd_raw; \
} while(0)


int8_t packed_fourweights(int8_t w0, int8_t w1, int8_t w2, int8_t w3) {
    int8_t bytes = 0;
    bytes |= (w0 & 0x03) << 0;
    bytes |= (w1 & 0x03) << 2;
    bytes |= (w2 & 0x03) << 4;
    bytes |= (w3 & 0x03) << 6;
    return bytes;
}

void rms_norm(float* input, float* output, int n) {
    float sum = 0;
    for (int i = 0; i < n; i++) sum += input[i] * input[i];
    float rms = sqrtf(sum / n + 1e-6f);
    for (int i = 0; i < n; i++) {
        output[i] = input[i] / rms;
      }
}

void int_quant(float* input, int8_t* output) {
    float max_abs = 0.0f;
    for (int i = 0; i < N; i++) {
        float val = fabsf(input[i]);
        if (val > max_abs) max_abs = val;
    }
    float scale = (max_abs > 1e-8f) ? (127.0f / max_abs) : 0.0f;
    for (int i = 0; i < N; i++) {
        output[i] = (int8_t)roundf(input[i] * scale);
    }
}

void bitnet_matmul_kernel(int8_t* activations, int8_t* packed_weights, int32_t* output) {
    for (int i = 0; i < N; i++) {
        int32_t row_sum = 0;
        uint8_t* row_weights = (uint8_t*)&packed_weights[i * (N / 4)];
        for (int j = 0; j < N / 4; j++) {
            BITNET_STEP_HW(row_weights[j], &activations[j * 4], row_sum);
        }
        output[i] = row_sum;
    }
}


int main() {
    size_t total_weights_bytes = (N * N) / 4;
    int8_t* weights = (int8_t*)malloc(total_weights_bytes);
    int8_t* input_i = (int8_t*)malloc(N * sizeof(int8_t));
    float* input   = (float*)malloc(N * sizeof(float));
    float* input_n = (float*)malloc(N * sizeof(float));
    int32_t* out    = (int32_t*)malloc(N * sizeof(int32_t));
    for (int i = 0; i < total_weights_bytes; i++) {
        weights[i] = packed_fourweights(1, 3, 0, 2);
    }
    for (int i = 0; i < N; i++) {
        input[i] = 0.25f;
    }
    rms_norm(input, input_n, N);  
    int_quant(input_n, input_i);
    printf("Starting BitNet Hardware Kernel Simulation (N=%d)...\n", N);  
    m5_reset_stats(0, 0);
    bitnet_matmul_kernel(input_i, weights, out);
    m5_dump_stats(0, 0);
    int64_t final_checksum = 0;
    for (int i = 0; i < N; i++) {
        final_checksum += out[i];
    }  
    printf("Output Checksum: %ld\n", final_checksum);
    free(weights); free(input_i); free(input); free(input_n); free(out);
    return 0;
}

To verify if our kernel optimized can get a right output,we will compute the checksum of matrix.

After few minutes,i got the awesome result:

1
2
3
Starting BitNet Hardware Kernel Simulation (N=8192)...
Output Checksum: -4261412864
Exiting @ tick 419616750000 because exiting with last active thread context
1
2
3
Starting BitNet Kernel Simulation...
Output Checksum: -4261412864
Exiting @ tick 829204726000 because exiting with last active thread context

You can see our kernel optimized throught hardware finished with a definitely fast speed,whose ticks decreased by almost doubleness.

I noticed that our BNRV kernel have less IPC(more CPI),even though it have a much lower simTicks and simSecond:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
Normal:
---------- Begin Simulation Statistics ----------
simSeconds 0.780732 # Number of seconds simulated (Second)
simTicks 780732408000 # Number of ticks simulated (Tick)
finalTick 829016574000 # Number of ticks from beginning of simulation (restored from checkpoints and never reset) (Tick)
simFreq 1000000000000 # The number of ticks per simulated second ((Tick/Second))
hostSeconds 454.10 # Real time elapsed on the host (Second)
hostTickRate 1719313855 # The number of ticks simulated per host second (ticks/s) ((Tick/Second))
hostMemory 8529416 # Number of bytes of host memory used (Byte)
simInsts 503373855 # Number of instructions simulated (Count)
simOps 503373855 # Number of ops (including micro ops) simulated (Count)
hostInstRate 1108520 # Simulator instruction rate (inst/s) ((Count/Second))
hostOpRate 1108520 # Simulator op (including micro ops) rate (op/s) ((Count/Second))
system.clk_domain.clock 2000 # Clock period in ticks (Tick)
system.clk_domain.voltage_domain.voltage 1 # Voltage in Volts (Volt)
system.cpu.numCycles 390366204 # Number of cpu cycles simulated (Cycle)
system.cpu.cpi 0.775500 # CPI: cycles per instruction (core level) ((Cycle/Count))
system.cpu.ipc 1.289491 # IPC: instructions per cycle (core level) ((Count/Cycle))
system.cpu.numWorkItemsStarted 0 # Number of work items this cpu started (Count)
system.cpu.numWorkItemsCompleted 0 # Number of work items this cpu completed (Count)
system.cpu.issuedInstType_0::No_OpClass 0 0.00% 0.00% # Number of instructions issued per FU type, per thread (Count)
system.cpu.issuedInstType_0::IntAlu 285261860 56.67% 56.67% # Number of instructions issued per FU type, per thread (Count)
system.cpu.issuedInstType_0::IntMult 67108864 13.33% 70.00% # Number of instructions issued per FU type, per thread (Count)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
BNRV
---------- Begin Simulation Statistics ----------
simSeconds 0.371144 # Number of seconds simulated (Second)
simTicks 371144234000 # Number of ticks simulated (Tick)
finalTick 419434078000 # Number of ticks from beginning of simulation (restored from checkpoints and never reset) (Tick)
simFreq 1000000000000 # The number of ticks per simulated second ((Tick/Second))
hostSeconds 164.60 # Real time elapsed on the host (Second)
hostTickRate 2254856087 # The number of ticks simulated per host second (ticks/s) ((Tick/Second))
hostMemory 8783488 # Number of bytes of host memory used (Byte)
simInsts 167829523 # Number of instructions simulated (Count)
simOps 167829523 # Number of ops (including micro ops) simulated (Count)
hostInstRate 1019634 # Simulator instruction rate (inst/s) ((Count/Second))
hostOpRate 1019634 # Simulator op (including micro ops) rate (op/s) ((Count/Second))
system.clk_domain.clock 2000 # Clock period in ticks (Tick)
system.clk_domain.voltage_domain.voltage 1 # Voltage in Volts (Volt)
system.cpu.numCycles 185572117 # Number of cpu cycles simulated (Cycle)
system.cpu.cpi 1.105718 # CPI: cycles per instruction (core level) ((Cycle/Count))
system.cpu.ipc 0.904390 # IPC: instructions per cycle (core level) ((Count/Cycle))
system.cpu.numWorkItemsStarted 0 # Number of work items this cpu started (Count)
system.cpu.numWorkItemsCompleted 0 # Number of work items this cpu completed (Count)
system.cpu.issuedInstType_0::No_OpClass 0 0.00% 0.00% # Number of instructions issued per FU type, per thread (Count)
system.cpu.issuedInstType_0::IntAlu 134266915 80.00% 80.00% # Number of instructions issued per FU type, per thread (Count)
system.cpu.issuedInstType_0::IntMult 0 0.00% 80.00% # Number of instructions issued per FU type, per thread (Count)

This is because BNRV have less operation to compute this kernel by removing useless code,I think.

I also noticed that BNRV accessed the cache less time,comparing with the normal BitNetKernel:

1
2
3
4
5
6
7
8
9
10
normal
system.cpu.dcache.demandHits::cpu.data 150755193 # number of demand (read+write) hits (Count)
system.cpu.dcache.demandHits::total 150755193 # number of demand (read+write) hits (Count)
system.cpu.dcache.overallHits::cpu.data 150755193 # number of overall hits (Count)
system.cpu.dcache.overallHits::total 150755193 # number of overall hits (Count)
system.cpu.dcache.demandMisses::cpu.data 264337 # number of demand (read+write) misses (Count)
system.cpu.dcache.demandMisses::total 264337 # number of demand (read+write) misses (Count)
system.cpu.dcache.overallMisses::cpu.data 264337 # number of overall misses (Count)
system.cpu.dcache.overallMisses::total 264337 # number of overall misses (Count)
system.cpu.dcache.demandMissLatency::cpu.data 42365002000 # number of demand (read+
1
2
3
4
5
6
7
8
9
10
BNRV
system.cpu.dcache.demandHits::cpu.data 33299587 # number of demand (read+write) hits (Count)
system.cpu.dcache.demandHits::total 33299587 # number of demand (read+write) hits (Count)
system.cpu.dcache.overallHits::cpu.data 33299587 # number of overall hits (Count)
system.cpu.dcache.overallHits::total 33299587 # number of overall hits (Count)
system.cpu.dcache.demandMisses::cpu.data 263037 # number of demand (read+write) misses (Count)
system.cpu.dcache.demandMisses::total 263037 # number of demand (read+write) misses (Count)
system.cpu.dcache.overallMisses::cpu.data 263037 # number of overall misses (Count)
system.cpu.dcache.overallMisses::total 263037 # number of overall misses (Count)
system.cpu.dcache.demandMissLatency::cpu.data 44365840000 # number of demand (rea

It’s means our CPU doesn’t need to recover the middle results,which decreases the energy consumption.