From 3664285246133460d9b50ac22e2b9748bae91e5a Mon Sep 17 00:00:00 2001 From: Hui Qin Ng Date: Mon, 15 Sep 2025 12:41:08 -0700 Subject: [PATCH] Improve _get_path_to_function_decl to handle function wrapper with class (#1116) Summary: We added file decorator support in https://github.com/pytorch/torchx/pull/1111 **Problem:** This will fail when the function wrapper with dataclass object **Fix:** Determine if decorators found in function before unwrap. Add two test cases to cover: * comp_f using dataclass in g.py => should return __init__.py * comp_g using decorator in h.py => should return g.py Reviewed By: ethanbwaite Differential Revision: D82346696 --- torchx/specs/finder.py | 24 +++++++++-- torchx/specs/test/components/f/__init__.py | 19 ++++++++ torchx/specs/test/components/f/g.py | 50 ++++++++++++++++++++++ torchx/specs/test/components/f/h.py | 23 ++++++++++ torchx/specs/test/finder_test.py | 10 +++++ 5 files changed, 123 insertions(+), 3 deletions(-) create mode 100644 torchx/specs/test/components/f/__init__.py create mode 100644 torchx/specs/test/components/f/g.py create mode 100644 torchx/specs/test/components/f/h.py diff --git a/torchx/specs/finder.py b/torchx/specs/finder.py index 1e92baf25..b90f2a88f 100644 --- a/torchx/specs/finder.py +++ b/torchx/specs/finder.py @@ -7,6 +7,7 @@ # pyre-strict import abc +import ast import copy import importlib import inspect @@ -278,6 +279,22 @@ def _get_validation_errors( linter_errors = validate(path, function_name, validators) return [linter_error.description for linter_error in linter_errors] + def _get_function_decorators_count( + self, function: Callable[..., Any] # pyre-ignore[2] + ) -> int: + """ + Returns the count of decorators for the given function. + """ + try: + source = inspect.getsource(function) + tree = ast.parse(source) + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + return len(node.decorator_list) + except (OSError, TypeError): + return 0 + return 0 + def _get_path_to_function_decl( self, function: Callable[..., Any] # pyre-ignore[2] ) -> str: @@ -287,9 +304,10 @@ def _get_path_to_function_decl( my_component defined in some_file.py, imported in other_file.py and the component is invoked as other_file.py:my_component """ - # Unwrap decorated functions to get the original function - unwrapped_function = inspect.unwrap(function) - path_to_function_decl = inspect.getabsfile(unwrapped_function) + # unwrap the function if it has decorators + if self._get_function_decorators_count(function) > 0: + function = inspect.unwrap(function) + path_to_function_decl = inspect.getabsfile(function) if path_to_function_decl is None or not os.path.isfile(path_to_function_decl): return self._filepath return path_to_function_decl diff --git a/torchx/specs/test/components/f/__init__.py b/torchx/specs/test/components/f/__init__.py new file mode 100644 index 000000000..01bc90f7d --- /dev/null +++ b/torchx/specs/test/components/f/__init__.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import functools + +from torchx import specs + +from .g import Cls + + +@functools.wraps(Cls) +def comp_f(**kwargs) -> specs.AppDef: # pyre-ignore[2] + return Cls(**kwargs).build() diff --git a/torchx/specs/test/components/f/g.py b/torchx/specs/test/components/f/g.py new file mode 100644 index 000000000..07442b527 --- /dev/null +++ b/torchx/specs/test/components/f/g.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from dataclasses import dataclass + +import torchx +from torchx import specs + +from .h import fake_decorator + + +@dataclass +class Args: + name: str + + +@dataclass +class Cls(Args): + def build(self) -> specs.AppDef: + return specs.AppDef( + name=self.name, + roles=[ + specs.Role( + name=self.name, + image=torchx.IMAGE, + entrypoint="echo", + args=["hello world"], + ) + ], + ) + + +@fake_decorator +def comp_g() -> specs.AppDef: + return specs.AppDef( + name="g", + roles=[ + specs.Role( + name="g", + image=torchx.IMAGE, + entrypoint="echo", + args=["hello world"], + ) + ], + ) diff --git a/torchx/specs/test/components/f/h.py b/torchx/specs/test/components/f/h.py new file mode 100644 index 000000000..8daaa256f --- /dev/null +++ b/torchx/specs/test/components/f/h.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +import functools +from typing import Any, Callable + + +def fake_decorator( # pyre-ignore[3] + func: Callable[..., Any], # pyre-ignore[2] +) -> Callable[..., Any]: + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]: # pyre-ignore[3] + # Fake decorator: just calls the original function + return func(*args, **kwargs) + + return wrapper diff --git a/torchx/specs/test/finder_test.py b/torchx/specs/test/finder_test.py index 18f01b4c5..7ac585886 100644 --- a/torchx/specs/test/finder_test.py +++ b/torchx/specs/test/finder_test.py @@ -30,6 +30,8 @@ ModuleComponentsFinder, ) from torchx.specs.test.components.a import comp_a +from torchx.specs.test.components.f import comp_f +from torchx.specs.test.components.f.g import comp_g from torchx.util.test.entrypoints_test import EntryPoint_from_text from torchx.util.types import none_throws @@ -243,6 +245,14 @@ def test_get_component_imported_from_other_file(self) -> None: component = get_component(f"{current_file_path()}:comp_a") self.assertListEqual([], component.validation_errors) + def test_get_component_from_dataclass(self) -> None: + component = get_component(f"{current_file_path()}:comp_f") + self.assertListEqual([], component.validation_errors) + + def test_get_component_from_decorator(self) -> None: + component = get_component(f"{current_file_path()}:comp_g") + self.assertListEqual([], component.validation_errors) + class GetBuiltinSourceTest(unittest.TestCase): def setUp(self) -> None: