diff --git a/src/nmreval/models/transitions.py b/src/nmreval/models/transitions.py index 217903f..ed80f59 100644 --- a/src/nmreval/models/transitions.py +++ b/src/nmreval/models/transitions.py @@ -4,16 +4,26 @@ from scipy import special as special from ..utils import kB -class Weight2Phase: +class Weight: type = 'Line shape' name = 'Weighting factor' - equation = r'A*[0.5 + 0.5 erf[(x-T_{0})/\DeltaT]] + A_{0}' + equation = r'A * [0.5 \pm 0.5 erf[(x-T_{0})/\DeltaT]] + A_{0}' params = ['T_{0}', r'\DeltaT', 'A', 'A_{0}'] + choices = [('Direction', 'sign', {'increase': '+', 'decrease': '-'})] bounds = [(0, None), (0, None), (None, None), (None, None)] @staticmethod - def func(x, t0, dt, amp, off): - return amp*(0.5 + 0.5*special.erf((x-t0)/dt)) + off + def func(x: np.ndarray | float, t0: float, dt: float, amp: float, off: float, sign: str = '+') -> np.ndarray | float: + if sign not in '+-': + raise ValueError(f"`value` is `+` or `-`, not {sign}") + + error_func = 1 + if sign == '+': + error_func += special.erf((x-t0)/dt) + else: + error_func -= special.erf((x - t0) / dt) + + return amp * error_func / 2. + off class HendricksonBray: @@ -24,5 +34,5 @@ class HendricksonBray: bounds = [(0, None)] * 4 @staticmethod - def func(x, a, b, e, w0): + def func(x: np.ndarray | float, a: float, b: float, e: float, w0: float) -> np.ndarray | float: return a*b / (b + (a-b)*np.exp(-e/kB/x)) + w0