File tree Expand file tree Collapse file tree 3 files changed +32
-7
lines changed
Expand file tree Collapse file tree 3 files changed +32
-7
lines changed Original file line number Diff line number Diff line change @@ -120,7 +120,19 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter {
120120 t <- tens
121121 } yield opToModelProto(
122122 opName,
123- (t.map(_.asInstanceOf [tensorMod.Tensor ].`type`.valueOf.asInstanceOf [Float ].round)
123+ (t.map(x => x.asInstanceOf [tensorMod.Tensor ].`type`.valueOf.toString match {
124+ // Can't access the enum int values here
125+ // But it's fine, doesn't match the ONNX spec anyway
126+ case " int8" => 3
127+ case " int16" => 5
128+ case " float64" => 11
129+ case " float32" => 1
130+ case " int32" => 6
131+ case " int64" => 7
132+ case " bool" => 9
133+ case y => y.toInt
134+ }
135+ )
124136 zip t.map(_.dims.map(_.toInt).toArray)),
125137 attrs
126138 ).toByteArray
Original file line number Diff line number Diff line change @@ -155,7 +155,20 @@ trait ORTOperatorBackend extends OpToONNXBytesConverter with AutoCloseable {
155155 } yield res(
156156 opToModelProto(
157157 opName,
158- (t.map(_.getInfo.onnxType.value) zip { t.map(_.getInfo.getShape.map(_.toInt) match {
158+ (t.map(_.getInfo.onnxType.value match {
159+ // ORT has two different enums for this for the Java and C APIs
160+ // Neither matches the ONNX spec
161+ case 2 => 3
162+ case 4 => 5
163+ case 10 => 1
164+ case 8 => 7
165+ case 13 => 9
166+ case n => n
167+ }
168+ )
169+
170+ zip
171+ { t.map(_.getInfo.getShape.map(_.toInt) match {
159172 // ORT shape inference diverges from the ONNX spec in requiring a scalar here instead of a tensor with shape,
160173 // causing a crash without this fix
161174 case Array (1 ) => if (opName.equals(" Dropout" )) Array [Int ]() else Array (1 )
Original file line number Diff line number Diff line change @@ -181,13 +181,13 @@ trait OpToONNXBytesConverter {
181181 {
182182
183183 val elemType = elemTypeIn match {
184- case 2 => INT8 .index
185- case 4 => INT16 .index
184+ case 3 => INT8 .index
185+ case 5 => INT16 .index
186186 case 11 => DOUBLE .index
187- case 10 => FLOAT .index
187+ case 1 => FLOAT .index
188188 case 6 => INT32 .index
189- case 8 => INT64 .index
190- case 13 => BOOL .index
189+ case 7 => INT64 .index
190+ case 9 => BOOL .index
191191 case _ => INT64 .index // In case of Scala.js BigInt
192192 }
193193// tens.shape.map { y =>
You can’t perform that action at this time.
0 commit comments