From 88b16f3154ce580e7c2882e5c683e23238c08216 Mon Sep 17 00:00:00 2001 From: Arseniy Obolenskiy Date: Thu, 23 Oct 2025 13:45:58 +0200 Subject: [PATCH] Add tests for NumPy language context and fix import path - Fix langctx module import in thunder.numpy - Add basic tests for thunder.numpy module --- thunder/numpy/__init__.py | 2 +- thunder/tests/test_numpy_langctx.py | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) create mode 100644 thunder/tests/test_numpy_langctx.py diff --git a/thunder/numpy/__init__.py b/thunder/numpy/__init__.py index dd7905756d..8e31f957e4 100644 --- a/thunder/numpy/__init__.py +++ b/thunder/numpy/__init__.py @@ -1,7 +1,7 @@ from numbers import Number from collections.abc import Callable -from thunder.core.langctx import langctx, Languages +from thunder.core.langctxs import langctx, Languages from thunder.numpy.langctx import register_method from thunder.core.proxies import TensorProxy diff --git a/thunder/tests/test_numpy_langctx.py b/thunder/tests/test_numpy_langctx.py new file mode 100644 index 0000000000..a44cc9815e --- /dev/null +++ b/thunder/tests/test_numpy_langctx.py @@ -0,0 +1,24 @@ +from thunder.numpy import size as np_size +from thunder.core.langctxs import langctx, Languages, resolve_language +from thunder.core.proxies import TensorProxy +from thunder.core.trace import detached_trace +from thunder.core.devices import cpu +from thunder.core.dtypes import float32 + + +def test_numpy_langctx_registration_and_len_size(): + with detached_trace(): + t = TensorProxy(shape=(2, 3), device=cpu, dtype=float32) + + with langctx(Languages.NUMPY): + assert len(t) == 2 # axis 0 length + assert t.size() == 6 # total elements + assert np_size(t) == 6 + + +def test_numpy_langctx_resolve_language(): + numpy_ctx_by_enum = resolve_language(Languages.NUMPY) + numpy_ctx_by_name = resolve_language("numpy") + + assert numpy_ctx_by_enum is numpy_ctx_by_name + assert numpy_ctx_by_enum.name == "numpy"