1// Copyright (c) 2021 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package main
6
7import (
8	"fmt"
9
10	. "github.com/mmcloughlin/avo/build"
11	. "github.com/mmcloughlin/avo/gotypes"
12	. "github.com/mmcloughlin/avo/operand"
13	. "github.com/mmcloughlin/avo/reg"
14)
15
16//go:generate go run . -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field
17
18func main() {
19	Package("golang.org/x/crypto/curve25519/internal/field")
20	ConstraintExpr("amd64,gc,!purego")
21	feMul()
22	feSquare()
23	Generate()
24}
25
26type namedComponent struct {
27	Component
28	name string
29}
30
31func (c namedComponent) String() string { return c.name }
32
33type uint128 struct {
34	name   string
35	hi, lo GPVirtual
36}
37
38func (c uint128) String() string { return c.name }
39
40func feSquare() {
41	TEXT("feSquare", NOSPLIT, "func(out, a *Element)")
42	Doc("feSquare sets out = a * a. It works like feSquareGeneric.")
43	Pragma("noescape")
44
45	a := Dereference(Param("a"))
46	l0 := namedComponent{a.Field("l0"), "l0"}
47	l1 := namedComponent{a.Field("l1"), "l1"}
48	l2 := namedComponent{a.Field("l2"), "l2"}
49	l3 := namedComponent{a.Field("l3"), "l3"}
50	l4 := namedComponent{a.Field("l4"), "l4"}
51
52	// r0 = l0×l0 + 19×2×(l1×l4 + l2×l3)
53	r0 := uint128{"r0", GP64(), GP64()}
54	mul64(r0, 1, l0, l0)
55	addMul64(r0, 38, l1, l4)
56	addMul64(r0, 38, l2, l3)
57
58	// r1 = 2×l0×l1 + 19×2×l2×l4 + 19×l3×l3
59	r1 := uint128{"r1", GP64(), GP64()}
60	mul64(r1, 2, l0, l1)
61	addMul64(r1, 38, l2, l4)
62	addMul64(r1, 19, l3, l3)
63
64	// r2 = = 2×l0×l2 + l1×l1 + 19×2×l3×l4
65	r2 := uint128{"r2", GP64(), GP64()}
66	mul64(r2, 2, l0, l2)
67	addMul64(r2, 1, l1, l1)
68	addMul64(r2, 38, l3, l4)
69
70	// r3 = = 2×l0×l3 + 2×l1×l2 + 19×l4×l4
71	r3 := uint128{"r3", GP64(), GP64()}
72	mul64(r3, 2, l0, l3)
73	addMul64(r3, 2, l1, l2)
74	addMul64(r3, 19, l4, l4)
75
76	// r4 = = 2×l0×l4 + 2×l1×l3 + l2×l2
77	r4 := uint128{"r4", GP64(), GP64()}
78	mul64(r4, 2, l0, l4)
79	addMul64(r4, 2, l1, l3)
80	addMul64(r4, 1, l2, l2)
81
82	Comment("First reduction chain")
83	maskLow51Bits := GP64()
84	MOVQ(Imm((1<<51)-1), maskLow51Bits)
85	c0, r0lo := shiftRightBy51(&r0)
86	c1, r1lo := shiftRightBy51(&r1)
87	c2, r2lo := shiftRightBy51(&r2)
88	c3, r3lo := shiftRightBy51(&r3)
89	c4, r4lo := shiftRightBy51(&r4)
90	maskAndAdd(r0lo, maskLow51Bits, c4, 19)
91	maskAndAdd(r1lo, maskLow51Bits, c0, 1)
92	maskAndAdd(r2lo, maskLow51Bits, c1, 1)
93	maskAndAdd(r3lo, maskLow51Bits, c2, 1)
94	maskAndAdd(r4lo, maskLow51Bits, c3, 1)
95
96	Comment("Second reduction chain (carryPropagate)")
97	// c0 = r0 >> 51
98	MOVQ(r0lo, c0)
99	SHRQ(Imm(51), c0)
100	// c1 = r1 >> 51
101	MOVQ(r1lo, c1)
102	SHRQ(Imm(51), c1)
103	// c2 = r2 >> 51
104	MOVQ(r2lo, c2)
105	SHRQ(Imm(51), c2)
106	// c3 = r3 >> 51
107	MOVQ(r3lo, c3)
108	SHRQ(Imm(51), c3)
109	// c4 = r4 >> 51
110	MOVQ(r4lo, c4)
111	SHRQ(Imm(51), c4)
112	maskAndAdd(r0lo, maskLow51Bits, c4, 19)
113	maskAndAdd(r1lo, maskLow51Bits, c0, 1)
114	maskAndAdd(r2lo, maskLow51Bits, c1, 1)
115	maskAndAdd(r3lo, maskLow51Bits, c2, 1)
116	maskAndAdd(r4lo, maskLow51Bits, c3, 1)
117
118	Comment("Store output")
119	out := Dereference(Param("out"))
120	Store(r0lo, out.Field("l0"))
121	Store(r1lo, out.Field("l1"))
122	Store(r2lo, out.Field("l2"))
123	Store(r3lo, out.Field("l3"))
124	Store(r4lo, out.Field("l4"))
125
126	RET()
127}
128
129func feMul() {
130	TEXT("feMul", NOSPLIT, "func(out, a, b *Element)")
131	Doc("feMul sets out = a * b. It works like feMulGeneric.")
132	Pragma("noescape")
133
134	a := Dereference(Param("a"))
135	a0 := namedComponent{a.Field("l0"), "a0"}
136	a1 := namedComponent{a.Field("l1"), "a1"}
137	a2 := namedComponent{a.Field("l2"), "a2"}
138	a3 := namedComponent{a.Field("l3"), "a3"}
139	a4 := namedComponent{a.Field("l4"), "a4"}
140
141	b := Dereference(Param("b"))
142	b0 := namedComponent{b.Field("l0"), "b0"}
143	b1 := namedComponent{b.Field("l1"), "b1"}
144	b2 := namedComponent{b.Field("l2"), "b2"}
145	b3 := namedComponent{b.Field("l3"), "b3"}
146	b4 := namedComponent{b.Field("l4"), "b4"}
147
148	// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
149	r0 := uint128{"r0", GP64(), GP64()}
150	mul64(r0, 1, a0, b0)
151	addMul64(r0, 19, a1, b4)
152	addMul64(r0, 19, a2, b3)
153	addMul64(r0, 19, a3, b2)
154	addMul64(r0, 19, a4, b1)
155
156	// r1 = a0×b1 + a1×b0 + 19×(a2×b4 + a3×b3 + a4×b2)
157	r1 := uint128{"r1", GP64(), GP64()}
158	mul64(r1, 1, a0, b1)
159	addMul64(r1, 1, a1, b0)
160	addMul64(r1, 19, a2, b4)
161	addMul64(r1, 19, a3, b3)
162	addMul64(r1, 19, a4, b2)
163
164	// r2 = a0×b2 + a1×b1 + a2×b0 + 19×(a3×b4 + a4×b3)
165	r2 := uint128{"r2", GP64(), GP64()}
166	mul64(r2, 1, a0, b2)
167	addMul64(r2, 1, a1, b1)
168	addMul64(r2, 1, a2, b0)
169	addMul64(r2, 19, a3, b4)
170	addMul64(r2, 19, a4, b3)
171
172	// r3 = a0×b3 + a1×b2 + a2×b1 + a3×b0 + 19×a4×b4
173	r3 := uint128{"r3", GP64(), GP64()}
174	mul64(r3, 1, a0, b3)
175	addMul64(r3, 1, a1, b2)
176	addMul64(r3, 1, a2, b1)
177	addMul64(r3, 1, a3, b0)
178	addMul64(r3, 19, a4, b4)
179
180	// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
181	r4 := uint128{"r4", GP64(), GP64()}
182	mul64(r4, 1, a0, b4)
183	addMul64(r4, 1, a1, b3)
184	addMul64(r4, 1, a2, b2)
185	addMul64(r4, 1, a3, b1)
186	addMul64(r4, 1, a4, b0)
187
188	Comment("First reduction chain")
189	maskLow51Bits := GP64()
190	MOVQ(Imm((1<<51)-1), maskLow51Bits)
191	c0, r0lo := shiftRightBy51(&r0)
192	c1, r1lo := shiftRightBy51(&r1)
193	c2, r2lo := shiftRightBy51(&r2)
194	c3, r3lo := shiftRightBy51(&r3)
195	c4, r4lo := shiftRightBy51(&r4)
196	maskAndAdd(r0lo, maskLow51Bits, c4, 19)
197	maskAndAdd(r1lo, maskLow51Bits, c0, 1)
198	maskAndAdd(r2lo, maskLow51Bits, c1, 1)
199	maskAndAdd(r3lo, maskLow51Bits, c2, 1)
200	maskAndAdd(r4lo, maskLow51Bits, c3, 1)
201
202	Comment("Second reduction chain (carryPropagate)")
203	// c0 = r0 >> 51
204	MOVQ(r0lo, c0)
205	SHRQ(Imm(51), c0)
206	// c1 = r1 >> 51
207	MOVQ(r1lo, c1)
208	SHRQ(Imm(51), c1)
209	// c2 = r2 >> 51
210	MOVQ(r2lo, c2)
211	SHRQ(Imm(51), c2)
212	// c3 = r3 >> 51
213	MOVQ(r3lo, c3)
214	SHRQ(Imm(51), c3)
215	// c4 = r4 >> 51
216	MOVQ(r4lo, c4)
217	SHRQ(Imm(51), c4)
218	maskAndAdd(r0lo, maskLow51Bits, c4, 19)
219	maskAndAdd(r1lo, maskLow51Bits, c0, 1)
220	maskAndAdd(r2lo, maskLow51Bits, c1, 1)
221	maskAndAdd(r3lo, maskLow51Bits, c2, 1)
222	maskAndAdd(r4lo, maskLow51Bits, c3, 1)
223
224	Comment("Store output")
225	out := Dereference(Param("out"))
226	Store(r0lo, out.Field("l0"))
227	Store(r1lo, out.Field("l1"))
228	Store(r2lo, out.Field("l2"))
229	Store(r3lo, out.Field("l3"))
230	Store(r4lo, out.Field("l4"))
231
232	RET()
233}
234
235// mul64 sets r to i * aX * bX.
236func mul64(r uint128, i int, aX, bX namedComponent) {
237	switch i {
238	case 1:
239		Comment(fmt.Sprintf("%s = %s×%s", r, aX, bX))
240		Load(aX, RAX)
241	case 2:
242		Comment(fmt.Sprintf("%s = 2×%s×%s", r, aX, bX))
243		Load(aX, RAX)
244		SHLQ(Imm(1), RAX)
245	default:
246		panic("unsupported i value")
247	}
248	MULQ(mustAddr(bX)) // RDX, RAX = RAX * bX
249	MOVQ(RAX, r.lo)
250	MOVQ(RDX, r.hi)
251}
252
253// addMul64 sets r to r + i * aX * bX.
254func addMul64(r uint128, i uint64, aX, bX namedComponent) {
255	switch i {
256	case 1:
257		Comment(fmt.Sprintf("%s += %s×%s", r, aX, bX))
258		Load(aX, RAX)
259	default:
260		Comment(fmt.Sprintf("%s += %d×%s×%s", r, i, aX, bX))
261		IMUL3Q(Imm(i), Load(aX, GP64()), RAX)
262	}
263	MULQ(mustAddr(bX)) // RDX, RAX = RAX * bX
264	ADDQ(RAX, r.lo)
265	ADCQ(RDX, r.hi)
266}
267
268// shiftRightBy51 returns r >> 51 and r.lo.
269//
270// After this function is called, the uint128 may not be used anymore.
271func shiftRightBy51(r *uint128) (out, lo GPVirtual) {
272	out = r.hi
273	lo = r.lo
274	SHLQ(Imm(64-51), r.lo, r.hi)
275	r.lo, r.hi = nil, nil // make sure the uint128 is unusable
276	return
277}
278
279// maskAndAdd sets r = r&mask + c*i.
280func maskAndAdd(r, mask, c GPVirtual, i uint64) {
281	ANDQ(mask, r)
282	if i != 1 {
283		IMUL3Q(Imm(i), c, c)
284	}
285	ADDQ(c, r)
286}
287
288func mustAddr(c Component) Op {
289	b, err := c.Resolve()
290	if err != nil {
291		panic(err)
292	}
293	return b.Addr
294}
295