diff --git a/abcd/model.py b/abcd/model.py index 4b9dabf..406c77d 100644 --- a/abcd/model.py +++ b/abcd/model.py @@ -12,8 +12,8 @@ class Hasher(object): - def __init__(self, method=md5()): - self.method = method + def __init__(self, method=md5): + self.method = method() def update(self, value): @@ -273,26 +273,25 @@ def pre_save(self): self["username"] = getpass.getuser() if not self.get("uploaded"): - self["uploaded"] = datetime.datetime.utcnow() + self["uploaded"] = datetime.datetime.now(datetime.timezone.utc) - self["modified"] = datetime.datetime.utcnow() + self["modified"] = datetime.datetime.now(datetime.timezone.utc) - m = Hasher() + hasher = Hasher() for key in ("numbers", "positions", "cell", "pbc"): - m.update(self[key]) + hasher.update(self[key]) self.derived_keys.append("hash_structure") - self["hash_structure"] = m() + self["hash_structure"] = hasher() - m = Hasher() for key in self.arrays_keys: - m.update(self[key]) + hasher.update(self[key]) for key in self.info_keys: - m.update(self[key]) + hasher.update(self[key]) self.derived_keys.append("hash") - self["hash"] = m() + self["hash"] = hasher() if __name__ == "__main__": diff --git a/tests/test_abstract_model.py b/tests/test_abstract_model.py index f9e820e..6960950 100644 --- a/tests/test_abstract_model.py +++ b/tests/test_abstract_model.py @@ -1,3 +1,4 @@ +import datetime import io import ase @@ -8,7 +9,7 @@ from ase.io import read, write import numpy as np -from abcd.model import AbstractModel +from abcd.model import AbstractModel, Hasher from ase.calculators.lj import LennardJones @@ -238,7 +239,6 @@ def test_write_and_read(store_calc): "hash", "modified", "uploaded", - "hash_structure", # see issue #118 }: assert ( abcd_data[key] == abcd_data_after_read[key] @@ -247,10 +247,75 @@ def test_write_and_read(store_calc): # expected differences - n.b. order of calls above assert abcd_data_after_read["modified"] > abcd_data["modified"] assert abcd_data_after_read["uploaded"] > abcd_data["uploaded"] - assert abcd_data_after_read["hash"] != abcd_data["hash"] # expect results to match within fp precision for key in set(abcd_data.results_keys): assert abcd_data[key] == approx( np.array(abcd_data_after_read[key]) ), f"{key}'s value does not match" + + +def test_hash_update(): + """Test hash can be updated after initialisation.""" + hasher_1 = Hasher() + + init_hash = hasher_1() + hasher_1.update("Test value") + assert hasher_1() != init_hash + + +@pytest.mark.parametrize( + "data", + [ + 1296, + 3.14, + [1, 2, 3], + (4, 5, 6), + {"a": "value"}, + datetime.datetime.now(datetime.timezone.utc), + b"test", + ], +) +def test_hash_data_types(data): + """Test updating hash for different data types.""" + hasher_1 = Hasher() + hasher_1.update("Test value") + updated_hash = hasher_1() + + hasher_1.update(data) + assert updated_hash != hasher_1() + + +def test_second_hash_init(): + """Test second hash is initialised correctly.""" + hasher_1 = Hasher() + + init_hash = hasher_1() + hasher_1.update("Test value") + + hasher_2 = Hasher() + assert hasher_2() == init_hash + + +@pytest.mark.parametrize( + "data", + [ + 1296, + 3.14, + [1, 2, 3], + (4, 5, 6), + {"a": "value"}, + datetime.datetime.now(datetime.timezone.utc), + b"test", + ], +) +def test_consistent_hash(data): + """Test two hashers agree with same data.""" + hasher_1 = Hasher() + hasher_1.update("Test value") + hasher_1.update(data) + + hasher_2 = Hasher() + hasher_2.update("Test value") + hasher_2.update(data) + assert hasher_1() == hasher_2()