Skip to content

Commit 81dd5bd

Browse files
Merge branch 'zdh/staged-eval' of github.com:Generative-Program-Analysis/GenSym into zdh/staged-eval
2 parents 74dfb1b + 39baa4a commit 81dd5bd

File tree

2 files changed

+31
-21
lines changed

2 files changed

+31
-21
lines changed

src/main/scala/wasm/StagedMiniWasm.scala

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import lms.core.virtualize
77
import lms.macros.SourceContext
88
import lms.core.stub.{Base, ScalaGenBase, CGenBase}
99
import lms.core.Backend._
10-
import lms.core.Backend.{Block => LMSBlock}
10+
import lms.core.Backend.{Block => LMSBlock, Const => LMSConst}
1111
import lms.core.Graph
1212

1313
import gensym.wasm.ast._
@@ -28,6 +28,14 @@ trait StagedWasmEvaluator extends SAIOps {
2828
// a cache storing the compiled code for each function, to reduce re-compilation
2929
val compileCache = new HashMap[Int, Rep[(Cont[Unit]) => Unit]]
3030

31+
def funHere[A:Manifest,B:Manifest](f: Rep[A] => Rep[B], dummy: Rep[Unit] = "dummy".reflectCtrlWith[Unit]()): Rep[A => B] = {
32+
// to avoid LMS lifting a function, we create a dummy node and read it inside function
33+
fun((x: Rep[A]) => {
34+
"dummy-op".reflectCtrlWith[Unit](dummy)
35+
f(x)
36+
})
37+
}
38+
3139
// NOTE: We don't support Ans type polymorphism yet
3240
def eval(insts: List[Instr],
3341
kont: Rep[Cont[Unit]],
@@ -88,37 +96,31 @@ trait StagedWasmEvaluator extends SAIOps {
8896
// no need to modify the stack when entering a block
8997
// the type system guarantees that we will never take more than the input size from the stack
9098
val funcTy = ty.funcType
91-
val dummy = "dummy".reflectCtrlWith[Unit]()
9299
// TODO: somehow the type of exitSize in residual program is nothing
93-
def restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => {
100+
def restK: Rep[Cont[Unit]] = funHere((_: Rep[Unit]) => {
94101
info(s"Exiting the block, stackSize =", Stack.size)
95-
"dummy-op".reflectCtrlWith[Unit](dummy)
96102
eval(rest, kont, trail)
97103
})
98104
eval(inner, restK, restK :: trail)
99105
case Loop(ty, inner) =>
100106
val funcTy = ty.funcType
101107
val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size
102-
val dummy = "dummy".reflectCtrlWith[Unit]()
103-
def restK = fun((_: Rep[Unit]) => {
104-
"dummy-op".reflectCtrlWith[Unit](dummy)
108+
def restK = funHere((_: Rep[Unit]) => {
105109
info(s"Exiting the loop, stackSize =", Stack.size)
106110
eval(rest, kont, trail)
107111
})
108-
def loop : Rep[Unit => Unit] = fun((_u: Rep[Unit]) => {
109-
"dummy-op".reflectCtrlWith[Unit](dummy)
112+
val dummy = "dummy".reflectCtrlWith[Unit]()
113+
def loop : Rep[Unit => Unit] = funHere((_u: Rep[Unit]) => {
110114
info(s"Entered the loop, stackSize =", Stack.size)
111115
eval(inner, restK, loop :: trail)
112-
})
116+
}, dummy) // <-- if we don't pass this dummy argument, lots of code will be generated
113117
loop(())
114118
case If(ty, thn, els) =>
115119
val funcTy = ty.funcType
116120
val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size
117121
val cond = Stack.pop()
118-
val dummy = "dummy".reflectCtrlWith[Unit]()
119122
// TODO: can we avoid code duplication here?
120-
def restK = fun((_: Rep[Unit]) => {
121-
"dummy-op".reflectCtrlWith[Unit](dummy)
123+
def restK = funHere((_: Rep[Unit]) => {
122124
info(s"Exiting the if, stackSize =", Stack.size)
123125
eval(rest, kont, trail)
124126
})
@@ -694,6 +696,11 @@ object WasmToScalaCompiler {
694696
}
695697

696698
trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase {
699+
override def mayInline(n: Node): Boolean = n match {
700+
case Node(s, "stack-pop", _, _) => false
701+
case _ => super.mayInline(n)
702+
}
703+
697704
override def remap(m: Manifest[_]): String = {
698705
if (m.toString.endsWith("Num")) "Num"
699706
else if (m.toString.endsWith("Slice")) "Slice"
@@ -727,8 +734,11 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase {
727734
emit("Frames.set("); shallow(i); emit(", "); shallow(value); emit(");\n")
728735
case Node(_, "global-set", List(i, value), _) =>
729736
emit("Global.globalSet("); shallow(i); emit(", "); shallow(value); emit(");\n")
737+
// Note: The following code is copied from the traverse of CppBackend.scala, try to avoid duplicated code
738+
case n @ Node(f, "λ", (b: LMSBlock)::LMSConst(0)::rest, _) =>
739+
// TODO: Is a leading block followed by 0 a hint for top function?
740+
super.traverse(n)
730741
case n @ Node(f, "λ", (b: LMSBlock)::rest, _) =>
731-
// Node: This code is copied from the traverse of CppSAICodeGenBase.scala, try to avoid code duplication
732742
val retType = remap(typeBlockRes(b.res))
733743
val argTypes = b.in.map(a => remap(typeMap(a))).mkString(", ")
734744
emitln(s"std::function<$retType(${argTypes})> ${quote(f)};")

src/test/scala/genwasym/TestStagedEval.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ class TestStagedEval extends FunSuite {
2828
val moduleInst = ModuleInstance(Parser.parseFile(filename))
2929
val code = WasmToCppCompiler.compile(moduleInst, main, true)
3030
if (printRes) {
31-
val writer = new java.io.PrintWriter(new java.io.File(s"$filename.cpp"))
32-
try {
33-
writer.write(code)
34-
} finally {
35-
writer.close()
36-
}
31+
println(code)
32+
}
33+
val writer = new java.io.PrintWriter(new java.io.File(s"$filename.cpp"))
34+
try {
35+
writer.write(code)
36+
} finally {
37+
writer.close()
3738
}
38-
println(code)
3939
}
4040

4141
test("ack-cpp") { testFileToCpp("./benchmarks/wasm/ack.wat", Some("real_main"), printRes = true) }

0 commit comments

Comments
 (0)