1/*
2 * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24
25/*
26 * This file is used to generated optimized finite field implementations.
27 * Required settings are included in the file. To generate, use jshell:
28 * jshell < FieldGen.jsh
29 */
30
31import java.io.*;
32import java.math.BigInteger;
33import java.nio.file.Files;
34import java.nio.file.Paths;
35import java.util.*;
36
37public class FieldGen {
38
39    static FieldParams Curve25519 = new FieldParams("IntegerPolynomial25519", 26, 10, 1, 255,
40    Arrays.asList(
41    new Term(0, -19)
42    ),
43    Curve25519CrSequence(), simpleSmallCrSequence(10)
44    );
45
46    private static List<CarryReduce> Curve25519CrSequence() {
47        List<CarryReduce> result = new ArrayList<CarryReduce>();
48
49        // reduce(7,2)
50        result.add(new Reduce(17));
51        result.add(new Reduce(18));
52
53        // carry(8,2)
54        result.add(new Carry(8));
55        result.add(new Carry(9));
56
57        // reduce(0,7)
58        for (int i = 10; i < 17; i++) {
59            result.add(new Reduce(i));
60        }
61
62        // carry(0,9)
63        result.addAll(fullCarry(10));
64
65        return result;
66    }
67
68    static FieldParams Curve448 = new FieldParams("IntegerPolynomial448", 28, 16, 1, 448,
69    Arrays.asList(
70    new Term(224, -1),
71    new Term(0, -1)
72    ),
73    Curve448CrSequence(), simpleSmallCrSequence(16)
74    );
75
76    private static List<CarryReduce> Curve448CrSequence() {
77        List<CarryReduce> result = new ArrayList<CarryReduce>();
78
79        // reduce(8, 7)
80        for (int i = 24; i < 31; i++) {
81            result.add(new Reduce(i));
82        }
83        // reduce(4, 4)
84        for (int i = 20; i < 24; i++) {
85            result.add(new Reduce(i));
86        }
87
88        //carry(14, 2)
89        result.add(new Carry(14));
90        result.add(new Carry(15));
91
92        // reduce(0, 4)
93        for (int i = 16; i < 20; i++) {
94            result.add(new Reduce(i));
95        }
96
97        // carry(0, 15)
98        result.addAll(fullCarry(16));
99
100        return result;
101    }
102
103    static FieldParams P256 = new FieldParams("IntegerPolynomialP256", 26, 10, 2, 256,
104    Arrays.asList(
105    new Term(224, -1),
106    new Term(192, 1),
107    new Term(96, 1),
108    new Term(0, -1)
109    ),
110    P256CrSequence(), simpleSmallCrSequence(10)
111    );
112
113    private static List<CarryReduce> P256CrSequence() {
114        List<CarryReduce> result = new ArrayList<CarryReduce>();
115        result.addAll(fullReduce(10));
116        result.addAll(simpleSmallCrSequence(10));
117        return result;
118    }
119
120    static FieldParams P384 = new FieldParams("IntegerPolynomialP384", 28, 14, 2, 384,
121    Arrays.asList(
122    new Term(128, -1),
123    new Term(96, -1),
124    new Term(32, 1),
125    new Term(0, -1)
126    ),
127    P384CrSequence(), simpleSmallCrSequence(14)
128    );
129
130    private static List<CarryReduce> P384CrSequence() {
131        List<CarryReduce> result = new ArrayList<CarryReduce>();
132        result.addAll(fullReduce(14));
133        result.addAll(simpleSmallCrSequence(14));
134        return result;
135    }
136
137    static FieldParams P521 = new FieldParams("IntegerPolynomialP521", 28, 19, 2, 521,
138    Arrays.asList(new Term(0, -1)), P521CrSequence(), simpleSmallCrSequence(19)
139    );
140
141    private static List<CarryReduce> P521CrSequence() {
142        List<CarryReduce> result = new ArrayList<CarryReduce>();
143        result.addAll(fullReduce(19));
144        result.addAll(simpleSmallCrSequence(19));
145        return result;
146    }
147
148    static FieldParams O256 = new FieldParams("P256OrderField", 26, 10, 1, 256,
149    new BigInteger("26959946660873538059280334323273029441504803697035324946844617595567"),
150    orderFieldCrSequence(10), orderFieldSmallCrSequence(10)
151    );
152
153    static FieldParams O384 = new FieldParams("P384OrderField", 28, 14, 1, 384,
154    new BigInteger("1388124618062372383947042015309946732620727252194336364173"),
155    orderFieldCrSequence(14), orderFieldSmallCrSequence(14)
156    );
157
158    static FieldParams O521 = new FieldParams("P521OrderField", 28, 19, 1, 521,
159    new BigInteger("657877501894328237357444332315020117536923257219387276263472201219398408051703"),
160    o521crSequence(19), orderFieldSmallCrSequence(19)
161    );
162
163    private static List<CarryReduce> o521crSequence(int numLimbs) {
164
165        // split the full reduce in half, with a carry in between
166        List<CarryReduce> result = new ArrayList<CarryReduce>();
167        result.addAll(fullCarry(2 * numLimbs));
168        for (int i = 2 * numLimbs - 1; i >= numLimbs + numLimbs/2; i--) {
169            result.add(new Reduce(i));
170        }
171        // carry
172        for (int i = numLimbs; i < numLimbs + numLimbs / 2 - 1; i++) {
173            result.add(new Carry(i));
174        }
175        // rest of reduce
176        for (int i = numLimbs + numLimbs/2 - 1; i >= numLimbs; i--) {
177            result.add(new Reduce(i));
178        }
179        result.addAll(orderFieldSmallCrSequence(numLimbs));
180
181        return result;
182    }
183
184    private static List<CarryReduce> orderFieldCrSequence(int numLimbs) {
185        List<CarryReduce> result = new ArrayList<CarryReduce>();
186        result.addAll(fullCarry(2 * numLimbs));
187        result.add(new Reduce(2 * numLimbs - 1));
188        result.addAll(fullReduce(numLimbs));
189        result.addAll(fullCarry(numLimbs + 1));
190        result.add(new Reduce(numLimbs));
191        result.addAll(fullCarry(numLimbs));
192
193        return result;
194    }
195   private static List<CarryReduce> orderFieldSmallCrSequence(int numLimbs) {
196        List<CarryReduce> result = new ArrayList<CarryReduce>();
197        result.addAll(fullCarry(numLimbs + 1));
198        result.add(new Reduce(numLimbs));
199        result.addAll(fullCarry(numLimbs));
200        return result;
201    }
202
203    static final FieldParams[] ALL_FIELDS = {P256, P384, P521, O256, O384, O521};
204
205    public static class Term {
206        private final int power;
207        private final int coefficient;
208
209        public Term(int power, int coefficient) {
210            this.power = power;
211            this.coefficient = coefficient;
212        }
213
214        public int getPower() {
215            return power;
216        }
217
218        public int getCoefficient() {
219            return coefficient;
220        }
221
222        public BigInteger getValue() {
223            return BigInteger.valueOf(2).pow(power).multiply(BigInteger.valueOf(coefficient));
224        }
225
226    }
227
228    static abstract class CarryReduce {
229        private final int index;
230
231        protected CarryReduce(int index) {
232            this.index = index;
233        }
234
235        public int getIndex() {
236            return index;
237        }
238
239        public abstract void write(CodeBuffer out, FieldParams params, String prefix, Iterable<CarryReduce> remaining);
240    }
241
242    static class Carry extends CarryReduce {
243        public Carry(int index) {
244            super(index);
245        }
246
247        public void write(CodeBuffer out, FieldParams params, String prefix, Iterable<CarryReduce> remaining) {
248            carry(out, params, prefix, getIndex());
249        }
250    }
251
252    static class Reduce extends CarryReduce {
253        public Reduce(int index) {
254            super(index);
255        }
256
257        public void write(CodeBuffer out, FieldParams params, String prefix, Iterable<CarryReduce> remaining) {
258            reduce(out, params, prefix, getIndex(), remaining);
259        }
260    }
261
262    static class FieldParams {
263        private final String className;
264        private final int bitsPerLimb;
265        private final int numLimbs;
266        private final int maxAdds;
267        private final int power;
268        private final Iterable<Term> terms;
269        private final List<CarryReduce> crSequence;
270        private final List<CarryReduce> smallCrSequence;
271
272        public FieldParams(String className, int bitsPerLimb, int numLimbs, int maxAdds, int power,
273                           Iterable<Term> terms, List<CarryReduce> crSequence, List<CarryReduce> smallCrSequence) {
274            this.className = className;
275            this.bitsPerLimb = bitsPerLimb;
276            this.numLimbs = numLimbs;
277            this.maxAdds = maxAdds;
278            this.power = power;
279            this.terms = terms;
280            this.crSequence = crSequence;
281            this.smallCrSequence = smallCrSequence;
282        }
283
284        public FieldParams(String className, int bitsPerLimb, int numLimbs, int maxAdds, int power,
285                           BigInteger term, List<CarryReduce> crSequence, List<CarryReduce> smallCrSequence) {
286            this.className = className;
287            this.bitsPerLimb = bitsPerLimb;
288            this.numLimbs = numLimbs;
289            this.maxAdds = maxAdds;
290            this.power = power;
291            this.crSequence = crSequence;
292            this.smallCrSequence = smallCrSequence;
293
294            terms = buildTerms(term);
295        }
296
297        private Iterable<Term> buildTerms(BigInteger sub) {
298            // split a large subtrahend into smaller terms that are aligned with limbs
299            List<Term> result = new ArrayList<Term>();
300            BigInteger mod = BigInteger.valueOf(1 << bitsPerLimb);
301            int termIndex = 0;
302            while (!sub.equals(BigInteger.ZERO)) {
303                int coef = sub.mod(mod).intValue();
304                boolean plusOne = false;
305                if (coef > (1 << (bitsPerLimb - 1))) {
306                    coef = coef - (1 << bitsPerLimb);
307                    plusOne = true;
308                }
309                if (coef != 0) {
310                    int pow = termIndex * bitsPerLimb;
311                    result.add(new Term(pow, -coef));
312                }
313                sub = sub.shiftRight(bitsPerLimb);
314                if (plusOne) {
315                   sub = sub.add(BigInteger.ONE);
316                }
317                ++termIndex;
318            }
319            return result;
320        }
321
322        public String getClassName() {
323            return className;
324        }
325
326        public int getBitsPerLimb() {
327            return bitsPerLimb;
328        }
329
330        public int getNumLimbs() {
331            return numLimbs;
332        }
333
334        public int getMaxAdds() {
335            return maxAdds;
336        }
337
338        public int getPower() {
339            return power;
340        }
341
342        public Iterable<Term> getTerms() {
343            return terms;
344        }
345
346        public List<CarryReduce> getCrSequence() {
347            return crSequence;
348        }
349
350        public List<CarryReduce> getSmallCrSequence() {
351            return smallCrSequence;
352        }
353    }
354
355    static Collection<Carry> fullCarry(int numLimbs) {
356        List<Carry> result = new ArrayList<Carry>();
357        for (int i = 0; i < numLimbs - 1; i++) {
358            result.add(new Carry(i));
359        }
360        return result;
361    }
362
363    static Collection<Reduce> fullReduce(int numLimbs) {
364        List<Reduce> result = new ArrayList<Reduce>();
365        for (int i = numLimbs - 2; i >= 0; i--) {
366            result.add(new Reduce(i + numLimbs));
367        }
368        return result;
369    }
370
371    static List<CarryReduce> simpleCrSequence(int numLimbs) {
372        List<CarryReduce> result = new ArrayList<CarryReduce>();
373        for(int i = 0; i < 4; i++) {
374            result.addAll(fullCarry(2 * numLimbs - 1));
375            result.addAll(fullReduce(numLimbs));
376        }
377
378        return result;
379    }
380
381    static List<CarryReduce> simpleSmallCrSequence(int numLimbs) {
382        List<CarryReduce> result = new ArrayList<CarryReduce>();
383        // carry a few positions at the end
384        for (int i = numLimbs - 2; i < numLimbs; i++) {
385            result.add(new Carry(i));
386        }
387        // this carries out a single value that must be reduced back in
388        result.add(new Reduce(numLimbs));
389        // finish with a full carry
390        result.addAll(fullCarry(numLimbs));
391        return result;
392    }
393
394    private final String packageName;
395    private final String parentName;
396
397    public FieldGen(String packageName, String parentName) {
398        this.packageName = packageName;
399        this.parentName = parentName;
400    }
401
402    public static void main(String[] args) throws Exception {
403
404        FieldGen gen = new FieldGen("sun.security.util.math.intpoly", "IntegerPolynomial");
405        for(FieldParams p : ALL_FIELDS) {
406            gen.generateFile(p);
407        }
408    }
409
410    private void generateFile(FieldParams params) throws IOException {
411        String text = generate(params);
412        String fileName = params.getClassName() + ".java";
413        PrintWriter out = new PrintWriter(new FileWriter(fileName));
414        out.println(text);
415        out.close();
416    }
417
418    static class CodeBuffer {
419
420        private int nextTemporary = 0;
421        private Set<String> temporaries = new HashSet<String>();
422        private StringBuffer buffer = new StringBuffer();
423        private int indent = 0;
424        private Class lastCR;
425        private int lastCrCount = 0;
426        private int crMethodBreakCount = 0;
427        private int crNumLimbs = 0;
428
429        public void incrIndent() {
430            indent++;
431        }
432
433        public void decrIndent() {
434            indent--;
435        }
436
437        public void newTempScope() {
438            nextTemporary = 0;
439            temporaries.clear();
440        }
441
442        public void appendLine(String s) {
443            appendIndent();
444            buffer.append(s + "\n");
445        }
446
447        public void appendLine() {
448            buffer.append("\n");
449        }
450
451        public String toString() {
452            return buffer.toString();
453        }
454
455        public void startCrSequence(int numLimbs) {
456            this.crNumLimbs = numLimbs;
457            lastCrCount = 0;
458            crMethodBreakCount = 0;
459            lastCR = null;
460        }
461        /*
462         * Record a carry/reduce of the specified type. This method is used to
463         * break up large carry/reduce sequences into multiple methods to make
464         * JIT/optimization easier
465         */
466        public void record(Class type) {
467            if (type == lastCR) {
468                lastCrCount++;
469            } else {
470
471                if (lastCrCount >= 8) {
472                    insertCrMethodBreak();
473                }
474
475                lastCR = type;
476                lastCrCount = 0;
477            }
478        }
479
480        private void insertCrMethodBreak() {
481
482            appendLine();
483
484            // call the new method
485            appendIndent();
486            append("carryReduce" + crMethodBreakCount + "(r");
487            for(int i = 0; i < crNumLimbs; i++) {
488                append(", c" + i);
489            }
490            // temporaries are not live between operations, no need to send
491            append(");\n");
492
493            decrIndent();
494            appendLine("}");
495
496            // make the method
497            appendIndent();
498            append("void carryReduce" + crMethodBreakCount + "(long[] r");
499            for(int i = 0; i < crNumLimbs; i++) {
500                append (", long c" + i);
501            }
502            append(") {\n");
503            incrIndent();
504            // declare temporaries
505            for(String temp : temporaries) {
506                appendLine("long " + temp + ";");
507            }
508            append("\n");
509
510            crMethodBreakCount++;
511        }
512
513        public String getTemporary(String type, String value) {
514            Iterator<String> iter = temporaries.iterator();
515            if(iter.hasNext()) {
516                String result = iter.next();
517                iter.remove();
518                appendLine(result + " = " + value + ";");
519                return result;
520            } else {
521                String result = "t" + (nextTemporary++);
522                appendLine(type + " " + result + " = " + value + ";");
523                return result;
524            }
525        }
526
527        public void freeTemporary(String temp) {
528            temporaries.add(temp);
529        }
530
531        public void appendIndent() {
532            for(int i = 0; i < indent; i++) {
533                buffer.append("    ");
534            }
535        }
536
537        public void append(String s) {
538            buffer.append(s);
539        }
540    }
541
542    private String generate(FieldParams params) throws IOException {
543        CodeBuffer result = new CodeBuffer();
544        String header = readHeader();
545        result.appendLine(header);
546
547        if (packageName != null) {
548            result.appendLine("package " + packageName + ";");
549            result.appendLine();
550        }
551        result.appendLine("import java.math.BigInteger;");
552
553        result.appendLine("public class " + params.getClassName() + " extends " + this.parentName + " {");
554        result.incrIndent();
555
556        result.appendLine("private static final int BITS_PER_LIMB = " + params.getBitsPerLimb() + ";");
557        result.appendLine("private static final int NUM_LIMBS = " + params.getNumLimbs() + ";");
558        result.appendLine("private static final int MAX_ADDS = " + params.getMaxAdds() + ";");
559        result.appendLine("public static final BigInteger MODULUS = evaluateModulus();");
560        result.appendLine("private static final long CARRY_ADD = 1 << " + (params.getBitsPerLimb() - 1) + ";");
561        if (params.getBitsPerLimb() * params.getNumLimbs() != params.getPower()) {
562            result.appendLine("private static final int LIMB_MASK = -1 >>> (64 - BITS_PER_LIMB);");
563        }
564        int termIndex = 0;
565
566        result.appendLine("public " + params.getClassName() + "() {");
567        result.appendLine();
568        result.appendLine("    super(BITS_PER_LIMB, NUM_LIMBS, MAX_ADDS, MODULUS);");
569        result.appendLine();
570        result.appendLine("}");
571
572        result.appendLine("private static BigInteger evaluateModulus() {");
573        result.incrIndent();
574        result.appendLine("BigInteger result = BigInteger.valueOf(2).pow(" + params.getPower() + ");");
575        for(Term t : params.getTerms()) {
576            boolean subtract = false;
577            int coefValue = t.getCoefficient();
578            if (coefValue < 0) {
579                coefValue = 0 - coefValue;
580                subtract = true;
581            }
582            String coefExpr = "BigInteger.valueOf(" + coefValue + ")";
583            String powExpr = "BigInteger.valueOf(2).pow(" + t.getPower() + ")";
584            String termExpr = "ERROR";
585            if (t.getPower() == 0) {
586                termExpr = coefExpr;
587            } else if (coefValue == 1) {
588                termExpr = powExpr;
589            } else {
590                termExpr = powExpr + ".multiply(" + coefExpr + ")";
591            }
592            if (subtract) {
593                result.appendLine("result = result.subtract(" + termExpr + ");");
594            } else {
595                result.appendLine("result = result.add(" + termExpr + ");");
596            }
597        }
598        result.appendLine("return result;");
599        result.decrIndent();
600        result.appendLine("}");
601
602        result.appendLine("@Override");
603        result.appendLine("protected void finalCarryReduceLast(long[] limbs) {");
604        result.incrIndent();
605        int extraBits = params.getBitsPerLimb() * params.getNumLimbs() - params.getPower();
606        int highBits = params.getBitsPerLimb() - extraBits;
607        result.appendLine("long c = limbs[" + (params.getNumLimbs() - 1) + "] >> " + highBits + ";");
608        result.appendLine("limbs[" + (params.getNumLimbs() - 1) + "] -= c << " + highBits + ";");
609        for (Term t : params.getTerms()) {
610            int reduceBits = params.getPower() + extraBits - t.getPower();
611            int negatedCoefficient = -1 * t.getCoefficient();
612            modReduceInBits(result, params, true, "limbs", params.getNumLimbs(), reduceBits, negatedCoefficient, "c");
613        }
614        result.decrIndent();
615        result.appendLine("}");
616
617        // full carry/reduce sequence
618        result.appendIndent();
619        result.append("private void carryReduce(long[] r, ");
620        for(int i = 0; i < 2 * params.getNumLimbs() - 1; i++) {
621            result.append ("long c" + i);
622            if (i < 2 * params.getNumLimbs() - 2) {
623                result.append(", ");
624            }
625        }
626        result.append(") {\n");
627        result.newTempScope();
628        result.incrIndent();
629        result.appendLine("long c" + (2 * params.getNumLimbs() - 1) + " = 0;");
630        write(result, params.getCrSequence(), params, "c", 2 * params.getNumLimbs());
631        result.appendLine();
632        for (int i = 0; i < params.getNumLimbs(); i++) {
633            result.appendLine("r[" + i + "] = c" + i + ";");
634        }
635        result.decrIndent();
636        result.appendLine("}");
637
638        // small carry/reduce sequence
639        result.appendIndent();
640        result.append("private void carryReduce(long[] r, ");
641        for(int i = 0; i < params.getNumLimbs(); i++) {
642            result.append ("long c" + i);
643            if (i < params.getNumLimbs() - 1) {
644                result.append(", ");
645            }
646        }
647        result.append(") {\n");
648        result.newTempScope();
649        result.incrIndent();
650        result.appendLine("long c" + params.getNumLimbs() + " = 0;");
651        write(result, params.getSmallCrSequence(), params, "c", params.getNumLimbs() + 1);
652        result.appendLine();
653        for (int i = 0; i < params.getNumLimbs(); i++) {
654            result.appendLine("r[" + i + "] = c" + i + ";");
655        }
656        result.decrIndent();
657        result.appendLine("}");
658
659        result.appendLine("@Override");
660        result.appendLine("protected void mult(long[] a, long[] b, long[] r) {");
661        result.incrIndent();
662        for (int i = 0; i < 2 * params.getNumLimbs() - 1; i++) {
663            result.appendIndent();
664            result.append("long c" + i + " = ");
665            int startJ = Math.max(i + 1 - params.getNumLimbs(), 0);
666            int endJ = Math.min(params.getNumLimbs(), i + 1);
667            for (int j = startJ; j < endJ; j++) {
668                int bIndex = i - j;
669                result.append("(a[" + j + "] * b[" + bIndex + "])");
670                if (j < endJ - 1) {
671                    result.append(" + ");
672                }
673            }
674            result.append(";\n");
675        }
676        result.appendLine();
677        result.appendIndent();
678        result.append("carryReduce(r, ");
679        for (int i = 0; i < 2 * params.getNumLimbs() - 1; i++) {
680            result.append("c" + i);
681            if (i < 2 * params.getNumLimbs() - 2) {
682                result.append(", ");
683            }
684        }
685        result.append(");\n");
686        result.decrIndent();
687        result.appendLine("}");
688
689        result.appendLine("@Override");
690        result.appendLine("protected void reduce(long[] a) {");
691        result.incrIndent();
692        result.appendIndent();
693        result.append("carryReduce(a, ");
694        for (int i = 0; i < params.getNumLimbs(); i++) {
695            result.append("a[" + i + "]");
696            if (i < params.getNumLimbs() - 1) {
697                result.append(", ");
698            }
699        }
700        result.append(");\n");
701        result.decrIndent();
702        result.appendLine("}");
703
704        result.appendLine("@Override");
705        result.appendLine("protected void square(long[] a, long[] r) {");
706        result.incrIndent();
707        for (int i = 0; i < 2 * params.getNumLimbs() - 1; i++) {
708            result.appendIndent();
709            result.append("long c" + i + " = ");
710            int startJ = Math.max(i + 1 - params.getNumLimbs(), 0);
711            int endJ = Math.min(params.getNumLimbs(), i + 1);
712            int jDiff = endJ - startJ;
713            if (jDiff > 1) {
714                result.append("2 * (");
715            }
716            for (int j = 0; j < jDiff / 2; j++) {
717                int aIndex = j + startJ;
718                int bIndex = i - aIndex;
719                result.append("(a[" + aIndex + "] * a[" + bIndex + "])");
720                if (j < (jDiff / 2) - 1) {
721                    result.append(" + ");
722                }
723            }
724            if (jDiff > 1) {
725                result.append(")");
726            }
727            if (jDiff % 2 == 1) {
728                int aIndex = i / 2;
729                if (jDiff > 1) {
730                    result.append (" + ");
731                }
732                result.append("(a[" + aIndex + "] * a[" + aIndex + "])");
733            }
734            result.append(";\n");
735        }
736        result.appendLine();
737        result.appendIndent();
738        result.append("carryReduce(r, ");
739        for (int i = 0; i < 2 * params.getNumLimbs() - 1; i++) {
740            result.append("c" + i);
741            if (i < 2 * params.getNumLimbs() - 2) {
742                result.append(", ");
743            }
744        }
745        result.append(");\n");
746        result.decrIndent();
747        result.appendLine("}");
748
749        result.decrIndent();
750        result.appendLine("}"); // end class
751
752        return result.toString();
753    }
754
755    private static void write(CodeBuffer out, List<CarryReduce> sequence, FieldParams params, String prefix, int numLimbs) {
756
757        out.startCrSequence(numLimbs);
758        for (int i = 0; i < sequence.size(); i++) {
759            CarryReduce cr = sequence.get(i);
760            Iterator<CarryReduce> remainingIter = sequence.listIterator(i + 1);
761            List<CarryReduce> remaining = new ArrayList<CarryReduce>();
762            remainingIter.forEachRemaining(remaining::add);
763            cr.write(out, params, prefix, remaining);
764        }
765    }
766
767    private static void reduce(CodeBuffer out, FieldParams params, String prefix, int index, Iterable<CarryReduce> remaining) {
768
769        out.record(Reduce.class);
770
771        out.appendLine("//reduce from position " + index);
772        String reduceFrom = indexedExpr(false, prefix, index);
773        boolean referenced = false;
774        for (CarryReduce cr : remaining) {
775            if(cr.index == index) {
776                referenced = true;
777            }
778        }
779        for (Term t : params.getTerms()) {
780            int reduceBits = params.getPower() - t.getPower();
781            int negatedCoefficient = -1 * t.getCoefficient();
782            modReduceInBits(out, params, false, prefix, index, reduceBits, negatedCoefficient, reduceFrom);
783        }
784        if (referenced) {
785            out.appendLine(reduceFrom + " = 0;");
786        }
787    }
788
789    private static void carry(CodeBuffer out, FieldParams params, String prefix, int index) {
790
791        out.record(Carry.class);
792
793        out.appendLine("//carry from position " + index);
794        String carryFrom = prefix + index;
795        String carryTo = prefix + (index + 1);
796        String carry = "(" + carryFrom + " + CARRY_ADD) >> " + params.getBitsPerLimb();
797        String temp = out.getTemporary("long", carry);
798        out.appendLine(carryFrom + " -= (" + temp + " << " + params.getBitsPerLimb() + ");");
799        out.appendLine(carryTo + " += " + temp + ";");
800        out.freeTemporary(temp);
801    }
802
803    private static String indexedExpr(boolean isArray, String prefix, int index) {
804        String result = prefix + index;
805        if (isArray) {
806            result = prefix + "[" + index + "]";
807        }
808        return result;
809    }
810
811    private static void modReduceInBits(CodeBuffer result, FieldParams params, boolean isArray, String prefix, int index, int reduceBits, int coefficient, String c) {
812
813        String x = coefficient + " * " + c;
814        String accOp = "+=";
815        String temp = null;
816        if (coefficient == 1) {
817            x = c;
818        } else if (coefficient == -1) {
819            x = c;
820            accOp = "-=";
821        } else {
822            temp = result.getTemporary("long", x);
823            x = temp;
824        }
825
826        if (reduceBits % params.getBitsPerLimb() == 0) {
827            int pos = reduceBits / params.getBitsPerLimb();
828            result.appendLine(indexedExpr(isArray, prefix, (index - pos)) + " " + accOp + " " + x + ";");
829        } else {
830            int secondPos = reduceBits / params.getBitsPerLimb();
831            int bitOffset = (secondPos + 1) * params.getBitsPerLimb() - reduceBits;
832            int rightBitOffset = params.getBitsPerLimb() - bitOffset;
833            result.appendLine(indexedExpr(isArray, prefix, (index - (secondPos + 1))) + " " + accOp + " (" + x + " << " + bitOffset + ") & LIMB_MASK;");
834            result.appendLine(indexedExpr(isArray, prefix, (index - secondPos)) + " " + accOp + " " + x + " >> " + rightBitOffset + ";");
835        }
836
837        if (temp != null) {
838            result.freeTemporary(temp);
839        }
840    }
841
842    private String readHeader() throws IOException {
843        BufferedReader reader = Files.newBufferedReader(Paths.get("header.txt"));
844        StringBuffer result = new StringBuffer();
845        reader.lines().forEach(s -> result.append(s + "\n"));
846        return result.toString();
847    }
848}
849
850FieldGen.main(null);
851
852