1# Licensed under a 3-clause BSD style license - see LICENSE.rst 2 3""" 4This module defines base classes for all models. The base class of all 5models is `~astropy.modeling.Model`. `~astropy.modeling.FittableModel` is 6the base class for all fittable models. Fittable models can be linear or 7nonlinear in a regression analysis sense. 8 9All models provide a `__call__` method which performs the transformation in 10a purely mathematical way, i.e. the models are unitless. Model instances can 11represent either a single model, or a "model set" representing multiple copies 12of the same type of model, but with potentially different values of the 13parameters in each model making up the set. 14""" 15# pylint: disable=invalid-name, protected-access, redefined-outer-name 16import abc 17import copy 18import inspect 19import itertools 20import functools 21import operator 22import types 23 24from collections import defaultdict, deque 25from inspect import signature 26from itertools import chain 27 28import numpy as np 29 30from astropy.utils import indent, metadata 31from astropy.table import Table 32from astropy.units import Quantity, UnitsError, dimensionless_unscaled 33from astropy.units.utils import quantity_asanyarray 34from astropy.utils import (sharedmethod, find_current_module, 35 check_broadcast, IncompatibleShapeError, isiterable) 36from astropy.utils.codegen import make_function_with_signature 37from astropy.nddata.utils import add_array, extract_array 38from .utils import (combine_labels, make_binary_operator_eval, 39 get_inputs_and_params, _combine_equivalency_dict, 40 _ConstraintsDict, _SpecialOperatorsDict) 41from .bounding_box import ModelBoundingBox, CompoundBoundingBox 42from .parameters import (Parameter, InputParameterError, 43 param_repr_oneline, _tofloat) 44 45 46__all__ = ['Model', 'FittableModel', 'Fittable1DModel', 'Fittable2DModel', 47 'CompoundModel', 'fix_inputs', 'custom_model', 'ModelDefinitionError', 48 'bind_bounding_box', 'bind_compound_bounding_box'] 49 50 51def _model_oper(oper, **kwargs): 52 """ 53 Returns a function that evaluates a given Python arithmetic operator 54 between two models. The operator should be given as a string, like ``'+'`` 55 or ``'**'``. 56 """ 57 return lambda left, right: CompoundModel(oper, left, right, **kwargs) 58 59 60class ModelDefinitionError(TypeError): 61 """Used for incorrect models definitions.""" 62 63 64class _ModelMeta(abc.ABCMeta): 65 """ 66 Metaclass for Model. 67 68 Currently just handles auto-generating the param_names list based on 69 Parameter descriptors declared at the class-level of Model subclasses. 70 """ 71 72 _is_dynamic = False 73 """ 74 This flag signifies whether this class was created in the "normal" way, 75 with a class statement in the body of a module, as opposed to a call to 76 `type` or some other metaclass constructor, such that the resulting class 77 does not belong to a specific module. This is important for pickling of 78 dynamic classes. 79 80 This flag is always forced to False for new classes, so code that creates 81 dynamic classes should manually set it to True on those classes when 82 creating them. 83 """ 84 85 # Default empty dict for _parameters_, which will be empty on model 86 # classes that don't have any Parameters 87 88 def __new__(mcls, name, bases, members): 89 # See the docstring for _is_dynamic above 90 if '_is_dynamic' not in members: 91 members['_is_dynamic'] = mcls._is_dynamic 92 opermethods = [ 93 ('__add__', _model_oper('+')), 94 ('__sub__', _model_oper('-')), 95 ('__mul__', _model_oper('*')), 96 ('__truediv__', _model_oper('/')), 97 ('__pow__', _model_oper('**')), 98 ('__or__', _model_oper('|')), 99 ('__and__', _model_oper('&')), 100 ('_fix_inputs', _model_oper('fix_inputs')) 101 ] 102 103 members['_parameters_'] = {k: v for k, v in members.items() 104 if isinstance(v, Parameter)} 105 106 for opermethod, opercall in opermethods: 107 members[opermethod] = opercall 108 cls = super().__new__(mcls, name, bases, members) 109 110 param_names = list(members['_parameters_']) 111 112 # Need to walk each base MRO to collect all parameter names 113 for base in bases: 114 for tbase in base.__mro__: 115 if issubclass(tbase, Model): 116 # Preserve order of definitions 117 param_names = list(tbase._parameters_) + param_names 118 # Remove duplicates (arising from redefinition in subclass). 119 param_names = list(dict.fromkeys(param_names)) 120 if cls._parameters_: 121 if hasattr(cls, '_param_names'): 122 # Slight kludge to support compound models, where 123 # cls.param_names is a property; could be improved with a 124 # little refactoring but fine for now 125 cls._param_names = tuple(param_names) 126 else: 127 cls.param_names = tuple(param_names) 128 129 return cls 130 131 def __init__(cls, name, bases, members): 132 super(_ModelMeta, cls).__init__(name, bases, members) 133 cls._create_inverse_property(members) 134 cls._create_bounding_box_property(members) 135 pdict = {} 136 for base in bases: 137 for tbase in base.__mro__: 138 if issubclass(tbase, Model): 139 for parname, val in cls._parameters_.items(): 140 pdict[parname] = val 141 cls._handle_special_methods(members, pdict) 142 143 def __repr__(cls): 144 """ 145 Custom repr for Model subclasses. 146 """ 147 148 return cls._format_cls_repr() 149 150 def _repr_pretty_(cls, p, cycle): 151 """ 152 Repr for IPython's pretty printer. 153 154 By default IPython "pretty prints" classes, so we need to implement 155 this so that IPython displays the custom repr for Models. 156 """ 157 158 p.text(repr(cls)) 159 160 def __reduce__(cls): 161 if not cls._is_dynamic: 162 # Just return a string specifying where the class can be imported 163 # from 164 return cls.__name__ 165 members = dict(cls.__dict__) 166 # Delete any ABC-related attributes--these will be restored when 167 # the class is reconstructed: 168 for key in list(members): 169 if key.startswith('_abc_'): 170 del members[key] 171 172 # Delete custom __init__ and __call__ if they exist: 173 for key in ('__init__', '__call__'): 174 if key in members: 175 del members[key] 176 177 return (type(cls), (cls.__name__, cls.__bases__, members)) 178 179 @property 180 def name(cls): 181 """ 182 The name of this model class--equivalent to ``cls.__name__``. 183 184 This attribute is provided for symmetry with the `Model.name` attribute 185 of model instances. 186 """ 187 188 return cls.__name__ 189 190 @property 191 def _is_concrete(cls): 192 """ 193 A class-level property that determines whether the class is a concrete 194 implementation of a Model--i.e. it is not some abstract base class or 195 internal implementation detail (i.e. begins with '_'). 196 """ 197 return not (cls.__name__.startswith('_') or inspect.isabstract(cls)) 198 199 def rename(cls, name=None, inputs=None, outputs=None): 200 """ 201 Creates a copy of this model class with a new name, inputs or outputs. 202 203 The new class is technically a subclass of the original class, so that 204 instance and type checks will still work. For example:: 205 206 >>> from astropy.modeling.models import Rotation2D 207 >>> SkyRotation = Rotation2D.rename('SkyRotation') 208 >>> SkyRotation 209 <class 'astropy.modeling.core.SkyRotation'> 210 Name: SkyRotation (Rotation2D) 211 N_inputs: 2 212 N_outputs: 2 213 Fittable parameters: ('angle',) 214 >>> issubclass(SkyRotation, Rotation2D) 215 True 216 >>> r = SkyRotation(90) 217 >>> isinstance(r, Rotation2D) 218 True 219 """ 220 221 mod = find_current_module(2) 222 if mod: 223 modname = mod.__name__ 224 else: 225 modname = '__main__' 226 227 if name is None: 228 name = cls.name 229 if inputs is None: 230 inputs = cls.inputs 231 else: 232 if not isinstance(inputs, tuple): 233 raise TypeError("Expected 'inputs' to be a tuple of strings.") 234 elif len(inputs) != len(cls.inputs): 235 raise ValueError(f'{cls.name} expects {len(cls.inputs)} inputs') 236 if outputs is None: 237 outputs = cls.outputs 238 else: 239 if not isinstance(outputs, tuple): 240 raise TypeError("Expected 'outputs' to be a tuple of strings.") 241 elif len(outputs) != len(cls.outputs): 242 raise ValueError(f'{cls.name} expects {len(cls.outputs)} outputs') 243 new_cls = type(name, (cls,), {"inputs": inputs, "outputs": outputs}) 244 new_cls.__module__ = modname 245 new_cls.__qualname__ = name 246 247 return new_cls 248 249 def _create_inverse_property(cls, members): 250 inverse = members.get('inverse') 251 if inverse is None or cls.__bases__[0] is object: 252 # The latter clause is the prevent the below code from running on 253 # the Model base class, which implements the default getter and 254 # setter for .inverse 255 return 256 257 if isinstance(inverse, property): 258 # We allow the @property decorator to be omitted entirely from 259 # the class definition, though its use should be encouraged for 260 # clarity 261 inverse = inverse.fget 262 263 # Store the inverse getter internally, then delete the given .inverse 264 # attribute so that cls.inverse resolves to Model.inverse instead 265 cls._inverse = inverse 266 del cls.inverse 267 268 def _create_bounding_box_property(cls, members): 269 """ 270 Takes any bounding_box defined on a concrete Model subclass (either 271 as a fixed tuple or a property or method) and wraps it in the generic 272 getter/setter interface for the bounding_box attribute. 273 """ 274 275 # TODO: Much of this is verbatim from _create_inverse_property--I feel 276 # like there could be a way to generify properties that work this way, 277 # but for the time being that would probably only confuse things more. 278 bounding_box = members.get('bounding_box') 279 if bounding_box is None or cls.__bases__[0] is object: 280 return 281 282 if isinstance(bounding_box, property): 283 bounding_box = bounding_box.fget 284 285 if not callable(bounding_box): 286 # See if it's a hard-coded bounding_box (as a sequence) and 287 # normalize it 288 try: 289 bounding_box = ModelBoundingBox.validate(cls, bounding_box) 290 except ValueError as exc: 291 raise ModelDefinitionError(exc.args[0]) 292 else: 293 sig = signature(bounding_box) 294 # May be a method that only takes 'self' as an argument (like a 295 # property, but the @property decorator was forgotten) 296 # 297 # However, if the method takes additional arguments then this is a 298 # parameterized bounding box and should be callable 299 if len(sig.parameters) > 1: 300 bounding_box = \ 301 cls._create_bounding_box_subclass(bounding_box, sig) 302 303 # See the Model.bounding_box getter definition for how this attribute 304 # is used 305 cls._bounding_box = bounding_box 306 del cls.bounding_box 307 308 def _create_bounding_box_subclass(cls, func, sig): 309 """ 310 For Models that take optional arguments for defining their bounding 311 box, we create a subclass of ModelBoundingBox with a ``__call__`` method 312 that supports those additional arguments. 313 314 Takes the function's Signature as an argument since that is already 315 computed in _create_bounding_box_property, so no need to duplicate that 316 effort. 317 """ 318 319 # TODO: Might be convenient if calling the bounding box also 320 # automatically sets the _user_bounding_box. So that 321 # 322 # >>> model.bounding_box(arg=1) 323 # 324 # in addition to returning the computed bbox, also sets it, so that 325 # it's a shortcut for 326 # 327 # >>> model.bounding_box = model.bounding_box(arg=1) 328 # 329 # Not sure if that would be non-obvious / confusing though... 330 331 def __call__(self, **kwargs): 332 return func(self._model, **kwargs) 333 334 kwargs = [] 335 for idx, param in enumerate(sig.parameters.values()): 336 if idx == 0: 337 # Presumed to be a 'self' argument 338 continue 339 340 if param.default is param.empty: 341 raise ModelDefinitionError( 342 'The bounding_box method for {0} is not correctly ' 343 'defined: If defined as a method all arguments to that ' 344 'method (besides self) must be keyword arguments with ' 345 'default values that can be used to compute a default ' 346 'bounding box.'.format(cls.name)) 347 348 kwargs.append((param.name, param.default)) 349 350 __call__.__signature__ = sig 351 352 return type(f'{cls.name}ModelBoundingBox', (ModelBoundingBox,), 353 {'__call__': __call__}) 354 355 def _handle_special_methods(cls, members, pdict): 356 357 # Handle init creation from inputs 358 def update_wrapper(wrapper, cls): 359 # Set up the new __call__'s metadata attributes as though it were 360 # manually defined in the class definition 361 # A bit like functools.update_wrapper but uses the class instead of 362 # the wrapped function 363 wrapper.__module__ = cls.__module__ 364 wrapper.__doc__ = getattr(cls, wrapper.__name__).__doc__ 365 if hasattr(cls, '__qualname__'): 366 wrapper.__qualname__ = f'{cls.__qualname__}.{wrapper.__name__}' 367 368 if ('__call__' not in members and 'n_inputs' in members and 369 isinstance(members['n_inputs'], int) and members['n_inputs'] > 0): 370 371 # Don't create a custom __call__ for classes that already have one 372 # explicitly defined (this includes the Model base class, and any 373 # other classes that manually override __call__ 374 375 def __call__(self, *inputs, **kwargs): 376 """Evaluate this model on the supplied inputs.""" 377 return super(cls, self).__call__(*inputs, **kwargs) 378 379 # When called, models can take two optional keyword arguments: 380 # 381 # * model_set_axis, which indicates (for multi-dimensional input) 382 # which axis is used to indicate different models 383 # 384 # * equivalencies, a dictionary of equivalencies to be applied to 385 # the input values, where each key should correspond to one of 386 # the inputs. 387 # 388 # The following code creates the __call__ function with these 389 # two keyword arguments. 390 391 args = ('self',) 392 kwargs = dict([('model_set_axis', None), 393 ('with_bounding_box', False), 394 ('fill_value', np.nan), 395 ('equivalencies', None), 396 ('inputs_map', None)]) 397 398 new_call = make_function_with_signature( 399 __call__, args, kwargs, varargs='inputs', varkwargs='new_inputs') 400 401 # The following makes it look like __call__ 402 # was defined in the class 403 update_wrapper(new_call, cls) 404 405 cls.__call__ = new_call 406 407 if ('__init__' not in members and not inspect.isabstract(cls) and 408 cls._parameters_): 409 # Build list of all parameters including inherited ones 410 411 # If *all* the parameters have default values we can make them 412 # keyword arguments; otherwise they must all be positional 413 # arguments 414 if all(p.default is not None for p in pdict.values()): 415 args = ('self',) 416 kwargs = [] 417 for param_name, param_val in pdict.items(): 418 default = param_val.default 419 unit = param_val.unit 420 # If the unit was specified in the parameter but the 421 # default is not a Quantity, attach the unit to the 422 # default. 423 if unit is not None: 424 default = Quantity(default, unit, copy=False) 425 kwargs.append((param_name, default)) 426 else: 427 args = ('self',) + tuple(pdict.keys()) 428 kwargs = {} 429 430 def __init__(self, *params, **kwargs): 431 return super(cls, self).__init__(*params, **kwargs) 432 433 new_init = make_function_with_signature( 434 __init__, args, kwargs, varkwargs='kwargs') 435 update_wrapper(new_init, cls) 436 cls.__init__ = new_init 437 438 # *** Arithmetic operators for creating compound models *** 439 __add__ = _model_oper('+') 440 __sub__ = _model_oper('-') 441 __mul__ = _model_oper('*') 442 __truediv__ = _model_oper('/') 443 __pow__ = _model_oper('**') 444 __or__ = _model_oper('|') 445 __and__ = _model_oper('&') 446 _fix_inputs = _model_oper('fix_inputs') 447 448 # *** Other utilities *** 449 450 def _format_cls_repr(cls, keywords=[]): 451 """ 452 Internal implementation of ``__repr__``. 453 454 This is separated out for ease of use by subclasses that wish to 455 override the default ``__repr__`` while keeping the same basic 456 formatting. 457 """ 458 459 # For the sake of familiarity start the output with the standard class 460 # __repr__ 461 parts = [super().__repr__()] 462 463 if not cls._is_concrete: 464 return parts[0] 465 466 def format_inheritance(cls): 467 bases = [] 468 for base in cls.mro()[1:]: 469 if not issubclass(base, Model): 470 continue 471 elif (inspect.isabstract(base) or 472 base.__name__.startswith('_')): 473 break 474 bases.append(base.name) 475 if bases: 476 return f"{cls.name} ({' -> '.join(bases)})" 477 return cls.name 478 479 try: 480 default_keywords = [ 481 ('Name', format_inheritance(cls)), 482 ('N_inputs', cls.n_inputs), 483 ('N_outputs', cls.n_outputs), 484 ] 485 486 if cls.param_names: 487 default_keywords.append(('Fittable parameters', 488 cls.param_names)) 489 490 for keyword, value in default_keywords + keywords: 491 if value is not None: 492 parts.append(f'{keyword}: {value}') 493 494 return '\n'.join(parts) 495 except Exception: 496 # If any of the above formatting fails fall back on the basic repr 497 # (this is particularly useful in debugging) 498 return parts[0] 499 500 501class Model(metaclass=_ModelMeta): 502 """ 503 Base class for all models. 504 505 This is an abstract class and should not be instantiated directly. 506 507 The following initialization arguments apply to the majority of Model 508 subclasses by default (exceptions include specialized utility models 509 like `~astropy.modeling.mappings.Mapping`). Parametric models take all 510 their parameters as arguments, followed by any of the following optional 511 keyword arguments: 512 513 Parameters 514 ---------- 515 name : str, optional 516 A human-friendly name associated with this model instance 517 (particularly useful for identifying the individual components of a 518 compound model). 519 520 meta : dict, optional 521 An optional dict of user-defined metadata to attach to this model. 522 How this is used and interpreted is up to the user or individual use 523 case. 524 525 n_models : int, optional 526 If given an integer greater than 1, a *model set* is instantiated 527 instead of a single model. This affects how the parameter arguments 528 are interpreted. In this case each parameter must be given as a list 529 or array--elements of this array are taken along the first axis (or 530 ``model_set_axis`` if specified), such that the Nth element is the 531 value of that parameter for the Nth model in the set. 532 533 See the section on model sets in the documentation for more details. 534 535 model_set_axis : int, optional 536 This argument only applies when creating a model set (i.e. ``n_models > 537 1``). It changes how parameter values are interpreted. Normally the 538 first axis of each input parameter array (properly the 0th axis) is 539 taken as the axis corresponding to the model sets. However, any axis 540 of an input array may be taken as this "model set axis". This accepts 541 negative integers as well--for example use ``model_set_axis=-1`` if the 542 last (most rapidly changing) axis should be associated with the model 543 sets. Also, ``model_set_axis=False`` can be used to tell that a given 544 input should be used to evaluate all the models in the model set. 545 546 fixed : dict, optional 547 Dictionary ``{parameter_name: bool}`` setting the fixed constraint 548 for one or more parameters. `True` means the parameter is held fixed 549 during fitting and is prevented from updates once an instance of the 550 model has been created. 551 552 Alternatively the `~astropy.modeling.Parameter.fixed` property of a 553 parameter may be used to lock or unlock individual parameters. 554 555 tied : dict, optional 556 Dictionary ``{parameter_name: callable}`` of parameters which are 557 linked to some other parameter. The dictionary values are callables 558 providing the linking relationship. 559 560 Alternatively the `~astropy.modeling.Parameter.tied` property of a 561 parameter may be used to set the ``tied`` constraint on individual 562 parameters. 563 564 bounds : dict, optional 565 A dictionary ``{parameter_name: value}`` of lower and upper bounds of 566 parameters. Keys are parameter names. Values are a list or a tuple 567 of length 2 giving the desired range for the parameter. 568 569 Alternatively the `~astropy.modeling.Parameter.min` and 570 `~astropy.modeling.Parameter.max` or 571 ~astropy.modeling.Parameter.bounds` properties of a parameter may be 572 used to set bounds on individual parameters. 573 574 eqcons : list, optional 575 List of functions of length n such that ``eqcons[j](x0, *args) == 0.0`` 576 in a successfully optimized problem. 577 578 ineqcons : list, optional 579 List of functions of length n such that ``ieqcons[j](x0, *args) >= 580 0.0`` is a successfully optimized problem. 581 582 Examples 583 -------- 584 >>> from astropy.modeling import models 585 >>> def tie_center(model): 586 ... mean = 50 * model.stddev 587 ... return mean 588 >>> tied_parameters = {'mean': tie_center} 589 590 Specify that ``'mean'`` is a tied parameter in one of two ways: 591 592 >>> g1 = models.Gaussian1D(amplitude=10, mean=5, stddev=.3, 593 ... tied=tied_parameters) 594 595 or 596 597 >>> g1 = models.Gaussian1D(amplitude=10, mean=5, stddev=.3) 598 >>> g1.mean.tied 599 False 600 >>> g1.mean.tied = tie_center 601 >>> g1.mean.tied 602 <function tie_center at 0x...> 603 604 Fixed parameters: 605 606 >>> g1 = models.Gaussian1D(amplitude=10, mean=5, stddev=.3, 607 ... fixed={'stddev': True}) 608 >>> g1.stddev.fixed 609 True 610 611 or 612 613 >>> g1 = models.Gaussian1D(amplitude=10, mean=5, stddev=.3) 614 >>> g1.stddev.fixed 615 False 616 >>> g1.stddev.fixed = True 617 >>> g1.stddev.fixed 618 True 619 """ 620 621 parameter_constraints = Parameter.constraints 622 """ 623 Primarily for informational purposes, these are the types of constraints 624 that can be set on a model's parameters. 625 """ 626 627 model_constraints = ('eqcons', 'ineqcons') 628 """ 629 Primarily for informational purposes, these are the types of constraints 630 that constrain model evaluation. 631 """ 632 633 param_names = () 634 """ 635 Names of the parameters that describe models of this type. 636 637 The parameters in this tuple are in the same order they should be passed in 638 when initializing a model of a specific type. Some types of models, such 639 as polynomial models, have a different number of parameters depending on 640 some other property of the model, such as the degree. 641 642 When defining a custom model class the value of this attribute is 643 automatically set by the `~astropy.modeling.Parameter` attributes defined 644 in the class body. 645 """ 646 647 n_inputs = 0 648 """The number of inputs.""" 649 n_outputs = 0 650 """ The number of outputs.""" 651 652 standard_broadcasting = True 653 fittable = False 654 linear = True 655 _separable = None 656 """ A boolean flag to indicate whether a model is separable.""" 657 meta = metadata.MetaData() 658 """A dict-like object to store optional information.""" 659 660 # By default models either use their own inverse property or have no 661 # inverse at all, but users may also assign a custom inverse to a model, 662 # optionally; in that case it is of course up to the user to determine 663 # whether their inverse is *actually* an inverse to the model they assign 664 # it to. 665 _inverse = None 666 _user_inverse = None 667 668 _bounding_box = None 669 _user_bounding_box = None 670 671 _has_inverse_bounding_box = False 672 673 # Default n_models attribute, so that __len__ is still defined even when a 674 # model hasn't completed initialization yet 675 _n_models = 1 676 677 # New classes can set this as a boolean value. 678 # It is converted to a dictionary mapping input name to a boolean value. 679 _input_units_strict = False 680 681 # Allow dimensionless input (and corresponding output). If this is True, 682 # input values to evaluate will gain the units specified in input_units. If 683 # this is a dictionary then it should map input name to a bool to allow 684 # dimensionless numbers for that input. 685 # Only has an effect if input_units is defined. 686 _input_units_allow_dimensionless = False 687 688 # Default equivalencies to apply to input values. If set, this should be a 689 # dictionary where each key is a string that corresponds to one of the 690 # model inputs. Only has an effect if input_units is defined. 691 input_units_equivalencies = None 692 693 # Covariance matrix can be set by fitter if available. 694 # If cov_matrix is available, then std will set as well 695 _cov_matrix = None 696 _stds = None 697 698 def __init__(self, *args, meta=None, name=None, **kwargs): 699 super().__init__() 700 self._default_inputs_outputs() 701 if meta is not None: 702 self.meta = meta 703 self._name = name 704 # add parameters to instance level by walking MRO list 705 mro = self.__class__.__mro__ 706 for cls in mro: 707 if issubclass(cls, Model): 708 for parname, val in cls._parameters_.items(): 709 newpar = copy.deepcopy(val) 710 newpar.model = self 711 if parname not in self.__dict__: 712 self.__dict__[parname] = newpar 713 714 self._initialize_constraints(kwargs) 715 kwargs = self._initialize_setters(kwargs) 716 # Remaining keyword args are either parameter values or invalid 717 # Parameter values must be passed in as keyword arguments in order to 718 # distinguish them 719 self._initialize_parameters(args, kwargs) 720 self._initialize_slices() 721 self._initialize_unit_support() 722 723 def _default_inputs_outputs(self): 724 if self.n_inputs == 1 and self.n_outputs == 1: 725 self._inputs = ("x",) 726 self._outputs = ("y",) 727 elif self.n_inputs == 2 and self.n_outputs == 1: 728 self._inputs = ("x", "y") 729 self._outputs = ("z",) 730 else: 731 try: 732 self._inputs = tuple("x" + str(idx) for idx in range(self.n_inputs)) 733 self._outputs = tuple("x" + str(idx) for idx in range(self.n_outputs)) 734 except TypeError: 735 # self.n_inputs and self.n_outputs are properties 736 # This is the case when subclasses of Model do not define 737 # ``n_inputs``, ``n_outputs``, ``inputs`` or ``outputs``. 738 self._inputs = () 739 self._outputs = () 740 741 def _initialize_setters(self, kwargs): 742 """ 743 This exists to inject defaults for settable properties for models 744 originating from `custom_model`. 745 """ 746 if hasattr(self, '_settable_properties'): 747 setters = {name: kwargs.pop(name, default) 748 for name, default in self._settable_properties.items()} 749 for name, value in setters.items(): 750 setattr(self, name, value) 751 752 return kwargs 753 754 @property 755 def inputs(self): 756 return self._inputs 757 758 @inputs.setter 759 def inputs(self, val): 760 if len(val) != self.n_inputs: 761 raise ValueError(f"Expected {self.n_inputs} number of inputs, got {len(val)}.") 762 self._inputs = val 763 self._initialize_unit_support() 764 765 @property 766 def outputs(self): 767 return self._outputs 768 769 @outputs.setter 770 def outputs(self, val): 771 if len(val) != self.n_outputs: 772 raise ValueError(f"Expected {self.n_outputs} number of outputs, got {len(val)}.") 773 self._outputs = val 774 775 @property 776 def n_inputs(self): 777 # TODO: remove the code in the ``if`` block when support 778 # for models with ``inputs`` as class variables is removed. 779 if hasattr(self.__class__, 'n_inputs') and isinstance(self.__class__.n_inputs, property): 780 try: 781 return len(self.__class__.inputs) 782 except TypeError: 783 try: 784 return len(self.inputs) 785 except AttributeError: 786 return 0 787 788 return self.__class__.n_inputs 789 790 @property 791 def n_outputs(self): 792 # TODO: remove the code in the ``if`` block when support 793 # for models with ``outputs`` as class variables is removed. 794 if hasattr(self.__class__, 'n_outputs') and isinstance(self.__class__.n_outputs, property): 795 try: 796 return len(self.__class__.outputs) 797 except TypeError: 798 try: 799 return len(self.outputs) 800 except AttributeError: 801 return 0 802 803 return self.__class__.n_outputs 804 805 def _initialize_unit_support(self): 806 """ 807 Convert self._input_units_strict and 808 self.input_units_allow_dimensionless to dictionaries 809 mapping input name to a boolean value. 810 """ 811 if isinstance(self._input_units_strict, bool): 812 self._input_units_strict = {key: self._input_units_strict for 813 key in self.inputs} 814 815 if isinstance(self._input_units_allow_dimensionless, bool): 816 self._input_units_allow_dimensionless = {key: self._input_units_allow_dimensionless 817 for key in self.inputs} 818 819 @property 820 def input_units_strict(self): 821 """ 822 Enforce strict units on inputs to evaluate. If this is set to True, 823 input values to evaluate will be in the exact units specified by 824 input_units. If the input quantities are convertible to input_units, 825 they are converted. If this is a dictionary then it should map input 826 name to a bool to set strict input units for that parameter. 827 """ 828 val = self._input_units_strict 829 if isinstance(val, bool): 830 return {key: val for key in self.inputs} 831 return dict(zip(self.inputs, val.values())) 832 833 @property 834 def input_units_allow_dimensionless(self): 835 """ 836 Allow dimensionless input (and corresponding output). If this is True, 837 input values to evaluate will gain the units specified in input_units. If 838 this is a dictionary then it should map input name to a bool to allow 839 dimensionless numbers for that input. 840 Only has an effect if input_units is defined. 841 """ 842 843 val = self._input_units_allow_dimensionless 844 if isinstance(val, bool): 845 return {key: val for key in self.inputs} 846 return dict(zip(self.inputs, val.values())) 847 848 @property 849 def uses_quantity(self): 850 """ 851 True if this model has been created with `~astropy.units.Quantity` 852 objects or if there are no parameters. 853 854 This can be used to determine if this model should be evaluated with 855 `~astropy.units.Quantity` or regular floats. 856 """ 857 pisq = [isinstance(p, Quantity) for p in self._param_sets(units=True)] 858 return (len(pisq) == 0) or any(pisq) 859 860 def __repr__(self): 861 return self._format_repr() 862 863 def __str__(self): 864 return self._format_str() 865 866 def __len__(self): 867 return self._n_models 868 869 @staticmethod 870 def _strip_ones(intup): 871 return tuple(item for item in intup if item != 1) 872 873 def __setattr__(self, attr, value): 874 if isinstance(self, CompoundModel): 875 param_names = self._param_names 876 param_names = self.param_names 877 878 if param_names is not None and attr in self.param_names: 879 param = self.__dict__[attr] 880 value = _tofloat(value) 881 if param._validator is not None: 882 param._validator(self, value) 883 # check consistency with previous shape and size 884 eshape = self._param_metrics[attr]['shape'] 885 if eshape == (): 886 eshape = (1,) 887 vshape = np.array(value).shape 888 if vshape == (): 889 vshape = (1,) 890 esize = self._param_metrics[attr]['size'] 891 if (np.size(value) != esize or 892 self._strip_ones(vshape) != self._strip_ones(eshape)): 893 raise InputParameterError( 894 "Value for parameter {0} does not match shape or size\n" 895 "expected by model ({1}, {2}) vs ({3}, {4})".format( 896 attr, vshape, np.size(value), eshape, esize)) 897 if param.unit is None: 898 if isinstance(value, Quantity): 899 param._unit = value.unit 900 param.value = value.value 901 else: 902 param.value = value 903 else: 904 if not isinstance(value, Quantity): 905 raise UnitsError(f"The '{param.name}' parameter should be given as a" 906 " Quantity because it was originally " 907 "initialized as a Quantity") 908 param._unit = value.unit 909 param.value = value.value 910 else: 911 if attr in ['fittable', 'linear']: 912 self.__dict__[attr] = value 913 else: 914 super().__setattr__(attr, value) 915 916 def _pre_evaluate(self, *args, **kwargs): 917 """ 918 Model specific input setup that needs to occur prior to model evaluation 919 """ 920 921 # Broadcast inputs into common size 922 inputs, broadcasted_shapes = self.prepare_inputs(*args, **kwargs) 923 924 # Setup actual model evaluation method 925 parameters = self._param_sets(raw=True, units=True) 926 927 def evaluate(_inputs): 928 return self.evaluate(*chain(_inputs, parameters)) 929 930 return evaluate, inputs, broadcasted_shapes, kwargs 931 932 def get_bounding_box(self, with_bbox=True): 933 """ 934 Return the ``bounding_box`` of a model if it exists or ``None`` 935 otherwise. 936 937 Parameters 938 ---------- 939 with_bbox : 940 The value of the ``with_bounding_box`` keyword argument 941 when calling the model. Default is `True` for usage when 942 looking up the model's ``bounding_box`` without risk of error. 943 """ 944 bbox = None 945 946 if not isinstance(with_bbox, bool) or with_bbox: 947 try: 948 bbox = self.bounding_box 949 except NotImplementedError: 950 pass 951 952 if isinstance(bbox, CompoundBoundingBox) and not isinstance(with_bbox, bool): 953 bbox = bbox[with_bbox] 954 955 return bbox 956 957 @property 958 def _argnames(self): 959 """The inputs used to determine input_shape for bounding_box evaluation""" 960 return self.inputs 961 962 def _validate_input_shape(self, _input, idx, argnames, model_set_axis, check_model_set_axis): 963 """ 964 Perform basic validation of a single model input's shape 965 -- it has the minimum dimensions for the given model_set_axis 966 967 Returns the shape of the input if validation succeeds. 968 """ 969 input_shape = np.shape(_input) 970 # Ensure that the input's model_set_axis matches the model's 971 # n_models 972 if input_shape and check_model_set_axis: 973 # Note: Scalar inputs *only* get a pass on this 974 if len(input_shape) < model_set_axis + 1: 975 raise ValueError( 976 f"For model_set_axis={model_set_axis}, all inputs must be at " 977 f"least {model_set_axis + 1}-dimensional.") 978 if input_shape[model_set_axis] != self._n_models: 979 try: 980 argname = argnames[idx] 981 except IndexError: 982 # the case of model.inputs = () 983 argname = str(idx) 984 985 raise ValueError( 986 f"Input argument '{argname}' does not have the correct " 987 f"dimensions in model_set_axis={model_set_axis} for a model set with " 988 f"n_models={self._n_models}.") 989 990 return input_shape 991 992 def _validate_input_shapes(self, inputs, argnames, model_set_axis): 993 """ 994 Perform basic validation of model inputs 995 --that they are mutually broadcastable and that they have 996 the minimum dimensions for the given model_set_axis. 997 998 If validation succeeds, returns the total shape that will result from 999 broadcasting the input arrays with each other. 1000 """ 1001 1002 check_model_set_axis = self._n_models > 1 and model_set_axis is not False 1003 1004 all_shapes = [] 1005 for idx, _input in enumerate(inputs): 1006 all_shapes.append(self._validate_input_shape(_input, idx, argnames, 1007 model_set_axis, check_model_set_axis)) 1008 1009 input_shape = check_broadcast(*all_shapes) 1010 if input_shape is None: 1011 raise ValueError( 1012 "All inputs must have identical shapes or must be scalars.") 1013 1014 return input_shape 1015 1016 def input_shape(self, inputs): 1017 """Get input shape for bounding_box evaluation""" 1018 return self._validate_input_shapes(inputs, self._argnames, self.model_set_axis) 1019 1020 def _generic_evaluate(self, evaluate, _inputs, fill_value, with_bbox): 1021 """ 1022 Generic model evaluation routine 1023 Selects and evaluates model with or without bounding_box enforcement 1024 """ 1025 1026 # Evaluate the model using the prepared evaluation method either 1027 # enforcing the bounding_box or not. 1028 bbox = self.get_bounding_box(with_bbox) 1029 if (not isinstance(with_bbox, bool) or with_bbox) and bbox is not None: 1030 outputs = bbox.evaluate(evaluate, _inputs, fill_value) 1031 else: 1032 outputs = evaluate(_inputs) 1033 return outputs 1034 1035 def _post_evaluate(self, inputs, outputs, broadcasted_shapes, with_bbox, **kwargs): 1036 """ 1037 Model specific post evaluation processing of outputs 1038 """ 1039 if self.get_bounding_box(with_bbox) is None and self.n_outputs == 1: 1040 outputs = (outputs,) 1041 1042 outputs = self.prepare_outputs(broadcasted_shapes, *outputs, **kwargs) 1043 outputs = self._process_output_units(inputs, outputs) 1044 1045 if self.n_outputs == 1: 1046 return outputs[0] 1047 return outputs 1048 1049 @property 1050 def bbox_with_units(self): 1051 return (not isinstance(self, CompoundModel)) 1052 1053 def __call__(self, *args, **kwargs): 1054 """ 1055 Evaluate this model using the given input(s) and the parameter values 1056 that were specified when the model was instantiated. 1057 """ 1058 # Turn any keyword arguments into positional arguments. 1059 args, kwargs = self._get_renamed_inputs_as_positional(*args, **kwargs) 1060 1061 # Read model evaluation related parameters 1062 with_bbox = kwargs.pop('with_bounding_box', False) 1063 fill_value = kwargs.pop('fill_value', np.nan) 1064 1065 # prepare for model evaluation (overridden in CompoundModel) 1066 evaluate, inputs, broadcasted_shapes, kwargs = self._pre_evaluate(*args, **kwargs) 1067 1068 outputs = self._generic_evaluate(evaluate, inputs, 1069 fill_value, with_bbox) 1070 1071 # post-process evaluation results (overridden in CompoundModel) 1072 return self._post_evaluate(inputs, outputs, broadcasted_shapes, with_bbox, **kwargs) 1073 1074 def _get_renamed_inputs_as_positional(self, *args, **kwargs): 1075 def _keyword2positional(kwargs): 1076 # Inputs were passed as keyword (not positional) arguments. 1077 # Because the signature of the ``__call__`` is defined at 1078 # the class level, the name of the inputs cannot be changed at 1079 # the instance level and the old names are always present in the 1080 # signature of the method. In order to use the new names of the 1081 # inputs, the old names are taken out of ``kwargs``, the input 1082 # values are sorted in the order of self.inputs and passed as 1083 # positional arguments to ``__call__``. 1084 1085 # These are the keys that are always present as keyword arguments. 1086 keys = ['model_set_axis', 'with_bounding_box', 'fill_value', 1087 'equivalencies', 'inputs_map'] 1088 1089 new_inputs = {} 1090 # kwargs contain the names of the new inputs + ``keys`` 1091 allkeys = list(kwargs.keys()) 1092 # Remove the names of the new inputs from kwargs and save them 1093 # to a dict ``new_inputs``. 1094 for key in allkeys: 1095 if key not in keys: 1096 new_inputs[key] = kwargs[key] 1097 del kwargs[key] 1098 return new_inputs, kwargs 1099 n_args = len(args) 1100 1101 new_inputs, kwargs = _keyword2positional(kwargs) 1102 n_all_args = n_args + len(new_inputs) 1103 1104 if n_all_args < self.n_inputs: 1105 raise ValueError(f"Missing input arguments - expected {self.n_inputs}, got {n_all_args}") 1106 elif n_all_args > self.n_inputs: 1107 raise ValueError(f"Too many input arguments - expected {self.n_inputs}, got {n_all_args}") 1108 if n_args == 0: 1109 # Create positional arguments from the keyword arguments in ``new_inputs``. 1110 new_args = [] 1111 for k in self.inputs: 1112 new_args.append(new_inputs[k]) 1113 elif n_args != self.n_inputs: 1114 # Some inputs are passed as positional, others as keyword arguments. 1115 args = list(args) 1116 1117 # Create positional arguments from the keyword arguments in ``new_inputs``. 1118 new_args = [] 1119 for k in self.inputs: 1120 if k in new_inputs: 1121 new_args.append(new_inputs[k]) 1122 else: 1123 new_args.append(args[0]) 1124 del args[0] 1125 else: 1126 new_args = args 1127 return new_args, kwargs 1128 1129 # *** Properties *** 1130 @property 1131 def name(self): 1132 """User-provided name for this model instance.""" 1133 1134 return self._name 1135 1136 @name.setter 1137 def name(self, val): 1138 """Assign a (new) name to this model.""" 1139 1140 self._name = val 1141 1142 @property 1143 def model_set_axis(self): 1144 """ 1145 The index of the model set axis--that is the axis of a parameter array 1146 that pertains to which model a parameter value pertains to--as 1147 specified when the model was initialized. 1148 1149 See the documentation on :ref:`astropy:modeling-model-sets` 1150 for more details. 1151 """ 1152 1153 return self._model_set_axis 1154 1155 @property 1156 def param_sets(self): 1157 """ 1158 Return parameters as a pset. 1159 1160 This is a list with one item per parameter set, which is an array of 1161 that parameter's values across all parameter sets, with the last axis 1162 associated with the parameter set. 1163 """ 1164 1165 return self._param_sets() 1166 1167 @property 1168 def parameters(self): 1169 """ 1170 A flattened array of all parameter values in all parameter sets. 1171 1172 Fittable parameters maintain this list and fitters modify it. 1173 """ 1174 1175 # Currently the sequence of a model's parameters must be contiguous 1176 # within the _parameters array (which may be a view of a larger array, 1177 # for example when taking a sub-expression of a compound model), so 1178 # the assumption here is reliable: 1179 if not self.param_names: 1180 # Trivial, but not unheard of 1181 return self._parameters 1182 1183 self._parameters_to_array() 1184 start = self._param_metrics[self.param_names[0]]['slice'].start 1185 stop = self._param_metrics[self.param_names[-1]]['slice'].stop 1186 1187 return self._parameters[start:stop] 1188 1189 @parameters.setter 1190 def parameters(self, value): 1191 """ 1192 Assigning to this attribute updates the parameters array rather than 1193 replacing it. 1194 """ 1195 1196 if not self.param_names: 1197 return 1198 1199 start = self._param_metrics[self.param_names[0]]['slice'].start 1200 stop = self._param_metrics[self.param_names[-1]]['slice'].stop 1201 1202 try: 1203 value = np.array(value).flatten() 1204 self._parameters[start:stop] = value 1205 except ValueError as e: 1206 raise InputParameterError( 1207 "Input parameter values not compatible with the model " 1208 "parameters array: {0}".format(e)) 1209 self._array_to_parameters() 1210 1211 @property 1212 def sync_constraints(self): 1213 ''' 1214 This is a boolean property that indicates whether or not accessing constraints 1215 automatically check the constituent models current values. It defaults to True 1216 on creation of a model, but for fitting purposes it should be set to False 1217 for performance reasons. 1218 ''' 1219 if not hasattr(self, '_sync_constraints'): 1220 self._sync_constraints = True 1221 return self._sync_constraints 1222 1223 @sync_constraints.setter 1224 def sync_constraints(self, value): 1225 if not isinstance(value, bool): 1226 raise ValueError('sync_constraints only accepts True or False as values') 1227 self._sync_constraints = value 1228 1229 @property 1230 def fixed(self): 1231 """ 1232 A ``dict`` mapping parameter names to their fixed constraint. 1233 """ 1234 if not hasattr(self, '_fixed') or self.sync_constraints: 1235 self._fixed = _ConstraintsDict(self, 'fixed') 1236 return self._fixed 1237 1238 @property 1239 def bounds(self): 1240 """ 1241 A ``dict`` mapping parameter names to their upper and lower bounds as 1242 ``(min, max)`` tuples or ``[min, max]`` lists. 1243 """ 1244 if not hasattr(self, '_bounds') or self.sync_constraints: 1245 self._bounds = _ConstraintsDict(self, 'bounds') 1246 return self._bounds 1247 1248 @property 1249 def tied(self): 1250 """ 1251 A ``dict`` mapping parameter names to their tied constraint. 1252 """ 1253 if not hasattr(self, '_tied') or self.sync_constraints: 1254 self._tied = _ConstraintsDict(self, 'tied') 1255 return self._tied 1256 1257 @property 1258 def eqcons(self): 1259 """List of parameter equality constraints.""" 1260 1261 return self._mconstraints['eqcons'] 1262 1263 @property 1264 def ineqcons(self): 1265 """List of parameter inequality constraints.""" 1266 1267 return self._mconstraints['ineqcons'] 1268 1269 def has_inverse(self): 1270 """ 1271 Returns True if the model has an analytic or user 1272 inverse defined. 1273 """ 1274 try: 1275 self.inverse 1276 except NotImplementedError: 1277 return False 1278 1279 return True 1280 1281 @property 1282 def inverse(self): 1283 """ 1284 Returns a new `~astropy.modeling.Model` instance which performs the 1285 inverse transform, if an analytic inverse is defined for this model. 1286 1287 Even on models that don't have an inverse defined, this property can be 1288 set with a manually-defined inverse, such a pre-computed or 1289 experimentally determined inverse (often given as a 1290 `~astropy.modeling.polynomial.PolynomialModel`, but not by 1291 requirement). 1292 1293 A custom inverse can be deleted with ``del model.inverse``. In this 1294 case the model's inverse is reset to its default, if a default exists 1295 (otherwise the default is to raise `NotImplementedError`). 1296 1297 Note to authors of `~astropy.modeling.Model` subclasses: To define an 1298 inverse for a model simply override this property to return the 1299 appropriate model representing the inverse. The machinery that will 1300 make the inverse manually-overridable is added automatically by the 1301 base class. 1302 """ 1303 if self._user_inverse is not None: 1304 return self._user_inverse 1305 elif self._inverse is not None: 1306 result = self._inverse() 1307 if result is not NotImplemented: 1308 if not self._has_inverse_bounding_box: 1309 result.bounding_box = None 1310 return result 1311 1312 raise NotImplementedError("No analytical or user-supplied inverse transform " 1313 "has been implemented for this model.") 1314 1315 @inverse.setter 1316 def inverse(self, value): 1317 if not isinstance(value, (Model, type(None))): 1318 raise ValueError( 1319 "The ``inverse`` attribute may be assigned a `Model` " 1320 "instance or `None` (where `None` explicitly forces the " 1321 "model to have no inverse.") 1322 1323 self._user_inverse = value 1324 1325 @inverse.deleter 1326 def inverse(self): 1327 """ 1328 Resets the model's inverse to its default (if one exists, otherwise 1329 the model will have no inverse). 1330 """ 1331 1332 try: 1333 del self._user_inverse 1334 except AttributeError: 1335 pass 1336 1337 @property 1338 def has_user_inverse(self): 1339 """ 1340 A flag indicating whether or not a custom inverse model has been 1341 assigned to this model by a user, via assignment to ``model.inverse``. 1342 """ 1343 return self._user_inverse is not None 1344 1345 @property 1346 def bounding_box(self): 1347 r""" 1348 A `tuple` of length `n_inputs` defining the bounding box limits, or 1349 raise `NotImplementedError` for no bounding_box. 1350 1351 The default limits are given by a ``bounding_box`` property or method 1352 defined in the class body of a specific model. If not defined then 1353 this property just raises `NotImplementedError` by default (but may be 1354 assigned a custom value by a user). ``bounding_box`` can be set 1355 manually to an array-like object of shape ``(model.n_inputs, 2)``. For 1356 further usage, see :ref:`astropy:bounding-boxes` 1357 1358 The limits are ordered according to the `numpy` ``'C'`` indexing 1359 convention, and are the reverse of the model input order, 1360 e.g. for inputs ``('x', 'y', 'z')``, ``bounding_box`` is defined: 1361 1362 * for 1D: ``(x_low, x_high)`` 1363 * for 2D: ``((y_low, y_high), (x_low, x_high))`` 1364 * for 3D: ``((z_low, z_high), (y_low, y_high), (x_low, x_high))`` 1365 1366 Examples 1367 -------- 1368 1369 Setting the ``bounding_box`` limits for a 1D and 2D model: 1370 1371 >>> from astropy.modeling.models import Gaussian1D, Gaussian2D 1372 >>> model_1d = Gaussian1D() 1373 >>> model_2d = Gaussian2D(x_stddev=1, y_stddev=1) 1374 >>> model_1d.bounding_box = (-5, 5) 1375 >>> model_2d.bounding_box = ((-6, 6), (-5, 5)) 1376 1377 Setting the bounding_box limits for a user-defined 3D `custom_model`: 1378 1379 >>> from astropy.modeling.models import custom_model 1380 >>> def const3d(x, y, z, amp=1): 1381 ... return amp 1382 ... 1383 >>> Const3D = custom_model(const3d) 1384 >>> model_3d = Const3D() 1385 >>> model_3d.bounding_box = ((-6, 6), (-5, 5), (-4, 4)) 1386 1387 To reset ``bounding_box`` to its default limits just delete the 1388 user-defined value--this will reset it back to the default defined 1389 on the class: 1390 1391 >>> del model_1d.bounding_box 1392 1393 To disable the bounding box entirely (including the default), 1394 set ``bounding_box`` to `None`: 1395 1396 >>> model_1d.bounding_box = None 1397 >>> model_1d.bounding_box # doctest: +IGNORE_EXCEPTION_DETAIL 1398 Traceback (most recent call last): 1399 NotImplementedError: No bounding box is defined for this model 1400 (note: the bounding box was explicitly disabled for this model; 1401 use `del model.bounding_box` to restore the default bounding box, 1402 if one is defined for this model). 1403 """ 1404 1405 if self._user_bounding_box is not None: 1406 if self._user_bounding_box is NotImplemented: 1407 raise NotImplementedError( 1408 "No bounding box is defined for this model (note: the " 1409 "bounding box was explicitly disabled for this model; " 1410 "use `del model.bounding_box` to restore the default " 1411 "bounding box, if one is defined for this model).") 1412 return self._user_bounding_box 1413 elif self._bounding_box is None: 1414 raise NotImplementedError( 1415 "No bounding box is defined for this model.") 1416 elif isinstance(self._bounding_box, ModelBoundingBox): 1417 # This typically implies a hard-coded bounding box. This will 1418 # probably be rare, but it is an option 1419 return self._bounding_box 1420 elif isinstance(self._bounding_box, types.MethodType): 1421 return ModelBoundingBox.validate(self, self._bounding_box()) 1422 else: 1423 # The only other allowed possibility is that it's a ModelBoundingBox 1424 # subclass, so we call it with its default arguments and return an 1425 # instance of it (that can be called to recompute the bounding box 1426 # with any optional parameters) 1427 # (In other words, in this case self._bounding_box is a *class*) 1428 bounding_box = self._bounding_box((), model=self)() 1429 return self._bounding_box(bounding_box, model=self) 1430 1431 @bounding_box.setter 1432 def bounding_box(self, bounding_box): 1433 """ 1434 Assigns the bounding box limits. 1435 """ 1436 1437 if bounding_box is None: 1438 cls = None 1439 # We use this to explicitly set an unimplemented bounding box (as 1440 # opposed to no user bounding box defined) 1441 bounding_box = NotImplemented 1442 elif (isinstance(bounding_box, CompoundBoundingBox) or 1443 isinstance(bounding_box, dict)): 1444 cls = CompoundBoundingBox 1445 elif (isinstance(self._bounding_box, type) and 1446 issubclass(self._bounding_box, ModelBoundingBox)): 1447 cls = self._bounding_box 1448 else: 1449 cls = ModelBoundingBox 1450 1451 if cls is not None: 1452 try: 1453 bounding_box = cls.validate(self, bounding_box) 1454 except ValueError as exc: 1455 raise ValueError(exc.args[0]) 1456 1457 self._user_bounding_box = bounding_box 1458 1459 def set_slice_args(self, *args): 1460 if isinstance(self._user_bounding_box, CompoundBoundingBox): 1461 self._user_bounding_box.slice_args = args 1462 else: 1463 raise RuntimeError('The bounding_box for this model is not compound') 1464 1465 @bounding_box.deleter 1466 def bounding_box(self): 1467 self._user_bounding_box = None 1468 1469 @property 1470 def has_user_bounding_box(self): 1471 """ 1472 A flag indicating whether or not a custom bounding_box has been 1473 assigned to this model by a user, via assignment to 1474 ``model.bounding_box``. 1475 """ 1476 1477 return self._user_bounding_box is not None 1478 1479 @property 1480 def cov_matrix(self): 1481 """ 1482 Fitter should set covariance matrix, if available. 1483 """ 1484 return self._cov_matrix 1485 1486 @cov_matrix.setter 1487 def cov_matrix(self, cov): 1488 1489 self._cov_matrix = cov 1490 1491 unfix_untied_params = [p for p in self.param_names if (self.fixed[p] is False) 1492 and (self.tied[p] is False)] 1493 if type(cov) == list: # model set 1494 param_stds = [] 1495 for c in cov: 1496 param_stds.append([np.sqrt(x) if x > 0 else None for x in np.diag(c.cov_matrix)]) 1497 for p, param_name in enumerate(unfix_untied_params): 1498 par = getattr(self, param_name) 1499 par.std = [item[p] for item in param_stds] 1500 setattr(self, param_name, par) 1501 else: 1502 param_stds = [np.sqrt(x) if x > 0 else None for x in np.diag(cov.cov_matrix)] 1503 for param_name in unfix_untied_params: 1504 par = getattr(self, param_name) 1505 par.std = param_stds.pop(0) 1506 setattr(self, param_name, par) 1507 1508 @property 1509 def stds(self): 1510 """ 1511 Standard deviation of parameters, if covariance matrix is available. 1512 """ 1513 return self._stds 1514 1515 @stds.setter 1516 def stds(self, stds): 1517 self._stds = stds 1518 1519 @property 1520 def separable(self): 1521 """ A flag indicating whether a model is separable.""" 1522 1523 if self._separable is not None: 1524 return self._separable 1525 raise NotImplementedError( 1526 'The "separable" property is not defined for ' 1527 'model {}'.format(self.__class__.__name__)) 1528 1529 # *** Public methods *** 1530 1531 def without_units_for_data(self, **kwargs): 1532 """ 1533 Return an instance of the model for which the parameter values have 1534 been converted to the right units for the data, then the units have 1535 been stripped away. 1536 1537 The input and output Quantity objects should be given as keyword 1538 arguments. 1539 1540 Notes 1541 ----- 1542 1543 This method is needed in order to be able to fit models with units in 1544 the parameters, since we need to temporarily strip away the units from 1545 the model during the fitting (which might be done by e.g. scipy 1546 functions). 1547 1548 The units that the parameters should be converted to are not 1549 necessarily the units of the input data, but are derived from them. 1550 Model subclasses that want fitting to work in the presence of 1551 quantities need to define a ``_parameter_units_for_data_units`` method 1552 that takes the input and output units (as two dictionaries) and 1553 returns a dictionary giving the target units for each parameter. 1554 1555 """ 1556 model = self.copy() 1557 1558 inputs_unit = {inp: getattr(kwargs[inp], 'unit', dimensionless_unscaled) 1559 for inp in self.inputs if kwargs[inp] is not None} 1560 1561 outputs_unit = {out: getattr(kwargs[out], 'unit', dimensionless_unscaled) 1562 for out in self.outputs if kwargs[out] is not None} 1563 parameter_units = self._parameter_units_for_data_units(inputs_unit, 1564 outputs_unit) 1565 for name, unit in parameter_units.items(): 1566 parameter = getattr(model, name) 1567 if parameter.unit is not None: 1568 parameter.value = parameter.quantity.to(unit).value 1569 parameter._set_unit(None, force=True) 1570 1571 if isinstance(model, CompoundModel): 1572 model.strip_units_from_tree() 1573 1574 return model 1575 1576 def strip_units_from_tree(self): 1577 for item in self._leaflist: 1578 for parname in item.param_names: 1579 par = getattr(item, parname) 1580 par._set_unit(None, force=True) 1581 1582 def with_units_from_data(self, **kwargs): 1583 """ 1584 Return an instance of the model which has units for which the parameter 1585 values are compatible with the data units specified. 1586 1587 The input and output Quantity objects should be given as keyword 1588 arguments. 1589 1590 Notes 1591 ----- 1592 1593 This method is needed in order to be able to fit models with units in 1594 the parameters, since we need to temporarily strip away the units from 1595 the model during the fitting (which might be done by e.g. scipy 1596 functions). 1597 1598 The units that the parameters will gain are not necessarily the units 1599 of the input data, but are derived from them. Model subclasses that 1600 want fitting to work in the presence of quantities need to define a 1601 ``_parameter_units_for_data_units`` method that takes the input and output 1602 units (as two dictionaries) and returns a dictionary giving the target 1603 units for each parameter. 1604 """ 1605 1606 model = self.copy() 1607 inputs_unit = {inp: getattr(kwargs[inp], 'unit', dimensionless_unscaled) 1608 for inp in self.inputs if kwargs[inp] is not None} 1609 1610 outputs_unit = {out: getattr(kwargs[out], 'unit', dimensionless_unscaled) 1611 for out in self.outputs if kwargs[out] is not None} 1612 1613 parameter_units = self._parameter_units_for_data_units(inputs_unit, 1614 outputs_unit) 1615 1616 # We are adding units to parameters that already have a value, but we 1617 # don't want to convert the parameter, just add the unit directly, 1618 # hence the call to ``_set_unit``. 1619 for name, unit in parameter_units.items(): 1620 parameter = getattr(model, name) 1621 parameter._set_unit(unit, force=True) 1622 1623 return model 1624 1625 @property 1626 def _has_units(self): 1627 # Returns True if any of the parameters have units 1628 for param in self.param_names: 1629 if getattr(self, param).unit is not None: 1630 return True 1631 else: 1632 return False 1633 1634 @property 1635 def _supports_unit_fitting(self): 1636 # If the model has a ``_parameter_units_for_data_units`` method, this 1637 # indicates that we have enough information to strip the units away 1638 # and add them back after fitting, when fitting quantities 1639 return hasattr(self, '_parameter_units_for_data_units') 1640 1641 @abc.abstractmethod 1642 def evaluate(self, *args, **kwargs): 1643 """Evaluate the model on some input variables.""" 1644 1645 def sum_of_implicit_terms(self, *args, **kwargs): 1646 """ 1647 Evaluate the sum of any implicit model terms on some input variables. 1648 This includes any fixed terms used in evaluating a linear model that 1649 do not have corresponding parameters exposed to the user. The 1650 prototypical case is `astropy.modeling.functional_models.Shift`, which 1651 corresponds to a function y = a + bx, where b=1 is intrinsically fixed 1652 by the type of model, such that sum_of_implicit_terms(x) == x. This 1653 method is needed by linear fitters to correct the dependent variable 1654 for the implicit term(s) when solving for the remaining terms 1655 (ie. a = y - bx). 1656 """ 1657 1658 def render(self, out=None, coords=None): 1659 """ 1660 Evaluate a model at fixed positions, respecting the ``bounding_box``. 1661 1662 The key difference relative to evaluating the model directly is that 1663 this method is limited to a bounding box if the `Model.bounding_box` 1664 attribute is set. 1665 1666 Parameters 1667 ---------- 1668 out : `numpy.ndarray`, optional 1669 An array that the evaluated model will be added to. If this is not 1670 given (or given as ``None``), a new array will be created. 1671 coords : array-like, optional 1672 An array to be used to translate from the model's input coordinates 1673 to the ``out`` array. It should have the property that 1674 ``self(coords)`` yields the same shape as ``out``. If ``out`` is 1675 not specified, ``coords`` will be used to determine the shape of 1676 the returned array. If this is not provided (or None), the model 1677 will be evaluated on a grid determined by `Model.bounding_box`. 1678 1679 Returns 1680 ------- 1681 out : `numpy.ndarray` 1682 The model added to ``out`` if ``out`` is not ``None``, or else a 1683 new array from evaluating the model over ``coords``. 1684 If ``out`` and ``coords`` are both `None`, the returned array is 1685 limited to the `Model.bounding_box` limits. If 1686 `Model.bounding_box` is `None`, ``arr`` or ``coords`` must be 1687 passed. 1688 1689 Raises 1690 ------ 1691 ValueError 1692 If ``coords`` are not given and the the `Model.bounding_box` of 1693 this model is not set. 1694 1695 Examples 1696 -------- 1697 :ref:`astropy:bounding-boxes` 1698 """ 1699 1700 try: 1701 bbox = self.bounding_box 1702 except NotImplementedError: 1703 bbox = None 1704 1705 if isinstance(bbox, ModelBoundingBox): 1706 bbox = bbox.bounding_box() 1707 1708 ndim = self.n_inputs 1709 1710 if (coords is None) and (out is None) and (bbox is None): 1711 raise ValueError('If no bounding_box is set, ' 1712 'coords or out must be input.') 1713 1714 # for consistent indexing 1715 if ndim == 1: 1716 if coords is not None: 1717 coords = [coords] 1718 if bbox is not None: 1719 bbox = [bbox] 1720 1721 if coords is not None: 1722 coords = np.asanyarray(coords, dtype=float) 1723 # Check dimensions match out and model 1724 assert len(coords) == ndim 1725 if out is not None: 1726 if coords[0].shape != out.shape: 1727 raise ValueError('inconsistent shape of the output.') 1728 else: 1729 out = np.zeros(coords[0].shape) 1730 1731 if out is not None: 1732 out = np.asanyarray(out) 1733 if out.ndim != ndim: 1734 raise ValueError('the array and model must have the same ' 1735 'number of dimensions.') 1736 1737 if bbox is not None: 1738 # Assures position is at center pixel, 1739 # important when using add_array. 1740 pd = np.array([(np.mean(bb), np.ceil((bb[1] - bb[0]) / 2)) 1741 for bb in bbox]).astype(int).T 1742 pos, delta = pd 1743 1744 if coords is not None: 1745 sub_shape = tuple(delta * 2 + 1) 1746 sub_coords = np.array([extract_array(c, sub_shape, pos) 1747 for c in coords]) 1748 else: 1749 limits = [slice(p - d, p + d + 1, 1) for p, d in pd.T] 1750 sub_coords = np.mgrid[limits] 1751 1752 sub_coords = sub_coords[::-1] 1753 1754 if out is None: 1755 out = self(*sub_coords) 1756 else: 1757 try: 1758 out = add_array(out, self(*sub_coords), pos) 1759 except ValueError: 1760 raise ValueError( 1761 'The `bounding_box` is larger than the input out in ' 1762 'one or more dimensions. Set ' 1763 '`model.bounding_box = None`.') 1764 else: 1765 if coords is None: 1766 im_shape = out.shape 1767 limits = [slice(i) for i in im_shape] 1768 coords = np.mgrid[limits] 1769 1770 coords = coords[::-1] 1771 1772 out += self(*coords) 1773 1774 return out 1775 1776 @property 1777 def input_units(self): 1778 """ 1779 This property is used to indicate what units or sets of units the 1780 evaluate method expects, and returns a dictionary mapping inputs to 1781 units (or `None` if any units are accepted). 1782 1783 Model sub-classes can also use function annotations in evaluate to 1784 indicate valid input units, in which case this property should 1785 not be overridden since it will return the input units based on the 1786 annotations. 1787 """ 1788 if hasattr(self, '_input_units'): 1789 return self._input_units 1790 elif hasattr(self.evaluate, '__annotations__'): 1791 annotations = self.evaluate.__annotations__.copy() 1792 annotations.pop('return', None) 1793 if annotations: 1794 # If there are not annotations for all inputs this will error. 1795 return dict((name, annotations[name]) for name in self.inputs) 1796 else: 1797 # None means any unit is accepted 1798 return None 1799 1800 @property 1801 def return_units(self): 1802 """ 1803 This property is used to indicate what units or sets of units the 1804 output of evaluate should be in, and returns a dictionary mapping 1805 outputs to units (or `None` if any units are accepted). 1806 1807 Model sub-classes can also use function annotations in evaluate to 1808 indicate valid output units, in which case this property should not be 1809 overridden since it will return the return units based on the 1810 annotations. 1811 """ 1812 if hasattr(self, '_return_units'): 1813 return self._return_units 1814 elif hasattr(self.evaluate, '__annotations__'): 1815 return self.evaluate.__annotations__.get('return', None) 1816 else: 1817 # None means any unit is accepted 1818 return None 1819 1820 def _prepare_inputs_single_model(self, params, inputs, **kwargs): 1821 broadcasts = [] 1822 for idx, _input in enumerate(inputs): 1823 input_shape = _input.shape 1824 1825 # Ensure that array scalars are always upgrade to 1-D arrays for the 1826 # sake of consistency with how parameters work. They will be cast back 1827 # to scalars at the end 1828 if not input_shape: 1829 inputs[idx] = _input.reshape((1,)) 1830 1831 if not params: 1832 max_broadcast = input_shape 1833 else: 1834 max_broadcast = () 1835 1836 for param in params: 1837 try: 1838 if self.standard_broadcasting: 1839 broadcast = check_broadcast(input_shape, param.shape) 1840 else: 1841 broadcast = input_shape 1842 except IncompatibleShapeError: 1843 raise ValueError( 1844 "self input argument {0!r} of shape {1!r} cannot be " 1845 "broadcast with parameter {2!r} of shape " 1846 "{3!r}.".format(self.inputs[idx], input_shape, 1847 param.name, param.shape)) 1848 1849 if len(broadcast) > len(max_broadcast): 1850 max_broadcast = broadcast 1851 elif len(broadcast) == len(max_broadcast): 1852 max_broadcast = max(max_broadcast, broadcast) 1853 1854 broadcasts.append(max_broadcast) 1855 1856 if self.n_outputs > self.n_inputs: 1857 extra_outputs = self.n_outputs - self.n_inputs 1858 if not broadcasts: 1859 # If there were no inputs then the broadcasts list is empty 1860 # just add a None since there is no broadcasting of outputs and 1861 # inputs necessary (see _prepare_outputs_single_self) 1862 broadcasts.append(None) 1863 broadcasts.extend([broadcasts[0]] * extra_outputs) 1864 1865 return inputs, (broadcasts,) 1866 1867 @staticmethod 1868 def _remove_axes_from_shape(shape, axis): 1869 """ 1870 Given a shape tuple as the first input, construct a new one by removing 1871 that particular axis from the shape and all preceeding axes. Negative axis 1872 numbers are permittted, where the axis is relative to the last axis. 1873 """ 1874 if len(shape) == 0: 1875 return shape 1876 if axis < 0: 1877 axis = len(shape) + axis 1878 return shape[:axis] + shape[axis+1:] 1879 if axis >= len(shape): 1880 axis = len(shape)-1 1881 shape = shape[axis+1:] 1882 return shape 1883 1884 def _prepare_inputs_model_set(self, params, inputs, model_set_axis_input, 1885 **kwargs): 1886 reshaped = [] 1887 pivots = [] 1888 1889 model_set_axis_param = self.model_set_axis # needed to reshape param 1890 for idx, _input in enumerate(inputs): 1891 max_param_shape = () 1892 if self._n_models > 1 and model_set_axis_input is not False: 1893 # Use the shape of the input *excluding* the model axis 1894 input_shape = (_input.shape[:model_set_axis_input] + 1895 _input.shape[model_set_axis_input + 1:]) 1896 else: 1897 input_shape = _input.shape 1898 1899 for param in params: 1900 try: 1901 check_broadcast(input_shape, 1902 self._remove_axes_from_shape(param.shape, 1903 model_set_axis_param)) 1904 except IncompatibleShapeError: 1905 raise ValueError( 1906 "Model input argument {0!r} of shape {1!r} cannot be " 1907 "broadcast with parameter {2!r} of shape " 1908 "{3!r}.".format(self.inputs[idx], input_shape, 1909 param.name, 1910 self._remove_axes_from_shape(param.shape, 1911 model_set_axis_param))) 1912 1913 if len(param.shape) - 1 > len(max_param_shape): 1914 max_param_shape = self._remove_axes_from_shape(param.shape, 1915 model_set_axis_param) 1916 1917 # We've now determined that, excluding the model_set_axis, the 1918 # input can broadcast with all the parameters 1919 input_ndim = len(input_shape) 1920 if model_set_axis_input is False: 1921 if len(max_param_shape) > input_ndim: 1922 # Just needs to prepend new axes to the input 1923 n_new_axes = 1 + len(max_param_shape) - input_ndim 1924 new_axes = (1,) * n_new_axes 1925 new_shape = new_axes + _input.shape 1926 pivot = model_set_axis_param 1927 else: 1928 pivot = input_ndim - len(max_param_shape) 1929 new_shape = (_input.shape[:pivot] + (1,) + 1930 _input.shape[pivot:]) 1931 new_input = _input.reshape(new_shape) 1932 else: 1933 if len(max_param_shape) >= input_ndim: 1934 n_new_axes = len(max_param_shape) - input_ndim 1935 pivot = self.model_set_axis 1936 new_axes = (1,) * n_new_axes 1937 new_shape = (_input.shape[:pivot + 1] + new_axes + 1938 _input.shape[pivot + 1:]) 1939 new_input = _input.reshape(new_shape) 1940 else: 1941 pivot = _input.ndim - len(max_param_shape) - 1 1942 new_input = np.rollaxis(_input, model_set_axis_input, 1943 pivot + 1) 1944 pivots.append(pivot) 1945 reshaped.append(new_input) 1946 1947 if self.n_inputs < self.n_outputs: 1948 pivots.extend([model_set_axis_input] * (self.n_outputs - self.n_inputs)) 1949 1950 return reshaped, (pivots,) 1951 1952 def prepare_inputs(self, *inputs, model_set_axis=None, equivalencies=None, 1953 **kwargs): 1954 """ 1955 This method is used in `~astropy.modeling.Model.__call__` to ensure 1956 that all the inputs to the model can be broadcast into compatible 1957 shapes (if one or both of them are input as arrays), particularly if 1958 there are more than one parameter sets. This also makes sure that (if 1959 applicable) the units of the input will be compatible with the evaluate 1960 method. 1961 """ 1962 # When we instantiate the model class, we make sure that __call__ can 1963 # take the following two keyword arguments: model_set_axis and 1964 # equivalencies. 1965 if model_set_axis is None: 1966 # By default the model_set_axis for the input is assumed to be the 1967 # same as that for the parameters the model was defined with 1968 # TODO: Ensure that negative model_set_axis arguments are respected 1969 model_set_axis = self.model_set_axis 1970 1971 params = [getattr(self, name) for name in self.param_names] 1972 inputs = [np.asanyarray(_input, dtype=float) for _input in inputs] 1973 1974 self._validate_input_shapes(inputs, self.inputs, model_set_axis) 1975 1976 inputs_map = kwargs.get('inputs_map', None) 1977 1978 inputs = self._validate_input_units(inputs, equivalencies, inputs_map) 1979 1980 # The input formatting required for single models versus a multiple 1981 # model set are different enough that they've been split into separate 1982 # subroutines 1983 if self._n_models == 1: 1984 return self._prepare_inputs_single_model(params, inputs, **kwargs) 1985 else: 1986 return self._prepare_inputs_model_set(params, inputs, 1987 model_set_axis, **kwargs) 1988 1989 def _validate_input_units(self, inputs, equivalencies=None, inputs_map=None): 1990 inputs = list(inputs) 1991 name = self.name or self.__class__.__name__ 1992 # Check that the units are correct, if applicable 1993 1994 if self.input_units is not None: 1995 # If a leaflist is provided that means this is in the context of 1996 # a compound model and it is necessary to create the appropriate 1997 # alias for the input coordinate name for the equivalencies dict 1998 if inputs_map: 1999 edict = {} 2000 for mod, mapping in inputs_map: 2001 if self is mod: 2002 edict[mapping[0]] = equivalencies[mapping[1]] 2003 else: 2004 edict = equivalencies 2005 # We combine any instance-level input equivalencies with user 2006 # specified ones at call-time. 2007 input_units_equivalencies = _combine_equivalency_dict(self.inputs, 2008 edict, 2009 self.input_units_equivalencies) 2010 2011 # We now iterate over the different inputs and make sure that their 2012 # units are consistent with those specified in input_units. 2013 for i in range(len(inputs)): 2014 2015 input_name = self.inputs[i] 2016 input_unit = self.input_units.get(input_name, None) 2017 2018 if input_unit is None: 2019 continue 2020 2021 if isinstance(inputs[i], Quantity): 2022 2023 # We check for consistency of the units with input_units, 2024 # taking into account any equivalencies 2025 2026 if inputs[i].unit.is_equivalent( 2027 input_unit, 2028 equivalencies=input_units_equivalencies[input_name]): 2029 2030 # If equivalencies have been specified, we need to 2031 # convert the input to the input units - this is 2032 # because some equivalencies are non-linear, and 2033 # we need to be sure that we evaluate the model in 2034 # its own frame of reference. If input_units_strict 2035 # is set, we also need to convert to the input units. 2036 if len(input_units_equivalencies) > 0 or self.input_units_strict[input_name]: 2037 inputs[i] = inputs[i].to(input_unit, 2038 equivalencies=input_units_equivalencies[input_name]) 2039 2040 else: 2041 2042 # We consider the following two cases separately so as 2043 # to be able to raise more appropriate/nicer exceptions 2044 2045 if input_unit is dimensionless_unscaled: 2046 raise UnitsError("{0}: Units of input '{1}', {2} ({3})," 2047 "could not be converted to " 2048 "required dimensionless " 2049 "input".format(name, 2050 self.inputs[i], 2051 inputs[i].unit, 2052 inputs[i].unit.physical_type)) 2053 else: 2054 raise UnitsError("{0}: Units of input '{1}', {2} ({3})," 2055 " could not be " 2056 "converted to required input" 2057 " units of {4} ({5})".format( 2058 name, 2059 self.inputs[i], 2060 inputs[i].unit, 2061 inputs[i].unit.physical_type, 2062 input_unit, 2063 input_unit.physical_type)) 2064 else: 2065 2066 # If we allow dimensionless input, we add the units to the 2067 # input values without conversion, otherwise we raise an 2068 # exception. 2069 2070 if (not self.input_units_allow_dimensionless[input_name] and 2071 input_unit is not dimensionless_unscaled and 2072 input_unit is not None): 2073 if np.any(inputs[i] != 0): 2074 raise UnitsError("{0}: Units of input '{1}', (dimensionless), could not be " 2075 "converted to required input units of " 2076 "{2} ({3})".format(name, self.inputs[i], input_unit, 2077 input_unit.physical_type)) 2078 return inputs 2079 2080 def _process_output_units(self, inputs, outputs): 2081 inputs_are_quantity = any([isinstance(i, Quantity) for i in inputs]) 2082 if self.return_units and inputs_are_quantity: 2083 # We allow a non-iterable unit only if there is one output 2084 if self.n_outputs == 1 and not isiterable(self.return_units): 2085 return_units = {self.outputs[0]: self.return_units} 2086 else: 2087 return_units = self.return_units 2088 2089 outputs = tuple([Quantity(out, return_units.get(out_name, None), subok=True) 2090 for out, out_name in zip(outputs, self.outputs)]) 2091 return outputs 2092 2093 @staticmethod 2094 def _prepare_output_single_model(output, broadcast_shape): 2095 if broadcast_shape is not None: 2096 if not broadcast_shape: 2097 return output.item() 2098 else: 2099 try: 2100 return output.reshape(broadcast_shape) 2101 except ValueError: 2102 try: 2103 return output.item() 2104 except ValueError: 2105 return output 2106 2107 return output 2108 2109 def _prepare_outputs_single_model(self, outputs, broadcasted_shapes): 2110 outputs = list(outputs) 2111 for idx, output in enumerate(outputs): 2112 try: 2113 broadcast_shape = check_broadcast(*broadcasted_shapes[0]) 2114 except (IndexError, TypeError): 2115 broadcast_shape = broadcasted_shapes[0][idx] 2116 2117 outputs[idx] = self._prepare_output_single_model(output, broadcast_shape) 2118 2119 return tuple(outputs) 2120 2121 def _prepare_outputs_model_set(self, outputs, broadcasted_shapes, model_set_axis): 2122 pivots = broadcasted_shapes[0] 2123 # If model_set_axis = False was passed then use 2124 # self._model_set_axis to format the output. 2125 if model_set_axis is None or model_set_axis is False: 2126 model_set_axis = self.model_set_axis 2127 outputs = list(outputs) 2128 for idx, output in enumerate(outputs): 2129 pivot = pivots[idx] 2130 if pivot < output.ndim and pivot != model_set_axis: 2131 outputs[idx] = np.rollaxis(output, pivot, 2132 model_set_axis) 2133 return tuple(outputs) 2134 2135 def prepare_outputs(self, broadcasted_shapes, *outputs, **kwargs): 2136 model_set_axis = kwargs.get('model_set_axis', None) 2137 2138 if len(self) == 1: 2139 return self._prepare_outputs_single_model(outputs, broadcasted_shapes) 2140 else: 2141 return self._prepare_outputs_model_set(outputs, broadcasted_shapes, model_set_axis) 2142 2143 def copy(self): 2144 """ 2145 Return a copy of this model. 2146 2147 Uses a deep copy so that all model attributes, including parameter 2148 values, are copied as well. 2149 """ 2150 2151 return copy.deepcopy(self) 2152 2153 def deepcopy(self): 2154 """ 2155 Return a deep copy of this model. 2156 2157 """ 2158 2159 return self.copy() 2160 2161 @sharedmethod 2162 def rename(self, name): 2163 """ 2164 Return a copy of this model with a new name. 2165 """ 2166 new_model = self.copy() 2167 new_model._name = name 2168 return new_model 2169 2170 def coerce_units( 2171 self, 2172 input_units=None, 2173 return_units=None, 2174 input_units_equivalencies=None, 2175 input_units_allow_dimensionless=False 2176 ): 2177 """ 2178 Attach units to this (unitless) model. 2179 2180 Parameters 2181 ---------- 2182 input_units : dict or tuple, optional 2183 Input units to attach. If dict, each key is the name of a model input, 2184 and the value is the unit to attach. If tuple, the elements are units 2185 to attach in order corresponding to `Model.inputs`. 2186 return_units : dict or tuple, optional 2187 Output units to attach. If dict, each key is the name of a model output, 2188 and the value is the unit to attach. If tuple, the elements are units 2189 to attach in order corresponding to `Model.outputs`. 2190 input_units_equivalencies : dict, optional 2191 Default equivalencies to apply to input values. If set, this should be a 2192 dictionary where each key is a string that corresponds to one of the 2193 model inputs. 2194 input_units_allow_dimensionless : bool or dict, optional 2195 Allow dimensionless input. If this is True, input values to evaluate will 2196 gain the units specified in input_units. If this is a dictionary then it 2197 should map input name to a bool to allow dimensionless numbers for that 2198 input. 2199 2200 Returns 2201 ------- 2202 `CompoundModel` 2203 A `CompoundModel` composed of the current model plus 2204 `~astropy.modeling.mappings.UnitsMapping` model(s) that attach the units. 2205 2206 Raises 2207 ------ 2208 ValueError 2209 If the current model already has units. 2210 2211 Examples 2212 -------- 2213 2214 Wrapping a unitless model to require and convert units: 2215 2216 >>> from astropy.modeling.models import Polynomial1D 2217 >>> from astropy import units as u 2218 >>> poly = Polynomial1D(1, c0=1, c1=2) 2219 >>> model = poly.coerce_units((u.m,), (u.s,)) 2220 >>> model(u.Quantity(10, u.m)) # doctest: +FLOAT_CMP 2221 <Quantity 21. s> 2222 >>> model(u.Quantity(1000, u.cm)) # doctest: +FLOAT_CMP 2223 <Quantity 21. s> 2224 >>> model(u.Quantity(10, u.cm)) # doctest: +FLOAT_CMP 2225 <Quantity 1.2 s> 2226 2227 Wrapping a unitless model but still permitting unitless input: 2228 2229 >>> from astropy.modeling.models import Polynomial1D 2230 >>> from astropy import units as u 2231 >>> poly = Polynomial1D(1, c0=1, c1=2) 2232 >>> model = poly.coerce_units((u.m,), (u.s,), input_units_allow_dimensionless=True) 2233 >>> model(u.Quantity(10, u.m)) # doctest: +FLOAT_CMP 2234 <Quantity 21. s> 2235 >>> model(10) # doctest: +FLOAT_CMP 2236 <Quantity 21. s> 2237 """ 2238 from .mappings import UnitsMapping 2239 2240 result = self 2241 2242 if input_units is not None: 2243 if self.input_units is not None: 2244 model_units = self.input_units 2245 else: 2246 model_units = {} 2247 2248 for unit in [model_units.get(i) for i in self.inputs]: 2249 if unit is not None and unit != dimensionless_unscaled: 2250 raise ValueError("Cannot specify input_units for model with existing input units") 2251 2252 if isinstance(input_units, dict): 2253 if input_units.keys() != set(self.inputs): 2254 message = ( 2255 f"""input_units keys ({", ".join(input_units.keys())}) """ 2256 f"""do not match model inputs ({", ".join(self.inputs)})""" 2257 ) 2258 raise ValueError(message) 2259 input_units = [input_units[i] for i in self.inputs] 2260 2261 if len(input_units) != self.n_inputs: 2262 message = ( 2263 "input_units length does not match n_inputs: " 2264 f"expected {self.n_inputs}, received {len(input_units)}" 2265 ) 2266 raise ValueError(message) 2267 2268 mapping = tuple((unit, model_units.get(i)) for i, unit in zip(self.inputs, input_units)) 2269 input_mapping = UnitsMapping( 2270 mapping, 2271 input_units_equivalencies=input_units_equivalencies, 2272 input_units_allow_dimensionless=input_units_allow_dimensionless 2273 ) 2274 input_mapping.inputs = self.inputs 2275 input_mapping.outputs = self.inputs 2276 result = input_mapping | result 2277 2278 if return_units is not None: 2279 if self.return_units is not None: 2280 model_units = self.return_units 2281 else: 2282 model_units = {} 2283 2284 for unit in [model_units.get(i) for i in self.outputs]: 2285 if unit is not None and unit != dimensionless_unscaled: 2286 raise ValueError("Cannot specify return_units for model with existing output units") 2287 2288 if isinstance(return_units, dict): 2289 if return_units.keys() != set(self.outputs): 2290 message = ( 2291 f"""return_units keys ({", ".join(return_units.keys())}) """ 2292 f"""do not match model outputs ({", ".join(self.outputs)})""" 2293 ) 2294 raise ValueError(message) 2295 return_units = [return_units[i] for i in self.outputs] 2296 2297 if len(return_units) != self.n_outputs: 2298 message = ( 2299 "return_units length does not match n_outputs: " 2300 f"expected {self.n_outputs}, received {len(return_units)}" 2301 ) 2302 raise ValueError(message) 2303 2304 mapping = tuple((model_units.get(i), unit) for i, unit in zip(self.outputs, return_units)) 2305 return_mapping = UnitsMapping(mapping) 2306 return_mapping.inputs = self.outputs 2307 return_mapping.outputs = self.outputs 2308 result = result | return_mapping 2309 2310 return result 2311 2312 @property 2313 def n_submodels(self): 2314 """ 2315 Return the number of components in a single model, which is 2316 obviously 1. 2317 """ 2318 return 1 2319 2320 def _initialize_constraints(self, kwargs): 2321 """ 2322 Pop parameter constraint values off the keyword arguments passed to 2323 `Model.__init__` and store them in private instance attributes. 2324 """ 2325 2326 # Pop any constraints off the keyword arguments 2327 for constraint in self.parameter_constraints: 2328 values = kwargs.pop(constraint, {}) 2329 for ckey, cvalue in values.items(): 2330 param = getattr(self, ckey) 2331 setattr(param, constraint, cvalue) 2332 self._mconstraints = {} 2333 for constraint in self.model_constraints: 2334 values = kwargs.pop(constraint, []) 2335 self._mconstraints[constraint] = values 2336 2337 def _initialize_parameters(self, args, kwargs): 2338 """ 2339 Initialize the _parameters array that stores raw parameter values for 2340 all parameter sets for use with vectorized fitting algorithms; on 2341 FittableModels the _param_name attributes actually just reference 2342 slices of this array. 2343 """ 2344 n_models = kwargs.pop('n_models', None) 2345 2346 if not (n_models is None or 2347 (isinstance(n_models, (int, np.integer)) and n_models >= 1)): 2348 raise ValueError( 2349 "n_models must be either None (in which case it is " 2350 "determined from the model_set_axis of the parameter initial " 2351 "values) or it must be a positive integer " 2352 "(got {0!r})".format(n_models)) 2353 2354 model_set_axis = kwargs.pop('model_set_axis', None) 2355 if model_set_axis is None: 2356 if n_models is not None and n_models > 1: 2357 # Default to zero 2358 model_set_axis = 0 2359 else: 2360 # Otherwise disable 2361 model_set_axis = False 2362 else: 2363 if not (model_set_axis is False or 2364 np.issubdtype(type(model_set_axis), np.integer)): 2365 raise ValueError( 2366 "model_set_axis must be either False or an integer " 2367 "specifying the parameter array axis to map to each " 2368 "model in a set of models (got {0!r}).".format( 2369 model_set_axis)) 2370 2371 # Process positional arguments by matching them up with the 2372 # corresponding parameters in self.param_names--if any also appear as 2373 # keyword arguments this presents a conflict 2374 params = set() 2375 if len(args) > len(self.param_names): 2376 raise TypeError( 2377 "{0}.__init__() takes at most {1} positional arguments ({2} " 2378 "given)".format(self.__class__.__name__, len(self.param_names), 2379 len(args))) 2380 2381 self._model_set_axis = model_set_axis 2382 self._param_metrics = defaultdict(dict) 2383 2384 for idx, arg in enumerate(args): 2385 if arg is None: 2386 # A value of None implies using the default value, if exists 2387 continue 2388 # We use quantity_asanyarray here instead of np.asanyarray because 2389 # if any of the arguments are quantities, we need to return a 2390 # Quantity object not a plain Numpy array. 2391 param_name = self.param_names[idx] 2392 params.add(param_name) 2393 if not isinstance(arg, Parameter): 2394 value = quantity_asanyarray(arg, dtype=float) 2395 else: 2396 value = arg 2397 self._initialize_parameter_value(param_name, value) 2398 2399 # At this point the only remaining keyword arguments should be 2400 # parameter names; any others are in error. 2401 for param_name in self.param_names: 2402 if param_name in kwargs: 2403 if param_name in params: 2404 raise TypeError( 2405 "{0}.__init__() got multiple values for parameter " 2406 "{1!r}".format(self.__class__.__name__, param_name)) 2407 value = kwargs.pop(param_name) 2408 if value is None: 2409 continue 2410 # We use quantity_asanyarray here instead of np.asanyarray 2411 # because if any of the arguments are quantities, we need 2412 # to return a Quantity object not a plain Numpy array. 2413 value = quantity_asanyarray(value, dtype=float) 2414 params.add(param_name) 2415 self._initialize_parameter_value(param_name, value) 2416 # Now deal with case where param_name is not supplied by args or kwargs 2417 for param_name in self.param_names: 2418 if param_name not in params: 2419 self._initialize_parameter_value(param_name, None) 2420 2421 if kwargs: 2422 # If any keyword arguments were left over at this point they are 2423 # invalid--the base class should only be passed the parameter 2424 # values, constraints, and param_dim 2425 for kwarg in kwargs: 2426 # Just raise an error on the first unrecognized argument 2427 raise TypeError( 2428 '{0}.__init__() got an unrecognized parameter ' 2429 '{1!r}'.format(self.__class__.__name__, kwarg)) 2430 2431 # Determine the number of model sets: If the model_set_axis is 2432 # None then there is just one parameter set; otherwise it is determined 2433 # by the size of that axis on the first parameter--if the other 2434 # parameters don't have the right number of axes or the sizes of their 2435 # model_set_axis don't match an error is raised 2436 if model_set_axis is not False and n_models != 1 and params: 2437 max_ndim = 0 2438 if model_set_axis < 0: 2439 min_ndim = abs(model_set_axis) 2440 else: 2441 min_ndim = model_set_axis + 1 2442 2443 for name in self.param_names: 2444 value = getattr(self, name) 2445 param_ndim = np.ndim(value) 2446 if param_ndim < min_ndim: 2447 raise InputParameterError( 2448 "All parameter values must be arrays of dimension " 2449 "at least {0} for model_set_axis={1} (the value " 2450 "given for {2!r} is only {3}-dimensional)".format( 2451 min_ndim, model_set_axis, name, param_ndim)) 2452 2453 max_ndim = max(max_ndim, param_ndim) 2454 2455 if n_models is None: 2456 # Use the dimensions of the first parameter to determine 2457 # the number of model sets 2458 n_models = value.shape[model_set_axis] 2459 elif value.shape[model_set_axis] != n_models: 2460 raise InputParameterError( 2461 "Inconsistent dimensions for parameter {0!r} for " 2462 "{1} model sets. The length of axis {2} must be the " 2463 "same for all input parameter values".format( 2464 name, n_models, model_set_axis)) 2465 2466 self._check_param_broadcast(max_ndim) 2467 else: 2468 if n_models is None: 2469 n_models = 1 2470 2471 self._check_param_broadcast(None) 2472 2473 self._n_models = n_models 2474 # now validate parameters 2475 for name in params: 2476 param = getattr(self, name) 2477 if param._validator is not None: 2478 param._validator(self, param.value) 2479 2480 def _initialize_parameter_value(self, param_name, value): 2481 """Mostly deals with consistency checks and determining unit issues.""" 2482 if isinstance(value, Parameter): 2483 self.__dict__[param_name] = value 2484 return 2485 param = getattr(self, param_name) 2486 # Use default if value is not provided 2487 if value is None: 2488 default = param.default 2489 if default is None: 2490 # No value was supplied for the parameter and the 2491 # parameter does not have a default, therefore the model 2492 # is underspecified 2493 raise TypeError("{0}.__init__() requires a value for parameter " 2494 "{1!r}".format(self.__class__.__name__, param_name)) 2495 value = default 2496 unit = param.unit 2497 else: 2498 if isinstance(value, Quantity): 2499 unit = value.unit 2500 value = value.value 2501 else: 2502 unit = None 2503 if unit is None and param.unit is not None: 2504 raise InputParameterError( 2505 "{0}.__init__() requires a Quantity for parameter " 2506 "{1!r}".format(self.__class__.__name__, param_name)) 2507 param._unit = unit 2508 param.internal_unit = None 2509 if param._setter is not None: 2510 if unit is not None: 2511 _val = param._setter(value * unit) 2512 else: 2513 _val = param._setter(value) 2514 if isinstance(_val, Quantity): 2515 param.internal_unit = _val.unit 2516 param._internal_value = np.array(_val.value) 2517 else: 2518 param.internal_unit = None 2519 param._internal_value = np.array(_val) 2520 else: 2521 param._value = np.array(value) 2522 2523 def _initialize_slices(self): 2524 2525 param_metrics = self._param_metrics 2526 total_size = 0 2527 2528 for name in self.param_names: 2529 param = getattr(self, name) 2530 value = param.value 2531 param_size = np.size(value) 2532 param_shape = np.shape(value) 2533 param_slice = slice(total_size, total_size + param_size) 2534 param_metrics[name]['slice'] = param_slice 2535 param_metrics[name]['shape'] = param_shape 2536 param_metrics[name]['size'] = param_size 2537 total_size += param_size 2538 self._parameters = np.empty(total_size, dtype=np.float64) 2539 2540 def _parameters_to_array(self): 2541 # Now set the parameter values (this will also fill 2542 # self._parameters) 2543 param_metrics = self._param_metrics 2544 for name in self.param_names: 2545 param = getattr(self, name) 2546 value = param.value 2547 if not isinstance(value, np.ndarray): 2548 value = np.array([value]) 2549 self._parameters[param_metrics[name]['slice']] = value.ravel() 2550 2551 # Finally validate all the parameters; we do this last so that 2552 # validators that depend on one of the other parameters' values will 2553 # work 2554 2555 def _array_to_parameters(self): 2556 param_metrics = self._param_metrics 2557 for name in self.param_names: 2558 param = getattr(self, name) 2559 value = self._parameters[param_metrics[name]['slice']] 2560 value.shape = param_metrics[name]['shape'] 2561 param.value = value 2562 2563 def _check_param_broadcast(self, max_ndim): 2564 """ 2565 This subroutine checks that all parameter arrays can be broadcast 2566 against each other, and determines the shapes parameters must have in 2567 order to broadcast correctly. 2568 2569 If model_set_axis is None this merely checks that the parameters 2570 broadcast and returns an empty dict if so. This mode is only used for 2571 single model sets. 2572 """ 2573 all_shapes = [] 2574 model_set_axis = self._model_set_axis 2575 2576 for name in self.param_names: 2577 param = getattr(self, name) 2578 value = param.value 2579 param_shape = np.shape(value) 2580 param_ndim = len(param_shape) 2581 if max_ndim is not None and param_ndim < max_ndim: 2582 # All arrays have the same number of dimensions up to the 2583 # model_set_axis dimension, but after that they may have a 2584 # different number of trailing axes. The number of trailing 2585 # axes must be extended for mutual compatibility. For example 2586 # if max_ndim = 3 and model_set_axis = 0, an array with the 2587 # shape (2, 2) must be extended to (2, 1, 2). However, an 2588 # array with shape (2,) is extended to (2, 1). 2589 new_axes = (1,) * (max_ndim - param_ndim) 2590 2591 if model_set_axis < 0: 2592 # Just need to prepend axes to make up the difference 2593 broadcast_shape = new_axes + param_shape 2594 else: 2595 broadcast_shape = (param_shape[:model_set_axis + 1] + 2596 new_axes + 2597 param_shape[model_set_axis + 1:]) 2598 self._param_metrics[name]['broadcast_shape'] = broadcast_shape 2599 all_shapes.append(broadcast_shape) 2600 else: 2601 all_shapes.append(param_shape) 2602 2603 # Now check mutual broadcastability of all shapes 2604 try: 2605 check_broadcast(*all_shapes) 2606 except IncompatibleShapeError as exc: 2607 shape_a, shape_a_idx, shape_b, shape_b_idx = exc.args 2608 param_a = self.param_names[shape_a_idx] 2609 param_b = self.param_names[shape_b_idx] 2610 2611 raise InputParameterError( 2612 "Parameter {0!r} of shape {1!r} cannot be broadcast with " 2613 "parameter {2!r} of shape {3!r}. All parameter arrays " 2614 "must have shapes that are mutually compatible according " 2615 "to the broadcasting rules.".format(param_a, shape_a, 2616 param_b, shape_b)) 2617 2618 def _param_sets(self, raw=False, units=False): 2619 """ 2620 Implementation of the Model.param_sets property. 2621 2622 This internal implementation has a ``raw`` argument which controls 2623 whether or not to return the raw parameter values (i.e. the values that 2624 are actually stored in the ._parameters array, as opposed to the values 2625 displayed to users. In most cases these are one in the same but there 2626 are currently a few exceptions. 2627 2628 Note: This is notably an overcomplicated device and may be removed 2629 entirely in the near future. 2630 """ 2631 2632 values = [] 2633 shapes = [] 2634 for name in self.param_names: 2635 param = getattr(self, name) 2636 2637 if raw and param._setter: 2638 value = param._internal_value 2639 else: 2640 value = param.value 2641 2642 broadcast_shape = self._param_metrics[name].get('broadcast_shape') 2643 if broadcast_shape is not None: 2644 value = value.reshape(broadcast_shape) 2645 2646 shapes.append(np.shape(value)) 2647 2648 if len(self) == 1: 2649 # Add a single param set axis to the parameter's value (thus 2650 # converting scalars to shape (1,) array values) for 2651 # consistency 2652 value = np.array([value]) 2653 2654 if units: 2655 if raw and param.internal_unit is not None: 2656 unit = param.internal_unit 2657 else: 2658 unit = param.unit 2659 if unit is not None: 2660 value = Quantity(value, unit) 2661 2662 values.append(value) 2663 2664 if len(set(shapes)) != 1 or units: 2665 # If the parameters are not all the same shape, converting to an 2666 # array is going to produce an object array 2667 # However the way Numpy creates object arrays is tricky in that it 2668 # will recurse into array objects in the list and break them up 2669 # into separate objects. Doing things this way ensures a 1-D 2670 # object array the elements of which are the individual parameter 2671 # arrays. There's not much reason to do this over returning a list 2672 # except for consistency 2673 psets = np.empty(len(values), dtype=object) 2674 psets[:] = values 2675 return psets 2676 2677 return np.array(values) 2678 2679 def _format_repr(self, args=[], kwargs={}, defaults={}): 2680 """ 2681 Internal implementation of ``__repr__``. 2682 2683 This is separated out for ease of use by subclasses that wish to 2684 override the default ``__repr__`` while keeping the same basic 2685 formatting. 2686 """ 2687 2688 parts = [repr(a) for a in args] 2689 2690 parts.extend( 2691 f"{name}={param_repr_oneline(getattr(self, name))}" 2692 for name in self.param_names) 2693 2694 if self.name is not None: 2695 parts.append(f'name={self.name!r}') 2696 2697 for kwarg, value in kwargs.items(): 2698 if kwarg in defaults and defaults[kwarg] == value: 2699 continue 2700 parts.append(f'{kwarg}={value!r}') 2701 2702 if len(self) > 1: 2703 parts.append(f"n_models={len(self)}") 2704 2705 return f"<{self.__class__.__name__}({', '.join(parts)})>" 2706 2707 def _format_str(self, keywords=[], defaults={}): 2708 """ 2709 Internal implementation of ``__str__``. 2710 2711 This is separated out for ease of use by subclasses that wish to 2712 override the default ``__str__`` while keeping the same basic 2713 formatting. 2714 """ 2715 2716 default_keywords = [ 2717 ('Model', self.__class__.__name__), 2718 ('Name', self.name), 2719 ('Inputs', self.inputs), 2720 ('Outputs', self.outputs), 2721 ('Model set size', len(self)) 2722 ] 2723 2724 parts = [f'{keyword}: {value}' 2725 for keyword, value in default_keywords 2726 if value is not None] 2727 2728 for keyword, value in keywords: 2729 if keyword.lower() in defaults and defaults[keyword.lower()] == value: 2730 continue 2731 parts.append(f'{keyword}: {value}') 2732 parts.append('Parameters:') 2733 2734 if len(self) == 1: 2735 columns = [[getattr(self, name).value] 2736 for name in self.param_names] 2737 else: 2738 columns = [getattr(self, name).value 2739 for name in self.param_names] 2740 2741 if columns: 2742 param_table = Table(columns, names=self.param_names) 2743 # Set units on the columns 2744 for name in self.param_names: 2745 param_table[name].unit = getattr(self, name).unit 2746 parts.append(indent(str(param_table), width=4)) 2747 2748 return '\n'.join(parts) 2749 2750 2751class FittableModel(Model): 2752 """ 2753 Base class for models that can be fitted using the built-in fitting 2754 algorithms. 2755 """ 2756 2757 linear = False 2758 # derivative with respect to parameters 2759 fit_deriv = None 2760 """ 2761 Function (similar to the model's `~Model.evaluate`) to compute the 2762 derivatives of the model with respect to its parameters, for use by fitting 2763 algorithms. In other words, this computes the Jacobian matrix with respect 2764 to the model's parameters. 2765 """ 2766 # Flag that indicates if the model derivatives with respect to parameters 2767 # are given in columns or rows 2768 col_fit_deriv = True 2769 fittable = True 2770 2771 2772class Fittable1DModel(FittableModel): 2773 """ 2774 Base class for one-dimensional fittable models. 2775 2776 This class provides an easier interface to defining new models. 2777 Examples can be found in `astropy.modeling.functional_models`. 2778 """ 2779 n_inputs = 1 2780 n_outputs = 1 2781 _separable = True 2782 2783 2784class Fittable2DModel(FittableModel): 2785 """ 2786 Base class for two-dimensional fittable models. 2787 2788 This class provides an easier interface to defining new models. 2789 Examples can be found in `astropy.modeling.functional_models`. 2790 """ 2791 2792 n_inputs = 2 2793 n_outputs = 1 2794 2795 2796def _make_arithmetic_operator(oper): 2797 # We don't bother with tuple unpacking here for efficiency's sake, but for 2798 # documentation purposes: 2799 # 2800 # f_eval, f_n_inputs, f_n_outputs = f 2801 # 2802 # and similarly for g 2803 def op(f, g): 2804 return (make_binary_operator_eval(oper, f[0], g[0]), f[1], f[2]) 2805 2806 return op 2807 2808 2809def _composition_operator(f, g): 2810 # We don't bother with tuple unpacking here for efficiency's sake, but for 2811 # documentation purposes: 2812 # 2813 # f_eval, f_n_inputs, f_n_outputs = f 2814 # 2815 # and similarly for g 2816 return (lambda inputs, params: g[0](f[0](inputs, params), params), 2817 f[1], g[2]) 2818 2819 2820def _join_operator(f, g): 2821 # We don't bother with tuple unpacking here for efficiency's sake, but for 2822 # documentation purposes: 2823 # 2824 # f_eval, f_n_inputs, f_n_outputs = f 2825 # 2826 # and similarly for g 2827 return (lambda inputs, params: (f[0](inputs[:f[1]], params) + 2828 g[0](inputs[f[1]:], params)), 2829 f[1] + g[1], f[2] + g[2]) 2830 2831 2832BINARY_OPERATORS = { 2833 '+': _make_arithmetic_operator(operator.add), 2834 '-': _make_arithmetic_operator(operator.sub), 2835 '*': _make_arithmetic_operator(operator.mul), 2836 '/': _make_arithmetic_operator(operator.truediv), 2837 '**': _make_arithmetic_operator(operator.pow), 2838 '|': _composition_operator, 2839 '&': _join_operator 2840} 2841 2842SPECIAL_OPERATORS = _SpecialOperatorsDict() 2843 2844 2845def _add_special_operator(sop_name, sop): 2846 return SPECIAL_OPERATORS.add(sop_name, sop) 2847 2848 2849class CompoundModel(Model): 2850 ''' 2851 Base class for compound models. 2852 2853 While it can be used directly, the recommended way 2854 to combine models is through the model operators. 2855 ''' 2856 2857 def __init__(self, op, left, right, name=None): 2858 self.__dict__['_param_names'] = None 2859 self._n_submodels = None 2860 self.op = op 2861 self.left = left 2862 self.right = right 2863 self._bounding_box = None 2864 self._user_bounding_box = None 2865 self._leaflist = None 2866 self._tdict = None 2867 self._parameters = None 2868 self._parameters_ = None 2869 self._param_metrics = None 2870 2871 if op != 'fix_inputs' and len(left) != len(right): 2872 raise ValueError( 2873 'Both operands must have equal values for n_models') 2874 self._n_models = len(left) 2875 2876 if op != 'fix_inputs' and ((left.model_set_axis != right.model_set_axis) 2877 or left.model_set_axis): # not False and not 0 2878 raise ValueError("model_set_axis must be False or 0 and consistent for operands") 2879 self._model_set_axis = left.model_set_axis 2880 2881 if op in ['+', '-', '*', '/', '**'] or op in SPECIAL_OPERATORS: 2882 if (left.n_inputs != right.n_inputs) or \ 2883 (left.n_outputs != right.n_outputs): 2884 raise ModelDefinitionError( 2885 'Both operands must match numbers of inputs and outputs') 2886 self.n_inputs = left.n_inputs 2887 self.n_outputs = left.n_outputs 2888 self.inputs = left.inputs 2889 self.outputs = left.outputs 2890 elif op == '&': 2891 self.n_inputs = left.n_inputs + right.n_inputs 2892 self.n_outputs = left.n_outputs + right.n_outputs 2893 self.inputs = combine_labels(left.inputs, right.inputs) 2894 self.outputs = combine_labels(left.outputs, right.outputs) 2895 elif op == '|': 2896 if left.n_outputs != right.n_inputs: 2897 raise ModelDefinitionError( 2898 "Unsupported operands for |: {0} (n_inputs={1}, " 2899 "n_outputs={2}) and {3} (n_inputs={4}, n_outputs={5}); " 2900 "n_outputs for the left-hand model must match n_inputs " 2901 "for the right-hand model.".format( 2902 left.name, left.n_inputs, left.n_outputs, right.name, 2903 right.n_inputs, right.n_outputs)) 2904 2905 self.n_inputs = left.n_inputs 2906 self.n_outputs = right.n_outputs 2907 self.inputs = left.inputs 2908 self.outputs = right.outputs 2909 elif op == 'fix_inputs': 2910 if not isinstance(left, Model): 2911 raise ValueError('First argument to "fix_inputs" must be an instance of an astropy Model.') 2912 if not isinstance(right, dict): 2913 raise ValueError('Expected a dictionary for second argument of "fix_inputs".') 2914 2915 # Dict keys must match either possible indices 2916 # for model on left side, or names for inputs. 2917 self.n_inputs = left.n_inputs - len(right) 2918 # Assign directly to the private attribute (instead of using the setter) 2919 # to avoid asserting the new number of outputs matches the old one. 2920 self._outputs = left.outputs 2921 self.n_outputs = left.n_outputs 2922 newinputs = list(left.inputs) 2923 keys = right.keys() 2924 input_ind = [] 2925 for key in keys: 2926 if np.issubdtype(type(key), np.integer): 2927 if key >= left.n_inputs or key < 0: 2928 raise ValueError( 2929 'Substitution key integer value ' 2930 'not among possible input choices.') 2931 if key in input_ind: 2932 raise ValueError("Duplicate specification of " 2933 "same input (index/name).") 2934 input_ind.append(key) 2935 elif isinstance(key, str): 2936 if key not in left.inputs: 2937 raise ValueError( 2938 'Substitution key string not among possible ' 2939 'input choices.') 2940 # Check to see it doesn't match positional 2941 # specification. 2942 ind = left.inputs.index(key) 2943 if ind in input_ind: 2944 raise ValueError("Duplicate specification of " 2945 "same input (index/name).") 2946 input_ind.append(ind) 2947 # Remove substituted inputs 2948 input_ind.sort() 2949 input_ind.reverse() 2950 for ind in input_ind: 2951 del newinputs[ind] 2952 self.inputs = tuple(newinputs) 2953 # Now check to see if the input model has bounding_box defined. 2954 # If so, remove the appropriate dimensions and set it for this 2955 # instance. 2956 try: 2957 self.bounding_box = \ 2958 self.left.bounding_box.fix_inputs(self, right) 2959 except NotImplementedError: 2960 pass 2961 2962 else: 2963 raise ModelDefinitionError('Illegal operator: ', self.op) 2964 self.name = name 2965 self._fittable = None 2966 self.fit_deriv = None 2967 self.col_fit_deriv = None 2968 if op in ('|', '+', '-'): 2969 self.linear = left.linear and right.linear 2970 else: 2971 self.linear = False 2972 self.eqcons = [] 2973 self.ineqcons = [] 2974 self.n_left_params = len(self.left.parameters) 2975 self._map_parameters() 2976 2977 def _get_left_inputs_from_args(self, args): 2978 return args[:self.left.n_inputs] 2979 2980 def _get_right_inputs_from_args(self, args): 2981 op = self.op 2982 if op == '&': 2983 # Args expected to look like (*left inputs, *right inputs, *left params, *right params) 2984 return args[self.left.n_inputs: self.left.n_inputs + self.right.n_inputs] 2985 elif op == '|' or op == 'fix_inputs': 2986 return None 2987 else: 2988 return args[:self.left.n_inputs] 2989 2990 def _get_left_params_from_args(self, args): 2991 op = self.op 2992 if op == '&': 2993 # Args expected to look like (*left inputs, *right inputs, *left params, *right params) 2994 n_inputs = self.left.n_inputs + self.right.n_inputs 2995 return args[n_inputs: n_inputs + self.n_left_params] 2996 else: 2997 return args[self.left.n_inputs: self.left.n_inputs + self.n_left_params] 2998 2999 def _get_right_params_from_args(self, args): 3000 op = self.op 3001 if op == 'fix_inputs': 3002 return None 3003 if op == '&': 3004 # Args expected to look like (*left inputs, *right inputs, *left params, *right params) 3005 return args[self.left.n_inputs + self.right.n_inputs + self.n_left_params:] 3006 else: 3007 return args[self.left.n_inputs + self.n_left_params:] 3008 3009 def _get_kwarg_model_parameters_as_positional(self, args, kwargs): 3010 # could do it with inserts but rebuilding seems like simpilist way 3011 3012 #TODO: Check if any param names are in kwargs maybe as an intersection of sets? 3013 if self.op == "&": 3014 new_args = list(args[:self.left.n_inputs + self.right.n_inputs]) 3015 args_pos = self.left.n_inputs + self.right.n_inputs 3016 else: 3017 new_args = list(args[:self.left.n_inputs]) 3018 args_pos = self.left.n_inputs 3019 3020 for param_name in self.param_names: 3021 kw_value = kwargs.pop(param_name, None) 3022 if kw_value is not None: 3023 value = kw_value 3024 else: 3025 try: 3026 value = args[args_pos] 3027 except IndexError: 3028 raise IndexError("Missing parameter or input") 3029 3030 args_pos += 1 3031 new_args.append(value) 3032 3033 return new_args, kwargs 3034 3035 def _apply_operators_to_value_lists(self, leftval, rightval, **kw): 3036 op = self.op 3037 if op == '+': 3038 return binary_operation(operator.add, leftval, rightval) 3039 elif op == '-': 3040 return binary_operation(operator.sub, leftval, rightval) 3041 elif op == '*': 3042 return binary_operation(operator.mul, leftval, rightval) 3043 elif op == '/': 3044 return binary_operation(operator.truediv, leftval, rightval) 3045 elif op == '**': 3046 return binary_operation(operator.pow, leftval, rightval) 3047 elif op == '&': 3048 if not isinstance(leftval, tuple): 3049 leftval = (leftval,) 3050 if not isinstance(rightval, tuple): 3051 rightval = (rightval,) 3052 return leftval + rightval 3053 elif op in SPECIAL_OPERATORS: 3054 return binary_operation(SPECIAL_OPERATORS[op], leftval, rightval) 3055 else: 3056 raise ModelDefinitionError('Unrecognized operator {op}') 3057 3058 def evaluate(self, *args, **kw): 3059 op = self.op 3060 args, kw = self._get_kwarg_model_parameters_as_positional(args, kw) 3061 left_inputs = self._get_left_inputs_from_args(args) 3062 left_params = self._get_left_params_from_args(args) 3063 3064 if op == 'fix_inputs': 3065 pos_index = dict(zip(self.left.inputs, range(self.left.n_inputs))) 3066 fixed_inputs = { 3067 key if np.issubdtype(type(key), np.integer) else pos_index[key]: value 3068 for key, value in self.right.items() 3069 } 3070 left_inputs = [ 3071 fixed_inputs[ind] if ind in fixed_inputs.keys() else inp 3072 for ind, inp in enumerate(left_inputs) 3073 ] 3074 3075 leftval = self.left.evaluate(*itertools.chain(left_inputs, left_params)) 3076 3077 if op == 'fix_inputs': 3078 return leftval 3079 3080 right_inputs = self._get_right_inputs_from_args(args) 3081 right_params = self._get_right_params_from_args(args) 3082 3083 if op == "|": 3084 if isinstance(leftval, tuple): 3085 return self.right.evaluate(*itertools.chain(leftval, right_params)) 3086 else: 3087 return self.right.evaluate(leftval, *right_params) 3088 else: 3089 rightval = self.right.evaluate(*itertools.chain(right_inputs, right_params)) 3090 3091 return self._apply_operators_to_value_lists(leftval, rightval, **kw) 3092 3093 @property 3094 def n_submodels(self): 3095 if self._leaflist is None: 3096 self._make_leaflist() 3097 return len(self._leaflist) 3098 3099 @property 3100 def submodel_names(self): 3101 """ Return the names of submodels in a ``CompoundModel``.""" 3102 if self._leaflist is None: 3103 self._make_leaflist() 3104 names = [item.name for item in self._leaflist] 3105 nonecount = 0 3106 newnames = [] 3107 for item in names: 3108 if item is None: 3109 newnames.append(f'None_{nonecount}') 3110 nonecount += 1 3111 else: 3112 newnames.append(item) 3113 return tuple(newnames) 3114 3115 def both_inverses_exist(self): 3116 ''' 3117 if both members of this compound model have inverses return True 3118 ''' 3119 warnings.warn( 3120 "CompoundModel.both_inverses_exist is deprecated. " 3121 "Use has_inverse instead.", 3122 AstropyDeprecationWarning 3123 ) 3124 3125 try: 3126 linv = self.left.inverse 3127 rinv = self.right.inverse 3128 except NotImplementedError: 3129 return False 3130 3131 return True 3132 3133 def _pre_evaluate(self, *args, **kwargs): 3134 """ 3135 CompoundModel specific input setup that needs to occur prior to 3136 model evaluation. 3137 3138 Note 3139 ---- 3140 All of the _pre_evaluate for each component model will be 3141 performed at the time that the individual model is evaluated. 3142 """ 3143 3144 # If equivalencies are provided, necessary to map parameters and pass 3145 # the leaflist as a keyword input for use by model evaluation so that 3146 # the compound model input names can be matched to the model input 3147 # names. 3148 if 'equivalencies' in kwargs: 3149 # Restructure to be useful for the individual model lookup 3150 kwargs['inputs_map'] = [(value[0], (value[1], key)) for 3151 key, value in self.inputs_map().items()] 3152 3153 # Setup actual model evaluation method 3154 def evaluate(_inputs): 3155 return self._evaluate(*_inputs, **kwargs) 3156 3157 return evaluate, args, None, kwargs 3158 3159 @property 3160 def _argnames(self): 3161 """No inputs should be used to determine input_shape when handling compound models""" 3162 return () 3163 3164 def _post_evaluate(self, inputs, outputs, broadcasted_shapes, with_bbox, **kwargs): 3165 """ 3166 CompoundModel specific post evaluation processing of outputs 3167 3168 Note 3169 ---- 3170 All of the _post_evaluate for each component model will be 3171 performed at the time that the individual model is evaluated. 3172 """ 3173 if self.get_bounding_box(with_bbox) is not None and self.n_outputs == 1: 3174 return outputs[0] 3175 return outputs 3176 3177 def _evaluate(self, *args, **kw): 3178 op = self.op 3179 if op != 'fix_inputs': 3180 if op != '&': 3181 leftval = self.left(*args, **kw) 3182 if op != '|': 3183 rightval = self.right(*args, **kw) 3184 else: 3185 rightval = None 3186 3187 else: 3188 leftval = self.left(*(args[:self.left.n_inputs]), **kw) 3189 rightval = self.right(*(args[self.left.n_inputs:]), **kw) 3190 3191 if op != "|": 3192 return self._apply_operators_to_value_lists(leftval, rightval, **kw) 3193 3194 elif op == '|': 3195 if isinstance(leftval, tuple): 3196 return self.right(*leftval, **kw) 3197 else: 3198 return self.right(leftval, **kw) 3199 3200 else: 3201 subs = self.right 3202 newargs = list(args) 3203 subinds = [] 3204 subvals = [] 3205 for key in subs.keys(): 3206 if np.issubdtype(type(key), np.integer): 3207 subinds.append(key) 3208 elif isinstance(key, str): 3209 ind = self.left.inputs.index(key) 3210 subinds.append(ind) 3211 subvals.append(subs[key]) 3212 # Turn inputs specified in kw into positional indices. 3213 # Names for compound inputs do not propagate to sub models. 3214 kwind = [] 3215 kwval = [] 3216 for kwkey in list(kw.keys()): 3217 if kwkey in self.inputs: 3218 ind = self.inputs.index(kwkey) 3219 if ind < len(args): 3220 raise ValueError("Keyword argument duplicates " 3221 "positional value supplied.") 3222 kwind.append(ind) 3223 kwval.append(kw[kwkey]) 3224 del kw[kwkey] 3225 # Build new argument list 3226 # Append keyword specified args first 3227 if kwind: 3228 kwargs = list(zip(kwind, kwval)) 3229 kwargs.sort() 3230 kwindsorted, kwvalsorted = list(zip(*kwargs)) 3231 newargs = newargs + list(kwvalsorted) 3232 if subinds: 3233 subargs = list(zip(subinds, subvals)) 3234 subargs.sort() 3235 # subindsorted, subvalsorted = list(zip(*subargs)) 3236 # The substitutions must be inserted in order 3237 for ind, val in subargs: 3238 newargs.insert(ind, val) 3239 return self.left(*newargs, **kw) 3240 3241 @property 3242 def param_names(self): 3243 """ An ordered list of parameter names.""" 3244 return self._param_names 3245 3246 def _make_leaflist(self): 3247 tdict = {} 3248 leaflist = [] 3249 make_subtree_dict(self, '', tdict, leaflist) 3250 self._leaflist = leaflist 3251 self._tdict = tdict 3252 3253 def __getattr__(self, name): 3254 """ 3255 If someone accesses an attribute not already defined, map the 3256 parameters, and then see if the requested attribute is one of 3257 the parameters 3258 """ 3259 # The following test is needed to avoid infinite recursion 3260 # caused by deepcopy. There may be other such cases discovered. 3261 if name == '__setstate__': 3262 raise AttributeError 3263 if name in self._param_names: 3264 return self.__dict__[name] 3265 else: 3266 raise AttributeError(f'Attribute "{name}" not found') 3267 3268 def __getitem__(self, index): 3269 if self._leaflist is None: 3270 self._make_leaflist() 3271 leaflist = self._leaflist 3272 tdict = self._tdict 3273 if isinstance(index, slice): 3274 if index.step: 3275 raise ValueError('Steps in slices not supported ' 3276 'for compound models') 3277 if index.start is not None: 3278 if isinstance(index.start, str): 3279 start = self._str_index_to_int(index.start) 3280 else: 3281 start = index.start 3282 else: 3283 start = 0 3284 if index.stop is not None: 3285 if isinstance(index.stop, str): 3286 stop = self._str_index_to_int(index.stop) 3287 else: 3288 stop = index.stop - 1 3289 else: 3290 stop = len(leaflist) - 1 3291 if index.stop == 0: 3292 raise ValueError("Slice endpoint cannot be 0") 3293 if start < 0: 3294 start = len(leaflist) + start 3295 if stop < 0: 3296 stop = len(leaflist) + stop 3297 # now search for matching node: 3298 if stop == start: # only single value, get leaf instead in code below 3299 index = start 3300 else: 3301 for key in tdict: 3302 node, leftind, rightind = tdict[key] 3303 if leftind == start and rightind == stop: 3304 return node 3305 raise IndexError("No appropriate subtree matches slice") 3306 if isinstance(index, type(0)): 3307 return leaflist[index] 3308 elif isinstance(index, type('')): 3309 return leaflist[self._str_index_to_int(index)] 3310 else: 3311 raise TypeError('index must be integer, slice, or model name string') 3312 3313 def _str_index_to_int(self, str_index): 3314 # Search through leaflist for item with that name 3315 found = [] 3316 for nleaf, leaf in enumerate(self._leaflist): 3317 if getattr(leaf, 'name', None) == str_index: 3318 found.append(nleaf) 3319 if len(found) == 0: 3320 raise IndexError(f"No component with name '{str_index}' found") 3321 if len(found) > 1: 3322 raise IndexError("Multiple components found using '{}' as name\n" 3323 "at indices {}".format(str_index, found)) 3324 return found[0] 3325 3326 @property 3327 def n_inputs(self): 3328 """ The number of inputs of a model.""" 3329 return self._n_inputs 3330 3331 @n_inputs.setter 3332 def n_inputs(self, value): 3333 self._n_inputs = value 3334 3335 @property 3336 def n_outputs(self): 3337 """ The number of outputs of a model.""" 3338 return self._n_outputs 3339 3340 @n_outputs.setter 3341 def n_outputs(self, value): 3342 self._n_outputs = value 3343 3344 @property 3345 def eqcons(self): 3346 return self._eqcons 3347 3348 @eqcons.setter 3349 def eqcons(self, value): 3350 self._eqcons = value 3351 3352 @property 3353 def ineqcons(self): 3354 return self._eqcons 3355 3356 @ineqcons.setter 3357 def ineqcons(self, value): 3358 self._eqcons = value 3359 3360 def traverse_postorder(self, include_operator=False): 3361 """ Postorder traversal of the CompoundModel tree.""" 3362 res = [] 3363 if isinstance(self.left, CompoundModel): 3364 res = res + self.left.traverse_postorder(include_operator) 3365 else: 3366 res = res + [self.left] 3367 if isinstance(self.right, CompoundModel): 3368 res = res + self.right.traverse_postorder(include_operator) 3369 else: 3370 res = res + [self.right] 3371 if include_operator: 3372 res.append(self.op) 3373 else: 3374 res.append(self) 3375 return res 3376 3377 def _format_expression(self, format_leaf=None): 3378 leaf_idx = 0 3379 operands = deque() 3380 3381 if format_leaf is None: 3382 format_leaf = lambda i, l: f'[{i}]' 3383 3384 for node in self.traverse_postorder(): 3385 if not isinstance(node, CompoundModel): 3386 operands.append(format_leaf(leaf_idx, node)) 3387 leaf_idx += 1 3388 continue 3389 3390 right = operands.pop() 3391 left = operands.pop() 3392 if node.op in OPERATOR_PRECEDENCE: 3393 oper_order = OPERATOR_PRECEDENCE[node.op] 3394 3395 if isinstance(node, CompoundModel): 3396 if (isinstance(node.left, CompoundModel) and 3397 OPERATOR_PRECEDENCE[node.left.op] < oper_order): 3398 left = f'({left})' 3399 if (isinstance(node.right, CompoundModel) and 3400 OPERATOR_PRECEDENCE[node.right.op] < oper_order): 3401 right = f'({right})' 3402 3403 operands.append(' '.join((left, node.op, right))) 3404 else: 3405 left = f'(({left}),' 3406 right = f'({right}))' 3407 operands.append(' '.join((node.op[0], left, right))) 3408 3409 return ''.join(operands) 3410 3411 def _format_components(self): 3412 if self._parameters_ is None: 3413 self._map_parameters() 3414 return '\n\n'.join('[{0}]: {1!r}'.format(idx, m) 3415 for idx, m in enumerate(self._leaflist)) 3416 3417 def __str__(self): 3418 expression = self._format_expression() 3419 components = self._format_components() 3420 keywords = [ 3421 ('Expression', expression), 3422 ('Components', '\n' + indent(components)) 3423 ] 3424 return super()._format_str(keywords=keywords) 3425 3426 def rename(self, name): 3427 self.name = name 3428 return self 3429 3430 @property 3431 def isleaf(self): 3432 return False 3433 3434 @property 3435 def inverse(self): 3436 if self.op == '|': 3437 return self.right.inverse | self.left.inverse 3438 elif self.op == '&': 3439 return self.left.inverse & self.right.inverse 3440 else: 3441 return NotImplemented 3442 3443 @property 3444 def fittable(self): 3445 """ Set the fittable attribute on a compound model.""" 3446 if self._fittable is None: 3447 if self._leaflist is None: 3448 self._map_parameters() 3449 self._fittable = all(m.fittable for m in self._leaflist) 3450 return self._fittable 3451 3452 __add__ = _model_oper('+') 3453 __sub__ = _model_oper('-') 3454 __mul__ = _model_oper('*') 3455 __truediv__ = _model_oper('/') 3456 __pow__ = _model_oper('**') 3457 __or__ = _model_oper('|') 3458 __and__ = _model_oper('&') 3459 3460 def _map_parameters(self): 3461 """ 3462 Map all the constituent model parameters to the compound object, 3463 renaming as necessary by appending a suffix number. 3464 3465 This can be an expensive operation, particularly for a complex 3466 expression tree. 3467 3468 All the corresponding parameter attributes are created that one 3469 expects for the Model class. 3470 3471 The parameter objects that the attributes point to are the same 3472 objects as in the constiutent models. Changes made to parameter 3473 values to either are seen by both. 3474 3475 Prior to calling this, none of the associated attributes will 3476 exist. This method must be called to make the model usable by 3477 fitting engines. 3478 3479 If oldnames=True, then parameters are named as in the original 3480 implementation of compound models. 3481 """ 3482 if self._parameters is not None: 3483 # do nothing 3484 return 3485 if self._leaflist is None: 3486 self._make_leaflist() 3487 self._parameters_ = {} 3488 param_map = {} 3489 self._param_names = [] 3490 for lindex, leaf in enumerate(self._leaflist): 3491 if not isinstance(leaf, dict): 3492 for param_name in leaf.param_names: 3493 param = getattr(leaf, param_name) 3494 new_param_name = f"{param_name}_{lindex}" 3495 self.__dict__[new_param_name] = param 3496 self._parameters_[new_param_name] = param 3497 self._param_names.append(new_param_name) 3498 param_map[new_param_name] = (lindex, param_name) 3499 self._param_metrics = {} 3500 self._param_map = param_map 3501 self._param_map_inverse = dict((v, k) for k, v in param_map.items()) 3502 self._initialize_slices() 3503 self._param_names = tuple(self._param_names) 3504 3505 def _initialize_slices(self): 3506 param_metrics = self._param_metrics 3507 total_size = 0 3508 3509 for name in self.param_names: 3510 param = getattr(self, name) 3511 value = param.value 3512 param_size = np.size(value) 3513 param_shape = np.shape(value) 3514 param_slice = slice(total_size, total_size + param_size) 3515 param_metrics[name] = {} 3516 param_metrics[name]['slice'] = param_slice 3517 param_metrics[name]['shape'] = param_shape 3518 param_metrics[name]['size'] = param_size 3519 total_size += param_size 3520 self._parameters = np.empty(total_size, dtype=np.float64) 3521 3522 @staticmethod 3523 def _recursive_lookup(branch, adict, key): 3524 if isinstance(branch, CompoundModel): 3525 return adict[key] 3526 return branch, key 3527 3528 def inputs_map(self): 3529 """ 3530 Map the names of the inputs to this ExpressionTree to the inputs to the leaf models. 3531 """ 3532 inputs_map = {} 3533 if not isinstance(self.op, str): # If we don't have an operator the mapping is trivial 3534 return {inp: (self, inp) for inp in self.inputs} 3535 3536 elif self.op == '|': 3537 if isinstance(self.left, CompoundModel): 3538 l_inputs_map = self.left.inputs_map() 3539 for inp in self.inputs: 3540 if isinstance(self.left, CompoundModel): 3541 inputs_map[inp] = l_inputs_map[inp] 3542 else: 3543 inputs_map[inp] = self.left, inp 3544 elif self.op == '&': 3545 if isinstance(self.left, CompoundModel): 3546 l_inputs_map = self.left.inputs_map() 3547 if isinstance(self.right, CompoundModel): 3548 r_inputs_map = self.right.inputs_map() 3549 for i, inp in enumerate(self.inputs): 3550 if i < len(self.left.inputs): # Get from left 3551 if isinstance(self.left, CompoundModel): 3552 inputs_map[inp] = l_inputs_map[self.left.inputs[i]] 3553 else: 3554 inputs_map[inp] = self.left, self.left.inputs[i] 3555 else: # Get from right 3556 if isinstance(self.right, CompoundModel): 3557 inputs_map[inp] = r_inputs_map[self.right.inputs[i - len(self.left.inputs)]] 3558 else: 3559 inputs_map[inp] = self.right, self.right.inputs[i - len(self.left.inputs)] 3560 elif self.op == 'fix_inputs': 3561 fixed_ind = list(self.right.keys()) 3562 ind = [list(self.left.inputs).index(i) if isinstance(i, str) else i for i in fixed_ind] 3563 inp_ind = list(range(self.left.n_inputs)) 3564 for i in ind: 3565 inp_ind.remove(i) 3566 for i in inp_ind: 3567 inputs_map[self.left.inputs[i]] = self.left, self.left.inputs[i] 3568 else: 3569 if isinstance(self.left, CompoundModel): 3570 l_inputs_map = self.left.inputs_map() 3571 for inp in self.left.inputs: 3572 if isinstance(self.left, CompoundModel): 3573 inputs_map[inp] = l_inputs_map[inp] 3574 else: 3575 inputs_map[inp] = self.left, inp 3576 return inputs_map 3577 3578 def _parameter_units_for_data_units(self, input_units, output_units): 3579 if self._leaflist is None: 3580 self._map_parameters() 3581 units_for_data = {} 3582 for imodel, model in enumerate(self._leaflist): 3583 units_for_data_leaf = model._parameter_units_for_data_units(input_units, output_units) 3584 for param_leaf in units_for_data_leaf: 3585 param = self._param_map_inverse[(imodel, param_leaf)] 3586 units_for_data[param] = units_for_data_leaf[param_leaf] 3587 return units_for_data 3588 3589 @property 3590 def input_units(self): 3591 inputs_map = self.inputs_map() 3592 input_units_dict = {key: inputs_map[key][0].input_units[orig_key] 3593 for key, (mod, orig_key) in inputs_map.items() 3594 if inputs_map[key][0].input_units is not None} 3595 if input_units_dict: 3596 return input_units_dict 3597 return None 3598 3599 @property 3600 def input_units_equivalencies(self): 3601 inputs_map = self.inputs_map() 3602 return {key: inputs_map[key][0].input_units_equivalencies[orig_key] 3603 for key, (mod, orig_key) in inputs_map.items() 3604 if inputs_map[key][0].input_units_equivalencies is not None} 3605 3606 @property 3607 def input_units_allow_dimensionless(self): 3608 inputs_map = self.inputs_map() 3609 return {key: inputs_map[key][0].input_units_allow_dimensionless[orig_key] 3610 for key, (mod, orig_key) in inputs_map.items()} 3611 3612 @property 3613 def input_units_strict(self): 3614 inputs_map = self.inputs_map() 3615 return {key: inputs_map[key][0].input_units_strict[orig_key] 3616 for key, (mod, orig_key) in inputs_map.items()} 3617 3618 @property 3619 def return_units(self): 3620 outputs_map = self.outputs_map() 3621 return {key: outputs_map[key][0].return_units[orig_key] 3622 for key, (mod, orig_key) in outputs_map.items() 3623 if outputs_map[key][0].return_units is not None} 3624 3625 def outputs_map(self): 3626 """ 3627 Map the names of the outputs to this ExpressionTree to the outputs to the leaf models. 3628 """ 3629 outputs_map = {} 3630 if not isinstance(self.op, str): # If we don't have an operator the mapping is trivial 3631 return {out: (self, out) for out in self.outputs} 3632 3633 elif self.op == '|': 3634 if isinstance(self.right, CompoundModel): 3635 r_outputs_map = self.right.outputs_map() 3636 for out in self.outputs: 3637 if isinstance(self.right, CompoundModel): 3638 outputs_map[out] = r_outputs_map[out] 3639 else: 3640 outputs_map[out] = self.right, out 3641 3642 elif self.op == '&': 3643 if isinstance(self.left, CompoundModel): 3644 l_outputs_map = self.left.outputs_map() 3645 if isinstance(self.right, CompoundModel): 3646 r_outputs_map = self.right.outputs_map() 3647 for i, out in enumerate(self.outputs): 3648 if i < len(self.left.outputs): # Get from left 3649 if isinstance(self.left, CompoundModel): 3650 outputs_map[out] = l_outputs_map[self.left.outputs[i]] 3651 else: 3652 outputs_map[out] = self.left, self.left.outputs[i] 3653 else: # Get from right 3654 if isinstance(self.right, CompoundModel): 3655 outputs_map[out] = r_outputs_map[self.right.outputs[i - len(self.left.outputs)]] 3656 else: 3657 outputs_map[out] = self.right, self.right.outputs[i - len(self.left.outputs)] 3658 elif self.op == 'fix_inputs': 3659 return self.left.outputs_map() 3660 else: 3661 if isinstance(self.left, CompoundModel): 3662 l_outputs_map = self.left.outputs_map() 3663 for out in self.left.outputs: 3664 if isinstance(self.left, CompoundModel): 3665 outputs_map[out] = l_outputs_map()[out] 3666 else: 3667 outputs_map[out] = self.left, out 3668 return outputs_map 3669 3670 @property 3671 def has_user_bounding_box(self): 3672 """ 3673 A flag indicating whether or not a custom bounding_box has been 3674 assigned to this model by a user, via assignment to 3675 ``model.bounding_box``. 3676 """ 3677 3678 return self._user_bounding_box is not None 3679 3680 def render(self, out=None, coords=None): 3681 """ 3682 Evaluate a model at fixed positions, respecting the ``bounding_box``. 3683 3684 The key difference relative to evaluating the model directly is that 3685 this method is limited to a bounding box if the `Model.bounding_box` 3686 attribute is set. 3687 3688 Parameters 3689 ---------- 3690 out : `numpy.ndarray`, optional 3691 An array that the evaluated model will be added to. If this is not 3692 given (or given as ``None``), a new array will be created. 3693 coords : array-like, optional 3694 An array to be used to translate from the model's input coordinates 3695 to the ``out`` array. It should have the property that 3696 ``self(coords)`` yields the same shape as ``out``. If ``out`` is 3697 not specified, ``coords`` will be used to determine the shape of 3698 the returned array. If this is not provided (or None), the model 3699 will be evaluated on a grid determined by `Model.bounding_box`. 3700 3701 Returns 3702 ------- 3703 out : `numpy.ndarray` 3704 The model added to ``out`` if ``out`` is not ``None``, or else a 3705 new array from evaluating the model over ``coords``. 3706 If ``out`` and ``coords`` are both `None`, the returned array is 3707 limited to the `Model.bounding_box` limits. If 3708 `Model.bounding_box` is `None`, ``arr`` or ``coords`` must be 3709 passed. 3710 3711 Raises 3712 ------ 3713 ValueError 3714 If ``coords`` are not given and the the `Model.bounding_box` of 3715 this model is not set. 3716 3717 Examples 3718 -------- 3719 :ref:`astropy:bounding-boxes` 3720 """ 3721 3722 bbox = self.get_bounding_box() 3723 3724 ndim = self.n_inputs 3725 3726 if (coords is None) and (out is None) and (bbox is None): 3727 raise ValueError('If no bounding_box is set, ' 3728 'coords or out must be input.') 3729 3730 # for consistent indexing 3731 if ndim == 1: 3732 if coords is not None: 3733 coords = [coords] 3734 if bbox is not None: 3735 bbox = [bbox] 3736 3737 if coords is not None: 3738 coords = np.asanyarray(coords, dtype=float) 3739 # Check dimensions match out and model 3740 assert len(coords) == ndim 3741 if out is not None: 3742 if coords[0].shape != out.shape: 3743 raise ValueError('inconsistent shape of the output.') 3744 else: 3745 out = np.zeros(coords[0].shape) 3746 3747 if out is not None: 3748 out = np.asanyarray(out) 3749 if out.ndim != ndim: 3750 raise ValueError('the array and model must have the same ' 3751 'number of dimensions.') 3752 3753 if bbox is not None: 3754 # Assures position is at center pixel, important when using 3755 # add_array. 3756 pd = np.array([(np.mean(bb), np.ceil((bb[1] - bb[0]) / 2)) 3757 for bb in bbox]).astype(int).T 3758 pos, delta = pd 3759 3760 if coords is not None: 3761 sub_shape = tuple(delta * 2 + 1) 3762 sub_coords = np.array([extract_array(c, sub_shape, pos) 3763 for c in coords]) 3764 else: 3765 limits = [slice(p - d, p + d + 1, 1) for p, d in pd.T] 3766 sub_coords = np.mgrid[limits] 3767 3768 sub_coords = sub_coords[::-1] 3769 3770 if out is None: 3771 out = self(*sub_coords) 3772 else: 3773 try: 3774 out = add_array(out, self(*sub_coords), pos) 3775 except ValueError: 3776 raise ValueError( 3777 'The `bounding_box` is larger than the input out in ' 3778 'one or more dimensions. Set ' 3779 '`model.bounding_box = None`.') 3780 else: 3781 if coords is None: 3782 im_shape = out.shape 3783 limits = [slice(i) for i in im_shape] 3784 coords = np.mgrid[limits] 3785 3786 coords = coords[::-1] 3787 3788 out += self(*coords) 3789 3790 return out 3791 3792 def replace_submodel(self, name, model): 3793 """ 3794 Construct a new `~astropy.modeling.CompoundModel` instance from an 3795 existing CompoundModel, replacing the named submodel with a new model. 3796 3797 In order to ensure that inverses and names are kept/reconstructed, it's 3798 necessary to rebuild the CompoundModel from the replaced node all the 3799 way back to the base. The original CompoundModel is left untouched. 3800 3801 Parameters 3802 ---------- 3803 name : str 3804 name of submodel to be replaced 3805 model : `~astropy.modeling.Model` 3806 replacement model 3807 """ 3808 submodels = [m for m in self.traverse_postorder() 3809 if getattr(m, 'name', None) == name] 3810 if submodels: 3811 if len(submodels) > 1: 3812 raise ValueError(f"More than one submodel named {name}") 3813 3814 old_model = submodels.pop() 3815 if len(old_model) != len(model): 3816 raise ValueError("New and old models must have equal values " 3817 "for n_models") 3818 3819 # Do this check first in order to raise a more helpful Exception, 3820 # although it would fail trying to construct the new CompoundModel 3821 if (old_model.n_inputs != model.n_inputs or 3822 old_model.n_outputs != model.n_outputs): 3823 raise ValueError("New model must match numbers of inputs and " 3824 "outputs of existing model") 3825 3826 tree = _get_submodel_path(self, name) 3827 while tree: 3828 branch = self.copy() 3829 for node in tree[:-1]: 3830 branch = getattr(branch, node) 3831 setattr(branch, tree[-1], model) 3832 model = CompoundModel(branch.op, branch.left, branch.right, 3833 name=branch.name) 3834 tree = tree[:-1] 3835 return model 3836 3837 else: 3838 raise ValueError(f"No submodels found named {name}") 3839 3840 3841def _get_submodel_path(model, name): 3842 """Find the route down a CompoundModel's tree to the model with the 3843 specified name (whether it's a leaf or not)""" 3844 if getattr(model, 'name', None) == name: 3845 return [] 3846 try: 3847 return ['left'] + _get_submodel_path(model.left, name) 3848 except (AttributeError, TypeError): 3849 pass 3850 try: 3851 return ['right'] + _get_submodel_path(model.right, name) 3852 except (AttributeError, TypeError): 3853 pass 3854 3855 3856def binary_operation(binoperator, left, right): 3857 ''' 3858 Perform binary operation. Operands may be matching tuples of operands. 3859 ''' 3860 if isinstance(left, tuple) and isinstance(right, tuple): 3861 return tuple([binoperator(item[0], item[1]) 3862 for item in zip(left, right)]) 3863 return binoperator(left, right) 3864 3865 3866def get_ops(tree, opset): 3867 """ 3868 Recursive function to collect operators used. 3869 """ 3870 if isinstance(tree, CompoundModel): 3871 opset.add(tree.op) 3872 get_ops(tree.left, opset) 3873 get_ops(tree.right, opset) 3874 else: 3875 return 3876 3877 3878def make_subtree_dict(tree, nodepath, tdict, leaflist): 3879 ''' 3880 Traverse a tree noting each node by a key that indicates all the 3881 left/right choices necessary to reach that node. Each key will 3882 reference a tuple that contains: 3883 3884 - reference to the compound model for that node. 3885 - left most index contained within that subtree 3886 (relative to all indices for the whole tree) 3887 - right most index contained within that subtree 3888 ''' 3889 # if this is a leaf, just append it to the leaflist 3890 if not hasattr(tree, 'isleaf'): 3891 leaflist.append(tree) 3892 else: 3893 leftmostind = len(leaflist) 3894 make_subtree_dict(tree.left, nodepath+'l', tdict, leaflist) 3895 make_subtree_dict(tree.right, nodepath+'r', tdict, leaflist) 3896 rightmostind = len(leaflist)-1 3897 tdict[nodepath] = (tree, leftmostind, rightmostind) 3898 3899 3900_ORDER_OF_OPERATORS = [('fix_inputs',), ('|',), ('&',), ('+', '-'), ('*', '/'), ('**',)] 3901OPERATOR_PRECEDENCE = {} 3902for idx, ops in enumerate(_ORDER_OF_OPERATORS): 3903 for op in ops: 3904 OPERATOR_PRECEDENCE[op] = idx 3905del idx, op, ops 3906 3907 3908def fix_inputs(modelinstance, values, bounding_boxes=None, selector_args=None): 3909 """ 3910 This function creates a compound model with one or more of the input 3911 values of the input model assigned fixed values (scalar or array). 3912 3913 Parameters 3914 ---------- 3915 modelinstance : `~astropy.modeling.Model` instance 3916 This is the model that one or more of the 3917 model input values will be fixed to some constant value. 3918 values : dict 3919 A dictionary where the key identifies which input to fix 3920 and its value is the value to fix it at. The key may either be the 3921 name of the input or a number reflecting its order in the inputs. 3922 3923 Examples 3924 -------- 3925 3926 >>> from astropy.modeling.models import Gaussian2D 3927 >>> g = Gaussian2D(1, 2, 3, 4, 5) 3928 >>> gv = fix_inputs(g, {0: 2.5}) 3929 3930 Results in a 1D function equivalent to Gaussian2D(1, 2, 3, 4, 5)(x=2.5, y) 3931 """ 3932 model = CompoundModel('fix_inputs', modelinstance, values) 3933 if bounding_boxes is not None: 3934 if selector_args is None: 3935 selector_args = tuple([(key, True) for key in values.keys()]) 3936 bbox = CompoundBoundingBox.validate(modelinstance, bounding_boxes, selector_args) 3937 _selector = bbox.selector_args.get_fixed_values(modelinstance, values) 3938 3939 model.bounding_box = bbox[_selector] 3940 return model 3941 3942 3943def bind_bounding_box(modelinstance, bounding_box, order='C'): 3944 """ 3945 Set a validated bounding box to a model instance. 3946 3947 Parameters 3948 ---------- 3949 modelinstance : `~astropy.modeling.Model` instance 3950 This is the model that the validated bounding box will be set on. 3951 bounding_box : tuple 3952 A bounding box tuple, see :ref:`astropy:bounding-boxes` for details 3953 order : str, optional 3954 The ordering of the bounding box tuple, can be either ``'C'`` or 3955 ``'F'``. 3956 """ 3957 modelinstance.bounding_box = ModelBoundingBox.validate(modelinstance, 3958 bounding_box, 3959 order=order) 3960 3961 3962def bind_compound_bounding_box(modelinstance, bounding_boxes, selector_args, 3963 create_selector=None, order='C'): 3964 """ 3965 Add a validated compound bounding box to a model instance. 3966 3967 Parameters 3968 ---------- 3969 modelinstance : `~astropy.modeling.Model` instance 3970 This is the model that the validated compound bounding box will be set on. 3971 bounding_boxes : dict 3972 A dictionary of bounding box tuples, see :ref:`astropy:bounding-boxes` 3973 for details. 3974 selector_args : list 3975 List of selector argument tuples to define selection for compound 3976 bounding box, see :ref:`astropy:bounding-boxes` for details. 3977 create_selector : callable, optional 3978 An optional callable with interface (selector_value, model) which 3979 can generate a bounding box based on a selector value and model if 3980 there is no bounding box in the compound bounding box listed under 3981 that selector value. Default is ``None``, meaning new bounding 3982 box entries will not be automatically generated. 3983 order : str, optional 3984 The ordering of the bounding box tuple, can be either ``'C'`` or 3985 ``'F'``. 3986 """ 3987 modelinstance.bounding_box = CompoundBoundingBox.validate(modelinstance, 3988 bounding_boxes, 3989 selector_args, 3990 create_selector, 3991 order=order) 3992 3993 3994def custom_model(*args, fit_deriv=None): 3995 """ 3996 Create a model from a user defined function. The inputs and parameters of 3997 the model will be inferred from the arguments of the function. 3998 3999 This can be used either as a function or as a decorator. See below for 4000 examples of both usages. 4001 4002 The model is separable only if there is a single input. 4003 4004 .. note:: 4005 4006 All model parameters have to be defined as keyword arguments with 4007 default values in the model function. Use `None` as a default argument 4008 value if you do not want to have a default value for that parameter. 4009 4010 The standard settable model properties can be configured by default 4011 using keyword arguments matching the name of the property; however, 4012 these values are not set as model "parameters". Moreover, users 4013 cannot use keyword arguments matching non-settable model properties, 4014 with the exception of ``n_outputs`` which should be set to the number of 4015 outputs of your function. 4016 4017 Parameters 4018 ---------- 4019 func : function 4020 Function which defines the model. It should take N positional 4021 arguments where ``N`` is dimensions of the model (the number of 4022 independent variable in the model), and any number of keyword arguments 4023 (the parameters). It must return the value of the model (typically as 4024 an array, but can also be a scalar for scalar inputs). This 4025 corresponds to the `~astropy.modeling.Model.evaluate` method. 4026 fit_deriv : function, optional 4027 Function which defines the Jacobian derivative of the model. I.e., the 4028 derivative with respect to the *parameters* of the model. It should 4029 have the same argument signature as ``func``, but should return a 4030 sequence where each element of the sequence is the derivative 4031 with respect to the corresponding argument. This corresponds to the 4032 :meth:`~astropy.modeling.FittableModel.fit_deriv` method. 4033 4034 Examples 4035 -------- 4036 Define a sinusoidal model function as a custom 1D model:: 4037 4038 >>> from astropy.modeling.models import custom_model 4039 >>> import numpy as np 4040 >>> def sine_model(x, amplitude=1., frequency=1.): 4041 ... return amplitude * np.sin(2 * np.pi * frequency * x) 4042 >>> def sine_deriv(x, amplitude=1., frequency=1.): 4043 ... return 2 * np.pi * amplitude * np.cos(2 * np.pi * frequency * x) 4044 >>> SineModel = custom_model(sine_model, fit_deriv=sine_deriv) 4045 4046 Create an instance of the custom model and evaluate it:: 4047 4048 >>> model = SineModel() 4049 >>> model(0.25) 4050 1.0 4051 4052 This model instance can now be used like a usual astropy model. 4053 4054 The next example demonstrates a 2D Moffat function model, and also 4055 demonstrates the support for docstrings (this example could also include 4056 a derivative, but it has been omitted for simplicity):: 4057 4058 >>> @custom_model 4059 ... def Moffat2D(x, y, amplitude=1.0, x_0=0.0, y_0=0.0, gamma=1.0, 4060 ... alpha=1.0): 4061 ... \"\"\"Two dimensional Moffat function.\"\"\" 4062 ... rr_gg = ((x - x_0) ** 2 + (y - y_0) ** 2) / gamma ** 2 4063 ... return amplitude * (1 + rr_gg) ** (-alpha) 4064 ... 4065 >>> print(Moffat2D.__doc__) 4066 Two dimensional Moffat function. 4067 >>> model = Moffat2D() 4068 >>> model(1, 1) # doctest: +FLOAT_CMP 4069 0.3333333333333333 4070 """ 4071 4072 if len(args) == 1 and callable(args[0]): 4073 return _custom_model_wrapper(args[0], fit_deriv=fit_deriv) 4074 elif not args: 4075 return functools.partial(_custom_model_wrapper, fit_deriv=fit_deriv) 4076 else: 4077 raise TypeError( 4078 "{0} takes at most one positional argument (the callable/" 4079 "function to be turned into a model. When used as a decorator " 4080 "it should be passed keyword arguments only (if " 4081 "any).".format(__name__)) 4082 4083 4084def _custom_model_inputs(func): 4085 """ 4086 Processes the inputs to the `custom_model`'s function into the appropriate 4087 categories. 4088 4089 Parameters 4090 ---------- 4091 func : callable 4092 4093 Returns 4094 ------- 4095 inputs : list 4096 list of evaluation inputs 4097 special_params : dict 4098 dictionary of model properties which require special treatment 4099 settable_params : dict 4100 dictionary of defaults for settable model properties 4101 params : dict 4102 dictionary of model parameters set by `custom_model`'s function 4103 """ 4104 inputs, parameters = get_inputs_and_params(func) 4105 4106 special = ['n_outputs'] 4107 settable = [attr for attr, value in vars(Model).items() 4108 if isinstance(value, property) and value.fset is not None] 4109 properties = [attr for attr, value in vars(Model).items() 4110 if isinstance(value, property) and value.fset is None and attr not in special] 4111 4112 special_params = {} 4113 settable_params = {} 4114 params = {} 4115 for param in parameters: 4116 if param.name in special: 4117 special_params[param.name] = param.default 4118 elif param.name in settable: 4119 settable_params[param.name] = param.default 4120 elif param.name in properties: 4121 raise ValueError(f"Parameter '{param.name}' cannot be a model property: {properties}.") 4122 else: 4123 params[param.name] = param.default 4124 4125 return inputs, special_params, settable_params, params 4126 4127 4128def _custom_model_wrapper(func, fit_deriv=None): 4129 """ 4130 Internal implementation `custom_model`. 4131 4132 When `custom_model` is called as a function its arguments are passed to 4133 this function, and the result of this function is returned. 4134 4135 When `custom_model` is used as a decorator a partial evaluation of this 4136 function is returned by `custom_model`. 4137 """ 4138 4139 if not callable(func): 4140 raise ModelDefinitionError( 4141 "func is not callable; it must be a function or other callable " 4142 "object") 4143 4144 if fit_deriv is not None and not callable(fit_deriv): 4145 raise ModelDefinitionError( 4146 "fit_deriv not callable; it must be a function or other " 4147 "callable object") 4148 4149 model_name = func.__name__ 4150 4151 inputs, special_params, settable_params, params = _custom_model_inputs(func) 4152 4153 if (fit_deriv is not None and 4154 len(fit_deriv.__defaults__) != len(params)): 4155 raise ModelDefinitionError("derivative function should accept " 4156 "same number of parameters as func.") 4157 4158 params = {param: Parameter(param, default=default) 4159 for param, default in params.items()} 4160 4161 mod = find_current_module(2) 4162 if mod: 4163 modname = mod.__name__ 4164 else: 4165 modname = '__main__' 4166 4167 members = { 4168 '__module__': str(modname), 4169 '__doc__': func.__doc__, 4170 'n_inputs': len(inputs), 4171 'n_outputs': special_params.pop('n_outputs', 1), 4172 'evaluate': staticmethod(func), 4173 '_settable_properties': settable_params 4174 } 4175 4176 if fit_deriv is not None: 4177 members['fit_deriv'] = staticmethod(fit_deriv) 4178 4179 members.update(params) 4180 4181 cls = type(model_name, (FittableModel,), members) 4182 cls._separable = True if (len(inputs) == 1) else False 4183 return cls 4184 4185 4186def render_model(model, arr=None, coords=None): 4187 """ 4188 Evaluates a model on an input array. Evaluation is limited to 4189 a bounding box if the `Model.bounding_box` attribute is set. 4190 4191 Parameters 4192 ---------- 4193 model : `Model` 4194 Model to be evaluated. 4195 arr : `numpy.ndarray`, optional 4196 Array on which the model is evaluated. 4197 coords : array-like, optional 4198 Coordinate arrays mapping to ``arr``, such that 4199 ``arr[coords] == arr``. 4200 4201 Returns 4202 ------- 4203 array : `numpy.ndarray` 4204 The model evaluated on the input ``arr`` or a new array from 4205 ``coords``. 4206 If ``arr`` and ``coords`` are both `None`, the returned array is 4207 limited to the `Model.bounding_box` limits. If 4208 `Model.bounding_box` is `None`, ``arr`` or ``coords`` must be passed. 4209 4210 Examples 4211 -------- 4212 :ref:`astropy:bounding-boxes` 4213 """ 4214 4215 bbox = model.bounding_box 4216 4217 if (coords is None) & (arr is None) & (bbox is None): 4218 raise ValueError('If no bounding_box is set,' 4219 'coords or arr must be input.') 4220 4221 # for consistent indexing 4222 if model.n_inputs == 1: 4223 if coords is not None: 4224 coords = [coords] 4225 if bbox is not None: 4226 bbox = [bbox] 4227 4228 if arr is not None: 4229 arr = arr.copy() 4230 # Check dimensions match model 4231 if arr.ndim != model.n_inputs: 4232 raise ValueError('number of array dimensions inconsistent with ' 4233 'number of model inputs.') 4234 if coords is not None: 4235 # Check dimensions match arr and model 4236 coords = np.array(coords) 4237 if len(coords) != model.n_inputs: 4238 raise ValueError('coordinate length inconsistent with the number ' 4239 'of model inputs.') 4240 if arr is not None: 4241 if coords[0].shape != arr.shape: 4242 raise ValueError('coordinate shape inconsistent with the ' 4243 'array shape.') 4244 else: 4245 arr = np.zeros(coords[0].shape) 4246 4247 if bbox is not None: 4248 # assures position is at center pixel, important when using add_array 4249 pd = pos, delta = np.array([(np.mean(bb), np.ceil((bb[1] - bb[0]) / 2)) 4250 for bb in bbox]).astype(int).T 4251 4252 if coords is not None: 4253 sub_shape = tuple(delta * 2 + 1) 4254 sub_coords = np.array([extract_array(c, sub_shape, pos) 4255 for c in coords]) 4256 else: 4257 limits = [slice(p - d, p + d + 1, 1) for p, d in pd.T] 4258 sub_coords = np.mgrid[limits] 4259 4260 sub_coords = sub_coords[::-1] 4261 4262 if arr is None: 4263 arr = model(*sub_coords) 4264 else: 4265 try: 4266 arr = add_array(arr, model(*sub_coords), pos) 4267 except ValueError: 4268 raise ValueError('The `bounding_box` is larger than the input' 4269 ' arr in one or more dimensions. Set ' 4270 '`model.bounding_box = None`.') 4271 else: 4272 4273 if coords is None: 4274 im_shape = arr.shape 4275 limits = [slice(i) for i in im_shape] 4276 coords = np.mgrid[limits] 4277 4278 arr += model(*coords[::-1]) 4279 4280 return arr 4281 4282 4283def hide_inverse(model): 4284 """ 4285 This is a convenience function intended to disable automatic generation 4286 of the inverse in compound models by disabling one of the constituent 4287 model's inverse. This is to handle cases where user provided inverse 4288 functions are not compatible within an expression. 4289 4290 Example: 4291 compound_model.inverse = hide_inverse(m1) + m2 + m3 4292 4293 This will insure that the defined inverse itself won't attempt to 4294 build its own inverse, which would otherwise fail in this example 4295 (e.g., m = m1 + m2 + m3 happens to raises an exception for this 4296 reason.) 4297 4298 Note that this permanently disables it. To prevent that either copy 4299 the model or restore the inverse later. 4300 """ 4301 del model.inverse 4302 return model 4303