Skip to content

torch issue with AUC metric #21770

@eric

Description

@eric

I've run into a problem when using the PyTorch backend where training fails with AUC metrics if validation_data includes sample weights. The error occurs during symbolic build phase.

I've found a work-around but I don't really know what I'm doing, so I wanted to raise the issue before trying to solve it for real.

Environment

  • Keras version: 3.5.0
  • Backend: torch (PyTorch 2.9.0)
  • Python version: 3.12
  • Operating System: macOS / Linux

Error Message

RuntimeError: cannot repeat_interleave a meta tensor without output_size
NotImplementedError: Cannot copy out of meta tensor; no data!

Stack trace shows error originates from:

File "keras/src/backend/torch/math.py", line 21, in _segment_reduction_fn
    segment_ids.repeat_interleave(num_repeats)

Minimal Reproducible Example

import os
os.environ['KERAS_BACKEND'] = 'torch'

import numpy as np
import keras

# Create simple model
inputs = keras.Input(shape=(10,))
x = keras.layers.Dense(16, activation='relu')(inputs)
outputs = keras.layers.Dense(1, activation='sigmoid')(x)
model = keras.Model(inputs=inputs, outputs=outputs)

# Compile with AUC metric
model.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=[
        keras.metrics.BinaryAccuracy(name='accuracy'),
        keras.metrics.AUC(name='auc'),
    ]
)

# Generate data
X_train = np.random.randn(100, 10).astype(np.float32)
y_train = np.random.randint(0, 2, (100, 1)).astype(np.float32)
sample_weights_train = np.random.rand(100, 1).astype(np.float32)

X_val = np.random.randn(20, 10).astype(np.float32)
y_val = np.random.randint(0, 2, (20, 1)).astype(np.float32)
sample_weights_val = np.random.rand(20, 1).astype(np.float32)

# This will fail with the error above
history = model.fit(
    X_train, y_train,
    epochs=1,
    batch_size=32,
    validation_data=(X_val, y_val, sample_weights_val),
    sample_weight=sample_weights_train
)

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions