forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 5
Open
Labels
Description
Description
Attempting to run the following example fails with a crash error.
The error emerges only if double precision (float64) is used. Single precision works fine
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.core import freeze, unfreeze
from jax import random
import numpy as np
# Set float64 globally (optional, good for consistency)
jax.config.update("jax_enable_x64", True)
# Define a simple CNN model with float64 dtype
class SimpleConv(nn.Module):
features: int
kernel_size: tuple = (3, 3)
@nn.compact
def __call__(self, x):
conv = nn.Conv(
features=self.features,
kernel_size=self.kernel_size,
dtype=jnp.float64, # ensures weights and output are float64
param_dtype=jnp.float64 # ensures parameters are float64
)
return conv(x)
# Initialize PRNG
key = random.PRNGKey(0)
# Dummy float64 input (e.g. a 1-channel 8x8 image)
x = jnp.ones((1, 8, 8, 1), dtype=jnp.float64)
# Initialize model
model = SimpleConv(features=4)
params = model.init(key, x) # initializes weights with float64
# Apply the model
output = model.apply(params, x)
print("Output dtype:", output.dtype)
print("Conv kernel dtype:", params['params']['Conv_0']['kernel'].dtype)
The error is
F0409 21:57:27.528072 1020082 rocm_dnn.cc:1944] Unsupported DNN data type: tf.float64 (dnn::DataType::kDouble)
*** Check failure stack trace: ***
@ 0x1553e4f4a114 absl::lts_20230802::log_internal::LogMessage::SendToLog()
@ 0x1553e4f49ab4 absl::lts_20230802::log_internal::LogMessage::Flush()
@ 0x1553e4f4a579 absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
@ 0x1553e4cfd3f5 stream_executor::gpu::(anonymous namespace)::ToMIOpenDataType()
@ 0x1553e4d05a09 stream_executor::gpu::MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode()
@ 0x1553e4d03977 stream_executor::gpu::MIOpenSupport::GetMIOpenConvolveAlgorithms()
@ 0x1553e4d031b4 stream_executor::gpu::MIOpenSupport::GetConvolveRunners()
@ 0x1553e018980a xla::gpu::GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm()
@ 0x1553e01889c5 xla::gpu::GpuConvAlgorithmPicker::PickBestAlgorithmNoCache()
@ 0x1553e018df1b std::_Function_handler<>::_M_invoke()
@ 0x1553e01be229 xla::gpu::AutotunerUtil::Autotune()
@ 0x1553e018bd2b xla::gpu::GpuConvAlgorithmPicker::RunOnInstruction()
@ 0x1553e018d39e xla::gpu::GpuConvAlgorithmPicker::RunOnComputation()
@ 0x1553e018d5de xla::gpu::GpuConvAlgorithmPicker::Run()
@ 0x1553e01c5b6f xla::HloPassPipeline::RunHelper()
@ 0x1553e01c2dc5 xla::HloPassPipeline::RunPassesInternal<>()
@ 0x1553e01c271a xla::HloPassPipeline::Run()
@ 0x1553dca5b38e xla::gpu::GpuCompiler::OptimizeHloPostLayoutAssignment()
@ 0x1553dca4d8b1 xla::gpu::AMDGPUCompiler::OptimizeHloPostLayoutAssignment()
@ 0x1553dca564e4 xla::gpu::GpuCompiler::OptimizeHloModule()
@ 0x1553dca5f062 xla::gpu::GpuCompiler::RunHloPasses()
@ 0x1553dca3694f xla::Service::BuildExecutable()
@ 0x1553dc9af067 xla::LocalService::CompileExecutables()
@ 0x1553dc9a9ad4 xla::LocalClient::Compile()
@ 0x1553dc94afdb xla::PjRtStreamExecutorClient::CompileInternal()
@ 0x1553dc94c2b0 xla::PjRtStreamExecutorClient::Compile()
@ 0x1553dc8b832a std::__detail::__variant::__gen_vtable_impl<>::__visit_invoke()
@ 0x1553dc8a7cbc pjrt::PJRT_Client_Compile()
@ 0x1554c8e0f3fd xla::InitializeArgsAndCompile()
@ 0x1554c8e0fb66 xla::PjRtCApiClient::Compile()
@ 0x1554cfdba31c xla::ifrt::PjRtLoadedExecutable::Create()
@ 0x1554cfdb55e7 xla::ifrt::PjRtCompiler::Compile()
@ 0x1554cef07afd xla::PyClient::CompileIfrtProgram()
@ 0x1554cef0890e xla::PyClient::Compile()
@ 0x1554cef0fbaf nanobind::detail::func_create<>()::{lambda()#1}::__invoke()
@ 0x1554cfd8d101 nanobind::detail::nb_func_vectorcall_complex()
@ 0x1554e438b84c nanobind::detail::nb_bound_method_vectorcall()
@ 0x1554e6cd4a39 _PyEval_EvalFrameDefault
System info (python version, jaxlib version, accelerator, etc.)
I'm running with ROCM 6.3.3
using jax[rocm]==0.5.0
installed on bare metal by doing uv add jax[rocm]>=0.5
from PyPi (the wheels that have been just released)
[cad14908] fvicentini@a1007:~/rep4$ uv tree | grep jax
Resolved 61 packages in 0.81ms
├── jax[rocm] v0.5.0
│ ├── jaxlib v0.5.0
│ ├── jax-rocm60-plugin v0.5.0 (extra: rocm)
│ │ └── jax-rocm60-pjrt v0.5.0
│ └── jaxlib v0.5.0 (extra: rocm) (*)
│ ├── jax v0.5.0 (*)
│ ├── jaxtyping v0.3.1
│ ├── jax v0.5.0 (*)
│ │ │ ├── jax v0.5.0 (*)
│ │ │ ├── jaxlib v0.5.0 (*)
│ │ ├── jax v0.5.0 (*)
│ │ ├── jaxlib v0.5.0 (*)
│ │ ├── jax v0.5.0 (*)
├── jax v0.5.0 (*)
In [1]: import jax; jax.print_environment_info()
...:
jax: 0.5.0
jaxlib: 0.5.0
numpy: 2.2.4
python: 3.12.8 (main, Dec 6 2024, 19:59:28) [Clang 18.1.8 ]
device info: AMD Instinct MI300A-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='a1007', release='5.14.0-427.26.1.el9_4.x86_64', version='#1 SMP PREEMPT_DYNAMIC Fri Jul 5 11:34:54 EDT 2024', machine='x86_64')