Skip to content

Commit 3dc019d

Browse files
authored
Merge pull request #55 from estodi/fix_thread_matmul_ext_onchip_ram
fix and update thread_matmul_ext_onchip_ram
2 parents 8d25bf8 + d8df881 commit 3dc019d

File tree

1 file changed

+89
-72
lines changed

1 file changed

+89
-72
lines changed

examples/thread_matmul_ext_onchip_ram/thread_matmul_ext_onchip_ram.py

Lines changed: 89 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
import os
55
import numpy as np
6+
import math
67

78
# the next line can be removed after installation
89
sys.path.insert(0, os.path.dirname(os.path.dirname(
@@ -12,13 +13,23 @@
1213
import veriloggen.thread as vthread
1314
import veriloggen.types.axi as axi
1415

15-
datawidth = 8
16+
mem_datawidth = 8
17+
datawidth = 16
1618
addrwidth = 8
1719

18-
matrix_size = 8
20+
matrix_size = 10
21+
22+
num_pack = math.ceil(datawidth / mem_datawidth)
23+
addr_pack = math.ceil((addrwidth + math.ceil(np.log2(datawidth / mem_datawidth)))
24+
/ mem_datawidth)
25+
26+
matrix_size_addr = 0
27+
a_offset_addr = 4
28+
b_offset_addr = 8
29+
c_offset_addr = 12
1930
a_offset = 16
20-
b_offset = a_offset + matrix_size * matrix_size
21-
c_offset = b_offset + matrix_size * matrix_size
31+
b_offset = a_offset + matrix_size * matrix_size * num_pack
32+
c_offset = b_offset + matrix_size * matrix_size * num_pack
2233

2334

2435
def mkLed():
@@ -28,7 +39,8 @@ def mkLed():
2839
start = m.Input('start')
2940
busy = m.OutputReg('busy', initval=0)
3041

31-
ram = vthread.ExtRAM(m, 'ram', clk, rst, datawidth, addrwidth)
42+
ram = vthread.ExtRAM(m, 'ram', clk, rst, mem_datawidth,
43+
addrwidth + math.ceil(np.log2(datawidth / mem_datawidth)))
3244

3345
def matmul():
3446
while True:
@@ -46,27 +58,31 @@ def wait():
4658
busy.value = 1
4759

4860
def read_matrix_size():
49-
size0 = ram.read(0)
50-
size1 = ram.read(1)
51-
size = (size1 << 8) | size0
61+
size = 0
62+
for i in range(addr_pack):
63+
size |= ((ram.read(matrix_size_addr + i) & ((1 << mem_datawidth) - 1))
64+
<< (mem_datawidth * i))
5265
return size
5366

5467
def read_matrix_a_offset():
55-
offset0 = ram.read(4) & 0xff
56-
offset1 = ram.read(5) & 0xff
57-
offset = (offset1 << 8) | offset0
68+
offset = 0
69+
for i in range(addr_pack):
70+
offset |= ((ram.read(a_offset_addr + i) & ((1 << mem_datawidth) - 1))
71+
<< (mem_datawidth * i))
5872
return offset
5973

6074
def read_matrix_b_offset():
61-
offset0 = ram.read(8) & 0xff
62-
offset1 = ram.read(9) & 0xff
63-
offset = (offset1 << 8) | offset0
75+
offset = 0
76+
for i in range(addr_pack):
77+
offset |= ((ram.read(b_offset_addr + i) & ((1 << mem_datawidth) - 1))
78+
<< (mem_datawidth * i))
6479
return offset
6580

6681
def read_matrix_c_offset():
67-
offset0 = ram.read(12) & 0xff
68-
offset1 = ram.read(13) & 0xff
69-
offset = (offset1 << 8) | offset0
82+
offset = 0
83+
for i in range(addr_pack):
84+
offset |= ((ram.read(c_offset_addr + i) & ((1 << mem_datawidth) - 1))
85+
<< (mem_datawidth * i))
7086
return offset
7187

7288
def comp(matrix_size, a_offset, b_offset, c_offset):
@@ -77,15 +93,24 @@ def comp(matrix_size, a_offset, b_offset, c_offset):
7793
for j in range(matrix_size):
7894
sum = 0
7995
for k in range(matrix_size):
80-
x = ram.read(a_addr + k)
81-
y = ram.read(b_addr + k)
96+
x = int(0, base=2)
97+
y = 0
98+
for l in range(num_pack):
99+
x |= ((ram.read(a_addr + k * num_pack + l)
100+
& ((1 << mem_datawidth) - 1))
101+
<< (mem_datawidth * l))
102+
y |= ((ram.read(b_addr + k * num_pack + l)
103+
& ((1 << mem_datawidth) - 1))
104+
<< (mem_datawidth * l))
82105
sum += x * y
83-
ram.write(c_addr + j, sum)
106+
for l in range(num_pack):
107+
ram.write(c_addr + j * num_pack + l,
108+
(sum >> (mem_datawidth * l)) & (1<<mem_datawidth)-1)
84109

85-
b_addr += matrix_size * (datawidth // 8)
110+
b_addr += matrix_size * num_pack
86111

87-
a_addr += matrix_size * (datawidth // 8)
88-
c_addr += matrix_size * (datawidth // 8)
112+
a_addr += matrix_size * num_pack
113+
c_addr += matrix_size * num_pack
89114

90115
def done():
91116
busy.value = 0
@@ -128,13 +153,11 @@ def mkTest(memimg_name=None):
128153
b[y][x] = 0
129154

130155
a_addr = a_offset
131-
size_a = n_a * datawidth // 8
132156
b_addr = b_offset
133-
size_b = n_b * datawidth // 8
134157

135-
mem = np.zeros([2 ** addrwidth * (8 // datawidth)], dtype=np.int64)
136-
axi.set_memory(mem, a, datawidth, datawidth, a_addr)
137-
axi.set_memory(mem, b, datawidth, datawidth, b_addr)
158+
mem = np.zeros([(2 ** addrwidth) * num_pack], dtype=np.int64)
159+
axi.set_memory(mem, a, mem_datawidth, datawidth, a_addr)
160+
axi.set_memory(mem, b, mem_datawidth, datawidth, b_addr)
138161

139162
led = mkLed()
140163

@@ -149,7 +172,8 @@ def mkTest(memimg_name=None):
149172

150173
start.initval = 0
151174

152-
memory = vthread.RAM(m, 'memory', clk, rst, datawidth, addrwidth,
175+
memory = vthread.RAM(m, 'memory', clk, rst, mem_datawidth,
176+
addrwidth + math.ceil(np.log2(datawidth / mem_datawidth)),
153177
numports=2, initvals=mem.tolist())
154178
memory.connect_rtl(0, ports['ram_0_addr'], ports['ram_0_wdata'],
155179
ports['ram_0_wenable'], ports['ram_0_rdata'],
@@ -166,45 +190,33 @@ def ctrl():
166190
for i in range(100):
167191
pass
168192

169-
awaddr = 0
170-
v = (matrix_size & 0xff)
171-
print('# matrix_size[7:0] = %d' % v)
172-
memory.write(awaddr, v, port=1)
173-
174-
awaddr = 1
175-
v = ((matrix_size >> 8) & 0xff)
176-
print('# matrix_size[15:8] = %d' % v)
177-
memory.write(awaddr, v, port=1)
178-
179-
awaddr = 4
180-
v = (a_offset & 0xff)
181-
print('# a_offset[7:0] = %d' % v)
182-
memory.write(awaddr, v, port=1)
183-
184-
awaddr = 5
185-
v = ((a_offset >> 8) & 0xff)
186-
print('# a_offset[15:8] = %d' % v)
187-
memory.write(awaddr, v, port=1)
188-
189-
awaddr = 8
190-
v = (b_offset & 0xff)
191-
print('# b_offset[7:0] = %d' % v)
192-
memory.write(awaddr, v, port=1)
193-
194-
awaddr = 9
195-
v = ((b_offset >> 8) & 0xff)
196-
print('# b_offset[15:8] = %d' % v)
197-
memory.write(awaddr, v, port=1)
198-
199-
awaddr = 12
200-
v = (c_offset & 0xff)
201-
print('# c_offset[7:0] = %d' % v)
202-
memory.write(awaddr, v, port=1)
203-
204-
awaddr = 13
205-
v = ((c_offset >> 8) & 0xff)
206-
print('# c_offset[15:8] = %d' % v)
207-
memory.write(awaddr, v, port=1)
193+
for i in range(addr_pack):
194+
awaddr = matrix_size_addr + i
195+
v = (matrix_size >> (mem_datawidth * i)) & ((1 << mem_datawidth) - 1)
196+
print('# matrix_size[%d:%d] = %d' %
197+
(mem_datawidth * (i+1) - 1, mem_datawidth * i, v))
198+
memory.write(awaddr, v, port=1)
199+
200+
for i in range(addr_pack):
201+
awaddr = a_offset_addr + i
202+
v = (a_offset >> (mem_datawidth * i)) & ((1 << mem_datawidth) - 1)
203+
print('# a_offset[%d:%d] = %d' %
204+
(mem_datawidth * (i+1) - 1, mem_datawidth * i, v))
205+
memory.write(awaddr, v, port=1)
206+
207+
for i in range(addr_pack):
208+
awaddr = b_offset_addr + i
209+
v = (b_offset >> (mem_datawidth * i)) & ((1 << mem_datawidth) - 1)
210+
print('# b_offset[%d:%d] = %d' %
211+
(mem_datawidth * (i+1) - 1, mem_datawidth * i, v))
212+
memory.write(awaddr, v, port=1)
213+
214+
for i in range(addr_pack):
215+
awaddr = c_offset_addr + i
216+
v = (c_offset >> (mem_datawidth * i)) & ((1 << mem_datawidth) - 1)
217+
print('# c_offset[%d:%d] = %d' %
218+
(mem_datawidth * (i+1) - 1, mem_datawidth * i, v))
219+
memory.write(awaddr, v, port=1)
208220

209221
start_time = counter
210222
print('# start time = %d' % start_time)
@@ -227,14 +239,19 @@ def ctrl():
227239
all_ok = True
228240
for y in range(matrix_size):
229241
for x in range(matrix_size):
230-
v = memory.read(
231-
c_offset + (y * matrix_size + x) * datawidth // 8, port=1)
242+
v = 0
243+
v_addr = c_offset + (y * matrix_size + x) * num_pack
244+
for l in range(num_pack):
245+
v |= memory.read(v_addr + l, port=1) << (mem_datawidth * l)
246+
v |= ((memory.read(v_addr + l, port=1)
247+
& ((1 << mem_datawidth) - 1))
248+
<< (mem_datawidth * l))
232249
if y == x and vthread.verilog.NotEql(v, (y + 1) * 2):
233250
all_ok = False
234-
print("NG [%d,%d] = %d" % (y, x, v))
251+
print("NG [%d,%d] = %d (expected: %d)" % (y, x, v, (y + 1) * 2))
235252
if y != x and vthread.verilog.NotEql(v, 0):
236253
all_ok = False
237-
print("NG [%d,%d] = %d" % (y, x, v))
254+
print("NG [%d,%d] = %d (expected: %d)" % (y, x, v, 0))
238255

239256
if all_ok:
240257
print('# verify: PASSED')

0 commit comments

Comments
 (0)