1
0
forked from IPKM/nmreval
nmreval/nmreval/fit/_meta.py
2022-03-08 10:27:40 +01:00

162 lines
5.2 KiB
Python

from typing import Union, Callable, Any
import operator
from inspect import signature, Parameter
class ModelFactory:
@staticmethod
def create_from_list(funcs: list, left=None, func_order=None, param_len=None, left_cnt=None):
if func_order is None:
func_order = []
if param_len is None:
param_len = []
for func in funcs:
if not func['active']:
continue
func_order.append(func['cnt'])
param_len.append(len(func['func'].params))
if func['children']:
right, _, _ = ModelFactory.create_from_list(func['children'], left=func['func'], left_cnt=func['pos'],
func_order=func_order, param_len=param_len)
right_cnt = None
else:
right = func['func']
right_cnt = func['pos']
if left is None:
left = right
left_cnt = right_cnt
else:
left = MultiModel(left, right, func['op'],
left_idx=left_cnt, right_idx=right_cnt)
return left, func_order, param_len
class MultiModel:
op_repr = {operator.add: ' + ', operator.mul: ' * ', operator.sub: ' - ', operator.truediv: ' / '}
str_op = {'+': operator.add, '*': operator.mul, '-': operator.sub, '/': operator.truediv}
int_op = {0: operator.add, 1: operator.mul, 2: operator.sub, 3: operator.truediv}
def __init__(self, left: Any, right: Any, op: Union[str, Callable, int] = '+', left_idx=0, right_idx=1):
self._left = left
self._right = right
self._op = None
if isinstance(op, str):
self._op = MultiModel.str_op.get(op, None)
elif isinstance(op, int):
self._op = MultiModel.int_op.get(op, None)
elif isinstance(op, Callable):
self._op = op
if self._op is None:
raise ValueError('Invalid binary operator.')
self.name = '('
self.params = []
self.bounds = []
self._kwargs_right = {}
self._kwargs_left = {}
self._fun_kwargs = {}
# mapping kwargs to kwargs of underlying functions
self._ext_int_kw = {}
self._get_parameter(left, 'l', left_idx)
self._param_left = len(left.params)
try:
self.name += MultiModel.op_repr[self._op]
except KeyError:
self.name += str(op)
self._get_parameter(right, 'r', right_idx)
self.name += ')'
self._param_len = len(self.params)
def __str__(self):
return self.name
def _get_parameter(self, func, pos, idx):
kw_dict = {'l': self._kwargs_left, 'r': self._kwargs_right}[pos]
if isinstance(func, MultiModel):
strcnt = ''
kw_dict.update(func.fun_kwargs)
self._fun_kwargs.update({k: v for k, v in kw_dict.items()})
self._ext_int_kw.update({k: k for k in kw_dict.keys()})
else:
temp_dic = {k: v.default for k, v in signature(func.func).parameters.items()
if v.default is not Parameter.empty}
for k, v in temp_dic.items():
key_ = '%s_%d' % (k, idx)
kw_dict[key_] = v
self._fun_kwargs[key_] = v
self._ext_int_kw[key_] = k
strcnt = '(%d)' % idx
self.params += [pp+strcnt for pp in func.params]
self.name += func.name + strcnt
try:
self.bounds.extend(func.bounds)
except AttributeError:
self.bounds.extend([(None, None)]*len(func.params))
def _left_arguments(self, *args, **kwargs):
kw_left = {k_int: kwargs[k_ext] for k_ext, k_int in self._ext_int_kw.items() if k_ext in self._kwargs_left}
pl = args[:self._param_left]
return pl, kw_left
def _right_arguments(self, *args, **kwargs):
kw_right = {k_int: kwargs[k_ext] for k_ext, k_int in self._ext_int_kw.items() if k_ext in self._kwargs_right}
pr = args[self._param_left:self._param_len]
return pr, kw_right
def func(self, x, *args, **kwargs):
pl, kw_left = self._left_arguments(*args, **kwargs)
l_func = self._left.func(x, *pl, **kw_left)
pr, kw_right = self._right_arguments(*args, **kwargs)
r_func = self._right.func(x, *pr, **kw_right)
return self._op(l_func, r_func)
def left_func(self, x, *args, **kwargs):
return self._left.func(x, *args, **kwargs)
def right_func(self, x, *args, **kwargs):
return self._right.func(x, *args, **kwargs)
@property
def fun_kwargs(self):
return self._fun_kwargs
def subs(self, x, *args, **kwargs):
""" Iterator over all sub-functions (depth-first and left-to-right) """
pl, kw_left = self._left_arguments(*args, **kwargs)
if isinstance(self._left, MultiModel):
yield from self._left.subs(x, *pl, **kw_left)
else:
yield self._left.func(x, *pl, **kw_left)
pr, kw_right = self._right_arguments(*args, **kwargs)
if isinstance(self._right, MultiModel):
yield from self._right.subs(x, *pr, **kw_right)
else:
yield self._right.func(x, *pr, **kw_right)