@@ -28,6 +28,14 @@ trait StagedWasmEvaluator extends SAIOps {
28
28
// a cache storing the compiled code for each function, to reduce re-compilation
29
29
val compileCache = new HashMap [Int , Rep [(Cont [Unit ]) => Unit ]]
30
30
31
+ def funHere [A : Manifest ,B : Manifest ](f : Rep [A ] => Rep [B ], dummy : Rep [Unit ] = " dummy" .reflectCtrlWith[Unit ]()): Rep [A => B ] = {
32
+ // to avoid LMS lifting a function, we create a dummy node and read it inside function
33
+ fun((x : Rep [A ]) => {
34
+ " dummy-op" .reflectCtrlWith[Unit ](dummy)
35
+ f(x)
36
+ })
37
+ }
38
+
31
39
// NOTE: We don't support Ans type polymorphism yet
32
40
def eval (insts : List [Instr ],
33
41
kont : Rep [Cont [Unit ]],
@@ -88,37 +96,31 @@ trait StagedWasmEvaluator extends SAIOps {
88
96
// no need to modify the stack when entering a block
89
97
// the type system guarantees that we will never take more than the input size from the stack
90
98
val funcTy = ty.funcType
91
- val dummy = " dummy" .reflectCtrlWith[Unit ]()
92
99
// TODO: somehow the type of exitSize in residual program is nothing
93
- def restK : Rep [Cont [Unit ]] = fun ((_ : Rep [Unit ]) => {
100
+ def restK : Rep [Cont [Unit ]] = funHere ((_ : Rep [Unit ]) => {
94
101
info(s " Exiting the block, stackSize = " , Stack .size)
95
- " dummy-op" .reflectCtrlWith[Unit ](dummy)
96
102
eval(rest, kont, trail)
97
103
})
98
104
eval(inner, restK, restK :: trail)
99
105
case Loop (ty, inner) =>
100
106
val funcTy = ty.funcType
101
107
val exitSize = Stack .size - funcTy.inps.size + funcTy.out.size
102
- val dummy = " dummy" .reflectCtrlWith[Unit ]()
103
- def restK = fun((_ : Rep [Unit ]) => {
104
- " dummy-op" .reflectCtrlWith[Unit ](dummy)
108
+ def restK = funHere((_ : Rep [Unit ]) => {
105
109
info(s " Exiting the loop, stackSize = " , Stack .size)
106
110
eval(rest, kont, trail)
107
111
})
108
- def loop : Rep [ Unit => Unit ] = fun(( _u : Rep [Unit ]) => {
109
- " dummy-op " .reflectCtrlWith [Unit ](dummy)
112
+ val dummy = " dummy " .reflectCtrlWith [Unit ]()
113
+ def loop : Rep [Unit => Unit ] = funHere(( _u : Rep [ Unit ]) => {
110
114
info(s " Entered the loop, stackSize = " , Stack .size)
111
115
eval(inner, restK, loop :: trail)
112
- })
116
+ }, dummy) // <-- if we don't pass this dummy argument, lots of code will be generated
113
117
loop(())
114
118
case If (ty, thn, els) =>
115
119
val funcTy = ty.funcType
116
120
val exitSize = Stack .size - funcTy.inps.size + funcTy.out.size
117
121
val cond = Stack .pop()
118
- val dummy = " dummy" .reflectCtrlWith[Unit ]()
119
122
// TODO: can we avoid code duplication here?
120
- def restK = fun((_ : Rep [Unit ]) => {
121
- " dummy-op" .reflectCtrlWith[Unit ](dummy)
123
+ def restK = funHere((_ : Rep [Unit ]) => {
122
124
info(s " Exiting the if, stackSize = " , Stack .size)
123
125
eval(rest, kont, trail)
124
126
})
0 commit comments