Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions cirq-core/cirq/study/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,20 @@ def __init__(self, param_dict: cirq.ParamResolverOrSimilarType = None) -> None:

self._param_hash: int | None = None
self._param_dict = cast(ParamDictType, {} if param_dict is None else param_dict)
self._param_dict_with_str_keys = self._param_dict
generate_str_keys = False
for key in self._param_dict:
if isinstance(key, sympy.Expr) and not isinstance(key, sympy.Symbol):
raise TypeError(f'ParamResolver keys cannot be (non-symbol) formulas ({key})')
if isinstance(key, sympy.Expr):
if isinstance(key, sympy.Symbol):
generate_str_keys = True
else:
raise TypeError(f'ParamResolver keys cannot be (non-symbol) formulas ({key})')
if generate_str_keys:
# Remake dictionary with string keys for faster access
self._param_dict_with_str_keys = {
(key.name if isinstance(key, sympy.Symbol) else key): value
for key, value in self._param_dict.items()
}
self._deep_eval_map: ParamDictType = {}

@property
Expand Down Expand Up @@ -119,22 +130,23 @@ def value_of(
"""

# Handle string or symbol
if isinstance(value, (str, sympy.Symbol)):
string = value if isinstance(value, str) else value.name
param_value = self._param_dict.get(string, _NOT_FOUND)
original_value = value
if isinstance(value, sympy.Symbol):
value = value.name
if isinstance(value, str):
param_value = self._param_dict_with_str_keys.get(value, _NOT_FOUND)
if isinstance(param_value, float):
return param_value
if param_value is _NOT_FOUND:
symbol = value if isinstance(value, sympy.Symbol) else sympy.Symbol(value)
param_value = self._param_dict.get(symbol, _NOT_FOUND)
if param_value is _NOT_FOUND:
# Symbol or string cannot be resolved if not in param dict; return as symbol.
return symbol
# Symbol or string cannot be resolved if not in param dict; return as symbol.
return sympy.Symbol(value)
v = _resolve_value(param_value)
if v is not NotImplemented:
return v
if isinstance(param_value, str):
param_value = sympy.Symbol(param_value)
elif not isinstance(param_value, sympy.Basic):
return value
return original_value
if recursive:
param_value = self._value_of_recursive(value)
return param_value
Expand Down Expand Up @@ -210,7 +222,7 @@ def _value_of_recursive(self, value: cirq.TParamKey) -> cirq.TParamValComplex:
self._deep_eval_map[value] = _RECURSION_FLAG

v = self.value_of(value, recursive=False)
if v == value:
if v == value or (isinstance(v, sympy.Symbol) and v.name == value):
self._deep_eval_map[value] = v
else:
self._deep_eval_map[value] = self.value_of(v, recursive=True)
Expand Down Expand Up @@ -278,7 +290,7 @@ def _from_json_dict_(cls, param_dict, **kwargs):


def _resolve_value(val: Any) -> Any:
if val is None or isinstance(val, float):
if isinstance(val, float) or val is None:
return val
if isinstance(val, numbers.Number) and not isinstance(val, sympy.Basic):
return val
Expand Down