It’s time to implement the architecture mentioned in this
thesis.First,i will implement the circut of this extension.according to
the paper.This is not too hard: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19module BNRV_4Mux(
input signed [7:0] x,
input [1:0] w,
output signed [9:0] o
);
reg signed [9:0] t_o;
wire signed [9:0] x_ext = { {2{x[7]}}, x };
always @(*) begin
case(w)
2'b00: t_o = 10'sd0;
2'b01: t_o = x_ext;
2'b10: t_o = -x_ext;
2'b11: t_o = -(x_ext << 1);
default: t_o = 10'sd0;
endcase
end
assign o = t_o;
endmodule
1 | module BNRV_SUM4 ( |
To avoid bit truncation for a big matrix,whose bugs have confused me for a long time,i extent the bit weights for a or two bits.
We need to compile it with verilator.This will generate c++ module
files in obj_dir: 1
verilator -cc -j 0 -build -Wall infiles...
gem5/src/arch/riscv,renamed with BNRV.Caution:after
copy,we also need to delete the lib files such as *.a,otherwise we would
encounter a link multiple define error when compiling gem5 because gem5
will generate symbols after we edit the sconscript which could conflict
with them
Edit the Sconscript to add some compile options.DO NOT add *__ALL.cpp files if you have added other files.Or you can ONLY add ALL.cpp,otherwise you would also encount a multiple define error when compiling.
I like to copy verilator.cpp and verilated_threads.cpp to BNRV which are also needed.
Then,let’s integrate the circut to gem5.the ISA files are all in
src/arch/[architecture]/isa.For example,the subset of
x86–x87 is in src/arch/x86/isa.For RISCV,its files about
ISA is placed in src/arch/riscv/isa,of course.
Gem5 will decode according to decoder.isa.According to
the manuel of RISCV,the lowest 2 bits is decoded to the length of code:
Standard instruction-set extensions encoded with more than 32 bits
have additional low-order bits set to 1, with the conventions for 48-bit
and 64-bit lengths shown in Table 1.
And in gem5,the decoder is organized with a tree,then anylazed to a c
case statment,meaning that the top is the lowest bits of the code.Our
BNRV code is a 32-bits code,which means the source register and
destination register is all 32 weights.So we should put the implement in
0x3 at last.I decided encode bnsum4 to 0001011 with func3
000.So put the implement in 0x3–0x02(with decode funt3 0x0)
is better.The gem5 allows users to implement their own format using to
custom their opeartion format.Such as,a FenceOp is: 1
2
3
4
5
6
7
8
9
10
11def format FenceOp(code, imm_type='int64_t', *opt_flags) {{
regs = ['destRegIdx(0)','srcRegIdx(0)']
iop = InstObjParams(name, Name, 'ImmOp<%s>' % imm_type,
{'code': code, 'imm_code': 'imm = sext<12>(IMM12);',
'regs': ','.join(regs)}, opt_flags)
header_output = ImmDeclare.subst(iop)
decoder_output = ImmConstructor.subst(iop)
decode_block = BasicDecode.subst(iop)
exec_output = FenceExecute.subst(iop)
}};code and pass a
opt_flags,the decoder will unfold our implement in
code to C++ statments.Because our code is a classic R-Type
code,we just use the default format of it. 1
2
3
4
5
6
7
8// defined BNRV
def format BNOP(code, *opt_flags){{
iop = InstObjParams(name, Name, 'RegOp', code, opt_flags)
header_output = BasicDeclare.subst(iop)
decoder_output = BasicConstructor.subst(iop)
decode_block = BasicDecode.subst(iop)
exec_output = BasicExecute.subst(iop)
}};includes.isa.The same that we
can also put any necessary include files in it.These files will be used
by decoder.isa: 1
2
3#include "arch/riscv/BNRV/VBNRV_SUM4.h"
#include "arch/riscv/BNRV/VBNRV_SUM4___024root.h"
#include "arch/riscv/BNRV/sum4_wrapper.h"
1 |
|
Then we just call invoke_bnsum4 in decoder:
1 | 0x3: decode OPCODE5 { |
We need to write a kernel implemented by BNRV extensions:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
Do not use registers with predefined functional roles, such as the zero register or stack pointer.Or you will encounter a virtual address error because you have broken the context.
It’s time to write our kernel: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
int8_t packed_fourweights(int8_t w0, int8_t w1, int8_t w2, int8_t w3) {
int8_t bytes = 0;
bytes |= (w0 & 0x03) << 0;
bytes |= (w1 & 0x03) << 2;
bytes |= (w2 & 0x03) << 4;
bytes |= (w3 & 0x03) << 6;
return bytes;
}
void rms_norm(float* input, float* output, int n) {
float sum = 0;
for (int i = 0; i < n; i++) sum += input[i] * input[i];
float rms = sqrtf(sum / n + 1e-6f);
for (int i = 0; i < n; i++) {
output[i] = input[i] / rms;
}
}
void int_quant(float* input, int8_t* output) {
float max_abs = 0.0f;
for (int i = 0; i < N; i++) {
float val = fabsf(input[i]);
if (val > max_abs) max_abs = val;
}
float scale = (max_abs > 1e-8f) ? (127.0f / max_abs) : 0.0f;
for (int i = 0; i < N; i++) {
output[i] = (int8_t)roundf(input[i] * scale);
}
}
void bitnet_matmul_kernel(int8_t* activations, int8_t* packed_weights, int32_t* output) {
for (int i = 0; i < N; i++) {
int32_t row_sum = 0;
uint8_t* row_weights = (uint8_t*)&packed_weights[i * (N / 4)];
for (int j = 0; j < N / 4; j++) {
BITNET_STEP_HW(row_weights[j], &activations[j * 4], row_sum);
}
output[i] = row_sum;
}
}
int main() {
size_t total_weights_bytes = (N * N) / 4;
int8_t* weights = (int8_t*)malloc(total_weights_bytes);
int8_t* input_i = (int8_t*)malloc(N * sizeof(int8_t));
float* input = (float*)malloc(N * sizeof(float));
float* input_n = (float*)malloc(N * sizeof(float));
int32_t* out = (int32_t*)malloc(N * sizeof(int32_t));
for (int i = 0; i < total_weights_bytes; i++) {
weights[i] = packed_fourweights(1, 3, 0, 2);
}
for (int i = 0; i < N; i++) {
input[i] = 0.25f;
}
rms_norm(input, input_n, N);
int_quant(input_n, input_i);
printf("Starting BitNet Hardware Kernel Simulation (N=%d)...\n", N);
m5_reset_stats(0, 0);
bitnet_matmul_kernel(input_i, weights, out);
m5_dump_stats(0, 0);
int64_t final_checksum = 0;
for (int i = 0; i < N; i++) {
final_checksum += out[i];
}
printf("Output Checksum: %ld\n", final_checksum);
free(weights); free(input_i); free(input); free(input_n); free(out);
return 0;
}
To verify if our kernel optimized can get a right output,we will compute the checksum of matrix.
After few minutes,i got the awesome result: 1
2
3Starting BitNet Hardware Kernel Simulation (N=8192)...
Output Checksum: -4261412864
Exiting @ tick 419616750000 because exiting with last active thread context1
2
3Starting BitNet Kernel Simulation...
Output Checksum: -4261412864
Exiting @ tick 829204726000 because exiting with last active thread context
You can see our kernel optimized throught hardware finished with a definitely fast speed,whose ticks decreased by almost doubleness.
I noticed that our BNRV kernel have less IPC(more CPI),even though it
have a much lower simTicks and simSecond: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23Normal:
---------- Begin Simulation Statistics ----------
simSeconds 0.780732 # Number of seconds simulated (Second)
simTicks 780732408000 # Number of ticks simulated (Tick)
finalTick 829016574000 # Number of ticks from beginning of simulation (restored from checkpoints and never reset) (Tick)
simFreq 1000000000000 # The number of ticks per simulated second ((Tick/Second))
hostSeconds 454.10 # Real time elapsed on the host (Second)
hostTickRate 1719313855 # The number of ticks simulated per host second (ticks/s) ((Tick/Second))
hostMemory 8529416 # Number of bytes of host memory used (Byte)
simInsts 503373855 # Number of instructions simulated (Count)
simOps 503373855 # Number of ops (including micro ops) simulated (Count)
hostInstRate 1108520 # Simulator instruction rate (inst/s) ((Count/Second))
hostOpRate 1108520 # Simulator op (including micro ops) rate (op/s) ((Count/Second))
system.clk_domain.clock 2000 # Clock period in ticks (Tick)
system.clk_domain.voltage_domain.voltage 1 # Voltage in Volts (Volt)
system.cpu.numCycles 390366204 # Number of cpu cycles simulated (Cycle)
system.cpu.cpi 0.775500 # CPI: cycles per instruction (core level) ((Cycle/Count))
system.cpu.ipc 1.289491 # IPC: instructions per cycle (core level) ((Count/Cycle))
system.cpu.numWorkItemsStarted 0 # Number of work items this cpu started (Count)
system.cpu.numWorkItemsCompleted 0 # Number of work items this cpu completed (Count)
system.cpu.issuedInstType_0::No_OpClass 0 0.00% 0.00% # Number of instructions issued per FU type, per thread (Count)
system.cpu.issuedInstType_0::IntAlu 285261860 56.67% 56.67% # Number of instructions issued per FU type, per thread (Count)
system.cpu.issuedInstType_0::IntMult 67108864 13.33% 70.00% # Number of instructions issued per FU type, per thread (Count)
1 | BNRV |
This is because BNRV have less operation to compute this kernel by removing useless code,I think.
I also noticed that BNRV accessed the cache less time,comparing with the normal BitNetKernel:
1 | normal |
1 | BNRV |
It’s means our CPU doesn’t need to recover the middle results,which decreases the energy consumption.