Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions torchx/specs/finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-strict

import abc
import ast
import copy
import importlib
import inspect
Expand Down Expand Up @@ -278,6 +279,22 @@ def _get_validation_errors(
linter_errors = validate(path, function_name, validators)
return [linter_error.description for linter_error in linter_errors]

def _get_function_decorators_count(
self, function: Callable[..., Any] # pyre-ignore[2]
) -> int:
"""
Returns the count of decorators for the given function.
"""
try:
source = inspect.getsource(function)
tree = ast.parse(source)
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
return len(node.decorator_list)
except (OSError, TypeError):
return 0
return 0

def _get_path_to_function_decl(
self, function: Callable[..., Any] # pyre-ignore[2]
) -> str:
Expand All @@ -287,9 +304,10 @@ def _get_path_to_function_decl(
my_component defined in some_file.py, imported in other_file.py
and the component is invoked as other_file.py:my_component
"""
# Unwrap decorated functions to get the original function
unwrapped_function = inspect.unwrap(function)
path_to_function_decl = inspect.getabsfile(unwrapped_function)
# unwrap the function if it has decorators
if self._get_function_decorators_count(function) > 0:
function = inspect.unwrap(function)
path_to_function_decl = inspect.getabsfile(function)
if path_to_function_decl is None or not os.path.isfile(path_to_function_decl):
return self._filepath
return path_to_function_decl
Expand Down
19 changes: 19 additions & 0 deletions torchx/specs/test/components/f/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import functools

from torchx import specs

from .g import Cls


@functools.wraps(Cls)
def comp_f(**kwargs) -> specs.AppDef: # pyre-ignore[2]
return Cls(**kwargs).build()
50 changes: 50 additions & 0 deletions torchx/specs/test/components/f/g.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
from dataclasses import dataclass

import torchx
from torchx import specs

from .h import fake_decorator


@dataclass
class Args:
name: str


@dataclass
class Cls(Args):
def build(self) -> specs.AppDef:
return specs.AppDef(
name=self.name,
roles=[
specs.Role(
name=self.name,
image=torchx.IMAGE,
entrypoint="echo",
args=["hello world"],
)
],
)


@fake_decorator
def comp_g() -> specs.AppDef:
return specs.AppDef(
name="g",
roles=[
specs.Role(
name="g",
image=torchx.IMAGE,
entrypoint="echo",
args=["hello world"],
)
],
)
23 changes: 23 additions & 0 deletions torchx/specs/test/components/f/h.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict


import functools
from typing import Any, Callable


def fake_decorator( # pyre-ignore[3]
func: Callable[..., Any], # pyre-ignore[2]
) -> Callable[..., Any]:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]: # pyre-ignore[3]
# Fake decorator: just calls the original function
return func(*args, **kwargs)

return wrapper
10 changes: 10 additions & 0 deletions torchx/specs/test/finder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
ModuleComponentsFinder,
)
from torchx.specs.test.components.a import comp_a
from torchx.specs.test.components.f import comp_f
from torchx.specs.test.components.f.g import comp_g
from torchx.util.test.entrypoints_test import EntryPoint_from_text
from torchx.util.types import none_throws

Expand Down Expand Up @@ -243,6 +245,14 @@ def test_get_component_imported_from_other_file(self) -> None:
component = get_component(f"{current_file_path()}:comp_a")
self.assertListEqual([], component.validation_errors)

def test_get_component_from_dataclass(self) -> None:
component = get_component(f"{current_file_path()}:comp_f")
self.assertListEqual([], component.validation_errors)

def test_get_component_from_decorator(self) -> None:
component = get_component(f"{current_file_path()}:comp_g")
self.assertListEqual([], component.validation_errors)


class GetBuiltinSourceTest(unittest.TestCase):
def setUp(self) -> None:
Expand Down
Loading