From 157aabf5ef0d386aca58a2af9b84a0ff5fc92474 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 23 Apr 2025 02:01:27 +0800 Subject: [PATCH 01/62] try lms --- benchmarks/wasm/staged/push-drop.wat | 7 ++ src/main/scala/wasm/StagedMiniWasm.scala | 98 ++++++++++++++++++++ src/test/scala/genwasym/TestStagedEval.scala | 22 +++++ 3 files changed, 127 insertions(+) create mode 100644 benchmarks/wasm/staged/push-drop.wat create mode 100644 src/main/scala/wasm/StagedMiniWasm.scala create mode 100644 src/test/scala/genwasym/TestStagedEval.scala diff --git a/benchmarks/wasm/staged/push-drop.wat b/benchmarks/wasm/staged/push-drop.wat new file mode 100644 index 00000000..eec1f816 --- /dev/null +++ b/benchmarks/wasm/staged/push-drop.wat @@ -0,0 +1,7 @@ +(module $push-drop + (func $real_main (type 1) (result i32) + i32.const 2 + i32.const 2 + drop + drop) + (start 0)) \ No newline at end of file diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala new file mode 100644 index 00000000..26691b0e --- /dev/null +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -0,0 +1,98 @@ +package gensym.wasm.miniwasm + +import scala.collection.mutable.ArrayBuffer + +import lms.core.stub.Adapter +import lms.core.virtualize +import lms.core.stub.Base +import lms.core.Backend.{Block => LMSBlock} + +import gensym.wasm.ast._ +import gensym.wasm.ast.{Const => ConstInstr} + +case class StagedEvaluator(module: ModuleInstance) extends Base { + // reset and initialize the internal state of Adapter + Adapter.resetState + Adapter.g = Adapter.mkGraphBuilder + + type Stack = Rep[List[Value]] + type Cont[A] = Stack => Rep[A] + type Trail[A] = List[Cont[A]] + + // Ans should be instantiated to something like Int, Unit, etc, which is the result type of staged program + def eval[Ans](insts: List[Instr], + stack: Stack, + frame: Rep[Frame], + kont: Cont[Ans], + trail: Trail[Ans]): Rep[Ans] = { + if (insts.isEmpty) return kont(stack) + val (inst, rest) = (insts.head, insts.tail) + inst match { + case Drop => eval(rest, stack.tail, frame, kont, trail) + // Why this cons operation compiled? does anything could be casted to Rep? + case ConstInstr(num) => eval(rest, num :: stack, frame, kont, trail) + case _ => "todo-op".reflectWith() + } + } + + def evalTop[Ans](kont: Cont[Ans], main: Option[String]): Rep[Ans] = { + val funBody: FuncBodyDef = main match { + case Some(func_name) => + module.defs.flatMap({ + case Export(`func_name`, ExportFunc(fid)) => + println(s"Entering function $main") + module.funcs(fid) match { + case FuncDef(_, body@FuncBodyDef(_,_,_,_)) => Some(body) + case _ => throw new Exception("Entry function has no concrete body") + } + case _ => None + }).head + case None => + val startIds = module.defs.flatMap { + case Start(id) => Some(id) + case _ => None + } + val startId = startIds.headOption.getOrElse { throw new Exception("No start function") } + module.funcs(startId) match { + case FuncDef(_, body@FuncBodyDef(_,_,_,_)) => body + case _ => + throw new Exception("Entry function has no concrete body") + } + } + val (instrs, localSize) = (funBody.body, funBody.locals.size) + val frame = Frame(ArrayBuffer.fill(localSize)(I32V(0))) + eval(instrs, List(), frame, kont, List(kont)) + } + + def evalTop(main: Option[String]): Rep[Unit] = { + val haltK: Cont[Unit] = stack => { + "no-op".reflectWith() + } + evalTop(haltK, main) + } + + def codegen(main: Option[String]): LMSBlock = { + Adapter.g.reify( { Unwrap(evalTop(main)) } ) + } + + // The stack should be allocated on the stack to get optimal performance + implicit class StackOps(stack: Stack) { + def tail(): Stack = { + "value-stack-tail".reflectWith(stack) + } + + def ::[A](v: Rep[A]): Stack = { + "value-stack-cons".reflectWith(v, stack) + } + } + + // directly specify the translated operation + implicit class StringOps(op: String) { + def reflectWith[T: Manifest](rs: Rep[_]*): Rep[T] = { + val result = rs.map(Unwrap) + Predef.println(s"reflectWith: $op, $result") + val result1 = Adapter.g.reflect(op, result:_*) + Wrap[T](result1) + } + } +} diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala new file mode 100644 index 00000000..39e5c198 --- /dev/null +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -0,0 +1,22 @@ +package gensym.wasm + +import org.scalatest.FunSuite + +import lms.core.stub.Adapter + +import gensym.wasm.parser._ +import gensym.wasm.miniwasm._ + +class TestStagedEval extends FunSuite { + def testFile(filename: String, main: Option[String] = None) = { + val module = Parser.parseFile(filename) + val partialEvaluator = StagedEvaluator(ModuleInstance(module)) + val block = partialEvaluator.codegen(main) + println(Adapter.g) + println(block) + } + + test("push-drop") { + testFile("./benchmarks/wasm/staged/push-drop.wat") + } +} From a11a7ff561175561786aa420768a68b8e629e07b Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Fri, 25 Apr 2025 23:44:05 +0800 Subject: [PATCH 02/62] compose all parts --- src/main/scala/wasm/StagedMiniWasm.scala | 126 +++++++++++++------ src/test/scala/genwasym/TestStagedEval.scala | 8 +- 2 files changed, 89 insertions(+), 45 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 26691b0e..f7416199 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -4,38 +4,47 @@ import scala.collection.mutable.ArrayBuffer import lms.core.stub.Adapter import lms.core.virtualize -import lms.core.stub.Base +import lms.macros.SourceContext +import lms.core.stub.{Base, ScalaGenBase} +import lms.core.Backend._ import lms.core.Backend.{Block => LMSBlock} import gensym.wasm.ast._ import gensym.wasm.ast.{Const => ConstInstr} +import gensym.lmsx.{SAIDriver, StringOps, SAIOps, SAICodeGenBase} -case class StagedEvaluator(module: ModuleInstance) extends Base { +@virtualize +trait StagedWasmEvaluator extends SAIOps { + def module: ModuleInstance + // NOTE: we don't need the following statements anymore, but where are they initialized? // reset and initialize the internal state of Adapter - Adapter.resetState - Adapter.g = Adapter.mkGraphBuilder + // Adapter.resetState + // Adapter.g = Adapter.mkGraphBuilder - type Stack = Rep[List[Value]] - type Cont[A] = Stack => Rep[A] + trait Stack + type Cont[A] = Rep[Stack => A] type Trail[A] = List[Cont[A]] // Ans should be instantiated to something like Int, Unit, etc, which is the result type of staged program - def eval[Ans](insts: List[Instr], - stack: Stack, - frame: Rep[Frame], - kont: Cont[Ans], - trail: Trail[Ans]): Rep[Ans] = { - if (insts.isEmpty) return kont(stack) - val (inst, rest) = (insts.head, insts.tail) - inst match { - case Drop => eval(rest, stack.tail, frame, kont, trail) - // Why this cons operation compiled? does anything could be casted to Rep? - case ConstInstr(num) => eval(rest, num :: stack, frame, kont, trail) - case _ => "todo-op".reflectWith() - } + def eval(insts: List[Instr], + stack: Rep[Stack], + frame: Rep[Frame], + kont: Cont[Unit], + trail: Trail[Unit]): Rep[Unit] = { + if (insts.isEmpty) return kont(stack) + val (inst, rest) = (insts.head, insts.tail) + inst match { + case Drop => eval(rest, stack.tail, frame, kont, trail) + case ConstInstr(num) => eval(rest, (num: Rep[Num]) :: stack, frame, kont, trail) + // case LocalGet(i) => + // eval(rest, frame.locals(i) :: stack, frame, kont, trail) + case _ => + val noOp = "todo-op".reflectCtrlWith() + eval(rest, noOp :: stack, frame, kont, trail) + } } - def evalTop[Ans](kont: Cont[Ans], main: Option[String]): Rep[Ans] = { + def evalTop(kont: Cont[Unit], main: Option[String]): Rep[Unit] = { val funBody: FuncBodyDef = main match { case Some(func_name) => module.defs.flatMap({ @@ -47,7 +56,7 @@ case class StagedEvaluator(module: ModuleInstance) extends Base { } case _ => None }).head - case None => + case None => val startIds = module.defs.flatMap { case Start(id) => Some(id) case _ => None @@ -61,38 +70,75 @@ case class StagedEvaluator(module: ModuleInstance) extends Base { } val (instrs, localSize) = (funBody.body, funBody.locals.size) val frame = Frame(ArrayBuffer.fill(localSize)(I32V(0))) - eval(instrs, List(), frame, kont, List(kont)) + eval(instrs, emptyStack, unit(frame), kont, kont::Nil) // NOTE: simply use List(kont) here will cause compilation error } def evalTop(main: Option[String]): Rep[Unit] = { - val haltK: Cont[Unit] = stack => { - "no-op".reflectWith() + val haltK: Rep[Stack] => Rep[Unit] = stack => { + "no-op".reflectCtrlWith() } - evalTop(haltK, main) + evalTop(fun(haltK), main) } - def codegen(main: Option[String]): LMSBlock = { - Adapter.g.reify( { Unwrap(evalTop(main)) } ) + def emptyStack: Rep[Stack] = { + "empty-stack".reflectWith() } - // The stack should be allocated on the stack to get optimal performance - implicit class StackOps(stack: Stack) { - def tail(): Stack = { - "value-stack-tail".reflectWith(stack) + // TODO: The stack should be allocated on the stack to get optimal performance + implicit class StackOps(stack: Rep[Stack]) { + def tail(): Rep[Stack] = { + "stack-tail".reflectCtrlWith(stack) + } + + def ::[A](v: Rep[A]): Rep[Stack] = { + "stack-cons".reflectCtrlWith(v, stack) } + } +} +trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { + override def traverse(n: Node): Unit = n match { + case _ => super.traverse(n) + } - def ::[A](v: Rep[A]): Stack = { - "value-stack-cons".reflectWith(v, stack) + // code generation for pure nodes + override def shallow(n: Node): Unit = n match { + case Node(_, "stack-cons", List(v, stack), _) => + shallow(stack); emit(".push("); shallow(v); emit(")") + case Node(_, "stack-tail", List(stack), _) => + shallow(stack); emit(".pop()") + case Node(_, "empty-stack", _, _) => + emit("new Stack()") + case _ => super.shallow(n) + } +} +trait WasmCompilerDriver[A, B] + extends SAIDriver[A, B] with StagedWasmEvaluator { q => + override val codegen = new StagedWasmScalaGen { + val IR: q.type = q + import IR._ + override def remap(m: Manifest[_]): String = { + if (m.toString.endsWith("Stack")) "Stack" + else super.remap(m) } } - // directly specify the translated operation - implicit class StringOps(op: String) { - def reflectWith[T: Manifest](rs: Rep[_]*): Rep[T] = { - val result = rs.map(Unwrap) - Predef.println(s"reflectWith: $op, $result") - val result1 = Adapter.g.reflect(op, result:_*) - Wrap[T](result1) + override val prelude = + """ + object Prelude { + } + import Prelude._ + """ +} + +object PartialEvaluator { + def apply(moduleInst: ModuleInstance, main: Option[String]): String = { + println(s"Now compiling wasm module with entry function $main") + val code = new WasmCompilerDriver[Unit, Unit] { + def module: ModuleInstance = moduleInst + def snippet(x: Rep[Unit]): Rep[Unit] = { + evalTop(main) + } } + code.code } } diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index 39e5c198..cc7197f0 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -9,11 +9,9 @@ import gensym.wasm.miniwasm._ class TestStagedEval extends FunSuite { def testFile(filename: String, main: Option[String] = None) = { - val module = Parser.parseFile(filename) - val partialEvaluator = StagedEvaluator(ModuleInstance(module)) - val block = partialEvaluator.codegen(main) - println(Adapter.g) - println(block) + val moduleInst = ModuleInstance(Parser.parseFile(filename)) + val code = PartialEvaluator(moduleInst, None) + println(code) } test("push-drop") { From 9408c85b848ea91b74610d4b8b9f12558abd6bec Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sat, 26 Apr 2025 01:15:02 +0800 Subject: [PATCH 03/62] Frame should be opaque --- benchmarks/wasm/staged/push-drop.wat | 3 +++ src/main/scala/wasm/StagedMiniWasm.scala | 26 +++++++++++++++++++----- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/benchmarks/wasm/staged/push-drop.wat b/benchmarks/wasm/staged/push-drop.wat index eec1f816..19a3897e 100644 --- a/benchmarks/wasm/staged/push-drop.wat +++ b/benchmarks/wasm/staged/push-drop.wat @@ -1,7 +1,10 @@ (module $push-drop (func $real_main (type 1) (result i32) + (local i32 i32) i32.const 2 i32.const 2 + local.get 0 + local.get 1 drop drop) (start 0)) \ No newline at end of file diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index f7416199..d1a3c095 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -25,6 +25,8 @@ trait StagedWasmEvaluator extends SAIOps { type Cont[A] = Rep[Stack => A] type Trail[A] = List[Cont[A]] + trait Frame + // Ans should be instantiated to something like Int, Unit, etc, which is the result type of staged program def eval(insts: List[Instr], stack: Rep[Stack], @@ -35,9 +37,9 @@ trait StagedWasmEvaluator extends SAIOps { val (inst, rest) = (insts.head, insts.tail) inst match { case Drop => eval(rest, stack.tail, frame, kont, trail) - case ConstInstr(num) => eval(rest, (num: Rep[Num]) :: stack, frame, kont, trail) - // case LocalGet(i) => - // eval(rest, frame.locals(i) :: stack, frame, kont, trail) + case ConstInstr(num) => eval(rest, num :: stack, frame, kont, trail) + case LocalGet(i) => + eval(rest, frame.locals(i) :: stack, frame, kont, trail) case _ => val noOp = "todo-op".reflectCtrlWith() eval(rest, noOp :: stack, frame, kont, trail) @@ -69,8 +71,8 @@ trait StagedWasmEvaluator extends SAIOps { } } val (instrs, localSize) = (funBody.body, funBody.locals.size) - val frame = Frame(ArrayBuffer.fill(localSize)(I32V(0))) - eval(instrs, emptyStack, unit(frame), kont, kont::Nil) // NOTE: simply use List(kont) here will cause compilation error + val frame = frameOf(localSize) + eval(instrs, emptyStack, frame, kont, kont::Nil) // NOTE: simply use List(kont) here will cause compilation error } def evalTop(main: Option[String]): Rep[Unit] = { @@ -80,6 +82,8 @@ trait StagedWasmEvaluator extends SAIOps { evalTop(fun(haltK), main) } + + // stack creation and operations def emptyStack: Rep[Stack] = { "empty-stack".reflectWith() } @@ -94,6 +98,18 @@ trait StagedWasmEvaluator extends SAIOps { "stack-cons".reflectCtrlWith(v, stack) } } + + // frame creation and operations + def frameOf(size: Int): Rep[Frame] = { + "frame-of".reflectWith(size) + } + + implicit class FrameOps(frame: Rep[Frame]) { + + def locals(i: Int): Rep[Num] = { + "frame-locals".reflectCtrlWith(frame, i) + } + } } trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { override def traverse(n: Node): Unit = n match { From 60c782b36b90d3216c9f76068a459737f1d86a8f Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sat, 26 Apr 2025 20:20:45 +0800 Subject: [PATCH 04/62] function call --- benchmarks/wasm/staged/push-drop.wat | 9 ++- src/main/scala/wasm/StagedMiniWasm.scala | 96 ++++++++++++++++++++++-- 2 files changed, 95 insertions(+), 10 deletions(-) diff --git a/benchmarks/wasm/staged/push-drop.wat b/benchmarks/wasm/staged/push-drop.wat index 19a3897e..2c630da8 100644 --- a/benchmarks/wasm/staged/push-drop.wat +++ b/benchmarks/wasm/staged/push-drop.wat @@ -1,10 +1,15 @@ (module $push-drop - (func $real_main (type 1) (result i32) + (func (;0;) (type 1) (result i32) (local i32 i32) i32.const 2 i32.const 2 local.get 0 local.get 1 drop - drop) + drop + (call 1)) + (func (;1;) (type 1) (param i32 i32) (result i32) + (local i32 i32) + local.get 0 + local.get 1) (start 0)) \ No newline at end of file diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index d1a3c095..f968b76c 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -1,6 +1,6 @@ package gensym.wasm.miniwasm -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, HashMap} import lms.core.stub.Adapter import lms.core.virtualize @@ -22,16 +22,19 @@ trait StagedWasmEvaluator extends SAIOps { // Adapter.g = Adapter.mkGraphBuilder trait Stack - type Cont[A] = Rep[Stack => A] - type Trail[A] = List[Cont[A]] + type Cont[A] = Stack => A + type Trail[A] = List[Rep[Cont[A]]] trait Frame - // Ans should be instantiated to something like Int, Unit, etc, which is the result type of staged program + // a cache storing the compiled code for each function, to reduce re-compilation + val compileCache = new HashMap[Int, Rep[(Stack, Frame, Cont[Unit]) => Unit]] + + // NOTE: We don't support Ans type polymorphism yet def eval(insts: List[Instr], stack: Rep[Stack], frame: Rep[Frame], - kont: Cont[Unit], + kont: Rep[Cont[Unit]], trail: Trail[Unit]): Rep[Unit] = { if (insts.isEmpty) return kont(stack) val (inst, rest) = (insts.head, insts.tail) @@ -40,13 +43,68 @@ trait StagedWasmEvaluator extends SAIOps { case ConstInstr(num) => eval(rest, num :: stack, frame, kont, trail) case LocalGet(i) => eval(rest, frame.locals(i) :: stack, frame, kont, trail) - case _ => + case Call(f) => evalCall(rest, stack, frame, kont, trail, f, false) + case ReturnCall(f) => evalCall(rest, stack, frame, kont, trail, f, true) + case _ => val noOp = "todo-op".reflectCtrlWith() eval(rest, noOp :: stack, frame, kont, trail) } } - def evalTop(kont: Cont[Unit], main: Option[String]): Rep[Unit] = { + def evalCall(rest: List[Instr], + stack: Rep[Stack], + frame: Rep[Frame], + kont: Rep[Cont[Unit]], + trail: Trail[Unit], + funcIndex: Int, + isTail: Boolean): Rep[Unit] = { + module.funcs(funcIndex) match { + case FuncDef(_, FuncBodyDef(ty, _, locals, body)) => + val args = stack.take(ty.inps.size).reverse + val newStack = stack.drop(ty.inps.size) + val newFrame = frameOf(ty.inps.size + locals.size).put(args) + val callee = + if (compileCache.contains(funcIndex)) { + compileCache(funcIndex) + } else { + val callee = fun( + (stack: Rep[Stack], frame: Rep[Frame], kont: Rep[Cont[Unit]]) => { + eval(body, stack, frame, kont, kont::Nil):Rep[Unit] + } + ) + compileCache(funcIndex) = callee + callee + } + if (isTail) + // when tail call, share the continuation for returning with the callee + callee(emptyStack, newFrame, kont) + else { + val restK = fun( + (retStack: Rep[Stack]) => + eval(rest, retStack.take(ty.out.size) ++ newStack, frame, kont, trail) + ) + // We make a new trail by `restK`, since function creates a new block to escape + // (more or less like `return`) + callee(emptyStack, newFrame, kont) + } + // TODO: Support imported functions + // case Import("console", "log", _) => + // //println(s"[DEBUG] current stack: $stack") + // val I32V(v) :: newStack = stack + // println(v) + // eval(rest, newStack, frame, kont, trail) + // case Import("spectest", "print_i32", _) => + // //println(s"[DEBUG] current stack: $stack") + // val I32V(v) :: newStack = stack + // println(v) + // eval(rest, newStack, frame, kont, trail) + case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex") + case _ => throw new Exception(s"Definition at $funcIndex is not callable") + } + } + + + def evalTop(kont: Rep[Cont[Unit]], main: Option[String]): Rep[Unit] = { val funBody: FuncBodyDef = main match { case Some(func_name) => module.defs.flatMap({ @@ -97,6 +155,22 @@ trait StagedWasmEvaluator extends SAIOps { def ::[A](v: Rep[A]): Rep[Stack] = { "stack-cons".reflectCtrlWith(v, stack) } + + def ++(v: Rep[Stack]): Rep[Stack] = { + "stack-append".reflectCtrlWith(stack, v) + } + + def take(n: Int): Rep[Stack] = { + "stack-take".reflectWith(stack, n) + } + + def drop(n: Int): Rep[Stack] = { + "stack-drop".reflectWith(stack, n) + } + + def reverse: Rep[Stack] = { + "stack-reverse".reflectWith(stack) + } } // frame creation and operations @@ -107,8 +181,13 @@ trait StagedWasmEvaluator extends SAIOps { implicit class FrameOps(frame: Rep[Frame]) { def locals(i: Int): Rep[Num] = { - "frame-locals".reflectCtrlWith(frame, i) + "frame-get".reflectCtrlWith(frame, i) + } + + def put(args: Rep[Stack]): Rep[Frame] = { + "frame-put".reflectCtrlWith(frame, args) } + } } trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { @@ -134,6 +213,7 @@ trait WasmCompilerDriver[A, B] import IR._ override def remap(m: Manifest[_]): String = { if (m.toString.endsWith("Stack")) "Stack" + else if(m.toString.endsWith("Frame")) "Frame" else super.remap(m) } } From 442d8d1d7015645f96470a2a9385281fbbbf85fc Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sat, 26 Apr 2025 21:20:11 +0800 Subject: [PATCH 05/62] factor out getFuncType --- src/main/scala/wasm/AST.scala | 11 ++++++++++- src/main/scala/wasm/ConcolicMiniWasm.scala | 21 ++++++--------------- src/main/scala/wasm/MiniWasm.scala | 15 +++------------ 3 files changed, 19 insertions(+), 28 deletions(-) diff --git a/src/main/scala/wasm/AST.scala b/src/main/scala/wasm/AST.scala index c59eefc9..274b2b50 100644 --- a/src/main/scala/wasm/AST.scala +++ b/src/main/scala/wasm/AST.scala @@ -270,7 +270,16 @@ case class RefType(kind: HeapType) extends ValueType case class GlobalType(ty: ValueType, mut: Boolean) extends WasmType -abstract class BlockType extends WIR +abstract class BlockType extends WIR { + def funcType: FuncType = + this match { + case VarBlockType(_, None) => + ??? // TODO: fill this branch until we handle type index correctly + case VarBlockType(_, Some(tipe)) => tipe + case ValBlockType(Some(tipe)) => FuncType(List(), List(), List(tipe)) + case ValBlockType(None) => FuncType(List(), List(), List()) + } +} case class VarBlockType(index: Int, tipe: Option[FuncType]) extends BlockType case class ValBlockType(tipe: Option[ValueType]) extends BlockType; diff --git a/src/main/scala/wasm/ConcolicMiniWasm.scala b/src/main/scala/wasm/ConcolicMiniWasm.scala index 849fd831..fef469ec 100644 --- a/src/main/scala/wasm/ConcolicMiniWasm.scala +++ b/src/main/scala/wasm/ConcolicMiniWasm.scala @@ -229,15 +229,6 @@ object Primitives { case NumType(F32Type) => F32V(rng.nextFloat()) case NumType(F64Type) => F64V(rng.nextDouble()) } - - def getFuncType(ty: BlockType): FuncType = - ty match { - case VarBlockType(_, None) => - ??? // TODO: fill this branch until we handle type index correctly - case VarBlockType(_, Some(tipe)) => tipe - case ValBlockType(Some(tipe)) => FuncType(List(), List(), List(tipe)) - case ValBlockType(None) => FuncType(List(), List(), List()) - } } case class Frame(module: ModuleInstance, locals: ArrayBuffer[Value], symLocals: ArrayBuffer[SymVal]) @@ -383,7 +374,7 @@ case class Evaluator(module: ModuleInstance) { eval(rest, concStack, symStack, frame, kont, trail) case Unreachable => throw new RuntimeException("Unreachable") case Block(ty, inner) => - val funcTy = getFuncType(ty) + val funcTy = ty.funcType val (inputSize, outputSize) = (funcTy.inps.size, funcTy.out.size) val (inputs, restStack) = concStack.splitAt(inputSize) val (symInputs, restSymStack) = symStack.splitAt(inputSize) @@ -391,7 +382,7 @@ case class Evaluator(module: ModuleInstance) { eval(rest, retStack.take(outputSize) ++ restStack, retSymStack.take(outputSize) ++ restSymStack, frame, kont, trail)(tree) eval(inner, inputs, symInputs, frame, restK, restK :: trail) case Loop(ty, inner) => - val funcTy = getFuncType(ty) + val funcTy = ty.funcType val (inputSize, outputSize) = (funcTy.inps.size, funcTy.out.size) val (inputs, restStack) = concStack.splitAt(inputSize) val (symInputs, restSymStack) = symStack.splitAt(inputSize) @@ -404,9 +395,9 @@ case class Evaluator(module: ModuleInstance) { val scnd :: newSymStack = symStack val I32V(cond) :: newStack = concStack val (ifNode, elseNode) = if (scnd.isInstanceOf[Concrete]) { - // if this is a concrete value, we don't need to put + // if this is a concrete value, we don't need to put (tree, tree) - } else { + } else { val ifElseNode = tree.fillWithIfElse(Not(CondEqz(scnd))) (ifElseNode.thenNode, ifElseNode.elseNode) } @@ -422,9 +413,9 @@ case class Evaluator(module: ModuleInstance) { val scnd :: newSymStack = symStack val I32V(cond) :: newStack = concStack val (ifNode, elseNode) = if (scnd.isInstanceOf[Concrete]) { - // if this is a concrete value, we don't need to put + // if this is a concrete value, we don't need to put (tree, tree) - } else { + } else { val ifElseNode = tree.fillWithIfElse(Not(CondEqz(scnd))) (ifElseNode.thenNode, ifElseNode.elseNode) } diff --git a/src/main/scala/wasm/MiniWasm.scala b/src/main/scala/wasm/MiniWasm.scala index 11eb301b..2a5abe6d 100644 --- a/src/main/scala/wasm/MiniWasm.scala +++ b/src/main/scala/wasm/MiniWasm.scala @@ -229,15 +229,6 @@ object Primtives { case VecType(kind) => ??? case RefType(kind) => RefNullV(kind) } - - def getFuncType(ty: BlockType): FuncType = - ty match { - case VarBlockType(_, None) => - ??? // TODO: fill this branch until we handle type index correctly - case VarBlockType(_, Some(tipe)) => tipe - case ValBlockType(Some(tipe)) => FuncType(List(), List(), List(tipe)) - case ValBlockType(None) => FuncType(List(), List(), List()) - } } case class Frame(locals: ArrayBuffer[Value]) @@ -380,7 +371,7 @@ case class Evaluator(module: ModuleInstance) { eval(rest, stack, frame, kont, trail) case Unreachable => throw Trap() case Block(ty, inner) => - val funcTy = getFuncType(ty) + val funcTy = ty.funcType val (inputs, restStack) = stack.splitAt(funcTy.inps.size) val restK: Cont[Ans] = (retStack) => eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) @@ -389,7 +380,7 @@ case class Evaluator(module: ModuleInstance) { // We construct two continuations, one for the break (to the begining of the loop), // and one for fall-through to the next instruction following the syntactic structure // of the program. - val funcTy = getFuncType(ty) + val funcTy = ty.funcType val (inputs, restStack) = stack.splitAt(funcTy.inps.size) val restK: Cont[Ans] = (retStack) => eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) @@ -397,7 +388,7 @@ case class Evaluator(module: ModuleInstance) { eval(inner, retStack.take(funcTy.inps.size), frame, restK, loop _ :: trail) loop(inputs) case If(ty, thn, els) => - val funcTy = getFuncType(ty) + val funcTy = ty.funcType val I32V(cond) :: newStack = stack val inner = if (cond != 0) thn else els val (inputs, restStack) = newStack.splitAt(funcTy.inps.size) From ab679f36c4181fba8bf25866f992bc59ae35f71b Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sat, 26 Apr 2025 22:15:15 +0800 Subject: [PATCH 06/62] fix: use restK when non-tail call --- src/main/scala/wasm/StagedMiniWasm.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index f968b76c..bb3a7291 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -85,7 +85,7 @@ trait StagedWasmEvaluator extends SAIOps { ) // We make a new trail by `restK`, since function creates a new block to escape // (more or less like `return`) - callee(emptyStack, newFrame, kont) + callee(emptyStack, newFrame, restK) } // TODO: Support imported functions // case Import("console", "log", _) => From 8ba8657782bf1a7cea7ce1ae7e8ff81847089692 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sat, 26 Apr 2025 22:43:55 +0800 Subject: [PATCH 07/62] compile Block-like instructions(if-else, loop, block) --- benchmarks/wasm/staged/push-drop.wat | 11 +++++- src/main/scala/wasm/StagedMiniWasm.scala | 46 ++++++++++++++++++++++-- 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/benchmarks/wasm/staged/push-drop.wat b/benchmarks/wasm/staged/push-drop.wat index 2c630da8..a32fb24a 100644 --- a/benchmarks/wasm/staged/push-drop.wat +++ b/benchmarks/wasm/staged/push-drop.wat @@ -7,7 +7,16 @@ local.get 1 drop drop - (call 1)) + (call 1) + i32.const 3 + if (result i32) ;; label = @1 + i32.const 1 + else + local.get 1 + end + (loop + i32.const 4) + ) (func (;1;) (type 1) (param i32 i32) (result i32) (local i32 i32) local.get 0 diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index bb3a7291..025a55df 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -10,7 +10,7 @@ import lms.core.Backend._ import lms.core.Backend.{Block => LMSBlock} import gensym.wasm.ast._ -import gensym.wasm.ast.{Const => ConstInstr} +import gensym.wasm.ast.{Const => WasmConst, Block => WasmBlock} import gensym.lmsx.{SAIDriver, StringOps, SAIOps, SAICodeGenBase} @virtualize @@ -40,9 +40,41 @@ trait StagedWasmEvaluator extends SAIOps { val (inst, rest) = (insts.head, insts.tail) inst match { case Drop => eval(rest, stack.tail, frame, kont, trail) - case ConstInstr(num) => eval(rest, num :: stack, frame, kont, trail) + case WasmConst(num) => eval(rest, num :: stack, frame, kont, trail) case LocalGet(i) => eval(rest, frame.locals(i) :: stack, frame, kont, trail) + case WasmBlock(ty, inner) => + val funcTy = ty.funcType + val (inputs, restStack) = stack.splitAt(funcTy.inps.size) + val restK = fun( + (retStack: Rep[Stack]) => + eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) + ) + eval(inner, inputs, frame, restK, restK :: trail) + case Loop(ty, inner) => + val funcTy = ty.funcType + val (inputs, restStack) = stack.splitAt(funcTy.inps.size) + val restK = fun( + (retStack: Rep[Stack]) => + eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) + ) + def loop(retStack: Rep[Stack]): Rep[Unit] = + eval(inner, retStack.take(funcTy.inps.size), frame, restK, fun(loop _) :: trail) + loop(inputs) + case If(ty, thn, els) => + val funcTy = ty.funcType + val (cond, newStack) = (stack.head, stack.tail) + val (inputs, restStack) = newStack.splitAt(funcTy.inps.size) + // TODO: can we avoid code duplication here? + val restK = fun( + (retStack: Rep[Stack]) => + eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) + ) + if (cond != 0) { + eval(thn, inputs, frame, restK, restK :: trail) + } else { + eval(els, inputs, frame, restK, restK :: trail) + } case Call(f) => evalCall(rest, stack, frame, kont, trail, f, false) case ReturnCall(f) => evalCall(rest, stack, frame, kont, trail, f, true) case _ => @@ -148,7 +180,11 @@ trait StagedWasmEvaluator extends SAIOps { // TODO: The stack should be allocated on the stack to get optimal performance implicit class StackOps(stack: Rep[Stack]) { - def tail(): Rep[Stack] = { + def head: Rep[Num] = { + "stack-head".reflectCtrlWith(stack) + } + + def tail: Rep[Stack] = { "stack-tail".reflectCtrlWith(stack) } @@ -171,6 +207,10 @@ trait StagedWasmEvaluator extends SAIOps { def reverse: Rep[Stack] = { "stack-reverse".reflectWith(stack) } + + def splitAt(n: Int): (Rep[Stack], Rep[Stack]) = { + (take(n), drop(n)) + } } // frame creation and operations From e2cc801df0b5998242de2be8d2f91d5fa737fd19 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 27 Apr 2025 01:25:56 +0800 Subject: [PATCH 08/62] branching instructions --- benchmarks/wasm/staged/push-drop.wat | 13 ++++++++++++- src/main/scala/wasm/StagedMiniWasm.scala | 23 +++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/benchmarks/wasm/staged/push-drop.wat b/benchmarks/wasm/staged/push-drop.wat index a32fb24a..196fbc19 100644 --- a/benchmarks/wasm/staged/push-drop.wat +++ b/benchmarks/wasm/staged/push-drop.wat @@ -14,8 +14,19 @@ else local.get 1 end + (block + (block + i32.const 4 + i32.const 2 + ;; br_table 0 0 ;; the compilation of br_table is problematic now + ) + ) + (loop - i32.const 4) + i32.const 5 + br 0) + return + i32.const 6 ) (func (;1;) (type 1) (param i32 i32) (result i32) (local i32 i32) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 025a55df..02c7dbac 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -75,6 +75,22 @@ trait StagedWasmEvaluator extends SAIOps { } else { eval(els, inputs, frame, restK, restK :: trail) } + case Br(label) => + trail(label)(stack) + case BrIf(label) => + val (cond, newStack) = (stack.head, stack.tail) + if (cond != 0) trail(label)(newStack) + else eval(rest, newStack, frame, kont, trail) + case BrTable(labels, default) => + val (cond, newStack) = (stack.head, stack.tail) + if (cond.toInt < labels.length) { + var targets: Rep[List[Cont[Unit]]] = List(labels.map(i => trail(i)): _*) + val goto: Rep[Cont[Unit]] = targets(cond.toInt) + goto(newStack) // TODO: this line will trigger an exception + } else { + trail(default)(newStack) + } + case Return => trail.last(stack) case Call(f) => evalCall(rest, stack, frame, kont, trail, f, false) case ReturnCall(f) => evalCall(rest, stack, frame, kont, trail, f, true) case _ => @@ -229,6 +245,13 @@ trait StagedWasmEvaluator extends SAIOps { } } + + // runtime Num type + implicit class NumOps(num: Rep[Num]) { + def toInt: Rep[Int] = { + "num-to-int".reflectWith(num) + } + } } trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { override def traverse(n: Node): Unit = n match { From fa3d6280330bb357277323d0d774b1b8b7acf94f Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 27 Apr 2025 14:13:48 +0800 Subject: [PATCH 09/62] local set --- benchmarks/wasm/staged/push-drop.wat | 2 ++ src/main/scala/wasm/StagedMiniWasm.scala | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/benchmarks/wasm/staged/push-drop.wat b/benchmarks/wasm/staged/push-drop.wat index 196fbc19..5b6c500b 100644 --- a/benchmarks/wasm/staged/push-drop.wat +++ b/benchmarks/wasm/staged/push-drop.wat @@ -5,6 +5,8 @@ i32.const 2 local.get 0 local.get 1 + local.set 0 + local.tee 1 drop drop (call 1) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 02c7dbac..9f5ef991 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -43,6 +43,14 @@ trait StagedWasmEvaluator extends SAIOps { case WasmConst(num) => eval(rest, num :: stack, frame, kont, trail) case LocalGet(i) => eval(rest, frame.locals(i) :: stack, frame, kont, trail) + case LocalSet(i) => + val (v, newStack) = (stack.head, stack.tail) + frame(i) = v + eval(rest, newStack, frame, kont, trail) + case LocalTee(i) => + val (v, _) = (stack.head, stack.tail) + frame(i) = v + eval(rest, stack, frame, kont, trail) case WasmBlock(ty, inner) => val funcTy = ty.funcType val (inputs, restStack) = stack.splitAt(funcTy.inps.size) @@ -244,6 +252,9 @@ trait StagedWasmEvaluator extends SAIOps { "frame-put".reflectCtrlWith(frame, args) } + def update(i: Int, value: Rep[Num]) = { + "frame-update".reflectCtrlWith(frame, i, value) + } } // runtime Num type From d78a96b5afe5f619c02eb75e15ca875fd58f2835 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 27 Apr 2025 15:18:44 +0800 Subject: [PATCH 10/62] operators --- benchmarks/wasm/staged/push-drop.wat | 2 + src/main/scala/wasm/StagedMiniWasm.scala | 110 ++++++++++++++++++++++- 2 files changed, 109 insertions(+), 3 deletions(-) diff --git a/benchmarks/wasm/staged/push-drop.wat b/benchmarks/wasm/staged/push-drop.wat index 5b6c500b..903b771d 100644 --- a/benchmarks/wasm/staged/push-drop.wat +++ b/benchmarks/wasm/staged/push-drop.wat @@ -9,6 +9,8 @@ local.tee 1 drop drop + i32.add + nop (call 1) i32.const 3 if (result i32) ;; label = @1 diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 9f5ef991..5112cecb 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -51,6 +51,21 @@ trait StagedWasmEvaluator extends SAIOps { val (v, _) = (stack.head, stack.tail) frame(i) = v eval(rest, stack, frame, kont, trail) + case Nop => + eval(rest, stack, frame, kont, trail) + case Unreachable => unreachable() + case Test(op) => + val (v, newStack) = (stack.head, stack.tail) + eval(rest, evalTestOp(op, v) :: newStack, frame, kont, trail) + case Unary(op) => + val (v, newStack) = (stack.head, stack.tail) + eval(rest, evalUnaryOp(op, v) :: newStack, frame, kont, trail) + case Binary(op) => + val (v2, v1, newStack) = (stack.head, stack.tail.head, stack.tail.tail) + eval(rest, evalBinOp(op, v1, v2) :: newStack, frame, kont, trail) + case Compare(op) => + val (v2, v1, newStack) = (stack.head, stack.tail.head, stack.tail.tail) + eval(rest, evalRelOp(op, v1, v2) :: newStack, frame, kont, trail) case WasmBlock(ty, inner) => val funcTy = ty.funcType val (inputs, restStack) = stack.splitAt(funcTy.inps.size) @@ -159,6 +174,41 @@ trait StagedWasmEvaluator extends SAIOps { } } + def evalTestOp(op: TestOp, value: Rep[Num]): Rep[Num] = op match { + case Eqz(_) => if (value.toInt == 0) I32(1) else I32(0) + } + + def evalUnaryOp(op: UnaryOp, value: Rep[Num]): Rep[Num] = op match { + case Clz(_) => value.clz() + case Ctz(_) => value.ctz() + case Popcnt(_) => value.popcnt() + case _ => ??? + } + + def evalBinOp(op: BinOp, v1: Rep[Num], v2: Rep[Num]): Rep[Num] = op match { + case Add(_) => v1 + v2 + case Mul(_) => v1 * v2 + case Sub(_) => v1 - v2 + case Shl(_) => v1 << v2 + // case ShrS(_) => v1 >> v2 // TODO: signed shift right + case ShrU(_) => v1 >> v2 + case And(_) => v1 & v2 + case _ => ??? + } + + def evalRelOp(op: RelOp, v1: Rep[Num], v2: Rep[Num]): Rep[Num] = op match { + case Eq(_) => v1 numEq v2 + case Ne(_) => v1 numNe v2 + case LtS(_) => v1 < v2 + case LtU(_) => v1 ltu v2 + case GtS(_) => v1 > v2 + case GtU(_) => v1 gtu v2 + case LeS(_) => v1 <= v2 + case LeU(_) => v1 leu v2 + case GeS(_) => v1 >= v2 + case GeU(_) => v1 geu v2 + case _ => ??? + } def evalTop(kont: Rep[Cont[Unit]], main: Option[String]): Rep[Unit] = { val funBody: FuncBodyDef = main match { @@ -202,6 +252,19 @@ trait StagedWasmEvaluator extends SAIOps { "empty-stack".reflectWith() } + // call unreachable + def unreachable(): Rep[Unit] = { + "unreachable".reflectCtrlWith() + } + + def I32(i: Rep[Int]): Rep[Num] = { + "I32V".reflectWith(i) + } + + def I64(i: Rep[Long]): Rep[Num] = { + "I64V".reflectWith(i) + } + // TODO: The stack should be allocated on the stack to get optimal performance implicit class StackOps(stack: Rep[Stack]) { def head: Rep[Num] = { @@ -255,13 +318,54 @@ trait StagedWasmEvaluator extends SAIOps { def update(i: Int, value: Rep[Num]) = { "frame-update".reflectCtrlWith(frame, i, value) } + } // runtime Num type implicit class NumOps(num: Rep[Num]) { - def toInt: Rep[Int] = { - "num-to-int".reflectWith(num) - } + + def toInt: Rep[Int] = "num-to-int".reflectWith(num) + + def clz(): Rep[Num] = "unary-clz".reflectWith(num) + + def ctz(): Rep[Num] = "unary-ctz".reflectWith(num) + + def popcnt(): Rep[Num] = "unary-popcnt".reflectWith(num) + + def +(rhs: Rep[Num]): Rep[Num] = "binary-add".reflectWith(num, rhs) + + def -(rhs: Rep[Num]): Rep[Num] = "binary-sub".reflectWith(num, rhs) + + def *(rhs: Rep[Num]): Rep[Num] = "binary-mul".reflectWith(num, rhs) + + def /(rhs: Rep[Num]): Rep[Num] = "binary-div".reflectWith(num, rhs) + + def <<(rhs: Rep[Num]): Rep[Num] = "binary-shl".reflectWith(num, rhs) + + def >>(rhs: Rep[Num]): Rep[Num] = "binary-shr".reflectWith(num, rhs) + + def &(rhs: Rep[Num]): Rep[Num] = "binary-and".reflectWith(num, rhs) + + def numEq(rhs: Rep[Num]): Rep[Num] = "relation-eq".reflectWith(num, rhs) + + def numNe(rhs: Rep[Num]): Rep[Num] = "relation-ne".reflectWith(num, rhs) + + def <(rhs: Rep[Num]): Rep[Num] = "relation-lt".reflectWith(num, rhs) + + def ltu(rhs: Rep[Num]): Rep[Num] = "relation-ltu".reflectWith(num, rhs) + + def >(rhs: Rep[Num]): Rep[Num] = "relation-gt".reflectWith(num, rhs) + + def gtu(rhs: Rep[Num]): Rep[Num] = "relation-gtu".reflectWith(num, rhs) + + def <=(rhs: Rep[Num]): Rep[Num] = "relation-le".reflectWith(num, rhs) + + def leu(rhs: Rep[Num]): Rep[Num] = "relation-leu".reflectWith(num, rhs) + + def >=(rhs: Rep[Num]): Rep[Num] = "relation-ge".reflectWith(num, rhs) + + def geu(rhs: Rep[Num]): Rep[Num] = "relation-geu".reflectWith(num, rhs) + } } trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { From dac7e1cef393609c1cf59ee75b3ec7837c35b085 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 27 Apr 2025 15:46:35 +0800 Subject: [PATCH 11/62] global instructions --- benchmarks/wasm/staged/push-drop.wat | 4 ++++ src/main/scala/wasm/StagedMiniWasm.scala | 20 ++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/benchmarks/wasm/staged/push-drop.wat b/benchmarks/wasm/staged/push-drop.wat index 903b771d..db7b18bd 100644 --- a/benchmarks/wasm/staged/push-drop.wat +++ b/benchmarks/wasm/staged/push-drop.wat @@ -1,4 +1,5 @@ (module $push-drop + (global (;0;) (mut i32) (i32.const 1048576)) (func (;0;) (type 1) (result i32) (local i32 i32) i32.const 2 @@ -12,7 +13,10 @@ i32.add nop (call 1) + global.get 1 i32.const 3 + global.set 2 ;; TODO: this line was compiled to global.get, fix the parser! + if (result i32) ;; label = @1 i32.const 1 else diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 5112cecb..b69a7fbc 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -51,6 +51,15 @@ trait StagedWasmEvaluator extends SAIOps { val (v, _) = (stack.head, stack.tail) frame(i) = v eval(rest, stack, frame, kont, trail) + case GlobalGet(i) => + eval(rest, Global.globalGet(i) :: stack, frame, kont, trail) + case GlobalSet(i) => + val (value, newStack) = (stack.head, stack.tail) + module.globals(i).ty match { + case GlobalType(tipe, true) => Global.globalSet(i, value) + case _ => throw new Exception("Cannot set immutable global") + } + eval(rest, newStack, frame, kont, trail) case Nop => eval(rest, stack, frame, kont, trail) case Unreachable => unreachable() @@ -265,6 +274,17 @@ trait StagedWasmEvaluator extends SAIOps { "I64V".reflectWith(i) } + // global read/write + object Global{ + def globalGet(i: Int): Rep[Num] = { + "global-get".reflectWith(i) + } + + def globalSet(i: Int, value: Rep[Num]): Rep[Unit] = { + "global-set".reflectCtrlWith(i, value) + } + } + // TODO: The stack should be allocated on the stack to get optimal performance implicit class StackOps(stack: Rep[Stack]) { def head: Rep[Num] = { From cc65f6358d80b7c38acdad3c43337344354e80d5 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 27 Apr 2025 16:06:21 +0800 Subject: [PATCH 12/62] placeholder for mem instructions --- src/main/scala/wasm/StagedMiniWasm.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index b69a7fbc..0373f49f 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -60,6 +60,9 @@ trait StagedWasmEvaluator extends SAIOps { case _ => throw new Exception("Cannot set immutable global") } eval(rest, newStack, frame, kont, trail) + case MemorySize => ??? + case MemoryGrow => ??? + case MemoryFill => ??? case Nop => eval(rest, stack, frame, kont, trail) case Unreachable => unreachable() From cf3063a4bda619bac6106e05577c1fb5c544d297 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 27 Apr 2025 16:26:11 +0800 Subject: [PATCH 13/62] scala code generation --- src/main/scala/wasm/StagedMiniWasm.scala | 61 +++++++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 0373f49f..80bec0d9 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -42,7 +42,7 @@ trait StagedWasmEvaluator extends SAIOps { case Drop => eval(rest, stack.tail, frame, kont, trail) case WasmConst(num) => eval(rest, num :: stack, frame, kont, trail) case LocalGet(i) => - eval(rest, frame.locals(i) :: stack, frame, kont, trail) + eval(rest, frame.get(i) :: stack, frame, kont, trail) case LocalSet(i) => val (v, newStack) = (stack.head, stack.tail) frame(i) = v @@ -330,7 +330,7 @@ trait StagedWasmEvaluator extends SAIOps { implicit class FrameOps(frame: Rep[Frame]) { - def locals(i: Int): Rep[Num] = { + def get(i: Int): Rep[Num] = { "frame-get".reflectCtrlWith(frame, i) } @@ -393,17 +393,74 @@ trait StagedWasmEvaluator extends SAIOps { } trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { override def traverse(n: Node): Unit = n match { + case Node(_, "frame-update", List(frame, i, value), _) => + // TODO: what is the protocol of automatic new line insertion? + shallow(frame); emit(".update("); shallow(i); emit(", "); shallow(value); emit(")\n") + case Node(_, "global-set", List(i, value), _) => + shallow(i); emit(".globalSet("); shallow(value); emit(")") case _ => super.traverse(n) } // code generation for pure nodes override def shallow(n: Node): Unit = n match { + case Node(_, "stack-take", List(stack, n), _) => + shallow(stack); emit(".take("); shallow(n); emit(")") + case Node(_, "stack-drop", List(stack, n), _) => + shallow(stack); emit(".drop("); shallow(n); emit(")") + case Node(_, "stack-append", List(stack1, stack2), _) => + shallow(stack1); emit(".++("); shallow(stack2); emit(")") + case Node(_, "stack-head", List(stack), _) => + shallow(stack); emit(".head") + case Node(_, "stack-reverse", List(stack), _) => + shallow(stack); emit(".reverse") case Node(_, "stack-cons", List(v, stack), _) => shallow(stack); emit(".push("); shallow(v); emit(")") case Node(_, "stack-tail", List(stack), _) => shallow(stack); emit(".pop()") case Node(_, "empty-stack", _, _) => emit("new Stack()") + case Node(_, "frame-of", List(size), _) => + emit("new Frame("); shallow(size); emit(")") + case Node(_, "frame-get", List(frame, i), _) => + shallow(frame); emit("("); shallow(i); emit(")") + case Node(_, "frame-put", List(frame, args), _) => + shallow(frame); emit(".put("); shallow(args); emit(")") + case Node(_, "global-get", List(i), _) => + emit("Global.globalGet("); shallow(i); emit(")") + case Node(_, "binary-add", List(lhs, rhs), _) => + shallow(lhs); emit(" + "); shallow(rhs) + case Node(_, "binary-sub", List(lhs, rhs), _) => + shallow(lhs); emit(" - "); shallow(rhs) + case Node(_, "binary-mul", List(lhs, rhs), _) => + shallow(lhs); emit(" * "); shallow(rhs) + case Node(_, "binary-div", List(lhs, rhs), _) => + shallow(lhs); emit(" / "); shallow(rhs) + case Node(_, "binary-shl", List(lhs, rhs), _) => + shallow(lhs); emit(" << "); shallow(rhs) + case Node(_, "binary-shr", List(lhs, rhs), _) => + shallow(lhs); emit(" >> "); shallow(rhs) + case Node(_, "binary-and", List(lhs, rhs), _) => + shallow(lhs); emit(" & "); shallow(rhs) + case Node(_, "relation-eq", List(lhs, rhs), _) => + shallow(lhs); emit(" == "); shallow(rhs) + case Node(_, "relation-ne", List(lhs, rhs), _) => + shallow(lhs); emit(" != "); shallow(rhs) + case Node(_, "relation-lt", List(lhs, rhs), _) => + shallow(lhs); emit(" < "); shallow(rhs) + case Node(_, "relation-ltu", List(lhs, rhs), _) => + shallow(lhs); emit(" < "); shallow(rhs) + case Node(_, "relation-gt", List(lhs, rhs), _) => + shallow(lhs); emit(" > "); shallow(rhs) + case Node(_, "relation-gtu", List(lhs, rhs), _) => + shallow(lhs); emit(" > "); shallow(rhs) + case Node(_, "relation-le", List(lhs, rhs), _) => + shallow(lhs); emit(" <= "); shallow(rhs) + case Node(_, "relation-leu", List(lhs, rhs), _) => + shallow(lhs); emit(" <= "); shallow(rhs) + case Node(_, "relation-ge", List(lhs, rhs), _) => + shallow(lhs); emit(" >= "); shallow(rhs) + case Node(_, "relation-geu", List(lhs, rhs), _) => + shallow(lhs); emit(" >= "); shallow(rhs) case _ => super.shallow(n) } } From 80bfa682c1d06d2cbcb42517e6b89239171129de Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 27 Apr 2025 20:03:56 +0800 Subject: [PATCH 14/62] some imported function --- src/main/scala/wasm/StagedMiniWasm.scala | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 80bec0d9..4e6f7107 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -170,17 +170,12 @@ trait StagedWasmEvaluator extends SAIOps { // (more or less like `return`) callee(emptyStack, newFrame, restK) } - // TODO: Support imported functions - // case Import("console", "log", _) => - // //println(s"[DEBUG] current stack: $stack") - // val I32V(v) :: newStack = stack - // println(v) - // eval(rest, newStack, frame, kont, trail) - // case Import("spectest", "print_i32", _) => - // //println(s"[DEBUG] current stack: $stack") - // val I32V(v) :: newStack = stack - // println(v) - // eval(rest, newStack, frame, kont, trail) + case Import("console", "log", _) + | Import("spectest", "print_i32", _) => + //println(s"[DEBUG] current stack: $stack") + val (v, newStack) = (stack.head, stack.tail) + println(v) + eval(rest, newStack, frame, kont, trail) case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex") case _ => throw new Exception(s"Definition at $funcIndex is not callable") } From 881eda45df5c730a80938b2d7e0e1c82656febaa Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 27 Apr 2025 20:04:26 +0800 Subject: [PATCH 15/62] polish --- .../staged/{push-drop.wat => scratch.wat} | 2 ++ src/main/scala/wasm/ConcolicMiniWasm.scala | 8 ++--- src/main/scala/wasm/StagedMiniWasm.scala | 31 +++++++++++-------- src/test/scala/genwasym/TestStagedEval.scala | 2 +- 4 files changed, 25 insertions(+), 18 deletions(-) rename benchmarks/wasm/staged/{push-drop.wat => scratch.wat} (87%) diff --git a/benchmarks/wasm/staged/push-drop.wat b/benchmarks/wasm/staged/scratch.wat similarity index 87% rename from benchmarks/wasm/staged/push-drop.wat rename to benchmarks/wasm/staged/scratch.wat index db7b18bd..2b0b3ede 100644 --- a/benchmarks/wasm/staged/push-drop.wat +++ b/benchmarks/wasm/staged/scratch.wat @@ -1,3 +1,5 @@ +;; this file contains some wasm instructions to test if the compiler works, +;; and its execution is meaningless. (module $push-drop (global (;0;) (mut i32) (i32.const 1048576)) (func (;0;) (type 1) (result i32) diff --git a/src/main/scala/wasm/ConcolicMiniWasm.scala b/src/main/scala/wasm/ConcolicMiniWasm.scala index fef469ec..fec869fe 100644 --- a/src/main/scala/wasm/ConcolicMiniWasm.scala +++ b/src/main/scala/wasm/ConcolicMiniWasm.scala @@ -395,9 +395,9 @@ case class Evaluator(module: ModuleInstance) { val scnd :: newSymStack = symStack val I32V(cond) :: newStack = concStack val (ifNode, elseNode) = if (scnd.isInstanceOf[Concrete]) { - // if this is a concrete value, we don't need to put + // if this is a concrete value, we don't need to put (tree, tree) - } else { + } else { val ifElseNode = tree.fillWithIfElse(Not(CondEqz(scnd))) (ifElseNode.thenNode, ifElseNode.elseNode) } @@ -413,9 +413,9 @@ case class Evaluator(module: ModuleInstance) { val scnd :: newSymStack = symStack val I32V(cond) :: newStack = concStack val (ifNode, elseNode) = if (scnd.isInstanceOf[Concrete]) { - // if this is a concrete value, we don't need to put + // if this is a concrete value, we don't need to put (tree, tree) - } else { + } else { val ifElseNode = tree.fillWithIfElse(Not(CondEqz(scnd))) (ifElseNode.thenNode, ifElseNode.elseNode) } diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 4e6f7107..b2c8a3eb 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -160,7 +160,7 @@ trait StagedWasmEvaluator extends SAIOps { } if (isTail) // when tail call, share the continuation for returning with the callee - callee(emptyStack, newFrame, kont) + callee(Stack.emptyStack, newFrame, kont) else { val restK = fun( (retStack: Rep[Stack]) => @@ -168,7 +168,7 @@ trait StagedWasmEvaluator extends SAIOps { ) // We make a new trail by `restK`, since function creates a new block to escape // (more or less like `return`) - callee(emptyStack, newFrame, restK) + callee(Stack.emptyStack, newFrame, restK) } case Import("console", "log", _) | Import("spectest", "print_i32", _) => @@ -182,7 +182,7 @@ trait StagedWasmEvaluator extends SAIOps { } def evalTestOp(op: TestOp, value: Rep[Num]): Rep[Num] = op match { - case Eqz(_) => if (value.toInt == 0) I32(1) else I32(0) + case Eqz(_) => if (value.toInt == 0) Values.I32(1) else Values.I32(0) } def evalUnaryOp(op: UnaryOp, value: Rep[Num]): Rep[Num] = op match { @@ -243,7 +243,7 @@ trait StagedWasmEvaluator extends SAIOps { } val (instrs, localSize) = (funBody.body, funBody.locals.size) val frame = frameOf(localSize) - eval(instrs, emptyStack, frame, kont, kont::Nil) // NOTE: simply use List(kont) here will cause compilation error + eval(instrs, Stack.emptyStack, frame, kont, kont::Nil) // NOTE: simply use List(kont) here will cause compilation error } def evalTop(main: Option[String]): Rep[Unit] = { @@ -253,10 +253,11 @@ trait StagedWasmEvaluator extends SAIOps { evalTop(fun(haltK), main) } - // stack creation and operations - def emptyStack: Rep[Stack] = { - "empty-stack".reflectWith() + object Stack { + def emptyStack: Rep[Stack] = { + "empty-stack".reflectWith() + } } // call unreachable @@ -264,12 +265,15 @@ trait StagedWasmEvaluator extends SAIOps { "unreachable".reflectCtrlWith() } - def I32(i: Rep[Int]): Rep[Num] = { - "I32V".reflectWith(i) - } + // runtime values + object Values { + def I32(i: Rep[Int]): Rep[Num] = { + "I32V".reflectWith(i) + } - def I64(i: Rep[Long]): Rep[Num] = { - "I64V".reflectWith(i) + def I64(i: Rep[Long]): Rep[Num] = { + "I64V".reflectWith(i) + } } // global read/write @@ -383,9 +387,9 @@ trait StagedWasmEvaluator extends SAIOps { def >=(rhs: Rep[Num]): Rep[Num] = "relation-ge".reflectWith(num, rhs) def geu(rhs: Rep[Num]): Rep[Num] = "relation-geu".reflectWith(num, rhs) - } } + trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { override def traverse(n: Node): Unit = n match { case Node(_, "frame-update", List(frame, i, value), _) => @@ -459,6 +463,7 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { case _ => super.shallow(n) } } + trait WasmCompilerDriver[A, B] extends SAIDriver[A, B] with StagedWasmEvaluator { q => override val codegen = new StagedWasmScalaGen { diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index cc7197f0..d8a12839 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -15,6 +15,6 @@ class TestStagedEval extends FunSuite { } test("push-drop") { - testFile("./benchmarks/wasm/staged/push-drop.wat") + testFile("./benchmarks/wasm/staged/scratch.wat") } } From 7a2bfd4ea55958ba6c397adac9c73b0496e2393e Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 27 Apr 2025 20:04:35 +0800 Subject: [PATCH 16/62] ci --- .github/workflows/scala.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/scala.yml b/.github/workflows/scala.yml index 5610a96e..4677da77 100644 --- a/.github/workflows/scala.yml +++ b/.github/workflows/scala.yml @@ -78,3 +78,4 @@ jobs: sbt 'testOnly gensym.wasm.TestScriptRun' sbt 'testOnly gensym.wasm.TestConcolic' sbt 'testOnly gensym.wasm.TestDriver' + sbt 'testOnly gensym.wasm.TestStagedEval' From 294fcea1df3586ca71ab53e17a7dd277b9d55243 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 27 Apr 2025 20:13:53 +0800 Subject: [PATCH 17/62] tweak --- src/main/scala/wasm/StagedMiniWasm.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index b2c8a3eb..9eb5efaa 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -145,7 +145,8 @@ trait StagedWasmEvaluator extends SAIOps { case FuncDef(_, FuncBodyDef(ty, _, locals, body)) => val args = stack.take(ty.inps.size).reverse val newStack = stack.drop(ty.inps.size) - val newFrame = frameOf(ty.inps.size + locals.size).put(args) + val newFrame = frameOf(ty.inps.size + locals.size) + newFrame.putAll(args) val callee = if (compileCache.contains(funcIndex)) { compileCache(funcIndex) @@ -333,8 +334,8 @@ trait StagedWasmEvaluator extends SAIOps { "frame-get".reflectCtrlWith(frame, i) } - def put(args: Rep[Stack]): Rep[Frame] = { - "frame-put".reflectCtrlWith(frame, args) + def putAll(args: Rep[Stack]) = { + "frame-putAll".reflectCtrlWith(frame, args) } def update(i: Int, value: Rep[Num]) = { @@ -422,8 +423,8 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { emit("new Frame("); shallow(size); emit(")") case Node(_, "frame-get", List(frame, i), _) => shallow(frame); emit("("); shallow(i); emit(")") - case Node(_, "frame-put", List(frame, args), _) => - shallow(frame); emit(".put("); shallow(args); emit(")") + case Node(_, "frame-putAll", List(frame, args), _) => + shallow(frame); emit(".putAll("); shallow(args); emit(")") case Node(_, "global-get", List(i), _) => emit("Global.globalGet("); shallow(i); emit(")") case Node(_, "binary-add", List(lhs, rhs), _) => From f450b5c68dbca114a5e20fd3fcb3ee60fc66c705 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Mon, 28 Apr 2025 00:18:24 +0800 Subject: [PATCH 18/62] try some simplification --- src/main/scala/wasm/StagedMiniWasm.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 9eb5efaa..2a50b6a5 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -307,11 +307,13 @@ trait StagedWasmEvaluator extends SAIOps { } def take(n: Int): Rep[Stack] = { - "stack-take".reflectWith(stack, n) + if (n == 0) Stack.emptyStack + else "stack-take".reflectWith(stack, n) } def drop(n: Int): Rep[Stack] = { - "stack-drop".reflectWith(stack, n) + if (n == 0) stack + else "stack-drop".reflectWith(stack, n) } def reverse: Rep[Stack] = { From 336eec590537a80d68ec751a4aee9f6bf7920fb7 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Mon, 28 Apr 2025 01:32:54 +0800 Subject: [PATCH 19/62] improve runtime(the prelude) --- src/main/scala/wasm/StagedMiniWasm.scala | 78 ++++++++++++++++++++---- 1 file changed, 67 insertions(+), 11 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 2a50b6a5..8fa27b2f 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -40,7 +40,7 @@ trait StagedWasmEvaluator extends SAIOps { val (inst, rest) = (insts.head, insts.tail) inst match { case Drop => eval(rest, stack.tail, frame, kont, trail) - case WasmConst(num) => eval(rest, num :: stack, frame, kont, trail) + case WasmConst(num) => eval(rest, Values.lift(num) :: stack, frame, kont, trail) case LocalGet(i) => eval(rest, frame.get(i) :: stack, frame, kont, trail) case LocalSet(i) => @@ -105,7 +105,7 @@ trait StagedWasmEvaluator extends SAIOps { (retStack: Rep[Stack]) => eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) ) - if (cond != 0) { + if (cond != Values.I32(0)) { eval(thn, inputs, frame, restK, restK :: trail) } else { eval(els, inputs, frame, restK, restK :: trail) @@ -268,6 +268,13 @@ trait StagedWasmEvaluator extends SAIOps { // runtime values object Values { + def lift(num: Num): Rep[Num] = { + num match { + case I32V(i) => I32(i) + case I64V(i) => I64(i) + } + } + def I32(i: Rep[Int]): Rep[Num] = { "I32V".reflectWith(i) } @@ -298,7 +305,7 @@ trait StagedWasmEvaluator extends SAIOps { "stack-tail".reflectCtrlWith(stack) } - def ::[A](v: Rep[A]): Rep[Stack] = { + def ::[A](v: Rep[Num]): Rep[Stack] = { "stack-cons".reflectCtrlWith(v, stack) } @@ -416,11 +423,11 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { case Node(_, "stack-reverse", List(stack), _) => shallow(stack); emit(".reverse") case Node(_, "stack-cons", List(v, stack), _) => - shallow(stack); emit(".push("); shallow(v); emit(")") + shallow(stack); emit(".::("); shallow(v); emit(")") case Node(_, "stack-tail", List(stack), _) => - shallow(stack); emit(".pop()") + shallow(stack); emit(".tail") case Node(_, "empty-stack", _, _) => - emit("new Stack()") + emit("Nil") case Node(_, "frame-of", List(size), _) => emit("new Frame("); shallow(size); emit(")") case Node(_, "frame-get", List(frame, i), _) => @@ -480,11 +487,60 @@ trait WasmCompilerDriver[A, B] } override val prelude = - """ - object Prelude { - } - import Prelude._ - """ + """ +object Prelude { + sealed abstract class Num { + def +(that: Num): Num = (this, that) match { + case (I32V(x), I32V(y)) => I32V(x + y) + case (I64V(x), I64V(y)) => I64V(x + y) + case _ => throw new RuntimeException("Invalid addition") + } + + def -(that: Num): Num = (this, that) match { + case (I32V(x), I32V(y)) => I32V(x - y) + case (I64V(x), I64V(y)) => I64V(x - y) + case _ => throw new RuntimeException("Invalid subtraction") + } + + def !=(that: Num): Num = (this, that) match { + case (I32V(x), I32V(y)) => I32V(if (x != y) 1 else 0) + case (I64V(x), I64V(y)) => I32V(if (x != y) 1 else 0) + case _ => throw new RuntimeException("Invalid inequality") + } + } + case class I32V(i: Int) extends Num + case class I64V(i: Long) extends Num + + + type Stack = List[Num] + + class Frame(val size: Int) { + private val data = new Array[Num](size) + def apply(i: Int): Num = data(i) + def update(i: Int, v: Num): Unit = data(i) = v + def putAll(xs: List[Num]): Unit = { + for (i <- 0 until xs.size) { + data(i) = xs(i) + } + } + } + + object Global { + // TODO: create global with specific size + private val globals = new Array[Num](10) + def globalGet(i: Int): Num = globals(i) + def globalSet(i: Int, v: Num): Unit = globals(i) = v + } +} +import Prelude._ + +object Main { + def main(args: Array[String]): Unit = { + val snippet = new Snippet() + snippet(()) + } +} +""" } object PartialEvaluator { From 6a666f35f584ab9abcdbd391f1d3de7a60ee5339 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Mon, 28 Apr 2025 22:16:16 +0800 Subject: [PATCH 20/62] some fixes --- src/main/scala/wasm/StagedMiniWasm.scala | 17 +++++++++++++---- src/test/scala/genwasym/TestStagedEval.scala | 4 ++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 8fa27b2f..208f8b95 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -114,7 +114,7 @@ trait StagedWasmEvaluator extends SAIOps { trail(label)(stack) case BrIf(label) => val (cond, newStack) = (stack.head, stack.tail) - if (cond != 0) trail(label)(newStack) + if (cond != Values.I32(0)) trail(label)(newStack) else eval(rest, newStack, frame, kont, trail) case BrTable(labels, default) => val (cond, newStack) = (stack.head, stack.tail) @@ -129,8 +129,8 @@ trait StagedWasmEvaluator extends SAIOps { case Call(f) => evalCall(rest, stack, frame, kont, trail, f, false) case ReturnCall(f) => evalCall(rest, stack, frame, kont, trail, f, true) case _ => - val noOp = "todo-op".reflectCtrlWith() - eval(rest, noOp :: stack, frame, kont, trail) + val todo = "todo-op".reflectCtrlWith() + eval(rest, todo :: stack, frame, kont, trail) } } @@ -249,7 +249,7 @@ trait StagedWasmEvaluator extends SAIOps { def evalTop(main: Option[String]): Rep[Unit] = { val haltK: Rep[Stack] => Rep[Unit] = stack => { - "no-op".reflectCtrlWith() + "no-op".reflectCtrlWith[Unit]() } evalTop(fun(haltK), main) } @@ -470,6 +470,10 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { shallow(lhs); emit(" >= "); shallow(rhs) case Node(_, "relation-geu", List(lhs, rhs), _) => shallow(lhs); emit(" >= "); shallow(rhs) + case Node(_, "num-to-int", List(num), _) => + shallow(num); emit(".toInt") + case Node(_, "no-op", _, _) => + emit("()") case _ => super.shallow(n) } } @@ -507,6 +511,11 @@ object Prelude { case (I64V(x), I64V(y)) => I32V(if (x != y) 1 else 0) case _ => throw new RuntimeException("Invalid inequality") } + + def toInt: Int = this match { + case I32V(i) => i + case I64V(i) => i.toInt + } } case class I32V(i: Int) extends Num case class I64V(i: Long) extends Num diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index d8a12839..fa7e8c65 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -10,11 +10,11 @@ import gensym.wasm.miniwasm._ class TestStagedEval extends FunSuite { def testFile(filename: String, main: Option[String] = None) = { val moduleInst = ModuleInstance(Parser.parseFile(filename)) - val code = PartialEvaluator(moduleInst, None) + val code = PartialEvaluator(moduleInst, main) println(code) } - test("push-drop") { + test("scratch") { testFile("./benchmarks/wasm/staged/scratch.wat") } } From 9947becca7cfea42e597342fdfbe25ae90ef192d Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Tue, 29 Apr 2025 00:52:36 +0800 Subject: [PATCH 21/62] fix: Frame creation is not optimizable --- src/main/scala/wasm/StagedMiniWasm.scala | 53 ++++++++++++++++---- src/test/scala/genwasym/TestStagedEval.scala | 6 ++- 2 files changed, 48 insertions(+), 11 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 208f8b95..e897cd57 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -48,7 +48,7 @@ trait StagedWasmEvaluator extends SAIOps { frame(i) = v eval(rest, newStack, frame, kont, trail) case LocalTee(i) => - val (v, _) = (stack.head, stack.tail) + val v = stack.head frame(i) = v eval(rest, stack, frame, kont, trail) case GlobalGet(i) => @@ -111,11 +111,19 @@ trait StagedWasmEvaluator extends SAIOps { eval(els, inputs, frame, restK, restK :: trail) } case Br(label) => + info(s"Jump to $label") trail(label)(stack) case BrIf(label) => val (cond, newStack) = (stack.head, stack.tail) - if (cond != Values.I32(0)) trail(label)(newStack) - else eval(rest, newStack, frame, kont, trail) + if (cond != Values.I32(0)) { + info("The br_if's condition is ", cond) + info(s"Jump to $label") + trail(label)(newStack) + } else { + info("The br_if's condition is ",cond) + info(s"Continue") + eval(rest, newStack, frame, kont, trail) + } case BrTable(labels, default) => val (cond, newStack) = (stack.head, stack.tail) if (cond.toInt < labels.length) { @@ -147,12 +155,14 @@ trait StagedWasmEvaluator extends SAIOps { val newStack = stack.drop(ty.inps.size) val newFrame = frameOf(ty.inps.size + locals.size) newFrame.putAll(args) + info("New frame:", newFrame) val callee = if (compileCache.contains(funcIndex)) { compileCache(funcIndex) } else { val callee = fun( (stack: Rep[Stack], frame: Rep[Frame], kont: Rep[Cont[Unit]]) => { + info(s"Entered the function at $funcIndex, stack =", stack, ", frame =", frame) eval(body, stack, frame, kont, kont::Nil):Rep[Unit] } ) @@ -223,7 +233,7 @@ trait StagedWasmEvaluator extends SAIOps { case Some(func_name) => module.defs.flatMap({ case Export(`func_name`, ExportFunc(fid)) => - println(s"Entering function $main") + Predef.println(s"Now compiling start with function $main") module.funcs(fid) match { case FuncDef(_, body@FuncBodyDef(_,_,_,_)) => Some(body) case _ => throw new Exception("Entry function has no concrete body") @@ -247,8 +257,12 @@ trait StagedWasmEvaluator extends SAIOps { eval(instrs, Stack.emptyStack, frame, kont, kont::Nil) // NOTE: simply use List(kont) here will cause compilation error } - def evalTop(main: Option[String]): Rep[Unit] = { + def evalTop(main: Option[String], printRes: Boolean = false): Rep[Unit] = { val haltK: Rep[Stack] => Rep[Unit] = stack => { + if (printRes) { + print("Final stack: ") + println(stack) + } "no-op".reflectCtrlWith[Unit]() } evalTop(fun(haltK), main) @@ -266,6 +280,10 @@ trait StagedWasmEvaluator extends SAIOps { "unreachable".reflectCtrlWith() } + def info(xs: Rep[_]*): Rep[Unit] = { + "info".reflectCtrlWith(xs: _*) + } + // runtime values object Values { def lift(num: Num): Rep[Num] = { @@ -334,7 +352,7 @@ trait StagedWasmEvaluator extends SAIOps { // frame creation and operations def frameOf(size: Int): Rep[Frame] = { - "frame-of".reflectWith(size) + "frame-of".reflectCtrlWith(size) } implicit class FrameOps(frame: Rep[Frame]) { @@ -408,6 +426,11 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { case Node(_, "global-set", List(i, value), _) => shallow(i); emit(".globalSet("); shallow(value); emit(")") case _ => super.traverse(n) + case Node(_, "info", xs, _) => + emit("println("); xs.foreach { x => + shallow(x); emit(", ") + }; emit(")") + } // code generation for pure nodes @@ -525,13 +548,19 @@ object Prelude { class Frame(val size: Int) { private val data = new Array[Num](size) - def apply(i: Int): Num = data(i) + def apply(i: Int): Num = { + info(s"frame(${i}) = ${data(i)}") + data(i) + } def update(i: Int, v: Num): Unit = data(i) = v def putAll(xs: List[Num]): Unit = { for (i <- 0 until xs.size) { data(i) = xs(i) } } + override def toString: String = { + "Frame(" + data.mkString(", ") + ")" + } } object Global { @@ -540,6 +569,12 @@ object Prelude { def globalGet(i: Int): Num = globals(i) def globalSet(i: Int, v: Num): Unit = globals(i) = v } + + def info(xs: Any*): Unit = { + if (System.getenv("DEBUG") != null) { + println("[INFO] " + xs.mkString(" ")) + } + } } import Prelude._ @@ -553,12 +588,12 @@ object Main { } object PartialEvaluator { - def apply(moduleInst: ModuleInstance, main: Option[String]): String = { + def apply(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): String = { println(s"Now compiling wasm module with entry function $main") val code = new WasmCompilerDriver[Unit, Unit] { def module: ModuleInstance = moduleInst def snippet(x: Rep[Unit]): Rep[Unit] = { - evalTop(main) + evalTop(main, printRes) } } code.code diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index fa7e8c65..2d9b3693 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -8,13 +8,15 @@ import gensym.wasm.parser._ import gensym.wasm.miniwasm._ class TestStagedEval extends FunSuite { - def testFile(filename: String, main: Option[String] = None) = { + def testFile(filename: String, main: Option[String] = None, printRes: Boolean = false) = { val moduleInst = ModuleInstance(Parser.parseFile(filename)) - val code = PartialEvaluator(moduleInst, main) + val code = PartialEvaluator(moduleInst, main, true) println(code) } test("scratch") { testFile("./benchmarks/wasm/staged/scratch.wat") } + + test("ack") { testFile("./benchmarks/wasm/ack.wat", Some("real_main"), printRes = true) } } From e7da82323c419895b12bcef6e09cd9e23dd0ff13 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Tue, 29 Apr 2025 11:02:14 +0800 Subject: [PATCH 22/62] demo br_table's attempts --- benchmarks/wasm/staged/scratch.wat | 2 +- src/main/scala/wasm/StagedMiniWasm.scala | 29 ++++++++++++++++++++---- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/benchmarks/wasm/staged/scratch.wat b/benchmarks/wasm/staged/scratch.wat index 2b0b3ede..6b0a4c44 100644 --- a/benchmarks/wasm/staged/scratch.wat +++ b/benchmarks/wasm/staged/scratch.wat @@ -28,7 +28,7 @@ (block i32.const 4 i32.const 2 - ;; br_table 0 0 ;; the compilation of br_table is problematic now + br_table 0 1 0 ;; the compilation of br_table is problematic now ) ) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index e897cd57..53eca440 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -126,10 +126,31 @@ trait StagedWasmEvaluator extends SAIOps { } case BrTable(labels, default) => val (cond, newStack) = (stack.head, stack.tail) - if (cond.toInt < labels.length) { - var targets: Rep[List[Cont[Unit]]] = List(labels.map(i => trail(i)): _*) - val goto: Rep[Cont[Unit]] = targets(cond.toInt) - goto(newStack) // TODO: this line will trigger an exception + if (cond.toInt < unit(labels.length)) { + // Implementation 1(trigger runtime exception): + // var targets: Rep[List[Cont[Unit]]] = List(labels.map(i => trail(i)): _*) + // val goto: Rep[Cont[Unit]] = targets(cond.toInt) + // goto(newStack) // TODO: confirm why this line will trigger an exception + + // Implementation 2(if-expression is not generated at all): + // var goto: Rep[Cont[Unit]] = null + // for (i <- Range(0, labels.length)) { + // if (i != cond.toInt) { + // info(s"Jump(br_table) to ${labels(i)}") + // return trail(labels(i))(newStack) + // } + // } + + // Implementation 3(assignment to `goto` is not generated): + var goto: Rep[Cont[Unit]] = null + for (i <- Range(0, labels.length)) { + if (i != cond.toInt) { + info(s"Jump(br_table) to ${labels(i)}") + goto = trail(labels(i)) + } + } + info(s"Jump to goto target") + goto(newStack) } else { trail(default)(newStack) } From 2de28f5a336e87c2b04a97e56698e99861ff7447 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Tue, 29 Apr 2025 11:30:00 +0800 Subject: [PATCH 23/62] fix: tail call --- src/main/scala/wasm/MiniWasm.scala | 4 ++-- src/main/scala/wasm/StagedMiniWasm.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/main/scala/wasm/MiniWasm.scala b/src/main/scala/wasm/MiniWasm.scala index 2a5abe6d..84a8bd88 100644 --- a/src/main/scala/wasm/MiniWasm.scala +++ b/src/main/scala/wasm/MiniWasm.scala @@ -255,8 +255,8 @@ case class Evaluator(module: ModuleInstance) { val frameLocals = args ++ locals.map(zero(_)) val newFrame = Frame(ArrayBuffer(frameLocals: _*)) if (isTail) - // when tail call, share the continuation for returning with the callee - eval(body, List(), newFrame, kont, List(kont)) + // when tail call, return to the caller's return continuation + eval(body, List(), newFrame, trail.last, List(trail.last)) else { val restK: Cont[Ans] = (retStack) => eval(rest, retStack.take(ty.out.size) ++ newStack, frame, kont, trail) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 53eca440..be9526c1 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -191,8 +191,8 @@ trait StagedWasmEvaluator extends SAIOps { callee } if (isTail) - // when tail call, share the continuation for returning with the callee - callee(Stack.emptyStack, newFrame, kont) + // when tail call, return to the caller's return continuation + callee(Stack.emptyStack, newFrame, trail.last) else { val restK = fun( (retStack: Rep[Stack]) => From b5a69dca1314b14af65b0d8c5cea60747a37746b Mon Sep 17 00:00:00 2001 From: ahuoguo Date: Tue, 29 Apr 2025 15:13:20 +0200 Subject: [PATCH 24/62] fix global --- benchmarks/wasm/global.wat | 19 +++++++++++++++++++ benchmarks/wasm/staged/scratch.wat | 4 ++-- src/main/scala/wasm/Parser.scala | 2 +- src/test/scala/genwasym/TestEval.scala | 3 +++ 4 files changed, 25 insertions(+), 3 deletions(-) create mode 100644 benchmarks/wasm/global.wat diff --git a/benchmarks/wasm/global.wat b/benchmarks/wasm/global.wat new file mode 100644 index 00000000..236467ef --- /dev/null +++ b/benchmarks/wasm/global.wat @@ -0,0 +1,19 @@ +(module + (type (;0;) (func (result i32))) + (type (;1;) (func)) + + (func (;0;) (type 0) (result i32) + i32.const 42 + global.set 0 + global.get 0 + ) + (func (;1;) (type 1) + call 0 + ;; should be 42 + ;; drop + ) + (start 1) + (memory (;0;) 2) + (export "main" (func 1)) + (global (;0;) (mut i32) (i32.const 0)) +) \ No newline at end of file diff --git a/benchmarks/wasm/staged/scratch.wat b/benchmarks/wasm/staged/scratch.wat index 6b0a4c44..b725d770 100644 --- a/benchmarks/wasm/staged/scratch.wat +++ b/benchmarks/wasm/staged/scratch.wat @@ -15,9 +15,9 @@ i32.add nop (call 1) - global.get 1 + global.get 0 i32.const 3 - global.set 2 ;; TODO: this line was compiled to global.get, fix the parser! + global.set 0 if (result i32) ;; label = @1 i32.const 1 diff --git a/src/main/scala/wasm/Parser.scala b/src/main/scala/wasm/Parser.scala index 40b497e0..0ce9fa94 100644 --- a/src/main/scala/wasm/Parser.scala +++ b/src/main/scala/wasm/Parser.scala @@ -314,7 +314,7 @@ class GSWasmVisitor extends WatParserBaseVisitor[WIR] { else if (ctx.LOCAL_GET() != null) LocalGet(getVar(ctx.idx(0)).toInt) else if (ctx.LOCAL_SET() != null) LocalSet(getVar(ctx.idx(0)).toInt) else if (ctx.LOCAL_TEE() != null) LocalTee(getVar(ctx.idx(0)).toInt) - else if (ctx.GLOBAL_SET() != null) GlobalGet(getVar(ctx.idx(0)).toInt) + else if (ctx.GLOBAL_SET() != null) GlobalSet(getVar(ctx.idx(0)).toInt) else if (ctx.GLOBAL_GET() != null) GlobalGet(getVar(ctx.idx(0)).toInt) else if (ctx.load() != null) { val ty = visitNumType(ctx.load.numType) diff --git a/src/test/scala/genwasym/TestEval.scala b/src/test/scala/genwasym/TestEval.scala index 38453996..2e358375 100644 --- a/src/test/scala/genwasym/TestEval.scala +++ b/src/test/scala/genwasym/TestEval.scala @@ -81,6 +81,9 @@ class TestEval extends FunSuite { test("loop block - poly br") { testFile("./benchmarks/wasm/loop_poly.wat", None, ExpStack(List(2, 1))) } + test("global") { + testFile("./benchmarks/wasm/global.wat", None, ExpInt(42)) + } // just a test for .bin.wast utility // the complete tests can be seen at https://github.com/Generative-Program-Analysis/wasm-cps/ From de8f18e2c51a6e1f4be5c9b0e1bd2bb19d3280b3 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Tue, 29 Apr 2025 21:20:47 +0800 Subject: [PATCH 25/62] fix: code generation for global.set --- src/main/scala/wasm/StagedMiniWasm.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index be9526c1..3cb0ddd7 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -445,7 +445,7 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { // TODO: what is the protocol of automatic new line insertion? shallow(frame); emit(".update("); shallow(i); emit(", "); shallow(value); emit(")\n") case Node(_, "global-set", List(i, value), _) => - shallow(i); emit(".globalSet("); shallow(value); emit(")") + emit("Global.globalSet("); shallow(i); emit(", "); shallow(value); emit(")\n") case _ => super.traverse(n) case Node(_, "info", xs, _) => emit("println("); xs.foreach { x => From 3bbd27e408323f66a6bd7c8752dbb7c1418bd749 Mon Sep 17 00:00:00 2001 From: Guannan Wei Date: Tue, 29 Apr 2025 16:16:10 +0200 Subject: [PATCH 26/62] brtable seems to work, but there is code duplication problem --- benchmarks/wasm/staged/brtable.wat | 11 +++++++++++ src/main/scala/wasm/StagedMiniWasm.scala | 13 +++++++++++-- src/test/scala/genwasym/TestStagedEval.scala | 6 ++++++ third-party/lms-clean | 2 +- 4 files changed, 29 insertions(+), 3 deletions(-) create mode 100644 benchmarks/wasm/staged/brtable.wat diff --git a/benchmarks/wasm/staged/brtable.wat b/benchmarks/wasm/staged/brtable.wat new file mode 100644 index 00000000..91133d70 --- /dev/null +++ b/benchmarks/wasm/staged/brtable.wat @@ -0,0 +1,11 @@ +(module $push-drop + (global (;0;) (mut i32) (i32.const 1048576)) + (func (;0;) (type 1) (result i32) + i32.const 2 + (block + (block + br_table 0 1 0 + ) + ) + ) + (start 0)) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 3cb0ddd7..bff5cca5 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -101,8 +101,7 @@ trait StagedWasmEvaluator extends SAIOps { val (cond, newStack) = (stack.head, stack.tail) val (inputs, restStack) = newStack.splitAt(funcTy.inps.size) // TODO: can we avoid code duplication here? - val restK = fun( - (retStack: Rep[Stack]) => + val restK = fun((retStack: Rep[Stack]) => eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) ) if (cond != Values.I32(0)) { @@ -126,6 +125,15 @@ trait StagedWasmEvaluator extends SAIOps { } case BrTable(labels, default) => val (cond, newStack) = (stack.head, stack.tail) + def aux(choices: List[Int], idx: Int): Rep[Unit] = { + if (choices.isEmpty) trail(default)(newStack) + else { + if (cond.toInt == idx) trail(choices.head)(newStack) + else aux(choices.tail, idx + 1) + } + } + aux(labels, 0) + /* if (cond.toInt < unit(labels.length)) { // Implementation 1(trigger runtime exception): // var targets: Rep[List[Cont[Unit]]] = List(labels.map(i => trail(i)): _*) @@ -154,6 +162,7 @@ trait StagedWasmEvaluator extends SAIOps { } else { trail(default)(newStack) } + */ case Return => trail.last(stack) case Call(f) => evalCall(rest, stack, frame, kont, trail, f, false) case ReturnCall(f) => evalCall(rest, stack, frame, kont, trail, f, true) diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index 2d9b3693..b572f90f 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -14,9 +14,15 @@ class TestStagedEval extends FunSuite { println(code) } + /* test("scratch") { testFile("./benchmarks/wasm/staged/scratch.wat") } test("ack") { testFile("./benchmarks/wasm/ack.wat", Some("real_main"), printRes = true) } + */ + + test("brtable") { + testFile("./benchmarks/wasm/staged/brtable.wat") + } } diff --git a/third-party/lms-clean b/third-party/lms-clean index b6f019ae..f3338d3a 160000 --- a/third-party/lms-clean +++ b/third-party/lms-clean @@ -1 +1 @@ -Subproject commit b6f019aef1df5f1f12bcd0b43a9136d7f9ce7704 +Subproject commit f3338d3ab0ea74e90e44acfdbbda7912c249a7fc From a83eb06f80e38e95a5a13722caceafbdc58f6e2e Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Mon, 5 May 2025 02:21:02 +0800 Subject: [PATCH 27/62] effectful staged interpreter --- src/main/scala/wasm/StagedMiniWasm.scala | 473 +++++++++++++---------- 1 file changed, 261 insertions(+), 212 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index bff5cca5..e04fe744 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -21,202 +21,186 @@ trait StagedWasmEvaluator extends SAIOps { // Adapter.resetState // Adapter.g = Adapter.mkGraphBuilder - trait Stack - type Cont[A] = Stack => A - type Trail[A] = List[Rep[Cont[A]]] + trait Slice trait Frame + type Cont[A] = Unit => A + type Trail[A] = List[Rep[Cont[A]]] + // a cache storing the compiled code for each function, to reduce re-compilation - val compileCache = new HashMap[Int, Rep[(Stack, Frame, Cont[Unit]) => Unit]] + val compileCache = new HashMap[Int, Rep[(Cont[Unit]) => Unit]] // NOTE: We don't support Ans type polymorphism yet def eval(insts: List[Instr], - stack: Rep[Stack], - frame: Rep[Frame], kont: Rep[Cont[Unit]], trail: Trail[Unit]): Rep[Unit] = { - if (insts.isEmpty) return kont(stack) + if (insts.isEmpty) return kont() val (inst, rest) = (insts.head, insts.tail) inst match { - case Drop => eval(rest, stack.tail, frame, kont, trail) - case WasmConst(num) => eval(rest, Values.lift(num) :: stack, frame, kont, trail) + case Drop => + Stack.pop() + eval(rest, kont, trail) + case WasmConst(num) => + Stack.push(num) + eval(rest, kont, trail) case LocalGet(i) => - eval(rest, frame.get(i) :: stack, frame, kont, trail) + Stack.push(Frames.get(i)) + eval(rest, kont, trail) case LocalSet(i) => - val (v, newStack) = (stack.head, stack.tail) - frame(i) = v - eval(rest, newStack, frame, kont, trail) + Frames.set(i, Stack.pop()) + eval(rest, kont, trail) case LocalTee(i) => - val v = stack.head - frame(i) = v - eval(rest, stack, frame, kont, trail) + Frames.set(i, Stack.peek) + eval(rest, kont, trail) case GlobalGet(i) => - eval(rest, Global.globalGet(i) :: stack, frame, kont, trail) + Stack.push(Global.globalGet(i)) + eval(rest, kont, trail) case GlobalSet(i) => - val (value, newStack) = (stack.head, stack.tail) + val value = Stack.pop() module.globals(i).ty match { case GlobalType(tipe, true) => Global.globalSet(i, value) case _ => throw new Exception("Cannot set immutable global") } - eval(rest, newStack, frame, kont, trail) + eval(rest, kont, trail) case MemorySize => ??? case MemoryGrow => ??? case MemoryFill => ??? case Nop => - eval(rest, stack, frame, kont, trail) + eval(rest, kont, trail) case Unreachable => unreachable() case Test(op) => - val (v, newStack) = (stack.head, stack.tail) - eval(rest, evalTestOp(op, v) :: newStack, frame, kont, trail) + val v = Stack.pop() + Stack.push(evalTestOp(op, v)) + eval(rest, kont, trail) case Unary(op) => - val (v, newStack) = (stack.head, stack.tail) - eval(rest, evalUnaryOp(op, v) :: newStack, frame, kont, trail) + val v = Stack.pop() + Stack.push(evalUnaryOp(op, v)) + eval(rest, kont, trail) case Binary(op) => - val (v2, v1, newStack) = (stack.head, stack.tail.head, stack.tail.tail) - eval(rest, evalBinOp(op, v1, v2) :: newStack, frame, kont, trail) + val v2 = Stack.pop() + val v1 = Stack.pop() + Stack.push(evalBinOp(op, v1, v2)) + eval(rest, kont, trail) case Compare(op) => - val (v2, v1, newStack) = (stack.head, stack.tail.head, stack.tail.tail) - eval(rest, evalRelOp(op, v1, v2) :: newStack, frame, kont, trail) + val v2 = Stack.pop() + val v1 = Stack.pop() + Stack.push(evalRelOp(op, v1, v2)) + eval(rest, kont, trail) case WasmBlock(ty, inner) => + // no need to modify the stack when entering a block + // the type system guarantees that we will never take more than the input size from the stack val funcTy = ty.funcType - val (inputs, restStack) = stack.splitAt(funcTy.inps.size) - val restK = fun( - (retStack: Rep[Stack]) => - eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) - ) - eval(inner, inputs, frame, restK, restK :: trail) + val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size + val restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { + Stack.reset(exitSize) + eval(rest, kont, trail) + }) + eval(inner, restK, restK :: trail) case Loop(ty, inner) => val funcTy = ty.funcType - val (inputs, restStack) = stack.splitAt(funcTy.inps.size) - val restK = fun( - (retStack: Rep[Stack]) => - eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) - ) - def loop(retStack: Rep[Stack]): Rep[Unit] = - eval(inner, retStack.take(funcTy.inps.size), frame, restK, fun(loop _) :: trail) - loop(inputs) + val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size + val restK = fun((_: Rep[Unit]) => { + Stack.reset(exitSize) + eval(rest, kont, trail) + }) + def loop(_u: Rep[Unit]): Rep[Unit] = + eval(inner, restK, fun(loop _) :: trail) + loop(()) case If(ty, thn, els) => val funcTy = ty.funcType - val (cond, newStack) = (stack.head, stack.tail) - val (inputs, restStack) = newStack.splitAt(funcTy.inps.size) + val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size + val cond = Stack.pop() // TODO: can we avoid code duplication here? - val restK = fun((retStack: Rep[Stack]) => - eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) - ) + val restK = fun((_: Rep[Unit]) => { + Stack.reset(exitSize) + eval(rest, kont, trail) + }) if (cond != Values.I32(0)) { - eval(thn, inputs, frame, restK, restK :: trail) + eval(thn, restK, restK :: trail) } else { - eval(els, inputs, frame, restK, restK :: trail) + eval(els, restK, restK :: trail) } case Br(label) => info(s"Jump to $label") - trail(label)(stack) + trail(label)(()) case BrIf(label) => - val (cond, newStack) = (stack.head, stack.tail) + val cond = Stack.pop() + info(s"The br_if(${label})'s condition is ", cond) if (cond != Values.I32(0)) { - info("The br_if's condition is ", cond) info(s"Jump to $label") - trail(label)(newStack) + trail(label)(()) } else { - info("The br_if's condition is ",cond) info(s"Continue") - eval(rest, newStack, frame, kont, trail) + eval(rest, kont, trail) } case BrTable(labels, default) => - val (cond, newStack) = (stack.head, stack.tail) + val cond = Stack.pop() def aux(choices: List[Int], idx: Int): Rep[Unit] = { - if (choices.isEmpty) trail(default)(newStack) + if (choices.isEmpty) trail(default)(()) else { - if (cond.toInt == idx) trail(choices.head)(newStack) + if (cond.toInt == idx) trail(choices.head)(()) else aux(choices.tail, idx + 1) } } aux(labels, 0) - /* - if (cond.toInt < unit(labels.length)) { - // Implementation 1(trigger runtime exception): - // var targets: Rep[List[Cont[Unit]]] = List(labels.map(i => trail(i)): _*) - // val goto: Rep[Cont[Unit]] = targets(cond.toInt) - // goto(newStack) // TODO: confirm why this line will trigger an exception - - // Implementation 2(if-expression is not generated at all): - // var goto: Rep[Cont[Unit]] = null - // for (i <- Range(0, labels.length)) { - // if (i != cond.toInt) { - // info(s"Jump(br_table) to ${labels(i)}") - // return trail(labels(i))(newStack) - // } - // } - - // Implementation 3(assignment to `goto` is not generated): - var goto: Rep[Cont[Unit]] = null - for (i <- Range(0, labels.length)) { - if (i != cond.toInt) { - info(s"Jump(br_table) to ${labels(i)}") - goto = trail(labels(i)) - } - } - info(s"Jump to goto target") - goto(newStack) - } else { - trail(default)(newStack) - } - */ - case Return => trail.last(stack) - case Call(f) => evalCall(rest, stack, frame, kont, trail, f, false) - case ReturnCall(f) => evalCall(rest, stack, frame, kont, trail, f, true) + case Return => trail.last(()) + case Call(f) => evalCall(rest, kont, trail, f, false) + case ReturnCall(f) => evalCall(rest, kont, trail, f, true) case _ => val todo = "todo-op".reflectCtrlWith() - eval(rest, todo :: stack, frame, kont, trail) + eval(rest, kont, trail) } } def evalCall(rest: List[Instr], - stack: Rep[Stack], - frame: Rep[Frame], kont: Rep[Cont[Unit]], trail: Trail[Unit], funcIndex: Int, isTail: Boolean): Rep[Unit] = { module.funcs(funcIndex) match { case FuncDef(_, FuncBodyDef(ty, _, locals, body)) => - val args = stack.take(ty.inps.size).reverse - val newStack = stack.drop(ty.inps.size) - val newFrame = frameOf(ty.inps.size + locals.size) - newFrame.putAll(args) - info("New frame:", newFrame) + val returnSize = Stack.size - ty.inps.size + ty.out.size + val args = Stack.take(ty.inps.size) + info("New frame:", Frames.top) val callee = if (compileCache.contains(funcIndex)) { compileCache(funcIndex) } else { - val callee = fun( - (stack: Rep[Stack], frame: Rep[Frame], kont: Rep[Cont[Unit]]) => { - info(s"Entered the function at $funcIndex, stack =", stack, ", frame =", frame) - eval(body, stack, frame, kont, kont::Nil):Rep[Unit] + val callee = topFun( + (kont: Rep[Cont[Unit]]) => { + info(s"Entered the function at $funcIndex, stackSize =", Stack.size, ", frame =", Frames.top) + eval(body, kont, kont::Nil): Rep[Unit] } ) compileCache(funcIndex) = callee callee } - if (isTail) + val frameSize = ty.inps.size + locals.size + if (isTail) { // when tail call, return to the caller's return continuation - callee(Stack.emptyStack, newFrame, trail.last) - else { - val restK = fun( - (retStack: Rep[Stack]) => - eval(rest, retStack.take(ty.out.size) ++ newStack, frame, kont, trail) - ) + Frames.popFrame() + Frames.pushFrame(frameSize) + Frames.putAll(args) + callee(trail.last) + } else { + val restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { + Stack.reset(returnSize) + Frames.popFrame() + eval(rest, kont, trail) + }) // We make a new trail by `restK`, since function creates a new block to escape // (more or less like `return`) - callee(Stack.emptyStack, newFrame, restK) + Frames.pushFrame(frameSize) + Frames.putAll(args) + callee(restK) } case Import("console", "log", _) | Import("spectest", "print_i32", _) => //println(s"[DEBUG] current stack: $stack") - val (v, newStack) = (stack.head, stack.tail) + val v = Stack.pop() println(v) - eval(rest, newStack, frame, kont, trail) + eval(rest, kont, trail) case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex") case _ => throw new Exception(s"Definition at $funcIndex is not callable") } @@ -283,122 +267,125 @@ trait StagedWasmEvaluator extends SAIOps { } } val (instrs, localSize) = (funBody.body, funBody.locals.size) - val frame = frameOf(localSize) - eval(instrs, Stack.emptyStack, frame, kont, kont::Nil) // NOTE: simply use List(kont) here will cause compilation error + Stack.initialize() + Frames.pushFrame(localSize) + eval(instrs, kont, kont::Nil) + Frames.popFrame() } def evalTop(main: Option[String], printRes: Boolean = false): Rep[Unit] = { - val haltK: Rep[Stack] => Rep[Unit] = stack => { + val haltK: Rep[Unit] => Rep[Unit] = (_) => { if (printRes) { - print("Final stack: ") - println(stack) + Stack.print() } "no-op".reflectCtrlWith[Unit]() } - evalTop(fun(haltK), main) + val temp: Rep[Cont[Unit]] = fun(haltK) + evalTop(temp, main) } // stack creation and operations object Stack { - def emptyStack: Rep[Stack] = { - "empty-stack".reflectWith() + def initialize(): Rep[Unit] = { + "stack-init".reflectCtrlWith() } - } - // call unreachable - def unreachable(): Rep[Unit] = { - "unreachable".reflectCtrlWith() - } - - def info(xs: Rep[_]*): Rep[Unit] = { - "info".reflectCtrlWith(xs: _*) - } + def pop(): Rep[Num] = { + "stack-pop".reflectCtrlWith() + } - // runtime values - object Values { - def lift(num: Num): Rep[Num] = { - num match { - case I32V(i) => I32(i) - case I64V(i) => I64(i) - } + def peek: Rep[Num] = { + "stack-peek".reflectCtrlWith() } - def I32(i: Rep[Int]): Rep[Num] = { - "I32V".reflectWith(i) + def push(v: Rep[Num]): Rep[Unit] = { + "stack-push".reflectCtrlWith(v) } - def I64(i: Rep[Long]): Rep[Num] = { - "I64V".reflectWith(i) + def drop(n: Int): Rep[Unit] = { + "stack-drop".reflectCtrlWith(n) } - } - // global read/write - object Global{ - def globalGet(i: Int): Rep[Num] = { - "global-get".reflectWith(i) + def print(): Rep[Unit] = { + "stack-print".reflectCtrlWith() } - def globalSet(i: Int, value: Rep[Num]): Rep[Unit] = { - "global-set".reflectCtrlWith(i, value) + def size: Rep[Int] = { + "stack-size".reflectCtrlWith() } - } - // TODO: The stack should be allocated on the stack to get optimal performance - implicit class StackOps(stack: Rep[Stack]) { - def head: Rep[Num] = { - "stack-head".reflectCtrlWith(stack) + def reset(x: Rep[Int]): Rep[Unit] = { + "stack-reset".reflectCtrlWith(x) } - def tail: Rep[Stack] = { - "stack-tail".reflectCtrlWith(stack) + def take(n: Int): Rep[Slice] = { + "stack-take".reflectCtrlWith(n) } + } - def ::[A](v: Rep[Num]): Rep[Stack] = { - "stack-cons".reflectCtrlWith(v, stack) + object Frames { + def get(i: Int): Rep[Num] = { + "frame-get".reflectCtrlWith(i) } - def ++(v: Rep[Stack]): Rep[Stack] = { - "stack-append".reflectCtrlWith(stack, v) + def set(i: Int, v: Rep[Num]): Rep[Unit] = { + "frame-set".reflectCtrlWith(i, v) } - def take(n: Int): Rep[Stack] = { - if (n == 0) Stack.emptyStack - else "stack-take".reflectWith(stack, n) + def pushFrame(i: Int): Rep[Unit] = { + "frame-push".reflectCtrlWith(i) } - def drop(n: Int): Rep[Stack] = { - if (n == 0) stack - else "stack-drop".reflectWith(stack, n) + def popFrame(): Rep[Unit] = { + "frame-pop".reflectCtrlWith() } - def reverse: Rep[Stack] = { - "stack-reverse".reflectWith(stack) + def putAll(args: Rep[Slice]): Rep[Unit] = { + "frame-putAll".reflectCtrlWith(args) } - def splitAt(n: Int): (Rep[Stack], Rep[Stack]) = { - (take(n), drop(n)) + def top: Rep[Frame] = { + "frame-top".reflectCtrlWith() } } - // frame creation and operations - def frameOf(size: Int): Rep[Frame] = { - "frame-of".reflectCtrlWith(size) + + // call unreachable + def unreachable(): Rep[Unit] = { + "unreachable".reflectCtrlWith() } - implicit class FrameOps(frame: Rep[Frame]) { + def info(xs: Rep[_]*): Rep[Unit] = { + "info".reflectCtrlWith(xs: _*) + } - def get(i: Int): Rep[Num] = { - "frame-get".reflectCtrlWith(frame, i) + // runtime values + object Values { + def lift(num: Num): Rep[Num] = { + num match { + case I32V(i) => I32(i) + case I64V(i) => I64(i) + } + } + + def I32(i: Rep[Int]): Rep[Num] = { + "I32V".reflectWith(i) } - def putAll(args: Rep[Stack]) = { - "frame-putAll".reflectCtrlWith(frame, args) + def I64(i: Rep[Long]): Rep[Num] = { + "I64V".reflectWith(i) } + } - def update(i: Int, value: Rep[Num]) = { - "frame-update".reflectCtrlWith(frame, i, value) + // global read/write + object Global{ + def globalGet(i: Int): Rep[Num] = { + "global-get".reflectWith(i) } + def globalSet(i: Int, value: Rep[Num]): Rep[Unit] = { + "global-set".reflectCtrlWith(i, value) + } } // runtime Num type @@ -446,49 +433,56 @@ trait StagedWasmEvaluator extends SAIOps { def geu(rhs: Rep[Num]): Rep[Num] = "relation-geu".reflectWith(num, rhs) } + implicit class SliceOps(slice: Rep[Slice]) { + def reverse: Rep[Slice] = "slice-reverse".reflectWith(slice) + } } trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { override def traverse(n: Node): Unit = n match { - case Node(_, "frame-update", List(frame, i, value), _) => - // TODO: what is the protocol of automatic new line insertion? - shallow(frame); emit(".update("); shallow(i); emit(", "); shallow(value); emit(")\n") + case Node(_, "stack-push", List(value), _) => + emit("Stack.push("); shallow(value); emit(")\n") + case Node(_, "stack-drop", List(n), _) => + emit("Stack.drop("); shallow(n); emit(")\n") + case Node(_, "stack-reset", List(n), _) => + emit("Stack.reset("); shallow(n); emit(")\n") + case Node(_, "stack-init", _, _) => + emit("Stack.initialize()\n") + case Node(_, "stack-print", _, _) => + emit("Stack.print()\n") + case Node(_, "frame-push", List(i), _) => + emit("Frames.pushFrame("); shallow(i); emit(")\n") + case Node(_, "frame-pop", _, _) => + emit("Frames.popFrame()\n") + case Node(_, "frame-putAll", List(args), _) => + emit("Frames.putAll("); shallow(args); emit(")\n") + case Node(_, "frame-set", List(i, value), _) => + emit("Frames.set("); shallow(i); emit(", "); shallow(value); emit(")\n") case Node(_, "global-set", List(i, value), _) => emit("Global.globalSet("); shallow(i); emit(", "); shallow(value); emit(")\n") case _ => super.traverse(n) - case Node(_, "info", xs, _) => - emit("println("); xs.foreach { x => - shallow(x); emit(", ") - }; emit(")") - } // code generation for pure nodes override def shallow(n: Node): Unit = n match { - case Node(_, "stack-take", List(stack, n), _) => - shallow(stack); emit(".take("); shallow(n); emit(")") - case Node(_, "stack-drop", List(stack, n), _) => - shallow(stack); emit(".drop("); shallow(n); emit(")") - case Node(_, "stack-append", List(stack1, stack2), _) => - shallow(stack1); emit(".++("); shallow(stack2); emit(")") - case Node(_, "stack-head", List(stack), _) => - shallow(stack); emit(".head") - case Node(_, "stack-reverse", List(stack), _) => - shallow(stack); emit(".reverse") - case Node(_, "stack-cons", List(v, stack), _) => - shallow(stack); emit(".::("); shallow(v); emit(")") - case Node(_, "stack-tail", List(stack), _) => - shallow(stack); emit(".tail") - case Node(_, "empty-stack", _, _) => - emit("Nil") - case Node(_, "frame-of", List(size), _) => - emit("new Frame("); shallow(size); emit(")") - case Node(_, "frame-get", List(frame, i), _) => - shallow(frame); emit("("); shallow(i); emit(")") - case Node(_, "frame-putAll", List(frame, args), _) => - shallow(frame); emit(".putAll("); shallow(args); emit(")") + case Node(_, "frame-get", List(i), _) => + emit("Frames.get("); shallow(i); emit(")") + case Node(_, "stack-pop", _, _) => + emit("Stack.pop()") + case Node(_, "frame-pop", _, _) => + emit("Frames.popFrame()") + case Node(_, "stack-peek", _, _) => + emit("Stack.peek\n") + case Node(_, "stack-take", List(n), _) => + emit("Stack.take("); shallow(n); emit(")") + case Node(_, "slice-reverse", List(slice), _) => + shallow(slice); emit(".reverse") + case Node(_, "stack-size", _, _) => + emit("Stack.size") case Node(_, "global-get", List(i), _) => emit("Global.globalGet("); shallow(i); emit(")") + case Node(_, "frame-top", _, _) => + emit("Frames.top") case Node(_, "binary-add", List(lhs, rhs), _) => shallow(lhs); emit(" + "); shallow(rhs) case Node(_, "binary-sub", List(lhs, rhs), _) => @@ -573,16 +567,48 @@ object Prelude { case class I32V(i: Int) extends Num case class I64V(i: Long) extends Num +object Stack { + private val buffer = new scala.collection.mutable.ArrayBuffer[Num]() + def push(v: Num): Unit = buffer.append(v) + def pop(): Num = { + buffer.remove(buffer.size - 1) + } + def peek: Num = { + buffer.last + } + def size: Int = buffer.size + def drop(n: Int): Unit = { + buffer.remove(buffer.size - n, n) + } + def take(n: Int): List[Num] = { + val xs = buffer.takeRight(n).toList + drop(n) + xs + } + def reset(size: Int): Unit = { + info(s"Reset stack to size $size") + while (buffer.size > size) { + buffer.remove(buffer.size - 1) + } + } + def initialize(): Unit = buffer.clear() + def print(): Unit = { + println("Stack: " + buffer.mkString(", ")) + } +} - type Stack = List[Num] + type Slice = List[Num] class Frame(val size: Int) { private val data = new Array[Num](size) def apply(i: Int): Num = { - info(s"frame(${i}) = ${data(i)}") + info(s"frame(${i}) is ${data(i)}") data(i) } - def update(i: Int, v: Num): Unit = data(i) = v + def update(i: Int, v: Num): Unit = { + info(s"set frame(${i}) to ${v}") + data(i) = v + } def putAll(xs: List[Num]): Unit = { for (i <- 0 until xs.size) { data(i) = xs(i) @@ -593,6 +619,28 @@ object Prelude { } } + object Frames { + private var frames = List[Frame]() + def pushFrame(size: Int): Unit = { + frames = new Frame(size) :: frames + } + def popFrame(): Unit = { + frames = frames.tail + } + def top: Frame = frames.head + def set(i: Int, v: Num): Unit = { + top(i) = v + } + def get(i: Int): Num = { + top(i) + } + def putAll(xs: Slice) = { + for (i <- 0 until xs.size) { + top(i) = xs(i) + } + } + } + object Global { // TODO: create global with specific size private val globals = new Array[Num](10) @@ -608,6 +656,7 @@ object Prelude { } import Prelude._ + object Main { def main(args: Array[String]): Unit = { val snippet = new Snippet() From b8a9aea960c30ac2a3e3134aa8458bbabe22cdd3 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Mon, 5 May 2025 02:21:13 +0800 Subject: [PATCH 28/62] remove non-sense tests --- benchmarks/wasm/staged/scratch.wat | 45 -------------------- src/test/scala/genwasym/TestStagedEval.scala | 6 --- 2 files changed, 51 deletions(-) delete mode 100644 benchmarks/wasm/staged/scratch.wat diff --git a/benchmarks/wasm/staged/scratch.wat b/benchmarks/wasm/staged/scratch.wat deleted file mode 100644 index b725d770..00000000 --- a/benchmarks/wasm/staged/scratch.wat +++ /dev/null @@ -1,45 +0,0 @@ -;; this file contains some wasm instructions to test if the compiler works, -;; and its execution is meaningless. -(module $push-drop - (global (;0;) (mut i32) (i32.const 1048576)) - (func (;0;) (type 1) (result i32) - (local i32 i32) - i32.const 2 - i32.const 2 - local.get 0 - local.get 1 - local.set 0 - local.tee 1 - drop - drop - i32.add - nop - (call 1) - global.get 0 - i32.const 3 - global.set 0 - - if (result i32) ;; label = @1 - i32.const 1 - else - local.get 1 - end - (block - (block - i32.const 4 - i32.const 2 - br_table 0 1 0 ;; the compilation of br_table is problematic now - ) - ) - - (loop - i32.const 5 - br 0) - return - i32.const 6 - ) - (func (;1;) (type 1) (param i32 i32) (result i32) - (local i32 i32) - local.get 0 - local.get 1) - (start 0)) \ No newline at end of file diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index b572f90f..4c46fc5b 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -14,13 +14,7 @@ class TestStagedEval extends FunSuite { println(code) } - /* - test("scratch") { - testFile("./benchmarks/wasm/staged/scratch.wat") - } - test("ack") { testFile("./benchmarks/wasm/ack.wat", Some("real_main"), printRes = true) } - */ test("brtable") { testFile("./benchmarks/wasm/staged/brtable.wat") From b7b87867a2d6e9ff5f0c029770b2898b89b6ef90 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Mon, 5 May 2025 21:57:31 +0800 Subject: [PATCH 29/62] scratch cpp backend --- src/main/scala/wasm/StagedMiniWasm.scala | 129 ++++++++++++++++++- src/test/scala/genwasym/TestStagedEval.scala | 23 +++- 2 files changed, 141 insertions(+), 11 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index e04fe744..867638e2 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -5,13 +5,13 @@ import scala.collection.mutable.{ArrayBuffer, HashMap} import lms.core.stub.Adapter import lms.core.virtualize import lms.macros.SourceContext -import lms.core.stub.{Base, ScalaGenBase} +import lms.core.stub.{Base, ScalaGenBase, CGenBase} import lms.core.Backend._ import lms.core.Backend.{Block => LMSBlock} import gensym.wasm.ast._ import gensym.wasm.ast.{Const => WasmConst, Block => WasmBlock} -import gensym.lmsx.{SAIDriver, StringOps, SAIOps, SAICodeGenBase} +import gensym.lmsx.{SAIDriver, StringOps, SAIOps, SAICodeGenBase, CppSAIDriver, CppSAICodeGenBase} @virtualize trait StagedWasmEvaluator extends SAIOps { @@ -472,7 +472,7 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { case Node(_, "frame-pop", _, _) => emit("Frames.popFrame()") case Node(_, "stack-peek", _, _) => - emit("Stack.peek\n") + emit("Stack.peek") case Node(_, "stack-take", List(n), _) => emit("Stack.take("); shallow(n); emit(")") case Node(_, "slice-reverse", List(slice), _) => @@ -525,7 +525,7 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { } } -trait WasmCompilerDriver[A, B] +trait WasmToScalaCompilerDriver[A, B] extends SAIDriver[A, B] with StagedWasmEvaluator { q => override val codegen = new StagedWasmScalaGen { val IR: q.type = q @@ -533,6 +533,7 @@ trait WasmCompilerDriver[A, B] override def remap(m: Manifest[_]): String = { if (m.toString.endsWith("Stack")) "Stack" else if(m.toString.endsWith("Frame")) "Frame" + else if(m.toString.endsWith("Slice")) "Slice" else super.remap(m) } } @@ -666,10 +667,125 @@ object Main { """ } -object PartialEvaluator { + + +object WasmToScalaCompiler { + def apply(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): String = { + println(s"Now compiling wasm module with entry function $main") + val code = new WasmToScalaCompilerDriver[Unit, Unit] { + def module: ModuleInstance = moduleInst + def snippet(x: Rep[Unit]): Rep[Unit] = { + evalTop(main, printRes) + } + } + code.code + } +} + +trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { + // for now, the traverse/shallow is same as the scala backend's + override def traverse(n: Node): Unit = n match { + case Node(_, "stack-push", List(value), _) => + emit("Stack.push("); shallow(value); emit(")\n") + case Node(_, "stack-drop", List(n), _) => + emit("Stack.drop("); shallow(n); emit(")\n") + case Node(_, "stack-reset", List(n), _) => + emit("Stack.reset("); shallow(n); emit(")\n") + case Node(_, "stack-init", _, _) => + emit("Stack.initialize()\n") + case Node(_, "stack-print", _, _) => + emit("Stack.print()\n") + case Node(_, "frame-push", List(i), _) => + emit("Frames.pushFrame("); shallow(i); emit(")\n") + case Node(_, "frame-pop", _, _) => + emit("Frames.popFrame()\n") + case Node(_, "frame-putAll", List(args), _) => + emit("Frames.putAll("); shallow(args); emit(")\n") + case Node(_, "frame-set", List(i, value), _) => + emit("Frames.set("); shallow(i); emit(", "); shallow(value); emit(")\n") + case Node(_, "global-set", List(i, value), _) => + emit("Global.globalSet("); shallow(i); emit(", "); shallow(value); emit(")\n") + case _ => super.traverse(n) + } + + // code generation for pure nodes + override def shallow(n: Node): Unit = n match { + case Node(_, "frame-get", List(i), _) => + emit("Frames.get("); shallow(i); emit(")") + case Node(_, "stack-pop", _, _) => + emit("Stack.pop()") + case Node(_, "frame-pop", _, _) => + emit("Frames.popFrame()") + case Node(_, "stack-peek", _, _) => + emit("Stack.peek") + case Node(_, "stack-take", List(n), _) => + emit("Stack.take("); shallow(n); emit(")") + case Node(_, "slice-reverse", List(slice), _) => + shallow(slice); emit(".reverse") + case Node(_, "stack-size", _, _) => + emit("Stack.size") + case Node(_, "global-get", List(i), _) => + emit("Global.globalGet("); shallow(i); emit(")") + case Node(_, "frame-top", _, _) => + emit("Frames.top") + case Node(_, "binary-add", List(lhs, rhs), _) => + shallow(lhs); emit(" + "); shallow(rhs) + case Node(_, "binary-sub", List(lhs, rhs), _) => + shallow(lhs); emit(" - "); shallow(rhs) + case Node(_, "binary-mul", List(lhs, rhs), _) => + shallow(lhs); emit(" * "); shallow(rhs) + case Node(_, "binary-div", List(lhs, rhs), _) => + shallow(lhs); emit(" / "); shallow(rhs) + case Node(_, "binary-shl", List(lhs, rhs), _) => + shallow(lhs); emit(" << "); shallow(rhs) + case Node(_, "binary-shr", List(lhs, rhs), _) => + shallow(lhs); emit(" >> "); shallow(rhs) + case Node(_, "binary-and", List(lhs, rhs), _) => + shallow(lhs); emit(" & "); shallow(rhs) + case Node(_, "relation-eq", List(lhs, rhs), _) => + shallow(lhs); emit(" == "); shallow(rhs) + case Node(_, "relation-ne", List(lhs, rhs), _) => + shallow(lhs); emit(" != "); shallow(rhs) + case Node(_, "relation-lt", List(lhs, rhs), _) => + shallow(lhs); emit(" < "); shallow(rhs) + case Node(_, "relation-ltu", List(lhs, rhs), _) => + shallow(lhs); emit(" < "); shallow(rhs) + case Node(_, "relation-gt", List(lhs, rhs), _) => + shallow(lhs); emit(" > "); shallow(rhs) + case Node(_, "relation-gtu", List(lhs, rhs), _) => + shallow(lhs); emit(" > "); shallow(rhs) + case Node(_, "relation-le", List(lhs, rhs), _) => + shallow(lhs); emit(" <= "); shallow(rhs) + case Node(_, "relation-leu", List(lhs, rhs), _) => + shallow(lhs); emit(" <= "); shallow(rhs) + case Node(_, "relation-ge", List(lhs, rhs), _) => + shallow(lhs); emit(" >= "); shallow(rhs) + case Node(_, "relation-geu", List(lhs, rhs), _) => + shallow(lhs); emit(" >= "); shallow(rhs) + case Node(_, "num-to-int", List(num), _) => + shallow(num); emit(".toInt") + case Node(_, "no-op", _, _) => + emit("()") + case _ => super.shallow(n) + } +} + + +trait WasmToCppCompilerDriver[A, B] extends CppSAIDriver[A, B] with StagedWasmEvaluator { q => + override val codegen = new StagedWasmCppGen { + val IR: q.type = q + import IR._ + override def remap(m: Manifest[_]): String = { + if (m.toString.endsWith("Num")) "Num" + else super.remap(m) + } + } +} + +object WasmToCppCompiler { def apply(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): String = { println(s"Now compiling wasm module with entry function $main") - val code = new WasmCompilerDriver[Unit, Unit] { + val code = new WasmToCppCompilerDriver[Unit, Unit] { def module: ModuleInstance = moduleInst def snippet(x: Rep[Unit]): Rep[Unit] = { evalTop(main, printRes) @@ -678,3 +794,4 @@ object PartialEvaluator { code.code } } + diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index 4c46fc5b..47afddce 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -8,15 +8,28 @@ import gensym.wasm.parser._ import gensym.wasm.miniwasm._ class TestStagedEval extends FunSuite { - def testFile(filename: String, main: Option[String] = None, printRes: Boolean = false) = { + def testFileToScala(filename: String, main: Option[String] = None, printRes: Boolean = false) = { val moduleInst = ModuleInstance(Parser.parseFile(filename)) - val code = PartialEvaluator(moduleInst, main, true) + val code = WasmToScalaCompiler(moduleInst, main, true) println(code) } - test("ack") { testFile("./benchmarks/wasm/ack.wat", Some("real_main"), printRes = true) } + test("ack-scala") { testFileToScala("./benchmarks/wasm/ack.wat", Some("real_main"), printRes = true) } - test("brtable") { - testFile("./benchmarks/wasm/staged/brtable.wat") + test("brtable-scala") { + testFileToScala("./benchmarks/wasm/staged/brtable.wat") } + + def testFileToCpp(filename: String, main: Option[String] = None, printRes: Boolean = false) = { + val moduleInst = ModuleInstance(Parser.parseFile(filename)) + val code = WasmToCppCompiler(moduleInst, main, true) + println(code) + } + + test("ack-cpp") { testFileToCpp("./benchmarks/wasm/ack.wat", Some("real_main"), printRes = true) } + + test("brtable-cpp") { + testFileToCpp("./benchmarks/wasm/staged/brtable.wat") + } + } From 0a8339e46165459adf3d900f3fce8bfe1f6a0caf Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 7 May 2025 01:39:17 +0800 Subject: [PATCH 30/62] some tweaks --- src/main/scala/wasm/StagedMiniWasm.scala | 31 +++++++++++++++--------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 867638e2..aba6ec11 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -91,6 +91,7 @@ trait StagedWasmEvaluator extends SAIOps { // no need to modify the stack when entering a block // the type system guarantees that we will never take more than the input size from the stack val funcTy = ty.funcType + // TODO: somehow the type of exitSize in residual program is nothing val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size val restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { Stack.reset(exitSize) @@ -686,25 +687,25 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { // for now, the traverse/shallow is same as the scala backend's override def traverse(n: Node): Unit = n match { case Node(_, "stack-push", List(value), _) => - emit("Stack.push("); shallow(value); emit(")\n") + emit("Stack.push("); shallow(value); emit(");\n") case Node(_, "stack-drop", List(n), _) => - emit("Stack.drop("); shallow(n); emit(")\n") + emit("Stack.drop("); shallow(n); emit(");\n") case Node(_, "stack-reset", List(n), _) => - emit("Stack.reset("); shallow(n); emit(")\n") + emit("Stack.reset("); shallow(n); emit(");\n") case Node(_, "stack-init", _, _) => - emit("Stack.initialize()\n") + emit("Stack.initialize();\n") case Node(_, "stack-print", _, _) => - emit("Stack.print()\n") + emit("Stack.print();\n") case Node(_, "frame-push", List(i), _) => - emit("Frames.pushFrame("); shallow(i); emit(")\n") + emit("Frames.pushFrame("); shallow(i); emit(");\n") case Node(_, "frame-pop", _, _) => - emit("Frames.popFrame()\n") + emit("Frames.popFrame();\n") case Node(_, "frame-putAll", List(args), _) => - emit("Frames.putAll("); shallow(args); emit(")\n") + emit("Frames.putAll("); shallow(args); emit(");\n") case Node(_, "frame-set", List(i, value), _) => - emit("Frames.set("); shallow(i); emit(", "); shallow(value); emit(")\n") + emit("Frames.set("); shallow(i); emit(", "); shallow(value); emit(");\n") case Node(_, "global-set", List(i, value), _) => - emit("Global.globalSet("); shallow(i); emit(", "); shallow(value); emit(")\n") + emit("Global.globalSet("); shallow(i); emit(", "); shallow(value); emit(");\n") case _ => super.traverse(n) } @@ -723,11 +724,11 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { case Node(_, "slice-reverse", List(slice), _) => shallow(slice); emit(".reverse") case Node(_, "stack-size", _, _) => - emit("Stack.size") + emit("Stack.size()") case Node(_, "global-get", List(i), _) => emit("Global.globalGet("); shallow(i); emit(")") case Node(_, "frame-top", _, _) => - emit("Frames.top") + emit("Frames.top()") case Node(_, "binary-add", List(lhs, rhs), _) => shallow(lhs); emit(" + "); shallow(rhs) case Node(_, "binary-sub", List(lhs, rhs), _) => @@ -777,6 +778,12 @@ trait WasmToCppCompilerDriver[A, B] extends CppSAIDriver[A, B] with StagedWasmEv import IR._ override def remap(m: Manifest[_]): String = { if (m.toString.endsWith("Num")) "Num" + else if (m.toString.endsWith("Slice")) "Slice" + else if (m.toString.endsWith("Frame")) "Frame" + else if (m.toString.endsWith("Stack")) "Stack" + else if (m.toString.endsWith("Global")) "Global" + else if (m.toString.endsWith("I32V")) "I32V" + else if (m.toString.endsWith("I64V")) "I64V" else super.remap(m) } } From b4703c7995c032dcfd7815e71dbd84c9c4c05d23 Mon Sep 17 00:00:00 2001 From: Guannan Wei Date: Mon, 12 May 2025 17:36:45 +0200 Subject: [PATCH 31/62] fix some of the nothing type --- src/main/scala/wasm/StagedMiniWasm.scala | 27 ++++++++++++------------ 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index aba6ec11..2098b6aa 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -292,7 +292,7 @@ trait StagedWasmEvaluator extends SAIOps { } def pop(): Rep[Num] = { - "stack-pop".reflectCtrlWith() + "stack-pop".reflectCtrlWith[Num]() } def peek: Rep[Num] = { @@ -312,7 +312,7 @@ trait StagedWasmEvaluator extends SAIOps { } def size: Rep[Int] = { - "stack-size".reflectCtrlWith() + "stack-size".reflectCtrlWith[Int]() } def reset(x: Rep[Int]): Rep[Unit] = { @@ -392,7 +392,7 @@ trait StagedWasmEvaluator extends SAIOps { // runtime Num type implicit class NumOps(num: Rep[Num]) { - def toInt: Rep[Int] = "num-to-int".reflectWith(num) + def toInt: Rep[Int] = "num-to-int".reflectWith[Int](num) def clz(): Rep[Num] = "unary-clz".reflectWith(num) @@ -684,6 +684,17 @@ object WasmToScalaCompiler { } trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { + override def remap(m: Manifest[_]): String = { + if (m.toString.endsWith("Num")) "Num" + else if (m.toString.endsWith("Slice")) "Slice" + else if (m.toString.endsWith("Frame")) "Frame" + else if (m.toString.endsWith("Stack")) "Stack" + else if (m.toString.endsWith("Global")) "Global" + else if (m.toString.endsWith("I32V")) "I32V" + else if (m.toString.endsWith("I64V")) "I64V" + else super.remap(m) + } + // for now, the traverse/shallow is same as the scala backend's override def traverse(n: Node): Unit = n match { case Node(_, "stack-push", List(value), _) => @@ -776,16 +787,6 @@ trait WasmToCppCompilerDriver[A, B] extends CppSAIDriver[A, B] with StagedWasmEv override val codegen = new StagedWasmCppGen { val IR: q.type = q import IR._ - override def remap(m: Manifest[_]): String = { - if (m.toString.endsWith("Num")) "Num" - else if (m.toString.endsWith("Slice")) "Slice" - else if (m.toString.endsWith("Frame")) "Frame" - else if (m.toString.endsWith("Stack")) "Stack" - else if (m.toString.endsWith("Global")) "Global" - else if (m.toString.endsWith("I32V")) "I32V" - else if (m.toString.endsWith("I64V")) "I64V" - else super.remap(m) - } } } From 29acef0c678f29824bf02c285c417872b2e2a3dc Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Tue, 13 May 2025 11:06:08 +0800 Subject: [PATCH 32/62] manually supply the reflect's type arguments --- src/main/scala/wasm/StagedMiniWasm.scala | 80 ++++++++++++------------ 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 2098b6aa..c1358894 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -149,7 +149,7 @@ trait StagedWasmEvaluator extends SAIOps { case Call(f) => evalCall(rest, kont, trail, f, false) case ReturnCall(f) => evalCall(rest, kont, trail, f, true) case _ => - val todo = "todo-op".reflectCtrlWith() + val todo = "todo-op".reflectCtrlWith[Unit]() eval(rest, kont, trail) } } @@ -288,7 +288,7 @@ trait StagedWasmEvaluator extends SAIOps { // stack creation and operations object Stack { def initialize(): Rep[Unit] = { - "stack-init".reflectCtrlWith() + "stack-init".reflectCtrlWith[Unit]() } def pop(): Rep[Num] = { @@ -296,19 +296,19 @@ trait StagedWasmEvaluator extends SAIOps { } def peek: Rep[Num] = { - "stack-peek".reflectCtrlWith() + "stack-peek".reflectCtrlWith[Num]() } def push(v: Rep[Num]): Rep[Unit] = { - "stack-push".reflectCtrlWith(v) + "stack-push".reflectCtrlWith[Unit](v) } def drop(n: Int): Rep[Unit] = { - "stack-drop".reflectCtrlWith(n) + "stack-drop".reflectCtrlWith[Unit](n) } def print(): Rep[Unit] = { - "stack-print".reflectCtrlWith() + "stack-print".reflectCtrlWith[Unit]() } def size: Rep[Int] = { @@ -316,17 +316,17 @@ trait StagedWasmEvaluator extends SAIOps { } def reset(x: Rep[Int]): Rep[Unit] = { - "stack-reset".reflectCtrlWith(x) + "stack-reset".reflectCtrlWith[Unit](x) } def take(n: Int): Rep[Slice] = { - "stack-take".reflectCtrlWith(n) + "stack-take".reflectCtrlWith[Slice](n) } } object Frames { def get(i: Int): Rep[Num] = { - "frame-get".reflectCtrlWith(i) + "frame-get".reflectCtrlWith[Num](i) } def set(i: Int, v: Rep[Num]): Rep[Unit] = { @@ -334,30 +334,30 @@ trait StagedWasmEvaluator extends SAIOps { } def pushFrame(i: Int): Rep[Unit] = { - "frame-push".reflectCtrlWith(i) + "frame-push".reflectCtrlWith[Unit](i) } def popFrame(): Rep[Unit] = { - "frame-pop".reflectCtrlWith() + "frame-pop".reflectCtrlWith[Unit]() } def putAll(args: Rep[Slice]): Rep[Unit] = { - "frame-putAll".reflectCtrlWith(args) + "frame-putAll".reflectCtrlWith[Unit](args) } def top: Rep[Frame] = { - "frame-top".reflectCtrlWith() + "frame-top".reflectCtrlWith[Frame]() } } // call unreachable def unreachable(): Rep[Unit] = { - "unreachable".reflectCtrlWith() + "unreachable".reflectCtrlWith[Unit]() } def info(xs: Rep[_]*): Rep[Unit] = { - "info".reflectCtrlWith(xs: _*) + "info".reflectCtrlWith[Unit](xs: _*) } // runtime values @@ -370,22 +370,22 @@ trait StagedWasmEvaluator extends SAIOps { } def I32(i: Rep[Int]): Rep[Num] = { - "I32V".reflectWith(i) + "I32V".reflectWith[Num](i) } def I64(i: Rep[Long]): Rep[Num] = { - "I64V".reflectWith(i) + "I64V".reflectWith[Num](i) } } // global read/write object Global{ def globalGet(i: Int): Rep[Num] = { - "global-get".reflectWith(i) + "global-get".reflectWith[Num](i) } def globalSet(i: Int, value: Rep[Num]): Rep[Unit] = { - "global-set".reflectCtrlWith(i, value) + "global-set".reflectCtrlWith[Unit](i, value) } } @@ -394,48 +394,48 @@ trait StagedWasmEvaluator extends SAIOps { def toInt: Rep[Int] = "num-to-int".reflectWith[Int](num) - def clz(): Rep[Num] = "unary-clz".reflectWith(num) + def clz(): Rep[Num] = "unary-clz".reflectWith[Num](num) - def ctz(): Rep[Num] = "unary-ctz".reflectWith(num) + def ctz(): Rep[Num] = "unary-ctz".reflectWith[Num](num) - def popcnt(): Rep[Num] = "unary-popcnt".reflectWith(num) + def popcnt(): Rep[Num] = "unary-popcnt".reflectWith[Num](num) - def +(rhs: Rep[Num]): Rep[Num] = "binary-add".reflectWith(num, rhs) + def +(rhs: Rep[Num]): Rep[Num] = "binary-add".reflectWith[Num](num, rhs) - def -(rhs: Rep[Num]): Rep[Num] = "binary-sub".reflectWith(num, rhs) + def -(rhs: Rep[Num]): Rep[Num] = "binary-sub".reflectWith[Num](num, rhs) - def *(rhs: Rep[Num]): Rep[Num] = "binary-mul".reflectWith(num, rhs) + def *(rhs: Rep[Num]): Rep[Num] = "binary-mul".reflectWith[Num](num, rhs) - def /(rhs: Rep[Num]): Rep[Num] = "binary-div".reflectWith(num, rhs) + def /(rhs: Rep[Num]): Rep[Num] = "binary-div".reflectWith[Num](num, rhs) - def <<(rhs: Rep[Num]): Rep[Num] = "binary-shl".reflectWith(num, rhs) + def <<(rhs: Rep[Num]): Rep[Num] = "binary-shl".reflectWith[Num](num, rhs) - def >>(rhs: Rep[Num]): Rep[Num] = "binary-shr".reflectWith(num, rhs) + def >>(rhs: Rep[Num]): Rep[Num] = "binary-shr".reflectWith[Num](num, rhs) - def &(rhs: Rep[Num]): Rep[Num] = "binary-and".reflectWith(num, rhs) + def &(rhs: Rep[Num]): Rep[Num] = "binary-and".reflectWith[Num](num, rhs) - def numEq(rhs: Rep[Num]): Rep[Num] = "relation-eq".reflectWith(num, rhs) + def numEq(rhs: Rep[Num]): Rep[Num] = "relation-eq".reflectWith[Num](num, rhs) - def numNe(rhs: Rep[Num]): Rep[Num] = "relation-ne".reflectWith(num, rhs) + def numNe(rhs: Rep[Num]): Rep[Num] = "relation-ne".reflectWith[Num](num, rhs) - def <(rhs: Rep[Num]): Rep[Num] = "relation-lt".reflectWith(num, rhs) + def <(rhs: Rep[Num]): Rep[Num] = "relation-lt".reflectWith[Num](num, rhs) - def ltu(rhs: Rep[Num]): Rep[Num] = "relation-ltu".reflectWith(num, rhs) + def ltu(rhs: Rep[Num]): Rep[Num] = "relation-ltu".reflectWith[Num](num, rhs) - def >(rhs: Rep[Num]): Rep[Num] = "relation-gt".reflectWith(num, rhs) + def >(rhs: Rep[Num]): Rep[Num] = "relation-gt".reflectWith[Num](num, rhs) - def gtu(rhs: Rep[Num]): Rep[Num] = "relation-gtu".reflectWith(num, rhs) + def gtu(rhs: Rep[Num]): Rep[Num] = "relation-gtu".reflectWith[Num](num, rhs) - def <=(rhs: Rep[Num]): Rep[Num] = "relation-le".reflectWith(num, rhs) + def <=(rhs: Rep[Num]): Rep[Num] = "relation-le".reflectWith[Num](num, rhs) - def leu(rhs: Rep[Num]): Rep[Num] = "relation-leu".reflectWith(num, rhs) + def leu(rhs: Rep[Num]): Rep[Num] = "relation-leu".reflectWith[Num](num, rhs) - def >=(rhs: Rep[Num]): Rep[Num] = "relation-ge".reflectWith(num, rhs) + def >=(rhs: Rep[Num]): Rep[Num] = "relation-ge".reflectWith[Num](num, rhs) - def geu(rhs: Rep[Num]): Rep[Num] = "relation-geu".reflectWith(num, rhs) + def geu(rhs: Rep[Num]): Rep[Num] = "relation-geu".reflectWith[Num](num, rhs) } implicit class SliceOps(slice: Rep[Slice]) { - def reverse: Rep[Slice] = "slice-reverse".reflectWith(slice) + def reverse: Rep[Slice] = "slice-reverse".reflectWith[Slice](slice) } } From 67b077bb02c903bd2d069cc0b81c8f777c744fc2 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 14 May 2025 21:27:55 +0800 Subject: [PATCH 33/62] lift every function to top level & avoid lms's common subexpr elimination --- src/main/scala/wasm/StagedMiniWasm.scala | 90 ++++++++++++++---------- 1 file changed, 51 insertions(+), 39 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index c1358894..dfbc4a7f 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -16,10 +16,6 @@ import gensym.lmsx.{SAIDriver, StringOps, SAIOps, SAICodeGenBase, CppSAIDriver, @virtualize trait StagedWasmEvaluator extends SAIOps { def module: ModuleInstance - // NOTE: we don't need the following statements anymore, but where are they initialized? - // reset and initialize the internal state of Adapter - // Adapter.resetState - // Adapter.g = Adapter.mkGraphBuilder trait Slice @@ -93,7 +89,7 @@ trait StagedWasmEvaluator extends SAIOps { val funcTy = ty.funcType // TODO: somehow the type of exitSize in residual program is nothing val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size - val restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { + def restK: Rep[Cont[Unit]] = topFun((_: Rep[Unit]) => { Stack.reset(exitSize) eval(rest, kont, trail) }) @@ -101,19 +97,20 @@ trait StagedWasmEvaluator extends SAIOps { case Loop(ty, inner) => val funcTy = ty.funcType val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size - val restK = fun((_: Rep[Unit]) => { + def restK = topFun((_: Rep[Unit]) => { Stack.reset(exitSize) eval(rest, kont, trail) }) - def loop(_u: Rep[Unit]): Rep[Unit] = - eval(inner, restK, fun(loop _) :: trail) + def loop : Rep[Unit => Unit] = topFun((_u: Rep[Unit]) => { + eval(inner, restK, loop :: trail) + }) loop(()) case If(ty, thn, els) => val funcTy = ty.funcType val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size val cond = Stack.pop() // TODO: can we avoid code duplication here? - val restK = fun((_: Rep[Unit]) => { + def restK = topFun((_: Rep[Unit]) => { Stack.reset(exitSize) eval(rest, kont, trail) }) @@ -185,7 +182,7 @@ trait StagedWasmEvaluator extends SAIOps { Frames.putAll(args) callee(trail.last) } else { - val restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { + val restK: Rep[Cont[Unit]] = topFun((_: Rep[Unit]) => { Stack.reset(returnSize) Frames.popFrame() eval(rest, kont, trail) @@ -281,7 +278,7 @@ trait StagedWasmEvaluator extends SAIOps { } "no-op".reflectCtrlWith[Unit]() } - val temp: Rep[Cont[Unit]] = fun(haltK) + val temp: Rep[Cont[Unit]] = topFun(haltK) evalTop(temp, main) } @@ -370,18 +367,18 @@ trait StagedWasmEvaluator extends SAIOps { } def I32(i: Rep[Int]): Rep[Num] = { - "I32V".reflectWith[Num](i) + "I32V".reflectCtrlWith[Num](i) } def I64(i: Rep[Long]): Rep[Num] = { - "I64V".reflectWith[Num](i) + "I64V".reflectCtrlWith[Num](i) } } // global read/write object Global{ def globalGet(i: Int): Rep[Num] = { - "global-get".reflectWith[Num](i) + "global-get".reflectCtrlWith[Num](i) } def globalSet(i: Int, value: Rep[Num]): Rep[Unit] = { @@ -392,50 +389,50 @@ trait StagedWasmEvaluator extends SAIOps { // runtime Num type implicit class NumOps(num: Rep[Num]) { - def toInt: Rep[Int] = "num-to-int".reflectWith[Int](num) + def toInt: Rep[Int] = "num-to-int".reflectCtrlWith[Int](num) - def clz(): Rep[Num] = "unary-clz".reflectWith[Num](num) + def clz(): Rep[Num] = "unary-clz".reflectCtrlWith[Num](num) - def ctz(): Rep[Num] = "unary-ctz".reflectWith[Num](num) + def ctz(): Rep[Num] = "unary-ctz".reflectCtrlWith[Num](num) - def popcnt(): Rep[Num] = "unary-popcnt".reflectWith[Num](num) + def popcnt(): Rep[Num] = "unary-popcnt".reflectCtrlWith[Num](num) - def +(rhs: Rep[Num]): Rep[Num] = "binary-add".reflectWith[Num](num, rhs) + def +(rhs: Rep[Num]): Rep[Num] = "binary-add".reflectCtrlWith[Num](num, rhs) - def -(rhs: Rep[Num]): Rep[Num] = "binary-sub".reflectWith[Num](num, rhs) + def -(rhs: Rep[Num]): Rep[Num] = "binary-sub".reflectCtrlWith[Num](num, rhs) - def *(rhs: Rep[Num]): Rep[Num] = "binary-mul".reflectWith[Num](num, rhs) + def *(rhs: Rep[Num]): Rep[Num] = "binary-mul".reflectCtrlWith[Num](num, rhs) - def /(rhs: Rep[Num]): Rep[Num] = "binary-div".reflectWith[Num](num, rhs) + def /(rhs: Rep[Num]): Rep[Num] = "binary-div".reflectCtrlWith[Num](num, rhs) - def <<(rhs: Rep[Num]): Rep[Num] = "binary-shl".reflectWith[Num](num, rhs) + def <<(rhs: Rep[Num]): Rep[Num] = "binary-shl".reflectCtrlWith[Num](num, rhs) - def >>(rhs: Rep[Num]): Rep[Num] = "binary-shr".reflectWith[Num](num, rhs) + def >>(rhs: Rep[Num]): Rep[Num] = "binary-shr".reflectCtrlWith[Num](num, rhs) - def &(rhs: Rep[Num]): Rep[Num] = "binary-and".reflectWith[Num](num, rhs) + def &(rhs: Rep[Num]): Rep[Num] = "binary-and".reflectCtrlWith[Num](num, rhs) - def numEq(rhs: Rep[Num]): Rep[Num] = "relation-eq".reflectWith[Num](num, rhs) + def numEq(rhs: Rep[Num]): Rep[Num] = "relation-eq".reflectCtrlWith[Num](num, rhs) - def numNe(rhs: Rep[Num]): Rep[Num] = "relation-ne".reflectWith[Num](num, rhs) + def numNe(rhs: Rep[Num]): Rep[Num] = "relation-ne".reflectCtrlWith[Num](num, rhs) - def <(rhs: Rep[Num]): Rep[Num] = "relation-lt".reflectWith[Num](num, rhs) + def <(rhs: Rep[Num]): Rep[Num] = "relation-lt".reflectCtrlWith[Num](num, rhs) - def ltu(rhs: Rep[Num]): Rep[Num] = "relation-ltu".reflectWith[Num](num, rhs) + def ltu(rhs: Rep[Num]): Rep[Num] = "relation-ltu".reflectCtrlWith[Num](num, rhs) - def >(rhs: Rep[Num]): Rep[Num] = "relation-gt".reflectWith[Num](num, rhs) + def >(rhs: Rep[Num]): Rep[Num] = "relation-gt".reflectCtrlWith[Num](num, rhs) - def gtu(rhs: Rep[Num]): Rep[Num] = "relation-gtu".reflectWith[Num](num, rhs) + def gtu(rhs: Rep[Num]): Rep[Num] = "relation-gtu".reflectCtrlWith[Num](num, rhs) - def <=(rhs: Rep[Num]): Rep[Num] = "relation-le".reflectWith[Num](num, rhs) + def <=(rhs: Rep[Num]): Rep[Num] = "relation-le".reflectCtrlWith[Num](num, rhs) - def leu(rhs: Rep[Num]): Rep[Num] = "relation-leu".reflectWith[Num](num, rhs) + def leu(rhs: Rep[Num]): Rep[Num] = "relation-leu".reflectCtrlWith[Num](num, rhs) - def >=(rhs: Rep[Num]): Rep[Num] = "relation-ge".reflectWith[Num](num, rhs) + def >=(rhs: Rep[Num]): Rep[Num] = "relation-ge".reflectCtrlWith[Num](num, rhs) - def geu(rhs: Rep[Num]): Rep[Num] = "relation-geu".reflectWith[Num](num, rhs) + def geu(rhs: Rep[Num]): Rep[Num] = "relation-geu".reflectCtrlWith[Num](num, rhs) } implicit class SliceOps(slice: Rep[Slice]) { - def reverse: Rep[Slice] = "slice-reverse".reflectWith[Slice](slice) + def reverse: Rep[Slice] = "slice-reverse".reflectCtrlWith[Slice](slice) } } @@ -729,7 +726,7 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { case Node(_, "frame-pop", _, _) => emit("Frames.popFrame()") case Node(_, "stack-peek", _, _) => - emit("Stack.peek") + emit("Stack.peek()") case Node(_, "stack-take", List(n), _) => emit("Stack.take("); shallow(n); emit(")") case Node(_, "slice-reverse", List(slice), _) => @@ -775,11 +772,26 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { case Node(_, "relation-geu", List(lhs, rhs), _) => shallow(lhs); emit(" >= "); shallow(rhs) case Node(_, "num-to-int", List(num), _) => - shallow(num); emit(".toInt") + shallow(num); emit(".toInt()") case Node(_, "no-op", _, _) => - emit("()") + emit("std::monostate()") case _ => super.shallow(n) } + + override def registerTopLevelFunction(id: String, streamId: String = "general")(f: => Unit) = + if (!registeredFunctions(id)) { + //if (ongoingFun(streamId)) ??? + //ongoingFun += streamId + registeredFunctions += id + withStream(functionsStreams.getOrElseUpdate(id, { + val functionsStream = new java.io.ByteArrayOutputStream() + val functionsWriter = new java.io.PrintStream(functionsStream) + (functionsWriter, functionsStream) + })._1)(f) + //ongoingFun -= streamId + } else { + withStream(functionsStreams(id)._1)(f) + } } From 6e41521a7704cbb0b7c7b55c58f78f3a81decd45 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 14 May 2025 21:50:51 +0800 Subject: [PATCH 34/62] stack pop example --- benchmarks/wasm/staged/pop.wat | 8 ++++++++ src/main/scala/wasm/StagedMiniWasm.scala | 4 ++-- src/test/scala/genwasym/TestStagedEval.scala | 4 ++++ 3 files changed, 14 insertions(+), 2 deletions(-) create mode 100644 benchmarks/wasm/staged/pop.wat diff --git a/benchmarks/wasm/staged/pop.wat b/benchmarks/wasm/staged/pop.wat new file mode 100644 index 00000000..691839b7 --- /dev/null +++ b/benchmarks/wasm/staged/pop.wat @@ -0,0 +1,8 @@ +(module $push-drop + (global (;0;) (mut i32) (i32.const 1048576)) + (func (;0;) (type 1) (result) + i32.const 2 + i32.const 2 + i32.add + ) + (start 0)) \ No newline at end of file diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index dfbc4a7f..8fb12001 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -442,6 +442,8 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { emit("Stack.push("); shallow(value); emit(")\n") case Node(_, "stack-drop", List(n), _) => emit("Stack.drop("); shallow(n); emit(")\n") + case Node(_, "stack-pop", _, _) => + emit("Stack.pop()\n") case Node(_, "stack-reset", List(n), _) => emit("Stack.reset("); shallow(n); emit(")\n") case Node(_, "stack-init", _, _) => @@ -465,8 +467,6 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { override def shallow(n: Node): Unit = n match { case Node(_, "frame-get", List(i), _) => emit("Frames.get("); shallow(i); emit(")") - case Node(_, "stack-pop", _, _) => - emit("Stack.pop()") case Node(_, "frame-pop", _, _) => emit("Frames.popFrame()") case Node(_, "stack-peek", _, _) => diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index 47afddce..c96f9b7e 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -20,6 +20,10 @@ class TestStagedEval extends FunSuite { testFileToScala("./benchmarks/wasm/staged/brtable.wat") } + test("drop-scala") { + testFileToScala("./benchmarks/wasm/staged/pop.wat") + } + def testFileToCpp(filename: String, main: Option[String] = None, printRes: Boolean = false) = { val moduleInst = ModuleInstance(Parser.parseFile(filename)) val code = WasmToCppCompiler(moduleInst, main, true) From 9f04722faa780ed516b78d1936bd24a2a9bab443 Mon Sep 17 00:00:00 2001 From: Guannan Wei Date: Thu, 15 May 2025 23:11:19 +0200 Subject: [PATCH 35/62] not inlining + shallow --- src/main/scala/wasm/StagedMiniWasm.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 8fb12001..d35318f5 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -437,13 +437,16 @@ trait StagedWasmEvaluator extends SAIOps { } trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { + override def mayInline(n: Node): Boolean = n match { + case Node(s, "stack-pop", _, _) => false + case _ => super.mayInline(n) + } + override def traverse(n: Node): Unit = n match { case Node(_, "stack-push", List(value), _) => emit("Stack.push("); shallow(value); emit(")\n") case Node(_, "stack-drop", List(n), _) => emit("Stack.drop("); shallow(n); emit(")\n") - case Node(_, "stack-pop", _, _) => - emit("Stack.pop()\n") case Node(_, "stack-reset", List(n), _) => emit("Stack.reset("); shallow(n); emit(")\n") case Node(_, "stack-init", _, _) => @@ -469,6 +472,8 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { emit("Frames.get("); shallow(i); emit(")") case Node(_, "frame-pop", _, _) => emit("Frames.popFrame()") + case Node(_, "stack-pop", _, _) => + emit("Stack.pop()") case Node(_, "stack-peek", _, _) => emit("Stack.peek") case Node(_, "stack-take", List(n), _) => @@ -666,7 +671,6 @@ object Main { } - object WasmToScalaCompiler { def apply(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): String = { println(s"Now compiling wasm module with entry function $main") From ed9c8e42b6948ef74839ba5c204602245fb9e3f8 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sat, 17 May 2025 11:58:36 +0800 Subject: [PATCH 36/62] an almost work runtime --- src/main/scala/wasm/StagedMiniWasm.scala | 323 ++++++++++++++++++++++- 1 file changed, 312 insertions(+), 11 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index d35318f5..aecd6391 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -8,6 +8,7 @@ import lms.macros.SourceContext import lms.core.stub.{Base, ScalaGenBase, CGenBase} import lms.core.Backend._ import lms.core.Backend.{Block => LMSBlock} +import lms.core.Graph import gensym.wasm.ast._ import gensym.wasm.ast.{Const => WasmConst, Block => WasmBlock} @@ -88,20 +89,17 @@ trait StagedWasmEvaluator extends SAIOps { // the type system guarantees that we will never take more than the input size from the stack val funcTy = ty.funcType // TODO: somehow the type of exitSize in residual program is nothing - val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size - def restK: Rep[Cont[Unit]] = topFun((_: Rep[Unit]) => { - Stack.reset(exitSize) + def restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { eval(rest, kont, trail) }) eval(inner, restK, restK :: trail) case Loop(ty, inner) => val funcTy = ty.funcType val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size - def restK = topFun((_: Rep[Unit]) => { - Stack.reset(exitSize) + def restK = fun((_: Rep[Unit]) => { eval(rest, kont, trail) }) - def loop : Rep[Unit => Unit] = topFun((_u: Rep[Unit]) => { + def loop : Rep[Unit => Unit] = fun((_u: Rep[Unit]) => { eval(inner, restK, loop :: trail) }) loop(()) @@ -110,8 +108,7 @@ trait StagedWasmEvaluator extends SAIOps { val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size val cond = Stack.pop() // TODO: can we avoid code duplication here? - def restK = topFun((_: Rep[Unit]) => { - Stack.reset(exitSize) + def restK = fun((_: Rep[Unit]) => { eval(rest, kont, trail) }) if (cond != Values.I32(0)) { @@ -182,8 +179,7 @@ trait StagedWasmEvaluator extends SAIOps { Frames.putAll(args) callee(trail.last) } else { - val restK: Rep[Cont[Unit]] = topFun((_: Rep[Unit]) => { - Stack.reset(returnSize) + val restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { Frames.popFrame() eval(rest, kont, trail) }) @@ -278,7 +274,7 @@ trait StagedWasmEvaluator extends SAIOps { } "no-op".reflectCtrlWith[Unit]() } - val temp: Rep[Cont[Unit]] = topFun(haltK) + val temp: Rep[Cont[Unit]] = fun(haltK) evalTop(temp, main) } @@ -796,6 +792,310 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { } else { withStream(functionsStreams(id)._1)(f) } + + override def emitAll(g: Graph, name: String)(m1: Manifest[_], m2: Manifest[_]): Unit = { + val ng = init(g) + emitln(prelude) + emitln(""" + |/***************************************** + |Emitting Generated Code + |*******************************************/ + """.stripMargin) + emitln(""" +#include +#include +#include +#include +#include """) + val src = run(name, ng) + emit(src) + emitln(""" + |/***************************************** + |End of Generated Code + |*******************************************/ + |int main(int argc, char *argv[]) { + | Snippet(std::monostate{}); + | return 0; + |}""".stripMargin) + } + + val prelude = """ +#include +#include +#include +#include +#include +#include +#include +#include + +#define info(x, ...) + +class Num_t { +public: + virtual std::unique_ptr clone() const = 0; + + virtual void display() = 0; + virtual int32_t toInt() = 0; + virtual int64_t toLong() = 0; +}; + +class I32V_t : public Num_t { +public: + I32V_t(int32_t value) : value_(value) {} + + std::unique_ptr clone() const override { + return std::make_unique(*this); + } + + void display() override { std::cout << value_ << std::endl; } + + int32_t toInt() override { return value_; } + + int64_t toLong() override { return static_cast(value_); } + +private: + int32_t value_; +}; + +class I64V_t : public Num_t { +public: + I64V_t(int64_t value) : value_(value) {} + + std::unique_ptr clone() const override { + return std::make_unique(*this); + } + + void display() override { std::cout << value_ << std::endl; } + + int32_t toInt() override { return static_cast(value_); } + + int64_t toLong() override { return value_; } + +private: + int64_t value_; +}; + +struct Num { + std::unique_ptr num_ptr; + + // Constructions and destruction + Num() : num_ptr(nullptr) {} + + Num(std::unique_ptr num_ptr_) : num_ptr(std::move(num_ptr_)) {} + + Num &operator=(const Num &other) { + if (this != &other) { + num_ptr = other.num_ptr ? other.num_ptr->clone() : nullptr; + } + return *this; + } + + Num(const Num &other) { + num_ptr = other.num_ptr ? other.num_ptr->clone() : nullptr; + } + + Num(Num &&other) noexcept = default; + + Num &operator=(Num &&other) noexcept = default; + + ~Num() = default; + + int32_t toInt() const { return num_ptr->toInt(); } + + int32_t toLong() const { return num_ptr->toLong(); } + + void display() const { num_ptr->display(); } + + Num operator+(const Num &other) const { + if (dynamic_cast(num_ptr.get()) && + dynamic_cast(other.num_ptr.get())) { + return Num( + std::make_unique(I32V_t(this->toInt() + other.toInt()))); + } else if (dynamic_cast(num_ptr.get()) && + dynamic_cast(other.num_ptr.get())) { + return Num( + std::make_unique(I64V_t(this->toLong() + other.toLong()))); + } else { + throw std::runtime_error("Operands are of different types"); + } + } + + Num operator-(const Num &other) const { + if (dynamic_cast(num_ptr.get()) && + dynamic_cast(other.num_ptr.get())) { + return Num( + std::make_unique(I32V_t(this->toInt() - other.toInt()))); + } else if (dynamic_cast(num_ptr.get()) && + dynamic_cast(other.num_ptr.get())) { + return Num( + std::make_unique(I64V_t(this->toLong() - other.toLong()))); + } else { + throw std::runtime_error("Operands are of different types"); + } + } + + bool operator==(const Num &other) const { + if (dynamic_cast(num_ptr.get()) && + dynamic_cast(other.num_ptr.get())) { + return this->toInt() == other.toInt(); + } else if (dynamic_cast(num_ptr.get()) && + dynamic_cast(other.num_ptr.get())) { + return this->toLong() == other.toLong(); + } else { + throw std::runtime_error("Operands are of different types"); + } + } + + bool operator!=(const Num &other) const { return !(this->operator==(other)); } +}; + +static Num I32V(int v) { return Num(std::make_unique(v)); } + +static Num I64V(int64_t v) { return Num(std::make_unique(v)); } + +// struct Slice { +// int32_t start; +// int32_t end; +// Slice(int32_t start_, int32_t end_) : start(start_), end(end_) {} +// }; + +using Slice = std::vector; + +class Stack_t { +public: + void push(Num &&num) { + assert(num.num_ptr != nullptr); + stack_.push_back(std::move(num)); + } + + void push(Num &num) { + assert(num.num_ptr != nullptr); + stack_.push_back(num); + } + + Num pop() { + if (stack_.empty()) { + throw std::runtime_error("Stack underflow"); + } + Num num = std::move(stack_.back()); + assert(num.num_ptr != nullptr); + stack_.pop_back(); + return num; + } + + Num peek() { + if (stack_.empty()) { + throw std::runtime_error("Stack underflow"); + } + return stack_.back(); + } + + Num get(int32_t index) { + assert(index >= 0); + assert(index < stack_.size()); + return stack_[index]; +} + + int32_t size() { return stack_.size(); } + + void reset(int32_t size) { + if (size > stack_.size()) { + throw std::out_of_range("Invalid size"); + } + while (stack_.size() > size) { + stack_.pop_back(); + } + } + + Slice take(int32_t size) { + if (size > stack_.size()) { + throw std::out_of_range("Invalid size"); + } + // todo: avoid re-allocation + Slice slice(stack_.end() - size, stack_.end()); + stack_.resize(stack_.size() - size); + return slice; + } + + void print() { + std::cout << "Stack contents: " << std::endl; + for (const auto &num : stack_) { + num.display(); + } + } + + void initialize() { stack_.clear(); } + +private: + std::vector stack_; +}; +static Stack_t Stack; + +struct Frame_t { + std::vector locals; + + Frame_t(std::int32_t size) : locals() { locals.resize(size); } + Num &operator[](std::int32_t index) { + assert(index >= 0); + if (index >= locals.size()) { + throw std::out_of_range("Index out of range"); + } + return locals[index]; + } + void putAll(Slice slice) { + for (std::int32_t i = 0; i < slice.size(); ++i) { + locals[i] = slice[i]; + } + } +}; + +class Frames_t { +public: + std::monostate popFrame() { + if (!frames.empty()) { + frames.pop_back(); + return std::monostate{}; + } else { + std::cout << "No frames to pop." << std::endl; + throw std::runtime_error("No frames to pop."); + } + } + + Num get(std::int32_t index) { + auto ret = top()[index]; + assert(ret.num_ptr != nullptr); + return ret; + } + + void set(std::int32_t index, Num num) { frames.back()[index] = num; } + + Frame_t &top() { + if (frames.empty()) { + throw std::runtime_error("No frames available"); + } + return frames.back(); + } + + void pushFrame(std::int32_t size) { + Frame_t frame(size); + frames.push_back(frame); + } + + void putAll(Slice slice) { + top().putAll(slice); + } + +private: + std::vector frames; +}; + +static Frames_t Frames; + +static void initRand() { + // for now, just do nothing +} + """ } @@ -817,5 +1117,6 @@ object WasmToCppCompiler { } code.code } + } From d5ed20d52078a28bf66ba87139b9571ae558d049 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sat, 17 May 2025 14:57:20 +0800 Subject: [PATCH 37/62] emit functions --- src/main/scala/wasm/StagedMiniWasm.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index aecd6391..774b921d 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -801,6 +801,7 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { |Emitting Generated Code |*******************************************/ """.stripMargin) + emitln(""" #include #include @@ -808,6 +809,9 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { #include #include """) val src = run(name, ng) + emitFunctionDecls(stream) + emitDatastructures(stream) + emitFunctions(stream) emit(src) emitln(""" |/***************************************** From 4fb5424e076019d2910849047fe3aafb4b06aa79 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 18 May 2025 10:36:42 +0800 Subject: [PATCH 38/62] read a dummy node to avoid lambda lifting it seems that the lambda lifting is unsound --- src/main/scala/wasm/StagedMiniWasm.scala | 35 +++++++++++++++++++++--- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 774b921d..25037e27 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -88,18 +88,26 @@ trait StagedWasmEvaluator extends SAIOps { // no need to modify the stack when entering a block // the type system guarantees that we will never take more than the input size from the stack val funcTy = ty.funcType + val dummy = "dummy".reflectCtrlWith[Unit]() // TODO: somehow the type of exitSize in residual program is nothing def restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { + info(s"Exiting the block, stackSize =", Stack.size) + "dummy-op".reflectCtrlWith[Unit](dummy) eval(rest, kont, trail) }) eval(inner, restK, restK :: trail) case Loop(ty, inner) => val funcTy = ty.funcType val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size + val dummy = "dummy".reflectCtrlWith[Unit]() def restK = fun((_: Rep[Unit]) => { + "dummy-op".reflectCtrlWith[Unit](dummy) + info(s"Exiting the loop, stackSize =", Stack.size) eval(rest, kont, trail) }) def loop : Rep[Unit => Unit] = fun((_u: Rep[Unit]) => { + "dummy-op".reflectCtrlWith[Unit](dummy) + info(s"Entered the loop, stackSize =", Stack.size) eval(inner, restK, loop :: trail) }) loop(()) @@ -107,8 +115,11 @@ trait StagedWasmEvaluator extends SAIOps { val funcTy = ty.funcType val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size val cond = Stack.pop() + val dummy = "dummy".reflectCtrlWith[Unit]() // TODO: can we avoid code duplication here? def restK = fun((_: Rep[Unit]) => { + "dummy-op".reflectCtrlWith[Unit](dummy) + info(s"Exiting the if, stackSize =", Stack.size) eval(rest, kont, trail) }) if (cond != Values.I32(0)) { @@ -121,7 +132,7 @@ trait StagedWasmEvaluator extends SAIOps { trail(label)(()) case BrIf(label) => val cond = Stack.pop() - info(s"The br_if(${label})'s condition is ", cond) + info(s"The br_if(${label})'s condition is ", cond.toInt) if (cond != Values.I32(0)) { info(s"Jump to $label") trail(label)(()) @@ -157,14 +168,14 @@ trait StagedWasmEvaluator extends SAIOps { case FuncDef(_, FuncBodyDef(ty, _, locals, body)) => val returnSize = Stack.size - ty.inps.size + ty.out.size val args = Stack.take(ty.inps.size) - info("New frame:", Frames.top) + // info("New frame:", Frames.top) val callee = if (compileCache.contains(funcIndex)) { compileCache(funcIndex) } else { val callee = topFun( (kont: Rep[Cont[Unit]]) => { - info(s"Entered the function at $funcIndex, stackSize =", Stack.size, ", frame =", Frames.top) + info(s"Entered the function at $funcIndex, stackSize =", Stack.size) eval(body, kont, kont::Nil): Rep[Unit] } ) @@ -180,6 +191,7 @@ trait StagedWasmEvaluator extends SAIOps { callee(trail.last) } else { val restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { + info(s"Exiting the function at $funcIndex, stackSize =", Stack.size) Frames.popFrame() eval(rest, kont, trail) }) @@ -269,6 +281,7 @@ trait StagedWasmEvaluator extends SAIOps { def evalTop(main: Option[String], printRes: Boolean = false): Rep[Unit] = { val haltK: Rep[Unit] => Rep[Unit] = (_) => { + info("Exiting the program...") if (printRes) { Stack.print() } @@ -773,6 +786,8 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { shallow(lhs); emit(" >= "); shallow(rhs) case Node(_, "num-to-int", List(num), _) => shallow(num); emit(".toInt()") + case Node(_, "dummy", _, _) => emit("std::monostate()") + case Node(_, "dummy-op", _, _) => emit("std::monostate()") case Node(_, "no-op", _, _) => emit("std::monostate()") case _ => super.shallow(n) @@ -833,7 +848,19 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { #include #include -#define info(x, ...) +void info() { +#ifdef DEBUG + std::cout << std::endl; +#endif +} + +template +void info(const T &first, const Args &...args) { +#ifdef DEBUG + std::cout << first << " "; + info(args...); +#endif +} class Num_t { public: From 8e293b8dbe5648d84669e39fbdefcd771c9f2cb0 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 18 May 2025 10:49:51 +0800 Subject: [PATCH 39/62] capture by value is not friendly with recursion --- src/main/scala/wasm/StagedMiniWasm.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 25037e27..7920c0dc 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -727,6 +727,14 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { emit("Frames.set("); shallow(i); emit(", "); shallow(value); emit(");\n") case Node(_, "global-set", List(i, value), _) => emit("Global.globalSet("); shallow(i); emit(", "); shallow(value); emit(");\n") + case n @ Node(f, "λ", (b: LMSBlock)::rest, _) => + // Node: This code is copied from the traverse of CppSAICodeGenBase.scala, try to avoid code duplication + val retType = remap(typeBlockRes(b.res)) + val argTypes = b.in.map(a => remap(typeMap(a))).mkString(", ") + emitln(s"std::function<$retType(${argTypes})> ${quote(f)};") + emit(quote(f)); emit(" = ") + quoteTypedBlock(b, false, true, capture = "&") + emitln(";") case _ => super.traverse(n) } From 51dd632c3fb990f0cabdb7134b2e63138b09a14c Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 18 May 2025 11:00:08 +0800 Subject: [PATCH 40/62] redirect generated code to a file --- src/main/scala/wasm/StagedMiniWasm.scala | 4 ++-- src/test/scala/genwasym/TestStagedEval.scala | 12 ++++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 7920c0dc..2bea6cd5 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -681,7 +681,7 @@ object Main { object WasmToScalaCompiler { - def apply(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): String = { + def compile(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): String = { println(s"Now compiling wasm module with entry function $main") val code = new WasmToScalaCompilerDriver[Unit, Unit] { def module: ModuleInstance = moduleInst @@ -1146,7 +1146,7 @@ trait WasmToCppCompilerDriver[A, B] extends CppSAIDriver[A, B] with StagedWasmEv } object WasmToCppCompiler { - def apply(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): String = { + def compile(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): String = { println(s"Now compiling wasm module with entry function $main") val code = new WasmToCppCompilerDriver[Unit, Unit] { def module: ModuleInstance = moduleInst diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index c96f9b7e..a74fbc6f 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -10,7 +10,7 @@ import gensym.wasm.miniwasm._ class TestStagedEval extends FunSuite { def testFileToScala(filename: String, main: Option[String] = None, printRes: Boolean = false) = { val moduleInst = ModuleInstance(Parser.parseFile(filename)) - val code = WasmToScalaCompiler(moduleInst, main, true) + val code = WasmToScalaCompiler.compile(moduleInst, main, true) println(code) } @@ -26,7 +26,15 @@ class TestStagedEval extends FunSuite { def testFileToCpp(filename: String, main: Option[String] = None, printRes: Boolean = false) = { val moduleInst = ModuleInstance(Parser.parseFile(filename)) - val code = WasmToCppCompiler(moduleInst, main, true) + val code = WasmToCppCompiler.compile(moduleInst, main, true) + if (printRes) { + val writer = new java.io.PrintWriter(new java.io.File(s"$filename.cpp")) + try { + writer.write(code) + } finally { + writer.close() + } + } println(code) } From 2c6d5f66822832e3e39b548d9f53f7feb619f697 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Mon, 19 May 2025 11:06:38 +0800 Subject: [PATCH 41/62] fix printing logic in test --- src/test/scala/genwasym/TestStagedEval.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index a74fbc6f..db45f1e5 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -28,14 +28,14 @@ class TestStagedEval extends FunSuite { val moduleInst = ModuleInstance(Parser.parseFile(filename)) val code = WasmToCppCompiler.compile(moduleInst, main, true) if (printRes) { - val writer = new java.io.PrintWriter(new java.io.File(s"$filename.cpp")) - try { - writer.write(code) - } finally { - writer.close() - } + println(code) + } + val writer = new java.io.PrintWriter(new java.io.File(s"$filename.cpp")) + try { + writer.write(code) + } finally { + writer.close() } - println(code) } test("ack-cpp") { testFileToCpp("./benchmarks/wasm/ack.wat", Some("real_main"), printRes = true) } From 5dbc219919bb1dfd50360130db7afc7c5370fc36 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Mon, 19 May 2025 14:07:01 +0800 Subject: [PATCH 42/62] extract the dummy writing pattern as a function --- src/main/scala/wasm/StagedMiniWasm.scala | 26 +++++++++++++----------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 2bea6cd5..4067e5e5 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -28,6 +28,14 @@ trait StagedWasmEvaluator extends SAIOps { // a cache storing the compiled code for each function, to reduce re-compilation val compileCache = new HashMap[Int, Rep[(Cont[Unit]) => Unit]] + def funHere[A:Manifest,B:Manifest](f: Rep[A] => Rep[B], dummy: Rep[Unit] = "dummy".reflectCtrlWith[Unit]()): Rep[A => B] = { + // to avoid LMS lifting a function, we create a dummy node and read it inside function + fun((x: Rep[A]) => { + "dummy-op".reflectCtrlWith[Unit](dummy) + f(x) + }) + } + // NOTE: We don't support Ans type polymorphism yet def eval(insts: List[Instr], kont: Rep[Cont[Unit]], @@ -88,37 +96,31 @@ trait StagedWasmEvaluator extends SAIOps { // no need to modify the stack when entering a block // the type system guarantees that we will never take more than the input size from the stack val funcTy = ty.funcType - val dummy = "dummy".reflectCtrlWith[Unit]() // TODO: somehow the type of exitSize in residual program is nothing - def restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { + def restK: Rep[Cont[Unit]] = funHere((_: Rep[Unit]) => { info(s"Exiting the block, stackSize =", Stack.size) - "dummy-op".reflectCtrlWith[Unit](dummy) eval(rest, kont, trail) }) eval(inner, restK, restK :: trail) case Loop(ty, inner) => val funcTy = ty.funcType val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size - val dummy = "dummy".reflectCtrlWith[Unit]() - def restK = fun((_: Rep[Unit]) => { - "dummy-op".reflectCtrlWith[Unit](dummy) + def restK = funHere((_: Rep[Unit]) => { info(s"Exiting the loop, stackSize =", Stack.size) eval(rest, kont, trail) }) - def loop : Rep[Unit => Unit] = fun((_u: Rep[Unit]) => { - "dummy-op".reflectCtrlWith[Unit](dummy) + val dummy = "dummy".reflectCtrlWith[Unit]() + def loop : Rep[Unit => Unit] = funHere((_u: Rep[Unit]) => { info(s"Entered the loop, stackSize =", Stack.size) eval(inner, restK, loop :: trail) - }) + }, dummy) // <-- if we don't pass this dummy argument, lots of code will be generated loop(()) case If(ty, thn, els) => val funcTy = ty.funcType val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size val cond = Stack.pop() - val dummy = "dummy".reflectCtrlWith[Unit]() // TODO: can we avoid code duplication here? - def restK = fun((_: Rep[Unit]) => { - "dummy-op".reflectCtrlWith[Unit](dummy) + def restK = funHere((_: Rep[Unit]) => { info(s"Exiting the if, stackSize =", Stack.size) eval(rest, kont, trail) }) From a0d31e54cc07bc6bc279ad5ca26fef14416abe40 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Mon, 19 May 2025 15:09:34 +0800 Subject: [PATCH 43/62] don't inline stack-pop to improve readability --- src/main/scala/wasm/StagedMiniWasm.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 4067e5e5..3ffedbfd 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -696,6 +696,11 @@ object WasmToScalaCompiler { } trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { + override def mayInline(n: Node): Boolean = n match { + case Node(s, "stack-pop", _, _) => false + case _ => super.mayInline(n) + } + override def remap(m: Manifest[_]): String = { if (m.toString.endsWith("Num")) "Num" else if (m.toString.endsWith("Slice")) "Slice" From 39baa4aa6d5187ca347db5c25b56c3f201684735 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Tue, 20 May 2025 10:45:27 +0800 Subject: [PATCH 44/62] make topFun work --- src/main/scala/wasm/StagedMiniWasm.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 3ffedbfd..aa33a0d7 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -7,7 +7,7 @@ import lms.core.virtualize import lms.macros.SourceContext import lms.core.stub.{Base, ScalaGenBase, CGenBase} import lms.core.Backend._ -import lms.core.Backend.{Block => LMSBlock} +import lms.core.Backend.{Block => LMSBlock, Const => LMSConst} import lms.core.Graph import gensym.wasm.ast._ @@ -734,8 +734,11 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { emit("Frames.set("); shallow(i); emit(", "); shallow(value); emit(");\n") case Node(_, "global-set", List(i, value), _) => emit("Global.globalSet("); shallow(i); emit(", "); shallow(value); emit(");\n") + // Note: The following code is copied from the traverse of CppBackend.scala, try to avoid duplicated code + case n @ Node(f, "λ", (b: LMSBlock)::LMSConst(0)::rest, _) => + // TODO: Is a leading block followed by 0 a hint for top function? + super.traverse(n) case n @ Node(f, "λ", (b: LMSBlock)::rest, _) => - // Node: This code is copied from the traverse of CppSAICodeGenBase.scala, try to avoid code duplication val retType = remap(typeBlockRes(b.res)) val argTypes = b.in.map(a => remap(typeMap(a))).mkString(", ") emitln(s"std::function<$retType(${argTypes})> ${quote(f)};") From 598580378aefa59e8b701066a2f4f557178d3d5c Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Tue, 20 May 2025 20:26:33 +0800 Subject: [PATCH 45/62] update runtime --- src/main/scala/wasm/StagedMiniWasm.scala | 104 ++++------------------- 1 file changed, 16 insertions(+), 88 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index aa33a0d7..015f36a7 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -926,82 +926,20 @@ private: }; struct Num { - std::unique_ptr num_ptr; - - // Constructions and destruction - Num() : num_ptr(nullptr) {} - - Num(std::unique_ptr num_ptr_) : num_ptr(std::move(num_ptr_)) {} - - Num &operator=(const Num &other) { - if (this != &other) { - num_ptr = other.num_ptr ? other.num_ptr->clone() : nullptr; - } - return *this; - } - - Num(const Num &other) { - num_ptr = other.num_ptr ? other.num_ptr->clone() : nullptr; - } - - Num(Num &&other) noexcept = default; - - Num &operator=(Num &&other) noexcept = default; - - ~Num() = default; - - int32_t toInt() const { return num_ptr->toInt(); } - - int32_t toLong() const { return num_ptr->toLong(); } - - void display() const { num_ptr->display(); } - - Num operator+(const Num &other) const { - if (dynamic_cast(num_ptr.get()) && - dynamic_cast(other.num_ptr.get())) { - return Num( - std::make_unique(I32V_t(this->toInt() + other.toInt()))); - } else if (dynamic_cast(num_ptr.get()) && - dynamic_cast(other.num_ptr.get())) { - return Num( - std::make_unique(I64V_t(this->toLong() + other.toLong()))); - } else { - throw std::runtime_error("Operands are of different types"); - } - } - - Num operator-(const Num &other) const { - if (dynamic_cast(num_ptr.get()) && - dynamic_cast(other.num_ptr.get())) { - return Num( - std::make_unique(I32V_t(this->toInt() - other.toInt()))); - } else if (dynamic_cast(num_ptr.get()) && - dynamic_cast(other.num_ptr.get())) { - return Num( - std::make_unique(I64V_t(this->toLong() - other.toLong()))); - } else { - throw std::runtime_error("Operands are of different types"); - } - } - - bool operator==(const Num &other) const { - if (dynamic_cast(num_ptr.get()) && - dynamic_cast(other.num_ptr.get())) { - return this->toInt() == other.toInt(); - } else if (dynamic_cast(num_ptr.get()) && - dynamic_cast(other.num_ptr.get())) { - return this->toLong() == other.toLong(); - } else { - throw std::runtime_error("Operands are of different types"); - } - } - - bool operator!=(const Num &other) const { return !(this->operator==(other)); } + Num(int64_t value) : value(value) {} + Num() : value(0) {} + int64_t value; + int32_t toInt() { return static_cast(value); } + + bool operator==(const Num &other) const { return value == other.value; } + bool operator!=(const Num &other) const { return !(*this == other); } + Num operator+(const Num &other) const { return Num(value + other.value); } + Num operator-(const Num &other) const { return Num(value - other.value); } }; -static Num I32V(int v) { return Num(std::make_unique(v)); } +static Num I32V(int v) { return v; } -static Num I64V(int64_t v) { return Num(std::make_unique(v)); } +static Num I64V(int64_t v) { return v; } // struct Slice { // int32_t start; @@ -1013,22 +951,15 @@ using Slice = std::vector; class Stack_t { public: - void push(Num &&num) { - assert(num.num_ptr != nullptr); - stack_.push_back(std::move(num)); - } + void push(Num &&num) { stack_.push_back(std::move(num)); } - void push(Num &num) { - assert(num.num_ptr != nullptr); - stack_.push_back(num); - } + void push(Num &num) { stack_.push_back(num); } Num pop() { if (stack_.empty()) { throw std::runtime_error("Stack underflow"); } Num num = std::move(stack_.back()); - assert(num.num_ptr != nullptr); stack_.pop_back(); return num; } @@ -1044,7 +975,7 @@ public: assert(index >= 0); assert(index < stack_.size()); return stack_[index]; -} + } int32_t size() { return stack_.size(); } @@ -1070,7 +1001,7 @@ public: void print() { std::cout << "Stack contents: " << std::endl; for (const auto &num : stack_) { - num.display(); + std::cout << num.value << " "; } } @@ -1113,7 +1044,6 @@ public: Num get(std::int32_t index) { auto ret = top()[index]; - assert(ret.num_ptr != nullptr); return ret; } @@ -1131,9 +1061,7 @@ public: frames.push_back(frame); } - void putAll(Slice slice) { - top().putAll(slice); - } + void putAll(Slice slice) { top().putAll(slice); } private: std::vector frames; From d314ebede766f9ceac37055eb6430f2bee1679e1 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Tue, 20 May 2025 23:05:38 +0800 Subject: [PATCH 46/62] add all passed test cases --- src/main/scala/wasm/StagedMiniWasm.scala | 78 ++++++-------------- src/test/scala/genwasym/TestStagedEval.scala | 66 +++++++++++++++-- 2 files changed, 83 insertions(+), 61 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 015f36a7..f89573d8 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -68,6 +68,8 @@ trait StagedWasmEvaluator extends SAIOps { case _ => throw new Exception("Cannot set immutable global") } eval(rest, kont, trail) + case Store(op) => ??? + case Load(op) => ??? case MemorySize => ??? case MemoryGrow => ??? case MemoryFill => ??? @@ -156,6 +158,7 @@ trait StagedWasmEvaluator extends SAIOps { case Call(f) => evalCall(rest, kont, trail, f, false) case ReturnCall(f) => evalCall(rest, kont, trail, f, true) case _ => + ??? val todo = "todo-op".reflectCtrlWith[Unit]() eval(rest, kont, trail) } @@ -192,7 +195,7 @@ trait StagedWasmEvaluator extends SAIOps { Frames.putAll(args) callee(trail.last) } else { - val restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { + val restK: Rep[Cont[Unit]] = funHere((_: Rep[Unit]) => { info(s"Exiting the function at $funcIndex, stackSize =", Stack.size) Frames.popFrame() eval(rest, kont, trail) @@ -207,7 +210,7 @@ trait StagedWasmEvaluator extends SAIOps { | Import("spectest", "print_i32", _) => //println(s"[DEBUG] current stack: $stack") val v = Stack.pop() - println(v) + println(v.toInt) eval(rest, kont, trail) case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex") case _ => throw new Exception(s"Definition at $funcIndex is not callable") @@ -880,51 +883,6 @@ void info(const T &first, const Args &...args) { #endif } -class Num_t { -public: - virtual std::unique_ptr clone() const = 0; - - virtual void display() = 0; - virtual int32_t toInt() = 0; - virtual int64_t toLong() = 0; -}; - -class I32V_t : public Num_t { -public: - I32V_t(int32_t value) : value_(value) {} - - std::unique_ptr clone() const override { - return std::make_unique(*this); - } - - void display() override { std::cout << value_ << std::endl; } - - int32_t toInt() override { return value_; } - - int64_t toLong() override { return static_cast(value_); } - -private: - int32_t value_; -}; - -class I64V_t : public Num_t { -public: - I64V_t(int64_t value) : value_(value) {} - - std::unique_ptr clone() const override { - return std::make_unique(*this); - } - - void display() override { std::cout << value_ << std::endl; } - - int32_t toInt() override { return static_cast(value_); } - - int64_t toLong() override { return value_; } - -private: - int64_t value_; -}; - struct Num { Num(int64_t value) : value(value) {} Num() : value(0) {} @@ -935,18 +893,23 @@ struct Num { bool operator!=(const Num &other) const { return !(*this == other); } Num operator+(const Num &other) const { return Num(value + other.value); } Num operator-(const Num &other) const { return Num(value - other.value); } + Num operator*(const Num &other) const { return Num(value * other.value); } + Num operator/(const Num &other) const { + if (other.value == 0) { + throw std::runtime_error("Division by zero"); + } + return Num(value / other.value); + } + Num operator<(const Num &other) const { return Num(value < other.value); } + Num operator<=(const Num &other) const { return Num(value <= other.value); } + Num operator>(const Num &other) const { return Num(value > other.value); } + Num operator>=(const Num &other) const { return Num(value >= other.value); } }; static Num I32V(int v) { return v; } static Num I64V(int64_t v) { return v; } -// struct Slice { -// int32_t start; -// int32_t end; -// Slice(int32_t start_, int32_t end_) : start(start_), end(end_) {} -// }; - using Slice = std::vector; class Stack_t { @@ -1000,8 +963,8 @@ public: void print() { std::cout << "Stack contents: " << std::endl; - for (const auto &num : stack_) { - std::cout << num.value << " "; + for (auto it = stack_.rbegin(); it != stack_.rend(); ++it) { + std::cout << it->value << std::endl; } } @@ -1072,6 +1035,11 @@ static Frames_t Frames; static void initRand() { // for now, just do nothing } + +static std::monostate unreachable() { + std::cout << "Unreachable code reached!" << std::endl; + throw std::runtime_error("Unreachable code reached"); +} """ } diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index db45f1e5..ecc27b12 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -24,21 +24,75 @@ class TestStagedEval extends FunSuite { testFileToScala("./benchmarks/wasm/staged/pop.wat") } - def testFileToCpp(filename: String, main: Option[String] = None, printRes: Boolean = false) = { + def testFileToCpp(filename: String, main: Option[String] = None, expect: Option[List[Float]]=None) = { val moduleInst = ModuleInstance(Parser.parseFile(filename)) val code = WasmToCppCompiler.compile(moduleInst, main, true) - if (printRes) { - println(code) - } - val writer = new java.io.PrintWriter(new java.io.File(s"$filename.cpp")) + + val cppFile = s"$filename.cpp" + + val writer = new java.io.PrintWriter(new java.io.File(cppFile)) try { writer.write(code) } finally { writer.close() } + import sys.process._ + + val exe = s"$cppFile.exe" + val command = s"g++ -o $exe $cppFile" + + if (command.! != 0) { + throw new RuntimeException(s"Compilation failed for $cppFile") + } + + val result = s"./$exe".!! + println(result) + + expect.map(vs => { + val stackValues = result + .split("Stack contents: \n")(1) + .split("\n") + .map(_.toFloat) + .toList + assert(vs == stackValues) + }) + } + + test("ack-cpp") { testFileToCpp("./benchmarks/wasm/ack.wat", Some("real_main"), expect=Some(List(7))) } + test("power") { testFileToCpp("./benchmarks/wasm/pow.wat", Some("real_main"), expect=Some(List(1024))) } + test("start") { testFileToCpp("./benchmarks/wasm/start.wat") } + test("fact") { testFileToCpp("./benchmarks/wasm/fact.wat", None, expect=Some(List(120))) } + test("loop") { testFileToCpp("./benchmarks/wasm/loop.wat", None, expect=Some(List(10))) } + test("even-odd") { testFileToCpp("./benchmarks/wasm/even_odd.wat", None, expect=Some(List(1))) } + // test("load") { testFileToCpp("./benchmarks/wasm/load.wat", None, expect=Some(List(1))) } + // test("btree") { testFileToCpp("./benchmarks/wasm/btree/2o1u-unlabeled.wat") } + test("fib") { testFileToCpp("./benchmarks/wasm/fib.wat", None, expect=Some(List(144))) } + test("tribonacci") { testFileToCpp("./benchmarks/wasm/tribonacci.wat", None, expect=Some(List(504))) } + + test("return") { + intercept[java.lang.RuntimeException] { + testFileToCpp("./benchmarks/wasm/return.wat", Some("$real_main")) + } + } + test("return_call") { + testFileToCpp("./benchmarks/wasm/sum.wat", Some("sum10"), expect=Some(List(55))) } - test("ack-cpp") { testFileToCpp("./benchmarks/wasm/ack.wat", Some("real_main"), printRes = true) } + test("block input") { + testFileToCpp("./benchmarks/wasm/block.wat", Some("real_main"), expect=Some(List(9))) + } + test("loop block input") { + testFileToCpp("./benchmarks/wasm/block.wat", Some("test_loop_input"), expect=Some(List(55))) + } + test("if block input") { + testFileToCpp("./benchmarks/wasm/block.wat", Some("test_if_input"), expect=Some(List(25))) + } + // test("block input - poly br") { + // testFileToCpp("./benchmarks/wasm/block.wat", Some("test_poly_br"), expect=Some(List(0))) + // } + // test("loop block - poly br") { + // testFileToCpp("./benchmarks/wasm/loop_poly.wat", None, expect=Some(List(2, 1))) + // } test("brtable-cpp") { testFileToCpp("./benchmarks/wasm/staged/brtable.wat") From 1f902e0b26e6d8d14e4a79cae85641342cebeee1 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Tue, 20 May 2025 23:42:55 +0800 Subject: [PATCH 47/62] store/load operation --- src/main/scala/wasm/StagedMiniWasm.scala | 49 +++++++++++++++++++- src/test/scala/genwasym/TestStagedEval.scala | 2 +- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index f89573d8..f167618b 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -68,8 +68,16 @@ trait StagedWasmEvaluator extends SAIOps { case _ => throw new Exception("Cannot set immutable global") } eval(rest, kont, trail) - case Store(op) => ??? - case Load(op) => ??? + case Store(StoreOp(align, offset, ty, None)) => + val value = Stack.pop() + val addr = Stack.pop() + Memory.storeInt(addr.toInt, offset, value.toInt) + eval(rest, kont, trail) + case Load(LoadOp(align, offset, ty, None, None)) => + val addr = Stack.pop() + val value = Memory.loadInt(addr.toInt, offset) + Stack.push(Values.I32(value)) + eval(rest, kont, trail) case MemorySize => ??? case MemoryGrow => ??? case MemoryFill => ??? @@ -361,6 +369,15 @@ trait StagedWasmEvaluator extends SAIOps { } } + object Memory { + def storeInt(base: Rep[Int], offset: Int, value: Rep[Int]): Rep[Unit] = { + "memory-store-int".reflectCtrlWith[Unit](base, offset, value) + } + + def loadInt(base: Rep[Int], offset: Int): Rep[Int] = { + "memory-load-int".reflectCtrlWith[Int](base, offset) + } + } // call unreachable def unreachable(): Rep[Unit] = { @@ -765,6 +782,10 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { emit("Stack.take("); shallow(n); emit(")") case Node(_, "slice-reverse", List(slice), _) => shallow(slice); emit(".reverse") + case Node(_, "memory-store-int", List(base, offset, value), _) => + emit("Memory.storeInt("); shallow(base); emit(", "); shallow(offset); emit(", "); shallow(value); emit(")") + case Node(_, "memory-load-int", List(base, offset), _) => + emit("Memory.loadInt("); shallow(base); emit(", "); shallow(offset); emit(")") case Node(_, "stack-size", _, _) => emit("Stack.size()") case Node(_, "global-get", List(i), _) => @@ -1040,6 +1061,30 @@ static std::monostate unreachable() { std::cout << "Unreachable code reached!" << std::endl; throw std::runtime_error("Unreachable code reached"); } + +struct Memory_t { + void *memory; + Memory_t(size_t size) { + memory = malloc(size); + if (!memory) { + throw std::runtime_error("Memory allocation failed"); + } + } + ~Memory_t() { free(memory); } + + int32_t loadInt(int32_t base, int32_t offset) { + return *reinterpret_cast(static_cast(memory) + base + + offset); + } + + std::monostate storeInt(int32_t base, int32_t offset, int32_t value) { + *reinterpret_cast(static_cast(memory) + base + + offset) = value; + return std::monostate{}; + } +}; + +static Memory_t Memory(1024 * 1024); // 1MB memory """ } diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index ecc27b12..051562c7 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -64,7 +64,7 @@ class TestStagedEval extends FunSuite { test("fact") { testFileToCpp("./benchmarks/wasm/fact.wat", None, expect=Some(List(120))) } test("loop") { testFileToCpp("./benchmarks/wasm/loop.wat", None, expect=Some(List(10))) } test("even-odd") { testFileToCpp("./benchmarks/wasm/even_odd.wat", None, expect=Some(List(1))) } - // test("load") { testFileToCpp("./benchmarks/wasm/load.wat", None, expect=Some(List(1))) } + test("load") { testFileToCpp("./benchmarks/wasm/load.wat", None, expect=Some(List(1))) } // test("btree") { testFileToCpp("./benchmarks/wasm/btree/2o1u-unlabeled.wat") } test("fib") { testFileToCpp("./benchmarks/wasm/fib.wat", None, expect=Some(List(144))) } test("tribonacci") { testFileToCpp("./benchmarks/wasm/tribonacci.wat", None, expect=Some(List(504))) } From 6a1db1db2f1134f3410ca5b4e6395298b48860bd Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 21 May 2025 00:46:21 +0800 Subject: [PATCH 48/62] more memory operations --- src/main/scala/wasm/StagedMiniWasm.scala | 89 +++++++++++++------- src/test/scala/genwasym/TestStagedEval.scala | 3 +- 2 files changed, 60 insertions(+), 32 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index f167618b..9f24a76f 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -79,7 +79,9 @@ trait StagedWasmEvaluator extends SAIOps { Stack.push(Values.I32(value)) eval(rest, kont, trail) case MemorySize => ??? - case MemoryGrow => ??? + case MemoryGrow => + val delta = Stack.pop() + Stack.push(Values.I32(Memory.grow(delta.toInt))) case MemoryFill => ??? case Nop => eval(rest, kont, trail) @@ -244,7 +246,10 @@ trait StagedWasmEvaluator extends SAIOps { // case ShrS(_) => v1 >> v2 // TODO: signed shift right case ShrU(_) => v1 >> v2 case And(_) => v1 & v2 - case _ => ??? + case DivS(_) => v1 / v2 + case DivU(_) => v1 / v2 + case _ => + throw new Exception(s"Unknown binary operation $op") } def evalRelOp(op: RelOp, v1: Rep[Num], v2: Rep[Num]): Rep[Num] = op match { @@ -377,6 +382,10 @@ trait StagedWasmEvaluator extends SAIOps { def loadInt(base: Rep[Int], offset: Int): Rep[Int] = { "memory-load-int".reflectCtrlWith[Int](base, offset) } + + def grow(delta: Rep[Int]): Rep[Int] = { + "memory-grow".reflectCtrlWith[Int](delta) + } } // call unreachable @@ -474,8 +483,6 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { } override def traverse(n: Node): Unit = n match { - case Node(_, "stack-push", List(value), _) => - emit("Stack.push("); shallow(value); emit(")\n") case Node(_, "stack-drop", List(n), _) => emit("Stack.drop("); shallow(n); emit(")\n") case Node(_, "stack-reset", List(n), _) => @@ -503,6 +510,8 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { emit("Frames.get("); shallow(i); emit(")") case Node(_, "frame-pop", _, _) => emit("Frames.popFrame()") + case Node(_, "stack-push", List(value), _) => + emit("Stack.push("); shallow(value); emit(")\n") case Node(_, "stack-pop", _, _) => emit("Stack.pop()") case Node(_, "stack-peek", _, _) => @@ -734,10 +743,6 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { // for now, the traverse/shallow is same as the scala backend's override def traverse(n: Node): Unit = n match { - case Node(_, "stack-push", List(value), _) => - emit("Stack.push("); shallow(value); emit(");\n") - case Node(_, "stack-drop", List(n), _) => - emit("Stack.drop("); shallow(n); emit(");\n") case Node(_, "stack-reset", List(n), _) => emit("Stack.reset("); shallow(n); emit(");\n") case Node(_, "stack-init", _, _) => @@ -755,16 +760,16 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { case Node(_, "global-set", List(i, value), _) => emit("Global.globalSet("); shallow(i); emit(", "); shallow(value); emit(");\n") // Note: The following code is copied from the traverse of CppBackend.scala, try to avoid duplicated code - case n @ Node(f, "λ", (b: LMSBlock)::LMSConst(0)::rest, _) => - // TODO: Is a leading block followed by 0 a hint for top function? - super.traverse(n) - case n @ Node(f, "λ", (b: LMSBlock)::rest, _) => - val retType = remap(typeBlockRes(b.res)) - val argTypes = b.in.map(a => remap(typeMap(a))).mkString(", ") - emitln(s"std::function<$retType(${argTypes})> ${quote(f)};") - emit(quote(f)); emit(" = ") - quoteTypedBlock(b, false, true, capture = "&") - emitln(";") + // case n @ Node(f, "λ", (b: LMSBlock)::LMSConst(0)::rest, _) => + // // TODO: Is a leading block followed by 0 a hint for top function? + // super.traverse(n) + // case n @ Node(f, "λ", (b: LMSBlock)::rest, _) => + // val retType = remap(typeBlockRes(b.res)) + // val argTypes = b.in.map(a => remap(typeMap(a))).mkString(", ") + // emitln(s"std::function<$retType(${argTypes})> ${quote(f)};") + // emit(quote(f)); emit(" = ") + // quoteTypedBlock(b, false, true, capture = "&") + // emitln(";") case _ => super.traverse(n) } @@ -772,6 +777,10 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { override def shallow(n: Node): Unit = n match { case Node(_, "frame-get", List(i), _) => emit("Frames.get("); shallow(i); emit(")") + case Node(_, "stack-drop", List(n), _) => + emit("Stack.drop("); shallow(n); emit(")") + case Node(_, "stack-push", List(value), _) => + emit("Stack.push("); shallow(value); emit(")") case Node(_, "stack-pop", _, _) => emit("Stack.pop()") case Node(_, "frame-pop", _, _) => @@ -786,6 +795,8 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { emit("Memory.storeInt("); shallow(base); emit(", "); shallow(offset); emit(", "); shallow(value); emit(")") case Node(_, "memory-load-int", List(base, offset), _) => emit("Memory.loadInt("); shallow(base); emit(", "); shallow(offset); emit(")") + case Node(_, "memory-grow", List(delta), _) => + emit("Memory.grow("); shallow(delta); emit(")") case Node(_, "stack-size", _, _) => emit("Stack.size()") case Node(_, "global-get", List(i), _) => @@ -935,9 +946,15 @@ using Slice = std::vector; class Stack_t { public: - void push(Num &&num) { stack_.push_back(std::move(num)); } + std::monostate push(Num &&num) { + stack_.push_back(std::move(num)); + return std::monostate{}; + } - void push(Num &num) { stack_.push_back(num); } + std::monostate push(Num &num) { + stack_.push_back(num); + return std::monostate{}; + } Num pop() { if (stack_.empty()) { @@ -1063,27 +1080,37 @@ static std::monostate unreachable() { } struct Memory_t { - void *memory; - Memory_t(size_t size) { - memory = malloc(size); - if (!memory) { - throw std::runtime_error("Memory allocation failed"); - } - } - ~Memory_t() { free(memory); } + std::vector memory; + Memory_t(size_t size) : memory(size) {} int32_t loadInt(int32_t base, int32_t offset) { - return *reinterpret_cast(static_cast(memory) + base + - offset); + return *reinterpret_cast(static_cast(memory.data()) + + base + offset); } std::monostate storeInt(int32_t base, int32_t offset, int32_t value) { - *reinterpret_cast(static_cast(memory) + base + + *reinterpret_cast(static_cast(memory.data()) + base + offset) = value; return std::monostate{}; } + + // grow memory by delta bytes when bytes > 0. return -1 if failed, return old + // size when success + int32_t grow(int32_t delta) { + if (delta <= 0) { + return memory.size(); + } + + try { + memory.resize(memory.size() + delta); + return memory.size(); + } catch (const std::bad_alloc &e) { + return -1; + } + } }; + static Memory_t Memory(1024 * 1024); // 1MB memory """ } diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index 051562c7..d851f070 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -65,7 +65,8 @@ class TestStagedEval extends FunSuite { test("loop") { testFileToCpp("./benchmarks/wasm/loop.wat", None, expect=Some(List(10))) } test("even-odd") { testFileToCpp("./benchmarks/wasm/even_odd.wat", None, expect=Some(List(1))) } test("load") { testFileToCpp("./benchmarks/wasm/load.wat", None, expect=Some(List(1))) } - // test("btree") { testFileToCpp("./benchmarks/wasm/btree/2o1u-unlabeled.wat") } + // TODO: this case will fail because of some undefined variables + test("btree") { testFileToCpp("./benchmarks/wasm/btree/2o1u-unlabeled.wat") } test("fib") { testFileToCpp("./benchmarks/wasm/fib.wat", None, expect=Some(List(144))) } test("tribonacci") { testFileToCpp("./benchmarks/wasm/tribonacci.wat", None, expect=Some(List(504))) } From a2b63f9f133e05230cccdfc147cae0947322dd57 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 21 May 2025 14:30:19 +0800 Subject: [PATCH 49/62] some fixes --- src/main/scala/wasm/StagedMiniWasm.scala | 50 ++++++++++++++---------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 9f24a76f..2a5b82a2 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -40,7 +40,7 @@ trait StagedWasmEvaluator extends SAIOps { def eval(insts: List[Instr], kont: Rep[Cont[Unit]], trail: Trail[Unit]): Rep[Unit] = { - if (insts.isEmpty) return kont() + if (insts.isEmpty) return kont(()) val (inst, rest) = (insts.head, insts.tail) inst match { case Drop => @@ -82,10 +82,13 @@ trait StagedWasmEvaluator extends SAIOps { case MemoryGrow => val delta = Stack.pop() Stack.push(Values.I32(Memory.grow(delta.toInt))) + eval(rest, kont, trail) case MemoryFill => ??? case Nop => eval(rest, kont, trail) - case Unreachable => unreachable() + case Unreachable => + unreachable() + eval(rest, kont, trail) case Test(op) => val v = Stack.pop() Stack.push(evalTestOp(op, v)) @@ -108,7 +111,6 @@ trait StagedWasmEvaluator extends SAIOps { // no need to modify the stack when entering a block // the type system guarantees that we will never take more than the input size from the stack val funcTy = ty.funcType - // TODO: somehow the type of exitSize in residual program is nothing def restK: Rep[Cont[Unit]] = funHere((_: Rep[Unit]) => { info(s"Exiting the block, stackSize =", Stack.size) eval(rest, kont, trail) @@ -116,7 +118,6 @@ trait StagedWasmEvaluator extends SAIOps { eval(inner, restK, restK :: trail) case Loop(ty, inner) => val funcTy = ty.funcType - val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size def restK = funHere((_: Rep[Unit]) => { info(s"Exiting the loop, stackSize =", Stack.size) eval(rest, kont, trail) @@ -129,10 +130,10 @@ trait StagedWasmEvaluator extends SAIOps { loop(()) case If(ty, thn, els) => val funcTy = ty.funcType - val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size val cond = Stack.pop() // TODO: can we avoid code duplication here? - def restK = funHere((_: Rep[Unit]) => { + // NOTE: if we define restK by `def` rather than val, some errors will be triggered + val restK = funHere((_: Rep[Unit]) => { info(s"Exiting the if, stackSize =", Stack.size) eval(rest, kont, trail) }) @@ -141,6 +142,7 @@ trait StagedWasmEvaluator extends SAIOps { } else { eval(els, restK, restK :: trail) } + () case Br(label) => info(s"Jump to $label") trail(label)(()) @@ -760,16 +762,16 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { case Node(_, "global-set", List(i, value), _) => emit("Global.globalSet("); shallow(i); emit(", "); shallow(value); emit(");\n") // Note: The following code is copied from the traverse of CppBackend.scala, try to avoid duplicated code - // case n @ Node(f, "λ", (b: LMSBlock)::LMSConst(0)::rest, _) => - // // TODO: Is a leading block followed by 0 a hint for top function? - // super.traverse(n) - // case n @ Node(f, "λ", (b: LMSBlock)::rest, _) => - // val retType = remap(typeBlockRes(b.res)) - // val argTypes = b.in.map(a => remap(typeMap(a))).mkString(", ") - // emitln(s"std::function<$retType(${argTypes})> ${quote(f)};") - // emit(quote(f)); emit(" = ") - // quoteTypedBlock(b, false, true, capture = "&") - // emitln(";") + case n @ Node(f, "λ", (b: LMSBlock)::LMSConst(0)::rest, _) => + // TODO: Is a leading block followed by 0 a hint for top function? + super.traverse(n) + case n @ Node(f, "λ", (b: LMSBlock)::rest, _) => + val retType = remap(typeBlockRes(b.res)) + val argTypes = b.in.map(a => remap(typeMap(a))).mkString(", ") + emitln(s"std::function<$retType(${argTypes})> ${quote(f)};") + emit(quote(f)); emit(" = ") + quoteTypedBlock(b, false, true, capture = "&") + emitln(";") case _ => super.traverse(n) } @@ -936,6 +938,7 @@ struct Num { Num operator<=(const Num &other) const { return Num(value <= other.value); } Num operator>(const Num &other) const { return Num(value > other.value); } Num operator>=(const Num &other) const { return Num(value >= other.value); } + Num operator&(const Num &other) const { return Num(value & other.value); } }; static Num I32V(int v) { return v; } @@ -991,7 +994,7 @@ public: Slice take(int32_t size) { if (size > stack_.size()) { - throw std::out_of_range("Invalid size"); + throw std::out_of_range("Invalid size: requested " + std::to_string(size) + ", stack size is " + std::to_string(stack_.size())); } // todo: avoid re-allocation Slice slice(stack_.end() - size, stack_.end()); @@ -1079,9 +1082,12 @@ static std::monostate unreachable() { throw std::runtime_error("Unreachable code reached"); } +static int32_t pagesize = 65536; +static int32_t page_count = 0; + struct Memory_t { std::vector memory; - Memory_t(size_t size) : memory(size) {} + Memory_t(int32_t init_page_count) : memory(init_page_count * pagesize) {} int32_t loadInt(int32_t base, int32_t offset) { return *reinterpret_cast(static_cast(memory.data()) + @@ -1102,7 +1108,9 @@ struct Memory_t { } try { - memory.resize(memory.size() + delta); + memory.resize(memory.size() + delta * pagesize); + auto old_page_count = page_count; + page_count += delta; return memory.size(); } catch (const std::bad_alloc &e) { return -1; @@ -1111,8 +1119,8 @@ struct Memory_t { }; -static Memory_t Memory(1024 * 1024); // 1MB memory - """ +static Memory_t Memory(1); // 1 page memory +""" } From d1ba8991f0dd251edac28b813b738eca1877ce33 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 21 May 2025 15:34:24 +0800 Subject: [PATCH 50/62] some little polish --- src/main/scala/wasm/StagedMiniWasm.scala | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 2a5b82a2..e52ae0e6 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -84,11 +84,8 @@ trait StagedWasmEvaluator extends SAIOps { Stack.push(Values.I32(Memory.grow(delta.toInt))) eval(rest, kont, trail) case MemoryFill => ??? - case Nop => - eval(rest, kont, trail) - case Unreachable => - unreachable() - eval(rest, kont, trail) + case Nop => eval(rest, kont, trail) + case Unreachable => unreachable() case Test(op) => val v = Stack.pop() Stack.push(evalTestOp(op, v)) @@ -132,7 +129,7 @@ trait StagedWasmEvaluator extends SAIOps { val funcTy = ty.funcType val cond = Stack.pop() // TODO: can we avoid code duplication here? - // NOTE: if we define restK by `def` rather than val, some errors will be triggered + // NOTE: if we define restK by `def` rather than val, the generated code will contain some undefined variables val restK = funHere((_: Rep[Unit]) => { info(s"Exiting the if, stackSize =", Stack.size) eval(rest, kont, trail) @@ -170,7 +167,6 @@ trait StagedWasmEvaluator extends SAIOps { case Call(f) => evalCall(rest, kont, trail, f, false) case ReturnCall(f) => evalCall(rest, kont, trail, f, true) case _ => - ??? val todo = "todo-op".reflectCtrlWith[Unit]() eval(rest, kont, trail) } @@ -185,7 +181,6 @@ trait StagedWasmEvaluator extends SAIOps { case FuncDef(_, FuncBodyDef(ty, _, locals, body)) => val returnSize = Stack.size - ty.inps.size + ty.out.size val args = Stack.take(ty.inps.size) - // info("New frame:", Frames.top) val callee = if (compileCache.contains(funcIndex)) { compileCache(funcIndex) @@ -220,7 +215,6 @@ trait StagedWasmEvaluator extends SAIOps { } case Import("console", "log", _) | Import("spectest", "print_i32", _) => - //println(s"[DEBUG] current stack: $stack") val v = Stack.pop() println(v.toInt) eval(rest, kont, trail) @@ -311,7 +305,7 @@ trait StagedWasmEvaluator extends SAIOps { evalTop(temp, main) } - // stack creation and operations + // stack operations object Stack { def initialize(): Rep[Unit] = { "stack-init".reflectCtrlWith[Unit]() From 89d9a771cf56a94f5f344c82473dc6253c841c9f Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 21 May 2025 17:28:47 +0800 Subject: [PATCH 51/62] shift stack elements when exiting block instructions --- src/main/scala/wasm/StagedMiniWasm.scala | 37 ++++++++++++++++---- src/test/scala/genwasym/TestStagedEval.scala | 12 +++---- 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index e52ae0e6..b33a05a2 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -108,29 +108,41 @@ trait StagedWasmEvaluator extends SAIOps { // no need to modify the stack when entering a block // the type system guarantees that we will never take more than the input size from the stack val funcTy = ty.funcType - def restK: Rep[Cont[Unit]] = funHere((_: Rep[Unit]) => { + val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size + def restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { info(s"Exiting the block, stackSize =", Stack.size) + val offset = Stack.size - exitSize + Stack.shift(offset, funcTy.out.size) eval(rest, kont, trail) }) eval(inner, restK, restK :: trail) case Loop(ty, inner) => val funcTy = ty.funcType + val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size def restK = funHere((_: Rep[Unit]) => { info(s"Exiting the loop, stackSize =", Stack.size) + val offset = Stack.size - exitSize + Stack.shift(offset, funcTy.out.size) eval(rest, kont, trail) }) + val enterSize = Stack.size val dummy = "dummy".reflectCtrlWith[Unit]() def loop : Rep[Unit => Unit] = funHere((_u: Rep[Unit]) => { info(s"Entered the loop, stackSize =", Stack.size) + val offset = Stack.size - enterSize + Stack.shift(offset, funcTy.inps.size) eval(inner, restK, loop :: trail) }, dummy) // <-- if we don't pass this dummy argument, lots of code will be generated loop(()) case If(ty, thn, els) => val funcTy = ty.funcType val cond = Stack.pop() + val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size // TODO: can we avoid code duplication here? // NOTE: if we define restK by `def` rather than val, the generated code will contain some undefined variables val restK = funHere((_: Rep[Unit]) => { + val offset = Stack.size - exitSize + Stack.shift(offset, funcTy.out.size) info(s"Exiting the if, stackSize =", Stack.size) eval(rest, kont, trail) }) @@ -327,6 +339,12 @@ trait StagedWasmEvaluator extends SAIOps { "stack-drop".reflectCtrlWith[Unit](n) } + def shift(offset: Rep[Int], size: Rep[Int]): Rep[Unit] = { + if (offset > 0) { + "stack-shift".reflectCtrlWith[Unit](offset, size) + } + } + def print(): Rep[Unit] = { "stack-print".reflectCtrlWith[Unit]() } @@ -777,6 +795,8 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { emit("Stack.drop("); shallow(n); emit(")") case Node(_, "stack-push", List(value), _) => emit("Stack.push("); shallow(value); emit(")") + case Node(_, "stack-shift", List(offset, size), _) => + emit("Stack.shift("); shallow(offset); emit(", "); shallow(size); emit(")") case Node(_, "stack-pop", _, _) => emit("Stack.pop()") case Node(_, "frame-pop", _, _) => @@ -977,13 +997,18 @@ public: int32_t size() { return stack_.size(); } - void reset(int32_t size) { - if (size > stack_.size()) { - throw std::out_of_range("Invalid size"); + void shift(int32_t offset, int32_t size) { + if (offset < 0) { + throw std::out_of_range("Invalid offset: " + std::to_string(offset)); + } + if (size < 0) { + throw std::out_of_range("Invalid size: " + std::to_string(size)); } - while (stack_.size() > size) { - stack_.pop_back(); + // shift last `size` of numbers forward of `offset` + for (int32_t i = stack_.size() - size; i < stack_.size(); ++i) { + stack_[i - offset] = stack_[i]; } + stack_.resize(stack_.size() - offset); } Slice take(int32_t size) { diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index d851f070..039da366 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -88,12 +88,12 @@ class TestStagedEval extends FunSuite { test("if block input") { testFileToCpp("./benchmarks/wasm/block.wat", Some("test_if_input"), expect=Some(List(25))) } - // test("block input - poly br") { - // testFileToCpp("./benchmarks/wasm/block.wat", Some("test_poly_br"), expect=Some(List(0))) - // } - // test("loop block - poly br") { - // testFileToCpp("./benchmarks/wasm/loop_poly.wat", None, expect=Some(List(2, 1))) - // } + test("block input - poly br") { + testFileToCpp("./benchmarks/wasm/block.wat", Some("test_poly_br"), expect=Some(List(0))) + } + test("loop block - poly br") { + testFileToCpp("./benchmarks/wasm/loop_poly.wat", None, expect=Some(List(2, 1))) + } test("brtable-cpp") { testFileToCpp("./benchmarks/wasm/staged/brtable.wat") From 8b4429f9061f67b6cae0e65375b4ce122bbdd83a Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Fri, 23 May 2025 16:53:09 +0800 Subject: [PATCH 52/62] fix: evalTop should be aware of frame size --- src/main/scala/wasm/MiniWasm.scala | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/src/main/scala/wasm/MiniWasm.scala b/src/main/scala/wasm/MiniWasm.scala index 84a8bd88..0fb12790 100644 --- a/src/main/scala/wasm/MiniWasm.scala +++ b/src/main/scala/wasm/MiniWasm.scala @@ -417,31 +417,42 @@ case class Evaluator(module: ModuleInstance) { // If `main` is given, then we use that function as the entry point of the program; // otherwise, we look up the top-level `start` instruction to locate the entry point. def evalTop[Ans](halt: Cont[Ans], main: Option[String] = None): Ans = { - val instrs = main match { + val entryFuncDefs = main match { case Some(func_name) => module.defs.flatMap({ case Export(`func_name`, ExportFunc(fid)) => println(s"Entering function $main") module.funcs(fid) match { - case FuncDef(_, FuncBodyDef(_, _, _, body)) => body + case FuncDef(_, funcDef @ FuncBodyDef(_, _, _, _)) => Some(funcDef) case _ => throw new Exception("Entry function has no concrete body") } - case _ => List() + case _ => None }) case None => module.defs.flatMap({ case Start(id) => println(s"Entering unnamed function $id") module.funcs(id) match { - case FuncDef(_, FuncBodyDef(_, _, _, body)) => body + case FuncDef(_, funcDef @ FuncBodyDef(_, _, _, _)) => Some(funcDef) case _ => throw new Exception("Entry function has no concrete body") } - case _ => List() + case _ => None }) } - if (instrs.isEmpty) println("Warning: nothing is executed") - eval(instrs, List(), Frame(ArrayBuffer(I32V(0))), halt, List(halt)) + + entryFuncDefs match { + case FuncBodyDef(_, _, locals, body) :: Nil => + val frame = Frame(ArrayBuffer(locals.map(zero(_)): _*)) + if (body.isEmpty) println("Warning: nothing is executed") + eval(body, List(), frame, halt, List(halt)) + case Nil => + println("Warning: no entry point found") + halt(List()) + case _ => + println("Warning: multiple entry points found") + halt(List()) + } } def evalTop(m: ModuleInstance): Unit = evalTop(stack => ()) From 8f4d6e027e6019ed679790636c515ce7490848ec Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Fri, 23 May 2025 17:04:36 +0800 Subject: [PATCH 53/62] comment IO statements --- src/main/scala/wasm/MiniWasm.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/scala/wasm/MiniWasm.scala b/src/main/scala/wasm/MiniWasm.scala index 0fb12790..4b672bd4 100644 --- a/src/main/scala/wasm/MiniWasm.scala +++ b/src/main/scala/wasm/MiniWasm.scala @@ -421,7 +421,7 @@ case class Evaluator(module: ModuleInstance) { case Some(func_name) => module.defs.flatMap({ case Export(`func_name`, ExportFunc(fid)) => - println(s"Entering function $main") + // println(s"Entering function $main") module.funcs(fid) match { case FuncDef(_, funcDef @ FuncBodyDef(_, _, _, _)) => Some(funcDef) case _ => throw new Exception("Entry function has no concrete body") @@ -431,7 +431,7 @@ case class Evaluator(module: ModuleInstance) { case None => module.defs.flatMap({ case Start(id) => - println(s"Entering unnamed function $id") + // println(s"Entering unnamed function $id") module.funcs(id) match { case FuncDef(_, funcDef @ FuncBodyDef(_, _, _, _)) => Some(funcDef) case _ => From e68218cf65ea9873240565a726a0b30f584b0ca5 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Fri, 23 May 2025 17:11:19 +0800 Subject: [PATCH 54/62] benchmark code --- benchmarks/wasm/performance/ack.wat | 70 ++++++++++++++++++++ benchmarks/wasm/performance/pow.wat | 55 +++++++++++++++ src/test/scala/genwasym/TestStagedEval.scala | 63 ++++++++++++++++++ 3 files changed, 188 insertions(+) create mode 100644 benchmarks/wasm/performance/ack.wat create mode 100644 benchmarks/wasm/performance/pow.wat diff --git a/benchmarks/wasm/performance/ack.wat b/benchmarks/wasm/performance/ack.wat new file mode 100644 index 00000000..f25dbead --- /dev/null +++ b/benchmarks/wasm/performance/ack.wat @@ -0,0 +1,70 @@ +(module $ack.wat.temp + (type (;0;) (func (param i32 i32) (result i32))) + (type (;1;) (func (result i32))) + (func $ack (type 0) (param i32 i32) (result i32) + local.get 0 + local.set 0 + local.get 1 + local.set 1 + block ;; label = @1 + loop ;; label = @2 + local.get 1 + local.set 1 + local.get 0 + local.tee 0 + i32.eqz + br_if 1 (;@1;) + block ;; label = @3 + block ;; label = @4 + local.get 1 + br_if 0 (;@4;) + i32.const 1 + local.set 1 + br 1 (;@3;) + end + local.get 0 + local.get 1 + i32.const -1 + i32.add + call $ack + local.set 1 + end + local.get 0 + i32.const -1 + i32.add + local.set 0 + local.get 1 + local.set 1 + br 0 (;@2;) + end + end + local.get 1 + i32.const 1 + i32.add) + (func $real_main (type 1) (result i32) + (local i32 i32) + i32.const 10000 + local.set 0 + loop + i32.const 2 + i32.const 1 + call $ack + local.set 1 + local.get 0 + i32.const 1 + i32.sub + local.tee 0 + br_if 0 + end + local.get 1) + (table (;0;) 1 1 funcref) + (memory (;0;) 16) + (global $__stack_pointer (mut i32) (i32.const 1048576)) + (global (;1;) i32 (i32.const 1048576)) + (global (;2;) i32 (i32.const 1048576)) + (export "memory" (memory 0)) + (export "ack" (func 0)) + (export "real_main" (func 1)) + (start 1) + (export "__data_end" (global 1)) + (export "__heap_base" (global 2))) diff --git a/benchmarks/wasm/performance/pow.wat b/benchmarks/wasm/performance/pow.wat new file mode 100644 index 00000000..b6595beb --- /dev/null +++ b/benchmarks/wasm/performance/pow.wat @@ -0,0 +1,55 @@ +(module $pow.temp + (type (;0;) (func (param i32 i32) (result i32))) + (type (;1;) (func (result i32))) + (func $power (type 0) (param i32 i32) (result i32) + (local i32) + i32.const 1 + local.set 2 + local.get 1 + local.set 1 + block ;; label = @1 + loop ;; label = @2 + local.get 2 + local.set 2 + local.get 1 + local.tee 1 + i32.eqz + br_if 1 (;@1;) + local.get 2 + local.get 0 + i32.mul + local.set 2 + local.get 1 + i32.const -1 + i32.add + local.set 1 + br 0 (;@2;) + end + end + local.get 2) + (func $real_main (type 1) (result i32) + (local i32 i32) + i32.const 10000 ;; loop counter + local.set 0 ;; reuse param 0 as loop counter + loop ;; label = @2 + i32.const 2 + i32.const 20 + call $power + local.set 1 + local.get 0 + i32.const 1 + i32.sub + local.tee 0 + br_if 0 ;; continue loop if counter != 0 + end + local.get 1) + (table (;0;) 1 1 funcref) + (memory (;0;) 16) + (global $__stack_pointer (mut i32) (i32.const 1048576)) + (global (;1;) i32 (i32.const 1048576)) + (global (;2;) i32 (i32.const 1048576)) + (export "memory" (memory 0)) + (export "power" (func 0)) + (export "real_main" (func 1)) + (export "__data_end" (global 1)) + (export "__heap_base" (global 2))) diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index 039da366..eabccb1d 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -100,3 +100,66 @@ class TestStagedEval extends FunSuite { } } + +object Benchmark extends App { + + def bench(f: => Unit): Double = { + import gensym.utils.Utils._ + // run a function f 20 times and return the average time taken + val times = for (i <- 1 to 20) yield { + time(f)._2 + } + times.sum / times.size.toDouble + } + + def benchmarkWasmInterpreter(filePath: String, main: Option[String] = None): Double = { + val moduleInst = ModuleInstance(Parser.parseFile(filePath)) + val evaluator = Evaluator(moduleInst) + val haltK: evaluator.Cont[Unit] = stack => () + bench { evaluator.evalTop(haltK, main) } + } + + def benchmarkWasmToCpp(filePath: String, main: Option[String] = None): Double = { + val moduleInst = ModuleInstance(Parser.parseFile(filePath)) + val code = WasmToCppCompiler.compile(moduleInst, main, false) + + val cppFile = s"$filePath.cpp" + + val writer = new java.io.PrintWriter(new java.io.File(cppFile)) + try { + writer.write(code) + } finally { + writer.close() + } + import sys.process._ + + val exe = s"$cppFile.exe" + // use -O0 optimization to more accurately inspect the interpretation overhead that we reduced by compilation + val command = s"g++ -o $exe $cppFile -O0" + + if (command.! != 0) { + throw new RuntimeException(s"Compilation failed for $cppFile") + } + + println(s"Running $exe") + bench { s"./$exe".! } + } + + case class BenchmarkResult(filePath: String, interpretExecutionTime: Double, compiledExecutionTime: Double) + + def benchmarkFile(filePath: String, main: Option[String] = None): Unit = { + val interpretExecutionTime = benchmarkWasmInterpreter(filePath, main) + val compiledExecutionTime = benchmarkWasmToCpp(filePath, main) + val result = BenchmarkResult(filePath, interpretExecutionTime, compiledExecutionTime) + println(s"Benchmark result for $filePath:") + println(s" Average interpreter execution time: $interpretExecutionTime ms") + println(s" Average compiled execution time: $compiledExecutionTime ms") + println(s" Speedup: ${interpretExecutionTime / compiledExecutionTime}x") + println() + } + + override def main(args: Array[String]): Unit = { + benchmarkFile("./benchmarks/wasm/performance/ack.wat", Some("real_main")) + benchmarkFile("./benchmarks/wasm/performance/pow.wat", Some("real_main")) + } +} From 46b4894cf527da617f39c0a5210d2ffaf7abf2e0 Mon Sep 17 00:00:00 2001 From: Guannan Wei Date: Fri, 23 May 2025 17:34:33 +0200 Subject: [PATCH 55/62] with std c++17 --- src/test/scala/genwasym/TestStagedEval.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index eabccb1d..fd8b75ba 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -135,7 +135,7 @@ object Benchmark extends App { val exe = s"$cppFile.exe" // use -O0 optimization to more accurately inspect the interpretation overhead that we reduced by compilation - val command = s"g++ -o $exe $cppFile -O0" + val command = s"g++ -std=c++17 -o $exe $cppFile -O0" if (command.! != 0) { throw new RuntimeException(s"Compilation failed for $cppFile") From 4dd702f85694f16f10f60b685a4cfce76cffbe3a Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sat, 24 May 2025 00:12:32 +0800 Subject: [PATCH 56/62] ensure the compiled program is executed correctly --- src/test/scala/genwasym/TestStagedEval.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index fd8b75ba..581e6e45 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -142,7 +142,7 @@ object Benchmark extends App { } println(s"Running $exe") - bench { s"./$exe".! } + bench { assert(s"./$exe".! == 0, s"Execution of $exe failed") } } case class BenchmarkResult(filePath: String, interpretExecutionTime: Double, compiledExecutionTime: Double) From f3861b1ed8cecf6c811ad76514ccc9c5aeeb250a Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sat, 17 May 2025 00:12:14 +0800 Subject: [PATCH 57/62] utilize type information --- benchmarks/wasm/staged/brtable.wat | 3 +- src/main/scala/wasm/StagedMiniWasm.scala | 603 +++++++++++++++-------- 2 files changed, 404 insertions(+), 202 deletions(-) diff --git a/benchmarks/wasm/staged/brtable.wat b/benchmarks/wasm/staged/brtable.wat index 91133d70..5d0d4856 100644 --- a/benchmarks/wasm/staged/brtable.wat +++ b/benchmarks/wasm/staged/brtable.wat @@ -4,7 +4,8 @@ i32.const 2 (block (block - br_table 0 1 0 + i32.const 1 + br_table 0 1 0 ;; br_table will consume an element from the stack ) ) ) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index b33a05a2..7deb98c8 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -18,17 +18,76 @@ import gensym.lmsx.{SAIDriver, StringOps, SAIOps, SAICodeGenBase, CppSAIDriver, trait StagedWasmEvaluator extends SAIOps { def module: ModuleInstance - trait Slice + trait ReturnSite + + trait StagedNum { + def tipe: ValueType = this match { + case I32(_) => NumType(I32Type) + case I64(_) => NumType(I64Type) + case F32(_) => NumType(F32Type) + case F64(_) => NumType(F64Type) + } + + def i: Rep[Num] + } + case class I32(i: Rep[Num]) extends StagedNum + case class I64(i: Rep[Num]) extends StagedNum + case class F32(i: Rep[Num]) extends StagedNum + case class F64(i: Rep[Num]) extends StagedNum + + implicit def toStagedNum(num: Num): StagedNum = { + num match { + case I32V(_) => I32(num) + case I64V(_) => I64(num) + case F32V(_) => F32(num) + case F64V(_) => F64(num) + } + } + + implicit class ValueTypeOps(ty: ValueType) { + def size: Int = ty match { + case NumType(I32Type) => 4 + case NumType(I64Type) => 8 + case NumType(F32Type) => 4 + case NumType(F64Type) => 8 + } + } + + case class Context( + stackTypes: List[ValueType], + frameTypes: List[ValueType] + ) { + def push(ty: ValueType): Context = { + Context(ty :: stackTypes, frameTypes) + } - trait Frame + def pop(): (ValueType, Context) = { + val (ty :: rest) = stackTypes + (ty, Context(rest, frameTypes)) + } + + def shift(offset: Int, size: Int): Context = { + // Predef.println(s"[DEBUG] Shifting stack by $offset, size $size, $this") + Predef.assert(offset >= 0, s"Context shift offset must be non-negative, get $offset") + if (offset == 0) { + this + } else { + this.copy( + stackTypes = stackTypes.take(size) ++ stackTypes.drop(offset + size) + ) + } + } + } type Cont[A] = Unit => A - type Trail[A] = List[Rep[Cont[A]]] + type Trail[A] = List[Context => Rep[Cont[A]]] // a cache storing the compiled code for each function, to reduce re-compilation val compileCache = new HashMap[Int, Rep[(Cont[Unit]) => Unit]] - def funHere[A:Manifest,B:Manifest](f: Rep[A] => Rep[B], dummy: Rep[Unit] = "dummy".reflectCtrlWith[Unit]()): Rep[A => B] = { + def makeDummy: Rep[Unit] = "dummy".reflectCtrlWith[Unit]() + + def funHere[A:Manifest,B:Manifest](f: Rep[A] => Rep[B], dummy: Rep[Unit]): Rep[A => B] = { // to avoid LMS lifting a function, we create a dummy node and read it inside function fun((x: Rep[A]) => { "dummy-op".reflectCtrlWith[Unit](dummy) @@ -36,146 +95,156 @@ trait StagedWasmEvaluator extends SAIOps { }) } - // NOTE: We don't support Ans type polymorphism yet + def eval(insts: List[Instr], - kont: Rep[Cont[Unit]], - trail: Trail[Unit]): Rep[Unit] = { - if (insts.isEmpty) return kont(()) + kont: Context => Rep[Cont[Unit]], + trail: Trail[Unit]) + (implicit ctx: Context): Rep[Unit] = { + if (insts.isEmpty) return kont(ctx)(()) + + // Predef.println(s"[DEBUG] Evaluating instructions: ${insts.mkString(", ")}") + // Predef.println(s"[DEBUG] Current context: $ctx") + val (inst, rest) = (insts.head, insts.tail) inst match { case Drop => - Stack.pop() - eval(rest, kont, trail) + val (_, newCtx) = Stack.pop() + eval(rest, kont, trail)(newCtx) case WasmConst(num) => - Stack.push(num) - eval(rest, kont, trail) + val newCtx = Stack.push(num) + eval(rest, kont, trail)(newCtx) case LocalGet(i) => - Stack.push(Frames.get(i)) - eval(rest, kont, trail) + val newCtx = Stack.push(Frames.get(i)) + eval(rest, kont, trail)(newCtx) case LocalSet(i) => - Frames.set(i, Stack.pop()) - eval(rest, kont, trail) + val (num, newCtx) = Stack.pop() + Frames.set(i, num)(newCtx) + eval(rest, kont, trail)(newCtx) case LocalTee(i) => - Frames.set(i, Stack.peek) - eval(rest, kont, trail) + val (num, newCtx) = Stack.peek + Frames.set(i, num) + eval(rest, kont, trail)(newCtx) case GlobalGet(i) => - Stack.push(Global.globalGet(i)) - eval(rest, kont, trail) + val newCtx = Stack.push(Globals(i)) + eval(rest, kont, trail)(newCtx) case GlobalSet(i) => - val value = Stack.pop() + val (value, newCtx) = Stack.pop() module.globals(i).ty match { - case GlobalType(tipe, true) => Global.globalSet(i, value) + case GlobalType(tipe, true) => Globals(i) = value case _ => throw new Exception("Cannot set immutable global") } - eval(rest, kont, trail) + eval(rest, kont, trail)(newCtx) case Store(StoreOp(align, offset, ty, None)) => - val value = Stack.pop() - val addr = Stack.pop() + val (value, newCtx1) = Stack.pop() + val (addr, newCtx2) = Stack.pop()(newCtx1) Memory.storeInt(addr.toInt, offset, value.toInt) - eval(rest, kont, trail) + eval(rest, kont, trail)(newCtx2) + case Nop => eval(rest, kont, trail) case Load(LoadOp(align, offset, ty, None, None)) => - val addr = Stack.pop() + val (addr, newCtx1) = Stack.pop() val value = Memory.loadInt(addr.toInt, offset) - Stack.push(Values.I32(value)) - eval(rest, kont, trail) + val newCtx2 = Stack.push(Values.I32V(value))(newCtx1) + eval(rest, kont, trail)(newCtx2) case MemorySize => ??? case MemoryGrow => - val delta = Stack.pop() - Stack.push(Values.I32(Memory.grow(delta.toInt))) - eval(rest, kont, trail) + val (delta, newCtx1) = Stack.pop() + val newCtx2 = Stack.push(Values.I32V(Memory.grow(delta.toInt)))(newCtx1) + eval(rest, kont, trail)(newCtx2) case MemoryFill => ??? case Nop => eval(rest, kont, trail) case Unreachable => unreachable() case Test(op) => - val v = Stack.pop() - Stack.push(evalTestOp(op, v)) - eval(rest, kont, trail) + val (v, newCtx1) = Stack.pop() + val newCtx2 = Stack.push(evalTestOp(op, v))(newCtx1) + eval(rest, kont, trail)(newCtx2) case Unary(op) => - val v = Stack.pop() - Stack.push(evalUnaryOp(op, v)) - eval(rest, kont, trail) + val (v, newCtx1) = Stack.pop() + val newCtx2 = Stack.push(evalUnaryOp(op, v))(newCtx1) + eval(rest, kont, trail)(newCtx2) case Binary(op) => - val v2 = Stack.pop() - val v1 = Stack.pop() - Stack.push(evalBinOp(op, v1, v2)) - eval(rest, kont, trail) + val (v2, newCtx1) = Stack.pop() + val (v1, newCtx2) = Stack.pop()(newCtx1) + val newCtx3 = Stack.push(evalBinOp(op, v1, v2))(newCtx2) + eval(rest, kont, trail)(newCtx3) case Compare(op) => - val v2 = Stack.pop() - val v1 = Stack.pop() - Stack.push(evalRelOp(op, v1, v2)) - eval(rest, kont, trail) + val (v2, newCtx1) = Stack.pop() + val (v1, newCtx2) = Stack.pop()(newCtx1) + val newCtx3 = Stack.push(evalRelOp(op, v1, v2))(newCtx2) + eval(rest, kont, trail)(newCtx3) case WasmBlock(ty, inner) => // no need to modify the stack when entering a block // the type system guarantees that we will never take more than the input size from the stack val funcTy = ty.funcType - val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size - def restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { + val exitSize = ctx.stackTypes.size - funcTy.inps.size + funcTy.out.size + val dummy = makeDummy + def restK(restCtx: Context): Rep[Cont[Unit]] = funHere((_: Rep[Unit]) => { info(s"Exiting the block, stackSize =", Stack.size) - val offset = Stack.size - exitSize - Stack.shift(offset, funcTy.out.size) - eval(rest, kont, trail) - }) - eval(inner, restK, restK :: trail) + val offset = restCtx.stackTypes.size - exitSize + val newRestCtx = Stack.shift(offset, funcTy.out.size)(restCtx) + eval(rest, kont, trail)(newRestCtx) + }, dummy) + eval(inner, restK _, restK _ :: trail) case Loop(ty, inner) => val funcTy = ty.funcType - val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size - def restK = funHere((_: Rep[Unit]) => { + val exitSize = ctx.stackTypes.size - funcTy.inps.size + funcTy.out.size + val dummy = makeDummy + def restK(restCtx: Context): Rep[Cont[Unit]] = funHere((_: Rep[Unit]) => { info(s"Exiting the loop, stackSize =", Stack.size) - val offset = Stack.size - exitSize - Stack.shift(offset, funcTy.out.size) - eval(rest, kont, trail) - }) - val enterSize = Stack.size - val dummy = "dummy".reflectCtrlWith[Unit]() - def loop : Rep[Unit => Unit] = funHere((_u: Rep[Unit]) => { + val offset = restCtx.stackTypes.size - exitSize + val newRestCtx = Stack.shift(offset, funcTy.out.size)(restCtx) + eval(rest, kont, trail)(newRestCtx) + }, dummy) + val enterSize = ctx.stackTypes.size + def loop(restCtx: Context): Rep[Unit => Unit] = funHere((_u: Rep[Unit]) => { info(s"Entered the loop, stackSize =", Stack.size) - val offset = Stack.size - enterSize - Stack.shift(offset, funcTy.inps.size) - eval(inner, restK, loop :: trail) + val offset = restCtx.stackTypes.size - enterSize + val newRestCtx = Stack.shift(offset, funcTy.inps.size)(restCtx) + eval(inner, restK _, loop _ :: trail)(newRestCtx) }, dummy) // <-- if we don't pass this dummy argument, lots of code will be generated - loop(()) + loop(ctx)(()) case If(ty, thn, els) => val funcTy = ty.funcType - val cond = Stack.pop() - val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size + val (cond, newCtx) = Stack.pop() + val exitSize = newCtx.stackTypes.size - funcTy.inps.size + funcTy.out.size // TODO: can we avoid code duplication here? - // NOTE: if we define restK by `def` rather than val, the generated code will contain some undefined variables - val restK = funHere((_: Rep[Unit]) => { - val offset = Stack.size - exitSize - Stack.shift(offset, funcTy.out.size) + val dummy = makeDummy + def restK(restCtx: Context): Rep[Cont[Unit]] = funHere((_: Rep[Unit]) => { info(s"Exiting the if, stackSize =", Stack.size) - eval(rest, kont, trail) - }) - if (cond != Values.I32(0)) { - eval(thn, restK, restK :: trail) + val offset = restCtx.stackTypes.size - exitSize + val newRestCtx = Stack.shift(offset, funcTy.out.size)(restCtx) + eval(rest, kont, trail)(newRestCtx) + }, dummy) + if (cond.toInt != 0) { + eval(thn, restK _, restK _ :: trail)(newCtx) } else { - eval(els, restK, restK :: trail) + eval(els, restK _, restK _ :: trail)(newCtx) } () case Br(label) => info(s"Jump to $label") - trail(label)(()) + trail(label)(ctx)(()) case BrIf(label) => - val cond = Stack.pop() + val (cond, newCtx) = Stack.pop() info(s"The br_if(${label})'s condition is ", cond.toInt) - if (cond != Values.I32(0)) { + if (cond.toInt != 0) { info(s"Jump to $label") - trail(label)(()) + trail(label)(newCtx)(()) } else { info(s"Continue") - eval(rest, kont, trail) + eval(rest, kont, trail)(newCtx) } + () case BrTable(labels, default) => - val cond = Stack.pop() + val (cond, newCtx) = Stack.pop() def aux(choices: List[Int], idx: Int): Rep[Unit] = { - if (choices.isEmpty) trail(default)(()) + if (choices.isEmpty) trail(default)(newCtx)(()) else { - if (cond.toInt == idx) trail(choices.head)(()) + if (cond.toInt == idx) trail(choices.head)(newCtx)(()) else aux(choices.tail, idx + 1) } } aux(labels, 0) - case Return => trail.last(()) + case Return => trail.last(ctx)(()) case Call(f) => evalCall(rest, kont, trail, f, false) case ReturnCall(f) => evalCall(rest, kont, trail, f, true) case _ => @@ -185,68 +254,70 @@ trait StagedWasmEvaluator extends SAIOps { } def evalCall(rest: List[Instr], - kont: Rep[Cont[Unit]], + kont: Context => Rep[Cont[Unit]], trail: Trail[Unit], funcIndex: Int, - isTail: Boolean): Rep[Unit] = { + isTail: Boolean) + (implicit ctx: Context): Rep[Unit] = { module.funcs(funcIndex) match { - case FuncDef(_, FuncBodyDef(ty, _, locals, body)) => - val returnSize = Stack.size - ty.inps.size + ty.out.size - val args = Stack.take(ty.inps.size) + case FuncDef(_, FuncBodyDef(ty, _, bodyLocals, body)) => + val locals = bodyLocals ++ ty.inps val callee = if (compileCache.contains(funcIndex)) { compileCache(funcIndex) } else { - val callee = topFun( - (kont: Rep[Cont[Unit]]) => { - info(s"Entered the function at $funcIndex, stackSize =", Stack.size) - eval(body, kont, kont::Nil): Rep[Unit] - } - ) + val callee = topFun((kont: Rep[Cont[Unit]]) => { + info(s"Entered the function at $funcIndex, stackSize =", Stack.size) + // we can do some check here to ensure the function returns correct size of stack + eval(body, (_: Context) => kont, ((_: Context) => kont)::Nil)(Context(Nil, locals)) + }) compileCache(funcIndex) = callee callee } - val frameSize = ty.inps.size + locals.size + // Predef.println(s"[DEBUG] locals size: ${locals.size}") + val (args, newCtx) = Stack.take(ty.inps.size) if (isTail) { // when tail call, return to the caller's return continuation Frames.popFrame() - Frames.pushFrame(frameSize) + Frames.pushFrame(locals) Frames.putAll(args) - callee(trail.last) + callee(trail.last(ctx)) } else { + val dummy = makeDummy val restK: Rep[Cont[Unit]] = funHere((_: Rep[Unit]) => { info(s"Exiting the function at $funcIndex, stackSize =", Stack.size) Frames.popFrame() - eval(rest, kont, trail) - }) + eval(rest, kont, trail)(newCtx.copy(stackTypes = ty.out.reverse ++ ctx.stackTypes.drop(ty.inps.size))) + }, dummy) // We make a new trail by `restK`, since function creates a new block to escape // (more or less like `return`) - Frames.pushFrame(frameSize) + Frames.pushFrame(locals) Frames.putAll(args) callee(restK) } case Import("console", "log", _) | Import("spectest", "print_i32", _) => - val v = Stack.pop() + //println(s"[DEBUG] current stack: $stack") + val (v, newCtx) = Stack.pop() println(v.toInt) - eval(rest, kont, trail) + eval(rest, kont, trail)(newCtx) case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex") case _ => throw new Exception(s"Definition at $funcIndex is not callable") } } - def evalTestOp(op: TestOp, value: Rep[Num]): Rep[Num] = op match { - case Eqz(_) => if (value.toInt == 0) Values.I32(1) else Values.I32(0) + def evalTestOp(op: TestOp, value: StagedNum): StagedNum = op match { + case Eqz(_) => Values.I32V(if (value.toInt == 0) 1 else 0) } - def evalUnaryOp(op: UnaryOp, value: Rep[Num]): Rep[Num] = op match { + def evalUnaryOp(op: UnaryOp, value: StagedNum): StagedNum = op match { case Clz(_) => value.clz() case Ctz(_) => value.ctz() case Popcnt(_) => value.popcnt() case _ => ??? } - def evalBinOp(op: BinOp, v1: Rep[Num], v2: Rep[Num]): Rep[Num] = op match { + def evalBinOp(op: BinOp, v1: StagedNum, v2: StagedNum): StagedNum = op match { case Add(_) => v1 + v2 case Mul(_) => v1 * v2 case Sub(_) => v1 - v2 @@ -260,7 +331,7 @@ trait StagedWasmEvaluator extends SAIOps { throw new Exception(s"Unknown binary operation $op") } - def evalRelOp(op: RelOp, v1: Rep[Num], v2: Rep[Num]): Rep[Num] = op match { + def evalRelOp(op: RelOp, v1: StagedNum, v2: StagedNum): StagedNum = op match { case Eq(_) => v1 numEq v2 case Ne(_) => v1 numNe v2 case LtS(_) => v1 < v2 @@ -298,10 +369,10 @@ trait StagedWasmEvaluator extends SAIOps { throw new Exception("Entry function has no concrete body") } } - val (instrs, localSize) = (funBody.body, funBody.locals.size) + val (instrs, locals) = (funBody.body, funBody.locals) Stack.initialize() - Frames.pushFrame(localSize) - eval(instrs, kont, kont::Nil) + Frames.pushFrame(locals) + eval(instrs, (_: Context) => kont, ((_: Context) => kont)::Nil)(Context(Nil, locals)) Frames.popFrame() } @@ -319,24 +390,59 @@ trait StagedWasmEvaluator extends SAIOps { // stack operations object Stack { + def shift(offset: Int, size: Int)(ctx: Context): Context = { + if (offset > 0) { + "stack-shift".reflectCtrlWith[Unit](offset, size) + } + ctx.shift(offset, size) + } + def initialize(): Rep[Unit] = { "stack-init".reflectCtrlWith[Unit]() } - def pop(): Rep[Num] = { - "stack-pop".reflectCtrlWith[Num]() + def pop()(implicit ctx: Context): (StagedNum, Context) = { + val (ty, newContext) = ctx.pop() + val num = ty match { + case NumType(I32Type) => I32("stack-pop".reflectCtrlWith[Num]()) + case NumType(I64Type) => I64("stack-pop".reflectCtrlWith[Num]()) + case NumType(F32Type) => F32("stack-pop".reflectCtrlWith[Num]()) + case NumType(F32Type) => F64("stack-pop".reflectCtrlWith[Num]()) + } + (num, newContext) } - def peek: Rep[Num] = { - "stack-peek".reflectCtrlWith[Num]() + def peek(implicit ctx: Context): (StagedNum, Context) = { + val ty = ctx.stackTypes.head + val num = ty match { + case NumType(I32Type) => I32("stack-peek".reflectCtrlWith[Num]()) + case NumType(I64Type) => I64("stack-peek".reflectCtrlWith[Num]()) + case NumType(F32Type) => F32("stack-peek".reflectCtrlWith[Num]()) + case NumType(F32Type) => F64("stack-peek".reflectCtrlWith[Num]()) + } + (num, ctx) + } + + def push(v: StagedNum)(implicit ctx: Context): Context = { + v match { + case I32(v) => NumType(I32Type); "stack-push".reflectCtrlWith[Unit](v) + case I64(v) => NumType(I64Type); "stack-push".reflectCtrlWith[Unit](v) + case F32(v) => NumType(F32Type); "stack-push".reflectCtrlWith[Unit](v) + case F64(v) => NumType(F64Type); "stack-push".reflectCtrlWith[Unit](v) + } + ctx.push(v.tipe) } - def push(v: Rep[Num]): Rep[Unit] = { - "stack-push".reflectCtrlWith[Unit](v) + def take(n: Int)(implicit ctx: Context): (List[StagedNum], Context) = n match { + case 0 => (Nil, ctx) + case n => + val (v, newCtx1) = pop() + val (rest, newCtx2) = take(n - 1) + (v::rest, newCtx2) } - def drop(n: Int): Rep[Unit] = { - "stack-drop".reflectCtrlWith[Unit](n) + def drop(n: Int)(implicit ctx: Context): Context = { + take(n)._2 } def shift(offset: Rep[Int], size: Rep[Int]): Rep[Unit] = { @@ -352,39 +458,43 @@ trait StagedWasmEvaluator extends SAIOps { def size: Rep[Int] = { "stack-size".reflectCtrlWith[Int]() } - - def reset(x: Rep[Int]): Rep[Unit] = { - "stack-reset".reflectCtrlWith[Unit](x) - } - - def take(n: Int): Rep[Slice] = { - "stack-take".reflectCtrlWith[Slice](n) - } } object Frames { - def get(i: Int): Rep[Num] = { - "frame-get".reflectCtrlWith[Num](i) + def get(i: Int)(implicit ctx: Context): StagedNum = { + // val offset = ctx.frameTypes.take(i).map(_.size).sum + ctx.frameTypes(i) match { + case NumType(I32Type) => I32("frame-get".reflectCtrlWith[Num](i)) + case NumType(I64Type) => I64("frame-get".reflectCtrlWith[Num](i)) + case NumType(F32Type) => F32("frame-get".reflectCtrlWith[Num](i)) + case NumType(F64Type) => F64("frame-get".reflectCtrlWith[Num](i)) + } } - def set(i: Int, v: Rep[Num]): Rep[Unit] = { - "frame-set".reflectCtrlWith(i, v) + def set(i: Int, v: StagedNum)(implicit ctx: Context): Rep[Unit] = { + // val offset = ctx.frameTypes.take(i).map(_.size).sum + v match { + case I32(v) => "frame-set".reflectCtrlWith[Unit](i, v) + case I64(v) => "frame-set".reflectCtrlWith[Unit](i, v) + case F32(v) => "frame-set".reflectCtrlWith[Unit](i, v) + case F64(v) => "frame-set".reflectCtrlWith[Unit](i, v) + } } - def pushFrame(i: Int): Rep[Unit] = { - "frame-push".reflectCtrlWith[Unit](i) + def pushFrame(locals: List[ValueType]): Rep[Unit] = { + // Predef.println(s"[DEBUG] push frame: $locals") + val size = locals.map(_.size).sum + "frame-push".reflectCtrlWith[Unit](size) } def popFrame(): Rep[Unit] = { "frame-pop".reflectCtrlWith[Unit]() } - def putAll(args: Rep[Slice]): Rep[Unit] = { - "frame-putAll".reflectCtrlWith[Unit](args) - } - - def top: Rep[Frame] = { - "frame-top".reflectCtrlWith[Frame]() + def putAll(args: List[StagedNum])(implicit ctx: Context): Rep[Unit] = { + for ((arg, i) <- args.view.reverse.zipWithIndex) { + Frames.set(i, arg) + } } } @@ -413,80 +523,182 @@ trait StagedWasmEvaluator extends SAIOps { // runtime values object Values { - def lift(num: Num): Rep[Num] = { - num match { - case I32V(i) => I32(i) - case I64V(i) => I64(i) - } - } - - def I32(i: Rep[Int]): Rep[Num] = { - "I32V".reflectCtrlWith[Num](i) + def I32V(i: Rep[Int]): StagedNum = { + I32("I32V".reflectCtrlWith[Num](i)) } - def I64(i: Rep[Long]): Rep[Num] = { - "I64V".reflectCtrlWith[Num](i) + def I64V(i: Rep[Long]): StagedNum = { + I64("I64V".reflectCtrlWith[Num](i)) } } // global read/write - object Global{ - def globalGet(i: Int): Rep[Num] = { - "global-get".reflectCtrlWith[Num](i) + object Globals { + def apply(i: Int): StagedNum = { + module.globals(i).ty match { + case GlobalType(NumType(I32Type), _) => I32("global-get".reflectCtrlWith[Num](i)) + case GlobalType(NumType(I64Type), _) => I64("global-get".reflectCtrlWith[Num](i)) + case GlobalType(NumType(F32Type), _) => F32("global-get".reflectCtrlWith[Num](i)) + case GlobalType(NumType(F64Type), _) => F64("global-get".reflectCtrlWith[Num](i)) + } } - def globalSet(i: Int, value: Rep[Num]): Rep[Unit] = { - "global-set".reflectCtrlWith[Unit](i, value) + def update(i: Int, v: StagedNum): Rep[Unit] = { + module.globals(i).ty match { + case GlobalType(NumType(I32Type), _) => "global-set".reflectCtrlWith[Unit](i) + case GlobalType(NumType(I64Type), _) => "global-set".reflectCtrlWith[Unit](i) + case GlobalType(NumType(F32Type), _) => "global-set".reflectCtrlWith[Unit](i) + case GlobalType(NumType(F64Type), _) => "global-set".reflectCtrlWith[Unit](i) + } } } // runtime Num type - implicit class NumOps(num: Rep[Num]) { + implicit class StagedNumOps(num: StagedNum) { - def toInt: Rep[Int] = "num-to-int".reflectCtrlWith[Int](num) + def toInt: Rep[Int] = "num-to-int".reflectCtrlWith[Int](num.i) - def clz(): Rep[Num] = "unary-clz".reflectCtrlWith[Num](num) + def clz(): StagedNum = num match { + case I32(i) => I32("clz".reflectCtrlWith[Num](i)) + case I64(i) => I64("clz".reflectCtrlWith[Num](i)) + } - def ctz(): Rep[Num] = "unary-ctz".reflectCtrlWith[Num](num) + def ctz(): StagedNum = num match { + case I32(i) => I32("ctz".reflectCtrlWith[Num](i)) + case I64(i) => I64("ctz".reflectCtrlWith[Num](i)) + } - def popcnt(): Rep[Num] = "unary-popcnt".reflectCtrlWith[Num](num) + def popcnt(): StagedNum = num match { + case I32(i) => I32("popcnt".reflectCtrlWith[Num](i)) + case I64(i) => I64("popcnt".reflectCtrlWith[Num](i)) + } - def +(rhs: Rep[Num]): Rep[Num] = "binary-add".reflectCtrlWith[Num](num, rhs) + def +(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("binary-add".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I64("binary-add".reflectCtrlWith[Num](x, y)) + case (F32(x), F32(y)) => F32("binary-add".reflectCtrlWith[Num](x, y)) + case (F64(x), F64(y)) => F64("binary-add".reflectCtrlWith[Num](x, y)) + } + } - def -(rhs: Rep[Num]): Rep[Num] = "binary-sub".reflectCtrlWith[Num](num, rhs) + def -(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("binary-sub".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I64("binary-sub".reflectCtrlWith[Num](x, y)) + case (F32(x), F32(y)) => F32("binary-sub".reflectCtrlWith[Num](x, y)) + case (F64(x), F64(y)) => F64("binary-sub".reflectCtrlWith[Num](x, y)) + } + } - def *(rhs: Rep[Num]): Rep[Num] = "binary-mul".reflectCtrlWith[Num](num, rhs) + def *(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("binary-mul".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I64("binary-mul".reflectCtrlWith[Num](x, y)) + case (F32(x), F32(y)) => F32("binary-mul".reflectCtrlWith[Num](x, y)) + case (F64(x), F64(y)) => F64("binary-mul".reflectCtrlWith[Num](x, y)) + } + } - def /(rhs: Rep[Num]): Rep[Num] = "binary-div".reflectCtrlWith[Num](num, rhs) + def /(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("binary-div".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I64("binary-div".reflectCtrlWith[Num](x, y)) + case (F32(x), F32(y)) => F32("binary-div".reflectCtrlWith[Num](x, y)) + case (F64(x), F64(y)) => F64("binary-div".reflectCtrlWith[Num](x, y)) + } + } - def <<(rhs: Rep[Num]): Rep[Num] = "binary-shl".reflectCtrlWith[Num](num, rhs) + def <<(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("binary-shl".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I64("binary-shl".reflectCtrlWith[Num](x, y)) + } + } - def >>(rhs: Rep[Num]): Rep[Num] = "binary-shr".reflectCtrlWith[Num](num, rhs) + def >>(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("binary-shr".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I64("binary-shr".reflectCtrlWith[Num](x, y)) + } + } - def &(rhs: Rep[Num]): Rep[Num] = "binary-and".reflectCtrlWith[Num](num, rhs) + def &(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("binary-and".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I64("binary-and".reflectCtrlWith[Num](x, y)) + } + } - def numEq(rhs: Rep[Num]): Rep[Num] = "relation-eq".reflectCtrlWith[Num](num, rhs) + def numEq(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("relation-eq".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I32("relation-eq".reflectCtrlWith[Num](x, y)) + } + } - def numNe(rhs: Rep[Num]): Rep[Num] = "relation-ne".reflectCtrlWith[Num](num, rhs) + def numNe(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("relation-ne".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I32("relation-ne".reflectCtrlWith[Num](x, y)) + } + } - def <(rhs: Rep[Num]): Rep[Num] = "relation-lt".reflectCtrlWith[Num](num, rhs) + def <(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("relation-lt".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I32("relation-lt".reflectCtrlWith[Num](x, y)) + } + } - def ltu(rhs: Rep[Num]): Rep[Num] = "relation-ltu".reflectCtrlWith[Num](num, rhs) + def ltu(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("relation-ltu".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I32("relation-ltu".reflectCtrlWith[Num](x, y)) + } + } - def >(rhs: Rep[Num]): Rep[Num] = "relation-gt".reflectCtrlWith[Num](num, rhs) + def >(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("relation-gt".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I32("relation-gt".reflectCtrlWith[Num](x, y)) + } + } - def gtu(rhs: Rep[Num]): Rep[Num] = "relation-gtu".reflectCtrlWith[Num](num, rhs) + def gtu(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("relation-gtu".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I32("relation-gtu".reflectCtrlWith[Num](x, y)) + } + } - def <=(rhs: Rep[Num]): Rep[Num] = "relation-le".reflectCtrlWith[Num](num, rhs) + def <=(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("relation-le".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I32("relation-le".reflectCtrlWith[Num](x, y)) + } + } - def leu(rhs: Rep[Num]): Rep[Num] = "relation-leu".reflectCtrlWith[Num](num, rhs) + def leu(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("relation-leu".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I32("relation-leu".reflectCtrlWith[Num](x, y)) + } + } - def >=(rhs: Rep[Num]): Rep[Num] = "relation-ge".reflectCtrlWith[Num](num, rhs) + def >=(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("relation-ge".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I32("relation-ge".reflectCtrlWith[Num](x, y)) + } + } - def geu(rhs: Rep[Num]): Rep[Num] = "relation-geu".reflectCtrlWith[Num](num, rhs) - } - implicit class SliceOps(slice: Rep[Slice]) { - def reverse: Rep[Slice] = "slice-reverse".reflectCtrlWith[Slice](slice) + def geu(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("relation-geu".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I32("relation-geu".reflectCtrlWith[Num](x, y)) + } + } } } @@ -525,21 +737,17 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { case Node(_, "frame-pop", _, _) => emit("Frames.popFrame()") case Node(_, "stack-push", List(value), _) => - emit("Stack.push("); shallow(value); emit(")\n") + emit("Stack.push("); shallow(value); emit(")") case Node(_, "stack-pop", _, _) => emit("Stack.pop()") case Node(_, "stack-peek", _, _) => emit("Stack.peek") case Node(_, "stack-take", List(n), _) => emit("Stack.take("); shallow(n); emit(")") - case Node(_, "slice-reverse", List(slice), _) => - shallow(slice); emit(".reverse") case Node(_, "stack-size", _, _) => emit("Stack.size") case Node(_, "global-get", List(i), _) => emit("Global.globalGet("); shallow(i); emit(")") - case Node(_, "frame-top", _, _) => - emit("Frames.top") case Node(_, "binary-add", List(lhs, rhs), _) => shallow(lhs); emit(" + "); shallow(rhs) case Node(_, "binary-sub", List(lhs, rhs), _) => @@ -590,7 +798,6 @@ trait WasmToScalaCompilerDriver[A, B] override def remap(m: Manifest[_]): String = { if (m.toString.endsWith("Stack")) "Stack" else if(m.toString.endsWith("Frame")) "Frame" - else if(m.toString.endsWith("Slice")) "Slice" else super.remap(m) } } @@ -655,8 +862,6 @@ object Stack { } } - type Slice = List[Num] - class Frame(val size: Int) { private val data = new Array[Num](size) def apply(i: Int): Num = { @@ -692,11 +897,6 @@ object Stack { def get(i: Int): Num = { top(i) } - def putAll(xs: Slice) = { - for (i <- 0 until xs.size) { - top(i) = xs(i) - } - } } object Global { @@ -740,13 +940,14 @@ object WasmToScalaCompiler { trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { override def mayInline(n: Node): Boolean = n match { - case Node(s, "stack-pop", _, _) => false + case Node(_, "stack-pop", _, _) + | Node(_, "stack-peek", _, _) + => false case _ => super.mayInline(n) } override def remap(m: Manifest[_]): String = { if (m.toString.endsWith("Num")) "Num" - else if (m.toString.endsWith("Slice")) "Slice" else if (m.toString.endsWith("Frame")) "Frame" else if (m.toString.endsWith("Stack")) "Stack" else if (m.toString.endsWith("Global")) "Global" @@ -757,8 +958,10 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { // for now, the traverse/shallow is same as the scala backend's override def traverse(n: Node): Unit = n match { - case Node(_, "stack-reset", List(n), _) => - emit("Stack.reset("); shallow(n); emit(");\n") + case Node(_, "stack-push", List(value), _) => + emit("Stack.push("); shallow(value); emit(");\n") + case Node(_, "stack-drop", List(n), _) => + emit("Stack.drop("); shallow(n); emit(");\n") case Node(_, "stack-init", _, _) => emit("Stack.initialize();\n") case Node(_, "stack-print", _, _) => @@ -817,8 +1020,6 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { emit("Stack.size()") case Node(_, "global-get", List(i), _) => emit("Global.globalGet("); shallow(i); emit(")") - case Node(_, "frame-top", _, _) => - emit("Frames.top()") case Node(_, "binary-add", List(lhs, rhs), _) => shallow(lhs); emit(" + "); shallow(rhs) case Node(_, "binary-sub", List(lhs, rhs), _) => From 4e907a3a1defc1f4c2c4067ae225363af2b26ea4 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 11 Jun 2025 12:06:46 +0800 Subject: [PATCH 58/62] lifting to the top --- src/main/scala/wasm/StagedMiniWasm.scala | 122 ++++++++++++----------- 1 file changed, 65 insertions(+), 57 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 7deb98c8..87f57b70 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -79,11 +79,12 @@ trait StagedWasmEvaluator extends SAIOps { } } - type Cont[A] = Unit => A + type MCont[A] = Unit => A + type Cont[A] = (MCont[A]) => A type Trail[A] = List[Context => Rep[Cont[A]]] // a cache storing the compiled code for each function, to reduce re-compilation - val compileCache = new HashMap[Int, Rep[(Cont[Unit]) => Unit]] + val compileCache = new HashMap[Int, Rep[(MCont[Unit]) => Unit]] def makeDummy: Rep[Unit] = "dummy".reflectCtrlWith[Unit]() @@ -98,9 +99,10 @@ trait StagedWasmEvaluator extends SAIOps { def eval(insts: List[Instr], kont: Context => Rep[Cont[Unit]], + mkont: Rep[MCont[Unit]], trail: Trail[Unit]) (implicit ctx: Context): Rep[Unit] = { - if (insts.isEmpty) return kont(ctx)(()) + if (insts.isEmpty) return kont(ctx)(mkont) // Predef.println(s"[DEBUG] Evaluating instructions: ${insts.mkString(", ")}") // Predef.println(s"[DEBUG] Current context: $ctx") @@ -109,152 +111,155 @@ trait StagedWasmEvaluator extends SAIOps { inst match { case Drop => val (_, newCtx) = Stack.pop() - eval(rest, kont, trail)(newCtx) + eval(rest, kont, mkont, trail)(newCtx) case WasmConst(num) => val newCtx = Stack.push(num) - eval(rest, kont, trail)(newCtx) + eval(rest, kont, mkont, trail)(newCtx) case LocalGet(i) => val newCtx = Stack.push(Frames.get(i)) - eval(rest, kont, trail)(newCtx) + eval(rest, kont, mkont, trail)(newCtx) case LocalSet(i) => val (num, newCtx) = Stack.pop() Frames.set(i, num)(newCtx) - eval(rest, kont, trail)(newCtx) + eval(rest, kont, mkont, trail)(newCtx) case LocalTee(i) => val (num, newCtx) = Stack.peek Frames.set(i, num) - eval(rest, kont, trail)(newCtx) + eval(rest, kont, mkont, trail)(newCtx) case GlobalGet(i) => val newCtx = Stack.push(Globals(i)) - eval(rest, kont, trail)(newCtx) + eval(rest, kont, mkont, trail)(newCtx) case GlobalSet(i) => val (value, newCtx) = Stack.pop() module.globals(i).ty match { case GlobalType(tipe, true) => Globals(i) = value case _ => throw new Exception("Cannot set immutable global") } - eval(rest, kont, trail)(newCtx) + eval(rest, kont, mkont, trail)(newCtx) case Store(StoreOp(align, offset, ty, None)) => val (value, newCtx1) = Stack.pop() val (addr, newCtx2) = Stack.pop()(newCtx1) Memory.storeInt(addr.toInt, offset, value.toInt) - eval(rest, kont, trail)(newCtx2) - case Nop => eval(rest, kont, trail) + eval(rest, kont, mkont, trail)(newCtx2) + case Nop => eval(rest, kont, mkont, trail) case Load(LoadOp(align, offset, ty, None, None)) => val (addr, newCtx1) = Stack.pop() val value = Memory.loadInt(addr.toInt, offset) val newCtx2 = Stack.push(Values.I32V(value))(newCtx1) - eval(rest, kont, trail)(newCtx2) + eval(rest, kont, mkont, trail)(newCtx2) case MemorySize => ??? case MemoryGrow => val (delta, newCtx1) = Stack.pop() val newCtx2 = Stack.push(Values.I32V(Memory.grow(delta.toInt)))(newCtx1) - eval(rest, kont, trail)(newCtx2) + eval(rest, kont, mkont, trail)(newCtx2) case MemoryFill => ??? - case Nop => eval(rest, kont, trail) case Unreachable => unreachable() case Test(op) => val (v, newCtx1) = Stack.pop() val newCtx2 = Stack.push(evalTestOp(op, v))(newCtx1) - eval(rest, kont, trail)(newCtx2) + eval(rest, kont, mkont, trail)(newCtx2) case Unary(op) => val (v, newCtx1) = Stack.pop() val newCtx2 = Stack.push(evalUnaryOp(op, v))(newCtx1) - eval(rest, kont, trail)(newCtx2) + eval(rest, kont, mkont, trail)(newCtx2) case Binary(op) => val (v2, newCtx1) = Stack.pop() val (v1, newCtx2) = Stack.pop()(newCtx1) val newCtx3 = Stack.push(evalBinOp(op, v1, v2))(newCtx2) - eval(rest, kont, trail)(newCtx3) + eval(rest, kont, mkont, trail)(newCtx3) case Compare(op) => val (v2, newCtx1) = Stack.pop() val (v1, newCtx2) = Stack.pop()(newCtx1) val newCtx3 = Stack.push(evalRelOp(op, v1, v2))(newCtx2) - eval(rest, kont, trail)(newCtx3) + eval(rest, kont, mkont, trail)(newCtx3) case WasmBlock(ty, inner) => // no need to modify the stack when entering a block // the type system guarantees that we will never take more than the input size from the stack val funcTy = ty.funcType val exitSize = ctx.stackTypes.size - funcTy.inps.size + funcTy.out.size val dummy = makeDummy - def restK(restCtx: Context): Rep[Cont[Unit]] = funHere((_: Rep[Unit]) => { + def restK(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Exiting the block, stackSize =", Stack.size) val offset = restCtx.stackTypes.size - exitSize val newRestCtx = Stack.shift(offset, funcTy.out.size)(restCtx) - eval(rest, kont, trail)(newRestCtx) - }, dummy) - eval(inner, restK _, restK _ :: trail) + eval(rest, kont, mk, trail)(newRestCtx) + }) + eval(inner, restK _, mkont, restK _ :: trail) case Loop(ty, inner) => val funcTy = ty.funcType val exitSize = ctx.stackTypes.size - funcTy.inps.size + funcTy.out.size val dummy = makeDummy - def restK(restCtx: Context): Rep[Cont[Unit]] = funHere((_: Rep[Unit]) => { + def restK(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Exiting the loop, stackSize =", Stack.size) val offset = restCtx.stackTypes.size - exitSize val newRestCtx = Stack.shift(offset, funcTy.out.size)(restCtx) - eval(rest, kont, trail)(newRestCtx) - }, dummy) + eval(rest, kont, mk, trail)(newRestCtx) + }) val enterSize = ctx.stackTypes.size - def loop(restCtx: Context): Rep[Unit => Unit] = funHere((_u: Rep[Unit]) => { + def loop(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Entered the loop, stackSize =", Stack.size) val offset = restCtx.stackTypes.size - enterSize val newRestCtx = Stack.shift(offset, funcTy.inps.size)(restCtx) - eval(inner, restK _, loop _ :: trail)(newRestCtx) - }, dummy) // <-- if we don't pass this dummy argument, lots of code will be generated - loop(ctx)(()) + eval(inner, restK _, mk, loop _ :: trail)(newRestCtx) + }) + loop(ctx)(mkont) case If(ty, thn, els) => val funcTy = ty.funcType val (cond, newCtx) = Stack.pop() val exitSize = newCtx.stackTypes.size - funcTy.inps.size + funcTy.out.size // TODO: can we avoid code duplication here? val dummy = makeDummy - def restK(restCtx: Context): Rep[Cont[Unit]] = funHere((_: Rep[Unit]) => { + def restK(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Exiting the if, stackSize =", Stack.size) val offset = restCtx.stackTypes.size - exitSize val newRestCtx = Stack.shift(offset, funcTy.out.size)(restCtx) - eval(rest, kont, trail)(newRestCtx) - }, dummy) + eval(rest, kont, mk, trail)(newRestCtx) + }) if (cond.toInt != 0) { - eval(thn, restK _, restK _ :: trail)(newCtx) + eval(thn, restK _, mkont, restK _ :: trail)(newCtx) } else { - eval(els, restK _, restK _ :: trail)(newCtx) + eval(els, restK _, mkont, restK _ :: trail)(newCtx) } () case Br(label) => info(s"Jump to $label") - trail(label)(ctx)(()) + trail(label)(ctx)(mkont) case BrIf(label) => val (cond, newCtx) = Stack.pop() info(s"The br_if(${label})'s condition is ", cond.toInt) if (cond.toInt != 0) { info(s"Jump to $label") - trail(label)(newCtx)(()) + trail(label)(newCtx)(mkont) } else { info(s"Continue") - eval(rest, kont, trail)(newCtx) + eval(rest, kont, mkont, trail)(newCtx) } () case BrTable(labels, default) => val (cond, newCtx) = Stack.pop() def aux(choices: List[Int], idx: Int): Rep[Unit] = { - if (choices.isEmpty) trail(default)(newCtx)(()) + if (choices.isEmpty) trail(default)(newCtx)(mkont) else { - if (cond.toInt == idx) trail(choices.head)(newCtx)(()) + if (cond.toInt == idx) trail(choices.head)(newCtx)(mkont) else aux(choices.tail, idx + 1) } } aux(labels, 0) - case Return => trail.last(ctx)(()) - case Call(f) => evalCall(rest, kont, trail, f, false) - case ReturnCall(f) => evalCall(rest, kont, trail, f, true) + case Return => trail.last(ctx)(mkont) + case Call(f) => evalCall(rest, kont, mkont, trail, f, false) + case ReturnCall(f) => evalCall(rest, kont, mkont, trail, f, true) case _ => val todo = "todo-op".reflectCtrlWith[Unit]() - eval(rest, kont, trail) + eval(rest, kont, mkont, trail) } } + def forwardKont: Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => mk(())) + + def evalCall(rest: List[Instr], kont: Context => Rep[Cont[Unit]], + mkont: Rep[MCont[Unit]], trail: Trail[Unit], funcIndex: Int, isTail: Boolean) @@ -266,10 +271,10 @@ trait StagedWasmEvaluator extends SAIOps { if (compileCache.contains(funcIndex)) { compileCache(funcIndex) } else { - val callee = topFun((kont: Rep[Cont[Unit]]) => { + val callee = topFun((mk: Rep[MCont[Unit]]) => { info(s"Entered the function at $funcIndex, stackSize =", Stack.size) // we can do some check here to ensure the function returns correct size of stack - eval(body, (_: Context) => kont, ((_: Context) => kont)::Nil)(Context(Nil, locals)) + eval(body, (_: Context) => forwardKont, mk, ((_: Context) => forwardKont)::Nil)(Context(Nil, locals)) }) compileCache(funcIndex) = callee callee @@ -281,26 +286,29 @@ trait StagedWasmEvaluator extends SAIOps { Frames.popFrame() Frames.pushFrame(locals) Frames.putAll(args) - callee(trail.last(ctx)) + callee(mkont) } else { - val dummy = makeDummy - val restK: Rep[Cont[Unit]] = funHere((_: Rep[Unit]) => { + // We make a new trail by `restK`, since function creates a new block to escape + // (more or less like `return`) + val restK: Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Exiting the function at $funcIndex, stackSize =", Stack.size) Frames.popFrame() - eval(rest, kont, trail)(newCtx.copy(stackTypes = ty.out.reverse ++ ctx.stackTypes.drop(ty.inps.size))) + eval(rest, kont, mk, trail)(newCtx.copy(stackTypes = ty.out.reverse ++ ctx.stackTypes.drop(ty.inps.size))) + }) + val dummy = makeDummy + val newMKont: Rep[MCont[Unit]] = funHere((_u: Rep[Unit]) => { + restK(mkont) }, dummy) - // We make a new trail by `restK`, since function creates a new block to escape - // (more or less like `return`) Frames.pushFrame(locals) Frames.putAll(args) - callee(restK) + callee(newMKont) } case Import("console", "log", _) | Import("spectest", "print_i32", _) => //println(s"[DEBUG] current stack: $stack") val (v, newCtx) = Stack.pop() println(v.toInt) - eval(rest, kont, trail)(newCtx) + eval(rest, kont, mkont, trail)(newCtx) case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex") case _ => throw new Exception(s"Definition at $funcIndex is not callable") } @@ -345,7 +353,7 @@ trait StagedWasmEvaluator extends SAIOps { case _ => ??? } - def evalTop(kont: Rep[Cont[Unit]], main: Option[String]): Rep[Unit] = { + def evalTop(mkont: Rep[MCont[Unit]], main: Option[String]): Rep[Unit] = { val funBody: FuncBodyDef = main match { case Some(func_name) => module.defs.flatMap({ @@ -372,7 +380,7 @@ trait StagedWasmEvaluator extends SAIOps { val (instrs, locals) = (funBody.body, funBody.locals) Stack.initialize() Frames.pushFrame(locals) - eval(instrs, (_: Context) => kont, ((_: Context) => kont)::Nil)(Context(Nil, locals)) + eval(instrs, (_: Context) => forwardKont, mkont, ((_: Context) => forwardKont)::Nil)(Context(Nil, locals)) Frames.popFrame() } @@ -384,7 +392,7 @@ trait StagedWasmEvaluator extends SAIOps { } "no-op".reflectCtrlWith[Unit]() } - val temp: Rep[Cont[Unit]] = fun(haltK) + val temp: Rep[MCont[Unit]] = topFun(haltK) evalTop(temp, main) } From 3adc60c97a0ee25ffbc4508969ab9e33ab249bd9 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 11 Jun 2025 16:34:22 +0800 Subject: [PATCH 59/62] avoid re-registering top function --- src/main/scala/wasm/StagedMiniWasm.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 87f57b70..53f10052 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -1083,7 +1083,8 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { })._1)(f) //ongoingFun -= streamId } else { - withStream(functionsStreams(id)._1)(f) + // If a function is registered, don't re-register it. + // withStream(functionsStreams(id)._1)(f) } override def emitAll(g: Graph, name: String)(m1: Manifest[_], m2: Manifest[_]): Unit = { From 86061f12629df141a88c293580dedf4cd0c86d28 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Mon, 16 Jun 2025 17:03:05 +0800 Subject: [PATCH 60/62] remove std::vector usages & use O3 in benchmark --- src/main/scala/wasm/StagedMiniWasm.scala | 141 ++++++++----------- src/test/scala/genwasym/TestStagedEval.scala | 4 +- 2 files changed, 59 insertions(+), 86 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 53f10052..794ec4bb 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -283,7 +283,7 @@ trait StagedWasmEvaluator extends SAIOps { val (args, newCtx) = Stack.take(ty.inps.size) if (isTail) { // when tail call, return to the caller's return continuation - Frames.popFrame() + Frames.popFrame(ctx.frameTypes.size) Frames.pushFrame(locals) Frames.putAll(args) callee(mkont) @@ -292,7 +292,7 @@ trait StagedWasmEvaluator extends SAIOps { // (more or less like `return`) val restK: Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Exiting the function at $funcIndex, stackSize =", Stack.size) - Frames.popFrame() + Frames.popFrame(locals.size) eval(rest, kont, mk, trail)(newCtx.copy(stackTypes = ty.out.reverse ++ ctx.stackTypes.drop(ty.inps.size))) }) val dummy = makeDummy @@ -381,7 +381,7 @@ trait StagedWasmEvaluator extends SAIOps { Stack.initialize() Frames.pushFrame(locals) eval(instrs, (_: Context) => forwardKont, mkont, ((_: Context) => forwardKont)::Nil)(Context(Nil, locals)) - Frames.popFrame() + Frames.popFrame(locals.size) } def evalTop(main: Option[String], printRes: Boolean = false): Rep[Unit] = { @@ -491,12 +491,12 @@ trait StagedWasmEvaluator extends SAIOps { def pushFrame(locals: List[ValueType]): Rep[Unit] = { // Predef.println(s"[DEBUG] push frame: $locals") - val size = locals.map(_.size).sum + val size = locals.size "frame-push".reflectCtrlWith[Unit](size) } - def popFrame(): Rep[Unit] = { - "frame-pop".reflectCtrlWith[Unit]() + def popFrame(size: Int): Rep[Unit] = { + "frame-pop".reflectCtrlWith[Unit](size) } def putAll(args: List[StagedNum])(implicit ctx: Context): Rep[Unit] = { @@ -727,8 +727,8 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { emit("Stack.print()\n") case Node(_, "frame-push", List(i), _) => emit("Frames.pushFrame("); shallow(i); emit(")\n") - case Node(_, "frame-pop", _, _) => - emit("Frames.popFrame()\n") + case Node(_, "frame-pop", List(i), _) => + emit("Frames.popFrame("); shallow(i); emit(")\n") case Node(_, "frame-putAll", List(args), _) => emit("Frames.putAll("); shallow(args); emit(")\n") case Node(_, "frame-set", List(i, value), _) => @@ -742,8 +742,8 @@ trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { override def shallow(n: Node): Unit = n match { case Node(_, "frame-get", List(i), _) => emit("Frames.get("); shallow(i); emit(")") - case Node(_, "frame-pop", _, _) => - emit("Frames.popFrame()") + case Node(_, "frame-pop", List(i), _) => + emit("Frames.popFrame("); shallow(i); emit(")") case Node(_, "stack-push", List(value), _) => emit("Stack.push("); shallow(value); emit(")") case Node(_, "stack-pop", _, _) => @@ -976,8 +976,8 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { emit("Stack.print();\n") case Node(_, "frame-push", List(i), _) => emit("Frames.pushFrame("); shallow(i); emit(");\n") - case Node(_, "frame-pop", _, _) => - emit("Frames.popFrame();\n") + case Node(_, "frame-pop", List(i), _) => + emit("Frames.popFrame("); shallow(i); emit(");\n") case Node(_, "frame-putAll", List(args), _) => emit("Frames.putAll("); shallow(args); emit(");\n") case Node(_, "frame-set", List(i, value), _) => @@ -1010,8 +1010,8 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { emit("Stack.shift("); shallow(offset); emit(", "); shallow(size); emit(")") case Node(_, "stack-pop", _, _) => emit("Stack.pop()") - case Node(_, "frame-pop", _, _) => - emit("Frames.popFrame()") + case Node(_, "frame-pop", List(i), _) => + emit("Frames.popFrame("); shallow(i); emit(")") case Node(_, "stack-peek", _, _) => emit("Stack.peek()") case Node(_, "stack-take", List(n), _) => @@ -1171,133 +1171,106 @@ static Num I64V(int64_t v) { return v; } using Slice = std::vector; +const int STACK_SIZE = 1024 * 64; + class Stack_t { public: + Stack_t() : count(0), stack_ptr(new Num[STACK_SIZE]) {} + std::monostate push(Num &&num) { - stack_.push_back(std::move(num)); + stack_ptr[count] = num; + count++; return std::monostate{}; } std::monostate push(Num &num) { - stack_.push_back(num); + stack_ptr[count] = num; + count++; return std::monostate{}; } Num pop() { - if (stack_.empty()) { +#ifdef DEBUG + if (count == 0) { throw std::runtime_error("Stack underflow"); } - Num num = std::move(stack_.back()); - stack_.pop_back(); +#endif + Num num = stack_ptr[count - 1]; + count--; return num; } Num peek() { - if (stack_.empty()) { +#ifdef DEBUG + if (count == 0) { throw std::runtime_error("Stack underflow"); } - return stack_.back(); - } - - Num get(int32_t index) { - assert(index >= 0); - assert(index < stack_.size()); - return stack_[index]; +#endif + return stack_ptr[count - 1]; } - int32_t size() { return stack_.size(); } + int32_t size() { return count; } void shift(int32_t offset, int32_t size) { +#ifdef DEBUG if (offset < 0) { throw std::out_of_range("Invalid offset: " + std::to_string(offset)); } if (size < 0) { throw std::out_of_range("Invalid size: " + std::to_string(size)); } +#endif // shift last `size` of numbers forward of `offset` - for (int32_t i = stack_.size() - size; i < stack_.size(); ++i) { - stack_[i - offset] = stack_[i]; + for (int32_t i = count - size; i < count; ++i) { + stack_ptr[i - offset] = stack_ptr[i]; } - stack_.resize(stack_.size() - offset); - } - - Slice take(int32_t size) { - if (size > stack_.size()) { - throw std::out_of_range("Invalid size: requested " + std::to_string(size) + ", stack size is " + std::to_string(stack_.size())); - } - // todo: avoid re-allocation - Slice slice(stack_.end() - size, stack_.end()); - stack_.resize(stack_.size() - size); - return slice; + count -= offset; } void print() { std::cout << "Stack contents: " << std::endl; - for (auto it = stack_.rbegin(); it != stack_.rend(); ++it) { - std::cout << it->value << std::endl; + for (int32_t i = 0; i < count; ++i) { + std::cout << stack_ptr[count - i - 1].value << std::endl; } } - void initialize() { stack_.clear(); } + void initialize() { + // do nothing for now + } private: - std::vector stack_; + int32_t count; + Num *stack_ptr; }; static Stack_t Stack; -struct Frame_t { - std::vector locals; - - Frame_t(std::int32_t size) : locals() { locals.resize(size); } - Num &operator[](std::int32_t index) { - assert(index >= 0); - if (index >= locals.size()) { - throw std::out_of_range("Index out of range"); - } - return locals[index]; - } - void putAll(Slice slice) { - for (std::int32_t i = 0; i < slice.size(); ++i) { - locals[i] = slice[i]; - } - } -}; +const int FRAME_SIZE = 1024; class Frames_t { public: - std::monostate popFrame() { - if (!frames.empty()) { - frames.pop_back(); - return std::monostate{}; - } else { - std::cout << "No frames to pop." << std::endl; - throw std::runtime_error("No frames to pop."); - } + Frames_t() : count(0), stack_ptr(new Num[FRAME_SIZE]) {} + + std::monostate popFrame(std::int32_t size) { + assert(size >= 0); + count -= size; + return std::monostate{}; } Num get(std::int32_t index) { - auto ret = top()[index]; + auto ret = stack_ptr[count - 1 - index]; return ret; } - void set(std::int32_t index, Num num) { frames.back()[index] = num; } - - Frame_t &top() { - if (frames.empty()) { - throw std::runtime_error("No frames available"); - } - return frames.back(); - } + void set(std::int32_t index, Num num) { stack_ptr[count - 1 - index] = num; } void pushFrame(std::int32_t size) { - Frame_t frame(size); - frames.push_back(frame); + assert(size >= 0); + count += size; } - void putAll(Slice slice) { top().putAll(slice); } - private: - std::vector frames; + int32_t count; + Num *stack_ptr; }; static Frames_t Frames; diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index 581e6e45..d9fa44dd 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -134,8 +134,7 @@ object Benchmark extends App { import sys.process._ val exe = s"$cppFile.exe" - // use -O0 optimization to more accurately inspect the interpretation overhead that we reduced by compilation - val command = s"g++ -std=c++17 -o $exe $cppFile -O0" + val command = s"g++ -std=c++17 -o $exe $cppFile -O3" if (command.! != 0) { throw new RuntimeException(s"Compilation failed for $cppFile") @@ -149,6 +148,7 @@ object Benchmark extends App { def benchmarkFile(filePath: String, main: Option[String] = None): Unit = { val interpretExecutionTime = benchmarkWasmInterpreter(filePath, main) + // val interpretExecutionTime = 0.0 val compiledExecutionTime = benchmarkWasmToCpp(filePath, main) val result = BenchmarkResult(filePath, interpretExecutionTime, compiledExecutionTime) println(s"Benchmark result for $filePath:") From 333a8d6a9420d1113d6feb1b0413024d9002daa1 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Fri, 20 Jun 2025 16:15:46 +0800 Subject: [PATCH 61/62] split header from prelude --- headers/wasm.hpp | 6 + headers/wasm/concrete_rt.hpp | 203 +++++++++++++++ src/main/scala/wasm/StagedMiniWasm.scala | 257 +++---------------- src/test/scala/genwasym/TestStagedEval.scala | 36 +-- 4 files changed, 251 insertions(+), 251 deletions(-) create mode 100644 headers/wasm.hpp create mode 100644 headers/wasm/concrete_rt.hpp diff --git a/headers/wasm.hpp b/headers/wasm.hpp new file mode 100644 index 00000000..21da2ff7 --- /dev/null +++ b/headers/wasm.hpp @@ -0,0 +1,6 @@ +#ifndef WASM_HEADERS +#define WASM_HEADERS + +#include "wasm/concrete_rt.hpp" + +#endif \ No newline at end of file diff --git a/headers/wasm/concrete_rt.hpp b/headers/wasm/concrete_rt.hpp new file mode 100644 index 00000000..34d739f4 --- /dev/null +++ b/headers/wasm/concrete_rt.hpp @@ -0,0 +1,203 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +void info() { +#ifdef DEBUG + std::cout << std::endl; +#endif +} + +template +void info(const T &first, const Args &...args) { +#ifdef DEBUG + std::cout << first << " "; + info(args...); +#endif +} + +struct Num { + Num(int64_t value) : value(value) {} + Num() : value(0) {} + int64_t value; + int32_t toInt() { return static_cast(value); } + + bool operator==(const Num &other) const { return value == other.value; } + bool operator!=(const Num &other) const { return !(*this == other); } + Num operator+(const Num &other) const { return Num(value + other.value); } + Num operator-(const Num &other) const { return Num(value - other.value); } + Num operator*(const Num &other) const { return Num(value * other.value); } + Num operator/(const Num &other) const { + if (other.value == 0) { + throw std::runtime_error("Division by zero"); + } + return Num(value / other.value); + } + Num operator<(const Num &other) const { return Num(value < other.value); } + Num operator<=(const Num &other) const { return Num(value <= other.value); } + Num operator>(const Num &other) const { return Num(value > other.value); } + Num operator>=(const Num &other) const { return Num(value >= other.value); } + Num operator&(const Num &other) const { return Num(value & other.value); } +}; + +static Num I32V(int v) { return v; } + +static Num I64V(int64_t v) { return v; } + +using Slice = std::vector; + +const int STACK_SIZE = 1024 * 64; + +class Stack_t { +public: + Stack_t() : count(0), stack_ptr(new Num[STACK_SIZE]) {} + + std::monostate push(Num &&num) { + stack_ptr[count] = num; + count++; + return std::monostate{}; + } + + std::monostate push(Num &num) { + stack_ptr[count] = num; + count++; + return std::monostate{}; + } + + Num pop() { +#ifdef DEBUG + if (count == 0) { + throw std::runtime_error("Stack underflow"); + } +#endif + Num num = stack_ptr[count - 1]; + count--; + return num; + } + + Num peek() { +#ifdef DEBUG + if (count == 0) { + throw std::runtime_error("Stack underflow"); + } +#endif + return stack_ptr[count - 1]; + } + + int32_t size() { return count; } + + void shift(int32_t offset, int32_t size) { +#ifdef DEBUG + if (offset < 0) { + throw std::out_of_range("Invalid offset: " + std::to_string(offset)); + } + if (size < 0) { + throw std::out_of_range("Invalid size: " + std::to_string(size)); + } +#endif + // shift last `size` of numbers forward of `offset` + for (int32_t i = count - size; i < count; ++i) { + stack_ptr[i - offset] = stack_ptr[i]; + } + count -= offset; + } + + void print() { + std::cout << "Stack contents: " << std::endl; + for (int32_t i = 0; i < count; ++i) { + std::cout << stack_ptr[count - i - 1].value << std::endl; + } + } + + void initialize() { + // do nothing for now + } + +private: + int32_t count; + Num *stack_ptr; +}; +static Stack_t Stack; + +const int FRAME_SIZE = 1024; + +class Frames_t { +public: + Frames_t() : count(0), stack_ptr(new Num[FRAME_SIZE]) {} + + std::monostate popFrame(std::int32_t size) { + assert(size >= 0); + count -= size; + return std::monostate{}; + } + + Num get(std::int32_t index) { + auto ret = stack_ptr[count - 1 - index]; + return ret; + } + + void set(std::int32_t index, Num num) { stack_ptr[count - 1 - index] = num; } + + void pushFrame(std::int32_t size) { + assert(size >= 0); + count += size; + } + +private: + int32_t count; + Num *stack_ptr; +}; + +static Frames_t Frames; + +static void initRand() { + // for now, just do nothing +} + +static std::monostate unreachable() { + std::cout << "Unreachable code reached!" << std::endl; + throw std::runtime_error("Unreachable code reached"); +} + +static int32_t pagesize = 65536; +static int32_t page_count = 0; + +struct Memory_t { + std::vector memory; + Memory_t(int32_t init_page_count) : memory(init_page_count * pagesize) {} + + int32_t loadInt(int32_t base, int32_t offset) { + return *reinterpret_cast(static_cast(memory.data()) + + base + offset); + } + + std::monostate storeInt(int32_t base, int32_t offset, int32_t value) { + *reinterpret_cast(static_cast(memory.data()) + base + + offset) = value; + return std::monostate{}; + } + + // grow memory by delta bytes when bytes > 0. return -1 if failed, return old + // size when success + int32_t grow(int32_t delta) { + if (delta <= 0) { + return memory.size(); + } + + try { + memory.resize(memory.size() + delta * pagesize); + auto old_page_count = page_count; + page_count += delta; + return memory.size(); + } catch (const std::bad_alloc &e) { + return -1; + } + } +}; + +static Memory_t Memory(1); // 1 page memory \ No newline at end of file diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 794ec4bb..2e22e7a7 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -947,6 +947,16 @@ object WasmToScalaCompiler { } trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { + // clear include path and headers by first + includePaths.clear() + headers.clear() + + registerHeader("headers", "\"wasm.hpp\"") + registerHeader("") + registerHeader("") + registerHeader("") + registerHeader("") + override def mayInline(n: Node): Boolean = n match { case Node(_, "stack-pop", _, _) | Node(_, "stack-peek", _, _) @@ -1089,19 +1099,12 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { override def emitAll(g: Graph, name: String)(m1: Manifest[_], m2: Manifest[_]): Unit = { val ng = init(g) - emitln(prelude) + emitHeaders(stream) emitln(""" |/***************************************** |Emitting Generated Code |*******************************************/ """.stripMargin) - - emitln(""" -#include -#include -#include -#include -#include """) val src = run(name, ng) emitFunctionDecls(stream) emitDatastructures(stream) @@ -1116,216 +1119,8 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { | return 0; |}""".stripMargin) } - - val prelude = """ -#include -#include -#include -#include -#include -#include -#include -#include - -void info() { -#ifdef DEBUG - std::cout << std::endl; -#endif -} - -template -void info(const T &first, const Args &...args) { -#ifdef DEBUG - std::cout << first << " "; - info(args...); -#endif -} - -struct Num { - Num(int64_t value) : value(value) {} - Num() : value(0) {} - int64_t value; - int32_t toInt() { return static_cast(value); } - - bool operator==(const Num &other) const { return value == other.value; } - bool operator!=(const Num &other) const { return !(*this == other); } - Num operator+(const Num &other) const { return Num(value + other.value); } - Num operator-(const Num &other) const { return Num(value - other.value); } - Num operator*(const Num &other) const { return Num(value * other.value); } - Num operator/(const Num &other) const { - if (other.value == 0) { - throw std::runtime_error("Division by zero"); - } - return Num(value / other.value); - } - Num operator<(const Num &other) const { return Num(value < other.value); } - Num operator<=(const Num &other) const { return Num(value <= other.value); } - Num operator>(const Num &other) const { return Num(value > other.value); } - Num operator>=(const Num &other) const { return Num(value >= other.value); } - Num operator&(const Num &other) const { return Num(value & other.value); } -}; - -static Num I32V(int v) { return v; } - -static Num I64V(int64_t v) { return v; } - -using Slice = std::vector; - -const int STACK_SIZE = 1024 * 64; - -class Stack_t { -public: - Stack_t() : count(0), stack_ptr(new Num[STACK_SIZE]) {} - - std::monostate push(Num &&num) { - stack_ptr[count] = num; - count++; - return std::monostate{}; - } - - std::monostate push(Num &num) { - stack_ptr[count] = num; - count++; - return std::monostate{}; - } - - Num pop() { -#ifdef DEBUG - if (count == 0) { - throw std::runtime_error("Stack underflow"); - } -#endif - Num num = stack_ptr[count - 1]; - count--; - return num; - } - - Num peek() { -#ifdef DEBUG - if (count == 0) { - throw std::runtime_error("Stack underflow"); - } -#endif - return stack_ptr[count - 1]; - } - - int32_t size() { return count; } - - void shift(int32_t offset, int32_t size) { -#ifdef DEBUG - if (offset < 0) { - throw std::out_of_range("Invalid offset: " + std::to_string(offset)); - } - if (size < 0) { - throw std::out_of_range("Invalid size: " + std::to_string(size)); - } -#endif - // shift last `size` of numbers forward of `offset` - for (int32_t i = count - size; i < count; ++i) { - stack_ptr[i - offset] = stack_ptr[i]; - } - count -= offset; - } - - void print() { - std::cout << "Stack contents: " << std::endl; - for (int32_t i = 0; i < count; ++i) { - std::cout << stack_ptr[count - i - 1].value << std::endl; - } - } - - void initialize() { - // do nothing for now - } - -private: - int32_t count; - Num *stack_ptr; -}; -static Stack_t Stack; - -const int FRAME_SIZE = 1024; - -class Frames_t { -public: - Frames_t() : count(0), stack_ptr(new Num[FRAME_SIZE]) {} - - std::monostate popFrame(std::int32_t size) { - assert(size >= 0); - count -= size; - return std::monostate{}; - } - - Num get(std::int32_t index) { - auto ret = stack_ptr[count - 1 - index]; - return ret; - } - - void set(std::int32_t index, Num num) { stack_ptr[count - 1 - index] = num; } - - void pushFrame(std::int32_t size) { - assert(size >= 0); - count += size; - } - -private: - int32_t count; - Num *stack_ptr; -}; - -static Frames_t Frames; - -static void initRand() { - // for now, just do nothing } -static std::monostate unreachable() { - std::cout << "Unreachable code reached!" << std::endl; - throw std::runtime_error("Unreachable code reached"); -} - -static int32_t pagesize = 65536; -static int32_t page_count = 0; - -struct Memory_t { - std::vector memory; - Memory_t(int32_t init_page_count) : memory(init_page_count * pagesize) {} - - int32_t loadInt(int32_t base, int32_t offset) { - return *reinterpret_cast(static_cast(memory.data()) + - base + offset); - } - - std::monostate storeInt(int32_t base, int32_t offset, int32_t value) { - *reinterpret_cast(static_cast(memory.data()) + base + - offset) = value; - return std::monostate{}; - } - - // grow memory by delta bytes when bytes > 0. return -1 if failed, return old - // size when success - int32_t grow(int32_t delta) { - if (delta <= 0) { - return memory.size(); - } - - try { - memory.resize(memory.size() + delta * pagesize); - auto old_page_count = page_count; - page_count += delta; - return memory.size(); - } catch (const std::bad_alloc &e) { - return -1; - } - } -}; - - -static Memory_t Memory(1); // 1 page memory -""" -} - - trait WasmToCppCompilerDriver[A, B] extends CppSAIDriver[A, B] with StagedWasmEvaluator { q => override val codegen = new StagedWasmCppGen { val IR: q.type = q @@ -1334,15 +1129,39 @@ trait WasmToCppCompilerDriver[A, B] extends CppSAIDriver[A, B] with StagedWasmEv } object WasmToCppCompiler { - def compile(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): String = { + case class GeneratedCpp(source: String, headerFolders: List[String]) + + def compile(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): GeneratedCpp = { println(s"Now compiling wasm module with entry function $main") - val code = new WasmToCppCompilerDriver[Unit, Unit] { + val driver = new WasmToCppCompilerDriver[Unit, Unit] { def module: ModuleInstance = moduleInst def snippet(x: Rep[Unit]): Rep[Unit] = { evalTop(main, printRes) } } - code.code + GeneratedCpp(driver.code, driver.codegen.includePaths.toList) + } + + def compileToExe(moduleInst: ModuleInstance, + main: Option[String], + outputCpp: String, + outputExe: String, + printRes: Boolean = false): Unit = { + val generated = compile(moduleInst, main, printRes) + val code = generated.source + + val writer = new java.io.PrintWriter(new java.io.File(outputCpp)) + try { + writer.write(code) + } finally { + writer.close() + } + + import sys.process._ + val command = s"g++ -std=c++17 $outputCpp -o $outputExe -O3 -g " + generated.headerFolders.map(f => s"-I$f").mkString(" ") + if (command.! != 0) { + throw new RuntimeException(s"Compilation failed for $outputCpp") + } } } diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index d9fa44dd..d4d1e960 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -26,25 +26,11 @@ class TestStagedEval extends FunSuite { def testFileToCpp(filename: String, main: Option[String] = None, expect: Option[List[Float]]=None) = { val moduleInst = ModuleInstance(Parser.parseFile(filename)) - val code = WasmToCppCompiler.compile(moduleInst, main, true) - val cppFile = s"$filename.cpp" - - val writer = new java.io.PrintWriter(new java.io.File(cppFile)) - try { - writer.write(code) - } finally { - writer.close() - } - import sys.process._ - val exe = s"$cppFile.exe" - val command = s"g++ -o $exe $cppFile" - - if (command.! != 0) { - throw new RuntimeException(s"Compilation failed for $cppFile") - } + WasmToCppCompiler.compileToExe(moduleInst, main, cppFile, exe, true) + import sys.process._ val result = s"./$exe".!! println(result) @@ -121,25 +107,11 @@ object Benchmark extends App { def benchmarkWasmToCpp(filePath: String, main: Option[String] = None): Double = { val moduleInst = ModuleInstance(Parser.parseFile(filePath)) - val code = WasmToCppCompiler.compile(moduleInst, main, false) - val cppFile = s"$filePath.cpp" - - val writer = new java.io.PrintWriter(new java.io.File(cppFile)) - try { - writer.write(code) - } finally { - writer.close() - } - import sys.process._ - val exe = s"$cppFile.exe" - val command = s"g++ -std=c++17 -o $exe $cppFile -O3" - - if (command.! != 0) { - throw new RuntimeException(s"Compilation failed for $cppFile") - } + WasmToCppCompiler.compileToExe(moduleInst, main, cppFile, exe, false) + import sys.process._ println(s"Running $exe") bench { assert(s"./$exe".! == 0, s"Execution of $exe failed") } } From 251e01402b58f11c58bdd59941ac72d44ad8657a Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Thu, 3 Jul 2025 20:41:38 +0800 Subject: [PATCH 62/62] move NewStagedEvalCPS.scala to attic --- .../wasm/{StagedEvalCPS.scala => attic/NewStagedEvalCPS.scala} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/main/scala/wasm/{StagedEvalCPS.scala => attic/NewStagedEvalCPS.scala} (100%) diff --git a/src/main/scala/wasm/StagedEvalCPS.scala b/src/main/scala/wasm/attic/NewStagedEvalCPS.scala similarity index 100% rename from src/main/scala/wasm/StagedEvalCPS.scala rename to src/main/scala/wasm/attic/NewStagedEvalCPS.scala