Skip to content

Jitted Function Save/Load #229

@xiazhuo

Description

@xiazhuo

Issue Description

I am facing an issue when trying to save and load JIT-compiled functions using tensorcircuit.keras.save_func() and tensorcircuit.keras.load_func(). Specifically, I am trying to save and load the qpred (or qlayer) function in my hybrid model, but I encounter the following error when trying to load the function:

  File "/home/.miniconda3/envs/qml/lib/python3.10/site-packages/tensorcircuit/keras.py", line 284, in wrapper  *
    return m.f(*args, **kws)
AttributeError: '_UserObject' object has no attribute 'f'

Here is the code that I am working with:

class HybridModel(torch.nn.Module): 
    def __init__(self, trunk_size, n_layers=2, n_hidden_layers=4, n_wires=2):
        super().__init__()
        K = tc.set_backend("tensorflow")
        tf_device = "/gpu"

        @tf.function
        def qpred(inputs, weights):
            with tf.device(tf_device):
                c = circuit(inputs, weights, trunk_size)
                observables = K.stack([K.real(c.expectation_ps(z=[i]))
                                       for i in range(n_wires)])
                return observables

        self.qpred = qpred
        self.qlayer = tc.TorchLayer(
            self.qpred, weights_shape=[2*n_layers, n_hidden_layers, n_wires, 2], use_jit=True, enable_dlpack=True)
        self.clayer = torch.nn.Linear(n_wires, 1)

    def forward(self, inputs):
        outputs = self.qlayer(inputs)
        outputs = torch.mean(outputs, axis=1)
        return outputs

What I have tried:

  • I have attempted to use tensorcircuit.keras.save_func() and tensorcircuit.keras.load_func() to save and load the function qpred or qlayer, but it results in the above error.

I am wondering if there is a different approach to saving/loading the JIT-compiled function, or if there is a potential issue with the way TensorCircuit handles saved functions in this context.

Would you be able to provide guidance or suggest an alternative solution for saving/loading the function, especially one that involves JIT compilation?

Thank you very much for your time and assistance. I appreciate any help you can provide!

Environment Context

OS info: Linux-5.4.0-150-generic-x86_64-with-glibc2.27
Python version: 3.10.14
Numpy version: 1.26.4
Scipy version: 1.12.0
Pandas version: 2.2.2
TensorNetwork version: 0.5.1
Cotengra version: 0.6.2
TensorFlow version: 2.18.0
TensorFlow GPU: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:2', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:3', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:4', device_type='GPU')]
TensorFlow CUDA infos: {'cpu_compiler': '/usr/lib/llvm-18/bin/clang', 'cuda_compute_capabilities': ['sm_60', 'sm_70', 'sm_80', 'sm_89', 'compute_90'], 'cuda_version': '12.5.1', 'cudnn_version': '9', 'is_cuda_build': True, 'is_rocm_build': False, 'is_tensorrt_build': False}
Jax version: 0.4.23
Jax GPU: [cuda(id=0), cuda(id=1), cuda(id=2), cuda(id=3), cuda(id=4)]
JaxLib version: 0.4.23
PyTorch version: 2.5.1+cu124
PyTorch GPU support: True
PyTorch GPUs: [<torch.cuda.device object at 0x7fec328bd8d0>, <torch.cuda.device object at 0x7fec328bd900>, <torch.cuda.device object at 0x7fec328bd8a0>, <torch.cuda.device object at 0x7fec328bdc30>, <torch.cuda.device object at 0x7fec328bdc90>]
Pytorch cuda version: 12.4
Cupy is not installed
Qiskit version: 1.3.1
Cirq version: 1.4.1
TensorCircuit version 0.12.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions