added a _seen set to avoid infinite recursions due to function arguments; also, applied pbc_diff to neighbors in tetrahedral order

This commit is contained in:
robrobo
2025-07-11 20:54:27 +02:00
parent 7585e598dc
commit 00043637e9
2 changed files with 17 additions and 9 deletions

View File

@ -72,7 +72,7 @@ def strip_comments(source: str) -> str:
return "\n".join([line for line in code_no_comments.splitlines() if line.strip() != ""])
def checksum(*args, csum=None):
def checksum(*args, csum=None, _seen=None):
"""
Calculate a checksum of any object, by sha1 hash.
@ -92,7 +92,15 @@ def checksum(*args, csum=None):
csum = hashlib.sha1()
csum.update(str(SALT).encode())
if _seen is None:
_seen = set()
for arg in args:
obj_id = id(arg)
if obj_id in _seen:
continue
_seen.add(obj_id)
if hasattr(arg, "__checksum__"):
logger.debug("Checksum via __checksum__: %s", str(arg))
csum.update(str(arg.__checksum__()).encode())
@ -109,15 +117,15 @@ def checksum(*args, csum=None):
for key in sorted(merged): # deterministic ordering
v = merged[key]
if v is not arg:
checksum(v, csum=csum)
checksum(v, csum=csum, _seen=_seen)
elif isinstance(arg, functools.partial):
logger.debug("Checksum via partial for %s", str(arg))
checksum(arg.func, csum=csum)
checksum(arg.func, csum=csum, _seen=_seen)
for x in arg.args:
checksum(x, csum=csum)
checksum(x, csum=csum, _seen=_seen)
for k in sorted(arg.keywords.keys()):
csum.update(k.encode())
checksum(arg.keywords[k], csum=csum)
checksum(arg.keywords[k], csum=csum, _seen=_seen)
elif isinstance(arg, np.ndarray):
csum.update(arg.tobytes())
else:

View File

@ -182,10 +182,10 @@ def tetrahedral_order(
)
# Connection vectors
neighbors_1 -= atoms
neighbors_2 -= atoms
neighbors_3 -= atoms
neighbors_4 -= atoms
neighbors_1 = pbc_diff(neighbors_1, atoms, box=atoms.box)
neighbors_2 = pbc_diff(neighbors_2, atoms, box=atoms.box)
neighbors_3 = pbc_diff(neighbors_3, atoms, box=atoms.box)
neighbors_4 = pbc_diff(neighbors_4, atoms, box=atoms.box)
# Normed Connection vectors
neighbors_1 /= np.linalg.norm(neighbors_1, axis=-1).reshape(-1, 1)