1# This file is part of Patsy 2# Copyright (C) 2011-2012 Nathaniel Smith <njs@pobox.com> 3# See file LICENSE.txt for license information. 4 5# Interpreting linear constraints like "2*x1 + x2 = 0" 6 7from __future__ import print_function 8 9# These are made available in the patsy.* namespace 10__all__ = ["LinearConstraint"] 11 12import re 13try: 14 from collections.abc import Mapping 15except ImportError: 16 from collections import Mapping 17import six 18import numpy as np 19from patsy import PatsyError 20from patsy.origin import Origin 21from patsy.util import (atleast_2d_column_default, 22 repr_pretty_delegate, repr_pretty_impl, 23 no_pickling, assert_no_pickling) 24from patsy.infix_parser import Token, Operator, infix_parse 25from patsy.parse_formula import _parsing_error_test 26 27 28class LinearConstraint(object): 29 """A linear constraint in matrix form. 30 31 This object represents a linear constraint of the form `Ax = b`. 32 33 Usually you won't be constructing these by hand, but instead get them as 34 the return value from :meth:`DesignInfo.linear_constraint`. 35 36 .. attribute:: coefs 37 38 A 2-dimensional ndarray with float dtype, representing `A`. 39 40 .. attribute:: constants 41 42 A 2-dimensional single-column ndarray with float dtype, representing 43 `b`. 44 45 .. attribute:: variable_names 46 47 A list of strings giving the names of the variables being 48 constrained. (Used only for consistency checking.) 49 """ 50 def __init__(self, variable_names, coefs, constants=None): 51 self.variable_names = list(variable_names) 52 self.coefs = np.atleast_2d(np.asarray(coefs, dtype=float)) 53 if constants is None: 54 constants = np.zeros(self.coefs.shape[0], dtype=float) 55 constants = np.asarray(constants, dtype=float) 56 self.constants = atleast_2d_column_default(constants) 57 if self.constants.ndim != 2 or self.constants.shape[1] != 1: 58 raise ValueError("constants is not (convertible to) a column matrix") 59 if self.coefs.ndim != 2 or self.coefs.shape[1] != len(variable_names): 60 raise ValueError("wrong shape for coefs") 61 if self.coefs.shape[0] == 0: 62 raise ValueError("must have at least one row in constraint matrix") 63 if self.coefs.shape[0] != self.constants.shape[0]: 64 raise ValueError("shape mismatch between coefs and constants") 65 66 __repr__ = repr_pretty_delegate 67 def _repr_pretty_(self, p, cycle): 68 assert not cycle 69 return repr_pretty_impl(p, self, 70 [self.variable_names, self.coefs, self.constants]) 71 72 __getstate__ = no_pickling 73 74 @classmethod 75 def combine(cls, constraints): 76 """Create a new LinearConstraint by ANDing together several existing 77 LinearConstraints. 78 79 :arg constraints: An iterable of LinearConstraint objects. Their 80 :attr:`variable_names` attributes must all match. 81 :returns: A new LinearConstraint object. 82 """ 83 if not constraints: 84 raise ValueError("no constraints specified") 85 variable_names = constraints[0].variable_names 86 for constraint in constraints: 87 if constraint.variable_names != variable_names: 88 raise ValueError("variable names don't match") 89 coefs = np.row_stack([c.coefs for c in constraints]) 90 constants = np.row_stack([c.constants for c in constraints]) 91 return cls(variable_names, coefs, constants) 92 93def test_LinearConstraint(): 94 try: 95 from numpy.testing import assert_equal 96 except ImportError: 97 from numpy.testing.utils import assert_equal 98 lc = LinearConstraint(["foo", "bar"], [1, 1]) 99 assert lc.variable_names == ["foo", "bar"] 100 assert_equal(lc.coefs, [[1, 1]]) 101 assert_equal(lc.constants, [[0]]) 102 103 lc = LinearConstraint(["foo", "bar"], [[1, 1], [2, 3]], [10, 20]) 104 assert_equal(lc.coefs, [[1, 1], [2, 3]]) 105 assert_equal(lc.constants, [[10], [20]]) 106 107 assert lc.coefs.dtype == np.dtype(float) 108 assert lc.constants.dtype == np.dtype(float) 109 110 111 # statsmodels wants to be able to create degenerate constraints like this, 112 # see: 113 # https://github.com/pydata/patsy/issues/89 114 # We used to forbid it, but I guess it's harmless, so why not. 115 lc = LinearConstraint(["a"], [[0]]) 116 assert_equal(lc.coefs, [[0]]) 117 118 import pytest 119 pytest.raises(ValueError, LinearConstraint, ["a"], [[1, 2]]) 120 pytest.raises(ValueError, LinearConstraint, ["a"], [[[1]]]) 121 pytest.raises(ValueError, LinearConstraint, ["a"], [[1, 2]], [3, 4]) 122 pytest.raises(ValueError, LinearConstraint, ["a", "b"], [[1, 2]], [3, 4]) 123 pytest.raises(ValueError, LinearConstraint, ["a"], [[1]], [[]]) 124 pytest.raises(ValueError, LinearConstraint, ["a", "b"], []) 125 pytest.raises(ValueError, LinearConstraint, ["a", "b"], 126 np.zeros((0, 2))) 127 128 assert_no_pickling(lc) 129 130def test_LinearConstraint_combine(): 131 comb = LinearConstraint.combine([LinearConstraint(["a", "b"], [1, 0]), 132 LinearConstraint(["a", "b"], [0, 1], [1])]) 133 assert comb.variable_names == ["a", "b"] 134 try: 135 from numpy.testing import assert_equal 136 except ImportError: 137 from numpy.testing.utils import assert_equal 138 assert_equal(comb.coefs, [[1, 0], [0, 1]]) 139 assert_equal(comb.constants, [[0], [1]]) 140 141 import pytest 142 pytest.raises(ValueError, LinearConstraint.combine, []) 143 pytest.raises(ValueError, LinearConstraint.combine, 144 [LinearConstraint(["a"], [1]), LinearConstraint(["b"], [1])]) 145 146 147_ops = [ 148 Operator(",", 2, -100), 149 150 Operator("=", 2, 0), 151 152 Operator("+", 1, 100), 153 Operator("-", 1, 100), 154 Operator("+", 2, 100), 155 Operator("-", 2, 100), 156 157 Operator("*", 2, 200), 158 Operator("/", 2, 200), 159 ] 160 161_atomic = ["NUMBER", "VARIABLE"] 162 163def _token_maker(type, string): 164 def make_token(scanner, token_string): 165 if type == "__OP__": 166 actual_type = token_string 167 else: 168 actual_type = type 169 return Token(actual_type, 170 Origin(string, *scanner.match.span()), 171 token_string) 172 return make_token 173 174def _tokenize_constraint(string, variable_names): 175 lparen_re = r"\(" 176 rparen_re = r"\)" 177 op_re = "|".join([re.escape(op.token_type) for op in _ops]) 178 num_re = r"[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?" 179 whitespace_re = r"\s+" 180 181 # Prefer long matches: 182 variable_names = sorted(variable_names, key=len, reverse=True) 183 variable_re = "|".join([re.escape(n) for n in variable_names]) 184 185 lexicon = [ 186 (lparen_re, _token_maker(Token.LPAREN, string)), 187 (rparen_re, _token_maker(Token.RPAREN, string)), 188 (op_re, _token_maker("__OP__", string)), 189 (variable_re, _token_maker("VARIABLE", string)), 190 (num_re, _token_maker("NUMBER", string)), 191 (whitespace_re, None), 192 ] 193 194 scanner = re.Scanner(lexicon) 195 tokens, leftover = scanner.scan(string) 196 if leftover: 197 offset = len(string) - len(leftover) 198 raise PatsyError("unrecognized token in constraint", 199 Origin(string, offset, offset + 1)) 200 201 return tokens 202 203def test__tokenize_constraint(): 204 code = "2 * (a + b) = q" 205 tokens = _tokenize_constraint(code, ["a", "b", "q"]) 206 expecteds = [("NUMBER", 0, 1, "2"), 207 ("*", 2, 3, "*"), 208 (Token.LPAREN, 4, 5, "("), 209 ("VARIABLE", 5, 6, "a"), 210 ("+", 7, 8, "+"), 211 ("VARIABLE", 9, 10, "b"), 212 (Token.RPAREN, 10, 11, ")"), 213 ("=", 12, 13, "="), 214 ("VARIABLE", 14, 15, "q")] 215 for got, expected in zip(tokens, expecteds): 216 assert isinstance(got, Token) 217 assert got.type == expected[0] 218 assert got.origin == Origin(code, expected[1], expected[2]) 219 assert got.extra == expected[3] 220 221 import pytest 222 pytest.raises(PatsyError, _tokenize_constraint, "1 + @b", ["b"]) 223 # Shouldn't raise an error: 224 _tokenize_constraint("1 + @b", ["@b"]) 225 226 # Check we aren't confused by names which are proper prefixes of other 227 # names: 228 for names in (["a", "aa"], ["aa", "a"]): 229 tokens = _tokenize_constraint("a aa a", names) 230 assert len(tokens) == 3 231 assert [t.extra for t in tokens] == ["a", "aa", "a"] 232 233 # Check that embedding ops and numbers inside a variable name works 234 tokens = _tokenize_constraint("2 * a[1,1],", ["a[1,1]"]) 235 assert len(tokens) == 4 236 assert [t.type for t in tokens] == ["NUMBER", "*", "VARIABLE", ","] 237 assert [t.extra for t in tokens] == ["2", "*", "a[1,1]", ","] 238 239def parse_constraint(string, variable_names): 240 return infix_parse(_tokenize_constraint(string, variable_names), 241 _ops, _atomic) 242 243class _EvalConstraint(object): 244 def __init__(self, variable_names): 245 self._variable_names = variable_names 246 self._N = len(variable_names) 247 248 self._dispatch = { 249 ("VARIABLE", 0): self._eval_variable, 250 ("NUMBER", 0): self._eval_number, 251 ("+", 1): self._eval_unary_plus, 252 ("-", 1): self._eval_unary_minus, 253 ("+", 2): self._eval_binary_plus, 254 ("-", 2): self._eval_binary_minus, 255 ("*", 2): self._eval_binary_multiply, 256 ("/", 2): self._eval_binary_div, 257 ("=", 2): self._eval_binary_eq, 258 (",", 2): self._eval_binary_comma, 259 } 260 261 # General scheme: there are 2 types we deal with: 262 # - linear combinations ("lincomb"s) of variables and constants, 263 # represented as ndarrays with size N+1 264 # The last entry is the constant, so [10, 20, 30] means 10x + 20y + 265 # 30. 266 # - LinearConstraint objects 267 268 def is_constant(self, coefs): 269 return np.all(coefs[:self._N] == 0) 270 271 def _eval_variable(self, tree): 272 var = tree.token.extra 273 coefs = np.zeros((self._N + 1,), dtype=float) 274 coefs[self._variable_names.index(var)] = 1 275 return coefs 276 277 def _eval_number(self, tree): 278 coefs = np.zeros((self._N + 1,), dtype=float) 279 coefs[-1] = float(tree.token.extra) 280 return coefs 281 282 def _eval_unary_plus(self, tree): 283 return self.eval(tree.args[0]) 284 285 def _eval_unary_minus(self, tree): 286 return -1 * self.eval(tree.args[0]) 287 288 def _eval_binary_plus(self, tree): 289 return self.eval(tree.args[0]) + self.eval(tree.args[1]) 290 291 def _eval_binary_minus(self, tree): 292 return self.eval(tree.args[0]) - self.eval(tree.args[1]) 293 294 def _eval_binary_div(self, tree): 295 left = self.eval(tree.args[0]) 296 right = self.eval(tree.args[1]) 297 if not self.is_constant(right): 298 raise PatsyError("Can't divide by a variable in a linear " 299 "constraint", tree.args[1]) 300 return left / right[-1] 301 302 def _eval_binary_multiply(self, tree): 303 left = self.eval(tree.args[0]) 304 right = self.eval(tree.args[1]) 305 if self.is_constant(left): 306 return left[-1] * right 307 elif self.is_constant(right): 308 return left * right[-1] 309 else: 310 raise PatsyError("Can't multiply one variable by another " 311 "in a linear constraint", tree) 312 313 def _eval_binary_eq(self, tree): 314 # Handle "a1 = a2 = a3", which is parsed as "(a1 = a2) = a3" 315 args = list(tree.args) 316 constraints = [] 317 for i, arg in enumerate(args): 318 if arg.type == "=": 319 constraints.append(self.eval(arg, constraint=True)) 320 # make our left argument be their right argument, or 321 # vice-versa 322 args[i] = arg.args[1 - i] 323 left = self.eval(args[0]) 324 right = self.eval(args[1]) 325 coefs = left[:self._N] - right[:self._N] 326 if np.all(coefs == 0): 327 raise PatsyError("no variables appear in constraint", tree) 328 constant = -left[-1] + right[-1] 329 constraint = LinearConstraint(self._variable_names, coefs, constant) 330 constraints.append(constraint) 331 return LinearConstraint.combine(constraints) 332 333 def _eval_binary_comma(self, tree): 334 left = self.eval(tree.args[0], constraint=True) 335 right = self.eval(tree.args[1], constraint=True) 336 return LinearConstraint.combine([left, right]) 337 338 def eval(self, tree, constraint=False): 339 key = (tree.type, len(tree.args)) 340 assert key in self._dispatch 341 val = self._dispatch[key](tree) 342 if constraint: 343 # Force it to be a constraint 344 if isinstance(val, LinearConstraint): 345 return val 346 else: 347 assert val.size == self._N + 1 348 if np.all(val[:self._N] == 0): 349 raise PatsyError("term is constant, with no variables", 350 tree) 351 return LinearConstraint(self._variable_names, 352 val[:self._N], 353 -val[-1]) 354 else: 355 # Force it to *not* be a constraint 356 if isinstance(val, LinearConstraint): 357 raise PatsyError("unexpected constraint object", tree) 358 return val 359 360def linear_constraint(constraint_like, variable_names): 361 """This is the internal interface implementing 362 DesignInfo.linear_constraint, see there for docs.""" 363 if isinstance(constraint_like, LinearConstraint): 364 if constraint_like.variable_names != variable_names: 365 raise ValueError("LinearConstraint has wrong variable_names " 366 "(got %r, expected %r)" 367 % (constraint_like.variable_names, 368 variable_names)) 369 return constraint_like 370 371 if isinstance(constraint_like, Mapping): 372 # Simple conjunction-of-equality constraints can be specified as 373 # dicts. {"x": 1, "y": 2} -> tests x = 1 and y = 2. Keys can be 374 # either variable names, or variable indices. 375 coefs = np.zeros((len(constraint_like), len(variable_names)), 376 dtype=float) 377 constants = np.zeros(len(constraint_like)) 378 used = set() 379 for i, (name, value) in enumerate(six.iteritems(constraint_like)): 380 if name in variable_names: 381 idx = variable_names.index(name) 382 elif isinstance(name, six.integer_types): 383 idx = name 384 else: 385 raise ValueError("unrecognized variable name/index %r" 386 % (name,)) 387 if idx in used: 388 raise ValueError("duplicated constraint on %r" 389 % (variable_names[idx],)) 390 used.add(idx) 391 coefs[i, idx] = 1 392 constants[i] = value 393 return LinearConstraint(variable_names, coefs, constants) 394 395 if isinstance(constraint_like, str): 396 constraint_like = [constraint_like] 397 # fall-through 398 399 if (isinstance(constraint_like, list) 400 and constraint_like 401 and isinstance(constraint_like[0], str)): 402 constraints = [] 403 for code in constraint_like: 404 if not isinstance(code, str): 405 raise ValueError("expected a string, not %r" % (code,)) 406 tree = parse_constraint(code, variable_names) 407 evaluator = _EvalConstraint(variable_names) 408 constraints.append(evaluator.eval(tree, constraint=True)) 409 return LinearConstraint.combine(constraints) 410 411 if isinstance(constraint_like, tuple): 412 if len(constraint_like) != 2: 413 raise ValueError("constraint tuple must have length 2") 414 coef, constants = constraint_like 415 return LinearConstraint(variable_names, coef, constants) 416 417 # assume a raw ndarray 418 coefs = np.asarray(constraint_like, dtype=float) 419 return LinearConstraint(variable_names, coefs) 420 421 422def _check_lincon(input, varnames, coefs, constants): 423 try: 424 from numpy.testing import assert_equal 425 except ImportError: 426 from numpy.testing.utils import assert_equal 427 got = linear_constraint(input, varnames) 428 print("got", got) 429 expected = LinearConstraint(varnames, coefs, constants) 430 print("expected", expected) 431 assert_equal(got.variable_names, expected.variable_names) 432 assert_equal(got.coefs, expected.coefs) 433 assert_equal(got.constants, expected.constants) 434 assert_equal(got.coefs.dtype, np.dtype(float)) 435 assert_equal(got.constants.dtype, np.dtype(float)) 436 437 438def test_linear_constraint(): 439 import pytest 440 from patsy.compat import OrderedDict 441 t = _check_lincon 442 443 t(LinearConstraint(["a", "b"], [2, 3]), ["a", "b"], [[2, 3]], [[0]]) 444 pytest.raises(ValueError, linear_constraint, 445 LinearConstraint(["b", "a"], [2, 3]), 446 ["a", "b"]) 447 448 t({"a": 2}, ["a", "b"], [[1, 0]], [[2]]) 449 t(OrderedDict([("a", 2), ("b", 3)]), 450 ["a", "b"], [[1, 0], [0, 1]], [[2], [3]]) 451 t(OrderedDict([("a", 2), ("b", 3)]), 452 ["b", "a"], [[0, 1], [1, 0]], [[2], [3]]) 453 454 t({0: 2}, ["a", "b"], [[1, 0]], [[2]]) 455 t(OrderedDict([(0, 2), (1, 3)]), ["a", "b"], [[1, 0], [0, 1]], [[2], [3]]) 456 457 t(OrderedDict([("a", 2), (1, 3)]), 458 ["a", "b"], [[1, 0], [0, 1]], [[2], [3]]) 459 460 pytest.raises(ValueError, linear_constraint, {"q": 1}, ["a", "b"]) 461 pytest.raises(ValueError, linear_constraint, {"a": 1, 0: 2}, ["a", "b"]) 462 463 t(np.array([2, 3]), ["a", "b"], [[2, 3]], [[0]]) 464 t(np.array([[2, 3], [4, 5]]), ["a", "b"], [[2, 3], [4, 5]], [[0], [0]]) 465 466 t("a = 2", ["a", "b"], [[1, 0]], [[2]]) 467 t("a - 2", ["a", "b"], [[1, 0]], [[2]]) 468 t("a + 1 = 3", ["a", "b"], [[1, 0]], [[2]]) 469 t("a + b = 3", ["a", "b"], [[1, 1]], [[3]]) 470 t("a = 2, b = 3", ["a", "b"], [[1, 0], [0, 1]], [[2], [3]]) 471 t("b = 3, a = 2", ["a", "b"], [[0, 1], [1, 0]], [[3], [2]]) 472 473 t(["a = 2", "b = 3"], ["a", "b"], [[1, 0], [0, 1]], [[2], [3]]) 474 475 pytest.raises(ValueError, linear_constraint, ["a", {"b": 0}], ["a", "b"]) 476 477 # Actual evaluator tests 478 t("2 * (a + b/3) + b + 2*3/4 = 1 + 2*3", ["a", "b"], 479 [[2, 2.0/3 + 1]], [[7 - 6.0/4]]) 480 t("+2 * -a", ["a", "b"], [[-2, 0]], [[0]]) 481 t("a - b, a + b = 2", ["a", "b"], [[1, -1], [1, 1]], [[0], [2]]) 482 t("a = 1, a = 2, a = 3", ["a", "b"], 483 [[1, 0], [1, 0], [1, 0]], [[1], [2], [3]]) 484 t("a * 2", ["a", "b"], [[2, 0]], [[0]]) 485 t("-a = 1", ["a", "b"], [[-1, 0]], [[1]]) 486 t("(2 + a - a) * b", ["a", "b"], [[0, 2]], [[0]]) 487 488 t("a = 1 = b", ["a", "b"], [[1, 0], [0, -1]], [[1], [-1]]) 489 t("a = (1 = b)", ["a", "b"], [[0, -1], [1, 0]], [[-1], [1]]) 490 t("a = 1, a = b = c", ["a", "b", "c"], 491 [[1, 0, 0], [1, -1, 0], [0, 1, -1]], [[1], [0], [0]]) 492 493 # One should never do this of course, but test that it works anyway... 494 t("a + 1 = 2", ["a", "a + 1"], [[0, 1]], [[2]]) 495 496 t(([10, 20], [30]), ["a", "b"], [[10, 20]], [[30]]) 497 t(([[10, 20], [20, 40]], [[30], [35]]), ["a", "b"], 498 [[10, 20], [20, 40]], [[30], [35]]) 499 # wrong-length tuple 500 pytest.raises(ValueError, linear_constraint, 501 ([1, 0], [0], [0]), ["a", "b"]) 502 pytest.raises(ValueError, linear_constraint, ([1, 0],), ["a", "b"]) 503 504 t([10, 20], ["a", "b"], [[10, 20]], [[0]]) 505 t([[10, 20], [20, 40]], ["a", "b"], [[10, 20], [20, 40]], [[0], [0]]) 506 t(np.array([10, 20]), ["a", "b"], [[10, 20]], [[0]]) 507 t(np.array([[10, 20], [20, 40]]), ["a", "b"], 508 [[10, 20], [20, 40]], [[0], [0]]) 509 510 # unknown object type 511 pytest.raises(ValueError, linear_constraint, None, ["a", "b"]) 512 513 514_parse_eval_error_tests = [ 515 # Bad token 516 "a + <f>oo", 517 # No pure constant equalities 518 "a = 1, <1 = 1>, b = 1", 519 "a = 1, <b * 2 - b + (-2/2 * b)>", 520 "a = 1, <1>, b = 2", 521 "a = 1, <2 * b = b + b>, c", 522 # No non-linearities 523 "a + <a * b> + c", 524 "a + 2 / <b> + c", 525 # Constraints are not numbers 526 "a = 1, 2 * <(a = b)>, c", 527 "a = 1, a + <(a = b)>, c", 528 "a = 1, <(a, b)> + 2, c", 529] 530 531 532def test_eval_errors(): 533 def doit(bad_code): 534 return linear_constraint(bad_code, ["a", "b", "c"]) 535 _parsing_error_test(doit, _parse_eval_error_tests) 536