Skip to content

Commit 0c589ec

Browse files
authored
Merge branch 'master' into update_docker_readme
2 parents f71d433 + 9f3a2ee commit 0c589ec

File tree

2 files changed

+113
-14
lines changed

2 files changed

+113
-14
lines changed

src/openfermion/ops/operators/symbolic_operator.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class SymbolicOperator(metaclass=abc.ABCMeta):
6666

6767
@staticmethod
6868
def _issmall(val, tol=EQ_TOLERANCE):
69-
'''Checks whether a value is near-zero
69+
'''Checks whether a value is near zero.
7070
7171
Parses the allowed coefficients above for near-zero tests.
7272
@@ -618,34 +618,64 @@ def __next__(self):
618618
term, coefficient = next(self._iter)
619619
return self.__class__(term=term, coefficient=coefficient)
620620

621-
def isclose(self, other, tol=EQ_TOLERANCE):
621+
def isclose(self, other, tol=None, rtol=EQ_TOLERANCE, atol=EQ_TOLERANCE):
622622
"""Check if other (SymbolicOperator) is close to self.
623623
624624
Comparison is done for each term individually. Return True
625625
if the difference between each term in self and other is
626-
less than EQ_TOLERANCE
626+
less than the specified tolerance.
627627
628628
Args:
629629
other(SymbolicOperator): SymbolicOperator to compare against.
630+
tol(float): This parameter is deprecated since version 1.8.0.
631+
Use `rtol` and/or `atol` instead. If `tol` is provided, it
632+
is used as the value of `atol`.
633+
rtol(float): Relative tolerance used in comparing each term in
634+
self and other.
635+
atol(float): Absolute tolerance used in comparing each term in
636+
self and other.
630637
"""
631638
if not isinstance(self, type(other)):
632639
return NotImplemented
633640

641+
if tol is not None:
642+
if rtol != EQ_TOLERANCE or atol != EQ_TOLERANCE:
643+
raise ValueError(
644+
'Parameters rtol and atol are mutually exclusive with the'
645+
' deprecated parameter tol; use either tol or the other two,'
646+
' not in combination.'
647+
)
648+
warnings.warn(
649+
'Parameter tol is deprecated. Use rtol and/or atol instead.',
650+
DeprecationWarning,
651+
stacklevel=2, # Identify the location of the warning.
652+
)
653+
atol = tol
654+
634655
# terms which are in both:
635656
for term in set(self.terms).intersection(set(other.terms)):
636657
a = self.terms[term]
637658
b = other.terms[term]
638-
if not (isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr)):
639-
tol *= max(1, abs(a), abs(b))
640-
if self._issmall(a - b, tol) is False:
659+
if isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr):
660+
if not self._issmall(a - b, atol):
661+
return False
662+
elif not abs(a - b) <= atol + rtol * max(abs(a), abs(b)):
641663
return False
642-
# terms only in one (compare to 0.0 so only abs_tol)
664+
# terms only in one (compare to 0.0 so only atol)
643665
for term in set(self.terms).symmetric_difference(set(other.terms)):
644666
if term in self.terms:
645-
if self._issmall(self.terms[term], tol) is False:
667+
coeff = self.terms[term]
668+
if isinstance(coeff, sympy.Expr):
669+
if not self._issmall(coeff, atol):
670+
return False
671+
elif not abs(coeff) <= atol:
646672
return False
647673
else:
648-
if self._issmall(other.terms[term], tol) is False:
674+
coeff = other.terms[term]
675+
if isinstance(coeff, sympy.Expr):
676+
if not self._issmall(coeff, atol):
677+
return False
678+
elif not abs(coeff) <= atol:
649679
return False
650680
return True
651681

src/openfermion/ops/operators/symbolic_operator_test.py

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,16 @@
1313
"""Tests symbolic_operator.py."""
1414

1515
import copy
16-
import unittest
17-
import warnings
18-
1916
import numpy
2017
import sympy
18+
import unittest
19+
import warnings
2120

2221
from openfermion.config import EQ_TOLERANCE
22+
from openfermion.testing.testing_utils import EqualsTester
23+
2324
from openfermion.ops.operators.fermion_operator import FermionOperator
2425
from openfermion.ops.operators.symbolic_operator import SymbolicOperator
25-
from openfermion.testing.testing_utils import EqualsTester
2626

