1# -*- coding: utf-8 -*-
2#  ___________________________________________________________________________
3#
4#  Pyomo: Python Optimization Modeling Objects
5#  Copyright 2017 National Technology and Engineering Solutions of Sandia, LLC
6#  Under the terms of Contract DE-NA0003525 with National Technology and
7#  Engineering Solutions of Sandia, LLC, the U.S. Government retains certain
8#  rights in this software.
9#  This software is distributed under the 3-clause BSD License.
10#  ___________________________________________________________________________
11#
12#
13
14import pyomo.common.unittest as unittest
15from pyomo.environ import (
16    ConcreteModel, Var, Param, Set, Constraint, Objective, Expression,
17    Suffix, RangeSet, ExternalFunction, units, maximize, sin, cos, sqrt,
18)
19from pyomo.network import Port, Arc
20from pyomo.dae import ContinuousSet, DerivativeVar
21from pyomo.gdp import Disjunct, Disjunction
22from pyomo.core.base.units_container import (
23    pint_available, UnitsError,
24)
25from pyomo.util.check_units import assert_units_consistent, assert_units_equivalent, check_units_equivalent
26
27def python_callback_function(arg1, arg2):
28    return 42.0
29
30@unittest.skipIf(not pint_available, 'Testing units requires pint')
31class TestUnitsChecking(unittest.TestCase):
32    def _create_model_and_vars(self):
33        u = units
34        m = ConcreteModel()
35        m.dx = Var(units=u.m, initialize=0.10188943773836046)
36        m.dy = Var(units=u.m, initialize=0.0)
37        m.vx = Var(units=u.m/u.s, initialize=0.7071067769802851)
38        m.vy = Var(units=u.m/u.s, initialize=0.7071067769802851)
39        m.t = Var(units=u.s, bounds=(1e-5,10.0), initialize=0.0024015570927624456)
40        m.theta = Var(bounds=(0, 0.49*3.14), initialize=0.7853981693583533, units=u.radians)
41        m.a = Param(initialize=-32.2, units=u.ft/u.s**2)
42        m.x_unitless = Var()
43        return m
44
45    def test_assert_units_consistent_equivalent(self):
46        u = units
47        m = ConcreteModel()
48        m.dx = Var(units=u.m, initialize=0.10188943773836046)
49        m.dy = Var(units=u.m, initialize=0.0)
50        m.vx = Var(units=u.m/u.s, initialize=0.7071067769802851)
51        m.vy = Var(units=u.m/u.s, initialize=0.7071067769802851)
52        m.t = Var(units=u.min, bounds=(1e-5,10.0), initialize=0.0024015570927624456)
53        m.theta = Var(bounds=(0, 0.49*3.14), initialize=0.7853981693583533, units=u.radians)
54        m.a = Param(initialize=-32.2, units=u.ft/u.s**2)
55        m.x_unitless = Var()
56
57        m.obj = Objective(expr = m.dx, sense=maximize)
58        m.vx_con = Constraint(expr = m.vx == 1.0*u.m/u.s*cos(m.theta))
59        m.vy_con = Constraint(expr = m.vy == 1.0*u.m/u.s*sin(m.theta))
60        m.dx_con = Constraint(expr = m.dx == m.vx*u.convert(m.t, to_units=u.s))
61        m.dy_con = Constraint(expr = m.dy == m.vy*u.convert(m.t, to_units=u.s)
62                              + 0.5*(u.convert(m.a, to_units=u.m/u.s**2))*(u.convert(m.t, to_units=u.s))**2)
63        m.ground = Constraint(expr = m.dy == 0)
64        m.unitless_con = Constraint(expr = m.x_unitless == 5.0)
65
66        assert_units_consistent(m) # check model
67        assert_units_consistent(m.dx) # check var - this should never fail
68        assert_units_consistent(m.x_unitless) # check unitless var - this should never fail
69        assert_units_consistent(m.vx_con) # check constraint
70        assert_units_consistent(m.unitless_con) # check unitless constraint
71
72        assert_units_equivalent(m.dx, m.dy) # check var
73        assert_units_equivalent(m.x_unitless, u.dimensionless) # check unitless var
74        assert_units_equivalent(m.x_unitless, None) # check unitless var
75        assert_units_equivalent(m.vx_con.body, u.m/u.s) # check constraint
76        assert_units_equivalent(m.unitless_con.body, u.dimensionless) # check unitless constraint
77        assert_units_equivalent(m.dx, m.dy) # check var
78        assert_units_equivalent(m.x_unitless, u.dimensionless) # check unitless var
79        assert_units_equivalent(m.x_unitless, None) # check unitless var
80        assert_units_equivalent(m.vx_con.body, u.m/u.s) # check constraint
81
82        m.broken = Constraint(expr = m.dy == 42.0*u.kg)
83        with self.assertRaises(UnitsError):
84            assert_units_consistent(m)
85        assert_units_consistent(m.dx)
86        assert_units_consistent(m.vx_con)
87        with self.assertRaises(UnitsError):
88            assert_units_consistent(m.broken)
89
90        self.assertTrue(check_units_equivalent(m.dx, m.dy))
91        self.assertFalse(check_units_equivalent(m.dx, m.vx))
92
93    def test_assert_units_consistent_on_datas(self):
94        u = units
95        m = ConcreteModel()
96        m.S = Set(initialize=[1,2,3])
97        m.x = Var(m.S, units=u.m)
98        m.t = Var(m.S, units=u.s)
99        m.v = Var(m.S, units=u.m/u.s)
100        m.unitless = Var(m.S)
101
102        @m.Constraint(m.S)
103        def vel_con(m,i):
104            return m.v[i] == m.x[i]/m.t[i]
105        @m.Constraint(m.S)
106        def unitless_con(m,i):
107            return m.unitless[i] == 42.0
108        @m.Constraint(m.S)
109        def sqrt_con(m,i):
110            return sqrt(m.v[i]) == sqrt(m.x[i]/m.t[i])
111
112        assert_units_consistent(m)  # check model
113        assert_units_consistent(m.x)  # check var
114        assert_units_consistent(m.t)  # check var
115        assert_units_consistent(m.v)  # check var
116        assert_units_consistent(m.unitless)  # check var
117        assert_units_consistent(m.vel_con) # check constraint
118        assert_units_consistent(m.unitless_con) # check unitless constraint
119
120        assert_units_consistent(m.x[2])  # check var data
121        assert_units_consistent(m.t[2])  # check var data
122        assert_units_consistent(m.v[2])  # check var data
123        assert_units_consistent(m.unitless[2])  # check var
124        assert_units_consistent(m.vel_con[2]) # check constraint data
125        assert_units_consistent(m.unitless_con[2]) # check unitless constraint data
126
127        assert_units_equivalent(m.x[2], m.x[1])  # check var data
128        assert_units_equivalent(m.t[2], u.s)  # check var data
129        assert_units_equivalent(m.v[2], u.m/u.s)  # check var data
130        assert_units_equivalent(m.unitless[2], u.dimensionless)  # check var data unitless
131        assert_units_equivalent(m.unitless[2], None)  # check var
132        assert_units_equivalent(m.vel_con[2].body, u.m/u.s) # check constraint data
133        assert_units_equivalent(m.unitless_con[2].body, u.dimensionless) # check unitless constraint data
134
135        @m.Constraint(m.S)
136        def broken(m,i):
137            return m.x[i] == 42.0*m.v[i]
138        with self.assertRaises(UnitsError):
139            assert_units_consistent(m)
140        with self.assertRaises(UnitsError):
141            assert_units_consistent(m.broken)
142        with self.assertRaises(UnitsError):
143            assert_units_consistent(m.broken[1])
144
145        # all of these should still work
146        assert_units_consistent(m.x)  # check var
147        assert_units_consistent(m.t)  # check var
148        assert_units_consistent(m.v)  # check var
149        assert_units_consistent(m.unitless)  # check var
150        assert_units_consistent(m.vel_con) # check constraint
151        assert_units_consistent(m.unitless_con) # check unitless constraint
152
153        assert_units_consistent(m.x[2])  # check var data
154        assert_units_consistent(m.t[2])  # check var data
155        assert_units_consistent(m.v[2])  # check var data
156        assert_units_consistent(m.unitless[2])  # check var
157        assert_units_consistent(m.vel_con[2]) # check constraint data
158        assert_units_consistent(m.unitless_con[2]) # check unitless constraint data
159
160    def test_assert_units_consistent_all_components(self):
161        # test all scalar components consistent
162        u = units
163        m = self._create_model_and_vars()
164        m.obj = Objective(expr=m.dx/m.t - m.vx)
165        m.con = Constraint(expr=m.dx/m.t == m.vx)
166        # vars already added
167        m.exp = Expression(expr=m.dx/m.t - m.vx)
168        m.suff = Suffix(direction=Suffix.LOCAL)
169        # params already added
170        # sets already added
171        m.rs = RangeSet(5)
172        m.disj1 = Disjunct()
173        m.disj1.constraint = Constraint(expr=m.dx/m.t <= m.vx)
174        m.disj2 = Disjunct()
175        m.disj2.constraint = Constraint(expr=m.dx/m.t <= m.vx)
176        m.disjn = Disjunction(expr=[m.disj1, m.disj2])
177        # block tested as part of model
178        m.extfn = ExternalFunction(python_callback_function, units=u.m/u.s, arg_units=[u.m, u.s])
179        m.conext = Constraint(expr=m.extfn(m.dx, m.t) - m.vx==0)
180        m.cset = ContinuousSet(bounds=(0,1))
181        m.svar = Var(m.cset, units=u.m)
182        m.dvar = DerivativeVar(sVar=m.svar, units=u.m/u.s)
183        def prt1_rule(m):
184            return {'avar': m.dx}
185        def prt2_rule(m):
186            return {'avar': m.dy}
187        m.prt1 = Port(rule=prt1_rule)
188        m.prt2 = Port(rule=prt2_rule)
189        def arcrule(m):
190            return dict(source=m.prt1, destination=m.prt2)
191        m.arc = Arc(rule=arcrule)
192
193        # complementarities do not work yet
194        # The expression system removes the u.m since it is multiplied by zero.
195        # We need to change the units_container to allow 0 when comparing units
196        # m.compl = Complementarity(expr=complements(m.dx/m.t >= m.vx, m.dx == 0*u.m))
197
198        assert_units_consistent(m)
199
200if __name__ == "__main__":
201    unittest.main()
202