1// Copyright (c) 2021 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package main 6 7import ( 8 "fmt" 9 10 . "github.com/mmcloughlin/avo/build" 11 . "github.com/mmcloughlin/avo/gotypes" 12 . "github.com/mmcloughlin/avo/operand" 13 . "github.com/mmcloughlin/avo/reg" 14) 15 16//go:generate go run . -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field 17 18func main() { 19 Package("golang.org/x/crypto/curve25519/internal/field") 20 ConstraintExpr("amd64,gc,!purego") 21 feMul() 22 feSquare() 23 Generate() 24} 25 26type namedComponent struct { 27 Component 28 name string 29} 30 31func (c namedComponent) String() string { return c.name } 32 33type uint128 struct { 34 name string 35 hi, lo GPVirtual 36} 37 38func (c uint128) String() string { return c.name } 39 40func feSquare() { 41 TEXT("feSquare", NOSPLIT, "func(out, a *Element)") 42 Doc("feSquare sets out = a * a. It works like feSquareGeneric.") 43 Pragma("noescape") 44 45 a := Dereference(Param("a")) 46 l0 := namedComponent{a.Field("l0"), "l0"} 47 l1 := namedComponent{a.Field("l1"), "l1"} 48 l2 := namedComponent{a.Field("l2"), "l2"} 49 l3 := namedComponent{a.Field("l3"), "l3"} 50 l4 := namedComponent{a.Field("l4"), "l4"} 51 52 // r0 = l0×l0 + 19×2×(l1×l4 + l2×l3) 53 r0 := uint128{"r0", GP64(), GP64()} 54 mul64(r0, 1, l0, l0) 55 addMul64(r0, 38, l1, l4) 56 addMul64(r0, 38, l2, l3) 57 58 // r1 = 2×l0×l1 + 19×2×l2×l4 + 19×l3×l3 59 r1 := uint128{"r1", GP64(), GP64()} 60 mul64(r1, 2, l0, l1) 61 addMul64(r1, 38, l2, l4) 62 addMul64(r1, 19, l3, l3) 63 64 // r2 = = 2×l0×l2 + l1×l1 + 19×2×l3×l4 65 r2 := uint128{"r2", GP64(), GP64()} 66 mul64(r2, 2, l0, l2) 67 addMul64(r2, 1, l1, l1) 68 addMul64(r2, 38, l3, l4) 69 70 // r3 = = 2×l0×l3 + 2×l1×l2 + 19×l4×l4 71 r3 := uint128{"r3", GP64(), GP64()} 72 mul64(r3, 2, l0, l3) 73 addMul64(r3, 2, l1, l2) 74 addMul64(r3, 19, l4, l4) 75 76 // r4 = = 2×l0×l4 + 2×l1×l3 + l2×l2 77 r4 := uint128{"r4", GP64(), GP64()} 78 mul64(r4, 2, l0, l4) 79 addMul64(r4, 2, l1, l3) 80 addMul64(r4, 1, l2, l2) 81 82 Comment("First reduction chain") 83 maskLow51Bits := GP64() 84 MOVQ(Imm((1<<51)-1), maskLow51Bits) 85 c0, r0lo := shiftRightBy51(&r0) 86 c1, r1lo := shiftRightBy51(&r1) 87 c2, r2lo := shiftRightBy51(&r2) 88 c3, r3lo := shiftRightBy51(&r3) 89 c4, r4lo := shiftRightBy51(&r4) 90 maskAndAdd(r0lo, maskLow51Bits, c4, 19) 91 maskAndAdd(r1lo, maskLow51Bits, c0, 1) 92 maskAndAdd(r2lo, maskLow51Bits, c1, 1) 93 maskAndAdd(r3lo, maskLow51Bits, c2, 1) 94 maskAndAdd(r4lo, maskLow51Bits, c3, 1) 95 96 Comment("Second reduction chain (carryPropagate)") 97 // c0 = r0 >> 51 98 MOVQ(r0lo, c0) 99 SHRQ(Imm(51), c0) 100 // c1 = r1 >> 51 101 MOVQ(r1lo, c1) 102 SHRQ(Imm(51), c1) 103 // c2 = r2 >> 51 104 MOVQ(r2lo, c2) 105 SHRQ(Imm(51), c2) 106 // c3 = r3 >> 51 107 MOVQ(r3lo, c3) 108 SHRQ(Imm(51), c3) 109 // c4 = r4 >> 51 110 MOVQ(r4lo, c4) 111 SHRQ(Imm(51), c4) 112 maskAndAdd(r0lo, maskLow51Bits, c4, 19) 113 maskAndAdd(r1lo, maskLow51Bits, c0, 1) 114 maskAndAdd(r2lo, maskLow51Bits, c1, 1) 115 maskAndAdd(r3lo, maskLow51Bits, c2, 1) 116 maskAndAdd(r4lo, maskLow51Bits, c3, 1) 117 118 Comment("Store output") 119 out := Dereference(Param("out")) 120 Store(r0lo, out.Field("l0")) 121 Store(r1lo, out.Field("l1")) 122 Store(r2lo, out.Field("l2")) 123 Store(r3lo, out.Field("l3")) 124 Store(r4lo, out.Field("l4")) 125 126 RET() 127} 128 129func feMul() { 130 TEXT("feMul", NOSPLIT, "func(out, a, b *Element)") 131 Doc("feMul sets out = a * b. It works like feMulGeneric.") 132 Pragma("noescape") 133 134 a := Dereference(Param("a")) 135 a0 := namedComponent{a.Field("l0"), "a0"} 136 a1 := namedComponent{a.Field("l1"), "a1"} 137 a2 := namedComponent{a.Field("l2"), "a2"} 138 a3 := namedComponent{a.Field("l3"), "a3"} 139 a4 := namedComponent{a.Field("l4"), "a4"} 140 141 b := Dereference(Param("b")) 142 b0 := namedComponent{b.Field("l0"), "b0"} 143 b1 := namedComponent{b.Field("l1"), "b1"} 144 b2 := namedComponent{b.Field("l2"), "b2"} 145 b3 := namedComponent{b.Field("l3"), "b3"} 146 b4 := namedComponent{b.Field("l4"), "b4"} 147 148 // r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1) 149 r0 := uint128{"r0", GP64(), GP64()} 150 mul64(r0, 1, a0, b0) 151 addMul64(r0, 19, a1, b4) 152 addMul64(r0, 19, a2, b3) 153 addMul64(r0, 19, a3, b2) 154 addMul64(r0, 19, a4, b1) 155 156 // r1 = a0×b1 + a1×b0 + 19×(a2×b4 + a3×b3 + a4×b2) 157 r1 := uint128{"r1", GP64(), GP64()} 158 mul64(r1, 1, a0, b1) 159 addMul64(r1, 1, a1, b0) 160 addMul64(r1, 19, a2, b4) 161 addMul64(r1, 19, a3, b3) 162 addMul64(r1, 19, a4, b2) 163 164 // r2 = a0×b2 + a1×b1 + a2×b0 + 19×(a3×b4 + a4×b3) 165 r2 := uint128{"r2", GP64(), GP64()} 166 mul64(r2, 1, a0, b2) 167 addMul64(r2, 1, a1, b1) 168 addMul64(r2, 1, a2, b0) 169 addMul64(r2, 19, a3, b4) 170 addMul64(r2, 19, a4, b3) 171 172 // r3 = a0×b3 + a1×b2 + a2×b1 + a3×b0 + 19×a4×b4 173 r3 := uint128{"r3", GP64(), GP64()} 174 mul64(r3, 1, a0, b3) 175 addMul64(r3, 1, a1, b2) 176 addMul64(r3, 1, a2, b1) 177 addMul64(r3, 1, a3, b0) 178 addMul64(r3, 19, a4, b4) 179 180 // r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0 181 r4 := uint128{"r4", GP64(), GP64()} 182 mul64(r4, 1, a0, b4) 183 addMul64(r4, 1, a1, b3) 184 addMul64(r4, 1, a2, b2) 185 addMul64(r4, 1, a3, b1) 186 addMul64(r4, 1, a4, b0) 187 188 Comment("First reduction chain") 189 maskLow51Bits := GP64() 190 MOVQ(Imm((1<<51)-1), maskLow51Bits) 191 c0, r0lo := shiftRightBy51(&r0) 192 c1, r1lo := shiftRightBy51(&r1) 193 c2, r2lo := shiftRightBy51(&r2) 194 c3, r3lo := shiftRightBy51(&r3) 195 c4, r4lo := shiftRightBy51(&r4) 196 maskAndAdd(r0lo, maskLow51Bits, c4, 19) 197 maskAndAdd(r1lo, maskLow51Bits, c0, 1) 198 maskAndAdd(r2lo, maskLow51Bits, c1, 1) 199 maskAndAdd(r3lo, maskLow51Bits, c2, 1) 200 maskAndAdd(r4lo, maskLow51Bits, c3, 1) 201 202 Comment("Second reduction chain (carryPropagate)") 203 // c0 = r0 >> 51 204 MOVQ(r0lo, c0) 205 SHRQ(Imm(51), c0) 206 // c1 = r1 >> 51 207 MOVQ(r1lo, c1) 208 SHRQ(Imm(51), c1) 209 // c2 = r2 >> 51 210 MOVQ(r2lo, c2) 211 SHRQ(Imm(51), c2) 212 // c3 = r3 >> 51 213 MOVQ(r3lo, c3) 214 SHRQ(Imm(51), c3) 215 // c4 = r4 >> 51 216 MOVQ(r4lo, c4) 217 SHRQ(Imm(51), c4) 218 maskAndAdd(r0lo, maskLow51Bits, c4, 19) 219 maskAndAdd(r1lo, maskLow51Bits, c0, 1) 220 maskAndAdd(r2lo, maskLow51Bits, c1, 1) 221 maskAndAdd(r3lo, maskLow51Bits, c2, 1) 222 maskAndAdd(r4lo, maskLow51Bits, c3, 1) 223 224 Comment("Store output") 225 out := Dereference(Param("out")) 226 Store(r0lo, out.Field("l0")) 227 Store(r1lo, out.Field("l1")) 228 Store(r2lo, out.Field("l2")) 229 Store(r3lo, out.Field("l3")) 230 Store(r4lo, out.Field("l4")) 231 232 RET() 233} 234 235// mul64 sets r to i * aX * bX. 236func mul64(r uint128, i int, aX, bX namedComponent) { 237 switch i { 238 case 1: 239 Comment(fmt.Sprintf("%s = %s×%s", r, aX, bX)) 240 Load(aX, RAX) 241 case 2: 242 Comment(fmt.Sprintf("%s = 2×%s×%s", r, aX, bX)) 243 Load(aX, RAX) 244 SHLQ(Imm(1), RAX) 245 default: 246 panic("unsupported i value") 247 } 248 MULQ(mustAddr(bX)) // RDX, RAX = RAX * bX 249 MOVQ(RAX, r.lo) 250 MOVQ(RDX, r.hi) 251} 252 253// addMul64 sets r to r + i * aX * bX. 254func addMul64(r uint128, i uint64, aX, bX namedComponent) { 255 switch i { 256 case 1: 257 Comment(fmt.Sprintf("%s += %s×%s", r, aX, bX)) 258 Load(aX, RAX) 259 default: 260 Comment(fmt.Sprintf("%s += %d×%s×%s", r, i, aX, bX)) 261 IMUL3Q(Imm(i), Load(aX, GP64()), RAX) 262 } 263 MULQ(mustAddr(bX)) // RDX, RAX = RAX * bX 264 ADDQ(RAX, r.lo) 265 ADCQ(RDX, r.hi) 266} 267 268// shiftRightBy51 returns r >> 51 and r.lo. 269// 270// After this function is called, the uint128 may not be used anymore. 271func shiftRightBy51(r *uint128) (out, lo GPVirtual) { 272 out = r.hi 273 lo = r.lo 274 SHLQ(Imm(64-51), r.lo, r.hi) 275 r.lo, r.hi = nil, nil // make sure the uint128 is unusable 276 return 277} 278 279// maskAndAdd sets r = r&mask + c*i. 280func maskAndAdd(r, mask, c GPVirtual, i uint64) { 281 ANDQ(mask, r) 282 if i != 1 { 283 IMUL3Q(Imm(i), c, c) 284 } 285 ADDQ(c, r) 286} 287 288func mustAddr(c Component) Op { 289 b, err := c.Resolve() 290 if err != nil { 291 panic(err) 292 } 293 return b.Addr 294} 295