Skip to content

Convolutions with Float64: Unsupported DNN data type: tf.float64 #352

@PhilipVinc

Description

@PhilipVinc

Description

Attempting to run the following example fails with a crash error.
The error emerges only if double precision (float64) is used. Single precision works fine

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.core import freeze, unfreeze
from jax import random
import numpy as np

# Set float64 globally (optional, good for consistency)
jax.config.update("jax_enable_x64", True)

# Define a simple CNN model with float64 dtype
class SimpleConv(nn.Module):
    features: int
    kernel_size: tuple = (3, 3)
    
    @nn.compact
    def __call__(self, x):
        conv = nn.Conv(
            features=self.features,
            kernel_size=self.kernel_size,
            dtype=jnp.float64,        # ensures weights and output are float64
            param_dtype=jnp.float64   # ensures parameters are float64
        )
        return conv(x)

# Initialize PRNG
key = random.PRNGKey(0)

# Dummy float64 input (e.g. a 1-channel 8x8 image)
x = jnp.ones((1, 8, 8, 1), dtype=jnp.float64)

# Initialize model
model = SimpleConv(features=4)
params = model.init(key, x)  # initializes weights with float64

# Apply the model
output = model.apply(params, x)

print("Output dtype:", output.dtype)
print("Conv kernel dtype:", params['params']['Conv_0']['kernel'].dtype)

The error is

F0409 21:57:27.528072 1020082 rocm_dnn.cc:1944] Unsupported DNN data type: tf.float64 (dnn::DataType::kDouble)
*** Check failure stack trace: ***
    @     0x1553e4f4a114  absl::lts_20230802::log_internal::LogMessage::SendToLog()
    @     0x1553e4f49ab4  absl::lts_20230802::log_internal::LogMessage::Flush()
    @     0x1553e4f4a579  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x1553e4cfd3f5  stream_executor::gpu::(anonymous namespace)::ToMIOpenDataType()
    @     0x1553e4d05a09  stream_executor::gpu::MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode()
    @     0x1553e4d03977  stream_executor::gpu::MIOpenSupport::GetMIOpenConvolveAlgorithms()
    @     0x1553e4d031b4  stream_executor::gpu::MIOpenSupport::GetConvolveRunners()
    @     0x1553e018980a  xla::gpu::GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm()
    @     0x1553e01889c5  xla::gpu::GpuConvAlgorithmPicker::PickBestAlgorithmNoCache()
    @     0x1553e018df1b  std::_Function_handler<>::_M_invoke()
    @     0x1553e01be229  xla::gpu::AutotunerUtil::Autotune()
    @     0x1553e018bd2b  xla::gpu::GpuConvAlgorithmPicker::RunOnInstruction()
    @     0x1553e018d39e  xla::gpu::GpuConvAlgorithmPicker::RunOnComputation()
    @     0x1553e018d5de  xla::gpu::GpuConvAlgorithmPicker::Run()
    @     0x1553e01c5b6f  xla::HloPassPipeline::RunHelper()
    @     0x1553e01c2dc5  xla::HloPassPipeline::RunPassesInternal<>()
    @     0x1553e01c271a  xla::HloPassPipeline::Run()
    @     0x1553dca5b38e  xla::gpu::GpuCompiler::OptimizeHloPostLayoutAssignment()
    @     0x1553dca4d8b1  xla::gpu::AMDGPUCompiler::OptimizeHloPostLayoutAssignment()
    @     0x1553dca564e4  xla::gpu::GpuCompiler::OptimizeHloModule()
    @     0x1553dca5f062  xla::gpu::GpuCompiler::RunHloPasses()
    @     0x1553dca3694f  xla::Service::BuildExecutable()
    @     0x1553dc9af067  xla::LocalService::CompileExecutables()
    @     0x1553dc9a9ad4  xla::LocalClient::Compile()
    @     0x1553dc94afdb  xla::PjRtStreamExecutorClient::CompileInternal()
    @     0x1553dc94c2b0  xla::PjRtStreamExecutorClient::Compile()
    @     0x1553dc8b832a  std::__detail::__variant::__gen_vtable_impl<>::__visit_invoke()
    @     0x1553dc8a7cbc  pjrt::PJRT_Client_Compile()
    @     0x1554c8e0f3fd  xla::InitializeArgsAndCompile()
    @     0x1554c8e0fb66  xla::PjRtCApiClient::Compile()
    @     0x1554cfdba31c  xla::ifrt::PjRtLoadedExecutable::Create()
    @     0x1554cfdb55e7  xla::ifrt::PjRtCompiler::Compile()
    @     0x1554cef07afd  xla::PyClient::CompileIfrtProgram()
    @     0x1554cef0890e  xla::PyClient::Compile()
    @     0x1554cef0fbaf  nanobind::detail::func_create<>()::{lambda()#1}::__invoke()
    @     0x1554cfd8d101  nanobind::detail::nb_func_vectorcall_complex()
    @     0x1554e438b84c  nanobind::detail::nb_bound_method_vectorcall()
    @     0x1554e6cd4a39  _PyEval_EvalFrameDefault

System info (python version, jaxlib version, accelerator, etc.)

I'm running with ROCM 6.3.3 using jax[rocm]==0.5.0 installed on bare metal by doing uv add jax[rocm]>=0.5 from PyPi (the wheels that have been just released)

[cad14908] fvicentini@a1007:~/rep4$ uv tree | grep jax
Resolved 61 packages in 0.81ms
├── jax[rocm] v0.5.0
│   ├── jaxlib v0.5.0
│   ├── jax-rocm60-plugin v0.5.0 (extra: rocm)
│   │   └── jax-rocm60-pjrt v0.5.0
│   └── jaxlib v0.5.0 (extra: rocm) (*)
    │   ├── jax v0.5.0 (*)
    │   ├── jaxtyping v0.3.1
    │   ├── jax v0.5.0 (*)
    │   │   │   ├── jax v0.5.0 (*)
    │   │   │   ├── jaxlib v0.5.0 (*)
    │   │   ├── jax v0.5.0 (*)
    │   │   ├── jaxlib v0.5.0 (*)
    │   │   ├── jax v0.5.0 (*)
    ├── jax v0.5.0 (*)
In [1]: import jax; jax.print_environment_info()
   ...:
jax:    0.5.0
jaxlib: 0.5.0
numpy:  2.2.4
python: 3.12.8 (main, Dec  6 2024, 19:59:28) [Clang 18.1.8 ]
device info: AMD Instinct MI300A-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='a1007', release='5.14.0-427.26.1.el9_4.x86_64', version='#1 SMP PREEMPT_DYNAMIC Fri Jul 5 11:34:54 EDT 2024', machine='x86_64')

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions