1// Copyright 2020 ConsenSys Software Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package amd64
16
17import (
18	"fmt"
19
20	"github.com/consensys/bavard/amd64"
21)
22
23// MulADX uses AX, DX and BP
24// sets x * y into t, without modular reduction
25// x() will have more accesses than y()
26// (caller should store x in registers, if possible)
27// if no (tmp) register is available, this uses one PUSH/POP on the stack in the hot loop.
28func (f *FFAmd64) MulADX(registers *amd64.Registers, x, y func(int) string, t []amd64.Register) []amd64.Register {
29	// registers
30	var tr amd64.Register // temporary register
31	A := amd64.BP
32
33	hasFreeRegister := registers.Available() > 0
34	if hasFreeRegister {
35		tr = registers.Pop()
36	} else {
37		tr = A
38	}
39
40	f.LabelRegisters("A", A)
41	f.LabelRegisters("t", t...)
42
43	for i := 0; i < f.NbWords; i++ {
44		f.Comment("clear the flags")
45		f.XORQ(amd64.AX, amd64.AX)
46
47		f.MOVQ(y(i), amd64.DX)
48
49		// for j=0 to N-1
50		//    (A,t[j])  := t[j] + x[j]*y[i] + A
51		if i == 0 {
52			for j := 0; j < f.NbWords; j++ {
53				f.Comment(fmt.Sprintf("(A,t[%[1]d])  := x[%[1]d]*y[%[2]d] + A", j, i))
54
55				if j == 0 {
56					f.MULXQ(x(j), t[j], t[j+1])
57				} else {
58					highBits := A
59					if j != f.NbWordsLastIndex {
60						highBits = t[j+1]
61					}
62					f.MULXQ(x(j), amd64.AX, highBits)
63					f.ADOXQ(amd64.AX, t[j])
64				}
65			}
66		} else {
67			for j := 0; j < f.NbWords; j++ {
68				f.Comment(fmt.Sprintf("(A,t[%[1]d])  := t[%[1]d] + x[%[1]d]*y[%[2]d] + A", j, i))
69
70				if j != 0 {
71					f.ADCXQ(A, t[j])
72				}
73				f.MULXQ(x(j), amd64.AX, A)
74				f.ADOXQ(amd64.AX, t[j])
75			}
76		}
77
78		f.Comment("A += carries from ADCXQ and ADOXQ")
79		f.MOVQ(0, amd64.AX)
80		if i != 0 {
81			f.ADCXQ(amd64.AX, A)
82		}
83		f.ADOXQ(amd64.AX, A)
84
85		if !hasFreeRegister {
86			f.PUSHQ(A)
87		}
88
89		// m := t[0]*q'[0] mod W
90		f.Comment("m := t[0]*q'[0] mod W")
91		m := amd64.DX
92		// f.MOVQ(t[0], m)
93		// f.MULXQ(f.qInv0(), m, amd64.AX)
94		f.MOVQ(f.qInv0(), m)
95		f.IMULQ(t[0], m)
96
97		// clear the carry flags
98		f.Comment("clear the flags")
99		f.XORQ(amd64.AX, amd64.AX)
100
101		// C,_ := t[0] + m*q[0]
102		f.Comment("C,_ := t[0] + m*q[0]")
103
104		f.MULXQ(f.qAt(0), amd64.AX, tr)
105		f.ADCXQ(t[0], amd64.AX)
106		f.MOVQ(tr, t[0])
107
108		if !hasFreeRegister {
109			f.POPQ(A)
110		}
111		// for j=1 to N-1
112		//    (C,t[j-1]) := t[j] + m*q[j] + C
113		for j := 1; j < f.NbWords; j++ {
114			f.Comment(fmt.Sprintf("(C,t[%[1]d]) := t[%[2]d] + m*q[%[2]d] + C", j-1, j))
115			f.ADCXQ(t[j], t[j-1])
116			f.MULXQ(f.qAt(j), amd64.AX, t[j])
117			f.ADOXQ(amd64.AX, t[j-1])
118		}
119
120		f.Comment(fmt.Sprintf("t[%d] = C + A", f.NbWordsLastIndex))
121		f.MOVQ(0, amd64.AX)
122		f.ADCXQ(amd64.AX, t[f.NbWordsLastIndex])
123		f.ADOXQ(A, t[f.NbWordsLastIndex])
124
125	}
126
127	if hasFreeRegister {
128		registers.Push(tr)
129	}
130
131	return t
132}
133
134func (f *FFAmd64) generateMul(forceADX bool) {
135	f.Comment("mul(res, x, y *Element)")
136
137	const argSize = 3 * 8
138	minStackSize := argSize
139	if forceADX {
140		minStackSize = 0
141	}
142	stackSize := f.StackSize(f.NbWords*2, 2, minStackSize)
143	registers := f.FnHeader("mul", stackSize, argSize, amd64.DX, amd64.AX)
144	defer f.AssertCleanStack(stackSize, minStackSize)
145
146	f.WriteLn(`
147	// the algorithm is described here
148	// https://hackmd.io/@zkteam/modular_multiplication
149	// however, to benefit from the ADCX and ADOX carry chains
150	// we split the inner loops in 2:
151	// for i=0 to N-1
152	// 		for j=0 to N-1
153	// 		    (A,t[j])  := t[j] + x[j]*y[i] + A
154	// 		m := t[0]*q'[0] mod W
155	// 		C,_ := t[0] + m*q[0]
156	// 		for j=1 to N-1
157	// 		    (C,t[j-1]) := t[j] + m*q[j] + C
158	// 		t[N-1] = C + A
159	`)
160	if stackSize > 0 {
161		f.WriteLn("NO_LOCAL_POINTERS")
162	}
163
164	noAdx := f.NewLabel()
165
166	if !forceADX {
167		// check ADX instruction support
168		f.CMPB("·supportAdx(SB)", 1)
169		f.JNE(noAdx)
170	}
171
172	{
173		// we need to access x and y words, per index
174		var xat, yat func(int) string
175		var gc func()
176
177		// we need NbWords registers for t, plus optionally one for tmp register in mulADX if we want to avoid PUSH/POP
178		nbRegisters := registers.Available()
179		if nbRegisters < f.NbWords {
180			panic("not enough registers, not supported.")
181		}
182
183		t := registers.PopN(f.NbWords)
184		nbRegisters = registers.Available()
185		switch nbRegisters {
186		case 0:
187			// y is access through use of AX/DX
188			yat = func(i int) string {
189				y := amd64.AX
190				f.MOVQ("y+16(FP)", y)
191				return y.At(i)
192			}
193
194			// we move x on the stack.
195			f.MOVQ("x+8(FP)", amd64.AX)
196			_x := f.PopN(&registers, true)
197			f.LabelRegisters("x", _x...)
198			f.Mov(amd64.AX, t)
199			f.Mov(t, _x)
200			xat = func(i int) string {
201				return string(_x[i])
202			}
203			gc = func() {
204				f.Push(&registers, _x...)
205			}
206		case 1:
207			// y is access through use of AX/DX
208			yat = func(i int) string {
209				y := amd64.AX
210				f.MOVQ("y+16(FP)", y)
211				return y.At(i)
212			}
213			// x uses the register
214			x := registers.Pop()
215			f.MOVQ("x+8(FP)", x)
216			xat = func(i int) string {
217				return x.At(i)
218			}
219			gc = func() {
220				registers.Push(x)
221			}
222		case 2, 3:
223			// x, y uses registers
224			x := registers.Pop()
225			y := registers.Pop()
226
227			f.MOVQ("x+8(FP)", x)
228			f.MOVQ("y+16(FP)", y)
229
230			xat = func(i int) string {
231				return x.At(i)
232			}
233
234			yat = func(i int) string {
235				return y.At(i)
236			}
237			gc = func() {
238				registers.Push(x, y)
239			}
240		default:
241			// we have a least 4 registers.
242			// 1 for tmp.
243			nbRegisters--
244			// 1 for y
245			nbRegisters--
246			var y amd64.Register
247
248			if nbRegisters >= f.NbWords {
249				// we store x fully in registers
250				x := registers.Pop()
251				f.MOVQ("x+8(FP)", x)
252				_x := registers.PopN(f.NbWords)
253				f.LabelRegisters("x", _x...)
254				f.Mov(x, _x)
255
256				xat = func(i int) string {
257					return string(_x[i])
258				}
259				registers.Push(x)
260				gc = func() {
261					registers.Push(y)
262					registers.Push(_x...)
263				}
264			} else {
265				// we take at least 1 register for x addr
266				nbRegisters--
267				x := registers.Pop()
268				y = registers.Pop() // temporary lock 1 for y
269				f.MOVQ("x+8(FP)", x)
270
271				// and use the rest for x0...xn
272				_x := registers.PopN(nbRegisters)
273				f.LabelRegisters("x", _x...)
274				for i := 0; i < len(_x); i++ {
275					f.MOVQ(x.At(i), _x[i])
276				}
277				xat = func(i int) string {
278					if i < len(_x) {
279						return string(_x[i])
280					}
281					return x.At(i)
282				}
283				registers.Push(y)
284
285				gc = func() {
286					registers.Push(x, y)
287					registers.Push(_x...)
288				}
289
290			}
291			y = registers.Pop()
292
293			f.MOVQ("y+16(FP)", y)
294			yat = func(i int) string {
295				return y.At(i)
296			}
297
298		}
299
300		f.MulADX(&registers, xat, yat, t)
301		gc()
302
303		// ---------------------------------------------------------------------------------------------
304		// reduce
305		f.Reduce(&registers, t)
306
307		f.MOVQ("res+0(FP)", amd64.AX)
308		f.Mov(t, amd64.AX)
309		f.RET()
310	}
311
312	// ---------------------------------------------------------------------------------------------
313	// no MULX, ADX instructions
314	if !forceADX {
315		f.LABEL(noAdx)
316
317		f.MOVQ("res+0(FP)", amd64.AX)
318		f.MOVQ(amd64.AX, "(SP)")
319		f.MOVQ("x+8(FP)", amd64.AX)
320		f.MOVQ(amd64.AX, "8(SP)")
321		f.MOVQ("y+16(FP)", amd64.AX)
322		f.MOVQ(amd64.AX, "16(SP)")
323		f.WriteLn("CALL ·_mulGeneric(SB)")
324		f.RET()
325
326	}
327}
328