forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 5
Open
Labels
Description
Description
I am running on 'bare metal' jax/lib 0.4.35 installed from the releases in this repository.
I had to work around the missing libsuitesparseconfig.so.4
by manually installing it.
When running a relatively straightforward script that uses sharding on a single process addressing 4 local GPUS I get weird errors.
The script and pyproject file can be found at this gist.
export NETKET_ENABLE_X64=1
export NETKET_EXPERIMENTAL_SHARDING=1
export NETKET_MPI=0
[cad14908] fvicentini@a1004:~/rep2$ uv run simple_nocnn.py
Config:
Global configurations for NetKet
- NETKET_DEBUG = False
- NETKET_EXPERIMENTAL = False
- NETKET_MPI_WARNING = True
- NETKET_MPI = False
- NETKET_USE_PLAIN_RHAT = False
- NETKET_EXPERIMENTAL_FFT_AUTOCORRELATION = False
- NETKET_EXPERIMENTAL_DISABLE_ODE_JIT = True
- NETKET_EXPERIMENTAL_SHARDING_CPU = 0
- NETKET_ENABLE_X64 = True
- NETKET_SPHINX_BUILD = False
- NETKET_EXPERIMENTAL_SHARDING = True
- NETKET_MPI_AUTODETECT_LOCAL_GPU = False
- NETKET_RANDOM_STATE_FALLBACK_WARNING = True
- NETKET_EXPERIMENTAL_SHARDING_FAST_SERIALIZATION = False
- NETKET_SPIN_ORDERING_WARNING = True
jax global devices [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]
jax local devices [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]
0%| | 0/500 [00:00<?, ?it/s]
Hip error: 'operation would make the legacy stream depend on a capturing blocking stream'(906) at /long_pathname_so_that_rpms_can_package_the_debug_info/src/extlibs/hipBLASLt/library/src/amd_detail/hipblaslt.cpp:135
Hip error: 'operation would make the legacy stream depend on a capturing blocking stream'(906) at /long_pathname_so_that_rpms_can_package_the_debug_info/src/extlibs/hipBLASLt/library/src/amd_detail/hipblaslt.cpp:135
rocBLAS error: Could not initialize Tensile host:
_Map_base::at
System info (python version, jaxlib version, accelerator, etc.)
[cad14908] fvicentini@a1002:~/rep2/lib$ uv run python -c 'import jax; jax.print_environment_info()'
jax: 0.4.33
jaxlib: 0.4.33
numpy: 2.1.3
python: 3.12.1 (main, Sep 18 2024, 23:46:30) [GCC 12.2.1 20221121 (Red Hat 12.2.1-7)]
jax.devices (4 total, 4 local): [RocmDevice(id=0) RocmDevice(id=1) RocmDevice(id=2) RocmDevice(id=3)]
process_count: 1
platform: uname_result(system='Linux', node='a1002', release='5.14.0-427.26.1.el9_4.x86_64', version='#1 SMP PREEMPT_DYNAMIC Fri Jul 5 11:34:54 EDT 2024', machine='x86_64')
============================================= ROCm System Management Interface =============================================
======================================================= Concise Info =======================================================
Device Node IDs Temp Power Partitions SCLK MCLK Fan Perf PwrCap VRAM% GPU%
(DID, GUID) (Junction) (Socket) (Mem, Compute, ID)
============================================================================================================================
0 4 0x74a0, 16722 45.0°C 110.0W NPS1, SPX, 0 102Mhz 1200Mhz 0% manual 550.0W 0% 0%
1 5 0x74a0, 8346 46.0°C 70.0W NPS1, SPX, 0 94Mhz 900Mhz 0% manual 550.0W 0% 0%
2 6 0x74a0, 33475 45.0°C 106.0W NPS1, SPX, 0 94Mhz 1200Mhz 0% manual 550.0W 0% 0%
3 7 0x74a0, 25611 47.0°C 72.0W NPS1, SPX, 0 95Mhz 900Mhz 0% manual 550.0W 0% 0%
============================================================================================================================
=================================================== End of ROCm SMI Log ====================================================