added a _seen set to avoid infinite recursions due to function arguments; also, applied pbc_diff to neighbors in tetrahedral order
This commit is contained in:
@ -72,7 +72,7 @@ def strip_comments(source: str) -> str:
|
||||
return "\n".join([line for line in code_no_comments.splitlines() if line.strip() != ""])
|
||||
|
||||
|
||||
def checksum(*args, csum=None):
|
||||
def checksum(*args, csum=None, _seen=None):
|
||||
"""
|
||||
Calculate a checksum of any object, by sha1 hash.
|
||||
|
||||
@ -92,7 +92,15 @@ def checksum(*args, csum=None):
|
||||
csum = hashlib.sha1()
|
||||
csum.update(str(SALT).encode())
|
||||
|
||||
if _seen is None:
|
||||
_seen = set()
|
||||
|
||||
for arg in args:
|
||||
obj_id = id(arg)
|
||||
if obj_id in _seen:
|
||||
continue
|
||||
_seen.add(obj_id)
|
||||
|
||||
if hasattr(arg, "__checksum__"):
|
||||
logger.debug("Checksum via __checksum__: %s", str(arg))
|
||||
csum.update(str(arg.__checksum__()).encode())
|
||||
@ -109,15 +117,15 @@ def checksum(*args, csum=None):
|
||||
for key in sorted(merged): # deterministic ordering
|
||||
v = merged[key]
|
||||
if v is not arg:
|
||||
checksum(v, csum=csum)
|
||||
checksum(v, csum=csum, _seen=_seen)
|
||||
elif isinstance(arg, functools.partial):
|
||||
logger.debug("Checksum via partial for %s", str(arg))
|
||||
checksum(arg.func, csum=csum)
|
||||
checksum(arg.func, csum=csum, _seen=_seen)
|
||||
for x in arg.args:
|
||||
checksum(x, csum=csum)
|
||||
checksum(x, csum=csum, _seen=_seen)
|
||||
for k in sorted(arg.keywords.keys()):
|
||||
csum.update(k.encode())
|
||||
checksum(arg.keywords[k], csum=csum)
|
||||
checksum(arg.keywords[k], csum=csum, _seen=_seen)
|
||||
elif isinstance(arg, np.ndarray):
|
||||
csum.update(arg.tobytes())
|
||||
else:
|
||||
|
@ -182,10 +182,10 @@ def tetrahedral_order(
|
||||
)
|
||||
|
||||
# Connection vectors
|
||||
neighbors_1 -= atoms
|
||||
neighbors_2 -= atoms
|
||||
neighbors_3 -= atoms
|
||||
neighbors_4 -= atoms
|
||||
neighbors_1 = pbc_diff(neighbors_1, atoms, box=atoms.box)
|
||||
neighbors_2 = pbc_diff(neighbors_2, atoms, box=atoms.box)
|
||||
neighbors_3 = pbc_diff(neighbors_3, atoms, box=atoms.box)
|
||||
neighbors_4 = pbc_diff(neighbors_4, atoms, box=atoms.box)
|
||||
|
||||
# Normed Connection vectors
|
||||
neighbors_1 /= np.linalg.norm(neighbors_1, axis=-1).reshape(-1, 1)
|
||||
|
Reference in New Issue
Block a user