1"""
2Copyright 2013 Steven Diamond
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15"""
16
17import abc
18
19import numpy as np
20
21import cvxpy.lin_ops.lin_utils as lu
22import cvxpy.utilities as u
23from cvxpy.expressions import cvxtypes
24
25
26class Constraint(u.Canonical):
27    """The base class for constraints.
28
29    A constraint is an equality, inequality, or more generally a generalized
30    inequality that is imposed upon a mathematical expression or a list of
31    thereof.
32
33    Parameters
34    ----------
35    args : list
36        A list of expression trees.
37    constr_id : int
38        A unique id for the constraint.
39    """
40
41    __metaclass__ = abc.ABCMeta
42
43    def __init__(self, args, constr_id=None) -> None:
44        # TODO cast constants.
45        # self.args = [cvxtypes.expression().cast_to_const(arg) for arg in args]
46        self.args = args
47        if constr_id is None:
48            self.constr_id = lu.get_id()
49        else:
50            self.constr_id = constr_id
51        self._construct_dual_variables(args)
52        super(Constraint, self).__init__()
53
54    def __str__(self):
55        """Returns a string showing the mathematical constraint.
56        """
57        return self.name()
58
59    def __repr__(self) -> str:
60        """Returns a string with information about the constraint.
61        """
62        return "%s(%s)" % (self.__class__.__name__,
63                           repr(self.args[0]))
64
65    def _construct_dual_variables(self, args) -> None:
66        self.dual_variables = [cvxtypes.variable()(arg.shape) for arg in args]
67
68    @property
69    def shape(self):
70        """int : The shape of the constrained expression."""
71        return self.args[0].shape
72
73    @property
74    def size(self):
75        """int : The size of the constrained expression."""
76        return self.args[0].size
77
78    def is_real(self) -> bool:
79        """Is the Leaf real valued?
80        """
81        return not self.is_complex()
82
83    def is_imag(self) -> bool:
84        """Is the Leaf imaginary?
85        """
86        return all(arg.is_imag() for arg in self.args)
87
88    def is_complex(self) -> bool:
89        """Is the Leaf complex valued?
90        """
91        return any(arg.is_complex() for arg in self.args)
92
93    @abc.abstractmethod
94    def is_dcp(self, dpp: bool = False) -> bool:
95        """Checks whether the constraint is DCP.
96
97        Returns
98        -------
99        bool
100            True if the constraint is DCP, False otherwise.
101        """
102        raise NotImplementedError()
103
104    @abc.abstractmethod
105    def is_dgp(self, dpp: bool = False) -> bool:
106        """Checks whether the constraint is DGP.
107
108        Returns
109        -------
110        bool
111            True if the constraint is DGP, False otherwise.
112        """
113        raise NotImplementedError()
114
115    def is_dpp(self, context='dcp') -> bool:
116        if context.lower() == 'dcp':
117            return self.is_dcp(dpp=True)
118        elif context.lower() == 'dgp':
119            return self.is_dgp(dpp=True)
120        else:
121            raise ValueError("Unsupported context ", context)
122
123    @abc.abstractproperty
124    def residual(self):
125        """The residual of the constraint.
126
127        Returns
128        -------
129        NumPy.ndarray
130            The residual, or None if the constrained expression does not have
131            a value.
132        """
133        raise NotImplementedError()
134
135    def violation(self):
136        """The numeric residual of the constraint.
137
138        The violation is defined as the distance between the constrained
139        expression's value and its projection onto the domain of the
140        constraint:
141
142        .. math::
143
144            ||\\Pi(v) - v||_2^2
145
146        where :math:`v` is the value of the constrained expression and
147        :math:`\\Pi` is the projection operator onto the constraint's domain .
148
149        Returns
150        -------
151        NumPy.ndarray
152            The residual value.
153
154        Raises
155        ------
156        ValueError
157            If the constrained expression does not have a value associated
158            with it.
159        """
160        residual = self.residual
161        if residual is None:
162            raise ValueError("Cannot compute the violation of an constraint "
163                             "whose expression is None-valued.")
164        return residual
165
166    def value(self, tolerance: float = 1e-8):
167        """Checks whether the constraint violation is less than a tolerance.
168
169        Parameters
170        ----------
171            tolerance : float
172                The absolute tolerance to impose on the violation.
173
174        Returns
175        -------
176            bool
177                True if the violation is less than ``tolerance``, False
178                otherwise.
179
180        Raises
181        ------
182            ValueError
183                If the constrained expression does not have a value associated
184                with it.
185        """
186        residual = self.residual
187        if residual is None:
188            raise ValueError("Cannot compute the value of an constraint "
189                             "whose expression is None-valued.")
190        return np.all(residual <= tolerance)
191
192    @property
193    def id(self):
194        """Wrapper for compatibility with variables.
195        """
196        return self.constr_id
197
198    def get_data(self):
199        """Data needed to copy.
200        """
201        return [self.id]
202
203    def __nonzero__(self):
204        """Raises an exception when called.
205
206        Python 2 version.
207
208        Called when evaluating the truth value of the constraint.
209        Raising an error here prevents writing chained constraints.
210        """
211        return self._chain_constraints()
212
213    def _chain_constraints(self):
214        """Raises an error due to chained constraints.
215        """
216        raise Exception(
217            ("Cannot evaluate the truth value of a constraint or "
218             "chain constraints, e.g., 1 >= x >= 0.")
219        )
220
221    def __bool__(self):
222        """Raises an exception when called.
223
224        Python 3 version.
225
226        Called when evaluating the truth value of the constraint.
227        Raising an error here prevents writing chained constraints.
228        """
229        return self._chain_constraints()
230
231    # TODO(rileyjmurray): add a function to compute dual-variable violation.
232
233    @property
234    def dual_value(self):
235        """NumPy.ndarray : The value of the dual variable.
236        """
237        dual_vals = [dv.value for dv in self.dual_variables]
238        if len(dual_vals) == 1:
239            return dual_vals[0]
240        else:
241            return dual_vals
242
243    def save_dual_value(self, value) -> None:
244        """Save the value of the dual variable for the constraint's parent.
245        Args:
246            value: The value of the dual variable.
247        """
248        self.dual_variables[0].save_value(value)
249