There has a special calculate work on modern CPU,named QNN inference.We use it to decrease the energy consumption during the calculation,such as the consumption when carrying the data from memory to ALU by compressing data from 32 bits to 8 bits or the area used by CPU by reducing the complexity of multiplication circuits and increase throughput,on resource limited devices.
In regular linear layers,the computation between weight W ∈ ℝd × m and input x ∈ ℝ1 × m can be expressed in form of matrix multiplication(MatMul): y = xW
where y ∈ ℝ1 × d .The
i-th elements of y is computed
as:
When the MatMul carrying float data,we can compress the weight matrix to a very low width,such as {−1, 0, 1} ,which can substantially decrease the complexity of float computation.The solution like this is named BitNet.As above,BitNet can help us to reduce the complexity of multiplication circuits,even no multiplication.
Modern ISA (especially RISC-V)provides many support to QNN,such as the SIMD extensions for common computations in neural networks.
Supposed there have a BitNet compution between a 32-bit floatpoint number (FP32) and a 2-bit weights.The input and output are FP32.In BitNet,the FP32 will be normalized by using root mean square normalization(RMSNorm) and quantized to 8-bit integers based on their absolute maximum values. After finishing conputing in BitNet,the results will be dequantized to 32-bit outputs by multiplying the scaling factors of activations and weights.In all the procedure,the CPU must de-packing the weights that is stored as 2-bit to 32-bit or higher and execute additionnal operation to align the weights,otherwise the CPU would not able to conpute MatMul,which has been simplied by ternary quantization to simply add.This de-pack procedure would incurs the most loss of performance by adding additional operation.
Because there is no way to store a 1.58-bit data,we store weights as
2-bit.And one bytes can store four weights.However,the depacking
operation from a bytes to four weights incurs significant performance
overhead.I will simulate the BitNet kernel on RISC-V: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
static const int8_t SCALE_LUT[4] = {0, 1, -1, -2};
int8_t packed_fourweights(int8_t w0, int8_t w1, int8_t w2, int8_t w3) {
int8_t bytes = 0;
bytes |= (w0 & 0x03) << 0;
bytes |= (w1 & 0x03) << 2;
bytes |= (w2 & 0x03) << 4;
bytes |= (w3 & 0x03) << 6;
return bytes;
}
void rms_norm(float* input, float* output, int n) {
float sum = 0;
for (int i = 0; i < n; i++) sum += input[i] * input[i];
float rms = sqrtf(sum / (float)n + 1e-6f);
for (int i = 0; i < n; i++) {
output[i] = input[i] / rms;
}
}
void int_quant(float* input, int8_t* output) {
float max_abs = 0.0f;
for (int i = 0; i < N; i++) {
float val = fabsf(input[i]);
if (val > max_abs) max_abs = val;
}
float scale = (max_abs > 1e-6f) ? (127.0f / max_abs) : 0.0f;
for (int i = 0; i < N; i++) {
output[i] = (int8_t)roundf(input[i] * scale);
}
}
void bitnet_matmul_kernel(int8_t* activations, uint8_t* packed_weights, int32_t* output) {
for (int i = 0; i < N; i++) {
int32_t sum = 0;
uint8_t* row_weights = &packed_weights[i * (N / 4)];
for (int j = 0; j < N / 4; j++) {
uint8_t weight_byte = row_weights[j];
int8_t* act_ptr = &activations[j * 4];
BITNET_STEP(0, 0)
BITNET_STEP(2, 1)
BITNET_STEP(4, 2)
BITNET_STEP(6, 3)
}
output[i] = sum;
}
}
int main() {
size_t total_weights_bytes = (N * N) / 4;
uint8_t* weights = (uint8_t*)malloc(total_weights_bytes);
int8_t* input_i = (int8_t*)malloc(N * sizeof(int8_t));
float* input = (float*)malloc(N * sizeof(float));
float* input_n = (float*)malloc(N * sizeof(float));
int32_t* out = (int32_t*)malloc(N * sizeof(int32_t));
for(size_t i = 0; i < total_weights_bytes; i++) {
weights[i] = (uint8_t)packed_fourweights(1, 3, 0, 2);
}
for (int i = 0; i < N; i++) {
input[i] = 0.25f;
}
rms_norm(input, input_n, N);
int_quant(input_n, input_i);
printf("Starting BitNet Kernel Simulation...\n");
m5_reset_stats(0, 0);
bitnet_matmul_kernel(input_i, weights, out);
m5_dump_stats(0, 0);
int64_t final_checksum = 0;
for (int i = 0; i < N; i++) {
final_checksum += out[i];
}
printf("Output Checksum: %ld\n", (long)final_checksum);
free(weights);
free(input_i);
free(input);
free(input_n);
free(out);
return 0;
}1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55import m5
from m5.objects import *
class L1Cache(Cache):
assoc = 2
tag_latency = 2
data_latency = 2
response_latency = 2
mshrs = 4
tgts_per_mshr = 20
system = System()
system.clk_domain = SrcClockDomain()
system.clk_domain.clock = '500MHaz'
system.clk_domain.voltage_domain = VoltageDomain()
system.mem_mode = 'timing'
system.mem_ranges = [AddrRange('8192MB')]
system.cpu = RiscvMinorCPU()
# the main bus
system.membus = SystemXBar()
# cache
system.cpu.icache = L1Cache(size = '32kB')
system.cpu.dcache = L1Cache(size = '64kB')
system.cpu.icache.cpu_side = system.cpu.icache_port
system.cpu.dcache.cpu_side = system.cpu.dcache_port
system.l2bus = L2XBar()
system.cpu.icache.mem_side = system.l2bus.cpu_side_ports
system.cpu.dcache.mem_side = system.l2bus.cpu_side_ports
system.l2cache = Cache(size = '256kB', assoc = 8, tag_latency = 20, data_latency = 20, response_latency = 20, mshrs = 20, tgts_per_mshr = 12)
system.l2cache.cpu_side = system.l2bus.mem_side_ports
system.l2cache.mem_side = system.membus.cpu_side_ports
# interrupt controller
system.cpu.createInterruptController()
system.system_port = system.membus.cpu_side_ports
system.mem_ctrl = MemCtrl()
system.mem_ctrl.dram = DDR3_2133_8x8()
system.mem_ctrl.dram.range = system.mem_ranges[0]
system.mem_ctrl.port = system.membus.mem_side_ports
binary = 'BitNet'
# for gem5 V21 and beyond
system.workload = SEWorkload.init_compatible(binary)
process = Process()
process.cmd = [binary]
system.cpu.workload = process
system.cpu.createThreads()
root = Root(full_system = False, system = system)
m5.instantiate()
print("Beginning simulation!")
exit_event = m5.simulate()
print('Exiting @ tick {} because {}'.format(m5.curTick(), exit_event.getCause()))
After executing the kernel on this CPU,I found that the number of
IntMult is 0,meaning that we have simplified the MatMul to add
operation.And the number of IntAlu is 7592896 whose propotion is 62.35% holding the most percent of operation:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21---------- Begin Simulation Statistics ----------
simSeconds 0.095827 # Number of seconds simulated (Second)
simTicks 95827216000 # Number of ticks simulated (Tick)
finalTick 132332018000 # Number of ticks from beginning of simulation (restored from checkpoints and never reset) (Tick)
simFreq 1000000000000 # The number of ticks per simulated second ((Tick/Second))
hostSeconds 47.36 # Real time elapsed on the host (Second)
hostTickRate 2023284131 # The number of ticks simulated per host second (ticks/s) ((Tick/Second))
hostMemory 8529176 # Number of bytes of host memory used (Byte)
simInsts 42980208 # Number of instructions simulated (Count)
simOps 42980208 # Number of ops (including micro ops) simulated (Count)
hostInstRate 907478 # Simulator instruction rate (inst/s) ((Count/Second))
hostOpRate 907477 # Simulator op (including micro ops) rate (op/s) ((Count/Second))
system.clk_domain.clock 2000 # Clock period in ticks (Tick)
system.clk_domain.voltage_domain.voltage 1 # Voltage in Volts (Volt)
system.cpu.numCycles 47913608 # Number of cpu cycles simulated (Cycle)
system.cpu.cpi 1.114783 # CPI: cycles per instruction (core level) ((Cycle/Count))
system.cpu.ipc 0.897036 # IPC: instructions per cycle (core level) ((Count/Cycle))
system.cpu.numWorkItemsStarted 0 # Number of work items this cpu started (Count)
system.cpu.numWorkItemsCompleted 0 # Number of work items this cpu completed (Count)
system.cpu.issuedInstType_0::No_OpClass 0 0.00% 0.00% # Number of instructions issued per FU type, per thread (Count)
system.cpu.issuedInstType_0::IntAlu 26272675 61.12% 61.12% # Number of instructions issued per FU type, per thread (Count)
For INT8 baseline: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
void int8_matmul_kernel(int8_t* activations, int8_t* weights, int32_t* output) {
for (int i = 0; i < N; i++) {
int32_t sum = 0;
int8_t* row_weights = &weights[i * N];
for (int j = 0; j < N; j++) {
sum += (int32_t)activations[j] * (int32_t)row_weights[j];
}
output[i] = sum;
}
}
int main() {
int8_t* weights = (int8_t*)malloc(N * N * sizeof(int8_t));
int8_t* input_i = (int8_t*)malloc(N * sizeof(int8_t));
int32_t* out = (int32_t*)malloc(N * sizeof(int32_t));
for(int i = 0; i < N * N; i++) weights[i] = (int8_t)(rand() % 3 - 1);
for(int i = 0; i < N; i++) input_i[i] = (int8_t)(rand() % 256 - 128);
m5_reset_stats(0, 0);
printf("Starting INT8 Baseline Simulation...\n");
int8_matmul_kernel(input_i, weights, out);
m5_dump_stats(0, 0);
return 0;
}1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
---------- Begin Simulation Statistics ----------
simSeconds 0.055076 # Number of seconds simulated (Second)
simTicks 55075592000 # Number of ticks simulated (Tick)
finalTick 281435572000 # Number of ticks from beginning of simulation (restored from checkpoints and never reset) (Tick)
simFreq 1000000000000 # The number of ticks per simulated second ((Tick/Second))
hostSeconds 26.91 # Real time elapsed on the host (Second)
hostTickRate 2046673084 # The number of ticks simulated per host second (ticks/s) ((Tick/Second))
hostMemory 8529176 # Number of bytes of host memory used (Byte)
simInsts 22050565 # Number of instructions simulated (Count)
simOps 22050565 # Number of ops (including micro ops) simulated (Count)
hostInstRate 819422 # Simulator instruction rate (inst/s) ((Count/Second))
hostOpRate 819422 # Simulator op (including micro ops) rate (op/s) ((Count/Second))
system.clk_domain.clock 2000 # Clock period in ticks (Tick)
system.clk_domain.voltage_domain.voltage 1 # Voltage in Volts (Volt)
system.cpu.numCycles 27537796 # Number of cpu cycles simulated (Cycle)
system.cpu.cpi 1.248848 # CPI: cycles per instruction (core level) ((Cycle/Count))
system.cpu.ipc 0.800738 # IPC: instructions per cycle (core level) ((Count/Cycle))
These datas show that our BitNet kernel uses shift operation and add operation mostly,who supporting the statement of this thesis:
“Despite the easy-to-compute nature of the BitNet MatMul, the operation becomes the bottleneck on regular CPUs due to the lack of native support for low-bitwidth operations, which motivates us to extend the ISA and unleash the full potential of BitNet-based models.”
“The official RISC-V Packed SIMD Extension provides general-purpose SIMD operations primarily for 8-bit and 16- bit data types, including addition, subtraction, and multiplica- tion. However, its architecture is not optimized for the ultra- low-bitwidth weights characteristic of BitNet models.”1
For low-bitwidth weights characteristic of BitNet models,RISC-V processor always extend it to higher bitwidth,which results in suboptimal performance.
Another extensions such as XPulpNN and MPIC are also incurs some loss on CPU performance.For XPulpNN,its paradigm is not aligned for BitNet,where 8-bit activations are combined with weights ranging from 1 bit to 2 bits, which causes a gap in hardware acceleration for such mixed-precision operations.MPIC only supports weights down to 2-bit, while binary and ternary weights are neglected.
Z. Jiang and Y. Lyu, “BNRV: A Lightweight SIMD Extension for Efficient BitNet Inference on RISC-V CPUs,” 2025 IEEE 43rd International Conference on Computer Design (ICCD), 2025. DOI: 10.1109/ICCD65941.2025.00100.↩︎