-
Notifications
You must be signed in to change notification settings - Fork 88
Description
Issue Description
I am facing an issue when trying to save and load JIT-compiled functions using tensorcircuit.keras.save_func()
and tensorcircuit.keras.load_func()
. Specifically, I am trying to save and load the qpred
(or qlayer
) function in my hybrid model, but I encounter the following error when trying to load the function:
File "/home/.miniconda3/envs/qml/lib/python3.10/site-packages/tensorcircuit/keras.py", line 284, in wrapper *
return m.f(*args, **kws)
AttributeError: '_UserObject' object has no attribute 'f'
Here is the code that I am working with:
class HybridModel(torch.nn.Module):
def __init__(self, trunk_size, n_layers=2, n_hidden_layers=4, n_wires=2):
super().__init__()
K = tc.set_backend("tensorflow")
tf_device = "/gpu"
@tf.function
def qpred(inputs, weights):
with tf.device(tf_device):
c = circuit(inputs, weights, trunk_size)
observables = K.stack([K.real(c.expectation_ps(z=[i]))
for i in range(n_wires)])
return observables
self.qpred = qpred
self.qlayer = tc.TorchLayer(
self.qpred, weights_shape=[2*n_layers, n_hidden_layers, n_wires, 2], use_jit=True, enable_dlpack=True)
self.clayer = torch.nn.Linear(n_wires, 1)
def forward(self, inputs):
outputs = self.qlayer(inputs)
outputs = torch.mean(outputs, axis=1)
return outputs
What I have tried:
- I have attempted to use
tensorcircuit.keras.save_func()
andtensorcircuit.keras.load_func()
to save and load the function qpred or qlayer, but it results in the above error.
I am wondering if there is a different approach to saving/loading the JIT-compiled function, or if there is a potential issue with the way TensorCircuit handles saved functions in this context.
Would you be able to provide guidance or suggest an alternative solution for saving/loading the function, especially one that involves JIT compilation?
Thank you very much for your time and assistance. I appreciate any help you can provide!
Environment Context
OS info: Linux-5.4.0-150-generic-x86_64-with-glibc2.27
Python version: 3.10.14
Numpy version: 1.26.4
Scipy version: 1.12.0
Pandas version: 2.2.2
TensorNetwork version: 0.5.1
Cotengra version: 0.6.2
TensorFlow version: 2.18.0
TensorFlow GPU: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:2', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:3', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:4', device_type='GPU')]
TensorFlow CUDA infos: {'cpu_compiler': '/usr/lib/llvm-18/bin/clang', 'cuda_compute_capabilities': ['sm_60', 'sm_70', 'sm_80', 'sm_89', 'compute_90'], 'cuda_version': '12.5.1', 'cudnn_version': '9', 'is_cuda_build': True, 'is_rocm_build': False, 'is_tensorrt_build': False}
Jax version: 0.4.23
Jax GPU: [cuda(id=0), cuda(id=1), cuda(id=2), cuda(id=3), cuda(id=4)]
JaxLib version: 0.4.23
PyTorch version: 2.5.1+cu124
PyTorch GPU support: True
PyTorch GPUs: [<torch.cuda.device object at 0x7fec328bd8d0>, <torch.cuda.device object at 0x7fec328bd900>, <torch.cuda.device object at 0x7fec328bd8a0>, <torch.cuda.device object at 0x7fec328bdc30>, <torch.cuda.device object at 0x7fec328bdc90>]
Pytorch cuda version: 12.4
Cupy is not installed
Qiskit version: 1.3.1
Cirq version: 1.4.1
TensorCircuit version 0.12.0