From 1266e547c52bd44aa0f586cf9732858693d863a3 Mon Sep 17 00:00:00 2001 From: Min Guo Date: Mon, 24 Feb 2025 12:35:34 -0800 Subject: [PATCH 1/2] add argmin support --- .../converters/mil/frontend/torch/ops.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 4e53a5765..38c5fb25c 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -6313,6 +6313,40 @@ def _parse_keyword_args(context, node, dim, keepdim) -> Tuple[Var]: context.add(res) +@register_torch_op +def argmin(context, node): + def _parse_positional_args(context, node) -> Tuple[Var]: + inputs = _get_inputs(context, node, expected=(1, 2, 3, 4)) + nargs = len(inputs) + + x = inputs[0] + + dim = inputs[1] if nargs > 1 else None + keepdim = inputs[2] if nargs > 2 else False + + # When node.kind == argmin.out, there can be 1 more arg `Tensor(a!) out`, + # which is for in-place mutation, so we ignore it since Core ML is functional + return x, dim, keepdim + + def _parse_keyword_args(context, node, dim, keepdim) -> Tuple[Var]: + dim = _get_kwinputs(context, node, "dim", default=[dim])[0] + keepdim = _get_kwinputs(context, node, "keepdim", default=[keepdim])[0] + return dim, keepdim + + x, dim, keepdim = _parse_positional_args(context, node) + dim, keepdim = _parse_keyword_args(context, node, dim, keepdim) + if isinstance(dim, Var): + dim = dim.val + if isinstance(keepdim, Var): + keepdim = keepdim.val + + if types.is_int(x.dtype) and x.dtype._width == 64: + # MIL reduce_argmin doesn't support int64. + x = mb.cast(x=x, dtype="int32") + res = mb.reduce_argmin(x=x, axis=dim, keep_dims=keepdim, name=node.name) + context.add(res) + + @register_torch_op(torch_alias=["empty_like"]) def zeros_like(context, node): inputs = _get_inputs( From f11efcb7d15bdfa69c5f403db711e0bce5454fe4 Mon Sep 17 00:00:00 2001 From: Min Guo Date: Mon, 24 Feb 2025 16:16:18 -0800 Subject: [PATCH 2/2] fixed aten.pow.Tensor lowering issue --- coremltools/converters/mil/frontend/torch/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/coremltools/converters/mil/frontend/torch/utils.py b/coremltools/converters/mil/frontend/torch/utils.py index ffda8623d..a1efa7b18 100644 --- a/coremltools/converters/mil/frontend/torch/utils.py +++ b/coremltools/converters/mil/frontend/torch/utils.py @@ -145,6 +145,7 @@ def skip_default_prefix_and_suffix_with_deliminator( "tensor_mode", "scalar", "tensor_scalar", + "tensor_tensor", } and len(split) - start > 1 else len(split)