1# This file is part of Patsy
2# Copyright (C) 2011-2012 Nathaniel Smith <njs@pobox.com>
3# See file LICENSE.txt for license information.
4
5# Interpreting linear constraints like "2*x1 + x2 = 0"
6
7from __future__ import print_function
8
9# These are made available in the patsy.* namespace
10__all__ = ["LinearConstraint"]
11
12import re
13try:
14    from collections.abc import Mapping
15except ImportError:
16    from collections import Mapping
17import six
18import numpy as np
19from patsy import PatsyError
20from patsy.origin import Origin
21from patsy.util import (atleast_2d_column_default,
22                        repr_pretty_delegate, repr_pretty_impl,
23                        no_pickling, assert_no_pickling)
24from patsy.infix_parser import Token, Operator, infix_parse
25from patsy.parse_formula import _parsing_error_test
26
27
28class LinearConstraint(object):
29    """A linear constraint in matrix form.
30
31    This object represents a linear constraint of the form `Ax = b`.
32
33    Usually you won't be constructing these by hand, but instead get them as
34    the return value from :meth:`DesignInfo.linear_constraint`.
35
36    .. attribute:: coefs
37
38       A 2-dimensional ndarray with float dtype, representing `A`.
39
40    .. attribute:: constants
41
42       A 2-dimensional single-column ndarray with float dtype, representing
43       `b`.
44
45    .. attribute:: variable_names
46
47       A list of strings giving the names of the variables being
48       constrained. (Used only for consistency checking.)
49    """
50    def __init__(self, variable_names, coefs, constants=None):
51        self.variable_names = list(variable_names)
52        self.coefs = np.atleast_2d(np.asarray(coefs, dtype=float))
53        if constants is None:
54            constants = np.zeros(self.coefs.shape[0], dtype=float)
55        constants = np.asarray(constants, dtype=float)
56        self.constants = atleast_2d_column_default(constants)
57        if self.constants.ndim != 2 or self.constants.shape[1] != 1:
58            raise ValueError("constants is not (convertible to) a column matrix")
59        if self.coefs.ndim != 2 or self.coefs.shape[1] != len(variable_names):
60            raise ValueError("wrong shape for coefs")
61        if self.coefs.shape[0] == 0:
62            raise ValueError("must have at least one row in constraint matrix")
63        if self.coefs.shape[0] != self.constants.shape[0]:
64            raise ValueError("shape mismatch between coefs and constants")
65
66    __repr__ = repr_pretty_delegate
67    def _repr_pretty_(self, p, cycle):
68        assert not cycle
69        return repr_pretty_impl(p, self,
70                                [self.variable_names, self.coefs, self.constants])
71
72    __getstate__ = no_pickling
73
74    @classmethod
75    def combine(cls, constraints):
76        """Create a new LinearConstraint by ANDing together several existing
77        LinearConstraints.
78
79        :arg constraints: An iterable of LinearConstraint objects. Their
80          :attr:`variable_names` attributes must all match.
81        :returns: A new LinearConstraint object.
82        """
83        if not constraints:
84            raise ValueError("no constraints specified")
85        variable_names = constraints[0].variable_names
86        for constraint in constraints:
87            if constraint.variable_names != variable_names:
88                raise ValueError("variable names don't match")
89        coefs = np.row_stack([c.coefs for c in constraints])
90        constants = np.row_stack([c.constants for c in constraints])
91        return cls(variable_names, coefs, constants)
92
93def test_LinearConstraint():
94    try:
95        from numpy.testing import assert_equal
96    except ImportError:
97        from numpy.testing.utils import assert_equal
98    lc = LinearConstraint(["foo", "bar"], [1, 1])
99    assert lc.variable_names == ["foo", "bar"]
100    assert_equal(lc.coefs, [[1, 1]])
101    assert_equal(lc.constants, [[0]])
102
103    lc = LinearConstraint(["foo", "bar"], [[1, 1], [2, 3]], [10, 20])
104    assert_equal(lc.coefs, [[1, 1], [2, 3]])
105    assert_equal(lc.constants, [[10], [20]])
106
107    assert lc.coefs.dtype == np.dtype(float)
108    assert lc.constants.dtype == np.dtype(float)
109
110
111    # statsmodels wants to be able to create degenerate constraints like this,
112    # see:
113    #     https://github.com/pydata/patsy/issues/89
114    # We used to forbid it, but I guess it's harmless, so why not.
115    lc = LinearConstraint(["a"], [[0]])
116    assert_equal(lc.coefs, [[0]])
117
118    import pytest
119    pytest.raises(ValueError, LinearConstraint, ["a"], [[1, 2]])
120    pytest.raises(ValueError, LinearConstraint, ["a"], [[[1]]])
121    pytest.raises(ValueError, LinearConstraint, ["a"], [[1, 2]], [3, 4])
122    pytest.raises(ValueError, LinearConstraint, ["a", "b"], [[1, 2]], [3, 4])
123    pytest.raises(ValueError, LinearConstraint, ["a"], [[1]], [[]])
124    pytest.raises(ValueError, LinearConstraint, ["a", "b"], [])
125    pytest.raises(ValueError, LinearConstraint, ["a", "b"],
126                  np.zeros((0, 2)))
127
128    assert_no_pickling(lc)
129
130def test_LinearConstraint_combine():
131    comb = LinearConstraint.combine([LinearConstraint(["a", "b"], [1, 0]),
132                                     LinearConstraint(["a", "b"], [0, 1], [1])])
133    assert comb.variable_names == ["a", "b"]
134    try:
135        from numpy.testing import assert_equal
136    except ImportError:
137        from numpy.testing.utils import assert_equal
138    assert_equal(comb.coefs, [[1, 0], [0, 1]])
139    assert_equal(comb.constants, [[0], [1]])
140
141    import pytest
142    pytest.raises(ValueError, LinearConstraint.combine, [])
143    pytest.raises(ValueError, LinearConstraint.combine,
144                  [LinearConstraint(["a"], [1]), LinearConstraint(["b"], [1])])
145
146
147_ops = [
148    Operator(",", 2, -100),
149
150    Operator("=", 2, 0),
151
152    Operator("+", 1, 100),
153    Operator("-", 1, 100),
154    Operator("+", 2, 100),
155    Operator("-", 2, 100),
156
157    Operator("*", 2, 200),
158    Operator("/", 2, 200),
159    ]
160
161_atomic = ["NUMBER", "VARIABLE"]
162
163def _token_maker(type, string):
164    def make_token(scanner, token_string):
165        if type == "__OP__":
166            actual_type = token_string
167        else:
168            actual_type = type
169        return Token(actual_type,
170                     Origin(string, *scanner.match.span()),
171                     token_string)
172    return make_token
173
174def _tokenize_constraint(string, variable_names):
175    lparen_re = r"\("
176    rparen_re = r"\)"
177    op_re = "|".join([re.escape(op.token_type) for op in _ops])
178    num_re = r"[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?"
179    whitespace_re = r"\s+"
180
181    # Prefer long matches:
182    variable_names = sorted(variable_names, key=len, reverse=True)
183    variable_re = "|".join([re.escape(n) for n in variable_names])
184
185    lexicon = [
186        (lparen_re, _token_maker(Token.LPAREN, string)),
187        (rparen_re, _token_maker(Token.RPAREN, string)),
188        (op_re, _token_maker("__OP__", string)),
189        (variable_re, _token_maker("VARIABLE", string)),
190        (num_re, _token_maker("NUMBER", string)),
191        (whitespace_re, None),
192        ]
193
194    scanner = re.Scanner(lexicon)
195    tokens, leftover = scanner.scan(string)
196    if leftover:
197        offset = len(string) - len(leftover)
198        raise PatsyError("unrecognized token in constraint",
199                            Origin(string, offset, offset + 1))
200
201    return tokens
202
203def test__tokenize_constraint():
204    code = "2 * (a + b) = q"
205    tokens = _tokenize_constraint(code, ["a", "b", "q"])
206    expecteds = [("NUMBER", 0, 1, "2"),
207                 ("*", 2, 3, "*"),
208                 (Token.LPAREN, 4, 5, "("),
209                 ("VARIABLE", 5, 6, "a"),
210                 ("+", 7, 8, "+"),
211                 ("VARIABLE", 9, 10, "b"),
212                 (Token.RPAREN, 10, 11, ")"),
213                 ("=", 12, 13, "="),
214                 ("VARIABLE", 14, 15, "q")]
215    for got, expected in zip(tokens, expecteds):
216        assert isinstance(got, Token)
217        assert got.type == expected[0]
218        assert got.origin == Origin(code, expected[1], expected[2])
219        assert got.extra == expected[3]
220
221    import pytest
222    pytest.raises(PatsyError, _tokenize_constraint, "1 + @b", ["b"])
223    # Shouldn't raise an error:
224    _tokenize_constraint("1 + @b", ["@b"])
225
226    # Check we aren't confused by names which are proper prefixes of other
227    # names:
228    for names in (["a", "aa"], ["aa", "a"]):
229        tokens = _tokenize_constraint("a aa a", names)
230        assert len(tokens) == 3
231        assert [t.extra for t in tokens] == ["a", "aa", "a"]
232
233    # Check that embedding ops and numbers inside a variable name works
234    tokens = _tokenize_constraint("2 * a[1,1],", ["a[1,1]"])
235    assert len(tokens) == 4
236    assert [t.type for t in tokens] == ["NUMBER", "*", "VARIABLE", ","]
237    assert [t.extra for t in tokens] == ["2", "*", "a[1,1]", ","]
238
239def parse_constraint(string, variable_names):
240    return infix_parse(_tokenize_constraint(string, variable_names),
241                       _ops, _atomic)
242
243class _EvalConstraint(object):
244    def __init__(self, variable_names):
245        self._variable_names = variable_names
246        self._N = len(variable_names)
247
248        self._dispatch = {
249            ("VARIABLE", 0): self._eval_variable,
250            ("NUMBER", 0): self._eval_number,
251            ("+", 1): self._eval_unary_plus,
252            ("-", 1): self._eval_unary_minus,
253            ("+", 2): self._eval_binary_plus,
254            ("-", 2): self._eval_binary_minus,
255            ("*", 2): self._eval_binary_multiply,
256            ("/", 2): self._eval_binary_div,
257            ("=", 2): self._eval_binary_eq,
258            (",", 2): self._eval_binary_comma,
259            }
260
261    # General scheme: there are 2 types we deal with:
262    #   - linear combinations ("lincomb"s) of variables and constants,
263    #     represented as ndarrays with size N+1
264    #     The last entry is the constant, so [10, 20, 30] means 10x + 20y +
265    #     30.
266    #   - LinearConstraint objects
267
268    def is_constant(self, coefs):
269        return np.all(coefs[:self._N] == 0)
270
271    def _eval_variable(self, tree):
272        var = tree.token.extra
273        coefs = np.zeros((self._N + 1,), dtype=float)
274        coefs[self._variable_names.index(var)] = 1
275        return coefs
276
277    def _eval_number(self, tree):
278        coefs = np.zeros((self._N + 1,), dtype=float)
279        coefs[-1] = float(tree.token.extra)
280        return coefs
281
282    def _eval_unary_plus(self, tree):
283        return self.eval(tree.args[0])
284
285    def _eval_unary_minus(self, tree):
286        return -1 * self.eval(tree.args[0])
287
288    def _eval_binary_plus(self, tree):
289        return self.eval(tree.args[0]) + self.eval(tree.args[1])
290
291    def _eval_binary_minus(self, tree):
292        return self.eval(tree.args[0]) - self.eval(tree.args[1])
293
294    def _eval_binary_div(self, tree):
295        left = self.eval(tree.args[0])
296        right = self.eval(tree.args[1])
297        if not self.is_constant(right):
298            raise PatsyError("Can't divide by a variable in a linear "
299                                "constraint", tree.args[1])
300        return left / right[-1]
301
302    def _eval_binary_multiply(self, tree):
303        left = self.eval(tree.args[0])
304        right = self.eval(tree.args[1])
305        if self.is_constant(left):
306            return left[-1] * right
307        elif self.is_constant(right):
308            return left * right[-1]
309        else:
310            raise PatsyError("Can't multiply one variable by another "
311                                "in a linear constraint", tree)
312
313    def _eval_binary_eq(self, tree):
314        # Handle "a1 = a2 = a3", which is parsed as "(a1 = a2) = a3"
315        args = list(tree.args)
316        constraints = []
317        for i, arg in enumerate(args):
318            if arg.type == "=":
319                constraints.append(self.eval(arg, constraint=True))
320                # make our left argument be their right argument, or
321                # vice-versa
322                args[i] = arg.args[1 - i]
323        left = self.eval(args[0])
324        right = self.eval(args[1])
325        coefs = left[:self._N] - right[:self._N]
326        if np.all(coefs == 0):
327            raise PatsyError("no variables appear in constraint", tree)
328        constant = -left[-1] + right[-1]
329        constraint = LinearConstraint(self._variable_names, coefs, constant)
330        constraints.append(constraint)
331        return LinearConstraint.combine(constraints)
332
333    def _eval_binary_comma(self, tree):
334        left = self.eval(tree.args[0], constraint=True)
335        right = self.eval(tree.args[1], constraint=True)
336        return LinearConstraint.combine([left, right])
337
338    def eval(self, tree, constraint=False):
339        key = (tree.type, len(tree.args))
340        assert key in self._dispatch
341        val = self._dispatch[key](tree)
342        if constraint:
343            # Force it to be a constraint
344            if isinstance(val, LinearConstraint):
345                return val
346            else:
347                assert val.size == self._N + 1
348                if np.all(val[:self._N] == 0):
349                    raise PatsyError("term is constant, with no variables",
350                                        tree)
351                return LinearConstraint(self._variable_names,
352                                        val[:self._N],
353                                        -val[-1])
354        else:
355            # Force it to *not* be a constraint
356            if isinstance(val, LinearConstraint):
357                raise PatsyError("unexpected constraint object", tree)
358            return val
359
360def linear_constraint(constraint_like, variable_names):
361    """This is the internal interface implementing
362    DesignInfo.linear_constraint, see there for docs."""
363    if isinstance(constraint_like, LinearConstraint):
364        if constraint_like.variable_names != variable_names:
365            raise ValueError("LinearConstraint has wrong variable_names "
366                             "(got %r, expected %r)"
367                             % (constraint_like.variable_names,
368                                variable_names))
369        return constraint_like
370
371    if isinstance(constraint_like, Mapping):
372        # Simple conjunction-of-equality constraints can be specified as
373        # dicts. {"x": 1, "y": 2} -> tests x = 1 and y = 2. Keys can be
374        # either variable names, or variable indices.
375        coefs = np.zeros((len(constraint_like), len(variable_names)),
376                         dtype=float)
377        constants = np.zeros(len(constraint_like))
378        used = set()
379        for i, (name, value) in enumerate(six.iteritems(constraint_like)):
380            if name in variable_names:
381                idx = variable_names.index(name)
382            elif isinstance(name, six.integer_types):
383                idx = name
384            else:
385                raise ValueError("unrecognized variable name/index %r"
386                                 % (name,))
387            if idx in used:
388                raise ValueError("duplicated constraint on %r"
389                                 % (variable_names[idx],))
390            used.add(idx)
391            coefs[i, idx] = 1
392            constants[i] = value
393        return LinearConstraint(variable_names, coefs, constants)
394
395    if isinstance(constraint_like, str):
396        constraint_like = [constraint_like]
397        # fall-through
398
399    if (isinstance(constraint_like, list)
400        and constraint_like
401        and isinstance(constraint_like[0], str)):
402        constraints = []
403        for code in constraint_like:
404            if not isinstance(code, str):
405                raise ValueError("expected a string, not %r" % (code,))
406            tree = parse_constraint(code, variable_names)
407            evaluator = _EvalConstraint(variable_names)
408            constraints.append(evaluator.eval(tree, constraint=True))
409        return LinearConstraint.combine(constraints)
410
411    if isinstance(constraint_like, tuple):
412        if len(constraint_like) != 2:
413            raise ValueError("constraint tuple must have length 2")
414        coef, constants = constraint_like
415        return LinearConstraint(variable_names, coef, constants)
416
417    # assume a raw ndarray
418    coefs = np.asarray(constraint_like, dtype=float)
419    return LinearConstraint(variable_names, coefs)
420
421
422def _check_lincon(input, varnames, coefs, constants):
423    try:
424        from numpy.testing import assert_equal
425    except ImportError:
426        from numpy.testing.utils import assert_equal
427    got = linear_constraint(input, varnames)
428    print("got", got)
429    expected = LinearConstraint(varnames, coefs, constants)
430    print("expected", expected)
431    assert_equal(got.variable_names, expected.variable_names)
432    assert_equal(got.coefs, expected.coefs)
433    assert_equal(got.constants, expected.constants)
434    assert_equal(got.coefs.dtype, np.dtype(float))
435    assert_equal(got.constants.dtype, np.dtype(float))
436
437
438def test_linear_constraint():
439    import pytest
440    from patsy.compat import OrderedDict
441    t = _check_lincon
442
443    t(LinearConstraint(["a", "b"], [2, 3]), ["a", "b"], [[2, 3]], [[0]])
444    pytest.raises(ValueError, linear_constraint,
445                  LinearConstraint(["b", "a"], [2, 3]),
446                  ["a", "b"])
447
448    t({"a": 2}, ["a", "b"], [[1, 0]], [[2]])
449    t(OrderedDict([("a", 2), ("b", 3)]),
450      ["a", "b"], [[1, 0], [0, 1]], [[2], [3]])
451    t(OrderedDict([("a", 2), ("b", 3)]),
452      ["b", "a"], [[0, 1], [1, 0]], [[2], [3]])
453
454    t({0: 2}, ["a", "b"], [[1, 0]], [[2]])
455    t(OrderedDict([(0, 2), (1, 3)]), ["a", "b"], [[1, 0], [0, 1]], [[2], [3]])
456
457    t(OrderedDict([("a", 2), (1, 3)]),
458      ["a", "b"], [[1, 0], [0, 1]], [[2], [3]])
459
460    pytest.raises(ValueError, linear_constraint, {"q": 1}, ["a", "b"])
461    pytest.raises(ValueError, linear_constraint, {"a": 1, 0: 2}, ["a", "b"])
462
463    t(np.array([2, 3]), ["a", "b"], [[2, 3]], [[0]])
464    t(np.array([[2, 3], [4, 5]]), ["a", "b"], [[2, 3], [4, 5]], [[0], [0]])
465
466    t("a = 2", ["a", "b"], [[1, 0]], [[2]])
467    t("a - 2", ["a", "b"], [[1, 0]], [[2]])
468    t("a + 1 = 3", ["a", "b"], [[1, 0]], [[2]])
469    t("a + b = 3", ["a", "b"], [[1, 1]], [[3]])
470    t("a = 2, b = 3", ["a", "b"], [[1, 0], [0, 1]], [[2], [3]])
471    t("b = 3, a = 2", ["a", "b"], [[0, 1], [1, 0]], [[3], [2]])
472
473    t(["a = 2", "b = 3"], ["a", "b"], [[1, 0], [0, 1]], [[2], [3]])
474
475    pytest.raises(ValueError, linear_constraint, ["a", {"b": 0}], ["a", "b"])
476
477    # Actual evaluator tests
478    t("2 * (a + b/3) + b + 2*3/4 = 1 + 2*3", ["a", "b"],
479      [[2, 2.0/3 + 1]], [[7 - 6.0/4]])
480    t("+2 * -a", ["a", "b"], [[-2, 0]], [[0]])
481    t("a - b, a + b = 2", ["a", "b"], [[1, -1], [1, 1]], [[0], [2]])
482    t("a = 1, a = 2, a = 3", ["a", "b"],
483      [[1, 0], [1, 0], [1, 0]], [[1], [2], [3]])
484    t("a * 2", ["a", "b"], [[2, 0]], [[0]])
485    t("-a = 1", ["a", "b"], [[-1, 0]], [[1]])
486    t("(2 + a - a) * b", ["a", "b"], [[0, 2]], [[0]])
487
488    t("a = 1 = b", ["a", "b"], [[1, 0], [0, -1]], [[1], [-1]])
489    t("a = (1 = b)", ["a", "b"], [[0, -1], [1, 0]], [[-1], [1]])
490    t("a = 1, a = b = c", ["a", "b", "c"],
491      [[1, 0, 0], [1, -1, 0], [0, 1, -1]], [[1], [0], [0]])
492
493    # One should never do this of course, but test that it works anyway...
494    t("a + 1 = 2", ["a", "a + 1"], [[0, 1]], [[2]])
495
496    t(([10, 20], [30]), ["a", "b"], [[10, 20]], [[30]])
497    t(([[10, 20], [20, 40]], [[30], [35]]), ["a", "b"],
498      [[10, 20], [20, 40]], [[30], [35]])
499    # wrong-length tuple
500    pytest.raises(ValueError, linear_constraint,
501                  ([1, 0], [0], [0]), ["a", "b"])
502    pytest.raises(ValueError, linear_constraint, ([1, 0],), ["a", "b"])
503
504    t([10, 20], ["a", "b"], [[10, 20]], [[0]])
505    t([[10, 20], [20, 40]], ["a", "b"], [[10, 20], [20, 40]], [[0], [0]])
506    t(np.array([10, 20]), ["a", "b"], [[10, 20]], [[0]])
507    t(np.array([[10, 20], [20, 40]]), ["a", "b"],
508      [[10, 20], [20, 40]], [[0], [0]])
509
510    # unknown object type
511    pytest.raises(ValueError, linear_constraint, None, ["a", "b"])
512
513
514_parse_eval_error_tests = [
515    # Bad token
516    "a + <f>oo",
517    # No pure constant equalities
518    "a = 1, <1 = 1>, b = 1",
519    "a = 1, <b * 2 - b + (-2/2 * b)>",
520    "a = 1, <1>, b = 2",
521    "a = 1, <2 * b = b + b>, c",
522    # No non-linearities
523    "a + <a * b> + c",
524    "a + 2 / <b> + c",
525    # Constraints are not numbers
526    "a = 1, 2 * <(a = b)>, c",
527    "a = 1, a + <(a = b)>, c",
528    "a = 1, <(a, b)> + 2, c",
529]
530
531
532def test_eval_errors():
533    def doit(bad_code):
534        return linear_constraint(bad_code, ["a", "b", "c"])
535    _parsing_error_test(doit, _parse_eval_error_tests)
536