Skip to content

Commit 96561f3

Browse files
author
Guannan Wei
committed
add demo
1 parent 29cde18 commit 96561f3

File tree

1 file changed

+300
-0
lines changed

1 file changed

+300
-0
lines changed
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
package tutorial
2+
3+
import lms.core._
4+
import lms.macros._
5+
import lms.core.stub._
6+
import lms.core.Backend._
7+
8+
import gensym.lmsx._
9+
10+
import scala.collection.immutable.{List => SList}
11+
12+
object ImpLang {
13+
sealed trait Stmt
14+
case class Skip() extends Stmt
15+
case class Break() extends Stmt
16+
case class Assign(x: String, e: Expr) extends Stmt
17+
case class Cond(e: Expr, thn: Stmt, els: Stmt) extends Stmt
18+
case class Seq(s1: Stmt, s2: Stmt) extends Stmt
19+
case class While(b: Expr, s: Stmt) extends Stmt
20+
case class Output(e: Expr) extends Stmt
21+
case class Assert(e: Expr) extends Stmt
22+
23+
sealed trait Expr {
24+
def toSExp: String
25+
}
26+
case class Input() extends Expr {
27+
def toSExp = ???
28+
}
29+
case class Lit(x: Any) extends Expr {
30+
override def toString: String = s"Lit(${x.toString})"
31+
def toSExp: String = x.toString
32+
}
33+
case class Var(x: String) extends Expr {
34+
override def toString: String = "Var(\"" + x.toString + "\")"
35+
def toSExp: String = x.toString
36+
}
37+
case class Op1(op: String, e: Expr) extends Expr {
38+
override def toString: String = "Op1(\"" + op + "\"," + s"${e.toString})"
39+
def toSExp: String = s"($op ${e.toSExp})"
40+
}
41+
case class Op2(op: String, e1: Expr, e2: Expr) extends Expr {
42+
override def toString: String =
43+
"Op2(\"" + op + "\"," + s"${e1.toString}, ${e2.toString})"
44+
def toSExp: String = s"($op ${e1.toSExp} ${e2.toSExp})"
45+
}
46+
47+
def let_(x: String, rhs: Int)(body: Var => Stmt): Stmt =
48+
Seq(Assign(x, Lit(rhs)), body(Var(x)))
49+
def let_(x: String, rhs: Expr)(body: Var => Stmt): Stmt =
50+
Seq(Assign(x, rhs), body(Var(x)))
51+
52+
def set_(x: String, rhs: Expr): Stmt = Assign(x, rhs)
53+
54+
def while_(e: Expr, s: Stmt): Stmt = While(e, s)
55+
56+
object Examples {
57+
val fact5 =
58+
Seq(Assign("i", Lit(1)),
59+
Seq(Assign("fact", Lit(1)),
60+
While(Op2("<=", Var("i"), Lit(5)),
61+
Seq(Assign("fact", Op2("*", Var("fact"), Var("i"))),
62+
Assign("i", Op2("+", Var("i"), Lit(1)))))))
63+
64+
val fact_n =
65+
Seq(Assign("i", Lit(1)),
66+
Seq(Assign("fact", Lit(1)),
67+
While(Op2("<=", Var("i"), Var("n")),
68+
Seq(Assign("fact", Op2("*", Var("fact"), Var("i"))),
69+
Assign("i", Op2("+", Var("i"), Lit(1)))))))
70+
71+
val w2 =
72+
While(Op2("<=", Var("i"), Var("x")),
73+
While(Op2("<=", Var("i"), Var("x")),
74+
Assign("x", Op2("-", Var("x"), Lit(1)))))
75+
76+
val w3 =
77+
While(Op2("<=", Var("i"), Var("x")),
78+
While(Op2("<=", Var("i"), Var("x")),
79+
While(Op2("<=", Var("i"), Var("x")),
80+
Assign("x", Op2("-", Var("x"), Lit(1))))))
81+
82+
val another_fact5 =
83+
let_("i", 1){ i =>
84+
let_("fact", 1){ fact =>
85+
while_(Op2("<=", i, Lit(5)),
86+
let_("fact", Op2("*", fact, i)){ _ =>
87+
set_("i", Op2("+", i, Lit(1)))
88+
})}}
89+
90+
91+
//println(another_fact5)
92+
assert(fact5 == another_fact5)
93+
94+
val x = Var("x")
95+
val y = Var("y")
96+
val z = Var("z")
97+
val a = Var("a")
98+
val b = Var("b")
99+
val i = Var("i")
100+
101+
val cond1 =
102+
Cond(Op2("<=", Lit(1), Lit(2)),
103+
Assign("x", Lit(3)),
104+
Assign("x", Lit(4)))
105+
106+
/* if (x <= y) {
107+
* z = x
108+
* } else {
109+
* z = y
110+
* }
111+
* z = z + 1
112+
*/
113+
val cond2 =
114+
Seq(Cond(Op2("<=", Var("x"), Var("y")),
115+
Assign("z", Var("x")),
116+
Assign("z", Var("y"))),
117+
Assign("z", Op2("+", Var("z"), Lit(1))))
118+
119+
/* if (x <= y) {
120+
* z = x
121+
* } else {
122+
* z = y
123+
* }
124+
* z = z - 1
125+
* if (z >= y) {
126+
* z = z * 2
127+
* } else {
128+
* z = z + 3
129+
* }
130+
*/
131+
val cond3 =
132+
Seq(Cond(Op2("<=", x, y),
133+
Assign("z", x),
134+
Assign("z", y)),
135+
Seq(Assign("z", Op2("-", z, Lit(1))),
136+
Seq(Cond(Op2(">=", z, y),
137+
Assign("z", Op2("*", z, Lit(2))),
138+
Assign("z", Op2("+", z, Lit(3)))),
139+
Skip())))
140+
141+
val condInput =
142+
Seq(Assign("x", Input()),
143+
Seq(Cond(Op2("<=", x, y),
144+
Assign("z", x),
145+
Assign("z", y)),
146+
Seq(Assign("z", Op2("+", z, Lit(1))),
147+
Seq(Cond(Op2(">=", z, y),
148+
Assign("z", Op2("+", z, Lit(2))),
149+
Assign("z", Op2("+", z, Lit(3)))),
150+
Skip()))))
151+
152+
val condAssert =
153+
Seq(Assign("x", Input()),
154+
Seq(Assert(Op2(">=", x, Lit(1))),
155+
Seq(Cond(Op2("<=", x, y),
156+
Assign("z", x),
157+
Assign("z", y)),
158+
Seq(Assign("z", Op2("+", z, Lit(1))),
159+
Seq(Cond(Op2(">=", z, y),
160+
Assign("z", Op2("+", z, Lit(2))),
161+
Assign("z", Op2("+", z, Lit(3)))),
162+
Skip())))))
163+
164+
val unboundLoop =
165+
Seq(Assign("i", Input()),
166+
While(Op2("<", i, Lit(42)),
167+
Assign("i", Op2("+", i, Lit(1)))))
168+
}
169+
}
170+
171+
import ImpLang._
172+
173+
@virtualize
174+
trait ImpureStagedImpSemantics extends SAIOps {
175+
trait Value
176+
def IntV(i: Rep[Int]): Rep[Value] = "IntV".reflectWith[Value](i)
177+
def BoolV(b: Rep[Boolean]): Rep[Value] = "BoolV".reflectWith[Value](b)
178+
179+
implicit def repIntProj(i: Rep[Value]): Rep[Int] = Unwrap(i) match {
180+
//case Adapter.g.Def("IntV", SList(v: Backend.Exp)) => Wrap[Int](v)
181+
case _ => "IntV-proj".reflectWith[Int](i)
182+
}
183+
implicit def repBoolProj(b: Rep[Value]): Rep[Boolean] = Unwrap(b) match {
184+
case Adapter.g.Def("BoolV", SList(v: Backend.Exp)) => Wrap[Boolean](v)
185+
case _ => "BoolV-proj".reflectWith[Boolean](b)
186+
}
187+
188+
trait MutState
189+
def newMutState(kvs: (String, Rep[Value])*): Rep[MutState] =
190+
"mutstate-new".reflectMutableWith[MutState](kvs.map({ case (k, v) => __liftTuple2RepLhs(k, v) }):_*)
191+
implicit class MutStateOps(s: Rep[MutState]) {
192+
def apply(x: String): Rep[Value] = "mutstate-read".reflectReadWith[Value](s, x)(s)
193+
def +=(x: String, v: Rep[Value]): Rep[Unit] = "mutstate-update".reflectWriteWith[Unit](s, x, v)(s)
194+
}
195+
def dummyRead(s: Rep[MutState]): Rep[Unit] = "mutstate-dummyread".reflectRWWith[Unit]()(s)(Adapter.CTRL)
196+
197+
def eval(e: Expr, σ: Rep[MutState]): Rep[Value] = e match {
198+
case Lit(i: Int) => IntV(i)
199+
case Lit(b: Boolean) => BoolV(b)
200+
case Var(x) => σ(x)
201+
case Op1("-", e) =>
202+
val i: Rep[Int] = eval(e, σ)
203+
IntV(-i)
204+
case Op2(op, e1, e2) =>
205+
val i1: Rep[Int] = eval(e1, σ)
206+
val i2: Rep[Int] = eval(e2, σ)
207+
op match {
208+
case "+" => IntV(i1 + i2)
209+
case "-" => IntV(i1 - i2)
210+
case "*" => IntV(i1 * i2)
211+
case "==" => BoolV(i1 == i2)
212+
case "<=" => BoolV(i1 <= i2)
213+
case "<" => BoolV(i1 < i2)
214+
case ">=" => BoolV(i1 >= i2)
215+
case ">" => BoolV(i1 > i2)
216+
}
217+
}
218+
219+
def exec(s: Stmt, σ: Rep[MutState]): Rep[Unit] = s match {
220+
case Skip() => ()
221+
case Assign(x, e) => σ += (x, eval(e, σ))
222+
case Cond(e, s1, s2) =>
223+
if (eval(e, σ)) exec(s1, σ) else exec(s2, σ)
224+
case Seq(s1, s2) => exec(s1, σ); exec(s2, σ)
225+
case While(e, b) => while (eval(e, σ)) exec(b, σ)
226+
}
227+
}
228+
229+
trait ImpureStagedImpGen extends SAICodeGenBase {
230+
override def traverse(n: Node): Unit = n match {
231+
case Node(s, "mutstate-new", kvs, _) =>
232+
es"val ${quote(s)} = Map[String, Value]("
233+
kvs.zipWithIndex.map { case (kv, i) =>
234+
shallow(kv)
235+
if (i != kvs.length-1) emit(", ")
236+
}
237+
esln")"
238+
case Node(_, "mutstate-update", List(s, x, v), _) => esln"$s($x) = $v"
239+
case Node(_, "mutstate-dummyread", _, _) => es""
240+
case _ => super.traverse(n)
241+
}
242+
// shallow : code generation for pure node/expression
243+
override def shallow(n: Node): Unit = n match {
244+
case Node(s, "IntV", List(i), _) => es"IntV($i)"
245+
case Node(s, "BoolV", List(b), _) => es"BoolV($b)"
246+
case Node(s, "IntV-proj", List(i), _) => es"$i.I"
247+
case Node(s, "BoolV-proj", List(i), _) => es"$i.B"
248+
case Node(_, "mutstate-read", List(s, x), _) => es"$s($x)"
249+
case _ => super.shallow(n)
250+
}
251+
}
252+
253+
trait ImpureStagedImpDriver[A, B] extends SAIDriver[A, B] with ImpureStagedImpSemantics { q =>
254+
override val codegen = new ScalaGenBase with ImpureStagedImpGen {
255+
val IR: q.type = q
256+
import IR._
257+
override def remap(m: Manifest[_]): String = {
258+
if (m.toString.endsWith("$Value")) "Value"
259+
else if (m.toString.endsWith("$MutState")) "MutState"
260+
else super.remap(m)
261+
}
262+
}
263+
264+
override val prelude =
265+
"""
266+
import scala.collection.mutable.Map
267+
import sai.lang.ImpLang._
268+
object Prelude {
269+
trait Value
270+
case class IntV(i: Int) extends Value
271+
case class BoolV(b: Boolean) extends Value
272+
implicit class ValueOps(v: Value) {
273+
def I: Int = v.asInstanceOf[IntV].i
274+
def B: Boolean = v.asInstanceOf[BoolV].b
275+
}
276+
}
277+
import Prelude._
278+
"""
279+
}
280+
281+
object ImpureStagedImpTest {
282+
import ImpLang._
283+
import ImpLang.Examples._
284+
def main(args: Array[String]): Unit = {
285+
val code = new ImpureStagedImpDriver[Int, Unit] {
286+
@virtualize
287+
def snippet(u: Rep[Int]) = {
288+
//val st: Rep[MutState] = newMutState("x" -> IntV(3), "y" -> IntV(4))
289+
//exec(cond3, st)
290+
val st: Rep[MutState] = newMutState()
291+
//exec(Seq(Assign("x", Lit(3)), Assign("y", Lit(4))), st)
292+
exec(fact5, st)
293+
dummyRead(st)
294+
println(st)
295+
}
296+
}
297+
println(code.code)
298+
//code.eval(0)
299+
}
300+
}

0 commit comments

Comments
 (0)