From 3af5cb0301277da95bb0fe3cc4a5d4806e61b552 Mon Sep 17 00:00:00 2001 From: Dominik Demuth Date: Sat, 16 Sep 2023 14:16:45 +0200 Subject: [PATCH] add todos --- src/nmreval/fit/minimizer.py | 15 +++++++++++++-- src/nmreval/fit/parameter.py | 11 +++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/nmreval/fit/minimizer.py b/src/nmreval/fit/minimizer.py index 50fa4e4..5594fd1 100644 --- a/src/nmreval/fit/minimizer.py +++ b/src/nmreval/fit/minimizer.py @@ -29,6 +29,7 @@ def _cost_scipy_glob(p: list[float], data: list[Data], varpars: list[str], used_ for keys, values in zip(varpars, p): for data_i in data: if keys in data_i.parameter.keys(): + # TODO move this to scaled_value setter data_i.parameter[keys].scaled_value = values data_i.parameter[keys].namespace[keys] = data_i.parameter[keys].value r = [] @@ -220,7 +221,7 @@ class FitRoutine(object): logger.info('Fit aborted by user') self._abort = True - def run(self, mode: str=None): + def run(self, mode: str = None): self._abort = False if mode is None: @@ -262,6 +263,16 @@ class FitRoutine(object): return self.result + def make_preview(self, x: np.ndarray) -> list[np.ndarray]: + y_pred = [] + fit_groups, linked_parameter = self.prepare_links() + for data_groups in fit_groups: + data = data_groups[0] + actual_parameters = [p.value for p in data.parameter.values()] + y_pred.append(data.func(actual_parameters, x)) + + return y_pred + def _prep_data(self, data): if data.get_model() is None: data._model = self.fit_model @@ -317,6 +328,7 @@ class FitRoutine(object): d._model = None self._no_own_model = [] + Parameters.reset() def _least_squares_single(self, data, p0, lb, ub, var): self.step = 0 @@ -345,7 +357,6 @@ class FitRoutine(object): with np.errstate(all='ignore'): res = optimize.least_squares(cost, p0, bounds=(lb, ub), max_nfev=500 * len(p0)) - err, corr, partial_corr = self._calc_error(res.jac, np.sum(res.fun**2), *res.jac.shape) for v, var_pars_k in zip(data, data_pars): self.make_results(v, res.x, var, var_pars_k, res.jac.shape, diff --git a/src/nmreval/fit/parameter.py b/src/nmreval/fit/parameter.py index e9bf2c1..6d5457e 100644 --- a/src/nmreval/fit/parameter.py +++ b/src/nmreval/fit/parameter.py @@ -88,6 +88,15 @@ class Parameters(dict): expression = re.sub(re.escape(n), k, expression) p._expr = expression + def fix(self): + for v in self.keys(): + v._value = v.value + v.namespace = {} + + @staticmethod + def reset(): + Parameters.namespace = {} + def get_key(self, name: str) -> str | None: for k, v in self.items(): if name == v.name: @@ -104,6 +113,7 @@ class Parameter: Container for one parameter """ + # TODO Parameter should know its own key def __init__(self, name: str, value: float | str, var: bool = True, lb: float = -np.inf, ub: float = np.inf): self._value: float | None = None self.var: bool = bool(var) if var is not None else True @@ -181,6 +191,7 @@ class Parameter: @property def value(self) -> float: + # TODO first _value, then _expr if self._expr is not None and self.eval_allowed: return eval(self._expr, {}, self.namespace) elif self._value is not None: