1
0
forked from IPKM/nmreval
nmreval/nmreval/fit/_meta.py
dominik 5120c9d57c delete graphs should be less leaky;
BDS: added CD/CC+HF
2022-08-10 20:00:10 +02:00

189 lines
6.1 KiB
Python

from typing import Union, Callable, Any
import operator
from inspect import signature, Parameter
class ModelFactory:
@staticmethod
def create_from_list(funcs: list, left=None, func_order=None, param_len=None, left_cnt=None):
if func_order is None:
func_order = []
if param_len is None:
param_len = []
for func in funcs:
if not func['active']:
continue
func_order.append(func['cnt'])
param_len.append(len(func['func'].params))
if func['children']:
right, _, _ = ModelFactory.create_from_list(func['children'], left=func['func'], left_cnt=func['pos'],
func_order=func_order, param_len=param_len)
right_cnt = None
else:
right = func['func']
right_cnt = func['pos']
if left is None:
left = right
left_cnt = right_cnt
else:
left = MultiModel(left, right, func['op'],
left_idx=left_cnt, right_idx=right_cnt)
return left, func_order, param_len
class MultiModel:
op_repr = {operator.add: ' + ', operator.mul: ' * ', operator.sub: ' - ', operator.truediv: ' / '}
str_op = {'+': operator.add, '*': operator.mul, '-': operator.sub, '/': operator.truediv}
int_op = {0: operator.add, 1: operator.mul, 2: operator.sub, 3: operator.truediv}
def __init__(self, left: Any, right: Any, op: Union[str, Callable, int] = '+', left_idx=0, right_idx=1):
self._left = left
self._right = right
self._op = None
if isinstance(op, str):
self._op = MultiModel.str_op.get(op, None)
elif isinstance(op, int):
self._op = MultiModel.int_op.get(op, None)
elif isinstance(op, Callable):
self._op = op
if self._op is None:
raise ValueError('Invalid binary operator.')
self.name = '('
self.params = []
self.bounds = []
self._kwargs_right = {}
self._kwargs_left = {}
self.fun_kwargs = {}
# mapping kwargs to kwargs of underlying functions
self._ext_int_kw = {}
self._get_parameter(left, 'l', left_idx)
self._param_left = len(left.params)
try:
self.name += MultiModel.op_repr[self._op]
except KeyError:
self.name += str(op)
self._get_parameter(right, 'r', right_idx)
self.name += ')'
self._param_len = len(self.params)
def __str__(self):
return self.name
def _get_parameter(self, func, pos, idx):
kw_dict = {'l': self._kwargs_left, 'r': self._kwargs_right}[pos]
if isinstance(func, MultiModel):
strcnt = ''
kw_dict.update(func.fun_kwargs)
self.fun_kwargs.update({k: v for k, v in kw_dict.items()})
self._ext_int_kw.update({k: k for k in kw_dict.keys()})
else:
temp_dic = {k: v.default for k, v in signature(func.func).parameters.items()
if v.default is not Parameter.empty}
for k, v in temp_dic.items():
key_ = k
if k != 'complex_mode':
key_ = f'{k}_{idx}'
kw_dict[key_] = v
self.fun_kwargs[key_] = v
self._ext_int_kw[key_] = k
strcnt = f'({idx})'
self.params += [pp+strcnt for pp in func.params]
self.name += func.name + strcnt
try:
self.bounds.extend(func.bounds)
except AttributeError:
self.bounds.extend([(None, None)]*len(func.params))
def _left_arguments(self, *args, **kwargs):
kw_left = {}
for k_ext, k_int in self._ext_int_kw.items():
if k_ext in self._kwargs_left:
if not k_ext.startswith('complex_mode'):
kw_left[k_int] = kwargs[k_ext]
else:
kw_left['complex_mode'] = kwargs['complex_mode']
pl = args[:self._param_left]
return pl, kw_left
def _right_arguments(self, *args, **kwargs):
kw_right = {}
for k_ext, k_int in self._ext_int_kw.items():
if k_ext in self._kwargs_right:
if not k_ext.startswith('complex_mode'):
kw_right[k_int] = kwargs[k_ext]
else:
kw_right['complex_mode'] = kwargs['complex_mode']
pr = args[self._param_left:self._param_len]
return pr, kw_right
def func(self, x, *args, **kwargs):
pl, kw_left = self._left_arguments(*args, **kwargs)
l_func = self._left.func(x, *pl, **kw_left)
pr, kw_right = self._right_arguments(*args, **kwargs)
r_func = self._right.func(x, *pr, **kw_right)
return self._op(l_func, r_func)
def left_func(self, x, *args, **kwargs):
return self._left.func(x, *args, **kwargs)
def right_func(self, x, *args, **kwargs):
return self._right.func(x, *args, **kwargs)
def subs(self, x, *args, **kwargs):
""" Iterator over all sub-functions (depth-first and left-to-right) """
pl, kw_left = self._left_arguments(*args, **kwargs)
if isinstance(self._left, MultiModel):
yield from self._left.subs(x, *pl, **kw_left)
else:
yield self._left.func(x, *pl, **kw_left)
pr, kw_right = self._right_arguments(*args, **kwargs)
if isinstance(self._right, MultiModel):
yield from self._right.subs(x, *pr, **kw_right)
else:
yield self._right.func(x, *pr, **kw_right)
def sub_name(self):
if isinstance(self._left, MultiModel):
yield from self._left.sub_name()
elif hasattr(self._left, 'name'):
yield self._left.name
else:
yield self.name + '(lhs)'
if isinstance(self._right, MultiModel):
yield from self._right.sub_name()
elif hasattr(self._right, 'name'):
yield self._right.name
else:
yield self.name + '(rhs)'