Merge branch 'main' into fit_constraints

# Conflicts:
#	src/gui_qt/fit/fit_forms.py
#	src/gui_qt/main/management.py
#	src/nmreval/fit/minimizer.py
This commit is contained in:
Dominik Demuth
2023-09-11 18:18:30 +02:00
46 changed files with 428 additions and 216 deletions

View File

@ -211,7 +211,6 @@ class NMRMainWindow(QtWidgets.QMainWindow, Ui_BaseWindow):
self.ptsselectwidget.points_selected.connect(self.management.extract_points)
self.t1tauwidget.newData.connect(self.management.add_new_data)
self.t1tauwidget.newData.connect(self.management.add_new_data)
self.editsignalwidget.do_something.connect(self.management.apply)
@ -917,10 +916,12 @@ class NMRMainWindow(QtWidgets.QMainWindow, Ui_BaseWindow):
self.action_odr_fit: 'odr'
}[self.ac_group.checkedAction()]
self.fit_dialog.fit_button.setEnabled(False)
self.management.start_fit(parameter, links, fit_options)
self.status.setText('Fit running...'.format(self.management.fitter.step))
self.fit_timer.start(500)
fit_is_ready = self.management.prepare_fit(parameter, links, fit_options)
if fit_is_ready:
self.management.start_fit()
self.fit_dialog.fit_button.setEnabled(False)
self.status.setText('Fit running...'.format(self.management.fitter.step))
self.fit_timer.start(500)
@QtCore.pyqtSlot(dict, int, bool)
def show_fit_preview(self, funcs: dict, num: int, show: bool):

View File

