1# This code supports verifying group implementations which have branches
2# or conditional statements (like cmovs), by allowing each execution path
3# to independently set assumptions on input or intermediary variables.
4#
5# The general approach is:
6# * A constraint is a tuple of two sets of symbolic expressions:
7#   the first of which are required to evaluate to zero, the second of which
8#   are required to evaluate to nonzero.
9#   - A constraint is said to be conflicting if any of its nonzero expressions
10#     is in the ideal with basis the zero expressions (in other words: when the
11#     zero expressions imply that one of the nonzero expressions are zero).
12# * There is a list of laws that describe the intended behaviour, including
13#   laws for addition and doubling. Each law is called with the symbolic point
14#   coordinates as arguments, and returns:
15#   - A constraint describing the assumptions under which it is applicable,
16#     called "assumeLaw"
17#   - A constraint describing the requirements of the law, called "require"
18# * Implementations are transliterated into functions that operate as well on
19#   algebraic input points, and are called once per combination of branches
20#   executed. Each execution returns:
21#   - A constraint describing the assumptions this implementation requires
22#     (such as Z1=1), called "assumeFormula"
23#   - A constraint describing the assumptions this specific branch requires,
24#     but which is by construction guaranteed to cover the entire space by
25#     merging the results from all branches, called "assumeBranch"
26#   - The result of the computation
27# * All combinations of laws with implementation branches are tried, and:
28#   - If the combination of assumeLaw, assumeFormula, and assumeBranch results
29#     in a conflict, it means this law does not apply to this branch, and it is
30#     skipped.
31#   - For others, we try to prove the require constraints hold, assuming the
32#     information in assumeLaw + assumeFormula + assumeBranch, and if this does
33#     not succeed, we fail.
34#     + To prove an expression is zero, we check whether it belongs to the
35#       ideal with the assumed zero expressions as basis. This test is exact.
36#     + To prove an expression is nonzero, we check whether each of its
37#       factors is contained in the set of nonzero assumptions' factors.
38#       This test is not exact, so various combinations of original and
39#       reduced expressions' factors are tried.
40#   - If we succeed, we print out the assumptions from assumeFormula that
41#     weren't implied by assumeLaw already. Those from assumeBranch are skipped,
42#     as we assume that all constraints in it are complementary with each other.
43#
44# Based on the sage verification scripts used in the Explicit-Formulas Database
45# by Tanja Lange and others, see https://hyperelliptic.org/EFD
46
47class fastfrac:
48  """Fractions over rings."""
49
50  def __init__(self,R,top,bot=1):
51    """Construct a fractional, given a ring, a numerator, and denominator."""
52    self.R = R
53    if parent(top) == ZZ or parent(top) == R:
54      self.top = R(top)
55      self.bot = R(bot)
56    elif top.__class__ == fastfrac:
57      self.top = top.top
58      self.bot = top.bot * bot
59    else:
60      self.top = R(numerator(top))
61      self.bot = R(denominator(top)) * bot
62
63  def iszero(self,I):
64    """Return whether this fraction is zero given an ideal."""
65    return self.top in I and self.bot not in I
66
67  def reduce(self,assumeZero):
68    zero = self.R.ideal(list(map(numerator, assumeZero)))
69    return fastfrac(self.R, zero.reduce(self.top)) / fastfrac(self.R, zero.reduce(self.bot))
70
71  def __add__(self,other):
72    """Add two fractions."""
73    if parent(other) == ZZ:
74      return fastfrac(self.R,self.top + self.bot * other,self.bot)
75    if other.__class__ == fastfrac:
76      return fastfrac(self.R,self.top * other.bot + self.bot * other.top,self.bot * other.bot)
77    return NotImplemented
78
79  def __sub__(self,other):
80    """Subtract two fractions."""
81    if parent(other) == ZZ:
82      return fastfrac(self.R,self.top - self.bot * other,self.bot)
83    if other.__class__ == fastfrac:
84      return fastfrac(self.R,self.top * other.bot - self.bot * other.top,self.bot * other.bot)
85    return NotImplemented
86
87  def __neg__(self):
88    """Return the negation of a fraction."""
89    return fastfrac(self.R,-self.top,self.bot)
90
91  def __mul__(self,other):
92    """Multiply two fractions."""
93    if parent(other) == ZZ:
94      return fastfrac(self.R,self.top * other,self.bot)
95    if other.__class__ == fastfrac:
96      return fastfrac(self.R,self.top * other.top,self.bot * other.bot)
97    return NotImplemented
98
99  def __rmul__(self,other):
100    """Multiply something else with a fraction."""
101    return self.__mul__(other)
102
103  def __truediv__(self,other):
104    """Divide two fractions."""
105    if parent(other) == ZZ:
106      return fastfrac(self.R,self.top,self.bot * other)
107    if other.__class__ == fastfrac:
108      return fastfrac(self.R,self.top * other.bot,self.bot * other.top)
109    return NotImplemented
110
111  # Compatibility wrapper for Sage versions based on Python 2
112  def __div__(self,other):
113     """Divide two fractions."""
114     return self.__truediv__(other)
115
116  def __pow__(self,other):
117    """Compute a power of a fraction."""
118    if parent(other) == ZZ:
119      if other < 0:
120        # Negative powers require flipping top and bottom
121        return fastfrac(self.R,self.bot ^ (-other),self.top ^ (-other))
122      else:
123        return fastfrac(self.R,self.top ^ other,self.bot ^ other)
124    return NotImplemented
125
126  def __str__(self):
127    return "fastfrac((" + str(self.top) + ") / (" + str(self.bot) + "))"
128  def __repr__(self):
129    return "%s" % self
130
131  def numerator(self):
132    return self.top
133
134class constraints:
135  """A set of constraints, consisting of zero and nonzero expressions.
136
137  Constraints can either be used to express knowledge or a requirement.
138
139  Both the fields zero and nonzero are maps from expressions to description
140  strings. The expressions that are the keys in zero are required to be zero,
141  and the expressions that are the keys in nonzero are required to be nonzero.
142
143  Note that (a != 0) and (b != 0) is the same as (a*b != 0), so all keys in
144  nonzero could be multiplied into a single key. This is often much less
145  efficient to work with though, so we keep them separate inside the
146  constraints. This allows higher-level code to do fast checks on the individual
147  nonzero elements, or combine them if needed for stronger checks.
148
149  We can't multiply the different zero elements, as it would suffice for one of
150  the factors to be zero, instead of all of them. Instead, the zero elements are
151  typically combined into an ideal first.
152  """
153
154  def __init__(self, **kwargs):
155    if 'zero' in kwargs:
156      self.zero = dict(kwargs['zero'])
157    else:
158      self.zero = dict()
159    if 'nonzero' in kwargs:
160      self.nonzero = dict(kwargs['nonzero'])
161    else:
162      self.nonzero = dict()
163
164  def negate(self):
165    return constraints(zero=self.nonzero, nonzero=self.zero)
166
167  def __add__(self, other):
168    zero = self.zero.copy()
169    zero.update(other.zero)
170    nonzero = self.nonzero.copy()
171    nonzero.update(other.nonzero)
172    return constraints(zero=zero, nonzero=nonzero)
173
174  def __str__(self):
175    return "constraints(zero=%s,nonzero=%s)" % (self.zero, self.nonzero)
176
177  def __repr__(self):
178    return "%s" % self
179
180
181def conflicts(R, con):
182  """Check whether any of the passed non-zero assumptions is implied by the zero assumptions"""
183  zero = R.ideal(list(map(numerator, con.zero)))
184  if 1 in zero:
185    return True
186  # First a cheap check whether any of the individual nonzero terms conflict on
187  # their own.
188  for nonzero in con.nonzero:
189    if nonzero.iszero(zero):
190      return True
191  # It can be the case that entries in the nonzero set do not individually
192  # conflict with the zero set, but their combination does. For example, knowing
193  # that either x or y is zero is equivalent to having x*y in the zero set.
194  # Having x or y individually in the nonzero set is not a conflict, but both
195  # simultaneously is, so that is the right thing to check for.
196  if reduce(lambda a,b: a * b, con.nonzero, fastfrac(R, 1)).iszero(zero):
197    return True
198  return False
199
200
201def get_nonzero_set(R, assume):
202  """Calculate a simple set of nonzero expressions"""
203  zero = R.ideal(list(map(numerator, assume.zero)))
204  nonzero = set()
205  for nz in map(numerator, assume.nonzero):
206    for (f,n) in nz.factor():
207      nonzero.add(f)
208    rnz = zero.reduce(nz)
209    for (f,n) in rnz.factor():
210      nonzero.add(f)
211  return nonzero
212
213
214def prove_nonzero(R, exprs, assume):
215  """Check whether an expression is provably nonzero, given assumptions"""
216  zero = R.ideal(list(map(numerator, assume.zero)))
217  nonzero = get_nonzero_set(R, assume)
218  expl = set()
219  ok = True
220  for expr in exprs:
221    if numerator(expr) in zero:
222      return (False, [exprs[expr]])
223  allexprs = reduce(lambda a,b: numerator(a)*numerator(b), exprs, 1)
224  for (f, n) in allexprs.factor():
225    if f not in nonzero:
226      ok = False
227  if ok:
228    return (True, None)
229  ok = True
230  for (f, n) in zero.reduce(numerator(allexprs)).factor():
231    if f not in nonzero:
232      ok = False
233  if ok:
234    return (True, None)
235  ok = True
236  for expr in exprs:
237    for (f,n) in numerator(expr).factor():
238      if f not in nonzero:
239        ok = False
240  if ok:
241    return (True, None)
242  ok = True
243  for expr in exprs:
244    for (f,n) in zero.reduce(numerator(expr)).factor():
245      if f not in nonzero:
246        expl.add(exprs[expr])
247  if expl:
248    return (False, list(expl))
249  else:
250    return (True, None)
251
252
253def prove_zero(R, exprs, assume):
254  """Check whether all of the passed expressions are provably zero, given assumptions"""
255  r, e = prove_nonzero(R, dict(map(lambda x: (fastfrac(R, x.bot, 1), exprs[x]), exprs)), assume)
256  if not r:
257    return (False, map(lambda x: "Possibly zero denominator: %s" % x, e))
258  zero = R.ideal(list(map(numerator, assume.zero)))
259  nonzero = prod(x for x in assume.nonzero)
260  expl = []
261  for expr in exprs:
262    if not expr.iszero(zero):
263      expl.append(exprs[expr])
264  if not expl:
265    return (True, None)
266  return (False, expl)
267
268
269def describe_extra(R, assume, assumeExtra):
270  """Describe what assumptions are added, given existing assumptions"""
271  zerox = assume.zero.copy()
272  zerox.update(assumeExtra.zero)
273  zero = R.ideal(list(map(numerator, assume.zero)))
274  zeroextra = R.ideal(list(map(numerator, zerox)))
275  nonzero = get_nonzero_set(R, assume)
276  ret = set()
277  # Iterate over the extra zero expressions
278  for base in assumeExtra.zero:
279    if base not in zero:
280      add = []
281      for (f, n) in numerator(base).factor():
282        if f not in nonzero:
283          add += ["%s" % f]
284      if add:
285        ret.add((" * ".join(add)) + " = 0 [%s]" % assumeExtra.zero[base])
286  # Iterate over the extra nonzero expressions
287  for nz in assumeExtra.nonzero:
288    nzr = zeroextra.reduce(numerator(nz))
289    if nzr not in zeroextra:
290      for (f,n) in nzr.factor():
291        if zeroextra.reduce(f) not in nonzero:
292          ret.add("%s != 0" % zeroextra.reduce(f))
293  return ", ".join(x for x in ret)
294
295
296def check_symbolic(R, assumeLaw, assumeAssert, assumeBranch, require):
297  """Check a set of zero and nonzero requirements, given a set of zero and nonzero assumptions"""
298  assume = assumeLaw + assumeAssert + assumeBranch
299
300  if conflicts(R, assume):
301    # This formula does not apply
302    return None
303
304  describe = describe_extra(R, assumeLaw + assumeBranch, assumeAssert)
305
306  ok, msg = prove_zero(R, require.zero, assume)
307  if not ok:
308    return "FAIL, %s fails (assuming %s)" % (str(msg), describe)
309
310  res, expl = prove_nonzero(R, require.nonzero, assume)
311  if not res:
312    return "FAIL, %s fails (assuming %s)" % (str(expl), describe)
313
314  if describe != "":
315    return "OK (assuming %s)" % describe
316  else:
317    return "OK"
318
319
320def concrete_verify(c):
321  for k in c.zero:
322    if k != 0:
323      return (False, c.zero[k])
324  for k in c.nonzero:
325    if k == 0:
326      return (False, c.nonzero[k])
327  return (True, None)
328