@@ -14,10 +14,12 @@ def __init__(self, backend: str, n_cpu: int, block_size: int) -> None:
14
14
self .block_dim = block_size
15
15
ti .init (arch = getattr (ti , backend .split ("-" )[- 1 ]))
16
16
self .N = 0
17
+ self .fb : ti .FieldsBuilder
18
+ self .fbst : ti ._snode .snode_tree .SNodeTree
19
+ self .terr = ti .field (ti .f32 , (3 ,))
17
20
self .tA = ti .field (ti .i32 )
18
21
self .tB = ti .field (ti .f32 )
19
22
self .tX = ti .field (ti .f32 )
20
- self .terr = ti .field (ti .f32 , (3 ,))
21
23
self .tmp = ti .field (ti .f32 )
22
24
23
25
def partition (self , mask : np .ndarray ) -> np .ndarray :
@@ -29,10 +31,19 @@ def reset(self, N: int, A: np.ndarray, X: np.ndarray, B: np.ndarray) -> None:
29
31
self .A = A
30
32
self .B = B
31
33
self .X = X
32
- ti .root .dense (ti .ij , A .shape ).place (self .tA )
33
- ti .root .dense (ti .ij , B .shape ).place (self .tB )
34
- ti .root .dense (ti .ij , X .shape ).place (self .tX )
35
- ti .root .dense (ti .ij , X .shape ).place (self .tmp )
34
+ if hasattr (self , "fbst" ):
35
+ self .fbst .destroy ()
36
+ self .tA = ti .field (ti .i32 )
37
+ self .tB = ti .field (ti .f32 )
38
+ self .tX = ti .field (ti .f32 )
39
+ self .tmp = ti .field (ti .f32 )
40
+ self .fb = ti .FieldsBuilder ()
41
+ layout = self .fb .dense (ti .i , N )
42
+ layout .dense (ti .j , 4 ).place (self .tA )
43
+ layout .dense (ti .j , 3 ).place (self .tB )
44
+ layout .dense (ti .j , 3 ).place (self .tX )
45
+ layout .dense (ti .j , 3 ).place (self .tmp )
46
+ self .fbst = self .fb .finalize ()
36
47
self .tA .from_numpy (A )
37
48
self .tB .from_numpy (B )
38
49
self .tX .from_numpy (X )
@@ -112,11 +123,13 @@ def __init__(
112
123
self .parallelize = n_cpu
113
124
self .block_dim = block_size
114
125
ti .init (arch = getattr (ti , backend .split ("-" )[- 1 ]))
126
+ self .fb : ti .FieldsBuilder
127
+ self .fbst : ti ._snode .snode_tree .SNodeTree
128
+ self .terr = ti .field (ti .f32 , (3 ,))
115
129
self .tmask = ti .field (ti .i32 )
116
130
self .ttgt = ti .field (ti .f32 )
117
131
self .tgrad = ti .field (ti .f32 )
118
132
self .tmp = ti .field (ti .f32 )
119
- self .terr = ti .field (ti .f32 , (3 ,))
120
133
121
134
def reset (
122
135
self , N : int , mask : np .ndarray , tgt : np .ndarray , grad : np .ndarray
@@ -132,15 +145,23 @@ def reset(
132
145
133
146
self .N , self .M = N , M = mask .shape
134
147
bx , by = N // gx , M // gy
135
- layout = ti .root .dense (ti .ij , (bx , by )).dense (ti .ij , (gx , gy ))
136
148
self .mask = mask
137
149
self .tgt = tgt
138
150
self .grad = grad
139
151
152
+ if hasattr (self , "fbst" ):
153
+ self .fbst .destroy ()
154
+ self .tmask = ti .field (ti .i32 )
155
+ self .ttgt = ti .field (ti .f32 )
156
+ self .tgrad = ti .field (ti .f32 )
157
+ self .tmp = ti .field (ti .f32 )
158
+ self .fb = ti .FieldsBuilder ()
159
+ layout = self .fb .dense (ti .ij , (bx , by )).dense (ti .ij , (gx , gy ))
140
160
layout .place (self .tmask )
141
161
layout .dense (ti .k , 3 ).place (self .ttgt )
142
162
layout .dense (ti .k , 3 ).place (self .tgrad )
143
163
layout .dense (ti .k , 3 ).place (self .tmp )
164
+ self .fbst = self .fb .finalize ()
144
165
self .tmask .from_numpy (mask )
145
166
self .ttgt .from_numpy (tgt )
146
167
self .tgrad .from_numpy (grad )
0 commit comments