2727

2828
class DummyOperator1(SymbolicOperator):
@@ -868,7 +868,76 @@ def test_pow_high_term(self):
868868
term = DummyOperator1(ops, coeff)
869869
high = term**10
870870
expected = DummyOperator1(ops * 10, coeff**10)
871-
self.assertTrue(expected == high)
871+
self.assertTrue(high.isclose(expected, rtol=1e-12, atol=1e-12))
872+
873+
def test_isclose_parameter_deprecation(self):
874+
op1 = DummyOperator1('0^ 1', 1.0)
875+
op2 = DummyOperator1('0^ 1', 1.001)
876+
877+
with self.assertWarns(DeprecationWarning):
878+
op1.isclose(op2, tol=0.01)
879+
880+
with warnings.catch_warnings():
881+
warnings.simplefilter("ignore", category=DeprecationWarning)
882+
self.assertTrue(op1.isclose(op2, tol=0.001))
883+
self.assertFalse(op1.isclose(op2, tol=0.0001))
884+
885+
def test_isclose_parameter_combos(self):
886+
op1 = DummyOperator1('0^ 1', 1.0)
887+
op2 = DummyOperator1('0^ 1', 1.001)
888+
889+
with self.assertRaises(ValueError):
890+
op1.isclose(op2, tol=0.01, rtol=1e-5)
891+
892+
with self.assertRaises(ValueError):
893+
op1.isclose(op2, tol=0.01, atol=1e-5)
894+
895+
def test_isclose_atol_rtol(self):
896+
op1 = DummyOperator1('0^ 1', 1.0)
897+
op2 = DummyOperator1('0^ 1', 1.001)
898+
899+
op_a = DummyOperator1('0^ 1', 1.0)
900+
op_b = DummyOperator1('0^ 1', 1.001)
901+
self.assertTrue(op_a.isclose(op_b, atol=0.001))
902+
self.assertFalse(op_a.isclose(op_b, atol=0.0001))
903+
904+
op_c = DummyOperator1('0^ 1', 1000)
905+
op_d = DummyOperator1('0^ 1', 1001)
906+
self.assertTrue(op_c.isclose(op_d, rtol=0.001))
907+
self.assertFalse(op_c.isclose(op_d, rtol=0.0001))
908+
909+
op_e = DummyOperator1('0^ 1', 1.0)
910+
op_f = DummyOperator1('0^ 1', 1.001)
911+
self.assertTrue(op_e.isclose(op_f, rtol=1e-4, atol=1e-3))
912+
self.assertFalse(op_e.isclose(op_f, rtol=1e-4, atol=1e-5))
913+
914+
def test_isclose(self):
915+
op1 = DummyOperator1()
916+
op2 = DummyOperator1()
917+
op1 += DummyOperator1('0^ 1', 1000000)
918+
op1 += DummyOperator1('2^ 3', 1)
919+
op2 += DummyOperator1('0^ 1', 1000000)
920+
op2 += DummyOperator1('2^ 3', 1.001)
921+
self.assertFalse(op1.isclose(op2, atol=1e-4))
922+
self.assertTrue(op1.isclose(op2, atol=1e-2))
923+
924+
# Case from https://github.com/quantumlib/OpenFermion/issues/764
925+
x = FermionOperator("0^ 0")
926+
y = FermionOperator("0^ 0")
927+
928+
# construct two identical operators up to some number of terms
929+
num_terms_before_ineq = 30
930+
for i in range(num_terms_before_ineq):
931+
x += FermionOperator(f" (10+0j) [0^ {i}]")
932+
y += FermionOperator(f" (10+0j) [0^ {i}]")
933+
934+
xfinal = FermionOperator(f" (1+0j) [0^ {num_terms_before_ineq + 1}]")
935+
yfinal = FermionOperator(f" (2+0j) [0^ {num_terms_before_ineq + 1}]")
936+
assert xfinal != yfinal
937+
938+
x += xfinal
939+
y += yfinal
940+
assert x != y
872941

873942
def test_pow_neg_error(self):
874943
with self.assertRaises(ValueError):

0 commit comments

Comments
 (0)