From c420a2a6ab4f17733396d7eddacb6c299f6cbb90 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sat, 5 Jul 2025 21:54:20 +0800 Subject: [PATCH 01/21] remove unused expressions --- src/main/scala/wasm/StagedMiniWasm.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 2e22e7a7..bfa2082d 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -433,10 +433,10 @@ trait StagedWasmEvaluator extends SAIOps { def push(v: StagedNum)(implicit ctx: Context): Context = { v match { - case I32(v) => NumType(I32Type); "stack-push".reflectCtrlWith[Unit](v) - case I64(v) => NumType(I64Type); "stack-push".reflectCtrlWith[Unit](v) - case F32(v) => NumType(F32Type); "stack-push".reflectCtrlWith[Unit](v) - case F64(v) => NumType(F64Type); "stack-push".reflectCtrlWith[Unit](v) + case I32(v) => "stack-push".reflectCtrlWith[Unit](v) + case I64(v) => "stack-push".reflectCtrlWith[Unit](v) + case F32(v) => "stack-push".reflectCtrlWith[Unit](v) + case F64(v) => "stack-push".reflectCtrlWith[Unit](v) } ctx.push(v.tipe) } From fb2a2c4e203fba7f571b3142ed2f739e6a7ed5ad Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Mon, 7 Jul 2025 14:31:16 +0800 Subject: [PATCH 02/21] let's start from the staged miniwasm interpreter --- .../scala/wasm/StagedConcolicMiniWasm.scala | 1170 +++++++++++++++++ 1 file changed, 1170 insertions(+) create mode 100644 src/main/scala/wasm/StagedConcolicMiniWasm.scala diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala new file mode 100644 index 00000000..6b14bcf6 --- /dev/null +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -0,0 +1,1170 @@ +package gensym.wasm.stagedconcolicminiwasm + +import scala.collection.mutable.{ArrayBuffer, HashMap} + +import lms.core.stub.Adapter +import lms.core.virtualize +import lms.macros.SourceContext +import lms.core.stub.{Base, ScalaGenBase, CGenBase} +import lms.core.Backend._ +import lms.core.Backend.{Block => LMSBlock, Const => LMSConst} +import lms.core.Graph + +import gensym.wasm.ast._ +import gensym.wasm.ast.{Const => WasmConst, Block => WasmBlock} +import gensym.wasm.miniwasm.{ModuleInstance} +import gensym.wasm.ast.{Const => WasmConst, Block => WasmBlock} +import gensym.lmsx.{SAIDriver, StringOps, SAIOps, SAICodeGenBase, CppSAIDriver, CppSAICodeGenBase} + +@virtualize +trait StagedWasmEvaluator extends SAIOps { + def module: ModuleInstance + + trait ReturnSite + + trait StagedNum { + def tipe: ValueType = this match { + case I32(_) => NumType(I32Type) + case I64(_) => NumType(I64Type) + case F32(_) => NumType(F32Type) + case F64(_) => NumType(F64Type) + } + + def i: Rep[Num] + } + case class I32(i: Rep[Num]) extends StagedNum + case class I64(i: Rep[Num]) extends StagedNum + case class F32(i: Rep[Num]) extends StagedNum + case class F64(i: Rep[Num]) extends StagedNum + + implicit def toStagedNum(num: Num): StagedNum = { + num match { + case I32V(_) => I32(num) + case I64V(_) => I64(num) + case F32V(_) => F32(num) + case F64V(_) => F64(num) + } + } + + implicit class ValueTypeOps(ty: ValueType) { + def size: Int = ty match { + case NumType(I32Type) => 4 + case NumType(I64Type) => 8 + case NumType(F32Type) => 4 + case NumType(F64Type) => 8 + } + } + + case class Context( + stackTypes: List[ValueType], + frameTypes: List[ValueType] + ) { + def push(ty: ValueType): Context = { + Context(ty :: stackTypes, frameTypes) + } + + def pop(): (ValueType, Context) = { + val (ty :: rest) = stackTypes + (ty, Context(rest, frameTypes)) + } + + def shift(offset: Int, size: Int): Context = { + // Predef.println(s"[DEBUG] Shifting stack by $offset, size $size, $this") + Predef.assert(offset >= 0, s"Context shift offset must be non-negative, get $offset") + if (offset == 0) { + this + } else { + this.copy( + stackTypes = stackTypes.take(size) ++ stackTypes.drop(offset + size) + ) + } + } + } + + type MCont[A] = Unit => A + type Cont[A] = (MCont[A]) => A + type Trail[A] = List[Context => Rep[Cont[A]]] + + // a cache storing the compiled code for each function, to reduce re-compilation + val compileCache = new HashMap[Int, Rep[(MCont[Unit]) => Unit]] + + def makeDummy: Rep[Unit] = "dummy".reflectCtrlWith[Unit]() + + def funHere[A:Manifest,B:Manifest](f: Rep[A] => Rep[B], dummy: Rep[Unit]): Rep[A => B] = { + // to avoid LMS lifting a function, we create a dummy node and read it inside function + fun((x: Rep[A]) => { + "dummy-op".reflectCtrlWith[Unit](dummy) + f(x) + }) + } + + + def eval(insts: List[Instr], + kont: Context => Rep[Cont[Unit]], + mkont: Rep[MCont[Unit]], + trail: Trail[Unit]) + (implicit ctx: Context): Rep[Unit] = { + if (insts.isEmpty) return kont(ctx)(mkont) + + // Predef.println(s"[DEBUG] Evaluating instructions: ${insts.mkString(", ")}") + // Predef.println(s"[DEBUG] Current context: $ctx") + + val (inst, rest) = (insts.head, insts.tail) + inst match { + case Drop => + val (_, newCtx) = Stack.pop() + eval(rest, kont, mkont, trail)(newCtx) + case WasmConst(num) => + val newCtx = Stack.push(num) + eval(rest, kont, mkont, trail)(newCtx) + case LocalGet(i) => + val newCtx = Stack.push(Frames.get(i)) + eval(rest, kont, mkont, trail)(newCtx) + case LocalSet(i) => + val (num, newCtx) = Stack.pop() + Frames.set(i, num)(newCtx) + eval(rest, kont, mkont, trail)(newCtx) + case LocalTee(i) => + val (num, newCtx) = Stack.peek + Frames.set(i, num) + eval(rest, kont, mkont, trail)(newCtx) + case GlobalGet(i) => + val newCtx = Stack.push(Globals(i)) + eval(rest, kont, mkont, trail)(newCtx) + case GlobalSet(i) => + val (value, newCtx) = Stack.pop() + module.globals(i).ty match { + case GlobalType(tipe, true) => Globals(i) = value + case _ => throw new Exception("Cannot set immutable global") + } + eval(rest, kont, mkont, trail)(newCtx) + case Store(StoreOp(align, offset, ty, None)) => + val (value, newCtx1) = Stack.pop() + val (addr, newCtx2) = Stack.pop()(newCtx1) + Memory.storeInt(addr.toInt, offset, value.toInt) + eval(rest, kont, mkont, trail)(newCtx2) + case Nop => eval(rest, kont, mkont, trail) + case Load(LoadOp(align, offset, ty, None, None)) => + val (addr, newCtx1) = Stack.pop() + val value = Memory.loadInt(addr.toInt, offset) + val newCtx2 = Stack.push(Values.I32V(value))(newCtx1) + eval(rest, kont, mkont, trail)(newCtx2) + case MemorySize => ??? + case MemoryGrow => + val (delta, newCtx1) = Stack.pop() + val newCtx2 = Stack.push(Values.I32V(Memory.grow(delta.toInt)))(newCtx1) + eval(rest, kont, mkont, trail)(newCtx2) + case MemoryFill => ??? + case Unreachable => unreachable() + case Test(op) => + val (v, newCtx1) = Stack.pop() + val newCtx2 = Stack.push(evalTestOp(op, v))(newCtx1) + eval(rest, kont, mkont, trail)(newCtx2) + case Unary(op) => + val (v, newCtx1) = Stack.pop() + val newCtx2 = Stack.push(evalUnaryOp(op, v))(newCtx1) + eval(rest, kont, mkont, trail)(newCtx2) + case Binary(op) => + val (v2, newCtx1) = Stack.pop() + val (v1, newCtx2) = Stack.pop()(newCtx1) + val newCtx3 = Stack.push(evalBinOp(op, v1, v2))(newCtx2) + eval(rest, kont, mkont, trail)(newCtx3) + case Compare(op) => + val (v2, newCtx1) = Stack.pop() + val (v1, newCtx2) = Stack.pop()(newCtx1) + val newCtx3 = Stack.push(evalRelOp(op, v1, v2))(newCtx2) + eval(rest, kont, mkont, trail)(newCtx3) + case WasmBlock(ty, inner) => + // no need to modify the stack when entering a block + // the type system guarantees that we will never take more than the input size from the stack + val funcTy = ty.funcType + val exitSize = ctx.stackTypes.size - funcTy.inps.size + funcTy.out.size + val dummy = makeDummy + def restK(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + info(s"Exiting the block, stackSize =", Stack.size) + val offset = restCtx.stackTypes.size - exitSize + val newRestCtx = Stack.shift(offset, funcTy.out.size)(restCtx) + eval(rest, kont, mk, trail)(newRestCtx) + }) + eval(inner, restK _, mkont, restK _ :: trail) + case Loop(ty, inner) => + val funcTy = ty.funcType + val exitSize = ctx.stackTypes.size - funcTy.inps.size + funcTy.out.size + val dummy = makeDummy + def restK(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + info(s"Exiting the loop, stackSize =", Stack.size) + val offset = restCtx.stackTypes.size - exitSize + val newRestCtx = Stack.shift(offset, funcTy.out.size)(restCtx) + eval(rest, kont, mk, trail)(newRestCtx) + }) + val enterSize = ctx.stackTypes.size + def loop(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + info(s"Entered the loop, stackSize =", Stack.size) + val offset = restCtx.stackTypes.size - enterSize + val newRestCtx = Stack.shift(offset, funcTy.inps.size)(restCtx) + eval(inner, restK _, mk, loop _ :: trail)(newRestCtx) + }) + loop(ctx)(mkont) + case If(ty, thn, els) => + val funcTy = ty.funcType + val (cond, newCtx) = Stack.pop() + val exitSize = newCtx.stackTypes.size - funcTy.inps.size + funcTy.out.size + // TODO: can we avoid code duplication here? + val dummy = makeDummy + def restK(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + info(s"Exiting the if, stackSize =", Stack.size) + val offset = restCtx.stackTypes.size - exitSize + val newRestCtx = Stack.shift(offset, funcTy.out.size)(restCtx) + eval(rest, kont, mk, trail)(newRestCtx) + }) + if (cond.toInt != 0) { + eval(thn, restK _, mkont, restK _ :: trail)(newCtx) + } else { + eval(els, restK _, mkont, restK _ :: trail)(newCtx) + } + () + case Br(label) => + info(s"Jump to $label") + trail(label)(ctx)(mkont) + case BrIf(label) => + val (cond, newCtx) = Stack.pop() + info(s"The br_if(${label})'s condition is ", cond.toInt) + if (cond.toInt != 0) { + info(s"Jump to $label") + trail(label)(newCtx)(mkont) + } else { + info(s"Continue") + eval(rest, kont, mkont, trail)(newCtx) + } + () + case BrTable(labels, default) => + val (cond, newCtx) = Stack.pop() + def aux(choices: List[Int], idx: Int): Rep[Unit] = { + if (choices.isEmpty) trail(default)(newCtx)(mkont) + else { + if (cond.toInt == idx) trail(choices.head)(newCtx)(mkont) + else aux(choices.tail, idx + 1) + } + } + aux(labels, 0) + case Return => trail.last(ctx)(mkont) + case Call(f) => evalCall(rest, kont, mkont, trail, f, false) + case ReturnCall(f) => evalCall(rest, kont, mkont, trail, f, true) + case _ => + val todo = "todo-op".reflectCtrlWith[Unit]() + eval(rest, kont, mkont, trail) + } + } + + def forwardKont: Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => mk(())) + + + def evalCall(rest: List[Instr], + kont: Context => Rep[Cont[Unit]], + mkont: Rep[MCont[Unit]], + trail: Trail[Unit], + funcIndex: Int, + isTail: Boolean) + (implicit ctx: Context): Rep[Unit] = { + module.funcs(funcIndex) match { + case FuncDef(_, FuncBodyDef(ty, _, bodyLocals, body)) => + val locals = bodyLocals ++ ty.inps + val callee = + if (compileCache.contains(funcIndex)) { + compileCache(funcIndex) + } else { + val callee = topFun((mk: Rep[MCont[Unit]]) => { + info(s"Entered the function at $funcIndex, stackSize =", Stack.size) + // we can do some check here to ensure the function returns correct size of stack + eval(body, (_: Context) => forwardKont, mk, ((_: Context) => forwardKont)::Nil)(Context(Nil, locals)) + }) + compileCache(funcIndex) = callee + callee + } + // Predef.println(s"[DEBUG] locals size: ${locals.size}") + val (args, newCtx) = Stack.take(ty.inps.size) + if (isTail) { + // when tail call, return to the caller's return continuation + Frames.popFrame(ctx.frameTypes.size) + Frames.pushFrame(locals) + Frames.putAll(args) + callee(mkont) + } else { + // We make a new trail by `restK`, since function creates a new block to escape + // (more or less like `return`) + val restK: Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + info(s"Exiting the function at $funcIndex, stackSize =", Stack.size) + Frames.popFrame(locals.size) + eval(rest, kont, mk, trail)(newCtx.copy(stackTypes = ty.out.reverse ++ ctx.stackTypes.drop(ty.inps.size))) + }) + val dummy = makeDummy + val newMKont: Rep[MCont[Unit]] = funHere((_u: Rep[Unit]) => { + restK(mkont) + }, dummy) + Frames.pushFrame(locals) + Frames.putAll(args) + callee(newMKont) + } + case Import("console", "log", _) + | Import("spectest", "print_i32", _) => + //println(s"[DEBUG] current stack: $stack") + val (v, newCtx) = Stack.pop() + println(v.toInt) + eval(rest, kont, mkont, trail)(newCtx) + case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex") + case _ => throw new Exception(s"Definition at $funcIndex is not callable") + } + } + + def evalTestOp(op: TestOp, value: StagedNum): StagedNum = op match { + case Eqz(_) => Values.I32V(if (value.toInt == 0) 1 else 0) + } + + def evalUnaryOp(op: UnaryOp, value: StagedNum): StagedNum = op match { + case Clz(_) => value.clz() + case Ctz(_) => value.ctz() + case Popcnt(_) => value.popcnt() + case _ => ??? + } + + def evalBinOp(op: BinOp, v1: StagedNum, v2: StagedNum): StagedNum = op match { + case Add(_) => v1 + v2 + case Mul(_) => v1 * v2 + case Sub(_) => v1 - v2 + case Shl(_) => v1 << v2 + // case ShrS(_) => v1 >> v2 // TODO: signed shift right + case ShrU(_) => v1 >> v2 + case And(_) => v1 & v2 + case DivS(_) => v1 / v2 + case DivU(_) => v1 / v2 + case _ => + throw new Exception(s"Unknown binary operation $op") + } + + def evalRelOp(op: RelOp, v1: StagedNum, v2: StagedNum): StagedNum = op match { + case Eq(_) => v1 numEq v2 + case Ne(_) => v1 numNe v2 + case LtS(_) => v1 < v2 + case LtU(_) => v1 ltu v2 + case GtS(_) => v1 > v2 + case GtU(_) => v1 gtu v2 + case LeS(_) => v1 <= v2 + case LeU(_) => v1 leu v2 + case GeS(_) => v1 >= v2 + case GeU(_) => v1 geu v2 + case _ => ??? + } + + def evalTop(mkont: Rep[MCont[Unit]], main: Option[String]): Rep[Unit] = { + val funBody: FuncBodyDef = main match { + case Some(func_name) => + module.defs.flatMap({ + case Export(`func_name`, ExportFunc(fid)) => + Predef.println(s"Now compiling start with function $main") + module.funcs(fid) match { + case FuncDef(_, body@FuncBodyDef(_,_,_,_)) => Some(body) + case _ => throw new Exception("Entry function has no concrete body") + } + case _ => None + }).head + case None => + val startIds = module.defs.flatMap { + case Start(id) => Some(id) + case _ => None + } + val startId = startIds.headOption.getOrElse { throw new Exception("No start function") } + module.funcs(startId) match { + case FuncDef(_, body@FuncBodyDef(_,_,_,_)) => body + case _ => + throw new Exception("Entry function has no concrete body") + } + } + val (instrs, locals) = (funBody.body, funBody.locals) + Stack.initialize() + Frames.pushFrame(locals) + eval(instrs, (_: Context) => forwardKont, mkont, ((_: Context) => forwardKont)::Nil)(Context(Nil, locals)) + Frames.popFrame(locals.size) + } + + def evalTop(main: Option[String], printRes: Boolean = false): Rep[Unit] = { + val haltK: Rep[Unit] => Rep[Unit] = (_) => { + info("Exiting the program...") + if (printRes) { + Stack.print() + } + "no-op".reflectCtrlWith[Unit]() + } + val temp: Rep[MCont[Unit]] = topFun(haltK) + evalTop(temp, main) + } + + // stack operations + object Stack { + def shift(offset: Int, size: Int)(ctx: Context): Context = { + if (offset > 0) { + "stack-shift".reflectCtrlWith[Unit](offset, size) + } + ctx.shift(offset, size) + } + + def initialize(): Rep[Unit] = { + "stack-init".reflectCtrlWith[Unit]() + } + + def pop()(implicit ctx: Context): (StagedNum, Context) = { + val (ty, newContext) = ctx.pop() + val num = ty match { + case NumType(I32Type) => I32("stack-pop".reflectCtrlWith[Num]()) + case NumType(I64Type) => I64("stack-pop".reflectCtrlWith[Num]()) + case NumType(F32Type) => F32("stack-pop".reflectCtrlWith[Num]()) + case NumType(F32Type) => F64("stack-pop".reflectCtrlWith[Num]()) + } + (num, newContext) + } + + def peek(implicit ctx: Context): (StagedNum, Context) = { + val ty = ctx.stackTypes.head + val num = ty match { + case NumType(I32Type) => I32("stack-peek".reflectCtrlWith[Num]()) + case NumType(I64Type) => I64("stack-peek".reflectCtrlWith[Num]()) + case NumType(F32Type) => F32("stack-peek".reflectCtrlWith[Num]()) + case NumType(F32Type) => F64("stack-peek".reflectCtrlWith[Num]()) + } + (num, ctx) + } + + def push(v: StagedNum)(implicit ctx: Context): Context = { + v match { + case I32(v) => "stack-push".reflectCtrlWith[Unit](v) + case I64(v) => "stack-push".reflectCtrlWith[Unit](v) + case F32(v) => "stack-push".reflectCtrlWith[Unit](v) + case F64(v) => "stack-push".reflectCtrlWith[Unit](v) + } + ctx.push(v.tipe) + } + + def take(n: Int)(implicit ctx: Context): (List[StagedNum], Context) = n match { + case 0 => (Nil, ctx) + case n => + val (v, newCtx1) = pop() + val (rest, newCtx2) = take(n - 1) + (v::rest, newCtx2) + } + + def drop(n: Int)(implicit ctx: Context): Context = { + take(n)._2 + } + + def shift(offset: Rep[Int], size: Rep[Int]): Rep[Unit] = { + if (offset > 0) { + "stack-shift".reflectCtrlWith[Unit](offset, size) + } + } + + def print(): Rep[Unit] = { + "stack-print".reflectCtrlWith[Unit]() + } + + def size: Rep[Int] = { + "stack-size".reflectCtrlWith[Int]() + } + } + + object Frames { + def get(i: Int)(implicit ctx: Context): StagedNum = { + // val offset = ctx.frameTypes.take(i).map(_.size).sum + ctx.frameTypes(i) match { + case NumType(I32Type) => I32("frame-get".reflectCtrlWith[Num](i)) + case NumType(I64Type) => I64("frame-get".reflectCtrlWith[Num](i)) + case NumType(F32Type) => F32("frame-get".reflectCtrlWith[Num](i)) + case NumType(F64Type) => F64("frame-get".reflectCtrlWith[Num](i)) + } + } + + def set(i: Int, v: StagedNum)(implicit ctx: Context): Rep[Unit] = { + // val offset = ctx.frameTypes.take(i).map(_.size).sum + v match { + case I32(v) => "frame-set".reflectCtrlWith[Unit](i, v) + case I64(v) => "frame-set".reflectCtrlWith[Unit](i, v) + case F32(v) => "frame-set".reflectCtrlWith[Unit](i, v) + case F64(v) => "frame-set".reflectCtrlWith[Unit](i, v) + } + } + + def pushFrame(locals: List[ValueType]): Rep[Unit] = { + // Predef.println(s"[DEBUG] push frame: $locals") + val size = locals.size + "frame-push".reflectCtrlWith[Unit](size) + } + + def popFrame(size: Int): Rep[Unit] = { + "frame-pop".reflectCtrlWith[Unit](size) + } + + def putAll(args: List[StagedNum])(implicit ctx: Context): Rep[Unit] = { + for ((arg, i) <- args.view.reverse.zipWithIndex) { + Frames.set(i, arg) + } + } + } + + object Memory { + def storeInt(base: Rep[Int], offset: Int, value: Rep[Int]): Rep[Unit] = { + "memory-store-int".reflectCtrlWith[Unit](base, offset, value) + } + + def loadInt(base: Rep[Int], offset: Int): Rep[Int] = { + "memory-load-int".reflectCtrlWith[Int](base, offset) + } + + def grow(delta: Rep[Int]): Rep[Int] = { + "memory-grow".reflectCtrlWith[Int](delta) + } + } + + // call unreachable + def unreachable(): Rep[Unit] = { + "unreachable".reflectCtrlWith[Unit]() + } + + def info(xs: Rep[_]*): Rep[Unit] = { + "info".reflectCtrlWith[Unit](xs: _*) + } + + // runtime values + object Values { + def I32V(i: Rep[Int]): StagedNum = { + I32("I32V".reflectCtrlWith[Num](i)) + } + + def I64V(i: Rep[Long]): StagedNum = { + I64("I64V".reflectCtrlWith[Num](i)) + } + } + + // global read/write + object Globals { + def apply(i: Int): StagedNum = { + module.globals(i).ty match { + case GlobalType(NumType(I32Type), _) => I32("global-get".reflectCtrlWith[Num](i)) + case GlobalType(NumType(I64Type), _) => I64("global-get".reflectCtrlWith[Num](i)) + case GlobalType(NumType(F32Type), _) => F32("global-get".reflectCtrlWith[Num](i)) + case GlobalType(NumType(F64Type), _) => F64("global-get".reflectCtrlWith[Num](i)) + } + } + + def update(i: Int, v: StagedNum): Rep[Unit] = { + module.globals(i).ty match { + case GlobalType(NumType(I32Type), _) => "global-set".reflectCtrlWith[Unit](i) + case GlobalType(NumType(I64Type), _) => "global-set".reflectCtrlWith[Unit](i) + case GlobalType(NumType(F32Type), _) => "global-set".reflectCtrlWith[Unit](i) + case GlobalType(NumType(F64Type), _) => "global-set".reflectCtrlWith[Unit](i) + } + } + } + + // runtime Num type + implicit class StagedNumOps(num: StagedNum) { + + def toInt: Rep[Int] = "num-to-int".reflectCtrlWith[Int](num.i) + + def clz(): StagedNum = num match { + case I32(i) => I32("clz".reflectCtrlWith[Num](i)) + case I64(i) => I64("clz".reflectCtrlWith[Num](i)) + } + + def ctz(): StagedNum = num match { + case I32(i) => I32("ctz".reflectCtrlWith[Num](i)) + case I64(i) => I64("ctz".reflectCtrlWith[Num](i)) + } + + def popcnt(): StagedNum = num match { + case I32(i) => I32("popcnt".reflectCtrlWith[Num](i)) + case I64(i) => I64("popcnt".reflectCtrlWith[Num](i)) + } + + def +(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("binary-add".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I64("binary-add".reflectCtrlWith[Num](x, y)) + case (F32(x), F32(y)) => F32("binary-add".reflectCtrlWith[Num](x, y)) + case (F64(x), F64(y)) => F64("binary-add".reflectCtrlWith[Num](x, y)) + } + } + + def -(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("binary-sub".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I64("binary-sub".reflectCtrlWith[Num](x, y)) + case (F32(x), F32(y)) => F32("binary-sub".reflectCtrlWith[Num](x, y)) + case (F64(x), F64(y)) => F64("binary-sub".reflectCtrlWith[Num](x, y)) + } + } + + def *(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("binary-mul".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I64("binary-mul".reflectCtrlWith[Num](x, y)) + case (F32(x), F32(y)) => F32("binary-mul".reflectCtrlWith[Num](x, y)) + case (F64(x), F64(y)) => F64("binary-mul".reflectCtrlWith[Num](x, y)) + } + } + + def /(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("binary-div".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I64("binary-div".reflectCtrlWith[Num](x, y)) + case (F32(x), F32(y)) => F32("binary-div".reflectCtrlWith[Num](x, y)) + case (F64(x), F64(y)) => F64("binary-div".reflectCtrlWith[Num](x, y)) + } + } + + def <<(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("binary-shl".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I64("binary-shl".reflectCtrlWith[Num](x, y)) + } + } + + def >>(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("binary-shr".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I64("binary-shr".reflectCtrlWith[Num](x, y)) + } + } + + def &(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("binary-and".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I64("binary-and".reflectCtrlWith[Num](x, y)) + } + } + + def numEq(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("relation-eq".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I32("relation-eq".reflectCtrlWith[Num](x, y)) + } + } + + def numNe(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("relation-ne".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I32("relation-ne".reflectCtrlWith[Num](x, y)) + } + } + + def <(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("relation-lt".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I32("relation-lt".reflectCtrlWith[Num](x, y)) + } + } + + def ltu(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("relation-ltu".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I32("relation-ltu".reflectCtrlWith[Num](x, y)) + } + } + + def >(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("relation-gt".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I32("relation-gt".reflectCtrlWith[Num](x, y)) + } + } + + def gtu(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("relation-gtu".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I32("relation-gtu".reflectCtrlWith[Num](x, y)) + } + } + + def <=(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("relation-le".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I32("relation-le".reflectCtrlWith[Num](x, y)) + } + } + + def leu(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("relation-leu".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I32("relation-leu".reflectCtrlWith[Num](x, y)) + } + } + + def >=(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("relation-ge".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I32("relation-ge".reflectCtrlWith[Num](x, y)) + } + } + + def geu(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x), I32(y)) => I32("relation-geu".reflectCtrlWith[Num](x, y)) + case (I64(x), I64(y)) => I32("relation-geu".reflectCtrlWith[Num](x, y)) + } + } + } +} + +trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { + override def mayInline(n: Node): Boolean = n match { + case Node(s, "stack-pop", _, _) => false + case _ => super.mayInline(n) + } + + override def traverse(n: Node): Unit = n match { + case Node(_, "stack-drop", List(n), _) => + emit("Stack.drop("); shallow(n); emit(")\n") + case Node(_, "stack-reset", List(n), _) => + emit("Stack.reset("); shallow(n); emit(")\n") + case Node(_, "stack-init", _, _) => + emit("Stack.initialize()\n") + case Node(_, "stack-print", _, _) => + emit("Stack.print()\n") + case Node(_, "frame-push", List(i), _) => + emit("Frames.pushFrame("); shallow(i); emit(")\n") + case Node(_, "frame-pop", List(i), _) => + emit("Frames.popFrame("); shallow(i); emit(")\n") + case Node(_, "frame-putAll", List(args), _) => + emit("Frames.putAll("); shallow(args); emit(")\n") + case Node(_, "frame-set", List(i, value), _) => + emit("Frames.set("); shallow(i); emit(", "); shallow(value); emit(")\n") + case Node(_, "global-set", List(i, value), _) => + emit("Global.globalSet("); shallow(i); emit(", "); shallow(value); emit(")\n") + case _ => super.traverse(n) + } + + // code generation for pure nodes + override def shallow(n: Node): Unit = n match { + case Node(_, "frame-get", List(i), _) => + emit("Frames.get("); shallow(i); emit(")") + case Node(_, "frame-pop", List(i), _) => + emit("Frames.popFrame("); shallow(i); emit(")") + case Node(_, "stack-push", List(value), _) => + emit("Stack.push("); shallow(value); emit(")") + case Node(_, "stack-pop", _, _) => + emit("Stack.pop()") + case Node(_, "stack-peek", _, _) => + emit("Stack.peek") + case Node(_, "stack-take", List(n), _) => + emit("Stack.take("); shallow(n); emit(")") + case Node(_, "stack-size", _, _) => + emit("Stack.size") + case Node(_, "global-get", List(i), _) => + emit("Global.globalGet("); shallow(i); emit(")") + case Node(_, "binary-add", List(lhs, rhs), _) => + shallow(lhs); emit(" + "); shallow(rhs) + case Node(_, "binary-sub", List(lhs, rhs), _) => + shallow(lhs); emit(" - "); shallow(rhs) + case Node(_, "binary-mul", List(lhs, rhs), _) => + shallow(lhs); emit(" * "); shallow(rhs) + case Node(_, "binary-div", List(lhs, rhs), _) => + shallow(lhs); emit(" / "); shallow(rhs) + case Node(_, "binary-shl", List(lhs, rhs), _) => + shallow(lhs); emit(" << "); shallow(rhs) + case Node(_, "binary-shr", List(lhs, rhs), _) => + shallow(lhs); emit(" >> "); shallow(rhs) + case Node(_, "binary-and", List(lhs, rhs), _) => + shallow(lhs); emit(" & "); shallow(rhs) + case Node(_, "relation-eq", List(lhs, rhs), _) => + shallow(lhs); emit(" == "); shallow(rhs) + case Node(_, "relation-ne", List(lhs, rhs), _) => + shallow(lhs); emit(" != "); shallow(rhs) + case Node(_, "relation-lt", List(lhs, rhs), _) => + shallow(lhs); emit(" < "); shallow(rhs) + case Node(_, "relation-ltu", List(lhs, rhs), _) => + shallow(lhs); emit(" < "); shallow(rhs) + case Node(_, "relation-gt", List(lhs, rhs), _) => + shallow(lhs); emit(" > "); shallow(rhs) + case Node(_, "relation-gtu", List(lhs, rhs), _) => + shallow(lhs); emit(" > "); shallow(rhs) + case Node(_, "relation-le", List(lhs, rhs), _) => + shallow(lhs); emit(" <= "); shallow(rhs) + case Node(_, "relation-leu", List(lhs, rhs), _) => + shallow(lhs); emit(" <= "); shallow(rhs) + case Node(_, "relation-ge", List(lhs, rhs), _) => + shallow(lhs); emit(" >= "); shallow(rhs) + case Node(_, "relation-geu", List(lhs, rhs), _) => + shallow(lhs); emit(" >= "); shallow(rhs) + case Node(_, "num-to-int", List(num), _) => + shallow(num); emit(".toInt") + case Node(_, "no-op", _, _) => + emit("()") + case _ => super.shallow(n) + } +} + +trait WasmToScalaCompilerDriver[A, B] + extends SAIDriver[A, B] with StagedWasmEvaluator { q => + override val codegen = new StagedWasmScalaGen { + val IR: q.type = q + import IR._ + override def remap(m: Manifest[_]): String = { + if (m.toString.endsWith("Stack")) "Stack" + else if(m.toString.endsWith("Frame")) "Frame" + else super.remap(m) + } + } + + override val prelude = + """ +object Prelude { + sealed abstract class Num { + def +(that: Num): Num = (this, that) match { + case (I32V(x), I32V(y)) => I32V(x + y) + case (I64V(x), I64V(y)) => I64V(x + y) + case _ => throw new RuntimeException("Invalid addition") + } + + def -(that: Num): Num = (this, that) match { + case (I32V(x), I32V(y)) => I32V(x - y) + case (I64V(x), I64V(y)) => I64V(x - y) + case _ => throw new RuntimeException("Invalid subtraction") + } + + def !=(that: Num): Num = (this, that) match { + case (I32V(x), I32V(y)) => I32V(if (x != y) 1 else 0) + case (I64V(x), I64V(y)) => I32V(if (x != y) 1 else 0) + case _ => throw new RuntimeException("Invalid inequality") + } + + def toInt: Int = this match { + case I32V(i) => i + case I64V(i) => i.toInt + } + } + case class I32V(i: Int) extends Num + case class I64V(i: Long) extends Num + +object Stack { + private val buffer = new scala.collection.mutable.ArrayBuffer[Num]() + def push(v: Num): Unit = buffer.append(v) + def pop(): Num = { + buffer.remove(buffer.size - 1) + } + def peek: Num = { + buffer.last + } + def size: Int = buffer.size + def drop(n: Int): Unit = { + buffer.remove(buffer.size - n, n) + } + def take(n: Int): List[Num] = { + val xs = buffer.takeRight(n).toList + drop(n) + xs + } + def reset(size: Int): Unit = { + info(s"Reset stack to size $size") + while (buffer.size > size) { + buffer.remove(buffer.size - 1) + } + } + def initialize(): Unit = buffer.clear() + def print(): Unit = { + println("Stack: " + buffer.mkString(", ")) + } +} + + class Frame(val size: Int) { + private val data = new Array[Num](size) + def apply(i: Int): Num = { + info(s"frame(${i}) is ${data(i)}") + data(i) + } + def update(i: Int, v: Num): Unit = { + info(s"set frame(${i}) to ${v}") + data(i) = v + } + def putAll(xs: List[Num]): Unit = { + for (i <- 0 until xs.size) { + data(i) = xs(i) + } + } + override def toString: String = { + "Frame(" + data.mkString(", ") + ")" + } + } + + object Frames { + private var frames = List[Frame]() + def pushFrame(size: Int): Unit = { + frames = new Frame(size) :: frames + } + def popFrame(): Unit = { + frames = frames.tail + } + def top: Frame = frames.head + def set(i: Int, v: Num): Unit = { + top(i) = v + } + def get(i: Int): Num = { + top(i) + } + } + + object Global { + // TODO: create global with specific size + private val globals = new Array[Num](10) + def globalGet(i: Int): Num = globals(i) + def globalSet(i: Int, v: Num): Unit = globals(i) = v + } + + def info(xs: Any*): Unit = { + if (System.getenv("DEBUG") != null) { + println("[INFO] " + xs.mkString(" ")) + } + } +} +import Prelude._ + + +object Main { + def main(args: Array[String]): Unit = { + val snippet = new Snippet() + snippet(()) + } +} +""" +} + + +object WasmToScalaCompiler { + def compile(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): String = { + println(s"Now compiling wasm module with entry function $main") + val code = new WasmToScalaCompilerDriver[Unit, Unit] { + def module: ModuleInstance = moduleInst + def snippet(x: Rep[Unit]): Rep[Unit] = { + evalTop(main, printRes) + } + } + code.code + } +} + +trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { + // clear include path and headers by first + includePaths.clear() + headers.clear() + + registerHeader("headers", "\"wasm.hpp\"") + registerHeader("") + registerHeader("") + registerHeader("") + registerHeader("") + + override def mayInline(n: Node): Boolean = n match { + case Node(_, "stack-pop", _, _) + | Node(_, "stack-peek", _, _) + => false + case _ => super.mayInline(n) + } + + override def remap(m: Manifest[_]): String = { + if (m.toString.endsWith("Num")) "Num" + else if (m.toString.endsWith("Frame")) "Frame" + else if (m.toString.endsWith("Stack")) "Stack" + else if (m.toString.endsWith("Global")) "Global" + else if (m.toString.endsWith("I32V")) "I32V" + else if (m.toString.endsWith("I64V")) "I64V" + else super.remap(m) + } + + // for now, the traverse/shallow is same as the scala backend's + override def traverse(n: Node): Unit = n match { + case Node(_, "stack-push", List(value), _) => + emit("Stack.push("); shallow(value); emit(");\n") + case Node(_, "stack-drop", List(n), _) => + emit("Stack.drop("); shallow(n); emit(");\n") + case Node(_, "stack-init", _, _) => + emit("Stack.initialize();\n") + case Node(_, "stack-print", _, _) => + emit("Stack.print();\n") + case Node(_, "frame-push", List(i), _) => + emit("Frames.pushFrame("); shallow(i); emit(");\n") + case Node(_, "frame-pop", List(i), _) => + emit("Frames.popFrame("); shallow(i); emit(");\n") + case Node(_, "frame-putAll", List(args), _) => + emit("Frames.putAll("); shallow(args); emit(");\n") + case Node(_, "frame-set", List(i, value), _) => + emit("Frames.set("); shallow(i); emit(", "); shallow(value); emit(");\n") + case Node(_, "global-set", List(i, value), _) => + emit("Global.globalSet("); shallow(i); emit(", "); shallow(value); emit(");\n") + // Note: The following code is copied from the traverse of CppBackend.scala, try to avoid duplicated code + case n @ Node(f, "λ", (b: LMSBlock)::LMSConst(0)::rest, _) => + // TODO: Is a leading block followed by 0 a hint for top function? + super.traverse(n) + case n @ Node(f, "λ", (b: LMSBlock)::rest, _) => + val retType = remap(typeBlockRes(b.res)) + val argTypes = b.in.map(a => remap(typeMap(a))).mkString(", ") + emitln(s"std::function<$retType(${argTypes})> ${quote(f)};") + emit(quote(f)); emit(" = ") + quoteTypedBlock(b, false, true, capture = "&") + emitln(";") + case _ => super.traverse(n) + } + + // code generation for pure nodes + override def shallow(n: Node): Unit = n match { + case Node(_, "frame-get", List(i), _) => + emit("Frames.get("); shallow(i); emit(")") + case Node(_, "stack-drop", List(n), _) => + emit("Stack.drop("); shallow(n); emit(")") + case Node(_, "stack-push", List(value), _) => + emit("Stack.push("); shallow(value); emit(")") + case Node(_, "stack-shift", List(offset, size), _) => + emit("Stack.shift("); shallow(offset); emit(", "); shallow(size); emit(")") + case Node(_, "stack-pop", _, _) => + emit("Stack.pop()") + case Node(_, "frame-pop", List(i), _) => + emit("Frames.popFrame("); shallow(i); emit(")") + case Node(_, "stack-peek", _, _) => + emit("Stack.peek()") + case Node(_, "stack-take", List(n), _) => + emit("Stack.take("); shallow(n); emit(")") + case Node(_, "slice-reverse", List(slice), _) => + shallow(slice); emit(".reverse") + case Node(_, "memory-store-int", List(base, offset, value), _) => + emit("Memory.storeInt("); shallow(base); emit(", "); shallow(offset); emit(", "); shallow(value); emit(")") + case Node(_, "memory-load-int", List(base, offset), _) => + emit("Memory.loadInt("); shallow(base); emit(", "); shallow(offset); emit(")") + case Node(_, "memory-grow", List(delta), _) => + emit("Memory.grow("); shallow(delta); emit(")") + case Node(_, "stack-size", _, _) => + emit("Stack.size()") + case Node(_, "global-get", List(i), _) => + emit("Global.globalGet("); shallow(i); emit(")") + case Node(_, "binary-add", List(lhs, rhs), _) => + shallow(lhs); emit(" + "); shallow(rhs) + case Node(_, "binary-sub", List(lhs, rhs), _) => + shallow(lhs); emit(" - "); shallow(rhs) + case Node(_, "binary-mul", List(lhs, rhs), _) => + shallow(lhs); emit(" * "); shallow(rhs) + case Node(_, "binary-div", List(lhs, rhs), _) => + shallow(lhs); emit(" / "); shallow(rhs) + case Node(_, "binary-shl", List(lhs, rhs), _) => + shallow(lhs); emit(" << "); shallow(rhs) + case Node(_, "binary-shr", List(lhs, rhs), _) => + shallow(lhs); emit(" >> "); shallow(rhs) + case Node(_, "binary-and", List(lhs, rhs), _) => + shallow(lhs); emit(" & "); shallow(rhs) + case Node(_, "relation-eq", List(lhs, rhs), _) => + shallow(lhs); emit(" == "); shallow(rhs) + case Node(_, "relation-ne", List(lhs, rhs), _) => + shallow(lhs); emit(" != "); shallow(rhs) + case Node(_, "relation-lt", List(lhs, rhs), _) => + shallow(lhs); emit(" < "); shallow(rhs) + case Node(_, "relation-ltu", List(lhs, rhs), _) => + shallow(lhs); emit(" < "); shallow(rhs) + case Node(_, "relation-gt", List(lhs, rhs), _) => + shallow(lhs); emit(" > "); shallow(rhs) + case Node(_, "relation-gtu", List(lhs, rhs), _) => + shallow(lhs); emit(" > "); shallow(rhs) + case Node(_, "relation-le", List(lhs, rhs), _) => + shallow(lhs); emit(" <= "); shallow(rhs) + case Node(_, "relation-leu", List(lhs, rhs), _) => + shallow(lhs); emit(" <= "); shallow(rhs) + case Node(_, "relation-ge", List(lhs, rhs), _) => + shallow(lhs); emit(" >= "); shallow(rhs) + case Node(_, "relation-geu", List(lhs, rhs), _) => + shallow(lhs); emit(" >= "); shallow(rhs) + case Node(_, "num-to-int", List(num), _) => + shallow(num); emit(".toInt()") + case Node(_, "dummy", _, _) => emit("std::monostate()") + case Node(_, "dummy-op", _, _) => emit("std::monostate()") + case Node(_, "no-op", _, _) => + emit("std::monostate()") + case _ => super.shallow(n) + } + + override def registerTopLevelFunction(id: String, streamId: String = "general")(f: => Unit) = + if (!registeredFunctions(id)) { + //if (ongoingFun(streamId)) ??? + //ongoingFun += streamId + registeredFunctions += id + withStream(functionsStreams.getOrElseUpdate(id, { + val functionsStream = new java.io.ByteArrayOutputStream() + val functionsWriter = new java.io.PrintStream(functionsStream) + (functionsWriter, functionsStream) + })._1)(f) + //ongoingFun -= streamId + } else { + // If a function is registered, don't re-register it. + // withStream(functionsStreams(id)._1)(f) + } + + override def emitAll(g: Graph, name: String)(m1: Manifest[_], m2: Manifest[_]): Unit = { + val ng = init(g) + emitHeaders(stream) + emitln(""" + |/***************************************** + |Emitting Generated Code + |*******************************************/ + """.stripMargin) + val src = run(name, ng) + emitFunctionDecls(stream) + emitDatastructures(stream) + emitFunctions(stream) + emit(src) + emitln(""" + |/***************************************** + |End of Generated Code + |*******************************************/ + |int main(int argc, char *argv[]) { + | Snippet(std::monostate{}); + | return 0; + |}""".stripMargin) + } +} + +trait WasmToCppCompilerDriver[A, B] extends CppSAIDriver[A, B] with StagedWasmEvaluator { q => + override val codegen = new StagedWasmCppGen { + val IR: q.type = q + import IR._ + } +} + +object WasmToCppCompiler { + case class GeneratedCpp(source: String, headerFolders: List[String]) + + def compile(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): GeneratedCpp = { + println(s"Now compiling wasm module with entry function $main") + val driver = new WasmToCppCompilerDriver[Unit, Unit] { + def module: ModuleInstance = moduleInst + def snippet(x: Rep[Unit]): Rep[Unit] = { + evalTop(main, printRes) + } + } + GeneratedCpp(driver.code, driver.codegen.includePaths.toList) + } + + def compileToExe(moduleInst: ModuleInstance, + main: Option[String], + outputCpp: String, + outputExe: String, + printRes: Boolean = false): Unit = { + val generated = compile(moduleInst, main, printRes) + val code = generated.source + + val writer = new java.io.PrintWriter(new java.io.File(outputCpp)) + try { + writer.write(code) + } finally { + writer.close() + } + + import sys.process._ + val command = s"g++ -std=c++17 $outputCpp -o $outputExe -O3 -g " + generated.headerFolders.map(f => s"-I$f").mkString(" ") + if (command.! != 0) { + throw new RuntimeException(s"Compilation failed for $outputCpp") + } + } + +} + From 27e3e32f9269303d21d65583f6b2d14a74b04dbe Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Mon, 7 Jul 2025 14:34:24 +0800 Subject: [PATCH 03/21] dup all concrete operations to symbolic --- .../scala/wasm/StagedConcolicMiniWasm.scala | 444 +++++------------- 1 file changed, 111 insertions(+), 333 deletions(-) diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index 6b14bcf6..22b593a9 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -13,8 +13,9 @@ import lms.core.Graph import gensym.wasm.ast._ import gensym.wasm.ast.{Const => WasmConst, Block => WasmBlock} import gensym.wasm.miniwasm.{ModuleInstance} -import gensym.wasm.ast.{Const => WasmConst, Block => WasmBlock} +import gensym.wasm.symbolic.{SymVal} import gensym.lmsx.{SAIDriver, StringOps, SAIOps, SAICodeGenBase, CppSAIDriver, CppSAICodeGenBase} +import gensym.wasm.symbolic.Concrete @virtualize trait StagedWasmEvaluator extends SAIOps { @@ -24,25 +25,27 @@ trait StagedWasmEvaluator extends SAIOps { trait StagedNum { def tipe: ValueType = this match { - case I32(_) => NumType(I32Type) - case I64(_) => NumType(I64Type) - case F32(_) => NumType(F32Type) - case F64(_) => NumType(F64Type) + case I32(_, _) => NumType(I32Type) + case I64(_, _) => NumType(I64Type) + case F32(_, _) => NumType(F32Type) + case F64(_, _) => NumType(F64Type) } def i: Rep[Num] + + def s: Rep[SymVal] } - case class I32(i: Rep[Num]) extends StagedNum - case class I64(i: Rep[Num]) extends StagedNum - case class F32(i: Rep[Num]) extends StagedNum - case class F64(i: Rep[Num]) extends StagedNum + case class I32(i: Rep[Num], s: Rep[SymVal]) extends StagedNum + case class I64(i: Rep[Num], s: Rep[SymVal]) extends StagedNum + case class F32(i: Rep[Num], s: Rep[SymVal]) extends StagedNum + case class F64(i: Rep[Num], s: Rep[SymVal]) extends StagedNum implicit def toStagedNum(num: Num): StagedNum = { num match { - case I32V(_) => I32(num) - case I64V(_) => I64(num) - case F32V(_) => F32(num) - case F64V(_) => F64(num) + case I32V(_) => I32(num, Concrete(num)) + case I64V(_) => I64(num, Concrete(num)) + case F32V(_) => F32(num, Concrete(num)) + case F64V(_) => F64(num, Concrete(num)) } } @@ -147,7 +150,7 @@ trait StagedWasmEvaluator extends SAIOps { case Load(LoadOp(align, offset, ty, None, None)) => val (addr, newCtx1) = Stack.pop() val value = Memory.loadInt(addr.toInt, offset) - val newCtx2 = Stack.push(Values.I32V(value))(newCtx1) + val newCtx2 = Stack.push(value)(newCtx1) eval(rest, kont, mkont, trail)(newCtx2) case MemorySize => ??? case MemoryGrow => @@ -414,10 +417,10 @@ trait StagedWasmEvaluator extends SAIOps { def pop()(implicit ctx: Context): (StagedNum, Context) = { val (ty, newContext) = ctx.pop() val num = ty match { - case NumType(I32Type) => I32("stack-pop".reflectCtrlWith[Num]()) - case NumType(I64Type) => I64("stack-pop".reflectCtrlWith[Num]()) - case NumType(F32Type) => F32("stack-pop".reflectCtrlWith[Num]()) - case NumType(F32Type) => F64("stack-pop".reflectCtrlWith[Num]()) + case NumType(I32Type) => I32("stack-pop".reflectCtrlWith[Num](), "sym-stack-pop".reflectCtrlWith[SymVal]()) + case NumType(I64Type) => I64("stack-pop".reflectCtrlWith[Num](), "sym-stack-pop".reflectCtrlWith[SymVal]()) + case NumType(F32Type) => F32("stack-pop".reflectCtrlWith[Num](), "sym-stack-pop".reflectCtrlWith[SymVal]()) + case NumType(F32Type) => F64("stack-pop".reflectCtrlWith[Num](), "sym-stack-pop".reflectCtrlWith[SymVal]()) } (num, newContext) } @@ -425,22 +428,22 @@ trait StagedWasmEvaluator extends SAIOps { def peek(implicit ctx: Context): (StagedNum, Context) = { val ty = ctx.stackTypes.head val num = ty match { - case NumType(I32Type) => I32("stack-peek".reflectCtrlWith[Num]()) - case NumType(I64Type) => I64("stack-peek".reflectCtrlWith[Num]()) - case NumType(F32Type) => F32("stack-peek".reflectCtrlWith[Num]()) - case NumType(F32Type) => F64("stack-peek".reflectCtrlWith[Num]()) + case NumType(I32Type) => I32("stack-peek".reflectCtrlWith[Num](), "sym-stack-peek".reflectCtrlWith[SymVal]()) + case NumType(I64Type) => I64("stack-peek".reflectCtrlWith[Num](), "sym-stack-peek".reflectCtrlWith[SymVal]()) + case NumType(F32Type) => F32("stack-peek".reflectCtrlWith[Num](), "sym-stack-peek".reflectCtrlWith[SymVal]()) + case NumType(F32Type) => F64("stack-peek".reflectCtrlWith[Num](), "sym-stack-peek".reflectCtrlWith[SymVal]()) } (num, ctx) } - def push(v: StagedNum)(implicit ctx: Context): Context = { - v match { - case I32(v) => "stack-push".reflectCtrlWith[Unit](v) - case I64(v) => "stack-push".reflectCtrlWith[Unit](v) - case F32(v) => "stack-push".reflectCtrlWith[Unit](v) - case F64(v) => "stack-push".reflectCtrlWith[Unit](v) + def push(num: StagedNum)(implicit ctx: Context): Context = { + num match { + case I32(v, s) => "stack-push".reflectCtrlWith[Unit](v); "sym-stack-push".reflectCtrlWith[Unit](s) + case I64(v, s) => "stack-push".reflectCtrlWith[Unit](v); "sym-stack-push".reflectCtrlWith[Unit](s) + case F32(v, s) => "stack-push".reflectCtrlWith[Unit](v); "sym-stack-push".reflectCtrlWith[Unit](s) + case F64(v, s) => "stack-push".reflectCtrlWith[Unit](v); "sym-stack-push".reflectCtrlWith[Unit](s) } - ctx.push(v.tipe) + ctx.push(num.tipe) } def take(n: Int)(implicit ctx: Context): (List[StagedNum], Context) = n match { @@ -458,6 +461,7 @@ trait StagedWasmEvaluator extends SAIOps { def shift(offset: Rep[Int], size: Rep[Int]): Rep[Unit] = { if (offset > 0) { "stack-shift".reflectCtrlWith[Unit](offset, size) + "sym-stack-shift".reflectCtrlWith[Unit](offset, size) } } @@ -474,20 +478,20 @@ trait StagedWasmEvaluator extends SAIOps { def get(i: Int)(implicit ctx: Context): StagedNum = { // val offset = ctx.frameTypes.take(i).map(_.size).sum ctx.frameTypes(i) match { - case NumType(I32Type) => I32("frame-get".reflectCtrlWith[Num](i)) - case NumType(I64Type) => I64("frame-get".reflectCtrlWith[Num](i)) - case NumType(F32Type) => F32("frame-get".reflectCtrlWith[Num](i)) - case NumType(F64Type) => F64("frame-get".reflectCtrlWith[Num](i)) + case NumType(I32Type) => I32("frame-get".reflectCtrlWith[Num](i), "sym-frame-get".reflectCtrlWith[SymVal](i)) + case NumType(I64Type) => I64("frame-get".reflectCtrlWith[Num](i), "sym-frame-get".reflectCtrlWith[SymVal](i)) + case NumType(F32Type) => F32("frame-get".reflectCtrlWith[Num](i), "sym-frame-get".reflectCtrlWith[SymVal](i)) + case NumType(F64Type) => F64("frame-get".reflectCtrlWith[Num](i), "sym-frame-get".reflectCtrlWith[SymVal](i)) } } def set(i: Int, v: StagedNum)(implicit ctx: Context): Rep[Unit] = { // val offset = ctx.frameTypes.take(i).map(_.size).sum v match { - case I32(v) => "frame-set".reflectCtrlWith[Unit](i, v) - case I64(v) => "frame-set".reflectCtrlWith[Unit](i, v) - case F32(v) => "frame-set".reflectCtrlWith[Unit](i, v) - case F64(v) => "frame-set".reflectCtrlWith[Unit](i, v) + case I32(v, s) => "frame-set".reflectCtrlWith[Unit](i, v); "sym-frame-set".reflectCtrlWith[Unit](i, s) + case I64(v, s) => "frame-set".reflectCtrlWith[Unit](i, v); "sym-frame-set".reflectCtrlWith[Unit](i, s) + case F32(v, s) => "frame-set".reflectCtrlWith[Unit](i, v); "sym-frame-set".reflectCtrlWith[Unit](i, s) + case F64(v, s) => "frame-set".reflectCtrlWith[Unit](i, v); "sym-frame-set".reflectCtrlWith[Unit](i, s) } } @@ -495,10 +499,12 @@ trait StagedWasmEvaluator extends SAIOps { // Predef.println(s"[DEBUG] push frame: $locals") val size = locals.size "frame-push".reflectCtrlWith[Unit](size) + "sym-frame-push".reflectCtrlWith[Unit](size) } def popFrame(size: Int): Rep[Unit] = { "frame-pop".reflectCtrlWith[Unit](size) + "sym-frame-pop".reflectCtrlWith[Unit](size) } def putAll(args: List[StagedNum])(implicit ctx: Context): Rep[Unit] = { @@ -513,8 +519,8 @@ trait StagedWasmEvaluator extends SAIOps { "memory-store-int".reflectCtrlWith[Unit](base, offset, value) } - def loadInt(base: Rep[Int], offset: Int): Rep[Int] = { - "memory-load-int".reflectCtrlWith[Int](base, offset) + def loadInt(base: Rep[Int], offset: Int): StagedNum = { + I32("I32V".reflectCtrlWith[Num]("memory-load-int".reflectCtrlWith[Int](base, offset)), "sym-load-int-todo".reflectCtrlWith[SymVal](base, offset)) } def grow(delta: Rep[Int]): Rep[Int] = { @@ -534,11 +540,11 @@ trait StagedWasmEvaluator extends SAIOps { // runtime values object Values { def I32V(i: Rep[Int]): StagedNum = { - I32("I32V".reflectCtrlWith[Num](i)) + I32("I32V".reflectCtrlWith[Num](i), "Concrete".reflectCtrlWith[SymVal]("I32V".reflectCtrlWith[Num](i))) } def I64V(i: Rep[Long]): StagedNum = { - I64("I64V".reflectCtrlWith[Num](i)) + I64("I64V".reflectCtrlWith[Num](i), "Concrete".reflectCtrlWith[SymVal]("I64V".reflectCtrlWith[Num](i))) } } @@ -546,19 +552,19 @@ trait StagedWasmEvaluator extends SAIOps { object Globals { def apply(i: Int): StagedNum = { module.globals(i).ty match { - case GlobalType(NumType(I32Type), _) => I32("global-get".reflectCtrlWith[Num](i)) - case GlobalType(NumType(I64Type), _) => I64("global-get".reflectCtrlWith[Num](i)) - case GlobalType(NumType(F32Type), _) => F32("global-get".reflectCtrlWith[Num](i)) - case GlobalType(NumType(F64Type), _) => F64("global-get".reflectCtrlWith[Num](i)) + case GlobalType(NumType(I32Type), _) => I32("global-get".reflectCtrlWith[Num](i), "sym-global-get".reflectCtrlWith[SymVal](i)) + case GlobalType(NumType(I64Type), _) => I64("global-get".reflectCtrlWith[Num](i), "sym-global-get".reflectCtrlWith[SymVal](i)) + case GlobalType(NumType(F32Type), _) => F32("global-get".reflectCtrlWith[Num](i), "sym-global-get".reflectCtrlWith[SymVal](i)) + case GlobalType(NumType(F64Type), _) => F64("global-get".reflectCtrlWith[Num](i), "sym-global-get".reflectCtrlWith[SymVal](i)) } } def update(i: Int, v: StagedNum): Rep[Unit] = { module.globals(i).ty match { - case GlobalType(NumType(I32Type), _) => "global-set".reflectCtrlWith[Unit](i) - case GlobalType(NumType(I64Type), _) => "global-set".reflectCtrlWith[Unit](i) - case GlobalType(NumType(F32Type), _) => "global-set".reflectCtrlWith[Unit](i) - case GlobalType(NumType(F64Type), _) => "global-set".reflectCtrlWith[Unit](i) + case GlobalType(NumType(I32Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i);"sym-global-set".reflectCtrlWith[Unit](i, v.s) + case GlobalType(NumType(I64Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i);"sym-global-set".reflectCtrlWith[Unit](i, v.s) + case GlobalType(NumType(F32Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i);"sym-global-set".reflectCtrlWith[Unit](i, v.s) + case GlobalType(NumType(F64Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i);"sym-global-set".reflectCtrlWith[Unit](i, v.s) } } } @@ -569,382 +575,153 @@ trait StagedWasmEvaluator extends SAIOps { def toInt: Rep[Int] = "num-to-int".reflectCtrlWith[Int](num.i) def clz(): StagedNum = num match { - case I32(i) => I32("clz".reflectCtrlWith[Num](i)) - case I64(i) => I64("clz".reflectCtrlWith[Num](i)) + case I32(x_c, x_s) => I32("clz".reflectCtrlWith[Num](x_c), "sym-clz".reflectCtrlWith[SymVal](x_s)) + case I64(x_c, x_s) => I64("clz".reflectCtrlWith[Num](x_c), "sym-clz".reflectCtrlWith[SymVal](x_s)) } def ctz(): StagedNum = num match { - case I32(i) => I32("ctz".reflectCtrlWith[Num](i)) - case I64(i) => I64("ctz".reflectCtrlWith[Num](i)) + case I32(x_c, x_s) => I32("ctz".reflectCtrlWith[Num](x_c), "sym-ctz".reflectCtrlWith[SymVal](x_s)) + case I64(x_c, x_s) => I64("ctz".reflectCtrlWith[Num](x_c), "sym-ctz".reflectCtrlWith[SymVal](x_s)) } def popcnt(): StagedNum = num match { - case I32(i) => I32("popcnt".reflectCtrlWith[Num](i)) - case I64(i) => I64("popcnt".reflectCtrlWith[Num](i)) + case I32(x_c, x_s) => I32("popcnt".reflectCtrlWith[Num](x_c), "sym-popcnt".reflectCtrlWith[SymVal](x_s)) + case I64(x_c, x_s) => I64("popcnt".reflectCtrlWith[Num](x_c), "sym-popcnt".reflectCtrlWith[SymVal](x_s)) } def +(rhs: StagedNum): StagedNum = { (num, rhs) match { - case (I32(x), I32(y)) => I32("binary-add".reflectCtrlWith[Num](x, y)) - case (I64(x), I64(y)) => I64("binary-add".reflectCtrlWith[Num](x, y)) - case (F32(x), F32(y)) => F32("binary-add".reflectCtrlWith[Num](x, y)) - case (F64(x), F64(y)) => F64("binary-add".reflectCtrlWith[Num](x, y)) + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-add".reflectCtrlWith[Num](x_c, y_c), "sym-binary-add".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-add".reflectCtrlWith[Num](x_c, y_c), "sym-binary-add".reflectCtrlWith[SymVal](x_s, y_s)) + case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-add".reflectCtrlWith[Num](x_c, y_c), "sym-binary-add".reflectCtrlWith[SymVal](x_s, y_s)) + case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-add".reflectCtrlWith[Num](x_c, y_c), "sym-binary-add".reflectCtrlWith[SymVal](x_s, y_s)) } } + def -(rhs: StagedNum): StagedNum = { (num, rhs) match { - case (I32(x), I32(y)) => I32("binary-sub".reflectCtrlWith[Num](x, y)) - case (I64(x), I64(y)) => I64("binary-sub".reflectCtrlWith[Num](x, y)) - case (F32(x), F32(y)) => F32("binary-sub".reflectCtrlWith[Num](x, y)) - case (F64(x), F64(y)) => F64("binary-sub".reflectCtrlWith[Num](x, y)) + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-sub".reflectCtrlWith[Num](x_c, y_c), "sym-binary-sub".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-sub".reflectCtrlWith[Num](x_c, y_c), "sym-binary-sub".reflectCtrlWith[SymVal](x_s, y_s)) + case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-sub".reflectCtrlWith[Num](x_c, y_c), "sym-binary-sub".reflectCtrlWith[SymVal](x_s, y_s)) + case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-sub".reflectCtrlWith[Num](x_c, y_c), "sym-binary-sub".reflectCtrlWith[SymVal](x_s, y_s)) } } def *(rhs: StagedNum): StagedNum = { (num, rhs) match { - case (I32(x), I32(y)) => I32("binary-mul".reflectCtrlWith[Num](x, y)) - case (I64(x), I64(y)) => I64("binary-mul".reflectCtrlWith[Num](x, y)) - case (F32(x), F32(y)) => F32("binary-mul".reflectCtrlWith[Num](x, y)) - case (F64(x), F64(y)) => F64("binary-mul".reflectCtrlWith[Num](x, y)) + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-mul".reflectCtrlWith[Num](x_c, y_c), "sym-binary-mul".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-mul".reflectCtrlWith[Num](x_c, y_c), "sym-binary-mul".reflectCtrlWith[SymVal](x_s, y_s)) + case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-mul".reflectCtrlWith[Num](x_c, y_c), "sym-binary-mul".reflectCtrlWith[SymVal](x_s, y_s)) + case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-mul".reflectCtrlWith[Num](x_c, y_c), "sym-binary-mul".reflectCtrlWith[SymVal](x_s, y_s)) } } def /(rhs: StagedNum): StagedNum = { (num, rhs) match { - case (I32(x), I32(y)) => I32("binary-div".reflectCtrlWith[Num](x, y)) - case (I64(x), I64(y)) => I64("binary-div".reflectCtrlWith[Num](x, y)) - case (F32(x), F32(y)) => F32("binary-div".reflectCtrlWith[Num](x, y)) - case (F64(x), F64(y)) => F64("binary-div".reflectCtrlWith[Num](x, y)) + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-div".reflectCtrlWith[Num](x_c, y_c), "sym-binary-div".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-div".reflectCtrlWith[Num](x_c, y_c), "sym-binary-div".reflectCtrlWith[SymVal](x_s, y_s)) + case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-div".reflectCtrlWith[Num](x_c, y_c), "sym-binary-div".reflectCtrlWith[SymVal](x_s, y_s)) + case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-div".reflectCtrlWith[Num](x_c, y_c), "sym-binary-div".reflectCtrlWith[SymVal](x_s, y_s)) } } def <<(rhs: StagedNum): StagedNum = { (num, rhs) match { - case (I32(x), I32(y)) => I32("binary-shl".reflectCtrlWith[Num](x, y)) - case (I64(x), I64(y)) => I64("binary-shl".reflectCtrlWith[Num](x, y)) + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-shl".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shl".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-shl".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shl".reflectCtrlWith[SymVal](x_s, y_s)) + case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-shl".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shl".reflectCtrlWith[SymVal](x_s, y_s)) + case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-shl".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shl".reflectCtrlWith[SymVal](x_s, y_s)) } } def >>(rhs: StagedNum): StagedNum = { (num, rhs) match { - case (I32(x), I32(y)) => I32("binary-shr".reflectCtrlWith[Num](x, y)) - case (I64(x), I64(y)) => I64("binary-shr".reflectCtrlWith[Num](x, y)) + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-shr".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shr".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-shr".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shr".reflectCtrlWith[SymVal](x_s, y_s)) + case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-shr".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shr".reflectCtrlWith[SymVal](x_s, y_s)) + case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-shr".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shr".reflectCtrlWith[SymVal](x_s, y_s)) } } def &(rhs: StagedNum): StagedNum = { (num, rhs) match { - case (I32(x), I32(y)) => I32("binary-and".reflectCtrlWith[Num](x, y)) - case (I64(x), I64(y)) => I64("binary-and".reflectCtrlWith[Num](x, y)) + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-and".reflectCtrlWith[Num](x_c, y_c), "sym-binary-and".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-and".reflectCtrlWith[Num](x_c, y_c), "sym-binary-and".reflectCtrlWith[SymVal](x_s, y_s)) + case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-and".reflectCtrlWith[Num](x_c, y_c), "sym-binary-and".reflectCtrlWith[SymVal](x_s, y_s)) + case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-and".reflectCtrlWith[Num](x_c, y_c), "sym-binary-and".reflectCtrlWith[SymVal](x_s, y_s)) } } def numEq(rhs: StagedNum): StagedNum = { (num, rhs) match { - case (I32(x), I32(y)) => I32("relation-eq".reflectCtrlWith[Num](x, y)) - case (I64(x), I64(y)) => I32("relation-eq".reflectCtrlWith[Num](x, y)) + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-eq".reflectCtrlWith[Num](x_c, y_c), "sym-relation-eq".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-eq".reflectCtrlWith[Num](x_c, y_c), "sym-relation-eq".reflectCtrlWith[SymVal](x_s, y_s)) } } def numNe(rhs: StagedNum): StagedNum = { (num, rhs) match { - case (I32(x), I32(y)) => I32("relation-ne".reflectCtrlWith[Num](x, y)) - case (I64(x), I64(y)) => I32("relation-ne".reflectCtrlWith[Num](x, y)) + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-ne".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ne".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-ne".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ne".reflectCtrlWith[SymVal](x_s, y_s)) } } def <(rhs: StagedNum): StagedNum = { (num, rhs) match { - case (I32(x), I32(y)) => I32("relation-lt".reflectCtrlWith[Num](x, y)) - case (I64(x), I64(y)) => I32("relation-lt".reflectCtrlWith[Num](x, y)) + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-lt".reflectCtrlWith[Num](x_c, y_c), "sym-relation-lt".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-lt".reflectCtrlWith[Num](x_c, y_c), "sym-relation-lt".reflectCtrlWith[SymVal](x_s, y_s)) } } def ltu(rhs: StagedNum): StagedNum = { (num, rhs) match { - case (I32(x), I32(y)) => I32("relation-ltu".reflectCtrlWith[Num](x, y)) - case (I64(x), I64(y)) => I32("relation-ltu".reflectCtrlWith[Num](x, y)) + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-ltu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ltu".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-ltu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ltu".reflectCtrlWith[SymVal](x_s, y_s)) } } def >(rhs: StagedNum): StagedNum = { (num, rhs) match { - case (I32(x), I32(y)) => I32("relation-gt".reflectCtrlWith[Num](x, y)) - case (I64(x), I64(y)) => I32("relation-gt".reflectCtrlWith[Num](x, y)) + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-gt".reflectCtrlWith[Num](x_c, y_c), "sym-relation-gt".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-gt".reflectCtrlWith[Num](x_c, y_c), "sym-relation-gt".reflectCtrlWith[SymVal](x_s, y_s)) } } def gtu(rhs: StagedNum): StagedNum = { (num, rhs) match { - case (I32(x), I32(y)) => I32("relation-gtu".reflectCtrlWith[Num](x, y)) - case (I64(x), I64(y)) => I32("relation-gtu".reflectCtrlWith[Num](x, y)) + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-gtu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-gtu".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-gtu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-gtu".reflectCtrlWith[SymVal](x_s, y_s)) } } def <=(rhs: StagedNum): StagedNum = { (num, rhs) match { - case (I32(x), I32(y)) => I32("relation-le".reflectCtrlWith[Num](x, y)) - case (I64(x), I64(y)) => I32("relation-le".reflectCtrlWith[Num](x, y)) + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-le".reflectCtrlWith[Num](x_c, y_c), "sym-relation-le".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-le".reflectCtrlWith[Num](x_c, y_c), "sym-relation-le".reflectCtrlWith[SymVal](x_s, y_s)) } } def leu(rhs: StagedNum): StagedNum = { (num, rhs) match { - case (I32(x), I32(y)) => I32("relation-leu".reflectCtrlWith[Num](x, y)) - case (I64(x), I64(y)) => I32("relation-leu".reflectCtrlWith[Num](x, y)) + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-leu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-leu".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-leu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-leu".reflectCtrlWith[SymVal](x_s, y_s)) } } def >=(rhs: StagedNum): StagedNum = { (num, rhs) match { - case (I32(x), I32(y)) => I32("relation-ge".reflectCtrlWith[Num](x, y)) - case (I64(x), I64(y)) => I32("relation-ge".reflectCtrlWith[Num](x, y)) + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-ge".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ge".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-ge".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ge".reflectCtrlWith[SymVal](x_s, y_s)) } } def geu(rhs: StagedNum): StagedNum = { (num, rhs) match { - case (I32(x), I32(y)) => I32("relation-geu".reflectCtrlWith[Num](x, y)) - case (I64(x), I64(y)) => I32("relation-geu".reflectCtrlWith[Num](x, y)) - } - } - } -} - -trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { - override def mayInline(n: Node): Boolean = n match { - case Node(s, "stack-pop", _, _) => false - case _ => super.mayInline(n) - } - - override def traverse(n: Node): Unit = n match { - case Node(_, "stack-drop", List(n), _) => - emit("Stack.drop("); shallow(n); emit(")\n") - case Node(_, "stack-reset", List(n), _) => - emit("Stack.reset("); shallow(n); emit(")\n") - case Node(_, "stack-init", _, _) => - emit("Stack.initialize()\n") - case Node(_, "stack-print", _, _) => - emit("Stack.print()\n") - case Node(_, "frame-push", List(i), _) => - emit("Frames.pushFrame("); shallow(i); emit(")\n") - case Node(_, "frame-pop", List(i), _) => - emit("Frames.popFrame("); shallow(i); emit(")\n") - case Node(_, "frame-putAll", List(args), _) => - emit("Frames.putAll("); shallow(args); emit(")\n") - case Node(_, "frame-set", List(i, value), _) => - emit("Frames.set("); shallow(i); emit(", "); shallow(value); emit(")\n") - case Node(_, "global-set", List(i, value), _) => - emit("Global.globalSet("); shallow(i); emit(", "); shallow(value); emit(")\n") - case _ => super.traverse(n) - } - - // code generation for pure nodes - override def shallow(n: Node): Unit = n match { - case Node(_, "frame-get", List(i), _) => - emit("Frames.get("); shallow(i); emit(")") - case Node(_, "frame-pop", List(i), _) => - emit("Frames.popFrame("); shallow(i); emit(")") - case Node(_, "stack-push", List(value), _) => - emit("Stack.push("); shallow(value); emit(")") - case Node(_, "stack-pop", _, _) => - emit("Stack.pop()") - case Node(_, "stack-peek", _, _) => - emit("Stack.peek") - case Node(_, "stack-take", List(n), _) => - emit("Stack.take("); shallow(n); emit(")") - case Node(_, "stack-size", _, _) => - emit("Stack.size") - case Node(_, "global-get", List(i), _) => - emit("Global.globalGet("); shallow(i); emit(")") - case Node(_, "binary-add", List(lhs, rhs), _) => - shallow(lhs); emit(" + "); shallow(rhs) - case Node(_, "binary-sub", List(lhs, rhs), _) => - shallow(lhs); emit(" - "); shallow(rhs) - case Node(_, "binary-mul", List(lhs, rhs), _) => - shallow(lhs); emit(" * "); shallow(rhs) - case Node(_, "binary-div", List(lhs, rhs), _) => - shallow(lhs); emit(" / "); shallow(rhs) - case Node(_, "binary-shl", List(lhs, rhs), _) => - shallow(lhs); emit(" << "); shallow(rhs) - case Node(_, "binary-shr", List(lhs, rhs), _) => - shallow(lhs); emit(" >> "); shallow(rhs) - case Node(_, "binary-and", List(lhs, rhs), _) => - shallow(lhs); emit(" & "); shallow(rhs) - case Node(_, "relation-eq", List(lhs, rhs), _) => - shallow(lhs); emit(" == "); shallow(rhs) - case Node(_, "relation-ne", List(lhs, rhs), _) => - shallow(lhs); emit(" != "); shallow(rhs) - case Node(_, "relation-lt", List(lhs, rhs), _) => - shallow(lhs); emit(" < "); shallow(rhs) - case Node(_, "relation-ltu", List(lhs, rhs), _) => - shallow(lhs); emit(" < "); shallow(rhs) - case Node(_, "relation-gt", List(lhs, rhs), _) => - shallow(lhs); emit(" > "); shallow(rhs) - case Node(_, "relation-gtu", List(lhs, rhs), _) => - shallow(lhs); emit(" > "); shallow(rhs) - case Node(_, "relation-le", List(lhs, rhs), _) => - shallow(lhs); emit(" <= "); shallow(rhs) - case Node(_, "relation-leu", List(lhs, rhs), _) => - shallow(lhs); emit(" <= "); shallow(rhs) - case Node(_, "relation-ge", List(lhs, rhs), _) => - shallow(lhs); emit(" >= "); shallow(rhs) - case Node(_, "relation-geu", List(lhs, rhs), _) => - shallow(lhs); emit(" >= "); shallow(rhs) - case Node(_, "num-to-int", List(num), _) => - shallow(num); emit(".toInt") - case Node(_, "no-op", _, _) => - emit("()") - case _ => super.shallow(n) - } -} - -trait WasmToScalaCompilerDriver[A, B] - extends SAIDriver[A, B] with StagedWasmEvaluator { q => - override val codegen = new StagedWasmScalaGen { - val IR: q.type = q - import IR._ - override def remap(m: Manifest[_]): String = { - if (m.toString.endsWith("Stack")) "Stack" - else if(m.toString.endsWith("Frame")) "Frame" - else super.remap(m) - } - } - - override val prelude = - """ -object Prelude { - sealed abstract class Num { - def +(that: Num): Num = (this, that) match { - case (I32V(x), I32V(y)) => I32V(x + y) - case (I64V(x), I64V(y)) => I64V(x + y) - case _ => throw new RuntimeException("Invalid addition") - } - - def -(that: Num): Num = (this, that) match { - case (I32V(x), I32V(y)) => I32V(x - y) - case (I64V(x), I64V(y)) => I64V(x - y) - case _ => throw new RuntimeException("Invalid subtraction") - } - - def !=(that: Num): Num = (this, that) match { - case (I32V(x), I32V(y)) => I32V(if (x != y) 1 else 0) - case (I64V(x), I64V(y)) => I32V(if (x != y) 1 else 0) - case _ => throw new RuntimeException("Invalid inequality") - } - - def toInt: Int = this match { - case I32V(i) => i - case I64V(i) => i.toInt - } - } - case class I32V(i: Int) extends Num - case class I64V(i: Long) extends Num - -object Stack { - private val buffer = new scala.collection.mutable.ArrayBuffer[Num]() - def push(v: Num): Unit = buffer.append(v) - def pop(): Num = { - buffer.remove(buffer.size - 1) - } - def peek: Num = { - buffer.last - } - def size: Int = buffer.size - def drop(n: Int): Unit = { - buffer.remove(buffer.size - n, n) - } - def take(n: Int): List[Num] = { - val xs = buffer.takeRight(n).toList - drop(n) - xs - } - def reset(size: Int): Unit = { - info(s"Reset stack to size $size") - while (buffer.size > size) { - buffer.remove(buffer.size - 1) - } - } - def initialize(): Unit = buffer.clear() - def print(): Unit = { - println("Stack: " + buffer.mkString(", ")) - } -} - - class Frame(val size: Int) { - private val data = new Array[Num](size) - def apply(i: Int): Num = { - info(s"frame(${i}) is ${data(i)}") - data(i) - } - def update(i: Int, v: Num): Unit = { - info(s"set frame(${i}) to ${v}") - data(i) = v - } - def putAll(xs: List[Num]): Unit = { - for (i <- 0 until xs.size) { - data(i) = xs(i) + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-geu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-geu".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-geu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-geu".reflectCtrlWith[SymVal](x_s, y_s)) } } - override def toString: String = { - "Frame(" + data.mkString(", ") + ")" - } - } - - object Frames { - private var frames = List[Frame]() - def pushFrame(size: Int): Unit = { - frames = new Frame(size) :: frames - } - def popFrame(): Unit = { - frames = frames.tail - } - def top: Frame = frames.head - def set(i: Int, v: Num): Unit = { - top(i) = v - } - def get(i: Int): Num = { - top(i) - } - } - - object Global { - // TODO: create global with specific size - private val globals = new Array[Num](10) - def globalGet(i: Int): Num = globals(i) - def globalSet(i: Int, v: Num): Unit = globals(i) = v - } - - def info(xs: Any*): Unit = { - if (System.getenv("DEBUG") != null) { - println("[INFO] " + xs.mkString(" ")) - } - } -} -import Prelude._ - - -object Main { - def main(args: Array[String]): Unit = { - val snippet = new Snippet() - snippet(()) - } -} -""" -} - - -object WasmToScalaCompiler { - def compile(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): String = { - println(s"Now compiling wasm module with entry function $main") - val code = new WasmToScalaCompilerDriver[Unit, Unit] { - def module: ModuleInstance = moduleInst - def snippet(x: Rep[Unit]): Rep[Unit] = { - evalTop(main, printRes) - } - } - code.code } } @@ -1168,3 +945,4 @@ object WasmToCppCompiler { } + From 2143050a320d3fa182aa361bc6b9c154aff9045d Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Mon, 7 Jul 2025 21:46:15 +0800 Subject: [PATCH 04/21] maintain a symbolic stack during the execution --- headers/wasm.hpp | 1 + headers/wasm/concrete_rt.hpp | 7 +- headers/wasm/symbolic_rt.hpp | 72 +++++++++++++++++++ .../scala/wasm/StagedConcolicMiniWasm.scala | 52 ++++++++++---- .../genwasym/TestStagedConcolicEval.scala | 33 +++++++++ 5 files changed, 152 insertions(+), 13 deletions(-) create mode 100644 headers/wasm/symbolic_rt.hpp create mode 100644 src/test/scala/genwasym/TestStagedConcolicEval.scala diff --git a/headers/wasm.hpp b/headers/wasm.hpp index 21da2ff7..c7e98b6e 100644 --- a/headers/wasm.hpp +++ b/headers/wasm.hpp @@ -2,5 +2,6 @@ #define WASM_HEADERS #include "wasm/concrete_rt.hpp" +#include "wasm/symbolic_rt.hpp" #endif \ No newline at end of file diff --git a/headers/wasm/concrete_rt.hpp b/headers/wasm/concrete_rt.hpp index 34d739f4..e994cbde 100644 --- a/headers/wasm/concrete_rt.hpp +++ b/headers/wasm/concrete_rt.hpp @@ -1,3 +1,6 @@ +#ifndef WASM_CONCRETE_RT_HPP +#define WASM_CONCRETE_RT_HPP + #include #include #include @@ -200,4 +203,6 @@ struct Memory_t { } }; -static Memory_t Memory(1); // 1 page memory \ No newline at end of file +static Memory_t Memory(1); // 1 page memory + +#endif // WASM_CONCRETE_RT_HPP \ No newline at end of file diff --git a/headers/wasm/symbolic_rt.hpp b/headers/wasm/symbolic_rt.hpp new file mode 100644 index 00000000..d03509ac --- /dev/null +++ b/headers/wasm/symbolic_rt.hpp @@ -0,0 +1,72 @@ +#ifndef WASM_SYMBOLIC_RT_HPP +#define WASM_SYMBOLIC_RT_HPP + +#include "concrete_rt.hpp" +#include + +class SymVal { +public: + SymVal operator+(const SymVal &other) const { + // Define how to add two symbolic values + // Not implemented yet + return SymVal(); + } + + SymVal is_zero() const { + // Check if the symbolic value is zero + // Not implemented yet + return SymVal(); + } +}; + +class SymStack_t { +public: + void push(SymVal val) { + // Push a symbolic value to the stack + // Not implemented yet + } + + SymVal pop() { + // Pop a symbolic value from the stack + // Not implemented yet + return SymVal(); + } + + SymVal peek() { return SymVal(); } +}; + +static SymStack_t SymStack; + +class SymFrames_t { +public: + void pushFrame(int size) { + // Push a new frame with the given size + // Not implemented yet + } + std::monostate popFrame(int size) { + // Pop the frame of the given size + // Not implemented yet + return std::monostate(); + } + + SymVal get(int index) { + // Get the symbolic value at the given index + // Not implemented yet + return SymVal(); + } + + void set(int index, SymVal val) { + // Set the symbolic value at the given index + // Not implemented yet + } +}; + +static SymFrames_t SymFrames; + +static SymVal Concrete(Num num) { + // Convert a concrete number to a symbolic value + // Not implemented yet + return SymVal(); +} + +#endif // WASM_SYMBOLIC_RT_HPP \ No newline at end of file diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index 22b593a9..c41c7c42 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -40,7 +40,7 @@ trait StagedWasmEvaluator extends SAIOps { case class F32(i: Rep[Num], s: Rep[SymVal]) extends StagedNum case class F64(i: Rep[Num], s: Rep[SymVal]) extends StagedNum - implicit def toStagedNum(num: Num): StagedNum = { + def toStagedNum(num: Num): StagedNum = { num match { case I32V(_) => I32(num, Concrete(num)) case I64V(_) => I64(num, Concrete(num)) @@ -118,7 +118,7 @@ trait StagedWasmEvaluator extends SAIOps { val (_, newCtx) = Stack.pop() eval(rest, kont, mkont, trail)(newCtx) case WasmConst(num) => - val newCtx = Stack.push(num) + val newCtx = Stack.push(toStagedNum(num)) eval(rest, kont, mkont, trail)(newCtx) case LocalGet(i) => val newCtx = Stack.push(Frames.get(i)) @@ -155,7 +155,10 @@ trait StagedWasmEvaluator extends SAIOps { case MemorySize => ??? case MemoryGrow => val (delta, newCtx1) = Stack.pop() - val newCtx2 = Stack.push(Values.I32V(Memory.grow(delta.toInt)))(newCtx1) + val ret = Memory.grow(delta.toInt) + val retNum = Values.I32V(ret) + val retSym = "Concrete".reflectCtrlWith[SymVal](retNum) + val newCtx2 = Stack.push(I32(retNum, retSym))(newCtx1) eval(rest, kont, mkont, trail)(newCtx2) case MemoryFill => ??? case Unreachable => unreachable() @@ -220,6 +223,7 @@ trait StagedWasmEvaluator extends SAIOps { val newRestCtx = Stack.shift(offset, funcTy.out.size)(restCtx) eval(rest, kont, mk, trail)(newRestCtx) }) + // TODO: put the cond.s to path condition if (cond.toInt != 0) { eval(thn, restK _, mkont, restK _ :: trail)(newCtx) } else { @@ -232,6 +236,7 @@ trait StagedWasmEvaluator extends SAIOps { case BrIf(label) => val (cond, newCtx) = Stack.pop() info(s"The br_if(${label})'s condition is ", cond.toInt) + // TODO: put the cond.s to path condition if (cond.toInt != 0) { info(s"Jump to $label") trail(label)(newCtx)(mkont) @@ -320,7 +325,7 @@ trait StagedWasmEvaluator extends SAIOps { } def evalTestOp(op: TestOp, value: StagedNum): StagedNum = op match { - case Eqz(_) => Values.I32V(if (value.toInt == 0) 1 else 0) + case Eqz(_) => value.isZero } def evalUnaryOp(op: UnaryOp, value: StagedNum): StagedNum = op match { @@ -523,6 +528,7 @@ trait StagedWasmEvaluator extends SAIOps { I32("I32V".reflectCtrlWith[Num]("memory-load-int".reflectCtrlWith[Int](base, offset)), "sym-load-int-todo".reflectCtrlWith[SymVal](base, offset)) } + // Returns the previous memory size on success, or -1 if the memory cannot be grown. def grow(delta: Rep[Int]): Rep[Int] = { "memory-grow".reflectCtrlWith[Int](delta) } @@ -539,12 +545,12 @@ trait StagedWasmEvaluator extends SAIOps { // runtime values object Values { - def I32V(i: Rep[Int]): StagedNum = { - I32("I32V".reflectCtrlWith[Num](i), "Concrete".reflectCtrlWith[SymVal]("I32V".reflectCtrlWith[Num](i))) + def I32V(i: Rep[Int]): Rep[Num] = { + "I32V".reflectCtrlWith[Num](i) } - def I64V(i: Rep[Long]): StagedNum = { - I64("I64V".reflectCtrlWith[Num](i), "Concrete".reflectCtrlWith[SymVal]("I64V".reflectCtrlWith[Num](i))) + def I64V(i: Rep[Long]): Rep[Num] = { + "I64V".reflectCtrlWith[Num](i) } } @@ -574,6 +580,10 @@ trait StagedWasmEvaluator extends SAIOps { def toInt: Rep[Int] = "num-to-int".reflectCtrlWith[Int](num.i) + def isZero(): StagedNum = num match { + case I32(x_c, x_s) => I32(Values.I32V("is-zero".reflectCtrlWith[Int](num.toInt)), "sym-is-zero".reflectCtrlWith[SymVal](x_s)) + } + def clz(): StagedNum = num match { case I32(x_c, x_s) => I32("clz".reflectCtrlWith[Num](x_c), "sym-clz".reflectCtrlWith[SymVal](x_s)) case I64(x_c, x_s) => I64("clz".reflectCtrlWith[Num](x_c), "sym-clz".reflectCtrlWith[SymVal](x_s)) @@ -750,13 +760,16 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { else if (m.toString.endsWith("Global")) "Global" else if (m.toString.endsWith("I32V")) "I32V" else if (m.toString.endsWith("I64V")) "I64V" + else if (m.toString.endsWith("SymVal")) "SymVal" + else super.remap(m) } - // for now, the traverse/shallow is same as the scala backend's override def traverse(n: Node): Unit = n match { case Node(_, "stack-push", List(value), _) => emit("Stack.push("); shallow(value); emit(");\n") + case Node(_, "sym-stack-push", List(s_value), _) => + emit("SymStack.push("); shallow(s_value); emit(");\n") case Node(_, "stack-drop", List(n), _) => emit("Stack.drop("); shallow(n); emit(");\n") case Node(_, "stack-init", _, _) => @@ -765,12 +778,14 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { emit("Stack.print();\n") case Node(_, "frame-push", List(i), _) => emit("Frames.pushFrame("); shallow(i); emit(");\n") + case Node(_, "sym-frame-push", List(i), _) => + emit("SymFrames.pushFrame("); shallow(i); emit(");\n") case Node(_, "frame-pop", List(i), _) => emit("Frames.popFrame("); shallow(i); emit(");\n") - case Node(_, "frame-putAll", List(args), _) => - emit("Frames.putAll("); shallow(args); emit(");\n") case Node(_, "frame-set", List(i, value), _) => emit("Frames.set("); shallow(i); emit(", "); shallow(value); emit(");\n") + case Node(_, "sym-frame-set", List(i, s_value), _) => + emit("SymFrames.set("); shallow(i); emit(", "); shallow(s_value); emit(");\n") case Node(_, "global-set", List(i, value), _) => emit("Global.globalSet("); shallow(i); emit(", "); shallow(value); emit(");\n") // Note: The following code is copied from the traverse of CppBackend.scala, try to avoid duplicated code @@ -787,10 +802,11 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { case _ => super.traverse(n) } - // code generation for pure nodes override def shallow(n: Node): Unit = n match { case Node(_, "frame-get", List(i), _) => emit("Frames.get("); shallow(i); emit(")") + case Node(_, "sym-frame-get", List(i), _) => + emit("SymFrames.get("); shallow(i); emit(")") case Node(_, "stack-drop", List(n), _) => emit("Stack.drop("); shallow(n); emit(")") case Node(_, "stack-push", List(value), _) => @@ -799,10 +815,16 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { emit("Stack.shift("); shallow(offset); emit(", "); shallow(size); emit(")") case Node(_, "stack-pop", _, _) => emit("Stack.pop()") + case Node(_, "sym-stack-pop", _, _) => + emit("SymStack.pop()") case Node(_, "frame-pop", List(i), _) => emit("Frames.popFrame("); shallow(i); emit(")") + case Node(_, "sym-frame-pop", List(i), _) => + emit("SymFrames.popFrame("); shallow(i); emit(")") case Node(_, "stack-peek", _, _) => emit("Stack.peek()") + case Node(_, "sym-stack-peek", _, _) => + emit("SymStack.peek()") case Node(_, "stack-take", List(n), _) => emit("Stack.take("); shallow(n); emit(")") case Node(_, "slice-reverse", List(slice), _) => @@ -817,8 +839,14 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { emit("Stack.size()") case Node(_, "global-get", List(i), _) => emit("Global.globalGet("); shallow(i); emit(")") + case Node(_, "is-zero", List(num), _) => + emit("(0 == "); shallow(num); emit(")") + case Node(_, "sym-is-zero", List(s_num), _) => + shallow(s_num); emit(".is_zero()") case Node(_, "binary-add", List(lhs, rhs), _) => shallow(lhs); emit(" + "); shallow(rhs) + case Node(_, "sym-binary-add", List(lhs, rhs), _) => + shallow(lhs); emit(" + "); shallow(rhs) case Node(_, "binary-sub", List(lhs, rhs), _) => shallow(lhs); emit(" - "); shallow(rhs) case Node(_, "binary-mul", List(lhs, rhs), _) => diff --git a/src/test/scala/genwasym/TestStagedConcolicEval.scala b/src/test/scala/genwasym/TestStagedConcolicEval.scala new file mode 100644 index 00000000..eef6ab01 --- /dev/null +++ b/src/test/scala/genwasym/TestStagedConcolicEval.scala @@ -0,0 +1,33 @@ +package gensym.wasm + +import org.scalatest.FunSuite + +import lms.core.stub.Adapter + +import gensym.wasm.miniwasm.{ModuleInstance} +import gensym.wasm.parser._ +import gensym.wasm.stagedconcolicminiwasm._ + +class TestStagedConcolicEval extends FunSuite { + def testFileToCpp(filename: String, main: Option[String] = None, expect: Option[List[Float]]=None) = { + val moduleInst = ModuleInstance(Parser.parseFile(filename)) + val cppFile = s"$filename.cpp" + val exe = s"$cppFile.exe" + WasmToCppCompiler.compileToExe(moduleInst, main, cppFile, exe, true) + + import sys.process._ + val result = s"./$exe".!! + println(result) + + expect.map(vs => { + val stackValues = result + .split("Stack contents: \n")(1) + .split("\n") + .map(_.toFloat) + .toList + assert(vs == stackValues) + }) + } + + test("ack-cpp") { testFileToCpp("./benchmarks/wasm/ack.wat", Some("real_main"), expect=Some(List(7))) } +} From 8d81fbe2e61a1f7e792be3fb5463d7ead9954ea4 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 9 Jul 2025 19:24:24 +0800 Subject: [PATCH 05/21] record path conditions --- headers/wasm/symbolic_rt.hpp | 18 ++++++++++++++++ .../scala/wasm/StagedConcolicMiniWasm.scala | 21 +++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/headers/wasm/symbolic_rt.hpp b/headers/wasm/symbolic_rt.hpp index d03509ac..05fc41d4 100644 --- a/headers/wasm/symbolic_rt.hpp +++ b/headers/wasm/symbolic_rt.hpp @@ -17,6 +17,12 @@ class SymVal { // Not implemented yet return SymVal(); } + + SymVal negate() const { + // negate the symbolic condition by creating a new symbolic value + // not implemented yet + return SymVal(); + } }; class SymStack_t { @@ -69,4 +75,16 @@ static SymVal Concrete(Num num) { return SymVal(); } +class ExploreTree_t { +public: + std::monostate fillIfElseNode(SymVal s, bool branch) { + // fill the current node with the branch condition s + // parameter branch is redundant, to hint which branch we've entered + // Not implemented yet + return std::monostate(); + } +}; + +static ExploreTree_t ExploreTree; + #endif // WASM_SYMBOLIC_RT_HPP \ No newline at end of file diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index c41c7c42..1eb717e3 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -225,8 +225,10 @@ trait StagedWasmEvaluator extends SAIOps { }) // TODO: put the cond.s to path condition if (cond.toInt != 0) { + ExploreTree.fillWithIfElse(cond.s, true) eval(thn, restK _, mkont, restK _ :: trail)(newCtx) } else { + ExploreTree.fillWithIfElse(cond.s.not, false) eval(els, restK _, mkont, restK _ :: trail)(newCtx) } () @@ -239,9 +241,11 @@ trait StagedWasmEvaluator extends SAIOps { // TODO: put the cond.s to path condition if (cond.toInt != 0) { info(s"Jump to $label") + ExploreTree.fillWithIfElse(cond.s, true) trail(label)(newCtx)(mkont) } else { info(s"Continue") + ExploreTree.fillWithIfElse(cond.s.not, false) eval(rest, kont, mkont, trail)(newCtx) } () @@ -575,6 +579,13 @@ trait StagedWasmEvaluator extends SAIOps { } } + // Exploration tree, + object ExploreTree { + def fillWithIfElse(s: Rep[SymVal], branch: Boolean): Rep[Unit] = { + "tree-fill-if-else".reflectCtrlWith[Unit](s, branch) + } + } + // runtime Num type implicit class StagedNumOps(num: StagedNum) { @@ -733,6 +744,12 @@ trait StagedWasmEvaluator extends SAIOps { } } } + + implicit class SymbolicOps(s: Rep[SymVal]) { + def not(): Rep[SymVal] = { + "sym-not".reflectCtrlWith(s) + } + } } trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { @@ -881,6 +898,10 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { shallow(lhs); emit(" >= "); shallow(rhs) case Node(_, "num-to-int", List(num), _) => shallow(num); emit(".toInt()") + case Node(_, "tree-fill-if-else", List(s, b), _) => + emit("ExploreTree.fillIfElseNode("); shallow(s); emit(", "); shallow(b); emit(")") + case Node(_, "sym-not", List(s), _) => + shallow(s); emit(".negate()") case Node(_, "dummy", _, _) => emit("std::monostate()") case Node(_, "dummy-op", _, _) => emit("std::monostate()") case Node(_, "no-op", _, _) => From 61215b6c87a2ba1ffa3f88838c17dc3d7e3cfb86 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 9 Jul 2025 23:06:44 +0800 Subject: [PATCH 06/21] The branch node only needs to remember the positive condition. use the sub-nodes of the branch to classify whether the execution is true or false --- headers/wasm/symbolic_rt.hpp | 6 ++++- .../scala/wasm/StagedConcolicMiniWasm.scala | 24 ++++++++++++------- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/headers/wasm/symbolic_rt.hpp b/headers/wasm/symbolic_rt.hpp index 05fc41d4..a0199db4 100644 --- a/headers/wasm/symbolic_rt.hpp +++ b/headers/wasm/symbolic_rt.hpp @@ -77,12 +77,16 @@ static SymVal Concrete(Num num) { class ExploreTree_t { public: - std::monostate fillIfElseNode(SymVal s, bool branch) { + std::monostate fillIfElseNode(SymVal s) { // fill the current node with the branch condition s // parameter branch is redundant, to hint which branch we've entered // Not implemented yet return std::monostate(); } + + std::monostate moveCursor(bool branch) { + return std::monostate(); + } }; static ExploreTree_t ExploreTree; diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index 1eb717e3..d31af3cb 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -224,11 +224,12 @@ trait StagedWasmEvaluator extends SAIOps { eval(rest, kont, mk, trail)(newRestCtx) }) // TODO: put the cond.s to path condition + ExploreTree.fillWithIfElse(cond.s) if (cond.toInt != 0) { - ExploreTree.fillWithIfElse(cond.s, true) + ExploreTree.moveCursor(true) eval(thn, restK _, mkont, restK _ :: trail)(newCtx) } else { - ExploreTree.fillWithIfElse(cond.s.not, false) + ExploreTree.moveCursor(false) eval(els, restK _, mkont, restK _ :: trail)(newCtx) } () @@ -239,13 +240,14 @@ trait StagedWasmEvaluator extends SAIOps { val (cond, newCtx) = Stack.pop() info(s"The br_if(${label})'s condition is ", cond.toInt) // TODO: put the cond.s to path condition + ExploreTree.fillWithIfElse(cond.s) if (cond.toInt != 0) { info(s"Jump to $label") - ExploreTree.fillWithIfElse(cond.s, true) + ExploreTree.moveCursor(true) trail(label)(newCtx)(mkont) } else { info(s"Continue") - ExploreTree.fillWithIfElse(cond.s.not, false) + ExploreTree.moveCursor(false) eval(rest, kont, mkont, trail)(newCtx) } () @@ -581,8 +583,12 @@ trait StagedWasmEvaluator extends SAIOps { // Exploration tree, object ExploreTree { - def fillWithIfElse(s: Rep[SymVal], branch: Boolean): Rep[Unit] = { - "tree-fill-if-else".reflectCtrlWith[Unit](s, branch) + def fillWithIfElse(s: Rep[SymVal]): Rep[Unit] = { + "tree-fill-if-else".reflectCtrlWith[Unit](s) + } + + def moveCursor(branch: Boolean): Rep[Unit] = { + "tree-move-cursor".reflectCtrlWith[Unit](branch) } } @@ -898,8 +904,10 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { shallow(lhs); emit(" >= "); shallow(rhs) case Node(_, "num-to-int", List(num), _) => shallow(num); emit(".toInt()") - case Node(_, "tree-fill-if-else", List(s, b), _) => - emit("ExploreTree.fillIfElseNode("); shallow(s); emit(", "); shallow(b); emit(")") + case Node(_, "tree-fill-if-else", List(s), _) => + emit("ExploreTree.fillIfElseNode("); shallow(s); emit(")") + case Node(_, "tree-move-cursor", List(b), _) => + emit("ExploreTree.moveCursor("); shallow(b); emit(")") case Node(_, "sym-not", List(s), _) => shallow(s); emit(".negate()") case Node(_, "dummy", _, _) => emit("std::monostate()") From d18b5f7ca62dc6a96baca04eb38f638d06e41b0b Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 13 Jul 2025 14:57:36 +0800 Subject: [PATCH 07/21] symbolic runtime for explore tree --- headers/wasm/symbolic_rt.hpp | 86 ++++++++++++++++++- .../scala/wasm/StagedConcolicMiniWasm.scala | 8 ++ 2 files changed, 90 insertions(+), 4 deletions(-) diff --git a/headers/wasm/symbolic_rt.hpp b/headers/wasm/symbolic_rt.hpp index a0199db4..4373af84 100644 --- a/headers/wasm/symbolic_rt.hpp +++ b/headers/wasm/symbolic_rt.hpp @@ -2,6 +2,9 @@ #define WASM_SYMBOLIC_RT_HPP #include "concrete_rt.hpp" +#include +#include +#include #include class SymVal { @@ -75,18 +78,93 @@ static SymVal Concrete(Num num) { return SymVal(); } +struct Node; + +struct NodeBox { + explicit NodeBox(); + std::unique_ptr node; + NodeBox *parent; +}; + +struct Node { + virtual ~Node(){}; + virtual std::string to_string() = 0; +}; + +struct IfElseNode : Node { + SymVal cond; + std::unique_ptr true_branch; + std::unique_ptr false_branch; + + IfElseNode(SymVal cond) + : cond(cond), true_branch(std::make_unique()), + false_branch(std::make_unique()) {} + + std::string to_string() override { + std::string result = "IfElseNode {\n"; + result += " true_branch: "; + if (true_branch) { + result += true_branch->node->to_string(); + } else { + result += "nullptr"; + } + result += "\n"; + + result += " false_branch: "; + if (false_branch) { + result += false_branch->node->to_string(); + } else { + result += "nullptr"; + } + result += "\n"; + result += "}"; + return result; + } +}; + +struct UnExploredNode : Node { + UnExploredNode() {} + std::string to_string() override { return "UnexploredNode"; } +}; + +static UnExploredNode unexplored; + +inline NodeBox::NodeBox() + : node(std::make_unique< + UnExploredNode>() /* TODO: avoid allocation of unexplored node */) {} + class ExploreTree_t { public: - std::monostate fillIfElseNode(SymVal s) { - // fill the current node with the branch condition s - // parameter branch is redundant, to hint which branch we've entered - // Not implemented yet + explicit ExploreTree_t() + : root(std::make_unique()), cursor(root.get()) {} + std::monostate fillIfElseNode(SymVal cond) { + // fill the current node with an ifelse branch node + cursor->node = std::make_unique(cond); return std::monostate(); } std::monostate moveCursor(bool branch) { + assert(cursor != nullptr); + auto if_else_node = dynamic_cast(cursor->node.get()); + assert( + if_else_node != nullptr && + "Can't move cursor when the branch node is not initialized correctly!"); + if (branch) { + cursor = if_else_node->true_branch.get(); + } else { + cursor = if_else_node->false_branch.get(); + } return std::monostate(); } + + std::monostate print() { + std::cout << root->node->to_string() << std::endl; + return std::monostate(); + } + +private: + std::unique_ptr root; + NodeBox *cursor; }; static ExploreTree_t ExploreTree; diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index d31af3cb..4366ac78 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -16,6 +16,7 @@ import gensym.wasm.miniwasm.{ModuleInstance} import gensym.wasm.symbolic.{SymVal} import gensym.lmsx.{SAIDriver, StringOps, SAIOps, SAICodeGenBase, CppSAIDriver, CppSAICodeGenBase} import gensym.wasm.symbolic.Concrete +import gensym.wasm.symbolic.ExploreTree @virtualize trait StagedWasmEvaluator extends SAIOps { @@ -405,6 +406,7 @@ trait StagedWasmEvaluator extends SAIOps { info("Exiting the program...") if (printRes) { Stack.print() + ExploreTree.print() } "no-op".reflectCtrlWith[Unit]() } @@ -590,6 +592,10 @@ trait StagedWasmEvaluator extends SAIOps { def moveCursor(branch: Boolean): Rep[Unit] = { "tree-move-cursor".reflectCtrlWith[Unit](branch) } + + def print(): Rep[Unit] = { + "tree-print".reflectCtrlWith[Unit]() + } } // runtime Num type @@ -908,6 +914,8 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { emit("ExploreTree.fillIfElseNode("); shallow(s); emit(")") case Node(_, "tree-move-cursor", List(b), _) => emit("ExploreTree.moveCursor("); shallow(b); emit(")") + case Node(_, "tree-print", List(), _) => + emit("ExploreTree.print()") case Node(_, "sym-not", List(s), _) => shallow(s); emit(".negate()") case Node(_, "dummy", _, _) => emit("std::monostate()") From 92ab8ba0d03ec8ef8f3dd9a3b43a32d0c1bfc158 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Mon, 14 Jul 2025 01:30:20 +0800 Subject: [PATCH 08/21] add a to graphviz method, enhancing debug experience --- headers/wasm/symbolic_rt.hpp | 73 ++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/headers/wasm/symbolic_rt.hpp b/headers/wasm/symbolic_rt.hpp index 4373af84..02c3021c 100644 --- a/headers/wasm/symbolic_rt.hpp +++ b/headers/wasm/symbolic_rt.hpp @@ -3,8 +3,10 @@ #include "concrete_rt.hpp" #include +#include #include #include +#include #include class SymVal { @@ -89,8 +91,29 @@ struct NodeBox { struct Node { virtual ~Node(){}; virtual std::string to_string() = 0; + void to_graphviz(std::ostream &os) { + os << "digraph G {\n"; + os << " rankdir=TB;\n"; + os << " node [shape=box, style=filled, fillcolor=lightblue];\n"; + current_id = 0; + generate_dot(os, -1, ""); + + os << "}\n"; + } + int get_next_id(int &id_counter) { return id_counter++; } + virtual int generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) = 0; + +protected: + // Counter for unique node IDs across the entire graph, only for generating + // graphviz purpose + static int current_id; }; +// TODO: use this header file in multiple compilation units will cause problems +// during linking +int Node::current_id = 0; + struct IfElseNode : Node { SymVal cond; std::unique_ptr true_branch; @@ -120,11 +143,56 @@ struct IfElseNode : Node { result += "}"; return result; } + + int generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) override { + int current_node_dot_id = current_id; + current_id += 1; + + os << " node" << current_node_dot_id << " [label=\"If\"," + << "shape=diamond, fillcolor=lightyellow];\n"; + + // Draw edge from parent if this is not the root node + if (parent_dot_id != -1) { + os << " node" << parent_dot_id << " -> node" << current_node_dot_id; + if (!edge_label.empty()) { + os << " [label=\"" << edge_label << "\"]"; + } + os << ";\n"; + } + assert(true_branch != nullptr); + assert(true_branch->node != nullptr); + true_branch->node->generate_dot(os, current_node_dot_id, "true"); + assert(false_branch != nullptr); + assert(false_branch->node != nullptr); + false_branch->node->generate_dot(os, current_node_dot_id, "false"); + return current_node_dot_id; + } }; struct UnExploredNode : Node { UnExploredNode() {} std::string to_string() override { return "UnexploredNode"; } + +protected: + int generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) override { + int current_node_dot_id = current_id++; + + os << " node" << current_node_dot_id + << " [label=\"Unexplored\", shape=octagon, style=filled, " + "fillcolor=lightgrey];\n"; + + if (parent_dot_id != -1) { + os << " node" << parent_dot_id << " -> node" << current_node_dot_id; + if (!edge_label.empty()) { + os << " [label=\"" << edge_label << "\"]"; + } + os << ";\n"; + } + + return current_node_dot_id; + } }; static UnExploredNode unexplored; @@ -162,6 +230,11 @@ class ExploreTree_t { return std::monostate(); } + std::monostate to_graphviz(std::ostream &os) { + root->node->to_graphviz(os); + return std::monostate(); + } + private: std::unique_ptr root; NodeBox *cursor; From e1d7fc8fe801a877c8ef7e09d904e8714f6bbd17 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Mon, 14 Jul 2025 23:49:20 +0800 Subject: [PATCH 09/21] put symbolic expression on the SymStack --- headers/wasm/symbolic_rt.hpp | 126 +++++++++++++++++++++++++---------- 1 file changed, 92 insertions(+), 34 deletions(-) diff --git a/headers/wasm/symbolic_rt.hpp b/headers/wasm/symbolic_rt.hpp index 02c3021c..fac3ce6d 100644 --- a/headers/wasm/symbolic_rt.hpp +++ b/headers/wasm/symbolic_rt.hpp @@ -8,42 +8,104 @@ #include #include #include +#include -class SymVal { +class Symbolic {}; + +class SymConcrete : public Symbolic { public: - SymVal operator+(const SymVal &other) const { - // Define how to add two symbolic values - // Not implemented yet - return SymVal(); - } + Num value; + SymConcrete(Num num) : value(num) {} +}; - SymVal is_zero() const { - // Check if the symbolic value is zero - // Not implemented yet - return SymVal(); - } +struct SymBinary; - SymVal negate() const { - // negate the symbolic condition by creating a new symbolic value - // not implemented yet - return SymVal(); - } +struct SymVal { + std::shared_ptr symptr; + + SymVal() : symptr(nullptr) {} + SymVal(std::shared_ptr symptr) : symptr(symptr) {} + + SymVal add(const SymVal &other) const; + SymVal minus(const SymVal &other) const; + SymVal mul(const SymVal &other) const; + SymVal div(const SymVal &other) const; + SymVal eq(const SymVal &other) const; + SymVal neq(const SymVal &other) const; + SymVal lt(const SymVal &other) const; + SymVal leq(const SymVal &other) const; + SymVal gt(const SymVal &other) const; + SymVal geq(const SymVal &other) const; +}; + +inline SymVal Concrete(Num num) { + return SymVal(std::make_shared(num)); +} + +enum Operation { ADD, SUB, MUL, DIV, EQ, NEQ, LT, LEQ, GT, GEQ }; + +struct SymBinary : Symbolic { + Operation op; + SymVal lhs; + SymVal rhs; + + SymBinary(Operation op, SymVal lhs, SymVal rhs) + : op(op), lhs(lhs), rhs(rhs) {} }; +inline SymVal SymVal::add(const SymVal &other) const { + return SymVal(std::make_shared(ADD, this, other)); +} + +inline SymVal SymVal::minus(const SymVal &other) const { + return SymVal(std::make_shared(SUB, this, other)); +} + +inline SymVal SymVal::mul(const SymVal &other) const { + return SymVal(std::make_shared(MUL, this, other)); +} + +inline SymVal SymVal::div(const SymVal &other) const { + return SymVal(std::make_shared(DIV, this, other)); +} + +inline SymVal SymVal::eq(const SymVal &other) const { + return SymVal(std::make_shared(EQ, this, other)); +} + +inline SymVal SymVal::neq(const SymVal &other) const { + return SymVal(std::make_shared(NEQ, this, other)); +} +inline SymVal SymVal::lt(const SymVal &other) const { + return SymVal(std::make_shared(LT, this, other)); +} +inline SymVal SymVal::leq(const SymVal &other) const { + return SymVal(std::make_shared(LEQ, this, other)); +} +inline SymVal SymVal::gt(const SymVal &other) const { + return SymVal(std::make_shared(GT, this, other)); +} +inline SymVal SymVal::geq(const SymVal &other) const { + return SymVal(std::make_shared(GEQ, this, other)); +} + class SymStack_t { public: void push(SymVal val) { // Push a symbolic value to the stack - // Not implemented yet + stack.push_back(val); } SymVal pop() { // Pop a symbolic value from the stack - // Not implemented yet - return SymVal(); + auto ret = stack.back(); + stack.pop_back(); + return ret; } - SymVal peek() { return SymVal(); } + SymVal peek() { return stack.back(); } + + std::vector stack; }; static SymStack_t SymStack; @@ -52,34 +114,30 @@ class SymFrames_t { public: void pushFrame(int size) { // Push a new frame with the given size - // Not implemented yet + stack.resize(size + stack.size()); } std::monostate popFrame(int size) { // Pop the frame of the given size - // Not implemented yet + stack.resize(stack.size() - size); return std::monostate(); } SymVal get(int index) { - // Get the symbolic value at the given index - // Not implemented yet - return SymVal(); + // Get the symbolic value at the given frame index + return stack[stack.size() - 1 - index]; } void set(int index, SymVal val) { // Set the symbolic value at the given index // Not implemented yet + stack[stack.size() - 1 - index] = val; } + + std::vector stack; }; static SymFrames_t SymFrames; -static SymVal Concrete(Num num) { - // Convert a concrete number to a symbolic value - // Not implemented yet - return SymVal(); -} - struct Node; struct NodeBox { @@ -115,11 +173,11 @@ struct Node { int Node::current_id = 0; struct IfElseNode : Node { - SymVal cond; + Symbolic cond; std::unique_ptr true_branch; std::unique_ptr false_branch; - IfElseNode(SymVal cond) + IfElseNode(Symbolic cond) : cond(cond), true_branch(std::make_unique()), false_branch(std::make_unique()) {} @@ -205,7 +263,7 @@ class ExploreTree_t { public: explicit ExploreTree_t() : root(std::make_unique()), cursor(root.get()) {} - std::monostate fillIfElseNode(SymVal cond) { + std::monostate fillIfElseNode(Symbolic cond) { // fill the current node with an ifelse branch node cursor->node = std::make_unique(cond); return std::monostate(); From 77a4e6f58547a71a5544779f06f41e92b5398375 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 16 Jul 2025 22:50:08 +0800 Subject: [PATCH 10/21] `type.symbolic` instruction --- headers/wasm/symbolic_rt.hpp | 70 +++++++++++++++---- .../scala/wasm/StagedConcolicMiniWasm.scala | 62 +++++++++++++++- .../genwasym/TestStagedConcolicEval.scala | 4 ++ 3 files changed, 120 insertions(+), 16 deletions(-) diff --git a/headers/wasm/symbolic_rt.hpp b/headers/wasm/symbolic_rt.hpp index fac3ce6d..e0d3feef 100644 --- a/headers/wasm/symbolic_rt.hpp +++ b/headers/wasm/symbolic_rt.hpp @@ -10,7 +10,19 @@ #include #include -class Symbolic {}; +class Symbolic { +public: + Symbolic() {} // TODO: remove this default constructor later + virtual ~Symbolic() = default; // Make Symbolic polymorphic +}; + +class Symbol : public Symbolic { +public: + Symbol(int id) : id(id) {} + +private: + int id; +}; class SymConcrete : public Symbolic { public: @@ -26,6 +38,11 @@ struct SymVal { SymVal() : symptr(nullptr) {} SymVal(std::shared_ptr symptr) : symptr(symptr) {} + // data structure operations + SymVal makeSymbolic() const; + + // arithmetic operations + SymVal is_zero() const; SymVal add(const SymVal &other) const; SymVal minus(const SymVal &other) const; SymVal mul(const SymVal &other) const; @@ -54,39 +71,53 @@ struct SymBinary : Symbolic { }; inline SymVal SymVal::add(const SymVal &other) const { - return SymVal(std::make_shared(ADD, this, other)); + return SymVal(std::make_shared(ADD, *this, other)); } inline SymVal SymVal::minus(const SymVal &other) const { - return SymVal(std::make_shared(SUB, this, other)); + return SymVal(std::make_shared(SUB, *this, other)); } inline SymVal SymVal::mul(const SymVal &other) const { - return SymVal(std::make_shared(MUL, this, other)); + return SymVal(std::make_shared(MUL, *this, other)); } inline SymVal SymVal::div(const SymVal &other) const { - return SymVal(std::make_shared(DIV, this, other)); + return SymVal(std::make_shared(DIV, *this, other)); } inline SymVal SymVal::eq(const SymVal &other) const { - return SymVal(std::make_shared(EQ, this, other)); + return SymVal(std::make_shared(EQ, *this, other)); } inline SymVal SymVal::neq(const SymVal &other) const { - return SymVal(std::make_shared(NEQ, this, other)); + return SymVal(std::make_shared(NEQ, *this, other)); } inline SymVal SymVal::lt(const SymVal &other) const { - return SymVal(std::make_shared(LT, this, other)); + return SymVal(std::make_shared(LT, *this, other)); } inline SymVal SymVal::leq(const SymVal &other) const { - return SymVal(std::make_shared(LEQ, this, other)); + return SymVal(std::make_shared(LEQ, *this, other)); } inline SymVal SymVal::gt(const SymVal &other) const { - return SymVal(std::make_shared(GT, this, other)); + return SymVal(std::make_shared(GT, *this, other)); } inline SymVal SymVal::geq(const SymVal &other) const { - return SymVal(std::make_shared(GEQ, this, other)); + return SymVal(std::make_shared(GEQ, *this, other)); +} +inline SymVal SymVal::is_zero() const { + return SymVal(std::make_shared(EQ, *this, Concrete(I32V(0)))); +} + +inline SymVal SymVal::makeSymbolic() const { + auto concrete = dynamic_cast(symptr.get()); + if (concrete) { + // If the symbolic value is a concrete value, use it to create a symbol + return SymVal(std::make_shared(concrete->value.toInt())); + } else { + throw std::runtime_error( + "Cannot make symbolic a non-concrete symbolic value"); + } } class SymStack_t { @@ -173,11 +204,11 @@ struct Node { int Node::current_id = 0; struct IfElseNode : Node { - Symbolic cond; + SymVal cond; std::unique_ptr true_branch; std::unique_ptr false_branch; - IfElseNode(Symbolic cond) + IfElseNode(SymVal cond) : cond(cond), true_branch(std::make_unique()), false_branch(std::make_unique()) {} @@ -263,7 +294,7 @@ class ExploreTree_t { public: explicit ExploreTree_t() : root(std::make_unique()), cursor(root.get()) {} - std::monostate fillIfElseNode(Symbolic cond) { + std::monostate fillIfElseNode(SymVal cond) { // fill the current node with an ifelse branch node cursor->node = std::make_unique(cond); return std::monostate(); @@ -300,4 +331,15 @@ class ExploreTree_t { static ExploreTree_t ExploreTree; +class SymEnv_t { +public: + Num read(SymVal sym) { + // Read a symbolic value from the symbolic environment + // For now, we just return a zero + return Num(0); + } +}; + +static SymEnv_t SymEnv; + #endif // WASM_SYMBOLIC_RT_HPP \ No newline at end of file diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index 4366ac78..01bb91c2 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -57,6 +57,15 @@ trait StagedWasmEvaluator extends SAIOps { case NumType(F32Type) => 4 case NumType(F64Type) => 8 } + + def toTagger: (Rep[Num], Rep[SymVal]) => StagedNum = { + ty match { + case NumType(I32Type) => I32 + case NumType(I64Type) => I64 + case NumType(F32Type) => F32 + case NumType(F64Type) => F64 + } + } } case class Context( @@ -121,6 +130,14 @@ trait StagedWasmEvaluator extends SAIOps { case WasmConst(num) => val newCtx = Stack.push(toStagedNum(num)) eval(rest, kont, mkont, trail)(newCtx) + case Symbolic(ty) => + val (id, newCtx1) = Stack.pop() + val symVal = id.makeSymbolic() + val concVal = SymEnv.read(symVal) + val tagger = ty.toTagger + val value = tagger(concVal, symVal) + val newCtx2 = Stack.push(value)(newCtx1) + eval(rest, kont, mkont, trail)(newCtx2) case LocalGet(i) => val newCtx = Stack.push(Frames.get(i)) eval(rest, kont, mkont, trail)(newCtx) @@ -326,6 +343,10 @@ trait StagedWasmEvaluator extends SAIOps { val (v, newCtx) = Stack.pop() println(v.toInt) eval(rest, kont, mkont, trail)(newCtx) + case Import("console", "assert", _) => + val (v, newCtx) = Stack.pop() + runtimeAssert(v.toInt != 0) + eval(rest, kont, mkont, trail)(newCtx) case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex") case _ => throw new Exception(s"Definition at $funcIndex is not callable") } @@ -414,6 +435,10 @@ trait StagedWasmEvaluator extends SAIOps { evalTop(temp, main) } + def runtimeAssert(b: Rep[Boolean]): Rep[Unit] = { + "assert-true".reflectCtrlWith[Unit](b) + } + // stack operations object Stack { def shift(offset: Int, size: Int)(ctx: Context): Context = { @@ -598,6 +623,12 @@ trait StagedWasmEvaluator extends SAIOps { } } + object SymEnv { + def read(sym: Rep[SymVal]): Rep[Num] = { + "sym-env-read".reflectCtrlWith[Num](sym) + } + } + // runtime Num type implicit class StagedNumOps(num: StagedNum) { @@ -622,6 +653,10 @@ trait StagedWasmEvaluator extends SAIOps { case I64(x_c, x_s) => I64("popcnt".reflectCtrlWith[Num](x_c), "sym-popcnt".reflectCtrlWith[SymVal](x_s)) } + def makeSymbolic(): Rep[SymVal] = { + "make-symbolic".reflectCtrlWith[SymVal](num.s) + } + def +(rhs: StagedNum): StagedNum = { (num, rhs) match { case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-add".reflectCtrlWith[Num](x_c, y_c), "sym-binary-add".reflectCtrlWith[SymVal](x_s, y_s)) @@ -778,6 +813,7 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { override def mayInline(n: Node): Boolean = n match { case Node(_, "stack-pop", _, _) | Node(_, "stack-peek", _, _) + | Node(_, "sym-stack-pop", _, _) => false case _ => super.mayInline(n) } @@ -874,8 +910,6 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { shallow(s_num); emit(".is_zero()") case Node(_, "binary-add", List(lhs, rhs), _) => shallow(lhs); emit(" + "); shallow(rhs) - case Node(_, "sym-binary-add", List(lhs, rhs), _) => - shallow(lhs); emit(" + "); shallow(rhs) case Node(_, "binary-sub", List(lhs, rhs), _) => shallow(lhs); emit(" - "); shallow(rhs) case Node(_, "binary-mul", List(lhs, rhs), _) => @@ -908,8 +942,32 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { shallow(lhs); emit(" >= "); shallow(rhs) case Node(_, "relation-geu", List(lhs, rhs), _) => shallow(lhs); emit(" >= "); shallow(rhs) + case Node(_, "sym-binary-add", List(lhs, rhs), _) => + shallow(lhs); emit(".add("); shallow(rhs); emit(")") + case Node(_, "sym-binary-mul", List(lhs, rhs), _) => + shallow(lhs); emit(".mul("); shallow(rhs); emit(")") + case Node(_, "sym-binary-div", List(lhs, rhs), _) => + shallow(lhs); emit(".div("); shallow(rhs); emit(")") + case Node(_, "sym-relation-le", List(lhs, rhs), _) => + shallow(lhs); emit(".leq("); shallow(rhs); emit(")") + case Node(_, "sym-relation-leu", List(lhs, rhs), _) => + shallow(lhs); emit(".leu("); shallow(rhs); emit(")") + case Node(_, "sym-relation-ge", List(lhs, rhs), _) => + shallow(lhs); emit(".ge("); shallow(rhs); emit(")") + case Node(_, "sym-relation-geu", List(lhs, rhs), _) => + shallow(lhs); emit(".geu("); shallow(rhs); emit(")") + case Node(_, "sym-relation-eq", List(lhs, rhs), _) => + shallow(lhs); emit(".eq("); shallow(rhs); emit(")") + case Node(_, "sym-relation-ne", List(lhs, rhs), _) => + shallow(lhs); emit(".neq("); shallow(rhs); emit(")") case Node(_, "num-to-int", List(num), _) => shallow(num); emit(".toInt()") + case Node(_, "make-symbolic", List(num), _) => + shallow(num); emit(".makeSymbolic()") + case Node(_, "sym-env-read", List(sym), _) => + emit("SymEnv.read("); shallow(sym); emit(")") + case Node(_, "assert-true", List(cond), _) => + emit("assert("); shallow(cond); emit(")") case Node(_, "tree-fill-if-else", List(s), _) => emit("ExploreTree.fillIfElseNode("); shallow(s); emit(")") case Node(_, "tree-move-cursor", List(b), _) => diff --git a/src/test/scala/genwasym/TestStagedConcolicEval.scala b/src/test/scala/genwasym/TestStagedConcolicEval.scala index eef6ab01..409bd07a 100644 --- a/src/test/scala/genwasym/TestStagedConcolicEval.scala +++ b/src/test/scala/genwasym/TestStagedConcolicEval.scala @@ -30,4 +30,8 @@ class TestStagedConcolicEval extends FunSuite { } test("ack-cpp") { testFileToCpp("./benchmarks/wasm/ack.wat", Some("real_main"), expect=Some(List(7))) } + + test("bug-finding") { + testFileToCpp("./benchmarks/wasm/branch-strip-buggy.wat", Some("real_main")) + } } From 314ff5fe8b848cb1f40ee774283489c0760eec35 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 16 Jul 2025 23:11:33 +0800 Subject: [PATCH 11/21] test staged concolic compilation in CI --- .github/workflows/scala.yml | 1 + benchmarks/wasm/branch-strip-buggy.wat | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/scala.yml b/.github/workflows/scala.yml index 4677da77..d2f8348b 100644 --- a/.github/workflows/scala.yml +++ b/.github/workflows/scala.yml @@ -79,3 +79,4 @@ jobs: sbt 'testOnly gensym.wasm.TestConcolic' sbt 'testOnly gensym.wasm.TestDriver' sbt 'testOnly gensym.wasm.TestStagedEval' + sbt 'testOnly gensym.wasm.TestStagedConcolicEval' diff --git a/benchmarks/wasm/branch-strip-buggy.wat b/benchmarks/wasm/branch-strip-buggy.wat index c957db7f..0685f0be 100644 --- a/benchmarks/wasm/branch-strip-buggy.wat +++ b/benchmarks/wasm/branch-strip-buggy.wat @@ -29,6 +29,7 @@ else i32.const 0 call 2 + i32.const 1 ;; to satisfy the type checker, this line will never be reached end end ) From 873936902fe34538c698fbd664df370b9133003c Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Thu, 17 Jul 2025 00:18:51 +0800 Subject: [PATCH 12/21] dump graphviz by default --- headers/wasm/symbolic_rt.hpp | 10 +++++++++ .../scala/wasm/StagedConcolicMiniWasm.scala | 22 ++++++++++++++----- .../genwasym/TestStagedConcolicEval.scala | 3 ++- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/headers/wasm/symbolic_rt.hpp b/headers/wasm/symbolic_rt.hpp index e0d3feef..a97fd0dd 100644 --- a/headers/wasm/symbolic_rt.hpp +++ b/headers/wasm/symbolic_rt.hpp @@ -4,6 +4,7 @@ #include "concrete_rt.hpp" #include #include +#include #include #include #include @@ -324,6 +325,15 @@ class ExploreTree_t { return std::monostate(); } + std::monostate dump_graphviz(std::string filepath) { + std::ofstream ofs(filepath); + if (!ofs.is_open()) { + throw std::runtime_error("Failed to open explore_tree.dot for writing"); + } + to_graphviz(ofs); + return std::monostate(); + } + private: std::unique_ptr root; NodeBox *cursor; diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index 01bb91c2..3c342177 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -422,12 +422,15 @@ trait StagedWasmEvaluator extends SAIOps { Frames.popFrame(locals.size) } - def evalTop(main: Option[String], printRes: Boolean = false): Rep[Unit] = { + def evalTop(main: Option[String], printRes: Boolean, dumpTree: Option[String]): Rep[Unit] = { val haltK: Rep[Unit] => Rep[Unit] = (_) => { info("Exiting the program...") if (printRes) { Stack.print() - ExploreTree.print() + } + dumpTree match { + case Some(filePath) => ExploreTree.dumpGraphiviz(filePath) + case None => () } "no-op".reflectCtrlWith[Unit]() } @@ -621,6 +624,10 @@ trait StagedWasmEvaluator extends SAIOps { def print(): Rep[Unit] = { "tree-print".reflectCtrlWith[Unit]() } + + def dumpGraphiviz(filePath: String): Rep[Unit] = { + "tree-dump-graphviz".reflectCtrlWith[Unit](filePath) + } } object SymEnv { @@ -974,6 +981,8 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { emit("ExploreTree.moveCursor("); shallow(b); emit(")") case Node(_, "tree-print", List(), _) => emit("ExploreTree.print()") + case Node(_, "tree-dump-graphviz", List(f), _) => + emit("ExploreTree.dump_graphviz("); shallow(f); emit(")") case Node(_, "sym-not", List(s), _) => shallow(s); emit(".negate()") case Node(_, "dummy", _, _) => emit("std::monostate()") @@ -1033,12 +1042,12 @@ trait WasmToCppCompilerDriver[A, B] extends CppSAIDriver[A, B] with StagedWasmEv object WasmToCppCompiler { case class GeneratedCpp(source: String, headerFolders: List[String]) - def compile(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): GeneratedCpp = { + def compile(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean, dumpTree: Option[String]): GeneratedCpp = { println(s"Now compiling wasm module with entry function $main") val driver = new WasmToCppCompilerDriver[Unit, Unit] { def module: ModuleInstance = moduleInst def snippet(x: Rep[Unit]): Rep[Unit] = { - evalTop(main, printRes) + evalTop(main, printRes, dumpTree) } } GeneratedCpp(driver.code, driver.codegen.includePaths.toList) @@ -1048,8 +1057,9 @@ object WasmToCppCompiler { main: Option[String], outputCpp: String, outputExe: String, - printRes: Boolean = false): Unit = { - val generated = compile(moduleInst, main, printRes) + printRes: Boolean, + dumpTree: Option[String]): Unit = { + val generated = compile(moduleInst, main, printRes, dumpTree) val code = generated.source val writer = new java.io.PrintWriter(new java.io.File(outputCpp)) diff --git a/src/test/scala/genwasym/TestStagedConcolicEval.scala b/src/test/scala/genwasym/TestStagedConcolicEval.scala index 409bd07a..77868e2c 100644 --- a/src/test/scala/genwasym/TestStagedConcolicEval.scala +++ b/src/test/scala/genwasym/TestStagedConcolicEval.scala @@ -13,7 +13,8 @@ class TestStagedConcolicEval extends FunSuite { val moduleInst = ModuleInstance(Parser.parseFile(filename)) val cppFile = s"$filename.cpp" val exe = s"$cppFile.exe" - WasmToCppCompiler.compileToExe(moduleInst, main, cppFile, exe, true) + val exploreTreeFile = s"$filename.tree.dot" + WasmToCppCompiler.compileToExe(moduleInst, main, cppFile, exe, true, Some(exploreTreeFile)) import sys.process._ val result = s"./$exe".!! From 9a9988c560f58961b2fb09892912623c26b9691e Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Thu, 17 Jul 2025 18:15:48 +0800 Subject: [PATCH 13/21] concolic driver --- headers/wasm.hpp | 3 +- headers/wasm/concolic_driver.hpp | 98 +++++++++ headers/wasm/concrete_rt.hpp | 9 +- headers/wasm/smt_solver.hpp | 28 +++ headers/wasm/symbolic_rt.hpp | 205 +++++++++++++++--- headers/wasm/utils.hpp | 15 ++ .../scala/wasm/StagedConcolicMiniWasm.scala | 20 +- 7 files changed, 338 insertions(+), 40 deletions(-) create mode 100644 headers/wasm/concolic_driver.hpp create mode 100644 headers/wasm/smt_solver.hpp create mode 100644 headers/wasm/utils.hpp diff --git a/headers/wasm.hpp b/headers/wasm.hpp index c7e98b6e..36fe3849 100644 --- a/headers/wasm.hpp +++ b/headers/wasm.hpp @@ -3,5 +3,6 @@ #include "wasm/concrete_rt.hpp" #include "wasm/symbolic_rt.hpp" - +#include "wasm/concolic_driver.hpp" +#include "wasm/utils.hpp" #endif \ No newline at end of file diff --git a/headers/wasm/concolic_driver.hpp b/headers/wasm/concolic_driver.hpp new file mode 100644 index 00000000..4307413b --- /dev/null +++ b/headers/wasm/concolic_driver.hpp @@ -0,0 +1,98 @@ +#ifndef CONCOLIC_DRIVER_HPP +#define CONCOLIC_DRIVER_HPP + +#include "smt_solver.hpp" +#include "symbolic_rt.hpp" +#include +#include +#include + +class ConcolicDriver { + friend class ManagedConcolicCleanup; + +public: + ConcolicDriver(std::function entrypoint, std::string tree_file) + : entrypoint(entrypoint), tree_file(tree_file) {} + ConcolicDriver(std::function entrypoint) + : entrypoint(entrypoint), tree_file(std::nullopt) {} + void run(); + +private: + Solver solver; + std::function entrypoint; + std::optional tree_file; +}; + +class ManagedConcolicCleanup { + const ConcolicDriver &driver; + +public: + ManagedConcolicCleanup(const ConcolicDriver &driver) : driver(driver) {} + ~ManagedConcolicCleanup() { + if (driver.tree_file.has_value()) + ExploreTree.dump_graphviz(driver.tree_file.value()); + } +}; + +inline void ConcolicDriver::run() { + ManagedConcolicCleanup cleanup{*this}; + while (true) { + auto cond = ExploreTree.get_unexplored_conditions(); + ExploreTree.reset_cursor(); + + if (!cond.has_value()) { + std::cout << "No unexplored conditions found, exiting..." << std::endl; + return; + } + auto new_env = solver.solve(cond.value()); + if (!new_env.has_value()) { + std::cout << "All unexplored paths are unreachable, exiting..." + << std::endl; + return; + } + SymEnv.update(std::move(new_env.value())); + try { + entrypoint(); + std::cout << "Execution finished successfully with symbolic environment:" + << std::endl; + std::cout << SymEnv.to_string() << std::endl; + } catch (...) { + ExploreTree.fillFailedNode(); + std::cout << "Caught runtime error with symbolic environment:" + << std::endl; + std::cout << SymEnv.to_string() << std::endl; + return; + } + } +} + +static std::monostate reset_stacks() { + Stack.reset(); + Frames.reset(); + SymStack.reset(); + SymFrames.reset(); + initRand(); + Memory = Memory_t(1); + return std::monostate{}; +} + +static void start_concolic_execution_with( + std::function entrypoint, + std::string tree_file) { + ConcolicDriver driver([=]() { entrypoint(std::monostate{}); }, tree_file); + driver.run(); +} + +static void start_concolic_execution_with( + std::function entrypoint) { + + const char *env_tree_file = std::getenv("TREE_FILE"); + + ConcolicDriver driver = + env_tree_file ? ConcolicDriver([=]() { entrypoint(std::monostate{}); }, + env_tree_file) + : ConcolicDriver([=]() { entrypoint(std::monostate{}); }); + driver.run(); +} + +#endif // CONCOLIC_DRIVER_HPP \ No newline at end of file diff --git a/headers/wasm/concrete_rt.hpp b/headers/wasm/concrete_rt.hpp index e994cbde..a0961453 100644 --- a/headers/wasm/concrete_rt.hpp +++ b/headers/wasm/concrete_rt.hpp @@ -52,8 +52,6 @@ static Num I32V(int v) { return v; } static Num I64V(int64_t v) { return v; } -using Slice = std::vector; - const int STACK_SIZE = 1024 * 64; class Stack_t { @@ -118,9 +116,12 @@ class Stack_t { } void initialize() { - // do nothing for now + // todo: remove this method + reset(); } + void reset() { count = 0; } + private: int32_t count; Num *stack_ptr; @@ -151,6 +152,8 @@ class Frames_t { count += size; } + void reset() { count = 0; } + private: int32_t count; Num *stack_ptr; diff --git a/headers/wasm/smt_solver.hpp b/headers/wasm/smt_solver.hpp new file mode 100644 index 00000000..a3bbf78d --- /dev/null +++ b/headers/wasm/smt_solver.hpp @@ -0,0 +1,28 @@ +#ifndef SMT_SOLVER_HPP +#define SMT_SOLVER_HPP + +#include "concrete_rt.hpp" +#include "symbolic_rt.hpp" +#include +#include + +class Solver { +public: + Solver() : count(0) { + envs[0] = {Num(0), Num(0)}; + envs[1] = {Num(1), Num(2)}; + } + std::optional> solve(const std::vector &conditions) { + // here is just a placeholder implementation to simulate solving result + if (count >= envs.size()) { + return std::nullopt; // No more environments to return + } + return envs[count++ % envs.size()]; + } + +private: + std::array, 5> envs; + int count; +}; + +#endif // SMT_SOLVER_HPP \ No newline at end of file diff --git a/headers/wasm/symbolic_rt.hpp b/headers/wasm/symbolic_rt.hpp index a97fd0dd..920bcad4 100644 --- a/headers/wasm/symbolic_rt.hpp +++ b/headers/wasm/symbolic_rt.hpp @@ -7,7 +7,9 @@ #include #include #include +#include #include +#include #include #include @@ -20,6 +22,7 @@ class Symbolic { class Symbol : public Symbolic { public: Symbol(int id) : id(id) {} + int get_id() const { return id; } private: int id; @@ -54,6 +57,7 @@ struct SymVal { SymVal leq(const SymVal &other) const; SymVal gt(const SymVal &other) const; SymVal geq(const SymVal &other) const; + SymVal negate() const; }; inline SymVal Concrete(Num num) { @@ -109,6 +113,9 @@ inline SymVal SymVal::geq(const SymVal &other) const { inline SymVal SymVal::is_zero() const { return SymVal(std::make_shared(EQ, *this, Concrete(I32V(0)))); } +inline SymVal SymVal::negate() const { + return SymVal(std::make_shared(EQ, *this, Concrete(I32V(0)))); +} inline SymVal SymVal::makeSymbolic() const { auto concrete = dynamic_cast(symptr.get()); @@ -137,6 +144,11 @@ class SymStack_t { SymVal peek() { return stack.back(); } + void reset() { + // Reset the symbolic stack + stack.clear(); + } + std::vector stack; }; @@ -165,6 +177,11 @@ class SymFrames_t { stack[stack.size() - 1 - index] = val; } + void reset() { + // Reset the symbolic frames + stack.clear(); + } + std::vector stack; }; @@ -190,14 +207,28 @@ struct Node { os << "}\n"; } - int get_next_id(int &id_counter) { return id_counter++; } - virtual int generate_dot(std::ostream &os, int parent_dot_id, - const std::string &edge_label) = 0; + virtual void generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) = 0; protected: // Counter for unique node IDs across the entire graph, only for generating // graphviz purpose static int current_id; + void graphviz_node(std::ostream &os, const int node_id, + const std::string &label, const std::string &shape, + const std::string &fillcolor) { + os << " node" << node_id << " [label=\"" << label << "\", shape=" << shape + << ", style=filled, fillcolor=" << fillcolor << "];\n"; + } + + void graphviz_edge(std::ostream &os, int from_id, int target_id, + const std::string &edge_label) { + os << " node" << from_id << " -> node" << target_id; + if (!edge_label.empty()) { + os << " [label=\"" << edge_label << "\"]"; + } + os << ";\n"; + } }; // TODO: use this header file in multiple compilation units will cause problems @@ -234,21 +265,16 @@ struct IfElseNode : Node { return result; } - int generate_dot(std::ostream &os, int parent_dot_id, - const std::string &edge_label) override { + void generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) override { int current_node_dot_id = current_id; current_id += 1; - os << " node" << current_node_dot_id << " [label=\"If\"," - << "shape=diamond, fillcolor=lightyellow];\n"; + graphviz_node(os, current_node_dot_id, "If", "diamond", "lightyellow"); // Draw edge from parent if this is not the root node if (parent_dot_id != -1) { - os << " node" << parent_dot_id << " -> node" << current_node_dot_id; - if (!edge_label.empty()) { - os << " [label=\"" << edge_label << "\"]"; - } - os << ";\n"; + graphviz_edge(os, parent_dot_id, current_node_dot_id, edge_label); } assert(true_branch != nullptr); assert(true_branch->node != nullptr); @@ -256,7 +282,6 @@ struct IfElseNode : Node { assert(false_branch != nullptr); assert(false_branch->node != nullptr); false_branch->node->generate_dot(os, current_node_dot_id, "false"); - return current_node_dot_id; } }; @@ -265,39 +290,92 @@ struct UnExploredNode : Node { std::string to_string() override { return "UnexploredNode"; } protected: - int generate_dot(std::ostream &os, int parent_dot_id, - const std::string &edge_label) override { + void generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) override { int current_node_dot_id = current_id++; + graphviz_node(os, current_node_dot_id, "Unexplored", "octagon", + "lightgrey"); + + if (parent_dot_id != -1) { + graphviz_edge(os, parent_dot_id, current_node_dot_id, edge_label); + } + } +}; + +struct Finished : Node { + Finished() {} + std::string to_string() override { return "FinishedNode"; } - os << " node" << current_node_dot_id - << " [label=\"Unexplored\", shape=octagon, style=filled, " - "fillcolor=lightgrey];\n"; +protected: + void generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) override { + int current_node_dot_id = current_id++; + graphviz_node(os, current_node_dot_id, "Finished", "box", "lightgreen"); if (parent_dot_id != -1) { - os << " node" << parent_dot_id << " -> node" << current_node_dot_id; - if (!edge_label.empty()) { - os << " [label=\"" << edge_label << "\"]"; - } - os << ";\n"; + graphviz_edge(os, parent_dot_id, current_node_dot_id, edge_label); } + } +}; + +struct Failed : Node { + Failed() {} + std::string to_string() override { return "FailedNode"; } - return current_node_dot_id; +protected: + void generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) override { + int current_node_dot_id = current_id++; + graphviz_node(os, current_node_dot_id, "Failed", "box", "red"); + + if (parent_dot_id != -1) { + graphviz_edge(os, parent_dot_id, current_node_dot_id, edge_label); + } } }; static UnExploredNode unexplored; inline NodeBox::NodeBox() - : node(std::make_unique< - UnExploredNode>() /* TODO: avoid allocation of unexplored node */) {} + : node(std::make_unique()), + /* TODO: avoid allocation of unexplored node */ + parent(nullptr) {} class ExploreTree_t { public: explicit ExploreTree_t() : root(std::make_unique()), cursor(root.get()) {} + + void reset_cursor() { + // Reset the cursor to the root of the tree + cursor = root.get(); + } + + std::monostate fillFinishedNode() { + if (dynamic_cast(cursor->node.get())) { + cursor->node = std::make_unique(); + } else { + assert(dynamic_cast(cursor->node.get()) != nullptr); + } + return std::monostate{}; + } + + std::monostate fillFailedNode() { + if (dynamic_cast(cursor->node.get())) { + cursor->node = std::make_unique(); + } else { + assert(dynamic_cast(cursor->node.get()) != nullptr); + } + return std::monostate{}; + } + std::monostate fillIfElseNode(SymVal cond) { - // fill the current node with an ifelse branch node - cursor->node = std::make_unique(cond); + // fill the current NodeBox with an ifelse branch node it's unexplored + if (dynamic_cast(cursor->node.get())) { + cursor->node = std::make_unique(cond); + } + assert(dynamic_cast(cursor->node.get()) != nullptr && + "Current node is not an IfElseNode, cannot fill it!"); return std::monostate(); } @@ -328,13 +406,60 @@ class ExploreTree_t { std::monostate dump_graphviz(std::string filepath) { std::ofstream ofs(filepath); if (!ofs.is_open()) { - throw std::runtime_error("Failed to open explore_tree.dot for writing"); + throw std::runtime_error("Failed to open " + filepath + " for writing"); } to_graphviz(ofs); return std::monostate(); } + std::optional> get_unexplored_conditions() { + // Get all unexplored conditions in the tree + std::vector result; + auto box = pick_unexplored(); + if (!box) { + return std::nullopt; + } + while (box->parent) { + auto parent = box->parent; + auto if_else_node = dynamic_cast(parent->node.get()); + if (if_else_node) { + if (if_else_node->true_branch.get() == box) { + // If the current box is the true branch, add the condition + result.push_back(if_else_node->cond); + } else if (if_else_node->false_branch.get() == box) { + // If the current box is the false branch, add the negated condition + result.push_back(if_else_node->cond.negate()); + } else { + throw std::runtime_error("Unexpected node structure in explore tree"); + } + } + // Move to parent + box = box->parent; + } + return result; + } + + NodeBox *pick_unexplored() { + // Pick an unexplored node from the tree + // For now, we just iterate through the tree and return the first unexplored + return pick_unexplored_of(root.get()); + } + private: + NodeBox *pick_unexplored_of(NodeBox *node) { + if (dynamic_cast(node->node.get()) != nullptr) { + return node; + } + auto if_else_node = dynamic_cast(node->node.get()); + if (if_else_node) { + NodeBox *result = pick_unexplored_of(if_else_node->true_branch.get()); + if (result) { + return result; + } + return pick_unexplored_of(if_else_node->false_branch.get()); + } + return nullptr; // No unexplored node found + } std::unique_ptr root; NodeBox *cursor; }; @@ -344,10 +469,28 @@ static ExploreTree_t ExploreTree; class SymEnv_t { public: Num read(SymVal sym) { - // Read a symbolic value from the symbolic environment - // For now, we just return a zero return Num(0); + auto symbol = dynamic_cast(sym.symptr.get()); + assert(symbol); + return map[symbol->get_id()]; } + + void update(std::vector new_env) { map = std::move(new_env); } + + std::string to_string() const { + std::string result; + result += "(\n"; + for (int i = 0; i < map.size(); ++i) { + const Num &num = map[i]; + result += + " (" + std::to_string(i) + "->" + std::to_string(num.value) + ")\n"; + } + result += ")"; + return result; + } + +private: + std::vector map; // The symbolic environment, a vector of Num }; static SymEnv_t SymEnv; diff --git a/headers/wasm/utils.hpp b/headers/wasm/utils.hpp new file mode 100644 index 00000000..8a86ac98 --- /dev/null +++ b/headers/wasm/utils.hpp @@ -0,0 +1,15 @@ +#ifndef UTILS_HPP +#define UTILS_HPP + +#ifndef GENSYM_ASSERT +#define GENSYM_ASSERT(condition) \ + do { \ + if (!(condition)) { \ + throw std::runtime_error(std::string("Assertion failed: ") + " (" + \ + __FILE__ + ":" + std::to_string(__LINE__) + \ + ")"); \ + } \ + } while (0) +#endif + +#endif // UTILS_HPP \ No newline at end of file diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index 3c342177..71463bf5 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -416,7 +416,7 @@ trait StagedWasmEvaluator extends SAIOps { } } val (instrs, locals) = (funBody.body, funBody.locals) - Stack.initialize() + resetStacks() Frames.pushFrame(locals) eval(instrs, (_: Context) => forwardKont, mkont, ((_: Context) => forwardKont)::Nil)(Context(Nil, locals)) Frames.popFrame(locals.size) @@ -428,10 +428,7 @@ trait StagedWasmEvaluator extends SAIOps { if (printRes) { Stack.print() } - dumpTree match { - case Some(filePath) => ExploreTree.dumpGraphiviz(filePath) - case None => () - } + ExploreTree.fillWithFinished() "no-op".reflectCtrlWith[Unit]() } val temp: Rep[MCont[Unit]] = topFun(haltK) @@ -558,6 +555,7 @@ trait StagedWasmEvaluator extends SAIOps { object Memory { def storeInt(base: Rep[Int], offset: Int, value: Rep[Int]): Rep[Unit] = { "memory-store-int".reflectCtrlWith[Unit](base, offset, value) + // todo: store symbolic value to memory via extract/concat operation } def loadInt(base: Rep[Int], offset: Int): StagedNum = { @@ -570,6 +568,10 @@ trait StagedWasmEvaluator extends SAIOps { } } + def resetStacks(): Rep[Unit] = { + "reset-stacks".reflectCtrlWith[Unit]() + } + // call unreachable def unreachable(): Rep[Unit] = { "unreachable".reflectCtrlWith[Unit]() @@ -617,6 +619,10 @@ trait StagedWasmEvaluator extends SAIOps { "tree-fill-if-else".reflectCtrlWith[Unit](s) } + def fillWithFinished(): Rep[Unit] = { + "tree-fill-finished".reflectCtrlWith[Unit]() + } + def moveCursor(branch: Boolean): Rep[Unit] = { "tree-move-cursor".reflectCtrlWith[Unit](branch) } @@ -875,6 +881,8 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { } override def shallow(n: Node): Unit = n match { + case Node(_, "reset-stacks", _, _) => + emit("reset_stacks()") case Node(_, "frame-get", List(i), _) => emit("Frames.get("); shallow(i); emit(")") case Node(_, "sym-frame-get", List(i), _) => @@ -977,6 +985,8 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { emit("assert("); shallow(cond); emit(")") case Node(_, "tree-fill-if-else", List(s), _) => emit("ExploreTree.fillIfElseNode("); shallow(s); emit(")") + case Node(_, "tree-fill-finished", List(), _) => + emit("ExploreTree.fillFinishedNode()") case Node(_, "tree-move-cursor", List(b), _) => emit("ExploreTree.moveCursor("); shallow(b); emit(")") case Node(_, "tree-print", List(), _) => From 9ab162fe0305be5c955154b6be9292befe8e4d70 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Fri, 18 Jul 2025 17:41:15 +0800 Subject: [PATCH 14/21] fix: add an unreachable node & use GENSYM_ASSERT --- headers/wasm/concolic_driver.hpp | 16 ++- headers/wasm/symbolic_rt.hpp | 128 ++++++++++++------ .../scala/wasm/StagedConcolicMiniWasm.scala | 2 +- 3 files changed, 97 insertions(+), 49 deletions(-) diff --git a/headers/wasm/concolic_driver.hpp b/headers/wasm/concolic_driver.hpp index 4307413b..9c35f161 100644 --- a/headers/wasm/concolic_driver.hpp +++ b/headers/wasm/concolic_driver.hpp @@ -37,18 +37,22 @@ class ManagedConcolicCleanup { inline void ConcolicDriver::run() { ManagedConcolicCleanup cleanup{*this}; while (true) { - auto cond = ExploreTree.get_unexplored_conditions(); ExploreTree.reset_cursor(); - if (!cond.has_value()) { - std::cout << "No unexplored conditions found, exiting..." << std::endl; + auto unexplored = ExploreTree.pick_unexplored(); + if (!unexplored) { + std::cout << "No unexplored nodes found, exiting..." << std::endl; return; } - auto new_env = solver.solve(cond.value()); + auto cond = unexplored->collect_path_conds(); + auto new_env = solver.solve(cond); if (!new_env.has_value()) { - std::cout << "All unexplored paths are unreachable, exiting..." + // TODO: current implementation is buggy, there could be other reachable + // unexplored paths + std::cout << "Found an unreachable path, marking it as unreachable..." << std::endl; - return; + unexplored->fillUnreachableNode(); + continue; } SymEnv.update(std::move(new_env.value())); try { diff --git a/headers/wasm/symbolic_rt.hpp b/headers/wasm/symbolic_rt.hpp index 920bcad4..95f292b5 100644 --- a/headers/wasm/symbolic_rt.hpp +++ b/headers/wasm/symbolic_rt.hpp @@ -193,6 +193,13 @@ struct NodeBox { explicit NodeBox(); std::unique_ptr node; NodeBox *parent; + + std::monostate fillIfElseNode(SymVal cond); + std::monostate fillFinishedNode(); + std::monostate fillFailedNode(); + std::monostate fillUnreachableNode(); + + std::vector collect_path_conds(); }; struct Node { @@ -334,13 +341,87 @@ struct Failed : Node { } }; -static UnExploredNode unexplored; +struct Unreachable : Node { + Unreachable() {} + std::string to_string() override { return "UnreachableNode"; } + +protected: + void generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) override { + int current_node_dot_id = current_id++; + graphviz_node(os, current_node_dot_id, "Unreachable", "box", "orange"); + + if (parent_dot_id != -1) { + graphviz_edge(os, parent_dot_id, current_node_dot_id, edge_label); + } + } +}; inline NodeBox::NodeBox() : node(std::make_unique()), /* TODO: avoid allocation of unexplored node */ parent(nullptr) {} +inline std::monostate NodeBox::fillIfElseNode(SymVal cond) { + // fill the current NodeBox with an ifelse branch node it's unexplored + if (dynamic_cast(node.get())) { + node = std::make_unique(cond); + } + assert(dynamic_cast(node.get()) != nullptr && + "Current node is not an IfElseNode, cannot fill it!"); + return std::monostate(); +} + +inline std::monostate NodeBox::fillFinishedNode() { + if (dynamic_cast(node.get())) { + node = std::make_unique(); + } else { + assert(dynamic_cast(node.get()) != nullptr); + } + return std::monostate(); +} + +inline std::monostate NodeBox::fillFailedNode() { + if (dynamic_cast(node.get())) { + node = std::make_unique(); + } else { + assert(dynamic_cast(node.get()) != nullptr); + } + return std::monostate(); +} + +inline std::monostate NodeBox::fillUnreachableNode() { + if (dynamic_cast(node.get())) { + node = std::make_unique(); + } else { + assert(dynamic_cast(node.get()) != nullptr); + } + return std::monostate(); +} + +inline std::vector NodeBox::collect_path_conds() { + auto box = this; + auto result = std::vector(); + while (box->parent) { + auto parent = box->parent; + auto if_else_node = dynamic_cast(parent->node.get()); + if (if_else_node) { + if (if_else_node->true_branch.get() == box) { + // If the current box is the true branch, add the condition + result.push_back(if_else_node->cond); + } else if (if_else_node->false_branch.get() == box) { + // If the current box is the false branch, add the negated condition + result.push_back(if_else_node->cond.negate()); + } else { + throw std::runtime_error("Unexpected node structure in explore tree"); + } + } + // Move to parent + box = box->parent; + } + return result; +} + class ExploreTree_t { public: explicit ExploreTree_t() @@ -351,32 +432,12 @@ class ExploreTree_t { cursor = root.get(); } - std::monostate fillFinishedNode() { - if (dynamic_cast(cursor->node.get())) { - cursor->node = std::make_unique(); - } else { - assert(dynamic_cast(cursor->node.get()) != nullptr); - } - return std::monostate{}; - } + std::monostate fillFinishedNode() { return cursor->fillFinishedNode(); } - std::monostate fillFailedNode() { - if (dynamic_cast(cursor->node.get())) { - cursor->node = std::make_unique(); - } else { - assert(dynamic_cast(cursor->node.get()) != nullptr); - } - return std::monostate{}; - } + std::monostate fillFailedNode() { return cursor->fillFailedNode(); } std::monostate fillIfElseNode(SymVal cond) { - // fill the current NodeBox with an ifelse branch node it's unexplored - if (dynamic_cast(cursor->node.get())) { - cursor->node = std::make_unique(cond); - } - assert(dynamic_cast(cursor->node.get()) != nullptr && - "Current node is not an IfElseNode, cannot fill it!"); - return std::monostate(); + return cursor->fillIfElseNode(cond); } std::monostate moveCursor(bool branch) { @@ -419,24 +480,7 @@ class ExploreTree_t { if (!box) { return std::nullopt; } - while (box->parent) { - auto parent = box->parent; - auto if_else_node = dynamic_cast(parent->node.get()); - if (if_else_node) { - if (if_else_node->true_branch.get() == box) { - // If the current box is the true branch, add the condition - result.push_back(if_else_node->cond); - } else if (if_else_node->false_branch.get() == box) { - // If the current box is the false branch, add the negated condition - result.push_back(if_else_node->cond.negate()); - } else { - throw std::runtime_error("Unexpected node structure in explore tree"); - } - } - // Move to parent - box = box->parent; - } - return result; + return box->collect_path_conds(); } NodeBox *pick_unexplored() { diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index 71463bf5..eb8d3aab 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -982,7 +982,7 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { case Node(_, "sym-env-read", List(sym), _) => emit("SymEnv.read("); shallow(sym); emit(")") case Node(_, "assert-true", List(cond), _) => - emit("assert("); shallow(cond); emit(")") + emit("GENSYM_ASSERT("); shallow(cond); emit(")") case Node(_, "tree-fill-if-else", List(s), _) => emit("ExploreTree.fillIfElseNode("); shallow(s); emit(")") case Node(_, "tree-fill-finished", List(), _) => From b75a627a59c88ac7738b7bf4c8a7944e74b5f6b1 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sat, 19 Jul 2025 13:20:40 +0800 Subject: [PATCH 15/21] call z3 to solve constraints --- headers/wasm/concolic_driver.hpp | 13 ++-- headers/wasm/smt_solver.hpp | 120 ++++++++++++++++++++++++++++--- headers/wasm/symbolic_rt.hpp | 39 ++++++---- 3 files changed, 145 insertions(+), 27 deletions(-) diff --git a/headers/wasm/concolic_driver.hpp b/headers/wasm/concolic_driver.hpp index 9c35f161..ea28d082 100644 --- a/headers/wasm/concolic_driver.hpp +++ b/headers/wasm/concolic_driver.hpp @@ -1,11 +1,13 @@ #ifndef CONCOLIC_DRIVER_HPP #define CONCOLIC_DRIVER_HPP +#include "concrete_rt.hpp" #include "smt_solver.hpp" #include "symbolic_rt.hpp" #include #include #include +#include class ConcolicDriver { friend class ManagedConcolicCleanup; @@ -45,16 +47,19 @@ inline void ConcolicDriver::run() { return; } auto cond = unexplored->collect_path_conds(); - auto new_env = solver.solve(cond); - if (!new_env.has_value()) { + std::vector new_env; + std::set valid_ids; + auto result = solver.solve(cond); + if (!result.has_value()) { // TODO: current implementation is buggy, there could be other reachable // unexplored paths std::cout << "Found an unreachable path, marking it as unreachable..." << std::endl; unexplored->fillUnreachableNode(); - continue; + continue; } - SymEnv.update(std::move(new_env.value())); + std::tie(new_env, valid_ids) = std::move(result.value()); + SymEnv.update(std::move(new_env), std::move(valid_ids)); try { entrypoint(); std::cout << "Execution finished successfully with symbolic environment:" diff --git a/headers/wasm/smt_solver.hpp b/headers/wasm/smt_solver.hpp index a3bbf78d..8e3f82e4 100644 --- a/headers/wasm/smt_solver.hpp +++ b/headers/wasm/smt_solver.hpp @@ -3,26 +3,124 @@ #include "concrete_rt.hpp" #include "symbolic_rt.hpp" +#include "z3++.h" #include +#include +#include +#include #include class Solver { public: - Solver() : count(0) { - envs[0] = {Num(0), Num(0)}; - envs[1] = {Num(1), Num(2)}; - } - std::optional> solve(const std::vector &conditions) { - // here is just a placeholder implementation to simulate solving result - if (count >= envs.size()) { - return std::nullopt; // No more environments to return + Solver() {} + std::optional, std::set>> + solve(const std::vector &conditions) { + // make an conjunction of all conditions + z3::expr conjunction = z3_ctx.bool_val(true); + for (const auto &cond : conditions) { + auto z3_cond = build_z3_expr(cond); + conjunction = conjunction && z3_cond != z3_ctx.bv_val(0, 32); + } +#ifdef DEBUG + std::cout << "Symbolic conditions size: " << conditions.size() << std::endl; + std::cout << "Solving conditions: " << conjunction << std::endl; +#endif + // call z3 to solve the condition + z3::solver z3_solver(z3_ctx); + z3_solver.add(conjunction); + switch (z3_solver.check()) { + case z3::unsat: + return std::nullopt; // No solution found + case z3::sat: { + z3::model model = z3_solver.get_model(); + std::vector result(max_id + 1, Num(0)); + // Reference: + // https://github.com/Z3Prover/z3/blob/master/examples/c%2B%2B/example.cpp#L59 + + std::cout << "Solved Z3 model" << model << std::endl; + std::set seen_ids; + for (unsigned i = 0; i < model.size(); ++i) { + z3::func_decl var = model[i]; + z3::expr value = model.get_const_interp(var); + std::string name = var.name().str(); + if (name.starts_with("s_")) { + int id = std::stoi(name.substr(2)); + seen_ids.insert(id); + result[id] = Num(value.get_numeral_int()); + } else { + std::cout << "Find a variable that is not created by GenSym: " << name + << std::endl; + } + } + return std::make_tuple(result, seen_ids); } - return envs[count++ % envs.size()]; + case z3::unknown: + throw std::runtime_error("Z3 solver returned unknown status"); + } + return std::nullopt; // Should not reach here } private: - std::array, 5> envs; - int count; + z3::context z3_ctx; + z3::expr build_z3_expr(const SymVal &sym_val); }; +inline z3::expr Solver::build_z3_expr(const SymVal &sym_val) { + if (auto sym = std::dynamic_pointer_cast(sym_val.symptr)) { + return z3_ctx.bv_const(("s_" + std::to_string(sym->get_id())).c_str(), 32); + } else if (auto concrete = + std::dynamic_pointer_cast(sym_val.symptr)) { + return z3_ctx.bv_val(concrete->value.value, 32); + } else if (auto binary = + std::dynamic_pointer_cast(sym_val.symptr)) { + auto bit_width = 32; + z3::expr zero_bv = + z3_ctx.bv_val(0, bit_width); // Represents 0 as a 32-bit bitvector + z3::expr one_bv = + z3_ctx.bv_val(1, bit_width); // Represents 1 as a 32-bit bitvector + + z3::expr left = build_z3_expr(binary->lhs); + z3::expr right = build_z3_expr(binary->rhs); + // TODO: make sure the semantics of these operations are aligned with wasm + switch (binary->op) { + case EQ: { + auto temp_bool = left == right; + return z3::ite(temp_bool, one_bv, zero_bv); + } + case NEQ: { + auto temp_bool = left != right; + return z3::ite(temp_bool, one_bv, zero_bv); + } + case LT: { + auto temp_bool = left < right; + return z3::ite(temp_bool, one_bv, zero_bv); + } + case LEQ: { + auto temp_bool = left <= right; + return z3::ite(temp_bool, one_bv, zero_bv); + } + case GT: { + auto temp_bool = left > right; + return z3::ite(temp_bool, one_bv, zero_bv); + } + case GEQ: { + auto temp_bool = left >= right; + return z3::ite(temp_bool, one_bv, zero_bv); + } + case ADD: { + return left + right; + } + case SUB: { + return left - right; + } + case MUL: { + return left * right; + } + case DIV: { + return left / right; + } + } + } + throw std::runtime_error("Unsupported symbolic value type"); +} #endif // SMT_SOLVER_HPP \ No newline at end of file diff --git a/headers/wasm/symbolic_rt.hpp b/headers/wasm/symbolic_rt.hpp index 95f292b5..3e7ea2cc 100644 --- a/headers/wasm/symbolic_rt.hpp +++ b/headers/wasm/symbolic_rt.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -19,9 +20,13 @@ class Symbolic { virtual ~Symbolic() = default; // Make Symbolic polymorphic }; +static int max_id = 0; + class Symbol : public Symbolic { public: - Symbol(int id) : id(id) {} + // TODO: add type information to determine the size of bitvector + // for now we just assume that only i32 will be used + Symbol(int id) : id(id) { max_id = std::max(max_id, id); } int get_id() const { return id; } private: @@ -190,7 +195,7 @@ static SymFrames_t SymFrames; struct Node; struct NodeBox { - explicit NodeBox(); + explicit NodeBox(NodeBox *parent); std::unique_ptr node; NodeBox *parent; @@ -247,9 +252,9 @@ struct IfElseNode : Node { std::unique_ptr true_branch; std::unique_ptr false_branch; - IfElseNode(SymVal cond) - : cond(cond), true_branch(std::make_unique()), - false_branch(std::make_unique()) {} + IfElseNode(SymVal cond, NodeBox *parent) + : cond(cond), true_branch(std::make_unique(parent)), + false_branch(std::make_unique(parent)) {} std::string to_string() override { std::string result = "IfElseNode {\n"; @@ -357,15 +362,15 @@ struct Unreachable : Node { } }; -inline NodeBox::NodeBox() +inline NodeBox::NodeBox(NodeBox *parent) : node(std::make_unique()), /* TODO: avoid allocation of unexplored node */ - parent(nullptr) {} + parent(parent) {} inline std::monostate NodeBox::fillIfElseNode(SymVal cond) { // fill the current NodeBox with an ifelse branch node it's unexplored if (dynamic_cast(node.get())) { - node = std::make_unique(cond); + node = std::make_unique(cond, this); } assert(dynamic_cast(node.get()) != nullptr && "Current node is not an IfElseNode, cannot fill it!"); @@ -425,7 +430,7 @@ inline std::vector NodeBox::collect_path_conds() { class ExploreTree_t { public: explicit ExploreTree_t() - : root(std::make_unique()), cursor(root.get()) {} + : root(std::make_unique(nullptr)), cursor(root.get()) {} void reset_cursor() { // Reset the cursor to the root of the tree @@ -513,13 +518,22 @@ static ExploreTree_t ExploreTree; class SymEnv_t { public: Num read(SymVal sym) { - return Num(0); auto symbol = dynamic_cast(sym.symptr.get()); assert(symbol); + if (symbol->get_id() >= map.size()) { + map.resize(symbol->get_id() + 1); + } +#if DEBUG + std::cout << "Read symbol: " << symbol->get_id() + << " from symbolic environment" << std::endl; + std::cout << "Current symbolic environment: " << to_string() << std::endl; +#endif return map[symbol->get_id()]; } - void update(std::vector new_env) { map = std::move(new_env); } + void update(std::vector new_env, std::set valid_ids) { + map = std::move(new_env); + } std::string to_string() const { std::string result; @@ -534,7 +548,8 @@ class SymEnv_t { } private: - std::vector map; // The symbolic environment, a vector of Num + std::vector map; // The symbolic environment, a vector of Num + std::set valid_ids; // The set of valid IDs in the symbolic environment }; static SymEnv_t SymEnv; From 26c9917fcb39e84f33318955fb3a9b3a3fd9fb32 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sat, 19 Jul 2025 17:48:20 +0800 Subject: [PATCH 16/21] remove unused & resize before update environment --- headers/wasm/concolic_driver.hpp | 6 ++---- headers/wasm/smt_solver.hpp | 14 +++++++------- headers/wasm/symbolic_rt.hpp | 3 +-- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/headers/wasm/concolic_driver.hpp b/headers/wasm/concolic_driver.hpp index ea28d082..8e8ca815 100644 --- a/headers/wasm/concolic_driver.hpp +++ b/headers/wasm/concolic_driver.hpp @@ -47,8 +47,6 @@ inline void ConcolicDriver::run() { return; } auto cond = unexplored->collect_path_conds(); - std::vector new_env; - std::set valid_ids; auto result = solver.solve(cond); if (!result.has_value()) { // TODO: current implementation is buggy, there could be other reachable @@ -58,8 +56,8 @@ inline void ConcolicDriver::run() { unexplored->fillUnreachableNode(); continue; } - std::tie(new_env, valid_ids) = std::move(result.value()); - SymEnv.update(std::move(new_env), std::move(valid_ids)); + auto new_env = result.value(); + SymEnv.update(std::move(new_env)); try { entrypoint(); std::cout << "Execution finished successfully with symbolic environment:" diff --git a/headers/wasm/smt_solver.hpp b/headers/wasm/smt_solver.hpp index 8e3f82e4..de5b80cb 100644 --- a/headers/wasm/smt_solver.hpp +++ b/headers/wasm/smt_solver.hpp @@ -13,8 +13,7 @@ class Solver { public: Solver() {} - std::optional, std::set>> - solve(const std::vector &conditions) { + std::optional> solve(const std::vector &conditions) { // make an conjunction of all conditions z3::expr conjunction = z3_ctx.bool_val(true); for (const auto &cond : conditions) { @@ -33,26 +32,27 @@ class Solver { return std::nullopt; // No solution found case z3::sat: { z3::model model = z3_solver.get_model(); - std::vector result(max_id + 1, Num(0)); + std::vector result; // Reference: // https://github.com/Z3Prover/z3/blob/master/examples/c%2B%2B/example.cpp#L59 - std::cout << "Solved Z3 model" << model << std::endl; - std::set seen_ids; + std::cout << "Solved Z3 model" << std::endl << model << std::endl; for (unsigned i = 0; i < model.size(); ++i) { z3::func_decl var = model[i]; z3::expr value = model.get_const_interp(var); std::string name = var.name().str(); if (name.starts_with("s_")) { int id = std::stoi(name.substr(2)); - seen_ids.insert(id); + if (id >= result.size()) { + result.resize(id + 1); + } result[id] = Num(value.get_numeral_int()); } else { std::cout << "Find a variable that is not created by GenSym: " << name << std::endl; } } - return std::make_tuple(result, seen_ids); + return result; } case z3::unknown: throw std::runtime_error("Z3 solver returned unknown status"); diff --git a/headers/wasm/symbolic_rt.hpp b/headers/wasm/symbolic_rt.hpp index 3e7ea2cc..18629c80 100644 --- a/headers/wasm/symbolic_rt.hpp +++ b/headers/wasm/symbolic_rt.hpp @@ -531,7 +531,7 @@ class SymEnv_t { return map[symbol->get_id()]; } - void update(std::vector new_env, std::set valid_ids) { + void update(std::vector new_env) { map = std::move(new_env); } @@ -549,7 +549,6 @@ class SymEnv_t { private: std::vector map; // The symbolic environment, a vector of Num - std::set valid_ids; // The set of valid IDs in the symbolic environment }; static SymEnv_t SymEnv; From 319cfd6576f0399bcfdd3668e0a19b9feeadcad2 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 23 Jul 2025 14:11:42 +0800 Subject: [PATCH 17/21] use c++20 --- src/main/scala/wasm/StagedConcolicMiniWasm.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index eb8d3aab..922d2113 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -1080,7 +1080,7 @@ object WasmToCppCompiler { } import sys.process._ - val command = s"g++ -std=c++17 $outputCpp -o $outputExe -O3 -g " + generated.headerFolders.map(f => s"-I$f").mkString(" ") + val command = s"g++ -std=c++20 $outputCpp -o $outputExe -O3 -g " + generated.headerFolders.map(f => s"-I$f").mkString(" ") if (command.! != 0) { throw new RuntimeException(s"Compilation failed for $outputCpp") } From 8f45912f6275c9dc5a5ee8bab34b02bcae3b5609 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 23 Jul 2025 19:28:50 +0800 Subject: [PATCH 18/21] branch in brtable --- benchmarks/wasm/staged/brtable_concolic.wat | 22 +++++++++++++++++++ headers/wasm/smt_solver.hpp | 2 +- .../scala/wasm/StagedConcolicMiniWasm.scala | 17 +++++++++++--- .../genwasym/TestStagedConcolicEval.scala | 4 ++++ 4 files changed, 41 insertions(+), 4 deletions(-) create mode 100644 benchmarks/wasm/staged/brtable_concolic.wat diff --git a/benchmarks/wasm/staged/brtable_concolic.wat b/benchmarks/wasm/staged/brtable_concolic.wat new file mode 100644 index 00000000..04429e90 --- /dev/null +++ b/benchmarks/wasm/staged/brtable_concolic.wat @@ -0,0 +1,22 @@ +(module $brtable + (global (;0;) (mut i32) (i32.const 1048576)) + (type (;0;) (func (param i32))) + (func (;0;) (type 1) (result i32) + i32.const 2 + (block + (block + (block + i32.const 0 + i32.symbolic + br_table 0 1 2 0 ;; br_table will consume an element from the stack + ) + i32.const 1 + call 1 + br 1 + ) + i32.const 0 + call 1 + ) + ) + (import "console" "assert" (func (type 0))) + (start 0)) diff --git a/headers/wasm/smt_solver.hpp b/headers/wasm/smt_solver.hpp index de5b80cb..f2450905 100644 --- a/headers/wasm/smt_solver.hpp +++ b/headers/wasm/smt_solver.hpp @@ -46,7 +46,7 @@ class Solver { if (id >= result.size()) { result.resize(id + 1); } - result[id] = Num(value.get_numeral_int()); + result[id] = Num(value.get_numeral_int64()); } else { std::cout << "Find a variable that is not created by GenSym: " << name << std::endl; diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index 922d2113..6c8e45aa 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -17,6 +17,7 @@ import gensym.wasm.symbolic.{SymVal} import gensym.lmsx.{SAIDriver, StringOps, SAIOps, SAICodeGenBase, CppSAIDriver, CppSAICodeGenBase} import gensym.wasm.symbolic.Concrete import gensym.wasm.symbolic.ExploreTree +import gensym.structure.freer.Explore @virtualize trait StagedWasmEvaluator extends SAIOps { @@ -270,12 +271,20 @@ trait StagedWasmEvaluator extends SAIOps { } () case BrTable(labels, default) => - val (cond, newCtx) = Stack.pop() + val (label, newCtx) = Stack.pop() def aux(choices: List[Int], idx: Int): Rep[Unit] = { if (choices.isEmpty) trail(default)(newCtx)(mkont) else { - if (cond.toInt == idx) trail(choices.head)(newCtx)(mkont) - else aux(choices.tail, idx + 1) + val cond = (label - toStagedNum(I32V(idx))).isZero() + ExploreTree.fillWithIfElse(cond.s) + if (cond.toInt != 0) { + ExploreTree.moveCursor(true) + trail(choices.head)(newCtx)(mkont) + } + else { + ExploreTree.moveCursor(false) + aux(choices.tail, idx + 1) + } } } aux(labels, 0) @@ -959,6 +968,8 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { shallow(lhs); emit(" >= "); shallow(rhs) case Node(_, "sym-binary-add", List(lhs, rhs), _) => shallow(lhs); emit(".add("); shallow(rhs); emit(")") + case Node(_, "sym-binary-sub", List(lhs, rhs), _) => + shallow(lhs); emit(".minus("); shallow(rhs); emit(")") case Node(_, "sym-binary-mul", List(lhs, rhs), _) => shallow(lhs); emit(".mul("); shallow(rhs); emit(")") case Node(_, "sym-binary-div", List(lhs, rhs), _) => diff --git a/src/test/scala/genwasym/TestStagedConcolicEval.scala b/src/test/scala/genwasym/TestStagedConcolicEval.scala index 77868e2c..fa7f704b 100644 --- a/src/test/scala/genwasym/TestStagedConcolicEval.scala +++ b/src/test/scala/genwasym/TestStagedConcolicEval.scala @@ -35,4 +35,8 @@ class TestStagedConcolicEval extends FunSuite { test("bug-finding") { testFileToCpp("./benchmarks/wasm/branch-strip-buggy.wat", Some("real_main")) } + + test("brtable-bug-finding") { + testFileToCpp("./benchmarks/wasm/staged/brtable_concolic.wat") + } } From 2e2259d0bbd18aa122d201d62bc03d245c6ec191 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 23 Jul 2025 19:32:49 +0800 Subject: [PATCH 19/21] use driver's entrypoint by default --- src/main/scala/wasm/StagedConcolicMiniWasm.scala | 4 ++-- src/test/scala/genwasym/TestStagedConcolicEval.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index 6c8e45aa..833bbc9b 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -1047,7 +1047,7 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { |End of Generated Code |*******************************************/ |int main(int argc, char *argv[]) { - | Snippet(std::monostate{}); + | start_concolic_execution_with(Snippet); | return 0; |}""".stripMargin) } @@ -1091,7 +1091,7 @@ object WasmToCppCompiler { } import sys.process._ - val command = s"g++ -std=c++20 $outputCpp -o $outputExe -O3 -g " + generated.headerFolders.map(f => s"-I$f").mkString(" ") + val command = s"g++ -std=c++20 $outputCpp -o $outputExe -O3 -g -l z3 " + generated.headerFolders.map(f => s"-I$f").mkString(" ") if (command.! != 0) { throw new RuntimeException(s"Compilation failed for $outputCpp") } diff --git a/src/test/scala/genwasym/TestStagedConcolicEval.scala b/src/test/scala/genwasym/TestStagedConcolicEval.scala index fa7f704b..a65d0eda 100644 --- a/src/test/scala/genwasym/TestStagedConcolicEval.scala +++ b/src/test/scala/genwasym/TestStagedConcolicEval.scala @@ -30,7 +30,7 @@ class TestStagedConcolicEval extends FunSuite { }) } - test("ack-cpp") { testFileToCpp("./benchmarks/wasm/ack.wat", Some("real_main"), expect=Some(List(7))) } + test("ack-cpp") { testFileToCpp("./benchmarks/wasm/ack.wat", Some("real_main")) } test("bug-finding") { testFileToCpp("./benchmarks/wasm/branch-strip-buggy.wat", Some("real_main")) From 2b42b277cfa42e6d06ba49f70daba327c4c4abcf Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 23 Jul 2025 20:09:12 +0800 Subject: [PATCH 20/21] rename package name of staged miniwasm --- src/main/scala/wasm/StagedMiniWasm.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index bfa2082d..ea9dc9c6 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -1,4 +1,4 @@ -package gensym.wasm.miniwasm +package gensym.wasm.stagedminiwasm import scala.collection.mutable.{ArrayBuffer, HashMap} @@ -12,6 +12,7 @@ import lms.core.Graph import gensym.wasm.ast._ import gensym.wasm.ast.{Const => WasmConst, Block => WasmBlock} +import gensym.wasm.miniwasm.ModuleInstance import gensym.lmsx.{SAIDriver, StringOps, SAIOps, SAICodeGenBase, CppSAIDriver, CppSAICodeGenBase} @virtualize From 619a8f022d015d67ccbc4919bc49801175dc9dda Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 23 Jul 2025 20:15:29 +0800 Subject: [PATCH 21/21] tweak --- src/test/scala/genwasym/TestStagedEval.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index d4d1e960..3769428f 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -6,6 +6,7 @@ import lms.core.stub.Adapter import gensym.wasm.parser._ import gensym.wasm.miniwasm._ +import gensym.wasm.stagedminiwasm._ class TestStagedEval extends FunSuite { def testFileToScala(filename: String, main: Option[String] = None, printRes: Boolean = false) = {