Skip to content

Commit da08343

Browse files
committed
fix taichi reset error in fpie-gui
1 parent 3782d25 commit da08343

File tree

2 files changed

+33
-9
lines changed

2 files changed

+33
-9
lines changed

fpie/gui.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import time
12
from typing import Any
23

34
import cv2
@@ -84,11 +85,13 @@ def target_callback(
8485
1,
8586
)
8687
mask = np.zeros([mask_y, mask_x], np.uint8) + 255
88+
t = time.time()
8789
self.proc.reset(self.src, mask, self.tgt, (self.y0, self.x0), (y, x))
8890
self.gui_out, err = self.proc.step(self.n) # type: ignore
91+
t = time.time() - t
8992
print(
90-
f"Args: -h0 {self.y0} -w0 {self.x0} -h1 {y} -w1 {x} -n {self.n}\t"
91-
f"mask size {mask.shape}, abs Error: {err}"
93+
f"Time elapsed: {t:.2f}s, mask size {mask.shape}, abs Error: {err}\t"
94+
f"Args: -n {self.n} -h0 {self.y0} -w0 {self.x0} -h1 {y} -w1 {x}"
9295
)
9396

9497

fpie/taichi_solver.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@ def __init__(self, backend: str, n_cpu: int, block_size: int) -> None:
1414
self.block_dim = block_size
1515
ti.init(arch=getattr(ti, backend.split("-")[-1]))
1616
self.N = 0
17+
self.fb: ti.FieldsBuilder
18+
self.fbst: ti._snode.snode_tree.SNodeTree
19+
self.terr = ti.field(ti.f32, (3,))
1720
self.tA = ti.field(ti.i32)
1821
self.tB = ti.field(ti.f32)
1922
self.tX = ti.field(ti.f32)
20-
self.terr = ti.field(ti.f32, (3,))
2123
self.tmp = ti.field(ti.f32)
2224

2325
def partition(self, mask: np.ndarray) -> np.ndarray:
@@ -29,10 +31,19 @@ def reset(self, N: int, A: np.ndarray, X: np.ndarray, B: np.ndarray) -> None:
2931
self.A = A
3032
self.B = B
3133
self.X = X
32-
ti.root.dense(ti.ij, A.shape).place(self.tA)
33-
ti.root.dense(ti.ij, B.shape).place(self.tB)
34-
ti.root.dense(ti.ij, X.shape).place(self.tX)
35-
ti.root.dense(ti.ij, X.shape).place(self.tmp)
34+
if hasattr(self, "fbst"):
35+
self.fbst.destroy()
36+
self.tA = ti.field(ti.i32)
37+
self.tB = ti.field(ti.f32)
38+
self.tX = ti.field(ti.f32)
39+
self.tmp = ti.field(ti.f32)
40+
self.fb = ti.FieldsBuilder()
41+
layout = self.fb.dense(ti.i, N)
42+
layout.dense(ti.j, 4).place(self.tA)
43+
layout.dense(ti.j, 3).place(self.tB)
44+
layout.dense(ti.j, 3).place(self.tX)
45+
layout.dense(ti.j, 3).place(self.tmp)
46+
self.fbst = self.fb.finalize()
3647
self.tA.from_numpy(A)
3748
self.tB.from_numpy(B)
3849
self.tX.from_numpy(X)
@@ -112,11 +123,13 @@ def __init__(
112123
self.parallelize = n_cpu
113124
self.block_dim = block_size
114125
ti.init(arch=getattr(ti, backend.split("-")[-1]))
126+
self.fb: ti.FieldsBuilder
127+
self.fbst: ti._snode.snode_tree.SNodeTree
128+
self.terr = ti.field(ti.f32, (3,))
115129
self.tmask = ti.field(ti.i32)
116130
self.ttgt = ti.field(ti.f32)
117131
self.tgrad = ti.field(ti.f32)
118132
self.tmp = ti.field(ti.f32)
119-
self.terr = ti.field(ti.f32, (3,))
120133

121134
def reset(
122135
self, N: int, mask: np.ndarray, tgt: np.ndarray, grad: np.ndarray
@@ -132,15 +145,23 @@ def reset(
132145

133146
self.N, self.M = N, M = mask.shape
134147
bx, by = N // gx, M // gy
135-
layout = ti.root.dense(ti.ij, (bx, by)).dense(ti.ij, (gx, gy))
136148
self.mask = mask
137149
self.tgt = tgt
138150
self.grad = grad
139151

152+
if hasattr(self, "fbst"):
153+
self.fbst.destroy()
154+
self.tmask = ti.field(ti.i32)
155+
self.ttgt = ti.field(ti.f32)
156+
self.tgrad = ti.field(ti.f32)
157+
self.tmp = ti.field(ti.f32)
158+
self.fb = ti.FieldsBuilder()
159+
layout = self.fb.dense(ti.ij, (bx, by)).dense(ti.ij, (gx, gy))
140160
layout.place(self.tmask)
141161
layout.dense(ti.k, 3).place(self.ttgt)
142162
layout.dense(ti.k, 3).place(self.tgrad)
143163
layout.dense(ti.k, 3).place(self.tmp)
164+
self.fbst = self.fb.finalize()
144165
self.tmask.from_numpy(mask)
145166
self.ttgt.from_numpy(tgt)
146167
self.tgrad.from_numpy(grad)

0 commit comments

Comments
 (0)