@@ -61,7 +61,7 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
6161 S <: Shape
6262 ](
6363 inputs : Tuple ,
64- input_node_names : IO [ List [String ] ],
64+ input_node_names : List [String ],
6565 opName : String ,
6666 attrs : Map [String , Any ]
6767 )(using
@@ -82,7 +82,7 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
8282 */
8383
8484 // TODO: more outputs
85- val output_node_names = input_node_names.map(x => { List (x .toString) } )
85+ val output_node_names = List ( input_node_names .toString)
8686
8787 // Spurious warning here, see: https://github.com/lampepfl/dotty/issues/10318
8888 // TODO: don't mix up Options and Tensors here
@@ -95,30 +95,46 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
9595 case opt : Option [Tensor [T , Tuple3 [Tt , Td , S ]]] =>
9696 opt match {
9797 case Some (x) =>
98- Some (x.data.flatMap { y =>
99- x.shape.map { z =>
100- getOnnxTensor(y, z)
98+ Some (x.map { y =>
99+ getOnnxTensor(y._1, y._2._3.toSeq.toArray)
101100 }
102- } )
101+ )
103102 case None => None
104103 }
105104 case tens : Tensor [T , Tuple3 [Tt , Td , S ]] =>
106- Some (tens.data.flatMap { x =>
107- tens.shape.map { y =>
108- getOnnxTensor(x, y)
105+ Some (tens.map { x =>
106+ getOnnxTensor(x._1, x._2._3.toSeq.toArray)
109107 }
110- } )
108+ )
111109 }
112110 }
113111 .toList
114112 .sequence
115113 .map(_.toArray)
116114 }
117115
118- val opModel = for {
116+ def res (opModelBytes : Array [Byte ], inputTensorss : IO [Array [OnnxTensor [T ]]]) : Tensor [T , Tuple3 [Tt , Td , S ]] = {
117+ cats.effect.Resource .make(inputTensorss)(inTens => IO {}).
118+ use(x =>
119+ cats.effect.Resource
120+ .make(IO .blocking(getSession(opModelBytes)))(sess => IO {})
121+ .use(sess =>
122+ runModel(
123+ sess,
124+ x,
125+ input_node_names,
126+ output_node_names
127+ )
128+ )
129+ // }
130+ )
131+
132+ }
133+
134+ val finalRes = for {
119135 tens <- inputTensors.memoize
120136 t <- tens
121- } yield opToModelProto(
137+ } yield res( opToModelProto(
122138 opName,
123139 (t.map(x => x.asInstanceOf [tensorMod.Tensor ].`type`.valueOf.toString match {
124140 // Can't access the enum int values here
@@ -135,25 +151,11 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
135151 )
136152 zip t.map(_.dims.map(_.toInt).toArray)),
137153 attrs
138- ).toByteArray
139-
140- val res : Tensor [T , Tuple3 [Tt , Td , S ]] = {
141- inputTensors.flatMap { x =>
142- cats.effect.Resource
143- .make(opModel.map(getSession(_)))(sess => IO {})
144- .use(sess =>
145- runModel(
146- sess,
147- x,
148- input_node_names,
149- output_node_names
150- )
151- )
152- // }
153- }
154+ ).toByteArray,
155+ tens
156+ )
154157
155- }
156- res.flatMap(IO .println(" opNAme = " + opName).as(_))
158+ finalRes.flatten
157159 // res
158160 }
159161
@@ -172,11 +174,11 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
172174 val result : Tensor [T , Tuple3 [Tt , Td , S ]] =
173175 callByteArrayOp(
174176 inputs,
175- IO { inputNodeNames} ,
177+ inputNodeNames,
176178 opName,
177179 attrs
178180 )
179- result
181+ result // .flatMap(x => IO.println("opName = " + opName).as(x))
180182 }
181183
182184 def runModel [
@@ -189,34 +191,34 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
189191 org.emergentorder.onnx.onnxruntimeCommon.inferenceSessionMod.InferenceSession
190192 ],
191193 input_tensor_values : Array [OnnxTensor [T ]],
192- inputNames : IO [ List [String ] ],
193- outputNames : IO [ List [String ] ]
194+ inputNames : List [String ],
195+ outputNames : List [String ]
194196 )(using
195197 tt : ValueOf [Tt ],
196198 td : TensorShapeDenotationOf [Td ],
197199 s : ShapeOf [S ]
198200 ): Tensor [T , Tuple3 [Tt , Td , S ]] = {
199201
200- val feeds : IO [ js.Dictionary [OnnxTensor [T ]]] = inputNames.map(x => {
201- val zipped = x .toArray zip input_tensor_values
202+ val feeds : js.Dictionary [OnnxTensor [T ]] = {
203+ val zipped = inputNames .toArray zip input_tensor_values
202204 js.Dictionary (zipped.map(z => z._1 -> z._2): _* )
203- })
205+ }
204206
205207 val output_tensors : IO [org.emergentorder.onnx.onnxruntimeCommon.tensorMod.Tensor ] =
206208 IO .fromFuture {
207209 sess
208210 .flatMap { realSess =>
209- feeds.flatMap { realFeeds =>
211+ // feeds.flatMap { realFeeds =>
210212 val res = IO .eval(cats.Eval .later {
211213 realSess
212214 .run(
213- realFeeds .asInstanceOf [
215+ feeds .asInstanceOf [
214216 org.emergentorder.onnx.onnxruntimeCommon.inferenceSessionMod.InferenceSession .FeedsType
215217 ]
216218 )
217219 .toFuture
218220 })
219- outputNames.flatMap { names =>
221+ // outputNames.flatMap { names =>
220222 res.map { result =>
221223 result.map { rr =>
222224// println(realSess.outputNames.toList)
@@ -228,8 +230,8 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
228230 .get
229231 }
230232 }
231- }
232- }
233+ // }
234+ // }
233235 }
234236 }
235237
@@ -359,8 +361,8 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
359361 ](
360362 session,
361363 inputs,
362- IO .pure { List (" data_0" ) } ,
363- IO .pure { List (" squeezenet0_flatten0_reshape0" ) }
364+ List (" data_0" ) ,
365+ List (" squeezenet0_flatten0_reshape0" )
364366 )
365367
366368 // res.foreach(tens => tens.data.foreach(println))
0 commit comments