forked from IPKM/nmreval
162 lines
5.2 KiB
Python
162 lines
5.2 KiB
Python
from typing import Union, Callable, Any
|
|
import operator
|
|
|
|
from inspect import signature, Parameter
|
|
|
|
|
|
class ModelFactory:
|
|
|
|
@staticmethod
|
|
def create_from_list(funcs: list, left=None, func_order=None, param_len=None, left_cnt=None):
|
|
if func_order is None:
|
|
func_order = []
|
|
|
|
if param_len is None:
|
|
param_len = []
|
|
|
|
for func in funcs:
|
|
if not func['active']:
|
|
continue
|
|
|
|
func_order.append(func['cnt'])
|
|
param_len.append(len(func['func'].params))
|
|
|
|
if func['children']:
|
|
right, _, _ = ModelFactory.create_from_list(func['children'], left=func['func'], left_cnt=func['pos'],
|
|
func_order=func_order, param_len=param_len)
|
|
right_cnt = None
|
|
else:
|
|
right = func['func']
|
|
right_cnt = func['pos']
|
|
|
|
if left is None:
|
|
left = right
|
|
left_cnt = right_cnt
|
|
else:
|
|
left = MultiModel(left, right, func['op'],
|
|
left_idx=left_cnt, right_idx=right_cnt)
|
|
|
|
return left, func_order, param_len
|
|
|
|
|
|
class MultiModel:
|
|
op_repr = {operator.add: ' + ', operator.mul: ' * ', operator.sub: ' - ', operator.truediv: ' / '}
|
|
str_op = {'+': operator.add, '*': operator.mul, '-': operator.sub, '/': operator.truediv}
|
|
int_op = {0: operator.add, 1: operator.mul, 2: operator.sub, 3: operator.truediv}
|
|
|
|
def __init__(self, left: Any, right: Any, op: Union[str, Callable, int] = '+', left_idx=0, right_idx=1):
|
|
self._left = left
|
|
self._right = right
|
|
|
|
self._op = None
|
|
|
|
if isinstance(op, str):
|
|
self._op = MultiModel.str_op.get(op, None)
|
|
elif isinstance(op, int):
|
|
self._op = MultiModel.int_op.get(op, None)
|
|
elif isinstance(op, Callable):
|
|
self._op = op
|
|
|
|
if self._op is None:
|
|
raise ValueError('Invalid binary operator.')
|
|
|
|
self.name = '('
|
|
self.params = []
|
|
self.bounds = []
|
|
self._kwargs_right = {}
|
|
self._kwargs_left = {}
|
|
self._fun_kwargs = {}
|
|
|
|
# mapping kwargs to kwargs of underlying functions
|
|
self._ext_int_kw = {}
|
|
|
|
self._get_parameter(left, 'l', left_idx)
|
|
self._param_left = len(left.params)
|
|
|
|
try:
|
|
self.name += MultiModel.op_repr[self._op]
|
|
except KeyError:
|
|
self.name += str(op)
|
|
|
|
self._get_parameter(right, 'r', right_idx)
|
|
self.name += ')'
|
|
|
|
self._param_len = len(self.params)
|
|
|
|
def __str__(self):
|
|
return self.name
|
|
|
|
def _get_parameter(self, func, pos, idx):
|
|
kw_dict = {'l': self._kwargs_left, 'r': self._kwargs_right}[pos]
|
|
|
|
if isinstance(func, MultiModel):
|
|
strcnt = ''
|
|
kw_dict.update(func.fun_kwargs)
|
|
self._fun_kwargs.update({k: v for k, v in kw_dict.items()})
|
|
self._ext_int_kw.update({k: k for k in kw_dict.keys()})
|
|
|
|
else:
|
|
temp_dic = {k: v.default for k, v in signature(func.func).parameters.items()
|
|
if v.default is not Parameter.empty}
|
|
|
|
for k, v in temp_dic.items():
|
|
key_ = '%s_%d' % (k, idx)
|
|
kw_dict[key_] = v
|
|
self._fun_kwargs[key_] = v
|
|
self._ext_int_kw[key_] = k
|
|
|
|
strcnt = '(%d)' % idx
|
|
|
|
self.params += [pp+strcnt for pp in func.params]
|
|
self.name += func.name + strcnt
|
|
|
|
try:
|
|
self.bounds.extend(func.bounds)
|
|
except AttributeError:
|
|
self.bounds.extend([(None, None)]*len(func.params))
|
|
|
|
def _left_arguments(self, *args, **kwargs):
|
|
kw_left = {k_int: kwargs[k_ext] for k_ext, k_int in self._ext_int_kw.items() if k_ext in self._kwargs_left}
|
|
pl = args[:self._param_left]
|
|
|
|
return pl, kw_left
|
|
|
|
def _right_arguments(self, *args, **kwargs):
|
|
kw_right = {k_int: kwargs[k_ext] for k_ext, k_int in self._ext_int_kw.items() if k_ext in self._kwargs_right}
|
|
pr = args[self._param_left:self._param_len]
|
|
|
|
return pr, kw_right
|
|
|
|
def func(self, x, *args, **kwargs):
|
|
pl, kw_left = self._left_arguments(*args, **kwargs)
|
|
l_func = self._left.func(x, *pl, **kw_left)
|
|
|
|
pr, kw_right = self._right_arguments(*args, **kwargs)
|
|
r_func = self._right.func(x, *pr, **kw_right)
|
|
|
|
return self._op(l_func, r_func)
|
|
|
|
def left_func(self, x, *args, **kwargs):
|
|
return self._left.func(x, *args, **kwargs)
|
|
|
|
def right_func(self, x, *args, **kwargs):
|
|
return self._right.func(x, *args, **kwargs)
|
|
|
|
@property
|
|
def fun_kwargs(self):
|
|
return self._fun_kwargs
|
|
|
|
def subs(self, x, *args, **kwargs):
|
|
""" Iterator over all sub-functions (depth-first and left-to-right) """
|
|
pl, kw_left = self._left_arguments(*args, **kwargs)
|
|
if isinstance(self._left, MultiModel):
|
|
yield from self._left.subs(x, *pl, **kw_left)
|
|
else:
|
|
yield self._left.func(x, *pl, **kw_left)
|
|
|
|
pr, kw_right = self._right_arguments(*args, **kwargs)
|
|
if isinstance(self._right, MultiModel):
|
|
yield from self._right.subs(x, *pr, **kw_right)
|
|
else:
|
|
yield self._right.func(x, *pr, **kw_right)
|