1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
3"""
4This module defines base classes for all models.  The base class of all
5models is `~astropy.modeling.Model`. `~astropy.modeling.FittableModel` is
6the base class for all fittable models. Fittable models can be linear or
7nonlinear in a regression analysis sense.
8
9All models provide a `__call__` method which performs the transformation in
10a purely mathematical way, i.e. the models are unitless.  Model instances can
11represent either a single model, or a "model set" representing multiple copies
12of the same type of model, but with potentially different values of the
13parameters in each model making up the set.
14"""
15# pylint: disable=invalid-name, protected-access, redefined-outer-name
16import abc
17import copy
18import inspect
19import itertools
20import functools
21import operator
22import types
23
24from collections import defaultdict, deque
25from inspect import signature
26from itertools import chain
27
28import numpy as np
29
30from astropy.utils import indent, metadata
31from astropy.table import Table
32from astropy.units import Quantity, UnitsError, dimensionless_unscaled
33from astropy.units.utils import quantity_asanyarray
34from astropy.utils import (sharedmethod, find_current_module,
35                           check_broadcast, IncompatibleShapeError, isiterable)
36from astropy.utils.codegen import make_function_with_signature
37from astropy.nddata.utils import add_array, extract_array
38from .utils import (combine_labels, make_binary_operator_eval,
39                    get_inputs_and_params, _combine_equivalency_dict,
40                    _ConstraintsDict, _SpecialOperatorsDict)
41from .bounding_box import ModelBoundingBox, CompoundBoundingBox
42from .parameters import (Parameter, InputParameterError,
43                         param_repr_oneline, _tofloat)
44
45
46__all__ = ['Model', 'FittableModel', 'Fittable1DModel', 'Fittable2DModel',
47           'CompoundModel', 'fix_inputs', 'custom_model', 'ModelDefinitionError',
48           'bind_bounding_box', 'bind_compound_bounding_box']
49
50
51def _model_oper(oper, **kwargs):
52    """
53    Returns a function that evaluates a given Python arithmetic operator
54    between two models.  The operator should be given as a string, like ``'+'``
55    or ``'**'``.
56    """
57    return lambda left, right: CompoundModel(oper, left, right, **kwargs)
58
59
60class ModelDefinitionError(TypeError):
61    """Used for incorrect models definitions."""
62
63
64class _ModelMeta(abc.ABCMeta):
65    """
66    Metaclass for Model.
67
68    Currently just handles auto-generating the param_names list based on
69    Parameter descriptors declared at the class-level of Model subclasses.
70    """
71
72    _is_dynamic = False
73    """
74    This flag signifies whether this class was created in the "normal" way,
75    with a class statement in the body of a module, as opposed to a call to
76    `type` or some other metaclass constructor, such that the resulting class
77    does not belong to a specific module.  This is important for pickling of
78    dynamic classes.
79
80    This flag is always forced to False for new classes, so code that creates
81    dynamic classes should manually set it to True on those classes when
82    creating them.
83    """
84
85    # Default empty dict for _parameters_, which will be empty on model
86    # classes that don't have any Parameters
87
88    def __new__(mcls, name, bases, members):
89        # See the docstring for _is_dynamic above
90        if '_is_dynamic' not in members:
91            members['_is_dynamic'] = mcls._is_dynamic
92        opermethods = [
93            ('__add__', _model_oper('+')),
94            ('__sub__', _model_oper('-')),
95            ('__mul__', _model_oper('*')),
96            ('__truediv__', _model_oper('/')),
97            ('__pow__', _model_oper('**')),
98            ('__or__', _model_oper('|')),
99            ('__and__', _model_oper('&')),
100            ('_fix_inputs', _model_oper('fix_inputs'))
101        ]
102
103        members['_parameters_'] = {k: v for k, v in members.items()
104                                   if isinstance(v, Parameter)}
105
106        for opermethod, opercall in opermethods:
107            members[opermethod] = opercall
108        cls = super().__new__(mcls, name, bases, members)
109
110        param_names = list(members['_parameters_'])
111
112        # Need to walk each base MRO to collect all parameter names
113        for base in bases:
114            for tbase in base.__mro__:
115                if issubclass(tbase, Model):
116                    # Preserve order of definitions
117                    param_names = list(tbase._parameters_) + param_names
118        # Remove duplicates (arising from redefinition in subclass).
119        param_names = list(dict.fromkeys(param_names))
120        if cls._parameters_:
121            if hasattr(cls, '_param_names'):
122                # Slight kludge to support compound models, where
123                # cls.param_names is a property; could be improved with a
124                # little refactoring but fine for now
125                cls._param_names = tuple(param_names)
126            else:
127                cls.param_names = tuple(param_names)
128
129        return cls
130
131    def __init__(cls, name, bases, members):
132        super(_ModelMeta, cls).__init__(name, bases, members)
133        cls._create_inverse_property(members)
134        cls._create_bounding_box_property(members)
135        pdict = {}
136        for base in bases:
137            for tbase in base.__mro__:
138                if issubclass(tbase, Model):
139                    for parname, val in cls._parameters_.items():
140                        pdict[parname] = val
141        cls._handle_special_methods(members, pdict)
142
143    def __repr__(cls):
144        """
145        Custom repr for Model subclasses.
146        """
147
148        return cls._format_cls_repr()
149
150    def _repr_pretty_(cls, p, cycle):
151        """
152        Repr for IPython's pretty printer.
153
154        By default IPython "pretty prints" classes, so we need to implement
155        this so that IPython displays the custom repr for Models.
156        """
157
158        p.text(repr(cls))
159
160    def __reduce__(cls):
161        if not cls._is_dynamic:
162            # Just return a string specifying where the class can be imported
163            # from
164            return cls.__name__
165        members = dict(cls.__dict__)
166        # Delete any ABC-related attributes--these will be restored when
167        # the class is reconstructed:
168        for key in list(members):
169            if key.startswith('_abc_'):
170                del members[key]
171
172        # Delete custom __init__ and __call__ if they exist:
173        for key in ('__init__', '__call__'):
174            if key in members:
175                del members[key]
176
177        return (type(cls), (cls.__name__, cls.__bases__, members))
178
179    @property
180    def name(cls):
181        """
182        The name of this model class--equivalent to ``cls.__name__``.
183
184        This attribute is provided for symmetry with the `Model.name` attribute
185        of model instances.
186        """
187
188        return cls.__name__
189
190    @property
191    def _is_concrete(cls):
192        """
193        A class-level property that determines whether the class is a concrete
194        implementation of a Model--i.e. it is not some abstract base class or
195        internal implementation detail (i.e. begins with '_').
196        """
197        return not (cls.__name__.startswith('_') or inspect.isabstract(cls))
198
199    def rename(cls, name=None, inputs=None, outputs=None):
200        """
201        Creates a copy of this model class with a new name, inputs or outputs.
202
203        The new class is technically a subclass of the original class, so that
204        instance and type checks will still work.  For example::
205
206            >>> from astropy.modeling.models import Rotation2D
207            >>> SkyRotation = Rotation2D.rename('SkyRotation')
208            >>> SkyRotation
209            <class 'astropy.modeling.core.SkyRotation'>
210            Name: SkyRotation (Rotation2D)
211            N_inputs: 2
212            N_outputs: 2
213            Fittable parameters: ('angle',)
214            >>> issubclass(SkyRotation, Rotation2D)
215            True
216            >>> r = SkyRotation(90)
217            >>> isinstance(r, Rotation2D)
218            True
219        """
220
221        mod = find_current_module(2)
222        if mod:
223            modname = mod.__name__
224        else:
225            modname = '__main__'
226
227        if name is None:
228            name = cls.name
229        if inputs is None:
230            inputs = cls.inputs
231        else:
232            if not isinstance(inputs, tuple):
233                raise TypeError("Expected 'inputs' to be a tuple of strings.")
234            elif len(inputs) != len(cls.inputs):
235                raise ValueError(f'{cls.name} expects {len(cls.inputs)} inputs')
236        if outputs is None:
237            outputs = cls.outputs
238        else:
239            if not isinstance(outputs, tuple):
240                raise TypeError("Expected 'outputs' to be a tuple of strings.")
241            elif len(outputs) != len(cls.outputs):
242                raise ValueError(f'{cls.name} expects {len(cls.outputs)} outputs')
243        new_cls = type(name, (cls,), {"inputs": inputs, "outputs": outputs})
244        new_cls.__module__ = modname
245        new_cls.__qualname__ = name
246
247        return new_cls
248
249    def _create_inverse_property(cls, members):
250        inverse = members.get('inverse')
251        if inverse is None or cls.__bases__[0] is object:
252            # The latter clause is the prevent the below code from running on
253            # the Model base class, which implements the default getter and
254            # setter for .inverse
255            return
256
257        if isinstance(inverse, property):
258            # We allow the @property decorator to be omitted entirely from
259            # the class definition, though its use should be encouraged for
260            # clarity
261            inverse = inverse.fget
262
263        # Store the inverse getter internally, then delete the given .inverse
264        # attribute so that cls.inverse resolves to Model.inverse instead
265        cls._inverse = inverse
266        del cls.inverse
267
268    def _create_bounding_box_property(cls, members):
269        """
270        Takes any bounding_box defined on a concrete Model subclass (either
271        as a fixed tuple or a property or method) and wraps it in the generic
272        getter/setter interface for the bounding_box attribute.
273        """
274
275        # TODO: Much of this is verbatim from _create_inverse_property--I feel
276        # like there could be a way to generify properties that work this way,
277        # but for the time being that would probably only confuse things more.
278        bounding_box = members.get('bounding_box')
279        if bounding_box is None or cls.__bases__[0] is object:
280            return
281
282        if isinstance(bounding_box, property):
283            bounding_box = bounding_box.fget
284
285        if not callable(bounding_box):
286            # See if it's a hard-coded bounding_box (as a sequence) and
287            # normalize it
288            try:
289                bounding_box = ModelBoundingBox.validate(cls, bounding_box)
290            except ValueError as exc:
291                raise ModelDefinitionError(exc.args[0])
292        else:
293            sig = signature(bounding_box)
294            # May be a method that only takes 'self' as an argument (like a
295            # property, but the @property decorator was forgotten)
296            #
297            # However, if the method takes additional arguments then this is a
298            # parameterized bounding box and should be callable
299            if len(sig.parameters) > 1:
300                bounding_box = \
301                        cls._create_bounding_box_subclass(bounding_box, sig)
302
303        # See the Model.bounding_box getter definition for how this attribute
304        # is used
305        cls._bounding_box = bounding_box
306        del cls.bounding_box
307
308    def _create_bounding_box_subclass(cls, func, sig):
309        """
310        For Models that take optional arguments for defining their bounding
311        box, we create a subclass of ModelBoundingBox with a ``__call__`` method
312        that supports those additional arguments.
313
314        Takes the function's Signature as an argument since that is already
315        computed in _create_bounding_box_property, so no need to duplicate that
316        effort.
317        """
318
319        # TODO: Might be convenient if calling the bounding box also
320        # automatically sets the _user_bounding_box.  So that
321        #
322        #    >>> model.bounding_box(arg=1)
323        #
324        # in addition to returning the computed bbox, also sets it, so that
325        # it's a shortcut for
326        #
327        #    >>> model.bounding_box = model.bounding_box(arg=1)
328        #
329        # Not sure if that would be non-obvious / confusing though...
330
331        def __call__(self, **kwargs):
332            return func(self._model, **kwargs)
333
334        kwargs = []
335        for idx, param in enumerate(sig.parameters.values()):
336            if idx == 0:
337                # Presumed to be a 'self' argument
338                continue
339
340            if param.default is param.empty:
341                raise ModelDefinitionError(
342                    'The bounding_box method for {0} is not correctly '
343                    'defined: If defined as a method all arguments to that '
344                    'method (besides self) must be keyword arguments with '
345                    'default values that can be used to compute a default '
346                    'bounding box.'.format(cls.name))
347
348            kwargs.append((param.name, param.default))
349
350        __call__.__signature__ = sig
351
352        return type(f'{cls.name}ModelBoundingBox', (ModelBoundingBox,),
353                    {'__call__': __call__})
354
355    def _handle_special_methods(cls, members, pdict):
356
357        # Handle init creation from inputs
358        def update_wrapper(wrapper, cls):
359            # Set up the new __call__'s metadata attributes as though it were
360            # manually defined in the class definition
361            # A bit like functools.update_wrapper but uses the class instead of
362            # the wrapped function
363            wrapper.__module__ = cls.__module__
364            wrapper.__doc__ = getattr(cls, wrapper.__name__).__doc__
365            if hasattr(cls, '__qualname__'):
366                wrapper.__qualname__ = f'{cls.__qualname__}.{wrapper.__name__}'
367
368        if ('__call__' not in members and 'n_inputs' in members and
369                isinstance(members['n_inputs'], int) and members['n_inputs'] > 0):
370
371            # Don't create a custom __call__ for classes that already have one
372            # explicitly defined (this includes the Model base class, and any
373            # other classes that manually override __call__
374
375            def __call__(self, *inputs, **kwargs):
376                """Evaluate this model on the supplied inputs."""
377                return super(cls, self).__call__(*inputs, **kwargs)
378
379            # When called, models can take two optional keyword arguments:
380            #
381            # * model_set_axis, which indicates (for multi-dimensional input)
382            #   which axis is used to indicate different models
383            #
384            # * equivalencies, a dictionary of equivalencies to be applied to
385            #   the input values, where each key should correspond to one of
386            #   the inputs.
387            #
388            # The following code creates the __call__ function with these
389            # two keyword arguments.
390
391            args = ('self',)
392            kwargs = dict([('model_set_axis', None),
393                           ('with_bounding_box', False),
394                           ('fill_value', np.nan),
395                           ('equivalencies', None),
396                           ('inputs_map', None)])
397
398            new_call = make_function_with_signature(
399                __call__, args, kwargs, varargs='inputs', varkwargs='new_inputs')
400
401            # The following makes it look like __call__
402            # was defined in the class
403            update_wrapper(new_call, cls)
404
405            cls.__call__ = new_call
406
407        if ('__init__' not in members and not inspect.isabstract(cls) and
408                cls._parameters_):
409            # Build list of all parameters including inherited ones
410
411            # If *all* the parameters have default values we can make them
412            # keyword arguments; otherwise they must all be positional
413            # arguments
414            if all(p.default is not None for p in pdict.values()):
415                args = ('self',)
416                kwargs = []
417                for param_name, param_val in pdict.items():
418                    default = param_val.default
419                    unit = param_val.unit
420                    # If the unit was specified in the parameter but the
421                    # default is not a Quantity, attach the unit to the
422                    # default.
423                    if unit is not None:
424                        default = Quantity(default, unit, copy=False)
425                    kwargs.append((param_name, default))
426            else:
427                args = ('self',) + tuple(pdict.keys())
428                kwargs = {}
429
430            def __init__(self, *params, **kwargs):
431                return super(cls, self).__init__(*params, **kwargs)
432
433            new_init = make_function_with_signature(
434                __init__, args, kwargs, varkwargs='kwargs')
435            update_wrapper(new_init, cls)
436            cls.__init__ = new_init
437
438    # *** Arithmetic operators for creating compound models ***
439    __add__ = _model_oper('+')
440    __sub__ = _model_oper('-')
441    __mul__ = _model_oper('*')
442    __truediv__ = _model_oper('/')
443    __pow__ = _model_oper('**')
444    __or__ = _model_oper('|')
445    __and__ = _model_oper('&')
446    _fix_inputs = _model_oper('fix_inputs')
447
448    # *** Other utilities ***
449
450    def _format_cls_repr(cls, keywords=[]):
451        """
452        Internal implementation of ``__repr__``.
453
454        This is separated out for ease of use by subclasses that wish to
455        override the default ``__repr__`` while keeping the same basic
456        formatting.
457        """
458
459        # For the sake of familiarity start the output with the standard class
460        # __repr__
461        parts = [super().__repr__()]
462
463        if not cls._is_concrete:
464            return parts[0]
465
466        def format_inheritance(cls):
467            bases = []
468            for base in cls.mro()[1:]:
469                if not issubclass(base, Model):
470                    continue
471                elif (inspect.isabstract(base) or
472                      base.__name__.startswith('_')):
473                    break
474                bases.append(base.name)
475            if bases:
476                return f"{cls.name} ({' -> '.join(bases)})"
477            return cls.name
478
479        try:
480            default_keywords = [
481                ('Name', format_inheritance(cls)),
482                ('N_inputs', cls.n_inputs),
483                ('N_outputs', cls.n_outputs),
484            ]
485
486            if cls.param_names:
487                default_keywords.append(('Fittable parameters',
488                                         cls.param_names))
489
490            for keyword, value in default_keywords + keywords:
491                if value is not None:
492                    parts.append(f'{keyword}: {value}')
493
494            return '\n'.join(parts)
495        except Exception:
496            # If any of the above formatting fails fall back on the basic repr
497            # (this is particularly useful in debugging)
498            return parts[0]
499
500
501class Model(metaclass=_ModelMeta):
502    """
503    Base class for all models.
504
505    This is an abstract class and should not be instantiated directly.
506
507    The following initialization arguments apply to the majority of Model
508    subclasses by default (exceptions include specialized utility models
509    like `~astropy.modeling.mappings.Mapping`).  Parametric models take all
510    their parameters as arguments, followed by any of the following optional
511    keyword arguments:
512
513    Parameters
514    ----------
515    name : str, optional
516        A human-friendly name associated with this model instance
517        (particularly useful for identifying the individual components of a
518        compound model).
519
520    meta : dict, optional
521        An optional dict of user-defined metadata to attach to this model.
522        How this is used and interpreted is up to the user or individual use
523        case.
524
525    n_models : int, optional
526        If given an integer greater than 1, a *model set* is instantiated
527        instead of a single model.  This affects how the parameter arguments
528        are interpreted.  In this case each parameter must be given as a list
529        or array--elements of this array are taken along the first axis (or
530        ``model_set_axis`` if specified), such that the Nth element is the
531        value of that parameter for the Nth model in the set.
532
533        See the section on model sets in the documentation for more details.
534
535    model_set_axis : int, optional
536        This argument only applies when creating a model set (i.e. ``n_models >
537        1``).  It changes how parameter values are interpreted.  Normally the
538        first axis of each input parameter array (properly the 0th axis) is
539        taken as the axis corresponding to the model sets.  However, any axis
540        of an input array may be taken as this "model set axis".  This accepts
541        negative integers as well--for example use ``model_set_axis=-1`` if the
542        last (most rapidly changing) axis should be associated with the model
543        sets. Also, ``model_set_axis=False`` can be used to tell that a given
544        input should be used to evaluate all the models in the model set.
545
546    fixed : dict, optional
547        Dictionary ``{parameter_name: bool}`` setting the fixed constraint
548        for one or more parameters.  `True` means the parameter is held fixed
549        during fitting and is prevented from updates once an instance of the
550        model has been created.
551
552        Alternatively the `~astropy.modeling.Parameter.fixed` property of a
553        parameter may be used to lock or unlock individual parameters.
554
555    tied : dict, optional
556        Dictionary ``{parameter_name: callable}`` of parameters which are
557        linked to some other parameter. The dictionary values are callables
558        providing the linking relationship.
559
560        Alternatively the `~astropy.modeling.Parameter.tied` property of a
561        parameter may be used to set the ``tied`` constraint on individual
562        parameters.
563
564    bounds : dict, optional
565        A dictionary ``{parameter_name: value}`` of lower and upper bounds of
566        parameters. Keys are parameter names. Values are a list or a tuple
567        of length 2 giving the desired range for the parameter.
568
569        Alternatively the `~astropy.modeling.Parameter.min` and
570        `~astropy.modeling.Parameter.max` or
571        ~astropy.modeling.Parameter.bounds` properties of a parameter may be
572        used to set bounds on individual parameters.
573
574    eqcons : list, optional
575        List of functions of length n such that ``eqcons[j](x0, *args) == 0.0``
576        in a successfully optimized problem.
577
578    ineqcons : list, optional
579        List of functions of length n such that ``ieqcons[j](x0, *args) >=
580        0.0`` is a successfully optimized problem.
581
582    Examples
583    --------
584    >>> from astropy.modeling import models
585    >>> def tie_center(model):
586    ...         mean = 50 * model.stddev
587    ...         return mean
588    >>> tied_parameters = {'mean': tie_center}
589
590    Specify that ``'mean'`` is a tied parameter in one of two ways:
591
592    >>> g1 = models.Gaussian1D(amplitude=10, mean=5, stddev=.3,
593    ...                        tied=tied_parameters)
594
595    or
596
597    >>> g1 = models.Gaussian1D(amplitude=10, mean=5, stddev=.3)
598    >>> g1.mean.tied
599    False
600    >>> g1.mean.tied = tie_center
601    >>> g1.mean.tied
602    <function tie_center at 0x...>
603
604    Fixed parameters:
605
606    >>> g1 = models.Gaussian1D(amplitude=10, mean=5, stddev=.3,
607    ...                        fixed={'stddev': True})
608    >>> g1.stddev.fixed
609    True
610
611    or
612
613    >>> g1 = models.Gaussian1D(amplitude=10, mean=5, stddev=.3)
614    >>> g1.stddev.fixed
615    False
616    >>> g1.stddev.fixed = True
617    >>> g1.stddev.fixed
618    True
619    """
620
621    parameter_constraints = Parameter.constraints
622    """
623    Primarily for informational purposes, these are the types of constraints
624    that can be set on a model's parameters.
625    """
626
627    model_constraints = ('eqcons', 'ineqcons')
628    """
629    Primarily for informational purposes, these are the types of constraints
630    that constrain model evaluation.
631    """
632
633    param_names = ()
634    """
635    Names of the parameters that describe models of this type.
636
637    The parameters in this tuple are in the same order they should be passed in
638    when initializing a model of a specific type.  Some types of models, such
639    as polynomial models, have a different number of parameters depending on
640    some other property of the model, such as the degree.
641
642    When defining a custom model class the value of this attribute is
643    automatically set by the `~astropy.modeling.Parameter` attributes defined
644    in the class body.
645    """
646
647    n_inputs = 0
648    """The number of inputs."""
649    n_outputs = 0
650    """ The number of outputs."""
651
652    standard_broadcasting = True
653    fittable = False
654    linear = True
655    _separable = None
656    """ A boolean flag to indicate whether a model is separable."""
657    meta = metadata.MetaData()
658    """A dict-like object to store optional information."""
659
660    # By default models either use their own inverse property or have no
661    # inverse at all, but users may also assign a custom inverse to a model,
662    # optionally; in that case it is of course up to the user to determine
663    # whether their inverse is *actually* an inverse to the model they assign
664    # it to.
665    _inverse = None
666    _user_inverse = None
667
668    _bounding_box = None
669    _user_bounding_box = None
670
671    _has_inverse_bounding_box = False
672
673    # Default n_models attribute, so that __len__ is still defined even when a
674    # model hasn't completed initialization yet
675    _n_models = 1
676
677    # New classes can set this as a boolean value.
678    # It is converted to a dictionary mapping input name to a boolean value.
679    _input_units_strict = False
680
681    # Allow dimensionless input (and corresponding output). If this is True,
682    # input values to evaluate will gain the units specified in input_units. If
683    # this is a dictionary then it should map input name to a bool to allow
684    # dimensionless numbers for that input.
685    # Only has an effect if input_units is defined.
686    _input_units_allow_dimensionless = False
687
688    # Default equivalencies to apply to input values. If set, this should be a
689    # dictionary where each key is a string that corresponds to one of the
690    # model inputs. Only has an effect if input_units is defined.
691    input_units_equivalencies = None
692
693    # Covariance matrix can be set by fitter if available.
694    # If cov_matrix is available, then std will set as well
695    _cov_matrix = None
696    _stds = None
697
698    def __init__(self, *args, meta=None, name=None, **kwargs):
699        super().__init__()
700        self._default_inputs_outputs()
701        if meta is not None:
702            self.meta = meta
703        self._name = name
704        # add parameters to instance level by walking MRO list
705        mro = self.__class__.__mro__
706        for cls in mro:
707            if issubclass(cls, Model):
708                for parname, val in cls._parameters_.items():
709                    newpar = copy.deepcopy(val)
710                    newpar.model = self
711                    if parname not in self.__dict__:
712                        self.__dict__[parname] = newpar
713
714        self._initialize_constraints(kwargs)
715        kwargs = self._initialize_setters(kwargs)
716        # Remaining keyword args are either parameter values or invalid
717        # Parameter values must be passed in as keyword arguments in order to
718        # distinguish them
719        self._initialize_parameters(args, kwargs)
720        self._initialize_slices()
721        self._initialize_unit_support()
722
723    def _default_inputs_outputs(self):
724        if self.n_inputs == 1 and self.n_outputs == 1:
725            self._inputs = ("x",)
726            self._outputs = ("y",)
727        elif self.n_inputs == 2 and self.n_outputs == 1:
728            self._inputs = ("x", "y")
729            self._outputs = ("z",)
730        else:
731            try:
732                self._inputs = tuple("x" + str(idx) for idx in range(self.n_inputs))
733                self._outputs = tuple("x" + str(idx) for idx in range(self.n_outputs))
734            except TypeError:
735                # self.n_inputs and self.n_outputs are properties
736                # This is the case when subclasses of Model do not define
737                # ``n_inputs``, ``n_outputs``, ``inputs`` or ``outputs``.
738                self._inputs = ()
739                self._outputs = ()
740
741    def _initialize_setters(self, kwargs):
742        """
743        This exists to inject defaults for settable properties for models
744        originating from `custom_model`.
745        """
746        if hasattr(self, '_settable_properties'):
747            setters = {name: kwargs.pop(name, default)
748                       for name, default in self._settable_properties.items()}
749            for name, value in setters.items():
750                setattr(self, name, value)
751
752        return kwargs
753
754    @property
755    def inputs(self):
756        return self._inputs
757
758    @inputs.setter
759    def inputs(self, val):
760        if len(val) != self.n_inputs:
761            raise ValueError(f"Expected {self.n_inputs} number of inputs, got {len(val)}.")
762        self._inputs = val
763        self._initialize_unit_support()
764
765    @property
766    def outputs(self):
767        return self._outputs
768
769    @outputs.setter
770    def outputs(self, val):
771        if len(val) != self.n_outputs:
772            raise ValueError(f"Expected {self.n_outputs} number of outputs, got {len(val)}.")
773        self._outputs = val
774
775    @property
776    def n_inputs(self):
777        # TODO: remove the code in the ``if`` block when support
778        # for models with ``inputs`` as class variables is removed.
779        if hasattr(self.__class__, 'n_inputs') and isinstance(self.__class__.n_inputs, property):
780            try:
781                return len(self.__class__.inputs)
782            except TypeError:
783                try:
784                    return len(self.inputs)
785                except AttributeError:
786                    return 0
787
788        return self.__class__.n_inputs
789
790    @property
791    def n_outputs(self):
792        # TODO: remove the code in the ``if`` block when support
793        # for models with ``outputs`` as class variables is removed.
794        if hasattr(self.__class__, 'n_outputs') and isinstance(self.__class__.n_outputs, property):
795            try:
796                return len(self.__class__.outputs)
797            except TypeError:
798                try:
799                    return len(self.outputs)
800                except AttributeError:
801                    return 0
802
803        return self.__class__.n_outputs
804
805    def _initialize_unit_support(self):
806        """
807        Convert self._input_units_strict and
808        self.input_units_allow_dimensionless to dictionaries
809        mapping input name to a boolean value.
810        """
811        if isinstance(self._input_units_strict, bool):
812            self._input_units_strict = {key: self._input_units_strict for
813                                        key in self.inputs}
814
815        if isinstance(self._input_units_allow_dimensionless, bool):
816            self._input_units_allow_dimensionless = {key: self._input_units_allow_dimensionless
817                                                     for key in self.inputs}
818
819    @property
820    def input_units_strict(self):
821        """
822        Enforce strict units on inputs to evaluate. If this is set to True,
823        input values to evaluate will be in the exact units specified by
824        input_units. If the input quantities are convertible to input_units,
825        they are converted. If this is a dictionary then it should map input
826        name to a bool to set strict input units for that parameter.
827        """
828        val = self._input_units_strict
829        if isinstance(val, bool):
830            return {key: val for key in self.inputs}
831        return dict(zip(self.inputs, val.values()))
832
833    @property
834    def input_units_allow_dimensionless(self):
835        """
836        Allow dimensionless input (and corresponding output). If this is True,
837        input values to evaluate will gain the units specified in input_units. If
838        this is a dictionary then it should map input name to a bool to allow
839        dimensionless numbers for that input.
840        Only has an effect if input_units is defined.
841        """
842
843        val = self._input_units_allow_dimensionless
844        if isinstance(val, bool):
845            return {key: val for key in self.inputs}
846        return dict(zip(self.inputs, val.values()))
847
848    @property
849    def uses_quantity(self):
850        """
851        True if this model has been created with `~astropy.units.Quantity`
852        objects or if there are no parameters.
853
854        This can be used to determine if this model should be evaluated with
855        `~astropy.units.Quantity` or regular floats.
856        """
857        pisq = [isinstance(p, Quantity) for p in self._param_sets(units=True)]
858        return (len(pisq) == 0) or any(pisq)
859
860    def __repr__(self):
861        return self._format_repr()
862
863    def __str__(self):
864        return self._format_str()
865
866    def __len__(self):
867        return self._n_models
868
869    @staticmethod
870    def _strip_ones(intup):
871        return tuple(item for item in intup if item != 1)
872
873    def __setattr__(self, attr, value):
874        if isinstance(self, CompoundModel):
875            param_names = self._param_names
876        param_names = self.param_names
877
878        if param_names is not None and attr in self.param_names:
879            param = self.__dict__[attr]
880            value = _tofloat(value)
881            if param._validator is not None:
882                param._validator(self, value)
883            # check consistency with previous shape and size
884            eshape = self._param_metrics[attr]['shape']
885            if eshape == ():
886                eshape = (1,)
887            vshape = np.array(value).shape
888            if vshape == ():
889                vshape = (1,)
890            esize = self._param_metrics[attr]['size']
891            if (np.size(value) != esize or
892                    self._strip_ones(vshape) != self._strip_ones(eshape)):
893                raise InputParameterError(
894                    "Value for parameter {0} does not match shape or size\n"
895                    "expected by model ({1}, {2}) vs ({3}, {4})".format(
896                        attr, vshape, np.size(value), eshape, esize))
897            if param.unit is None:
898                if isinstance(value, Quantity):
899                    param._unit = value.unit
900                    param.value = value.value
901                else:
902                    param.value = value
903            else:
904                if not isinstance(value, Quantity):
905                    raise UnitsError(f"The '{param.name}' parameter should be given as a"
906                                     " Quantity because it was originally "
907                                     "initialized as a Quantity")
908                param._unit = value.unit
909                param.value = value.value
910        else:
911            if attr in ['fittable', 'linear']:
912                self.__dict__[attr] = value
913            else:
914                super().__setattr__(attr, value)
915
916    def _pre_evaluate(self, *args, **kwargs):
917        """
918        Model specific input setup that needs to occur prior to model evaluation
919        """
920
921        # Broadcast inputs into common size
922        inputs, broadcasted_shapes = self.prepare_inputs(*args, **kwargs)
923
924        # Setup actual model evaluation method
925        parameters = self._param_sets(raw=True, units=True)
926
927        def evaluate(_inputs):
928            return self.evaluate(*chain(_inputs, parameters))
929
930        return evaluate, inputs, broadcasted_shapes, kwargs
931
932    def get_bounding_box(self, with_bbox=True):
933        """
934        Return the ``bounding_box`` of a model if it exists or ``None``
935        otherwise.
936
937        Parameters
938        ----------
939        with_bbox :
940            The value of the ``with_bounding_box`` keyword argument
941            when calling the model. Default is `True` for usage when
942            looking up the model's ``bounding_box`` without risk of error.
943        """
944        bbox = None
945
946        if not isinstance(with_bbox, bool) or with_bbox:
947            try:
948                bbox = self.bounding_box
949            except NotImplementedError:
950                pass
951
952            if isinstance(bbox, CompoundBoundingBox) and not isinstance(with_bbox, bool):
953                bbox = bbox[with_bbox]
954
955        return bbox
956
957    @property
958    def _argnames(self):
959        """The inputs used to determine input_shape for bounding_box evaluation"""
960        return self.inputs
961
962    def _validate_input_shape(self, _input, idx, argnames, model_set_axis, check_model_set_axis):
963        """
964        Perform basic validation of a single model input's shape
965            -- it has the minimum dimensions for the given model_set_axis
966
967        Returns the shape of the input if validation succeeds.
968        """
969        input_shape = np.shape(_input)
970        # Ensure that the input's model_set_axis matches the model's
971        # n_models
972        if input_shape and check_model_set_axis:
973            # Note: Scalar inputs *only* get a pass on this
974            if len(input_shape) < model_set_axis + 1:
975                raise ValueError(
976                    f"For model_set_axis={model_set_axis}, all inputs must be at "
977                    f"least {model_set_axis + 1}-dimensional.")
978            if input_shape[model_set_axis] != self._n_models:
979                try:
980                    argname = argnames[idx]
981                except IndexError:
982                    # the case of model.inputs = ()
983                    argname = str(idx)
984
985                raise ValueError(
986                    f"Input argument '{argname}' does not have the correct "
987                    f"dimensions in model_set_axis={model_set_axis} for a model set with "
988                    f"n_models={self._n_models}.")
989
990        return input_shape
991
992    def _validate_input_shapes(self, inputs, argnames, model_set_axis):
993        """
994        Perform basic validation of model inputs
995            --that they are mutually broadcastable and that they have
996            the minimum dimensions for the given model_set_axis.
997
998        If validation succeeds, returns the total shape that will result from
999        broadcasting the input arrays with each other.
1000        """
1001
1002        check_model_set_axis = self._n_models > 1 and model_set_axis is not False
1003
1004        all_shapes = []
1005        for idx, _input in enumerate(inputs):
1006            all_shapes.append(self._validate_input_shape(_input, idx, argnames,
1007                                                         model_set_axis, check_model_set_axis))
1008
1009        input_shape = check_broadcast(*all_shapes)
1010        if input_shape is None:
1011            raise ValueError(
1012                "All inputs must have identical shapes or must be scalars.")
1013
1014        return input_shape
1015
1016    def input_shape(self, inputs):
1017        """Get input shape for bounding_box evaluation"""
1018        return self._validate_input_shapes(inputs, self._argnames, self.model_set_axis)
1019
1020    def _generic_evaluate(self, evaluate, _inputs, fill_value, with_bbox):
1021        """
1022        Generic model evaluation routine
1023            Selects and evaluates model with or without bounding_box enforcement
1024        """
1025
1026        # Evaluate the model using the prepared evaluation method either
1027        #   enforcing the bounding_box or not.
1028        bbox = self.get_bounding_box(with_bbox)
1029        if (not isinstance(with_bbox, bool) or with_bbox) and bbox is not None:
1030            outputs = bbox.evaluate(evaluate, _inputs, fill_value)
1031        else:
1032            outputs = evaluate(_inputs)
1033        return outputs
1034
1035    def _post_evaluate(self, inputs, outputs, broadcasted_shapes, with_bbox, **kwargs):
1036        """
1037        Model specific post evaluation processing of outputs
1038        """
1039        if self.get_bounding_box(with_bbox) is None and self.n_outputs == 1:
1040            outputs = (outputs,)
1041
1042        outputs = self.prepare_outputs(broadcasted_shapes, *outputs, **kwargs)
1043        outputs = self._process_output_units(inputs, outputs)
1044
1045        if self.n_outputs == 1:
1046            return outputs[0]
1047        return outputs
1048
1049    @property
1050    def bbox_with_units(self):
1051        return (not isinstance(self, CompoundModel))
1052
1053    def __call__(self, *args, **kwargs):
1054        """
1055        Evaluate this model using the given input(s) and the parameter values
1056        that were specified when the model was instantiated.
1057        """
1058        # Turn any keyword arguments into positional arguments.
1059        args, kwargs = self._get_renamed_inputs_as_positional(*args, **kwargs)
1060
1061        # Read model evaluation related parameters
1062        with_bbox = kwargs.pop('with_bounding_box', False)
1063        fill_value = kwargs.pop('fill_value', np.nan)
1064
1065        # prepare for model evaluation (overridden in CompoundModel)
1066        evaluate, inputs, broadcasted_shapes, kwargs = self._pre_evaluate(*args, **kwargs)
1067
1068        outputs = self._generic_evaluate(evaluate, inputs,
1069                                         fill_value, with_bbox)
1070
1071        # post-process evaluation results (overridden in CompoundModel)
1072        return self._post_evaluate(inputs, outputs, broadcasted_shapes, with_bbox, **kwargs)
1073
1074    def _get_renamed_inputs_as_positional(self, *args, **kwargs):
1075        def _keyword2positional(kwargs):
1076            # Inputs were passed as keyword (not positional) arguments.
1077            # Because the signature of the ``__call__`` is defined at
1078            # the class level, the name of the inputs cannot be changed at
1079            # the instance level and the old names are always present in the
1080            # signature of the method. In order to use the new names of the
1081            # inputs, the old names are taken out of ``kwargs``, the input
1082            # values are sorted in the order of self.inputs and passed as
1083            # positional arguments to ``__call__``.
1084
1085            # These are the keys that are always present as keyword arguments.
1086            keys = ['model_set_axis', 'with_bounding_box', 'fill_value',
1087                    'equivalencies', 'inputs_map']
1088
1089            new_inputs = {}
1090            # kwargs contain the names of the new inputs + ``keys``
1091            allkeys = list(kwargs.keys())
1092            # Remove the names of the new inputs from kwargs and save them
1093            # to a dict ``new_inputs``.
1094            for key in allkeys:
1095                if key not in keys:
1096                    new_inputs[key] = kwargs[key]
1097                    del kwargs[key]
1098            return new_inputs, kwargs
1099        n_args = len(args)
1100
1101        new_inputs, kwargs = _keyword2positional(kwargs)
1102        n_all_args = n_args + len(new_inputs)
1103
1104        if n_all_args < self.n_inputs:
1105            raise ValueError(f"Missing input arguments - expected {self.n_inputs}, got {n_all_args}")
1106        elif n_all_args > self.n_inputs:
1107            raise ValueError(f"Too many input arguments - expected {self.n_inputs}, got {n_all_args}")
1108        if n_args == 0:
1109            # Create positional arguments from the keyword arguments in ``new_inputs``.
1110            new_args = []
1111            for k in self.inputs:
1112                new_args.append(new_inputs[k])
1113        elif n_args != self.n_inputs:
1114            # Some inputs are passed as positional, others as keyword arguments.
1115            args = list(args)
1116
1117            # Create positional arguments from the keyword arguments in ``new_inputs``.
1118            new_args = []
1119            for k in self.inputs:
1120                if k in new_inputs:
1121                    new_args.append(new_inputs[k])
1122                else:
1123                    new_args.append(args[0])
1124                    del args[0]
1125        else:
1126            new_args = args
1127        return new_args, kwargs
1128
1129    # *** Properties ***
1130    @property
1131    def name(self):
1132        """User-provided name for this model instance."""
1133
1134        return self._name
1135
1136    @name.setter
1137    def name(self, val):
1138        """Assign a (new) name to this model."""
1139
1140        self._name = val
1141
1142    @property
1143    def model_set_axis(self):
1144        """
1145        The index of the model set axis--that is the axis of a parameter array
1146        that pertains to which model a parameter value pertains to--as
1147        specified when the model was initialized.
1148
1149        See the documentation on :ref:`astropy:modeling-model-sets`
1150        for more details.
1151        """
1152
1153        return self._model_set_axis
1154
1155    @property
1156    def param_sets(self):
1157        """
1158        Return parameters as a pset.
1159
1160        This is a list with one item per parameter set, which is an array of
1161        that parameter's values across all parameter sets, with the last axis
1162        associated with the parameter set.
1163        """
1164
1165        return self._param_sets()
1166
1167    @property
1168    def parameters(self):
1169        """
1170        A flattened array of all parameter values in all parameter sets.
1171
1172        Fittable parameters maintain this list and fitters modify it.
1173        """
1174
1175        # Currently the sequence of a model's parameters must be contiguous
1176        # within the _parameters array (which may be a view of a larger array,
1177        # for example when taking a sub-expression of a compound model), so
1178        # the assumption here is reliable:
1179        if not self.param_names:
1180            # Trivial, but not unheard of
1181            return self._parameters
1182
1183        self._parameters_to_array()
1184        start = self._param_metrics[self.param_names[0]]['slice'].start
1185        stop = self._param_metrics[self.param_names[-1]]['slice'].stop
1186
1187        return self._parameters[start:stop]
1188
1189    @parameters.setter
1190    def parameters(self, value):
1191        """
1192        Assigning to this attribute updates the parameters array rather than
1193        replacing it.
1194        """
1195
1196        if not self.param_names:
1197            return
1198
1199        start = self._param_metrics[self.param_names[0]]['slice'].start
1200        stop = self._param_metrics[self.param_names[-1]]['slice'].stop
1201
1202        try:
1203            value = np.array(value).flatten()
1204            self._parameters[start:stop] = value
1205        except ValueError as e:
1206            raise InputParameterError(
1207                "Input parameter values not compatible with the model "
1208                "parameters array: {0}".format(e))
1209        self._array_to_parameters()
1210
1211    @property
1212    def sync_constraints(self):
1213        '''
1214        This is a boolean property that indicates whether or not accessing constraints
1215        automatically check the constituent models current values. It defaults to True
1216        on creation of a model, but for fitting purposes it should be set to False
1217        for performance reasons.
1218        '''
1219        if not hasattr(self, '_sync_constraints'):
1220            self._sync_constraints = True
1221        return self._sync_constraints
1222
1223    @sync_constraints.setter
1224    def sync_constraints(self, value):
1225        if not isinstance(value, bool):
1226            raise ValueError('sync_constraints only accepts True or False as values')
1227        self._sync_constraints = value
1228
1229    @property
1230    def fixed(self):
1231        """
1232        A ``dict`` mapping parameter names to their fixed constraint.
1233        """
1234        if not hasattr(self, '_fixed') or self.sync_constraints:
1235            self._fixed = _ConstraintsDict(self, 'fixed')
1236        return self._fixed
1237
1238    @property
1239    def bounds(self):
1240        """
1241        A ``dict`` mapping parameter names to their upper and lower bounds as
1242        ``(min, max)`` tuples or ``[min, max]`` lists.
1243        """
1244        if not hasattr(self, '_bounds') or self.sync_constraints:
1245            self._bounds = _ConstraintsDict(self, 'bounds')
1246        return self._bounds
1247
1248    @property
1249    def tied(self):
1250        """
1251        A ``dict`` mapping parameter names to their tied constraint.
1252        """
1253        if not hasattr(self, '_tied') or self.sync_constraints:
1254            self._tied = _ConstraintsDict(self, 'tied')
1255        return self._tied
1256
1257    @property
1258    def eqcons(self):
1259        """List of parameter equality constraints."""
1260
1261        return self._mconstraints['eqcons']
1262
1263    @property
1264    def ineqcons(self):
1265        """List of parameter inequality constraints."""
1266
1267        return self._mconstraints['ineqcons']
1268
1269    def has_inverse(self):
1270        """
1271        Returns True if the model has an analytic or user
1272        inverse defined.
1273        """
1274        try:
1275            self.inverse
1276        except NotImplementedError:
1277            return False
1278
1279        return True
1280
1281    @property
1282    def inverse(self):
1283        """
1284        Returns a new `~astropy.modeling.Model` instance which performs the
1285        inverse transform, if an analytic inverse is defined for this model.
1286
1287        Even on models that don't have an inverse defined, this property can be
1288        set with a manually-defined inverse, such a pre-computed or
1289        experimentally determined inverse (often given as a
1290        `~astropy.modeling.polynomial.PolynomialModel`, but not by
1291        requirement).
1292
1293        A custom inverse can be deleted with ``del model.inverse``.  In this
1294        case the model's inverse is reset to its default, if a default exists
1295        (otherwise the default is to raise `NotImplementedError`).
1296
1297        Note to authors of `~astropy.modeling.Model` subclasses:  To define an
1298        inverse for a model simply override this property to return the
1299        appropriate model representing the inverse.  The machinery that will
1300        make the inverse manually-overridable is added automatically by the
1301        base class.
1302        """
1303        if self._user_inverse is not None:
1304            return self._user_inverse
1305        elif self._inverse is not None:
1306            result = self._inverse()
1307            if result is not NotImplemented:
1308                if not self._has_inverse_bounding_box:
1309                    result.bounding_box = None
1310                return result
1311
1312        raise NotImplementedError("No analytical or user-supplied inverse transform "
1313                                  "has been implemented for this model.")
1314
1315    @inverse.setter
1316    def inverse(self, value):
1317        if not isinstance(value, (Model, type(None))):
1318            raise ValueError(
1319                "The ``inverse`` attribute may be assigned a `Model` "
1320                "instance or `None` (where `None` explicitly forces the "
1321                "model to have no inverse.")
1322
1323        self._user_inverse = value
1324
1325    @inverse.deleter
1326    def inverse(self):
1327        """
1328        Resets the model's inverse to its default (if one exists, otherwise
1329        the model will have no inverse).
1330        """
1331
1332        try:
1333            del self._user_inverse
1334        except AttributeError:
1335            pass
1336
1337    @property
1338    def has_user_inverse(self):
1339        """
1340        A flag indicating whether or not a custom inverse model has been
1341        assigned to this model by a user, via assignment to ``model.inverse``.
1342        """
1343        return self._user_inverse is not None
1344
1345    @property
1346    def bounding_box(self):
1347        r"""
1348        A `tuple` of length `n_inputs` defining the bounding box limits, or
1349        raise `NotImplementedError` for no bounding_box.
1350
1351        The default limits are given by a ``bounding_box`` property or method
1352        defined in the class body of a specific model.  If not defined then
1353        this property just raises `NotImplementedError` by default (but may be
1354        assigned a custom value by a user).  ``bounding_box`` can be set
1355        manually to an array-like object of shape ``(model.n_inputs, 2)``. For
1356        further usage, see :ref:`astropy:bounding-boxes`
1357
1358        The limits are ordered according to the `numpy` ``'C'`` indexing
1359        convention, and are the reverse of the model input order,
1360        e.g. for inputs ``('x', 'y', 'z')``, ``bounding_box`` is defined:
1361
1362        * for 1D: ``(x_low, x_high)``
1363        * for 2D: ``((y_low, y_high), (x_low, x_high))``
1364        * for 3D: ``((z_low, z_high), (y_low, y_high), (x_low, x_high))``
1365
1366        Examples
1367        --------
1368
1369        Setting the ``bounding_box`` limits for a 1D and 2D model:
1370
1371        >>> from astropy.modeling.models import Gaussian1D, Gaussian2D
1372        >>> model_1d = Gaussian1D()
1373        >>> model_2d = Gaussian2D(x_stddev=1, y_stddev=1)
1374        >>> model_1d.bounding_box = (-5, 5)
1375        >>> model_2d.bounding_box = ((-6, 6), (-5, 5))
1376
1377        Setting the bounding_box limits for a user-defined 3D `custom_model`:
1378
1379        >>> from astropy.modeling.models import custom_model
1380        >>> def const3d(x, y, z, amp=1):
1381        ...    return amp
1382        ...
1383        >>> Const3D = custom_model(const3d)
1384        >>> model_3d = Const3D()
1385        >>> model_3d.bounding_box = ((-6, 6), (-5, 5), (-4, 4))
1386
1387        To reset ``bounding_box`` to its default limits just delete the
1388        user-defined value--this will reset it back to the default defined
1389        on the class:
1390
1391        >>> del model_1d.bounding_box
1392
1393        To disable the bounding box entirely (including the default),
1394        set ``bounding_box`` to `None`:
1395
1396        >>> model_1d.bounding_box = None
1397        >>> model_1d.bounding_box  # doctest: +IGNORE_EXCEPTION_DETAIL
1398        Traceback (most recent call last):
1399        NotImplementedError: No bounding box is defined for this model
1400        (note: the bounding box was explicitly disabled for this model;
1401        use `del model.bounding_box` to restore the default bounding box,
1402        if one is defined for this model).
1403        """
1404
1405        if self._user_bounding_box is not None:
1406            if self._user_bounding_box is NotImplemented:
1407                raise NotImplementedError(
1408                    "No bounding box is defined for this model (note: the "
1409                    "bounding box was explicitly disabled for this model; "
1410                    "use `del model.bounding_box` to restore the default "
1411                    "bounding box, if one is defined for this model).")
1412            return self._user_bounding_box
1413        elif self._bounding_box is None:
1414            raise NotImplementedError(
1415                "No bounding box is defined for this model.")
1416        elif isinstance(self._bounding_box, ModelBoundingBox):
1417            # This typically implies a hard-coded bounding box.  This will
1418            # probably be rare, but it is an option
1419            return self._bounding_box
1420        elif isinstance(self._bounding_box, types.MethodType):
1421            return ModelBoundingBox.validate(self, self._bounding_box())
1422        else:
1423            # The only other allowed possibility is that it's a ModelBoundingBox
1424            # subclass, so we call it with its default arguments and return an
1425            # instance of it (that can be called to recompute the bounding box
1426            # with any optional parameters)
1427            # (In other words, in this case self._bounding_box is a *class*)
1428            bounding_box = self._bounding_box((), model=self)()
1429            return self._bounding_box(bounding_box, model=self)
1430
1431    @bounding_box.setter
1432    def bounding_box(self, bounding_box):
1433        """
1434        Assigns the bounding box limits.
1435        """
1436
1437        if bounding_box is None:
1438            cls = None
1439            # We use this to explicitly set an unimplemented bounding box (as
1440            # opposed to no user bounding box defined)
1441            bounding_box = NotImplemented
1442        elif (isinstance(bounding_box, CompoundBoundingBox) or
1443              isinstance(bounding_box, dict)):
1444            cls = CompoundBoundingBox
1445        elif (isinstance(self._bounding_box, type) and
1446              issubclass(self._bounding_box, ModelBoundingBox)):
1447            cls = self._bounding_box
1448        else:
1449            cls = ModelBoundingBox
1450
1451        if cls is not None:
1452            try:
1453                bounding_box = cls.validate(self, bounding_box)
1454            except ValueError as exc:
1455                raise ValueError(exc.args[0])
1456
1457        self._user_bounding_box = bounding_box
1458
1459    def set_slice_args(self, *args):
1460        if isinstance(self._user_bounding_box, CompoundBoundingBox):
1461            self._user_bounding_box.slice_args = args
1462        else:
1463            raise RuntimeError('The bounding_box for this model is not compound')
1464
1465    @bounding_box.deleter
1466    def bounding_box(self):
1467        self._user_bounding_box = None
1468
1469    @property
1470    def has_user_bounding_box(self):
1471        """
1472        A flag indicating whether or not a custom bounding_box has been
1473        assigned to this model by a user, via assignment to
1474        ``model.bounding_box``.
1475        """
1476
1477        return self._user_bounding_box is not None
1478
1479    @property
1480    def cov_matrix(self):
1481        """
1482        Fitter should set covariance matrix, if available.
1483        """
1484        return self._cov_matrix
1485
1486    @cov_matrix.setter
1487    def cov_matrix(self, cov):
1488
1489        self._cov_matrix = cov
1490
1491        unfix_untied_params = [p for p in self.param_names if (self.fixed[p] is False)
1492                               and (self.tied[p] is False)]
1493        if type(cov) == list:  # model set
1494            param_stds = []
1495            for c in cov:
1496                param_stds.append([np.sqrt(x) if x > 0 else None for x in np.diag(c.cov_matrix)])
1497            for p, param_name in enumerate(unfix_untied_params):
1498                par = getattr(self, param_name)
1499                par.std = [item[p] for item in param_stds]
1500                setattr(self, param_name, par)
1501        else:
1502            param_stds = [np.sqrt(x) if x > 0 else None for x in np.diag(cov.cov_matrix)]
1503            for param_name in unfix_untied_params:
1504                par = getattr(self, param_name)
1505                par.std = param_stds.pop(0)
1506                setattr(self, param_name, par)
1507
1508    @property
1509    def stds(self):
1510        """
1511        Standard deviation of parameters, if covariance matrix is available.
1512        """
1513        return self._stds
1514
1515    @stds.setter
1516    def stds(self, stds):
1517        self._stds = stds
1518
1519    @property
1520    def separable(self):
1521        """ A flag indicating whether a model is separable."""
1522
1523        if self._separable is not None:
1524            return self._separable
1525        raise NotImplementedError(
1526            'The "separable" property is not defined for '
1527            'model {}'.format(self.__class__.__name__))
1528
1529    # *** Public methods ***
1530
1531    def without_units_for_data(self, **kwargs):
1532        """
1533        Return an instance of the model for which the parameter values have
1534        been converted to the right units for the data, then the units have
1535        been stripped away.
1536
1537        The input and output Quantity objects should be given as keyword
1538        arguments.
1539
1540        Notes
1541        -----
1542
1543        This method is needed in order to be able to fit models with units in
1544        the parameters, since we need to temporarily strip away the units from
1545        the model during the fitting (which might be done by e.g. scipy
1546        functions).
1547
1548        The units that the parameters should be converted to are not
1549        necessarily the units of the input data, but are derived from them.
1550        Model subclasses that want fitting to work in the presence of
1551        quantities need to define a ``_parameter_units_for_data_units`` method
1552        that takes the input and output units (as two dictionaries) and
1553        returns a dictionary giving the target units for each parameter.
1554
1555        """
1556        model = self.copy()
1557
1558        inputs_unit = {inp: getattr(kwargs[inp], 'unit', dimensionless_unscaled)
1559                       for inp in self.inputs if kwargs[inp] is not None}
1560
1561        outputs_unit = {out: getattr(kwargs[out], 'unit', dimensionless_unscaled)
1562                        for out in self.outputs if kwargs[out] is not None}
1563        parameter_units = self._parameter_units_for_data_units(inputs_unit,
1564                                                               outputs_unit)
1565        for name, unit in parameter_units.items():
1566            parameter = getattr(model, name)
1567            if parameter.unit is not None:
1568                parameter.value = parameter.quantity.to(unit).value
1569                parameter._set_unit(None, force=True)
1570
1571        if isinstance(model, CompoundModel):
1572            model.strip_units_from_tree()
1573
1574        return model
1575
1576    def strip_units_from_tree(self):
1577        for item in self._leaflist:
1578            for parname in item.param_names:
1579                par = getattr(item, parname)
1580                par._set_unit(None, force=True)
1581
1582    def with_units_from_data(self, **kwargs):
1583        """
1584        Return an instance of the model which has units for which the parameter
1585        values are compatible with the data units specified.
1586
1587        The input and output Quantity objects should be given as keyword
1588        arguments.
1589
1590        Notes
1591        -----
1592
1593        This method is needed in order to be able to fit models with units in
1594        the parameters, since we need to temporarily strip away the units from
1595        the model during the fitting (which might be done by e.g. scipy
1596        functions).
1597
1598        The units that the parameters will gain are not necessarily the units
1599        of the input data, but are derived from them. Model subclasses that
1600        want fitting to work in the presence of quantities need to define a
1601        ``_parameter_units_for_data_units`` method that takes the input and output
1602        units (as two dictionaries) and returns a dictionary giving the target
1603        units for each parameter.
1604        """
1605
1606        model = self.copy()
1607        inputs_unit = {inp: getattr(kwargs[inp], 'unit', dimensionless_unscaled)
1608                       for inp in self.inputs if kwargs[inp] is not None}
1609
1610        outputs_unit = {out: getattr(kwargs[out], 'unit', dimensionless_unscaled)
1611                        for out in self.outputs if kwargs[out] is not None}
1612
1613        parameter_units = self._parameter_units_for_data_units(inputs_unit,
1614                                                               outputs_unit)
1615
1616        # We are adding units to parameters that already have a value, but we
1617        # don't want to convert the parameter, just add the unit directly,
1618        # hence the call to ``_set_unit``.
1619        for name, unit in parameter_units.items():
1620            parameter = getattr(model, name)
1621            parameter._set_unit(unit, force=True)
1622
1623        return model
1624
1625    @property
1626    def _has_units(self):
1627        # Returns True if any of the parameters have units
1628        for param in self.param_names:
1629            if getattr(self, param).unit is not None:
1630                return True
1631        else:
1632            return False
1633
1634    @property
1635    def _supports_unit_fitting(self):
1636        # If the model has a ``_parameter_units_for_data_units`` method, this
1637        # indicates that we have enough information to strip the units away
1638        # and add them back after fitting, when fitting quantities
1639        return hasattr(self, '_parameter_units_for_data_units')
1640
1641    @abc.abstractmethod
1642    def evaluate(self, *args, **kwargs):
1643        """Evaluate the model on some input variables."""
1644
1645    def sum_of_implicit_terms(self, *args, **kwargs):
1646        """
1647        Evaluate the sum of any implicit model terms on some input variables.
1648        This includes any fixed terms used in evaluating a linear model that
1649        do not have corresponding parameters exposed to the user. The
1650        prototypical case is `astropy.modeling.functional_models.Shift`, which
1651        corresponds to a function y = a + bx, where b=1 is intrinsically fixed
1652        by the type of model, such that sum_of_implicit_terms(x) == x. This
1653        method is needed by linear fitters to correct the dependent variable
1654        for the implicit term(s) when solving for the remaining terms
1655        (ie. a = y - bx).
1656        """
1657
1658    def render(self, out=None, coords=None):
1659        """
1660        Evaluate a model at fixed positions, respecting the ``bounding_box``.
1661
1662        The key difference relative to evaluating the model directly is that
1663        this method is limited to a bounding box if the `Model.bounding_box`
1664        attribute is set.
1665
1666        Parameters
1667        ----------
1668        out : `numpy.ndarray`, optional
1669            An array that the evaluated model will be added to.  If this is not
1670            given (or given as ``None``), a new array will be created.
1671        coords : array-like, optional
1672            An array to be used to translate from the model's input coordinates
1673            to the ``out`` array. It should have the property that
1674            ``self(coords)`` yields the same shape as ``out``.  If ``out`` is
1675            not specified, ``coords`` will be used to determine the shape of
1676            the returned array. If this is not provided (or None), the model
1677            will be evaluated on a grid determined by `Model.bounding_box`.
1678
1679        Returns
1680        -------
1681        out : `numpy.ndarray`
1682            The model added to ``out`` if  ``out`` is not ``None``, or else a
1683            new array from evaluating the model over ``coords``.
1684            If ``out`` and ``coords`` are both `None`, the returned array is
1685            limited to the `Model.bounding_box` limits. If
1686            `Model.bounding_box` is `None`, ``arr`` or ``coords`` must be
1687            passed.
1688
1689        Raises
1690        ------
1691        ValueError
1692            If ``coords`` are not given and the the `Model.bounding_box` of
1693            this model is not set.
1694
1695        Examples
1696        --------
1697        :ref:`astropy:bounding-boxes`
1698        """
1699
1700        try:
1701            bbox = self.bounding_box
1702        except NotImplementedError:
1703            bbox = None
1704
1705        if isinstance(bbox, ModelBoundingBox):
1706            bbox = bbox.bounding_box()
1707
1708        ndim = self.n_inputs
1709
1710        if (coords is None) and (out is None) and (bbox is None):
1711            raise ValueError('If no bounding_box is set, '
1712                             'coords or out must be input.')
1713
1714        # for consistent indexing
1715        if ndim == 1:
1716            if coords is not None:
1717                coords = [coords]
1718            if bbox is not None:
1719                bbox = [bbox]
1720
1721        if coords is not None:
1722            coords = np.asanyarray(coords, dtype=float)
1723            # Check dimensions match out and model
1724            assert len(coords) == ndim
1725            if out is not None:
1726                if coords[0].shape != out.shape:
1727                    raise ValueError('inconsistent shape of the output.')
1728            else:
1729                out = np.zeros(coords[0].shape)
1730
1731        if out is not None:
1732            out = np.asanyarray(out)
1733            if out.ndim != ndim:
1734                raise ValueError('the array and model must have the same '
1735                                 'number of dimensions.')
1736
1737        if bbox is not None:
1738            # Assures position is at center pixel,
1739            # important when using add_array.
1740            pd = np.array([(np.mean(bb), np.ceil((bb[1] - bb[0]) / 2))
1741                           for bb in bbox]).astype(int).T
1742            pos, delta = pd
1743
1744            if coords is not None:
1745                sub_shape = tuple(delta * 2 + 1)
1746                sub_coords = np.array([extract_array(c, sub_shape, pos)
1747                                       for c in coords])
1748            else:
1749                limits = [slice(p - d, p + d + 1, 1) for p, d in pd.T]
1750                sub_coords = np.mgrid[limits]
1751
1752            sub_coords = sub_coords[::-1]
1753
1754            if out is None:
1755                out = self(*sub_coords)
1756            else:
1757                try:
1758                    out = add_array(out, self(*sub_coords), pos)
1759                except ValueError:
1760                    raise ValueError(
1761                        'The `bounding_box` is larger than the input out in '
1762                        'one or more dimensions. Set '
1763                        '`model.bounding_box = None`.')
1764        else:
1765            if coords is None:
1766                im_shape = out.shape
1767                limits = [slice(i) for i in im_shape]
1768                coords = np.mgrid[limits]
1769
1770            coords = coords[::-1]
1771
1772            out += self(*coords)
1773
1774        return out
1775
1776    @property
1777    def input_units(self):
1778        """
1779        This property is used to indicate what units or sets of units the
1780        evaluate method expects, and returns a dictionary mapping inputs to
1781        units (or `None` if any units are accepted).
1782
1783        Model sub-classes can also use function annotations in evaluate to
1784        indicate valid input units, in which case this property should
1785        not be overridden since it will return the input units based on the
1786        annotations.
1787        """
1788        if hasattr(self, '_input_units'):
1789            return self._input_units
1790        elif hasattr(self.evaluate, '__annotations__'):
1791            annotations = self.evaluate.__annotations__.copy()
1792            annotations.pop('return', None)
1793            if annotations:
1794                # If there are not annotations for all inputs this will error.
1795                return dict((name, annotations[name]) for name in self.inputs)
1796        else:
1797            # None means any unit is accepted
1798            return None
1799
1800    @property
1801    def return_units(self):
1802        """
1803        This property is used to indicate what units or sets of units the
1804        output of evaluate should be in, and returns a dictionary mapping
1805        outputs to units (or `None` if any units are accepted).
1806
1807        Model sub-classes can also use function annotations in evaluate to
1808        indicate valid output units, in which case this property should not be
1809        overridden since it will return the return units based on the
1810        annotations.
1811        """
1812        if hasattr(self, '_return_units'):
1813            return self._return_units
1814        elif hasattr(self.evaluate, '__annotations__'):
1815            return self.evaluate.__annotations__.get('return', None)
1816        else:
1817            # None means any unit is accepted
1818            return None
1819
1820    def _prepare_inputs_single_model(self, params, inputs, **kwargs):
1821        broadcasts = []
1822        for idx, _input in enumerate(inputs):
1823            input_shape = _input.shape
1824
1825            # Ensure that array scalars are always upgrade to 1-D arrays for the
1826            # sake of consistency with how parameters work.  They will be cast back
1827            # to scalars at the end
1828            if not input_shape:
1829                inputs[idx] = _input.reshape((1,))
1830
1831            if not params:
1832                max_broadcast = input_shape
1833            else:
1834                max_broadcast = ()
1835
1836            for param in params:
1837                try:
1838                    if self.standard_broadcasting:
1839                        broadcast = check_broadcast(input_shape, param.shape)
1840                    else:
1841                        broadcast = input_shape
1842                except IncompatibleShapeError:
1843                    raise ValueError(
1844                        "self input argument {0!r} of shape {1!r} cannot be "
1845                        "broadcast with parameter {2!r} of shape "
1846                        "{3!r}.".format(self.inputs[idx], input_shape,
1847                                        param.name, param.shape))
1848
1849                if len(broadcast) > len(max_broadcast):
1850                    max_broadcast = broadcast
1851                elif len(broadcast) == len(max_broadcast):
1852                    max_broadcast = max(max_broadcast, broadcast)
1853
1854            broadcasts.append(max_broadcast)
1855
1856        if self.n_outputs > self.n_inputs:
1857            extra_outputs = self.n_outputs - self.n_inputs
1858            if not broadcasts:
1859                # If there were no inputs then the broadcasts list is empty
1860                # just add a None since there is no broadcasting of outputs and
1861                # inputs necessary (see _prepare_outputs_single_self)
1862                broadcasts.append(None)
1863            broadcasts.extend([broadcasts[0]] * extra_outputs)
1864
1865        return inputs, (broadcasts,)
1866
1867    @staticmethod
1868    def _remove_axes_from_shape(shape, axis):
1869        """
1870        Given a shape tuple as the first input, construct a new one by  removing
1871        that particular axis from the shape and all preceeding axes. Negative axis
1872        numbers are permittted, where the axis is relative to the last axis.
1873        """
1874        if len(shape) == 0:
1875            return shape
1876        if axis < 0:
1877            axis = len(shape) + axis
1878            return shape[:axis] + shape[axis+1:]
1879        if axis >= len(shape):
1880            axis = len(shape)-1
1881        shape = shape[axis+1:]
1882        return shape
1883
1884    def _prepare_inputs_model_set(self, params, inputs, model_set_axis_input,
1885                                  **kwargs):
1886        reshaped = []
1887        pivots = []
1888
1889        model_set_axis_param = self.model_set_axis  # needed to reshape param
1890        for idx, _input in enumerate(inputs):
1891            max_param_shape = ()
1892            if self._n_models > 1 and model_set_axis_input is not False:
1893                # Use the shape of the input *excluding* the model axis
1894                input_shape = (_input.shape[:model_set_axis_input] +
1895                               _input.shape[model_set_axis_input + 1:])
1896            else:
1897                input_shape = _input.shape
1898
1899            for param in params:
1900                try:
1901                    check_broadcast(input_shape,
1902                                    self._remove_axes_from_shape(param.shape,
1903                                                                 model_set_axis_param))
1904                except IncompatibleShapeError:
1905                    raise ValueError(
1906                        "Model input argument {0!r} of shape {1!r} cannot be "
1907                        "broadcast with parameter {2!r} of shape "
1908                        "{3!r}.".format(self.inputs[idx], input_shape,
1909                                        param.name,
1910                                        self._remove_axes_from_shape(param.shape,
1911                                                                     model_set_axis_param)))
1912
1913                if len(param.shape) - 1 > len(max_param_shape):
1914                    max_param_shape = self._remove_axes_from_shape(param.shape,
1915                                                                   model_set_axis_param)
1916
1917            # We've now determined that, excluding the model_set_axis, the
1918            # input can broadcast with all the parameters
1919            input_ndim = len(input_shape)
1920            if model_set_axis_input is False:
1921                if len(max_param_shape) > input_ndim:
1922                    # Just needs to prepend new axes to the input
1923                    n_new_axes = 1 + len(max_param_shape) - input_ndim
1924                    new_axes = (1,) * n_new_axes
1925                    new_shape = new_axes + _input.shape
1926                    pivot = model_set_axis_param
1927                else:
1928                    pivot = input_ndim - len(max_param_shape)
1929                    new_shape = (_input.shape[:pivot] + (1,) +
1930                                 _input.shape[pivot:])
1931                new_input = _input.reshape(new_shape)
1932            else:
1933                if len(max_param_shape) >= input_ndim:
1934                    n_new_axes = len(max_param_shape) - input_ndim
1935                    pivot = self.model_set_axis
1936                    new_axes = (1,) * n_new_axes
1937                    new_shape = (_input.shape[:pivot + 1] + new_axes +
1938                                 _input.shape[pivot + 1:])
1939                    new_input = _input.reshape(new_shape)
1940                else:
1941                    pivot = _input.ndim - len(max_param_shape) - 1
1942                    new_input = np.rollaxis(_input, model_set_axis_input,
1943                                            pivot + 1)
1944            pivots.append(pivot)
1945            reshaped.append(new_input)
1946
1947        if self.n_inputs < self.n_outputs:
1948            pivots.extend([model_set_axis_input] * (self.n_outputs - self.n_inputs))
1949
1950        return reshaped, (pivots,)
1951
1952    def prepare_inputs(self, *inputs, model_set_axis=None, equivalencies=None,
1953                       **kwargs):
1954        """
1955        This method is used in `~astropy.modeling.Model.__call__` to ensure
1956        that all the inputs to the model can be broadcast into compatible
1957        shapes (if one or both of them are input as arrays), particularly if
1958        there are more than one parameter sets. This also makes sure that (if
1959        applicable) the units of the input will be compatible with the evaluate
1960        method.
1961        """
1962        # When we instantiate the model class, we make sure that __call__ can
1963        # take the following two keyword arguments: model_set_axis and
1964        # equivalencies.
1965        if model_set_axis is None:
1966            # By default the model_set_axis for the input is assumed to be the
1967            # same as that for the parameters the model was defined with
1968            # TODO: Ensure that negative model_set_axis arguments are respected
1969            model_set_axis = self.model_set_axis
1970
1971        params = [getattr(self, name) for name in self.param_names]
1972        inputs = [np.asanyarray(_input, dtype=float) for _input in inputs]
1973
1974        self._validate_input_shapes(inputs, self.inputs, model_set_axis)
1975
1976        inputs_map = kwargs.get('inputs_map', None)
1977
1978        inputs = self._validate_input_units(inputs, equivalencies, inputs_map)
1979
1980        # The input formatting required for single models versus a multiple
1981        # model set are different enough that they've been split into separate
1982        # subroutines
1983        if self._n_models == 1:
1984            return self._prepare_inputs_single_model(params, inputs, **kwargs)
1985        else:
1986            return self._prepare_inputs_model_set(params, inputs,
1987                                                  model_set_axis, **kwargs)
1988
1989    def _validate_input_units(self, inputs, equivalencies=None, inputs_map=None):
1990        inputs = list(inputs)
1991        name = self.name or self.__class__.__name__
1992        # Check that the units are correct, if applicable
1993
1994        if self.input_units is not None:
1995            # If a leaflist is provided that means this is in the context of
1996            # a compound model and it is necessary to create the appropriate
1997            # alias for the input coordinate name for the equivalencies dict
1998            if inputs_map:
1999                edict = {}
2000                for mod, mapping in inputs_map:
2001                    if self is mod:
2002                        edict[mapping[0]] = equivalencies[mapping[1]]
2003            else:
2004                edict = equivalencies
2005            # We combine any instance-level input equivalencies with user
2006            # specified ones at call-time.
2007            input_units_equivalencies = _combine_equivalency_dict(self.inputs,
2008                                                                  edict,
2009                                                                  self.input_units_equivalencies)
2010
2011            # We now iterate over the different inputs and make sure that their
2012            # units are consistent with those specified in input_units.
2013            for i in range(len(inputs)):
2014
2015                input_name = self.inputs[i]
2016                input_unit = self.input_units.get(input_name, None)
2017
2018                if input_unit is None:
2019                    continue
2020
2021                if isinstance(inputs[i], Quantity):
2022
2023                    # We check for consistency of the units with input_units,
2024                    # taking into account any equivalencies
2025
2026                    if inputs[i].unit.is_equivalent(
2027                            input_unit,
2028                            equivalencies=input_units_equivalencies[input_name]):
2029
2030                        # If equivalencies have been specified, we need to
2031                        # convert the input to the input units - this is
2032                        # because some equivalencies are non-linear, and
2033                        # we need to be sure that we evaluate the model in
2034                        # its own frame of reference. If input_units_strict
2035                        # is set, we also need to convert to the input units.
2036                        if len(input_units_equivalencies) > 0 or self.input_units_strict[input_name]:
2037                            inputs[i] = inputs[i].to(input_unit,
2038                                                     equivalencies=input_units_equivalencies[input_name])
2039
2040                    else:
2041
2042                        # We consider the following two cases separately so as
2043                        # to be able to raise more appropriate/nicer exceptions
2044
2045                        if input_unit is dimensionless_unscaled:
2046                            raise UnitsError("{0}: Units of input '{1}', {2} ({3}),"
2047                                             "could not be converted to "
2048                                             "required dimensionless "
2049                                             "input".format(name,
2050                                                            self.inputs[i],
2051                                                            inputs[i].unit,
2052                                                            inputs[i].unit.physical_type))
2053                        else:
2054                            raise UnitsError("{0}: Units of input '{1}', {2} ({3}),"
2055                                             " could not be "
2056                                             "converted to required input"
2057                                             " units of {4} ({5})".format(
2058                                                 name,
2059                                                 self.inputs[i],
2060                                                 inputs[i].unit,
2061                                                 inputs[i].unit.physical_type,
2062                                                 input_unit,
2063                                                 input_unit.physical_type))
2064                else:
2065
2066                    # If we allow dimensionless input, we add the units to the
2067                    # input values without conversion, otherwise we raise an
2068                    # exception.
2069
2070                    if (not self.input_units_allow_dimensionless[input_name] and
2071                        input_unit is not dimensionless_unscaled and
2072                        input_unit is not None):
2073                        if np.any(inputs[i] != 0):
2074                            raise UnitsError("{0}: Units of input '{1}', (dimensionless), could not be "
2075                                             "converted to required input units of "
2076                                             "{2} ({3})".format(name, self.inputs[i], input_unit,
2077                                                                input_unit.physical_type))
2078        return inputs
2079
2080    def _process_output_units(self, inputs, outputs):
2081        inputs_are_quantity = any([isinstance(i, Quantity) for i in inputs])
2082        if self.return_units and inputs_are_quantity:
2083            # We allow a non-iterable unit only if there is one output
2084            if self.n_outputs == 1 and not isiterable(self.return_units):
2085                return_units = {self.outputs[0]: self.return_units}
2086            else:
2087                return_units = self.return_units
2088
2089            outputs = tuple([Quantity(out, return_units.get(out_name, None), subok=True)
2090                             for out, out_name in zip(outputs, self.outputs)])
2091        return outputs
2092
2093    @staticmethod
2094    def _prepare_output_single_model(output, broadcast_shape):
2095        if broadcast_shape is not None:
2096            if not broadcast_shape:
2097                return output.item()
2098            else:
2099                try:
2100                    return output.reshape(broadcast_shape)
2101                except ValueError:
2102                    try:
2103                        return output.item()
2104                    except ValueError:
2105                        return output
2106
2107        return output
2108
2109    def _prepare_outputs_single_model(self, outputs, broadcasted_shapes):
2110        outputs = list(outputs)
2111        for idx, output in enumerate(outputs):
2112            try:
2113                broadcast_shape = check_broadcast(*broadcasted_shapes[0])
2114            except (IndexError, TypeError):
2115                broadcast_shape = broadcasted_shapes[0][idx]
2116
2117            outputs[idx] = self._prepare_output_single_model(output, broadcast_shape)
2118
2119        return tuple(outputs)
2120
2121    def _prepare_outputs_model_set(self, outputs, broadcasted_shapes, model_set_axis):
2122        pivots = broadcasted_shapes[0]
2123        # If model_set_axis = False was passed then use
2124        # self._model_set_axis to format the output.
2125        if model_set_axis is None or model_set_axis is False:
2126            model_set_axis = self.model_set_axis
2127        outputs = list(outputs)
2128        for idx, output in enumerate(outputs):
2129            pivot = pivots[idx]
2130            if pivot < output.ndim and pivot != model_set_axis:
2131                outputs[idx] = np.rollaxis(output, pivot,
2132                                           model_set_axis)
2133        return tuple(outputs)
2134
2135    def prepare_outputs(self, broadcasted_shapes, *outputs, **kwargs):
2136        model_set_axis = kwargs.get('model_set_axis', None)
2137
2138        if len(self) == 1:
2139            return self._prepare_outputs_single_model(outputs, broadcasted_shapes)
2140        else:
2141            return self._prepare_outputs_model_set(outputs, broadcasted_shapes, model_set_axis)
2142
2143    def copy(self):
2144        """
2145        Return a copy of this model.
2146
2147        Uses a deep copy so that all model attributes, including parameter
2148        values, are copied as well.
2149        """
2150
2151        return copy.deepcopy(self)
2152
2153    def deepcopy(self):
2154        """
2155        Return a deep copy of this model.
2156
2157        """
2158
2159        return self.copy()
2160
2161    @sharedmethod
2162    def rename(self, name):
2163        """
2164        Return a copy of this model with a new name.
2165        """
2166        new_model = self.copy()
2167        new_model._name = name
2168        return new_model
2169
2170    def coerce_units(
2171        self,
2172        input_units=None,
2173        return_units=None,
2174        input_units_equivalencies=None,
2175        input_units_allow_dimensionless=False
2176    ):
2177        """
2178        Attach units to this (unitless) model.
2179
2180        Parameters
2181        ----------
2182        input_units : dict or tuple, optional
2183            Input units to attach.  If dict, each key is the name of a model input,
2184            and the value is the unit to attach.  If tuple, the elements are units
2185            to attach in order corresponding to `Model.inputs`.
2186        return_units : dict or tuple, optional
2187            Output units to attach.  If dict, each key is the name of a model output,
2188            and the value is the unit to attach.  If tuple, the elements are units
2189            to attach in order corresponding to `Model.outputs`.
2190        input_units_equivalencies : dict, optional
2191            Default equivalencies to apply to input values.  If set, this should be a
2192            dictionary where each key is a string that corresponds to one of the
2193            model inputs.
2194        input_units_allow_dimensionless : bool or dict, optional
2195            Allow dimensionless input. If this is True, input values to evaluate will
2196            gain the units specified in input_units. If this is a dictionary then it
2197            should map input name to a bool to allow dimensionless numbers for that
2198            input.
2199
2200        Returns
2201        -------
2202        `CompoundModel`
2203            A `CompoundModel` composed of the current model plus
2204            `~astropy.modeling.mappings.UnitsMapping` model(s) that attach the units.
2205
2206        Raises
2207        ------
2208        ValueError
2209            If the current model already has units.
2210
2211        Examples
2212        --------
2213
2214        Wrapping a unitless model to require and convert units:
2215
2216        >>> from astropy.modeling.models import Polynomial1D
2217        >>> from astropy import units as u
2218        >>> poly = Polynomial1D(1, c0=1, c1=2)
2219        >>> model = poly.coerce_units((u.m,), (u.s,))
2220        >>> model(u.Quantity(10, u.m))  # doctest: +FLOAT_CMP
2221        <Quantity 21. s>
2222        >>> model(u.Quantity(1000, u.cm))  # doctest: +FLOAT_CMP
2223        <Quantity 21. s>
2224        >>> model(u.Quantity(10, u.cm))  # doctest: +FLOAT_CMP
2225        <Quantity 1.2 s>
2226
2227        Wrapping a unitless model but still permitting unitless input:
2228
2229        >>> from astropy.modeling.models import Polynomial1D
2230        >>> from astropy import units as u
2231        >>> poly = Polynomial1D(1, c0=1, c1=2)
2232        >>> model = poly.coerce_units((u.m,), (u.s,), input_units_allow_dimensionless=True)
2233        >>> model(u.Quantity(10, u.m))  # doctest: +FLOAT_CMP
2234        <Quantity 21. s>
2235        >>> model(10)  # doctest: +FLOAT_CMP
2236        <Quantity 21. s>
2237        """
2238        from .mappings import UnitsMapping
2239
2240        result = self
2241
2242        if input_units is not None:
2243            if self.input_units is not None:
2244                model_units = self.input_units
2245            else:
2246                model_units = {}
2247
2248            for unit in [model_units.get(i) for i in self.inputs]:
2249                if unit is not None and unit != dimensionless_unscaled:
2250                    raise ValueError("Cannot specify input_units for model with existing input units")
2251
2252            if isinstance(input_units, dict):
2253                if input_units.keys() != set(self.inputs):
2254                    message = (
2255                        f"""input_units keys ({", ".join(input_units.keys())}) """
2256                        f"""do not match model inputs ({", ".join(self.inputs)})"""
2257                    )
2258                    raise ValueError(message)
2259                input_units = [input_units[i] for i in self.inputs]
2260
2261            if len(input_units) != self.n_inputs:
2262                message = (
2263                    "input_units length does not match n_inputs: "
2264                    f"expected {self.n_inputs}, received {len(input_units)}"
2265                )
2266                raise ValueError(message)
2267
2268            mapping = tuple((unit, model_units.get(i)) for i, unit in zip(self.inputs, input_units))
2269            input_mapping = UnitsMapping(
2270                mapping,
2271                input_units_equivalencies=input_units_equivalencies,
2272                input_units_allow_dimensionless=input_units_allow_dimensionless
2273            )
2274            input_mapping.inputs = self.inputs
2275            input_mapping.outputs = self.inputs
2276            result = input_mapping | result
2277
2278        if return_units is not None:
2279            if self.return_units is not None:
2280                model_units = self.return_units
2281            else:
2282                model_units = {}
2283
2284            for unit in [model_units.get(i) for i in self.outputs]:
2285                if unit is not None and unit != dimensionless_unscaled:
2286                    raise ValueError("Cannot specify return_units for model with existing output units")
2287
2288            if isinstance(return_units, dict):
2289                if return_units.keys() != set(self.outputs):
2290                    message = (
2291                        f"""return_units keys ({", ".join(return_units.keys())}) """
2292                        f"""do not match model outputs ({", ".join(self.outputs)})"""
2293                    )
2294                    raise ValueError(message)
2295                return_units = [return_units[i] for i in self.outputs]
2296
2297            if len(return_units) != self.n_outputs:
2298                message = (
2299                    "return_units length does not match n_outputs: "
2300                    f"expected {self.n_outputs}, received {len(return_units)}"
2301                )
2302                raise ValueError(message)
2303
2304            mapping = tuple((model_units.get(i), unit) for i, unit in zip(self.outputs, return_units))
2305            return_mapping = UnitsMapping(mapping)
2306            return_mapping.inputs = self.outputs
2307            return_mapping.outputs = self.outputs
2308            result = result | return_mapping
2309
2310        return result
2311
2312    @property
2313    def n_submodels(self):
2314        """
2315        Return the number of components in a single model, which is
2316        obviously 1.
2317        """
2318        return 1
2319
2320    def _initialize_constraints(self, kwargs):
2321        """
2322        Pop parameter constraint values off the keyword arguments passed to
2323        `Model.__init__` and store them in private instance attributes.
2324        """
2325
2326        # Pop any constraints off the keyword arguments
2327        for constraint in self.parameter_constraints:
2328            values = kwargs.pop(constraint, {})
2329            for ckey, cvalue in values.items():
2330                param = getattr(self, ckey)
2331                setattr(param, constraint, cvalue)
2332        self._mconstraints = {}
2333        for constraint in self.model_constraints:
2334            values = kwargs.pop(constraint, [])
2335            self._mconstraints[constraint] = values
2336
2337    def _initialize_parameters(self, args, kwargs):
2338        """
2339        Initialize the _parameters array that stores raw parameter values for
2340        all parameter sets for use with vectorized fitting algorithms; on
2341        FittableModels the _param_name attributes actually just reference
2342        slices of this array.
2343        """
2344        n_models = kwargs.pop('n_models', None)
2345
2346        if not (n_models is None or
2347                (isinstance(n_models, (int, np.integer)) and n_models >= 1)):
2348            raise ValueError(
2349                "n_models must be either None (in which case it is "
2350                "determined from the model_set_axis of the parameter initial "
2351                "values) or it must be a positive integer "
2352                "(got {0!r})".format(n_models))
2353
2354        model_set_axis = kwargs.pop('model_set_axis', None)
2355        if model_set_axis is None:
2356            if n_models is not None and n_models > 1:
2357                # Default to zero
2358                model_set_axis = 0
2359            else:
2360                # Otherwise disable
2361                model_set_axis = False
2362        else:
2363            if not (model_set_axis is False or
2364                    np.issubdtype(type(model_set_axis), np.integer)):
2365                raise ValueError(
2366                    "model_set_axis must be either False or an integer "
2367                    "specifying the parameter array axis to map to each "
2368                    "model in a set of models (got {0!r}).".format(
2369                        model_set_axis))
2370
2371        # Process positional arguments by matching them up with the
2372        # corresponding parameters in self.param_names--if any also appear as
2373        # keyword arguments this presents a conflict
2374        params = set()
2375        if len(args) > len(self.param_names):
2376            raise TypeError(
2377                "{0}.__init__() takes at most {1} positional arguments ({2} "
2378                "given)".format(self.__class__.__name__, len(self.param_names),
2379                                len(args)))
2380
2381        self._model_set_axis = model_set_axis
2382        self._param_metrics = defaultdict(dict)
2383
2384        for idx, arg in enumerate(args):
2385            if arg is None:
2386                # A value of None implies using the default value, if exists
2387                continue
2388            # We use quantity_asanyarray here instead of np.asanyarray because
2389            # if any of the arguments are quantities, we need to return a
2390            # Quantity object not a plain Numpy array.
2391            param_name = self.param_names[idx]
2392            params.add(param_name)
2393            if not isinstance(arg, Parameter):
2394                value = quantity_asanyarray(arg, dtype=float)
2395            else:
2396                value = arg
2397            self._initialize_parameter_value(param_name, value)
2398
2399        # At this point the only remaining keyword arguments should be
2400        # parameter names; any others are in error.
2401        for param_name in self.param_names:
2402            if param_name in kwargs:
2403                if param_name in params:
2404                    raise TypeError(
2405                        "{0}.__init__() got multiple values for parameter "
2406                        "{1!r}".format(self.__class__.__name__, param_name))
2407                value = kwargs.pop(param_name)
2408                if value is None:
2409                    continue
2410                # We use quantity_asanyarray here instead of np.asanyarray
2411                # because if any of the arguments are quantities, we need
2412                # to return a Quantity object not a plain Numpy array.
2413                value = quantity_asanyarray(value, dtype=float)
2414                params.add(param_name)
2415                self._initialize_parameter_value(param_name, value)
2416        # Now deal with case where param_name is not supplied by args or kwargs
2417        for param_name in self.param_names:
2418            if param_name not in params:
2419                self._initialize_parameter_value(param_name, None)
2420
2421        if kwargs:
2422            # If any keyword arguments were left over at this point they are
2423            # invalid--the base class should only be passed the parameter
2424            # values, constraints, and param_dim
2425            for kwarg in kwargs:
2426                # Just raise an error on the first unrecognized argument
2427                raise TypeError(
2428                    '{0}.__init__() got an unrecognized parameter '
2429                    '{1!r}'.format(self.__class__.__name__, kwarg))
2430
2431        # Determine the number of model sets: If the model_set_axis is
2432        # None then there is just one parameter set; otherwise it is determined
2433        # by the size of that axis on the first parameter--if the other
2434        # parameters don't have the right number of axes or the sizes of their
2435        # model_set_axis don't match an error is raised
2436        if model_set_axis is not False and n_models != 1 and params:
2437            max_ndim = 0
2438            if model_set_axis < 0:
2439                min_ndim = abs(model_set_axis)
2440            else:
2441                min_ndim = model_set_axis + 1
2442
2443            for name in self.param_names:
2444                value = getattr(self, name)
2445                param_ndim = np.ndim(value)
2446                if param_ndim < min_ndim:
2447                    raise InputParameterError(
2448                        "All parameter values must be arrays of dimension "
2449                        "at least {0} for model_set_axis={1} (the value "
2450                        "given for {2!r} is only {3}-dimensional)".format(
2451                            min_ndim, model_set_axis, name, param_ndim))
2452
2453                max_ndim = max(max_ndim, param_ndim)
2454
2455                if n_models is None:
2456                    # Use the dimensions of the first parameter to determine
2457                    # the number of model sets
2458                    n_models = value.shape[model_set_axis]
2459                elif value.shape[model_set_axis] != n_models:
2460                    raise InputParameterError(
2461                        "Inconsistent dimensions for parameter {0!r} for "
2462                        "{1} model sets.  The length of axis {2} must be the "
2463                        "same for all input parameter values".format(
2464                            name, n_models, model_set_axis))
2465
2466            self._check_param_broadcast(max_ndim)
2467        else:
2468            if n_models is None:
2469                n_models = 1
2470
2471            self._check_param_broadcast(None)
2472
2473        self._n_models = n_models
2474        # now validate parameters
2475        for name in params:
2476            param = getattr(self, name)
2477            if param._validator is not None:
2478                param._validator(self, param.value)
2479
2480    def _initialize_parameter_value(self, param_name, value):
2481        """Mostly deals with consistency checks and determining unit issues."""
2482        if isinstance(value, Parameter):
2483            self.__dict__[param_name] = value
2484            return
2485        param = getattr(self, param_name)
2486        # Use default if value is not provided
2487        if value is None:
2488            default = param.default
2489            if default is None:
2490                # No value was supplied for the parameter and the
2491                # parameter does not have a default, therefore the model
2492                # is underspecified
2493                raise TypeError("{0}.__init__() requires a value for parameter "
2494                                "{1!r}".format(self.__class__.__name__, param_name))
2495            value = default
2496            unit = param.unit
2497        else:
2498            if isinstance(value, Quantity):
2499                unit = value.unit
2500                value = value.value
2501            else:
2502                unit = None
2503        if unit is None and param.unit is not None:
2504            raise InputParameterError(
2505                "{0}.__init__() requires a Quantity for parameter "
2506                "{1!r}".format(self.__class__.__name__, param_name))
2507        param._unit = unit
2508        param.internal_unit = None
2509        if param._setter is not None:
2510            if unit is not None:
2511                _val = param._setter(value * unit)
2512            else:
2513                _val = param._setter(value)
2514            if isinstance(_val, Quantity):
2515                param.internal_unit = _val.unit
2516                param._internal_value = np.array(_val.value)
2517            else:
2518                param.internal_unit = None
2519                param._internal_value = np.array(_val)
2520        else:
2521            param._value = np.array(value)
2522
2523    def _initialize_slices(self):
2524
2525        param_metrics = self._param_metrics
2526        total_size = 0
2527
2528        for name in self.param_names:
2529            param = getattr(self, name)
2530            value = param.value
2531            param_size = np.size(value)
2532            param_shape = np.shape(value)
2533            param_slice = slice(total_size, total_size + param_size)
2534            param_metrics[name]['slice'] = param_slice
2535            param_metrics[name]['shape'] = param_shape
2536            param_metrics[name]['size'] = param_size
2537            total_size += param_size
2538        self._parameters = np.empty(total_size, dtype=np.float64)
2539
2540    def _parameters_to_array(self):
2541        # Now set the parameter values (this will also fill
2542        # self._parameters)
2543        param_metrics = self._param_metrics
2544        for name in self.param_names:
2545            param = getattr(self, name)
2546            value = param.value
2547            if not isinstance(value, np.ndarray):
2548                value = np.array([value])
2549            self._parameters[param_metrics[name]['slice']] = value.ravel()
2550
2551        # Finally validate all the parameters; we do this last so that
2552        # validators that depend on one of the other parameters' values will
2553        # work
2554
2555    def _array_to_parameters(self):
2556        param_metrics = self._param_metrics
2557        for name in self.param_names:
2558            param = getattr(self, name)
2559            value = self._parameters[param_metrics[name]['slice']]
2560            value.shape = param_metrics[name]['shape']
2561            param.value = value
2562
2563    def _check_param_broadcast(self, max_ndim):
2564        """
2565        This subroutine checks that all parameter arrays can be broadcast
2566        against each other, and determines the shapes parameters must have in
2567        order to broadcast correctly.
2568
2569        If model_set_axis is None this merely checks that the parameters
2570        broadcast and returns an empty dict if so.  This mode is only used for
2571        single model sets.
2572        """
2573        all_shapes = []
2574        model_set_axis = self._model_set_axis
2575
2576        for name in self.param_names:
2577            param = getattr(self, name)
2578            value = param.value
2579            param_shape = np.shape(value)
2580            param_ndim = len(param_shape)
2581            if max_ndim is not None and param_ndim < max_ndim:
2582                # All arrays have the same number of dimensions up to the
2583                # model_set_axis dimension, but after that they may have a
2584                # different number of trailing axes.  The number of trailing
2585                # axes must be extended for mutual compatibility.  For example
2586                # if max_ndim = 3 and model_set_axis = 0, an array with the
2587                # shape (2, 2) must be extended to (2, 1, 2).  However, an
2588                # array with shape (2,) is extended to (2, 1).
2589                new_axes = (1,) * (max_ndim - param_ndim)
2590
2591                if model_set_axis < 0:
2592                    # Just need to prepend axes to make up the difference
2593                    broadcast_shape = new_axes + param_shape
2594                else:
2595                    broadcast_shape = (param_shape[:model_set_axis + 1] +
2596                                       new_axes +
2597                                       param_shape[model_set_axis + 1:])
2598                self._param_metrics[name]['broadcast_shape'] = broadcast_shape
2599                all_shapes.append(broadcast_shape)
2600            else:
2601                all_shapes.append(param_shape)
2602
2603        # Now check mutual broadcastability of all shapes
2604        try:
2605            check_broadcast(*all_shapes)
2606        except IncompatibleShapeError as exc:
2607            shape_a, shape_a_idx, shape_b, shape_b_idx = exc.args
2608            param_a = self.param_names[shape_a_idx]
2609            param_b = self.param_names[shape_b_idx]
2610
2611            raise InputParameterError(
2612                "Parameter {0!r} of shape {1!r} cannot be broadcast with "
2613                "parameter {2!r} of shape {3!r}.  All parameter arrays "
2614                "must have shapes that are mutually compatible according "
2615                "to the broadcasting rules.".format(param_a, shape_a,
2616                                                    param_b, shape_b))
2617
2618    def _param_sets(self, raw=False, units=False):
2619        """
2620        Implementation of the Model.param_sets property.
2621
2622        This internal implementation has a ``raw`` argument which controls
2623        whether or not to return the raw parameter values (i.e. the values that
2624        are actually stored in the ._parameters array, as opposed to the values
2625        displayed to users.  In most cases these are one in the same but there
2626        are currently a few exceptions.
2627
2628        Note: This is notably an overcomplicated device and may be removed
2629        entirely in the near future.
2630        """
2631
2632        values = []
2633        shapes = []
2634        for name in self.param_names:
2635            param = getattr(self, name)
2636
2637            if raw and param._setter:
2638                value = param._internal_value
2639            else:
2640                value = param.value
2641
2642            broadcast_shape = self._param_metrics[name].get('broadcast_shape')
2643            if broadcast_shape is not None:
2644                value = value.reshape(broadcast_shape)
2645
2646            shapes.append(np.shape(value))
2647
2648            if len(self) == 1:
2649                # Add a single param set axis to the parameter's value (thus
2650                # converting scalars to shape (1,) array values) for
2651                # consistency
2652                value = np.array([value])
2653
2654            if units:
2655                if raw and param.internal_unit is not None:
2656                    unit = param.internal_unit
2657                else:
2658                    unit = param.unit
2659                if unit is not None:
2660                    value = Quantity(value, unit)
2661
2662            values.append(value)
2663
2664        if len(set(shapes)) != 1 or units:
2665            # If the parameters are not all the same shape, converting to an
2666            # array is going to produce an object array
2667            # However the way Numpy creates object arrays is tricky in that it
2668            # will recurse into array objects in the list and break them up
2669            # into separate objects.  Doing things this way ensures a 1-D
2670            # object array the elements of which are the individual parameter
2671            # arrays.  There's not much reason to do this over returning a list
2672            # except for consistency
2673            psets = np.empty(len(values), dtype=object)
2674            psets[:] = values
2675            return psets
2676
2677        return np.array(values)
2678
2679    def _format_repr(self, args=[], kwargs={}, defaults={}):
2680        """
2681        Internal implementation of ``__repr__``.
2682
2683        This is separated out for ease of use by subclasses that wish to
2684        override the default ``__repr__`` while keeping the same basic
2685        formatting.
2686        """
2687
2688        parts = [repr(a) for a in args]
2689
2690        parts.extend(
2691            f"{name}={param_repr_oneline(getattr(self, name))}"
2692            for name in self.param_names)
2693
2694        if self.name is not None:
2695            parts.append(f'name={self.name!r}')
2696
2697        for kwarg, value in kwargs.items():
2698            if kwarg in defaults and defaults[kwarg] == value:
2699                continue
2700            parts.append(f'{kwarg}={value!r}')
2701
2702        if len(self) > 1:
2703            parts.append(f"n_models={len(self)}")
2704
2705        return f"<{self.__class__.__name__}({', '.join(parts)})>"
2706
2707    def _format_str(self, keywords=[], defaults={}):
2708        """
2709        Internal implementation of ``__str__``.
2710
2711        This is separated out for ease of use by subclasses that wish to
2712        override the default ``__str__`` while keeping the same basic
2713        formatting.
2714        """
2715
2716        default_keywords = [
2717            ('Model', self.__class__.__name__),
2718            ('Name', self.name),
2719            ('Inputs', self.inputs),
2720            ('Outputs', self.outputs),
2721            ('Model set size', len(self))
2722        ]
2723
2724        parts = [f'{keyword}: {value}'
2725                 for keyword, value in default_keywords
2726                 if value is not None]
2727
2728        for keyword, value in keywords:
2729            if keyword.lower() in defaults and defaults[keyword.lower()] == value:
2730                continue
2731            parts.append(f'{keyword}: {value}')
2732        parts.append('Parameters:')
2733
2734        if len(self) == 1:
2735            columns = [[getattr(self, name).value]
2736                       for name in self.param_names]
2737        else:
2738            columns = [getattr(self, name).value
2739                       for name in self.param_names]
2740
2741        if columns:
2742            param_table = Table(columns, names=self.param_names)
2743            # Set units on the columns
2744            for name in self.param_names:
2745                param_table[name].unit = getattr(self, name).unit
2746            parts.append(indent(str(param_table), width=4))
2747
2748        return '\n'.join(parts)
2749
2750
2751class FittableModel(Model):
2752    """
2753    Base class for models that can be fitted using the built-in fitting
2754    algorithms.
2755    """
2756
2757    linear = False
2758    # derivative with respect to parameters
2759    fit_deriv = None
2760    """
2761    Function (similar to the model's `~Model.evaluate`) to compute the
2762    derivatives of the model with respect to its parameters, for use by fitting
2763    algorithms.  In other words, this computes the Jacobian matrix with respect
2764    to the model's parameters.
2765    """
2766    # Flag that indicates if the model derivatives with respect to parameters
2767    # are given in columns or rows
2768    col_fit_deriv = True
2769    fittable = True
2770
2771
2772class Fittable1DModel(FittableModel):
2773    """
2774    Base class for one-dimensional fittable models.
2775
2776    This class provides an easier interface to defining new models.
2777    Examples can be found in `astropy.modeling.functional_models`.
2778    """
2779    n_inputs = 1
2780    n_outputs = 1
2781    _separable = True
2782
2783
2784class Fittable2DModel(FittableModel):
2785    """
2786    Base class for two-dimensional fittable models.
2787
2788    This class provides an easier interface to defining new models.
2789    Examples can be found in `astropy.modeling.functional_models`.
2790    """
2791
2792    n_inputs = 2
2793    n_outputs = 1
2794
2795
2796def _make_arithmetic_operator(oper):
2797    # We don't bother with tuple unpacking here for efficiency's sake, but for
2798    # documentation purposes:
2799    #
2800    #     f_eval, f_n_inputs, f_n_outputs = f
2801    #
2802    # and similarly for g
2803    def op(f, g):
2804        return (make_binary_operator_eval(oper, f[0], g[0]), f[1], f[2])
2805
2806    return op
2807
2808
2809def _composition_operator(f, g):
2810    # We don't bother with tuple unpacking here for efficiency's sake, but for
2811    # documentation purposes:
2812    #
2813    #     f_eval, f_n_inputs, f_n_outputs = f
2814    #
2815    # and similarly for g
2816    return (lambda inputs, params: g[0](f[0](inputs, params), params),
2817            f[1], g[2])
2818
2819
2820def _join_operator(f, g):
2821    # We don't bother with tuple unpacking here for efficiency's sake, but for
2822    # documentation purposes:
2823    #
2824    #     f_eval, f_n_inputs, f_n_outputs = f
2825    #
2826    # and similarly for g
2827    return (lambda inputs, params: (f[0](inputs[:f[1]], params) +
2828                                    g[0](inputs[f[1]:], params)),
2829            f[1] + g[1], f[2] + g[2])
2830
2831
2832BINARY_OPERATORS = {
2833    '+': _make_arithmetic_operator(operator.add),
2834    '-': _make_arithmetic_operator(operator.sub),
2835    '*': _make_arithmetic_operator(operator.mul),
2836    '/': _make_arithmetic_operator(operator.truediv),
2837    '**': _make_arithmetic_operator(operator.pow),
2838    '|': _composition_operator,
2839    '&': _join_operator
2840}
2841
2842SPECIAL_OPERATORS = _SpecialOperatorsDict()
2843
2844
2845def _add_special_operator(sop_name, sop):
2846    return SPECIAL_OPERATORS.add(sop_name, sop)
2847
2848
2849class CompoundModel(Model):
2850    '''
2851    Base class for compound models.
2852
2853    While it can be used directly, the recommended way
2854    to combine models is through the model operators.
2855    '''
2856
2857    def __init__(self, op, left, right, name=None):
2858        self.__dict__['_param_names'] = None
2859        self._n_submodels = None
2860        self.op = op
2861        self.left = left
2862        self.right = right
2863        self._bounding_box = None
2864        self._user_bounding_box = None
2865        self._leaflist = None
2866        self._tdict = None
2867        self._parameters = None
2868        self._parameters_ = None
2869        self._param_metrics = None
2870
2871        if op != 'fix_inputs' and len(left) != len(right):
2872            raise ValueError(
2873                'Both operands must have equal values for n_models')
2874        self._n_models = len(left)
2875
2876        if op != 'fix_inputs' and ((left.model_set_axis != right.model_set_axis)
2877                                   or left.model_set_axis):  # not False and not 0
2878            raise ValueError("model_set_axis must be False or 0 and consistent for operands")
2879        self._model_set_axis = left.model_set_axis
2880
2881        if op in ['+', '-', '*', '/', '**'] or op in SPECIAL_OPERATORS:
2882            if (left.n_inputs != right.n_inputs) or \
2883               (left.n_outputs != right.n_outputs):
2884                raise ModelDefinitionError(
2885                    'Both operands must match numbers of inputs and outputs')
2886            self.n_inputs = left.n_inputs
2887            self.n_outputs = left.n_outputs
2888            self.inputs = left.inputs
2889            self.outputs = left.outputs
2890        elif op == '&':
2891            self.n_inputs = left.n_inputs + right.n_inputs
2892            self.n_outputs = left.n_outputs + right.n_outputs
2893            self.inputs = combine_labels(left.inputs, right.inputs)
2894            self.outputs = combine_labels(left.outputs, right.outputs)
2895        elif op == '|':
2896            if left.n_outputs != right.n_inputs:
2897                raise ModelDefinitionError(
2898                    "Unsupported operands for |: {0} (n_inputs={1}, "
2899                    "n_outputs={2}) and {3} (n_inputs={4}, n_outputs={5}); "
2900                    "n_outputs for the left-hand model must match n_inputs "
2901                    "for the right-hand model.".format(
2902                        left.name, left.n_inputs, left.n_outputs, right.name,
2903                        right.n_inputs, right.n_outputs))
2904
2905            self.n_inputs = left.n_inputs
2906            self.n_outputs = right.n_outputs
2907            self.inputs = left.inputs
2908            self.outputs = right.outputs
2909        elif op == 'fix_inputs':
2910            if not isinstance(left, Model):
2911                raise ValueError('First argument to "fix_inputs" must be an instance of an astropy Model.')
2912            if not isinstance(right, dict):
2913                raise ValueError('Expected a dictionary for second argument of "fix_inputs".')
2914
2915            # Dict keys must match either possible indices
2916            # for model on left side, or names for inputs.
2917            self.n_inputs = left.n_inputs - len(right)
2918            # Assign directly to the private attribute (instead of using the setter)
2919            # to avoid asserting the new number of outputs matches the old one.
2920            self._outputs = left.outputs
2921            self.n_outputs = left.n_outputs
2922            newinputs = list(left.inputs)
2923            keys = right.keys()
2924            input_ind = []
2925            for key in keys:
2926                if np.issubdtype(type(key), np.integer):
2927                    if key >= left.n_inputs or key < 0:
2928                        raise ValueError(
2929                            'Substitution key integer value '
2930                            'not among possible input choices.')
2931                    if key in input_ind:
2932                        raise ValueError("Duplicate specification of "
2933                                         "same input (index/name).")
2934                    input_ind.append(key)
2935                elif isinstance(key, str):
2936                    if key not in left.inputs:
2937                        raise ValueError(
2938                            'Substitution key string not among possible '
2939                            'input choices.')
2940                    # Check to see it doesn't match positional
2941                    # specification.
2942                    ind = left.inputs.index(key)
2943                    if ind in input_ind:
2944                        raise ValueError("Duplicate specification of "
2945                                         "same input (index/name).")
2946                    input_ind.append(ind)
2947            # Remove substituted inputs
2948            input_ind.sort()
2949            input_ind.reverse()
2950            for ind in input_ind:
2951                del newinputs[ind]
2952            self.inputs = tuple(newinputs)
2953            # Now check to see if the input model has bounding_box defined.
2954            # If so, remove the appropriate dimensions and set it for this
2955            # instance.
2956            try:
2957                self.bounding_box = \
2958                    self.left.bounding_box.fix_inputs(self, right)
2959            except NotImplementedError:
2960                pass
2961
2962        else:
2963            raise ModelDefinitionError('Illegal operator: ', self.op)
2964        self.name = name
2965        self._fittable = None
2966        self.fit_deriv = None
2967        self.col_fit_deriv = None
2968        if op in ('|', '+', '-'):
2969            self.linear = left.linear and right.linear
2970        else:
2971            self.linear = False
2972        self.eqcons = []
2973        self.ineqcons = []
2974        self.n_left_params = len(self.left.parameters)
2975        self._map_parameters()
2976
2977    def _get_left_inputs_from_args(self, args):
2978        return args[:self.left.n_inputs]
2979
2980    def _get_right_inputs_from_args(self, args):
2981        op = self.op
2982        if op == '&':
2983            # Args expected to look like (*left inputs, *right inputs, *left params, *right params)
2984            return args[self.left.n_inputs: self.left.n_inputs + self.right.n_inputs]
2985        elif op == '|' or  op == 'fix_inputs':
2986            return None
2987        else:
2988            return args[:self.left.n_inputs]
2989
2990    def _get_left_params_from_args(self, args):
2991        op = self.op
2992        if op == '&':
2993            # Args expected to look like (*left inputs, *right inputs, *left params, *right params)
2994            n_inputs = self.left.n_inputs + self.right.n_inputs
2995            return args[n_inputs: n_inputs + self.n_left_params]
2996        else:
2997            return args[self.left.n_inputs: self.left.n_inputs + self.n_left_params]
2998
2999    def _get_right_params_from_args(self, args):
3000        op = self.op
3001        if op == 'fix_inputs':
3002            return None
3003        if op == '&':
3004            # Args expected to look like (*left inputs, *right inputs, *left params, *right params)
3005            return args[self.left.n_inputs + self.right.n_inputs + self.n_left_params:]
3006        else:
3007            return args[self.left.n_inputs + self.n_left_params:]
3008
3009    def _get_kwarg_model_parameters_as_positional(self, args, kwargs):
3010        # could do it with inserts but rebuilding seems like simpilist way
3011
3012        #TODO: Check if any param names are in kwargs maybe as an intersection of sets?
3013        if self.op == "&":
3014            new_args = list(args[:self.left.n_inputs + self.right.n_inputs])
3015            args_pos = self.left.n_inputs + self.right.n_inputs
3016        else:
3017            new_args = list(args[:self.left.n_inputs])
3018            args_pos = self.left.n_inputs
3019
3020        for param_name in self.param_names:
3021            kw_value = kwargs.pop(param_name, None)
3022            if kw_value is not None:
3023                value = kw_value
3024            else:
3025                try:
3026                    value = args[args_pos]
3027                except IndexError:
3028                    raise IndexError("Missing parameter or input")
3029
3030                args_pos += 1
3031            new_args.append(value)
3032
3033        return new_args, kwargs
3034
3035    def _apply_operators_to_value_lists(self, leftval, rightval, **kw):
3036        op = self.op
3037        if op == '+':
3038            return binary_operation(operator.add, leftval, rightval)
3039        elif op == '-':
3040            return binary_operation(operator.sub, leftval, rightval)
3041        elif op == '*':
3042            return binary_operation(operator.mul, leftval, rightval)
3043        elif op == '/':
3044            return binary_operation(operator.truediv, leftval, rightval)
3045        elif op == '**':
3046            return binary_operation(operator.pow, leftval, rightval)
3047        elif op == '&':
3048            if not isinstance(leftval, tuple):
3049                leftval = (leftval,)
3050            if not isinstance(rightval, tuple):
3051                rightval = (rightval,)
3052            return leftval + rightval
3053        elif op in SPECIAL_OPERATORS:
3054            return binary_operation(SPECIAL_OPERATORS[op], leftval, rightval)
3055        else:
3056            raise ModelDefinitionError('Unrecognized operator {op}')
3057
3058    def evaluate(self, *args, **kw):
3059        op = self.op
3060        args, kw = self._get_kwarg_model_parameters_as_positional(args, kw)
3061        left_inputs = self._get_left_inputs_from_args(args)
3062        left_params = self._get_left_params_from_args(args)
3063
3064        if op == 'fix_inputs':
3065            pos_index = dict(zip(self.left.inputs, range(self.left.n_inputs)))
3066            fixed_inputs = {
3067                key if np.issubdtype(type(key), np.integer) else pos_index[key]: value
3068                for key, value in self.right.items()
3069            }
3070            left_inputs = [
3071                fixed_inputs[ind] if ind in fixed_inputs.keys() else inp
3072                for ind, inp in enumerate(left_inputs)
3073            ]
3074
3075        leftval = self.left.evaluate(*itertools.chain(left_inputs, left_params))
3076
3077        if op == 'fix_inputs':
3078            return leftval
3079
3080        right_inputs = self._get_right_inputs_from_args(args)
3081        right_params = self._get_right_params_from_args(args)
3082
3083        if op == "|":
3084            if isinstance(leftval, tuple):
3085                return self.right.evaluate(*itertools.chain(leftval, right_params))
3086            else:
3087                return self.right.evaluate(leftval, *right_params)
3088        else:
3089            rightval = self.right.evaluate(*itertools.chain(right_inputs, right_params))
3090
3091        return self._apply_operators_to_value_lists(leftval, rightval, **kw)
3092
3093    @property
3094    def n_submodels(self):
3095        if self._leaflist is None:
3096            self._make_leaflist()
3097        return len(self._leaflist)
3098
3099    @property
3100    def submodel_names(self):
3101        """ Return the names of submodels in a ``CompoundModel``."""
3102        if self._leaflist is None:
3103            self._make_leaflist()
3104        names = [item.name for item in self._leaflist]
3105        nonecount = 0
3106        newnames = []
3107        for item in names:
3108            if item is None:
3109                newnames.append(f'None_{nonecount}')
3110                nonecount += 1
3111            else:
3112                newnames.append(item)
3113        return tuple(newnames)
3114
3115    def both_inverses_exist(self):
3116        '''
3117        if both members of this compound model have inverses return True
3118        '''
3119        warnings.warn(
3120            "CompoundModel.both_inverses_exist is deprecated. "
3121            "Use has_inverse instead.",
3122            AstropyDeprecationWarning
3123        )
3124
3125        try:
3126            linv = self.left.inverse
3127            rinv = self.right.inverse
3128        except NotImplementedError:
3129            return False
3130
3131        return True
3132
3133    def _pre_evaluate(self, *args, **kwargs):
3134        """
3135        CompoundModel specific input setup that needs to occur prior to
3136            model evaluation.
3137
3138        Note
3139        ----
3140            All of the _pre_evaluate for each component model will be
3141            performed at the time that the individual model is evaluated.
3142        """
3143
3144        # If equivalencies are provided, necessary to map parameters and pass
3145        # the leaflist as a keyword input for use by model evaluation so that
3146        # the compound model input names can be matched to the model input
3147        # names.
3148        if 'equivalencies' in kwargs:
3149            # Restructure to be useful for the individual model lookup
3150            kwargs['inputs_map'] = [(value[0], (value[1], key)) for
3151                                    key, value in self.inputs_map().items()]
3152
3153        # Setup actual model evaluation method
3154        def evaluate(_inputs):
3155            return self._evaluate(*_inputs, **kwargs)
3156
3157        return evaluate, args, None, kwargs
3158
3159    @property
3160    def _argnames(self):
3161        """No inputs should be used to determine input_shape when handling compound models"""
3162        return ()
3163
3164    def _post_evaluate(self, inputs, outputs, broadcasted_shapes, with_bbox, **kwargs):
3165        """
3166        CompoundModel specific post evaluation processing of outputs
3167
3168        Note
3169        ----
3170            All of the _post_evaluate for each component model will be
3171            performed at the time that the individual model is evaluated.
3172        """
3173        if self.get_bounding_box(with_bbox) is not None and self.n_outputs == 1:
3174            return outputs[0]
3175        return outputs
3176
3177    def _evaluate(self, *args, **kw):
3178        op = self.op
3179        if op != 'fix_inputs':
3180            if op != '&':
3181                leftval = self.left(*args, **kw)
3182                if op != '|':
3183                    rightval = self.right(*args, **kw)
3184                else:
3185                    rightval = None
3186
3187            else:
3188                leftval = self.left(*(args[:self.left.n_inputs]), **kw)
3189                rightval = self.right(*(args[self.left.n_inputs:]), **kw)
3190
3191            if op != "|":
3192                return self._apply_operators_to_value_lists(leftval, rightval, **kw)
3193
3194            elif op == '|':
3195                if isinstance(leftval, tuple):
3196                    return self.right(*leftval, **kw)
3197                else:
3198                    return self.right(leftval, **kw)
3199
3200        else:
3201            subs = self.right
3202            newargs = list(args)
3203            subinds = []
3204            subvals = []
3205            for key in subs.keys():
3206                if np.issubdtype(type(key), np.integer):
3207                    subinds.append(key)
3208                elif isinstance(key, str):
3209                    ind = self.left.inputs.index(key)
3210                    subinds.append(ind)
3211                subvals.append(subs[key])
3212            # Turn inputs specified in kw into positional indices.
3213            # Names for compound inputs do not propagate to sub models.
3214            kwind = []
3215            kwval = []
3216            for kwkey in list(kw.keys()):
3217                if kwkey in self.inputs:
3218                    ind = self.inputs.index(kwkey)
3219                    if ind < len(args):
3220                        raise ValueError("Keyword argument duplicates "
3221                                         "positional value supplied.")
3222                    kwind.append(ind)
3223                    kwval.append(kw[kwkey])
3224                    del kw[kwkey]
3225            # Build new argument list
3226            # Append keyword specified args first
3227            if kwind:
3228                kwargs = list(zip(kwind, kwval))
3229                kwargs.sort()
3230                kwindsorted, kwvalsorted = list(zip(*kwargs))
3231                newargs = newargs + list(kwvalsorted)
3232            if subinds:
3233                subargs = list(zip(subinds, subvals))
3234                subargs.sort()
3235                # subindsorted, subvalsorted = list(zip(*subargs))
3236                # The substitutions must be inserted in order
3237                for ind, val in subargs:
3238                    newargs.insert(ind, val)
3239            return self.left(*newargs, **kw)
3240
3241    @property
3242    def param_names(self):
3243        """ An ordered list of parameter names."""
3244        return self._param_names
3245
3246    def _make_leaflist(self):
3247        tdict = {}
3248        leaflist = []
3249        make_subtree_dict(self, '', tdict, leaflist)
3250        self._leaflist = leaflist
3251        self._tdict = tdict
3252
3253    def __getattr__(self, name):
3254        """
3255        If someone accesses an attribute not already defined, map the
3256        parameters, and then see if the requested attribute is one of
3257        the parameters
3258        """
3259        # The following test is needed to avoid infinite recursion
3260        # caused by deepcopy. There may be other such cases discovered.
3261        if name == '__setstate__':
3262            raise AttributeError
3263        if name in self._param_names:
3264            return self.__dict__[name]
3265        else:
3266            raise AttributeError(f'Attribute "{name}" not found')
3267
3268    def __getitem__(self, index):
3269        if self._leaflist is None:
3270            self._make_leaflist()
3271        leaflist = self._leaflist
3272        tdict = self._tdict
3273        if isinstance(index, slice):
3274            if index.step:
3275                raise ValueError('Steps in slices not supported '
3276                                 'for compound models')
3277            if index.start is not None:
3278                if isinstance(index.start, str):
3279                    start = self._str_index_to_int(index.start)
3280                else:
3281                    start = index.start
3282            else:
3283                start = 0
3284            if index.stop is not None:
3285                if isinstance(index.stop, str):
3286                    stop = self._str_index_to_int(index.stop)
3287                else:
3288                    stop = index.stop - 1
3289            else:
3290                stop = len(leaflist) - 1
3291            if index.stop == 0:
3292                raise ValueError("Slice endpoint cannot be 0")
3293            if start < 0:
3294                start = len(leaflist) + start
3295            if stop < 0:
3296                stop = len(leaflist) + stop
3297            # now search for matching node:
3298            if stop == start:  # only single value, get leaf instead in code below
3299                index = start
3300            else:
3301                for key in tdict:
3302                    node, leftind, rightind = tdict[key]
3303                    if leftind == start and rightind == stop:
3304                        return node
3305                raise IndexError("No appropriate subtree matches slice")
3306        if isinstance(index, type(0)):
3307            return leaflist[index]
3308        elif isinstance(index, type('')):
3309            return leaflist[self._str_index_to_int(index)]
3310        else:
3311            raise TypeError('index must be integer, slice, or model name string')
3312
3313    def _str_index_to_int(self, str_index):
3314        # Search through leaflist for item with that name
3315        found = []
3316        for nleaf, leaf in enumerate(self._leaflist):
3317            if getattr(leaf, 'name', None) == str_index:
3318                found.append(nleaf)
3319        if len(found) == 0:
3320            raise IndexError(f"No component with name '{str_index}' found")
3321        if len(found) > 1:
3322            raise IndexError("Multiple components found using '{}' as name\n"
3323                             "at indices {}".format(str_index, found))
3324        return found[0]
3325
3326    @property
3327    def n_inputs(self):
3328        """ The number of inputs of a model."""
3329        return self._n_inputs
3330
3331    @n_inputs.setter
3332    def n_inputs(self, value):
3333        self._n_inputs = value
3334
3335    @property
3336    def n_outputs(self):
3337        """ The number of outputs of a model."""
3338        return self._n_outputs
3339
3340    @n_outputs.setter
3341    def n_outputs(self, value):
3342        self._n_outputs = value
3343
3344    @property
3345    def eqcons(self):
3346        return self._eqcons
3347
3348    @eqcons.setter
3349    def eqcons(self, value):
3350        self._eqcons = value
3351
3352    @property
3353    def ineqcons(self):
3354        return self._eqcons
3355
3356    @ineqcons.setter
3357    def ineqcons(self, value):
3358        self._eqcons = value
3359
3360    def traverse_postorder(self, include_operator=False):
3361        """ Postorder traversal of the CompoundModel tree."""
3362        res = []
3363        if isinstance(self.left, CompoundModel):
3364            res = res + self.left.traverse_postorder(include_operator)
3365        else:
3366            res = res + [self.left]
3367        if isinstance(self.right, CompoundModel):
3368            res = res + self.right.traverse_postorder(include_operator)
3369        else:
3370            res = res + [self.right]
3371        if include_operator:
3372            res.append(self.op)
3373        else:
3374            res.append(self)
3375        return res
3376
3377    def _format_expression(self, format_leaf=None):
3378        leaf_idx = 0
3379        operands = deque()
3380
3381        if format_leaf is None:
3382            format_leaf = lambda i, l: f'[{i}]'
3383
3384        for node in self.traverse_postorder():
3385            if not isinstance(node, CompoundModel):
3386                operands.append(format_leaf(leaf_idx, node))
3387                leaf_idx += 1
3388                continue
3389
3390            right = operands.pop()
3391            left = operands.pop()
3392            if node.op in OPERATOR_PRECEDENCE:
3393                oper_order = OPERATOR_PRECEDENCE[node.op]
3394
3395                if isinstance(node, CompoundModel):
3396                    if (isinstance(node.left, CompoundModel) and
3397                            OPERATOR_PRECEDENCE[node.left.op] < oper_order):
3398                        left = f'({left})'
3399                    if (isinstance(node.right, CompoundModel) and
3400                            OPERATOR_PRECEDENCE[node.right.op] < oper_order):
3401                        right = f'({right})'
3402
3403                operands.append(' '.join((left, node.op, right)))
3404            else:
3405                left = f'(({left}),'
3406                right = f'({right}))'
3407                operands.append(' '.join((node.op[0], left, right)))
3408
3409        return ''.join(operands)
3410
3411    def _format_components(self):
3412        if self._parameters_ is None:
3413            self._map_parameters()
3414        return '\n\n'.join('[{0}]: {1!r}'.format(idx, m)
3415                           for idx, m in enumerate(self._leaflist))
3416
3417    def __str__(self):
3418        expression = self._format_expression()
3419        components = self._format_components()
3420        keywords = [
3421            ('Expression', expression),
3422            ('Components', '\n' + indent(components))
3423        ]
3424        return super()._format_str(keywords=keywords)
3425
3426    def rename(self, name):
3427        self.name = name
3428        return self
3429
3430    @property
3431    def isleaf(self):
3432        return False
3433
3434    @property
3435    def inverse(self):
3436        if self.op == '|':
3437            return self.right.inverse | self.left.inverse
3438        elif self.op == '&':
3439            return self.left.inverse & self.right.inverse
3440        else:
3441            return NotImplemented
3442
3443    @property
3444    def fittable(self):
3445        """ Set the fittable attribute on a compound model."""
3446        if self._fittable is None:
3447            if self._leaflist is None:
3448                self._map_parameters()
3449            self._fittable = all(m.fittable for m in self._leaflist)
3450        return self._fittable
3451
3452    __add__ = _model_oper('+')
3453    __sub__ = _model_oper('-')
3454    __mul__ = _model_oper('*')
3455    __truediv__ = _model_oper('/')
3456    __pow__ = _model_oper('**')
3457    __or__ = _model_oper('|')
3458    __and__ = _model_oper('&')
3459
3460    def _map_parameters(self):
3461        """
3462        Map all the constituent model parameters to the compound object,
3463        renaming as necessary by appending a suffix number.
3464
3465        This can be an expensive operation, particularly for a complex
3466        expression tree.
3467
3468        All the corresponding parameter attributes are created that one
3469        expects for the Model class.
3470
3471        The parameter objects that the attributes point to are the same
3472        objects as in the constiutent models. Changes made to parameter
3473        values to either are seen by both.
3474
3475        Prior to calling this, none of the associated attributes will
3476        exist. This method must be called to make the model usable by
3477        fitting engines.
3478
3479        If oldnames=True, then parameters are named as in the original
3480        implementation of compound models.
3481        """
3482        if self._parameters is not None:
3483            # do nothing
3484            return
3485        if self._leaflist is None:
3486            self._make_leaflist()
3487        self._parameters_ = {}
3488        param_map = {}
3489        self._param_names = []
3490        for lindex, leaf in enumerate(self._leaflist):
3491            if not isinstance(leaf, dict):
3492                for param_name in leaf.param_names:
3493                    param = getattr(leaf, param_name)
3494                    new_param_name = f"{param_name}_{lindex}"
3495                    self.__dict__[new_param_name] = param
3496                    self._parameters_[new_param_name] = param
3497                    self._param_names.append(new_param_name)
3498                    param_map[new_param_name] = (lindex, param_name)
3499        self._param_metrics = {}
3500        self._param_map = param_map
3501        self._param_map_inverse = dict((v, k) for k, v in param_map.items())
3502        self._initialize_slices()
3503        self._param_names = tuple(self._param_names)
3504
3505    def _initialize_slices(self):
3506        param_metrics = self._param_metrics
3507        total_size = 0
3508
3509        for name in self.param_names:
3510            param = getattr(self, name)
3511            value = param.value
3512            param_size = np.size(value)
3513            param_shape = np.shape(value)
3514            param_slice = slice(total_size, total_size + param_size)
3515            param_metrics[name] = {}
3516            param_metrics[name]['slice'] = param_slice
3517            param_metrics[name]['shape'] = param_shape
3518            param_metrics[name]['size'] = param_size
3519            total_size += param_size
3520        self._parameters = np.empty(total_size, dtype=np.float64)
3521
3522    @staticmethod
3523    def _recursive_lookup(branch, adict, key):
3524        if isinstance(branch, CompoundModel):
3525            return adict[key]
3526        return branch, key
3527
3528    def inputs_map(self):
3529        """
3530        Map the names of the inputs to this ExpressionTree to the inputs to the leaf models.
3531        """
3532        inputs_map = {}
3533        if not isinstance(self.op, str):  # If we don't have an operator the mapping is trivial
3534            return {inp: (self, inp) for inp in self.inputs}
3535
3536        elif self.op == '|':
3537            if isinstance(self.left, CompoundModel):
3538                l_inputs_map = self.left.inputs_map()
3539            for inp in self.inputs:
3540                if isinstance(self.left, CompoundModel):
3541                    inputs_map[inp] = l_inputs_map[inp]
3542                else:
3543                    inputs_map[inp] = self.left, inp
3544        elif self.op == '&':
3545            if isinstance(self.left, CompoundModel):
3546                l_inputs_map = self.left.inputs_map()
3547            if isinstance(self.right, CompoundModel):
3548                r_inputs_map = self.right.inputs_map()
3549            for i, inp in enumerate(self.inputs):
3550                if i < len(self.left.inputs):  # Get from left
3551                    if isinstance(self.left, CompoundModel):
3552                        inputs_map[inp] = l_inputs_map[self.left.inputs[i]]
3553                    else:
3554                        inputs_map[inp] = self.left, self.left.inputs[i]
3555                else:  # Get from right
3556                    if isinstance(self.right, CompoundModel):
3557                        inputs_map[inp] = r_inputs_map[self.right.inputs[i - len(self.left.inputs)]]
3558                    else:
3559                        inputs_map[inp] = self.right, self.right.inputs[i - len(self.left.inputs)]
3560        elif self.op == 'fix_inputs':
3561            fixed_ind = list(self.right.keys())
3562            ind = [list(self.left.inputs).index(i) if isinstance(i, str) else i for i in fixed_ind]
3563            inp_ind = list(range(self.left.n_inputs))
3564            for i in ind:
3565                inp_ind.remove(i)
3566            for i in inp_ind:
3567                inputs_map[self.left.inputs[i]] = self.left, self.left.inputs[i]
3568        else:
3569            if isinstance(self.left, CompoundModel):
3570                l_inputs_map = self.left.inputs_map()
3571            for inp in self.left.inputs:
3572                if isinstance(self.left, CompoundModel):
3573                    inputs_map[inp] = l_inputs_map[inp]
3574                else:
3575                    inputs_map[inp] = self.left, inp
3576        return inputs_map
3577
3578    def _parameter_units_for_data_units(self, input_units, output_units):
3579        if self._leaflist is None:
3580            self._map_parameters()
3581        units_for_data = {}
3582        for imodel, model in enumerate(self._leaflist):
3583            units_for_data_leaf = model._parameter_units_for_data_units(input_units, output_units)
3584            for param_leaf in units_for_data_leaf:
3585                param = self._param_map_inverse[(imodel, param_leaf)]
3586                units_for_data[param] = units_for_data_leaf[param_leaf]
3587        return units_for_data
3588
3589    @property
3590    def input_units(self):
3591        inputs_map = self.inputs_map()
3592        input_units_dict = {key: inputs_map[key][0].input_units[orig_key]
3593                            for key, (mod, orig_key) in inputs_map.items()
3594                            if inputs_map[key][0].input_units is not None}
3595        if input_units_dict:
3596            return input_units_dict
3597        return None
3598
3599    @property
3600    def input_units_equivalencies(self):
3601        inputs_map = self.inputs_map()
3602        return {key: inputs_map[key][0].input_units_equivalencies[orig_key]
3603                for key, (mod, orig_key) in inputs_map.items()
3604                if inputs_map[key][0].input_units_equivalencies is not None}
3605
3606    @property
3607    def input_units_allow_dimensionless(self):
3608        inputs_map = self.inputs_map()
3609        return {key: inputs_map[key][0].input_units_allow_dimensionless[orig_key]
3610                for key, (mod, orig_key) in inputs_map.items()}
3611
3612    @property
3613    def input_units_strict(self):
3614        inputs_map = self.inputs_map()
3615        return {key: inputs_map[key][0].input_units_strict[orig_key]
3616                for key, (mod, orig_key) in inputs_map.items()}
3617
3618    @property
3619    def return_units(self):
3620        outputs_map = self.outputs_map()
3621        return {key: outputs_map[key][0].return_units[orig_key]
3622                for key, (mod, orig_key) in outputs_map.items()
3623                if outputs_map[key][0].return_units is not None}
3624
3625    def outputs_map(self):
3626        """
3627        Map the names of the outputs to this ExpressionTree to the outputs to the leaf models.
3628        """
3629        outputs_map = {}
3630        if not isinstance(self.op, str):  # If we don't have an operator the mapping is trivial
3631            return {out: (self, out) for out in self.outputs}
3632
3633        elif self.op == '|':
3634            if isinstance(self.right, CompoundModel):
3635                r_outputs_map = self.right.outputs_map()
3636            for out in self.outputs:
3637                if isinstance(self.right, CompoundModel):
3638                    outputs_map[out] = r_outputs_map[out]
3639                else:
3640                    outputs_map[out] = self.right, out
3641
3642        elif self.op == '&':
3643            if isinstance(self.left, CompoundModel):
3644                l_outputs_map = self.left.outputs_map()
3645            if isinstance(self.right, CompoundModel):
3646                r_outputs_map = self.right.outputs_map()
3647            for i, out in enumerate(self.outputs):
3648                if i < len(self.left.outputs):  # Get from left
3649                    if isinstance(self.left, CompoundModel):
3650                        outputs_map[out] = l_outputs_map[self.left.outputs[i]]
3651                    else:
3652                        outputs_map[out] = self.left, self.left.outputs[i]
3653                else:  # Get from right
3654                    if isinstance(self.right, CompoundModel):
3655                        outputs_map[out] = r_outputs_map[self.right.outputs[i - len(self.left.outputs)]]
3656                    else:
3657                        outputs_map[out] = self.right, self.right.outputs[i - len(self.left.outputs)]
3658        elif self.op == 'fix_inputs':
3659            return self.left.outputs_map()
3660        else:
3661            if isinstance(self.left, CompoundModel):
3662                l_outputs_map = self.left.outputs_map()
3663            for out in self.left.outputs:
3664                if isinstance(self.left, CompoundModel):
3665                    outputs_map[out] = l_outputs_map()[out]
3666                else:
3667                    outputs_map[out] = self.left, out
3668        return outputs_map
3669
3670    @property
3671    def has_user_bounding_box(self):
3672        """
3673        A flag indicating whether or not a custom bounding_box has been
3674        assigned to this model by a user, via assignment to
3675        ``model.bounding_box``.
3676        """
3677
3678        return self._user_bounding_box is not None
3679
3680    def render(self, out=None, coords=None):
3681        """
3682        Evaluate a model at fixed positions, respecting the ``bounding_box``.
3683
3684        The key difference relative to evaluating the model directly is that
3685        this method is limited to a bounding box if the `Model.bounding_box`
3686        attribute is set.
3687
3688        Parameters
3689        ----------
3690        out : `numpy.ndarray`, optional
3691            An array that the evaluated model will be added to.  If this is not
3692            given (or given as ``None``), a new array will be created.
3693        coords : array-like, optional
3694            An array to be used to translate from the model's input coordinates
3695            to the ``out`` array. It should have the property that
3696            ``self(coords)`` yields the same shape as ``out``.  If ``out`` is
3697            not specified, ``coords`` will be used to determine the shape of
3698            the returned array. If this is not provided (or None), the model
3699            will be evaluated on a grid determined by `Model.bounding_box`.
3700
3701        Returns
3702        -------
3703        out : `numpy.ndarray`
3704            The model added to ``out`` if  ``out`` is not ``None``, or else a
3705            new array from evaluating the model over ``coords``.
3706            If ``out`` and ``coords`` are both `None`, the returned array is
3707            limited to the `Model.bounding_box` limits. If
3708            `Model.bounding_box` is `None`, ``arr`` or ``coords`` must be
3709            passed.
3710
3711        Raises
3712        ------
3713        ValueError
3714            If ``coords`` are not given and the the `Model.bounding_box` of
3715            this model is not set.
3716
3717        Examples
3718        --------
3719        :ref:`astropy:bounding-boxes`
3720        """
3721
3722        bbox = self.get_bounding_box()
3723
3724        ndim = self.n_inputs
3725
3726        if (coords is None) and (out is None) and (bbox is None):
3727            raise ValueError('If no bounding_box is set, '
3728                             'coords or out must be input.')
3729
3730        # for consistent indexing
3731        if ndim == 1:
3732            if coords is not None:
3733                coords = [coords]
3734            if bbox is not None:
3735                bbox = [bbox]
3736
3737        if coords is not None:
3738            coords = np.asanyarray(coords, dtype=float)
3739            # Check dimensions match out and model
3740            assert len(coords) == ndim
3741            if out is not None:
3742                if coords[0].shape != out.shape:
3743                    raise ValueError('inconsistent shape of the output.')
3744            else:
3745                out = np.zeros(coords[0].shape)
3746
3747        if out is not None:
3748            out = np.asanyarray(out)
3749            if out.ndim != ndim:
3750                raise ValueError('the array and model must have the same '
3751                                 'number of dimensions.')
3752
3753        if bbox is not None:
3754            # Assures position is at center pixel, important when using
3755            # add_array.
3756            pd = np.array([(np.mean(bb), np.ceil((bb[1] - bb[0]) / 2))
3757                           for bb in bbox]).astype(int).T
3758            pos, delta = pd
3759
3760            if coords is not None:
3761                sub_shape = tuple(delta * 2 + 1)
3762                sub_coords = np.array([extract_array(c, sub_shape, pos)
3763                                       for c in coords])
3764            else:
3765                limits = [slice(p - d, p + d + 1, 1) for p, d in pd.T]
3766                sub_coords = np.mgrid[limits]
3767
3768            sub_coords = sub_coords[::-1]
3769
3770            if out is None:
3771                out = self(*sub_coords)
3772            else:
3773                try:
3774                    out = add_array(out, self(*sub_coords), pos)
3775                except ValueError:
3776                    raise ValueError(
3777                        'The `bounding_box` is larger than the input out in '
3778                        'one or more dimensions. Set '
3779                        '`model.bounding_box = None`.')
3780        else:
3781            if coords is None:
3782                im_shape = out.shape
3783                limits = [slice(i) for i in im_shape]
3784                coords = np.mgrid[limits]
3785
3786            coords = coords[::-1]
3787
3788            out += self(*coords)
3789
3790        return out
3791
3792    def replace_submodel(self, name, model):
3793        """
3794        Construct a new `~astropy.modeling.CompoundModel` instance from an
3795        existing CompoundModel, replacing the named submodel with a new model.
3796
3797        In order to ensure that inverses and names are kept/reconstructed, it's
3798        necessary to rebuild the CompoundModel from the replaced node all the
3799        way back to the base. The original CompoundModel is left untouched.
3800
3801        Parameters
3802        ----------
3803        name : str
3804            name of submodel to be replaced
3805        model : `~astropy.modeling.Model`
3806            replacement model
3807        """
3808        submodels = [m for m in self.traverse_postorder()
3809                     if getattr(m, 'name', None) == name]
3810        if submodels:
3811            if len(submodels) > 1:
3812                raise ValueError(f"More than one submodel named {name}")
3813
3814            old_model = submodels.pop()
3815            if len(old_model) != len(model):
3816                raise ValueError("New and old models must have equal values "
3817                                 "for n_models")
3818
3819            # Do this check first in order to raise a more helpful Exception,
3820            # although it would fail trying to construct the new CompoundModel
3821            if (old_model.n_inputs != model.n_inputs or
3822                        old_model.n_outputs != model.n_outputs):
3823                raise ValueError("New model must match numbers of inputs and "
3824                                 "outputs of existing model")
3825
3826            tree = _get_submodel_path(self, name)
3827            while tree:
3828                branch = self.copy()
3829                for node in tree[:-1]:
3830                    branch = getattr(branch, node)
3831                setattr(branch, tree[-1], model)
3832                model = CompoundModel(branch.op, branch.left, branch.right,
3833                                      name=branch.name)
3834                tree = tree[:-1]
3835            return model
3836
3837        else:
3838            raise ValueError(f"No submodels found named {name}")
3839
3840
3841def _get_submodel_path(model, name):
3842    """Find the route down a CompoundModel's tree to the model with the
3843    specified name (whether it's a leaf or not)"""
3844    if getattr(model, 'name', None) == name:
3845        return []
3846    try:
3847        return ['left'] + _get_submodel_path(model.left, name)
3848    except (AttributeError, TypeError):
3849        pass
3850    try:
3851        return ['right'] + _get_submodel_path(model.right, name)
3852    except (AttributeError, TypeError):
3853        pass
3854
3855
3856def binary_operation(binoperator, left, right):
3857    '''
3858    Perform binary operation. Operands may be matching tuples of operands.
3859    '''
3860    if isinstance(left, tuple) and isinstance(right, tuple):
3861        return tuple([binoperator(item[0], item[1])
3862                      for item in zip(left, right)])
3863    return binoperator(left, right)
3864
3865
3866def get_ops(tree, opset):
3867    """
3868    Recursive function to collect operators used.
3869    """
3870    if isinstance(tree, CompoundModel):
3871        opset.add(tree.op)
3872        get_ops(tree.left, opset)
3873        get_ops(tree.right, opset)
3874    else:
3875        return
3876
3877
3878def make_subtree_dict(tree, nodepath, tdict, leaflist):
3879    '''
3880    Traverse a tree noting each node by a key that indicates all the
3881    left/right choices necessary to reach that node. Each key will
3882    reference a tuple that contains:
3883
3884    - reference to the compound model for that node.
3885    - left most index contained within that subtree
3886       (relative to all indices for the whole tree)
3887    - right most index contained within that subtree
3888    '''
3889    # if this is a leaf, just append it to the leaflist
3890    if not hasattr(tree, 'isleaf'):
3891        leaflist.append(tree)
3892    else:
3893        leftmostind = len(leaflist)
3894        make_subtree_dict(tree.left, nodepath+'l', tdict, leaflist)
3895        make_subtree_dict(tree.right, nodepath+'r', tdict, leaflist)
3896        rightmostind = len(leaflist)-1
3897        tdict[nodepath] = (tree, leftmostind, rightmostind)
3898
3899
3900_ORDER_OF_OPERATORS = [('fix_inputs',), ('|',), ('&',), ('+', '-'), ('*', '/'), ('**',)]
3901OPERATOR_PRECEDENCE = {}
3902for idx, ops in enumerate(_ORDER_OF_OPERATORS):
3903    for op in ops:
3904        OPERATOR_PRECEDENCE[op] = idx
3905del idx, op, ops
3906
3907
3908def fix_inputs(modelinstance, values, bounding_boxes=None, selector_args=None):
3909    """
3910    This function creates a compound model with one or more of the input
3911    values of the input model assigned fixed values (scalar or array).
3912
3913    Parameters
3914    ----------
3915    modelinstance : `~astropy.modeling.Model` instance
3916        This is the model that one or more of the
3917        model input values will be fixed to some constant value.
3918    values : dict
3919        A dictionary where the key identifies which input to fix
3920        and its value is the value to fix it at. The key may either be the
3921        name of the input or a number reflecting its order in the inputs.
3922
3923    Examples
3924    --------
3925
3926    >>> from astropy.modeling.models import Gaussian2D
3927    >>> g = Gaussian2D(1, 2, 3, 4, 5)
3928    >>> gv = fix_inputs(g, {0: 2.5})
3929
3930    Results in a 1D function equivalent to Gaussian2D(1, 2, 3, 4, 5)(x=2.5, y)
3931    """
3932    model = CompoundModel('fix_inputs', modelinstance, values)
3933    if bounding_boxes is not None:
3934        if selector_args is None:
3935            selector_args = tuple([(key, True) for key in values.keys()])
3936        bbox = CompoundBoundingBox.validate(modelinstance, bounding_boxes, selector_args)
3937        _selector = bbox.selector_args.get_fixed_values(modelinstance, values)
3938
3939        model.bounding_box = bbox[_selector]
3940    return model
3941
3942
3943def bind_bounding_box(modelinstance, bounding_box, order='C'):
3944    """
3945    Set a validated bounding box to a model instance.
3946
3947    Parameters
3948    ----------
3949    modelinstance : `~astropy.modeling.Model` instance
3950        This is the model that the validated bounding box will be set on.
3951    bounding_box : tuple
3952        A bounding box tuple, see :ref:`astropy:bounding-boxes` for details
3953    order : str, optional
3954        The ordering of the bounding box tuple, can be either ``'C'`` or
3955        ``'F'``.
3956    """
3957    modelinstance.bounding_box = ModelBoundingBox.validate(modelinstance,
3958                                                           bounding_box,
3959                                                           order=order)
3960
3961
3962def bind_compound_bounding_box(modelinstance, bounding_boxes, selector_args,
3963                               create_selector=None, order='C'):
3964    """
3965    Add a validated compound bounding box to a model instance.
3966
3967    Parameters
3968    ----------
3969    modelinstance : `~astropy.modeling.Model` instance
3970        This is the model that the validated compound bounding box will be set on.
3971    bounding_boxes : dict
3972        A dictionary of bounding box tuples, see :ref:`astropy:bounding-boxes`
3973        for details.
3974    selector_args : list
3975        List of selector argument tuples to define selection for compound
3976        bounding box, see :ref:`astropy:bounding-boxes` for details.
3977    create_selector : callable, optional
3978        An optional callable with interface (selector_value, model) which
3979        can generate a bounding box based on a selector value and model if
3980        there is no bounding box in the compound bounding box listed under
3981        that selector value. Default is ``None``, meaning new bounding
3982        box entries will not be automatically generated.
3983    order : str, optional
3984        The ordering of the bounding box tuple, can be either ``'C'`` or
3985        ``'F'``.
3986    """
3987    modelinstance.bounding_box = CompoundBoundingBox.validate(modelinstance,
3988                                                              bounding_boxes,
3989                                                              selector_args,
3990                                                              create_selector,
3991                                                              order=order)
3992
3993
3994def custom_model(*args, fit_deriv=None):
3995    """
3996    Create a model from a user defined function. The inputs and parameters of
3997    the model will be inferred from the arguments of the function.
3998
3999    This can be used either as a function or as a decorator.  See below for
4000    examples of both usages.
4001
4002    The model is separable only if there is a single input.
4003
4004    .. note::
4005
4006        All model parameters have to be defined as keyword arguments with
4007        default values in the model function.  Use `None` as a default argument
4008        value if you do not want to have a default value for that parameter.
4009
4010        The standard settable model properties can be configured by default
4011        using keyword arguments matching the name of the property; however,
4012        these values are not set as model "parameters". Moreover, users
4013        cannot use keyword arguments matching non-settable model properties,
4014        with the exception of ``n_outputs`` which should be set to the number of
4015        outputs of your function.
4016
4017    Parameters
4018    ----------
4019    func : function
4020        Function which defines the model.  It should take N positional
4021        arguments where ``N`` is dimensions of the model (the number of
4022        independent variable in the model), and any number of keyword arguments
4023        (the parameters).  It must return the value of the model (typically as
4024        an array, but can also be a scalar for scalar inputs).  This
4025        corresponds to the `~astropy.modeling.Model.evaluate` method.
4026    fit_deriv : function, optional
4027        Function which defines the Jacobian derivative of the model. I.e., the
4028        derivative with respect to the *parameters* of the model.  It should
4029        have the same argument signature as ``func``, but should return a
4030        sequence where each element of the sequence is the derivative
4031        with respect to the corresponding argument. This corresponds to the
4032        :meth:`~astropy.modeling.FittableModel.fit_deriv` method.
4033
4034    Examples
4035    --------
4036    Define a sinusoidal model function as a custom 1D model::
4037
4038        >>> from astropy.modeling.models import custom_model
4039        >>> import numpy as np
4040        >>> def sine_model(x, amplitude=1., frequency=1.):
4041        ...     return amplitude * np.sin(2 * np.pi * frequency * x)
4042        >>> def sine_deriv(x, amplitude=1., frequency=1.):
4043        ...     return 2 * np.pi * amplitude * np.cos(2 * np.pi * frequency * x)
4044        >>> SineModel = custom_model(sine_model, fit_deriv=sine_deriv)
4045
4046    Create an instance of the custom model and evaluate it::
4047
4048        >>> model = SineModel()
4049        >>> model(0.25)
4050        1.0
4051
4052    This model instance can now be used like a usual astropy model.
4053
4054    The next example demonstrates a 2D Moffat function model, and also
4055    demonstrates the support for docstrings (this example could also include
4056    a derivative, but it has been omitted for simplicity)::
4057
4058        >>> @custom_model
4059        ... def Moffat2D(x, y, amplitude=1.0, x_0=0.0, y_0=0.0, gamma=1.0,
4060        ...            alpha=1.0):
4061        ...     \"\"\"Two dimensional Moffat function.\"\"\"
4062        ...     rr_gg = ((x - x_0) ** 2 + (y - y_0) ** 2) / gamma ** 2
4063        ...     return amplitude * (1 + rr_gg) ** (-alpha)
4064        ...
4065        >>> print(Moffat2D.__doc__)
4066        Two dimensional Moffat function.
4067        >>> model = Moffat2D()
4068        >>> model(1, 1)  # doctest: +FLOAT_CMP
4069        0.3333333333333333
4070    """
4071
4072    if len(args) == 1 and callable(args[0]):
4073        return _custom_model_wrapper(args[0], fit_deriv=fit_deriv)
4074    elif not args:
4075        return functools.partial(_custom_model_wrapper, fit_deriv=fit_deriv)
4076    else:
4077        raise TypeError(
4078            "{0} takes at most one positional argument (the callable/"
4079            "function to be turned into a model.  When used as a decorator "
4080            "it should be passed keyword arguments only (if "
4081            "any).".format(__name__))
4082
4083
4084def _custom_model_inputs(func):
4085    """
4086    Processes the inputs to the `custom_model`'s function into the appropriate
4087    categories.
4088
4089    Parameters
4090    ----------
4091    func : callable
4092
4093    Returns
4094    -------
4095    inputs : list
4096        list of evaluation inputs
4097    special_params : dict
4098        dictionary of model properties which require special treatment
4099    settable_params : dict
4100        dictionary of defaults for settable model properties
4101    params : dict
4102        dictionary of model parameters set by `custom_model`'s function
4103    """
4104    inputs, parameters = get_inputs_and_params(func)
4105
4106    special = ['n_outputs']
4107    settable = [attr for attr, value in vars(Model).items()
4108                if isinstance(value, property) and value.fset is not None]
4109    properties = [attr for attr, value in vars(Model).items()
4110                  if isinstance(value, property) and value.fset is None and attr not in special]
4111
4112    special_params = {}
4113    settable_params = {}
4114    params = {}
4115    for param in parameters:
4116        if param.name in special:
4117            special_params[param.name] = param.default
4118        elif param.name in settable:
4119            settable_params[param.name] = param.default
4120        elif param.name in properties:
4121            raise ValueError(f"Parameter '{param.name}' cannot be a model property: {properties}.")
4122        else:
4123            params[param.name] = param.default
4124
4125    return inputs, special_params, settable_params, params
4126
4127
4128def _custom_model_wrapper(func, fit_deriv=None):
4129    """
4130    Internal implementation `custom_model`.
4131
4132    When `custom_model` is called as a function its arguments are passed to
4133    this function, and the result of this function is returned.
4134
4135    When `custom_model` is used as a decorator a partial evaluation of this
4136    function is returned by `custom_model`.
4137    """
4138
4139    if not callable(func):
4140        raise ModelDefinitionError(
4141            "func is not callable; it must be a function or other callable "
4142            "object")
4143
4144    if fit_deriv is not None and not callable(fit_deriv):
4145        raise ModelDefinitionError(
4146            "fit_deriv not callable; it must be a function or other "
4147            "callable object")
4148
4149    model_name = func.__name__
4150
4151    inputs, special_params, settable_params, params = _custom_model_inputs(func)
4152
4153    if (fit_deriv is not None and
4154            len(fit_deriv.__defaults__) != len(params)):
4155        raise ModelDefinitionError("derivative function should accept "
4156                                   "same number of parameters as func.")
4157
4158    params = {param: Parameter(param, default=default)
4159              for param, default in params.items()}
4160
4161    mod = find_current_module(2)
4162    if mod:
4163        modname = mod.__name__
4164    else:
4165        modname = '__main__'
4166
4167    members = {
4168        '__module__': str(modname),
4169        '__doc__': func.__doc__,
4170        'n_inputs': len(inputs),
4171        'n_outputs': special_params.pop('n_outputs', 1),
4172        'evaluate': staticmethod(func),
4173        '_settable_properties': settable_params
4174    }
4175
4176    if fit_deriv is not None:
4177        members['fit_deriv'] = staticmethod(fit_deriv)
4178
4179    members.update(params)
4180
4181    cls = type(model_name, (FittableModel,), members)
4182    cls._separable = True if (len(inputs) == 1) else False
4183    return cls
4184
4185
4186def render_model(model, arr=None, coords=None):
4187    """
4188    Evaluates a model on an input array. Evaluation is limited to
4189    a bounding box if the `Model.bounding_box` attribute is set.
4190
4191    Parameters
4192    ----------
4193    model : `Model`
4194        Model to be evaluated.
4195    arr : `numpy.ndarray`, optional
4196        Array on which the model is evaluated.
4197    coords : array-like, optional
4198        Coordinate arrays mapping to ``arr``, such that
4199        ``arr[coords] == arr``.
4200
4201    Returns
4202    -------
4203    array : `numpy.ndarray`
4204        The model evaluated on the input ``arr`` or a new array from
4205        ``coords``.
4206        If ``arr`` and ``coords`` are both `None`, the returned array is
4207        limited to the `Model.bounding_box` limits. If
4208        `Model.bounding_box` is `None`, ``arr`` or ``coords`` must be passed.
4209
4210    Examples
4211    --------
4212    :ref:`astropy:bounding-boxes`
4213    """
4214
4215    bbox = model.bounding_box
4216
4217    if (coords is None) & (arr is None) & (bbox is None):
4218        raise ValueError('If no bounding_box is set,'
4219                         'coords or arr must be input.')
4220
4221    # for consistent indexing
4222    if model.n_inputs == 1:
4223        if coords is not None:
4224            coords = [coords]
4225        if bbox is not None:
4226            bbox = [bbox]
4227
4228    if arr is not None:
4229        arr = arr.copy()
4230        # Check dimensions match model
4231        if arr.ndim != model.n_inputs:
4232            raise ValueError('number of array dimensions inconsistent with '
4233                             'number of model inputs.')
4234    if coords is not None:
4235        # Check dimensions match arr and model
4236        coords = np.array(coords)
4237        if len(coords) != model.n_inputs:
4238            raise ValueError('coordinate length inconsistent with the number '
4239                             'of model inputs.')
4240        if arr is not None:
4241            if coords[0].shape != arr.shape:
4242                raise ValueError('coordinate shape inconsistent with the '
4243                                 'array shape.')
4244        else:
4245            arr = np.zeros(coords[0].shape)
4246
4247    if bbox is not None:
4248        # assures position is at center pixel, important when using add_array
4249        pd = pos, delta = np.array([(np.mean(bb), np.ceil((bb[1] - bb[0]) / 2))
4250                                    for bb in bbox]).astype(int).T
4251
4252        if coords is not None:
4253            sub_shape = tuple(delta * 2 + 1)
4254            sub_coords = np.array([extract_array(c, sub_shape, pos)
4255                                   for c in coords])
4256        else:
4257            limits = [slice(p - d, p + d + 1, 1) for p, d in pd.T]
4258            sub_coords = np.mgrid[limits]
4259
4260        sub_coords = sub_coords[::-1]
4261
4262        if arr is None:
4263            arr = model(*sub_coords)
4264        else:
4265            try:
4266                arr = add_array(arr, model(*sub_coords), pos)
4267            except ValueError:
4268                raise ValueError('The `bounding_box` is larger than the input'
4269                                 ' arr in one or more dimensions. Set '
4270                                 '`model.bounding_box = None`.')
4271    else:
4272
4273        if coords is None:
4274            im_shape = arr.shape
4275            limits = [slice(i) for i in im_shape]
4276            coords = np.mgrid[limits]
4277
4278        arr += model(*coords[::-1])
4279
4280    return arr
4281
4282
4283def hide_inverse(model):
4284    """
4285    This is a convenience function intended to disable automatic generation
4286    of the inverse in compound models by disabling one of the constituent
4287    model's inverse. This is to handle cases where user provided inverse
4288    functions are not compatible within an expression.
4289
4290    Example:
4291        compound_model.inverse = hide_inverse(m1) + m2 + m3
4292
4293    This will insure that the defined inverse itself won't attempt to
4294    build its own inverse, which would otherwise fail in this example
4295    (e.g., m = m1 + m2 + m3 happens to raises an exception for this
4296    reason.)
4297
4298    Note that this permanently disables it. To prevent that either copy
4299    the model or restore the inverse later.
4300    """
4301    del model.inverse
4302    return model
4303