from typing import Union, Callable, Any import operator from inspect import signature, Parameter class ModelFactory: @staticmethod def create_from_list(funcs: list, left=None, func_order=None, param_len=None, left_cnt=None): if func_order is None: func_order = [] if param_len is None: param_len = [] for func in funcs: if not func['active']: continue func_order.append(func['cnt']) param_len.append(len(func['func'].params)) if func['children']: right, _, _ = ModelFactory.create_from_list(func['children'], left=func['func'], left_cnt=func['pos'], func_order=func_order, param_len=param_len) right_cnt = None else: right = func['func'] right_cnt = func['pos'] if left is None: left = right left_cnt = right_cnt else: left = MultiModel(left, right, func['op'], left_idx=left_cnt, right_idx=right_cnt) return left, func_order, param_len class MultiModel: op_repr = {operator.add: ' + ', operator.mul: ' * ', operator.sub: ' - ', operator.truediv: ' / '} str_op = {'+': operator.add, '*': operator.mul, '-': operator.sub, '/': operator.truediv} int_op = {0: operator.add, 1: operator.mul, 2: operator.sub, 3: operator.truediv} def __init__(self, left: Any, right: Any, op: Union[str, Callable, int] = '+', left_idx=0, right_idx=1): self._left = left self._right = right self._op = None if isinstance(op, str): self._op = MultiModel.str_op.get(op, None) elif isinstance(op, int): self._op = MultiModel.int_op.get(op, None) elif isinstance(op, Callable): self._op = op if self._op is None: raise ValueError('Invalid binary operator.') self.name = '(' self.params = [] self.bounds = [] self._kwargs_right = {} self._kwargs_left = {} self._fun_kwargs = {} # mapping kwargs to kwargs of underlying functions self._ext_int_kw = {} self._get_parameter(left, 'l', left_idx) self._param_left = len(left.params) try: self.name += MultiModel.op_repr[self._op] except KeyError: self.name += str(op) self._get_parameter(right, 'r', right_idx) self.name += ')' self._param_len = len(self.params) def __str__(self): return self.name def _get_parameter(self, func, pos, idx): kw_dict = {'l': self._kwargs_left, 'r': self._kwargs_right}[pos] if isinstance(func, MultiModel): strcnt = '' kw_dict.update(func.fun_kwargs) self._fun_kwargs.update({k: v for k, v in kw_dict.items()}) self._ext_int_kw.update({k: k for k in kw_dict.keys()}) else: temp_dic = {k: v.default for k, v in signature(func.func).parameters.items() if v.default is not Parameter.empty} for k, v in temp_dic.items(): key_ = '%s_%d' % (k, idx) kw_dict[key_] = v self._fun_kwargs[key_] = v self._ext_int_kw[key_] = k strcnt = '(%d)' % idx self.params += [pp+strcnt for pp in func.params] self.name += func.name + strcnt try: self.bounds.extend(func.bounds) except AttributeError: self.bounds.extend([(None, None)]*len(func.params)) def _left_arguments(self, *args, **kwargs): kw_left = {k_int: kwargs[k_ext] for k_ext, k_int in self._ext_int_kw.items() if k_ext in self._kwargs_left} pl = args[:self._param_left] return pl, kw_left def _right_arguments(self, *args, **kwargs): kw_right = {k_int: kwargs[k_ext] for k_ext, k_int in self._ext_int_kw.items() if k_ext in self._kwargs_right} pr = args[self._param_left:self._param_len] return pr, kw_right def func(self, x, *args, **kwargs): pl, kw_left = self._left_arguments(*args, **kwargs) l_func = self._left.func(x, *pl, **kw_left) pr, kw_right = self._right_arguments(*args, **kwargs) r_func = self._right.func(x, *pr, **kw_right) return self._op(l_func, r_func) def left_func(self, x, *args, **kwargs): return self._left.func(x, *args, **kwargs) def right_func(self, x, *args, **kwargs): return self._right.func(x, *args, **kwargs) @property def fun_kwargs(self): return self._fun_kwargs def subs(self, x, *args, **kwargs): """ Iterator over all sub-functions (depth-first and left-to-right) """ pl, kw_left = self._left_arguments(*args, **kwargs) if isinstance(self._left, MultiModel): yield from self._left.subs(x, *pl, **kw_left) else: yield self._left.func(x, *pl, **kw_left) pr, kw_right = self._right_arguments(*args, **kwargs) if isinstance(self._right, MultiModel): yield from self._right.subs(x, *pr, **kw_right) else: yield self._right.func(x, *pr, **kw_right)