@ -58,11 +58,18 @@ class GraphDict(OrderedDict):
def list(self):
return [(k, v.title) for k, v in self.items()]
def active(self, key: str):
if key:
return [(self._data[i].id, self._data[i].name) for i in self[key]]
else:
def active(self, key: str, return_val: str = 'both'):
if not key:
return []
else:
if return_val == 'both':
return [(self._data[i].id, self._data[i].name) for i in self[key]]
elif return_val == 'id':
return [self._data[i].id for i in self[key]]
elif return_val == 'name':
return [self._data[i].name for i in self[key]]
else:
raise ValueError(f'return_val got wrong value {return_val!r}')
def current_sets(self, key: str):
if key:
@ -148,6 +155,10 @@ class UpperManagement(QtCore.QObject):
def active_sets(self):
return self.graphs.active(self.current_graph)
@property
def active_id(self):
return self.graphs.active(self.current_graph, return_val='id')
def get_attributes(self, graph_id: str, attr: str) -> dict[str, Any]:
return {self.data[i].id: getattr(self.data[i], attr) for i in self.graphs[graph_id].sets}
@ -413,9 +424,9 @@ class UpperManagement(QtCore.QObject):
for d in self.data.values():
d.mask = np.ones_like(d.mask, dtype=bool)
def start_fit(self, parameter: dict, links: list, fit_options: dict):
def prepare_fit(self, parameter: dict, links: list, fit_options: dict) -> bool:
if self._fit_active:
return
return False
self.__fit_options = (parameter, links, fit_options)
@ -423,60 +434,84 @@ class UpperManagement(QtCore.QObject):
models = {}
fit_limits = fit_options['limits']
fit_mode = fit_options['fit_mode']
we = fit_options['we']
we_option = fit_options['we']
for model_id, model_p in parameter.items():
m = Model(model_p['func'])
models[model_id] = m
self.fitter.fitmethod = fit_mode
m_complex = model_p['complex']
# all-encompassing error catch
try:
for model_id, model_p in parameter.items():
m = Model(model_p['func'])
models[model_id] = m
for set_id, set_params in model_p['parameter'].items():
data_i = self.data[set_id]
if we.lower() == 'deltay':
we = data_i.y_err**2
m_complex = model_p['complex']
if m_complex is None or m_complex == 1:
_y = data_i.y.real
elif m_complex == 2 and np.iscomplexobj(data_i.y):
_y = data_i.y.imag
else:
_y = data_i.y
# sets are not in active order but in order they first appeared in fit dialog
# iterate over order of set id in active order and access parameter inside loop
# instead of directly looping
list_ids = list(model_p['parameter'].keys())
set_order = [self.active_id.index(i) for i in list_ids]
for pos in set_order:
set_id = list_ids[pos]
_x = data_i.x
data_i = self.data[set_id]
set_params = model_p['parameter'][set_id]
if fit_limits == 'none':
inside = slice(None)
elif fit_limits == 'x':
x_lim, _ = self.graphs[self.current_graph].ranges
inside = np.where((_x >= x_lim[0]) & (_x <= x_lim[1]))
else:
inside = np.where((_x >= fit_limits[0]) & (_x <= fit_limits[1]))
if we_option.lower() == 'deltay':
we = data_i.y_err**2
else:
we = we_option
if isinstance(we, str):
d = fit_d.Data(_x[inside], _y[inside], we=we, idx=set_id)
else:
d = fit_d.Data(_x[inside], _y[inside], we=we[inside], idx=set_id)
if m_complex is None or m_complex == 1:
_y = data_i.y.real
elif m_complex == 2 and np.iscomplexobj(data_i.y):
_y = data_i.y.imag
else:
_y = data_i.y
d.set_model(m)
d.set_parameter(set_params[0], var=model_p['var'],
lb=model_p['lb'], ub=model_p['ub'],
fun_kwargs=set_params[1])
_x = data_i.x
self.fitter.add_data(d)
if fit_limits == 'none':
inside = slice(None)
elif fit_limits == 'x':
x_lim, _ = self.graphs[self.current_graph].ranges
inside = np.where((_x >= x_lim[0]) & (_x <= x_lim[1]))
else:
inside = np.where((_x >= fit_limits[0]) & (_x <= fit_limits[1]))
model_globs = model_p['glob']
if model_globs:
for parameter_args in zip(*model_globs.values()):
m.set_global_parameter(**{k: v for k, v in zip(model_globs.keys(), parameter_args)})
# m.set_global_parameter(**model_p['glob'])
if isinstance(we, str):
d = fit_d.Data(_x[inside], _y[inside], we=we, idx=set_id)
else:
d = fit_d.Data(_x[inside], _y[inside], we=we[inside], idx=set_id)
for links_i in links:
self.fitter.set_link_parameter((models[links_i[0]], links_i[1]),
(models[links_i[2]], links_i[3]))
d.set_model(m)
d.set_parameter(set_params[0], var=model_p['var'],
lb=model_p['lb'], ub=model_p['ub'],
fun_kwargs=set_params[1])
self.fitter.add_data(d)
model_globs = model_p['glob']
if model_globs:
for parameter_args in zip(*model_globs.values()):
m.set_global_parameter(**{k: v for k, v in zip(model_globs.keys(), parameter_args)})
# m.set_global_parameter(**model_p['glob'])
for links_i in links:
self.fitter.set_link_parameter((models[links_i[0]], links_i[1]),
(models[links_i[2]], links_i[3]))
return True
except Exception as e:
logger.error('Fit preparation failed', *e.args)
QtWidgets.QMessageBox.warning(QtWidgets.QWidget(),
'Fit prep failed',
f'Fit preparation failed with message\n{e.args}')
return False
def start_fit(self):
with busy_cursor():
self.fit_worker = FitWorker(self.fitter, fit_mode)
self.fit_worker = FitWorker(self.fitter)
self.fit_thread = QtCore.QThread()
self.fit_worker.moveToThread(self.fit_thread)
@ -512,7 +547,8 @@ class UpperManagement(QtCore.QObject):
for set_id, set_parameter in parameter.items():
new_values = [v.value for v in res[set_id].parameter.values()]
parameter[set_id] = (new_values, set_parameter[1])
self.start_fit(*self.__fit_options)
if self.prepare_fit(*self.__fit_options):
self.start_fit()
def make_fits(self, res: dict, opts: list, param_graph: str, show_fit: bool, parts: bool, extrapolate: list) -> None:
"""
@ -637,7 +673,7 @@ class UpperManagement(QtCore.QObject):
def save_fit_parameter(self, fname: str | pathlib.Path, fit_sets: list[str] = None):
if fit_sets is None:
fit_sets = [s for (s, _) in self.active_sets]
fit_sets = [s for s in self.active_id]
for set_id in fit_sets:
data = self.data[set_id]
@ -843,13 +879,10 @@ class UpperManagement(QtCore.QObject):
d_k = self.data[k]
if copy_data is None:
d_k.x = d_k.x*v[1][0] + v[0][0]
d_k.y = d_k.y*v[1][1] + v[0][1]
d_k.shift_scale(v[0], v[1])
else:
new_data = d_k.copy(full=True)
new_data.update({'shift': v[0], 'scale': v[1]})
new_data.data.x = new_data.x*v[1][0] + v[0][0]
new_data.y = new_data.y*v[1][1] + v[0][1]
new_data.shift_scale(v[0], v[1])
sid = self.add(new_data)
sid_list.append(sid)
@ -1009,7 +1042,7 @@ class UpperManagement(QtCore.QObject):
def show_statistics(self, mode):
x, y, = [], []
for i, _ in self.active_sets:
for i in self.active_id:
_temp = self.data[i]
try:
x.append(float(_temp.name))
@ -1020,7 +1053,7 @@ class UpperManagement(QtCore.QObject):
@QtCore.pyqtSlot()
def calc_magn(self):
new_id = []
for k, _ in self.active_sets:
for k in self.active_id:
dataset = self.data[k]
if isinstance(dataset, SignalContainer):
new_value = dataset.copy(full=True)
@ -1032,7 +1065,7 @@ class UpperManagement(QtCore.QObject):
@QtCore.pyqtSlot()
def center(self):
new_id = []
for k, _ in self.active_sets:
for k in self.active_id:
new_value = self.data[k].copy(full=True)
new_value.x -= new_value.x[np.argmax(new_value.y.real)]
new_id.append(self.add(new_value))
@ -1071,7 +1104,7 @@ class UpperManagement(QtCore.QObject):
def bds_deriv(self):
new_sets = []
for (set_id, _) in self.active_sets:
for set_id in self.active_id:
data_i = self.data[set_id]
diff = data_i.data.diff(log=True)
new_data = Points(x=diff.x, y=-np.pi/2*diff.y.real)
@ -1098,7 +1131,7 @@ class UpperManagement(QtCore.QObject):
self.newData.emit(new_sets, kwargs['graph'])
def skip_points(self, offset: int, step: int, invert: bool = False, copy: bool = False):
for k, _ in self.active_sets:
for k in self.active_id:
src = self.data[k]
if invert:
mask = np.mod(np.arange(offset, src.x.size+offset), step) != 0
@ -1253,16 +1286,15 @@ class UpperManagement(QtCore.QObject):
class FitWorker(QtCore.QObject):
finished = QtCore.pyqtSignal(list, bool)
def __init__(self, fitter, mode):
def __init__(self, fitter):
super().__init__()
self.fitter = fitter
self.mode = mode
@QtCore.pyqtSlot()
def run(self):
try:
res = self.fitter.run(mode=self.mode)
res = self.fitter.run()
success = True
except Exception as e:
res = [e]