1
+ package tutorial
2
+
3
+ import lms .core ._
4
+ import lms .macros ._
5
+ import lms .core .stub ._
6
+ import lms .core .Backend ._
7
+
8
+ import gensym .lmsx ._
9
+
10
+ import scala .collection .immutable .{List => SList }
11
+
12
+ object ImpLang {
13
+ sealed trait Stmt
14
+ case class Skip () extends Stmt
15
+ case class Break () extends Stmt
16
+ case class Assign (x : String , e : Expr ) extends Stmt
17
+ case class Cond (e : Expr , thn : Stmt , els : Stmt ) extends Stmt
18
+ case class Seq (s1 : Stmt , s2 : Stmt ) extends Stmt
19
+ case class While (b : Expr , s : Stmt ) extends Stmt
20
+ case class Output (e : Expr ) extends Stmt
21
+ case class Assert (e : Expr ) extends Stmt
22
+
23
+ sealed trait Expr {
24
+ def toSExp : String
25
+ }
26
+ case class Input () extends Expr {
27
+ def toSExp = ???
28
+ }
29
+ case class Lit (x : Any ) extends Expr {
30
+ override def toString : String = s " Lit( ${x.toString}) "
31
+ def toSExp : String = x.toString
32
+ }
33
+ case class Var (x : String ) extends Expr {
34
+ override def toString : String = " Var(\" " + x.toString + " \" )"
35
+ def toSExp : String = x.toString
36
+ }
37
+ case class Op1 (op : String , e : Expr ) extends Expr {
38
+ override def toString : String = " Op1(\" " + op + " \" ," + s " ${e.toString}) "
39
+ def toSExp : String = s " ( $op ${e.toSExp}) "
40
+ }
41
+ case class Op2 (op : String , e1 : Expr , e2 : Expr ) extends Expr {
42
+ override def toString : String =
43
+ " Op2(\" " + op + " \" ," + s " ${e1.toString}, ${e2.toString}) "
44
+ def toSExp : String = s " ( $op ${e1.toSExp} ${e2.toSExp}) "
45
+ }
46
+
47
+ def let_ (x : String , rhs : Int )(body : Var => Stmt ): Stmt =
48
+ Seq (Assign (x, Lit (rhs)), body(Var (x)))
49
+ def let_ (x : String , rhs : Expr )(body : Var => Stmt ): Stmt =
50
+ Seq (Assign (x, rhs), body(Var (x)))
51
+
52
+ def set_ (x : String , rhs : Expr ): Stmt = Assign (x, rhs)
53
+
54
+ def while_ (e : Expr , s : Stmt ): Stmt = While (e, s)
55
+
56
+ object Examples {
57
+ val fact5 =
58
+ Seq (Assign (" i" , Lit (1 )),
59
+ Seq (Assign (" fact" , Lit (1 )),
60
+ While (Op2 (" <=" , Var (" i" ), Lit (5 )),
61
+ Seq (Assign (" fact" , Op2 (" *" , Var (" fact" ), Var (" i" ))),
62
+ Assign (" i" , Op2 (" +" , Var (" i" ), Lit (1 )))))))
63
+
64
+ val fact_n =
65
+ Seq (Assign (" i" , Lit (1 )),
66
+ Seq (Assign (" fact" , Lit (1 )),
67
+ While (Op2 (" <=" , Var (" i" ), Var (" n" )),
68
+ Seq (Assign (" fact" , Op2 (" *" , Var (" fact" ), Var (" i" ))),
69
+ Assign (" i" , Op2 (" +" , Var (" i" ), Lit (1 )))))))
70
+
71
+ val w2 =
72
+ While (Op2 (" <=" , Var (" i" ), Var (" x" )),
73
+ While (Op2 (" <=" , Var (" i" ), Var (" x" )),
74
+ Assign (" x" , Op2 (" -" , Var (" x" ), Lit (1 )))))
75
+
76
+ val w3 =
77
+ While (Op2 (" <=" , Var (" i" ), Var (" x" )),
78
+ While (Op2 (" <=" , Var (" i" ), Var (" x" )),
79
+ While (Op2 (" <=" , Var (" i" ), Var (" x" )),
80
+ Assign (" x" , Op2 (" -" , Var (" x" ), Lit (1 ))))))
81
+
82
+ val another_fact5 =
83
+ let_(" i" , 1 ){ i =>
84
+ let_(" fact" , 1 ){ fact =>
85
+ while_(Op2 (" <=" , i, Lit (5 )),
86
+ let_(" fact" , Op2 (" *" , fact, i)){ _ =>
87
+ set_(" i" , Op2 (" +" , i, Lit (1 )))
88
+ })}}
89
+
90
+
91
+ // println(another_fact5)
92
+ assert(fact5 == another_fact5)
93
+
94
+ val x = Var (" x" )
95
+ val y = Var (" y" )
96
+ val z = Var (" z" )
97
+ val a = Var (" a" )
98
+ val b = Var (" b" )
99
+ val i = Var (" i" )
100
+
101
+ val cond1 =
102
+ Cond (Op2 (" <=" , Lit (1 ), Lit (2 )),
103
+ Assign (" x" , Lit (3 )),
104
+ Assign (" x" , Lit (4 )))
105
+
106
+ /* if (x <= y) {
107
+ * z = x
108
+ * } else {
109
+ * z = y
110
+ * }
111
+ * z = z + 1
112
+ */
113
+ val cond2 =
114
+ Seq (Cond (Op2 (" <=" , Var (" x" ), Var (" y" )),
115
+ Assign (" z" , Var (" x" )),
116
+ Assign (" z" , Var (" y" ))),
117
+ Assign (" z" , Op2 (" +" , Var (" z" ), Lit (1 ))))
118
+
119
+ /* if (x <= y) {
120
+ * z = x
121
+ * } else {
122
+ * z = y
123
+ * }
124
+ * z = z - 1
125
+ * if (z >= y) {
126
+ * z = z * 2
127
+ * } else {
128
+ * z = z + 3
129
+ * }
130
+ */
131
+ val cond3 =
132
+ Seq (Cond (Op2 (" <=" , x, y),
133
+ Assign (" z" , x),
134
+ Assign (" z" , y)),
135
+ Seq (Assign (" z" , Op2 (" -" , z, Lit (1 ))),
136
+ Seq (Cond (Op2 (" >=" , z, y),
137
+ Assign (" z" , Op2 (" *" , z, Lit (2 ))),
138
+ Assign (" z" , Op2 (" +" , z, Lit (3 )))),
139
+ Skip ())))
140
+
141
+ val condInput =
142
+ Seq (Assign (" x" , Input ()),
143
+ Seq (Cond (Op2 (" <=" , x, y),
144
+ Assign (" z" , x),
145
+ Assign (" z" , y)),
146
+ Seq (Assign (" z" , Op2 (" +" , z, Lit (1 ))),
147
+ Seq (Cond (Op2 (" >=" , z, y),
148
+ Assign (" z" , Op2 (" +" , z, Lit (2 ))),
149
+ Assign (" z" , Op2 (" +" , z, Lit (3 )))),
150
+ Skip ()))))
151
+
152
+ val condAssert =
153
+ Seq (Assign (" x" , Input ()),
154
+ Seq (Assert (Op2 (" >=" , x, Lit (1 ))),
155
+ Seq (Cond (Op2 (" <=" , x, y),
156
+ Assign (" z" , x),
157
+ Assign (" z" , y)),
158
+ Seq (Assign (" z" , Op2 (" +" , z, Lit (1 ))),
159
+ Seq (Cond (Op2 (" >=" , z, y),
160
+ Assign (" z" , Op2 (" +" , z, Lit (2 ))),
161
+ Assign (" z" , Op2 (" +" , z, Lit (3 )))),
162
+ Skip ())))))
163
+
164
+ val unboundLoop =
165
+ Seq (Assign (" i" , Input ()),
166
+ While (Op2 (" <" , i, Lit (42 )),
167
+ Assign (" i" , Op2 (" +" , i, Lit (1 )))))
168
+ }
169
+ }
170
+
171
+ import ImpLang ._
172
+
173
+ @ virtualize
174
+ trait ImpureStagedImpSemantics extends SAIOps {
175
+ trait Value
176
+ def IntV (i : Rep [Int ]): Rep [Value ] = " IntV" .reflectWith[Value ](i)
177
+ def BoolV (b : Rep [Boolean ]): Rep [Value ] = " BoolV" .reflectWith[Value ](b)
178
+
179
+ implicit def repIntProj (i : Rep [Value ]): Rep [Int ] = Unwrap (i) match {
180
+ // case Adapter.g.Def("IntV", SList(v: Backend.Exp)) => Wrap[Int](v)
181
+ case _ => " IntV-proj" .reflectWith[Int ](i)
182
+ }
183
+ implicit def repBoolProj (b : Rep [Value ]): Rep [Boolean ] = Unwrap (b) match {
184
+ case Adapter .g.Def (" BoolV" , SList (v : Backend .Exp )) => Wrap [Boolean ](v)
185
+ case _ => " BoolV-proj" .reflectWith[Boolean ](b)
186
+ }
187
+
188
+ trait MutState
189
+ def newMutState (kvs : (String , Rep [Value ])* ): Rep [MutState ] =
190
+ " mutstate-new" .reflectMutableWith[MutState ](kvs.map({ case (k, v) => __liftTuple2RepLhs(k, v) }):_* )
191
+ implicit class MutStateOps (s : Rep [MutState ]) {
192
+ def apply (x : String ): Rep [Value ] = " mutstate-read" .reflectReadWith[Value ](s, x)(s)
193
+ def += (x : String , v : Rep [Value ]): Rep [Unit ] = " mutstate-update" .reflectWriteWith[Unit ](s, x, v)(s)
194
+ }
195
+ def dummyRead (s : Rep [MutState ]): Rep [Unit ] = " mutstate-dummyread" .reflectRWWith[Unit ]()(s)(Adapter .CTRL )
196
+
197
+ def eval (e : Expr , σ : Rep [MutState ]): Rep [Value ] = e match {
198
+ case Lit (i : Int ) => IntV (i)
199
+ case Lit (b : Boolean ) => BoolV (b)
200
+ case Var (x) => σ(x)
201
+ case Op1 (" -" , e) =>
202
+ val i : Rep [Int ] = eval(e, σ)
203
+ IntV (- i)
204
+ case Op2 (op, e1, e2) =>
205
+ val i1 : Rep [Int ] = eval(e1, σ)
206
+ val i2 : Rep [Int ] = eval(e2, σ)
207
+ op match {
208
+ case " +" => IntV (i1 + i2)
209
+ case " -" => IntV (i1 - i2)
210
+ case " *" => IntV (i1 * i2)
211
+ case " ==" => BoolV (i1 == i2)
212
+ case " <=" => BoolV (i1 <= i2)
213
+ case " <" => BoolV (i1 < i2)
214
+ case " >=" => BoolV (i1 >= i2)
215
+ case " >" => BoolV (i1 > i2)
216
+ }
217
+ }
218
+
219
+ def exec (s : Stmt , σ : Rep [MutState ]): Rep [Unit ] = s match {
220
+ case Skip () => ()
221
+ case Assign (x, e) => σ += (x, eval(e, σ))
222
+ case Cond (e, s1, s2) =>
223
+ if (eval(e, σ)) exec(s1, σ) else exec(s2, σ)
224
+ case Seq (s1, s2) => exec(s1, σ); exec(s2, σ)
225
+ case While (e, b) => while (eval(e, σ)) exec(b, σ)
226
+ }
227
+ }
228
+
229
+ trait ImpureStagedImpGen extends SAICodeGenBase {
230
+ override def traverse (n : Node ): Unit = n match {
231
+ case Node (s, " mutstate-new" , kvs, _) =>
232
+ es " val ${quote(s)} = Map[String, Value]( "
233
+ kvs.zipWithIndex.map { case (kv, i) =>
234
+ shallow(kv)
235
+ if (i != kvs.length- 1 ) emit(" , " )
236
+ }
237
+ esln " ) "
238
+ case Node (_, " mutstate-update" , List (s, x, v), _) => esln " $s( $x) = $v"
239
+ case Node (_, " mutstate-dummyread" , _, _) => es " "
240
+ case _ => super .traverse(n)
241
+ }
242
+ // shallow : code generation for pure node/expression
243
+ override def shallow (n : Node ): Unit = n match {
244
+ case Node (s, " IntV" , List (i), _) => es " IntV( $i) "
245
+ case Node (s, " BoolV" , List (b), _) => es " BoolV( $b) "
246
+ case Node (s, " IntV-proj" , List (i), _) => es " $i.I "
247
+ case Node (s, " BoolV-proj" , List (i), _) => es " $i.B "
248
+ case Node (_, " mutstate-read" , List (s, x), _) => es " $s( $x) "
249
+ case _ => super .shallow(n)
250
+ }
251
+ }
252
+
253
+ trait ImpureStagedImpDriver [A , B ] extends SAIDriver [A , B ] with ImpureStagedImpSemantics { q =>
254
+ override val codegen = new ScalaGenBase with ImpureStagedImpGen {
255
+ val IR : q.type = q
256
+ import IR ._
257
+ override def remap (m : Manifest [_]): String = {
258
+ if (m.toString.endsWith(" $Value" )) " Value"
259
+ else if (m.toString.endsWith(" $MutState" )) " MutState"
260
+ else super .remap(m)
261
+ }
262
+ }
263
+
264
+ override val prelude =
265
+ """
266
+ import scala.collection.mutable.Map
267
+ import sai.lang.ImpLang._
268
+ object Prelude {
269
+ trait Value
270
+ case class IntV(i: Int) extends Value
271
+ case class BoolV(b: Boolean) extends Value
272
+ implicit class ValueOps(v: Value) {
273
+ def I: Int = v.asInstanceOf[IntV].i
274
+ def B: Boolean = v.asInstanceOf[BoolV].b
275
+ }
276
+ }
277
+ import Prelude._
278
+ """
279
+ }
280
+
281
+ object ImpureStagedImpTest {
282
+ import ImpLang ._
283
+ import ImpLang .Examples ._
284
+ def main (args : Array [String ]): Unit = {
285
+ val code = new ImpureStagedImpDriver [Int , Unit ] {
286
+ @ virtualize
287
+ def snippet (u : Rep [Int ]) = {
288
+ // val st: Rep[MutState] = newMutState("x" -> IntV(3), "y" -> IntV(4))
289
+ // exec(cond3, st)
290
+ val st : Rep [MutState ] = newMutState()
291
+ // exec(Seq(Assign("x", Lit(3)), Assign("y", Lit(4))), st)
292
+ exec(fact5, st)
293
+ dummyRead(st)
294
+ println(st)
295
+ }
296
+ }
297
+ println(code.code)
298
+ // code.eval(0)
299
+ }
300
+ }
0 commit comments