1// Copyright 2020 ConsenSys Software Inc. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package amd64 16 17import ( 18 "fmt" 19 20 "github.com/consensys/bavard/amd64" 21) 22 23// MulADX uses AX, DX and BP 24// sets x * y into t, without modular reduction 25// x() will have more accesses than y() 26// (caller should store x in registers, if possible) 27// if no (tmp) register is available, this uses one PUSH/POP on the stack in the hot loop. 28func (f *FFAmd64) MulADX(registers *amd64.Registers, x, y func(int) string, t []amd64.Register) []amd64.Register { 29 // registers 30 var tr amd64.Register // temporary register 31 A := amd64.BP 32 33 hasFreeRegister := registers.Available() > 0 34 if hasFreeRegister { 35 tr = registers.Pop() 36 } else { 37 tr = A 38 } 39 40 f.LabelRegisters("A", A) 41 f.LabelRegisters("t", t...) 42 43 for i := 0; i < f.NbWords; i++ { 44 f.Comment("clear the flags") 45 f.XORQ(amd64.AX, amd64.AX) 46 47 f.MOVQ(y(i), amd64.DX) 48 49 // for j=0 to N-1 50 // (A,t[j]) := t[j] + x[j]*y[i] + A 51 if i == 0 { 52 for j := 0; j < f.NbWords; j++ { 53 f.Comment(fmt.Sprintf("(A,t[%[1]d]) := x[%[1]d]*y[%[2]d] + A", j, i)) 54 55 if j == 0 { 56 f.MULXQ(x(j), t[j], t[j+1]) 57 } else { 58 highBits := A 59 if j != f.NbWordsLastIndex { 60 highBits = t[j+1] 61 } 62 f.MULXQ(x(j), amd64.AX, highBits) 63 f.ADOXQ(amd64.AX, t[j]) 64 } 65 } 66 } else { 67 for j := 0; j < f.NbWords; j++ { 68 f.Comment(fmt.Sprintf("(A,t[%[1]d]) := t[%[1]d] + x[%[1]d]*y[%[2]d] + A", j, i)) 69 70 if j != 0 { 71 f.ADCXQ(A, t[j]) 72 } 73 f.MULXQ(x(j), amd64.AX, A) 74 f.ADOXQ(amd64.AX, t[j]) 75 } 76 } 77 78 f.Comment("A += carries from ADCXQ and ADOXQ") 79 f.MOVQ(0, amd64.AX) 80 if i != 0 { 81 f.ADCXQ(amd64.AX, A) 82 } 83 f.ADOXQ(amd64.AX, A) 84 85 if !hasFreeRegister { 86 f.PUSHQ(A) 87 } 88 89 // m := t[0]*q'[0] mod W 90 f.Comment("m := t[0]*q'[0] mod W") 91 m := amd64.DX 92 // f.MOVQ(t[0], m) 93 // f.MULXQ(f.qInv0(), m, amd64.AX) 94 f.MOVQ(f.qInv0(), m) 95 f.IMULQ(t[0], m) 96 97 // clear the carry flags 98 f.Comment("clear the flags") 99 f.XORQ(amd64.AX, amd64.AX) 100 101 // C,_ := t[0] + m*q[0] 102 f.Comment("C,_ := t[0] + m*q[0]") 103 104 f.MULXQ(f.qAt(0), amd64.AX, tr) 105 f.ADCXQ(t[0], amd64.AX) 106 f.MOVQ(tr, t[0]) 107 108 if !hasFreeRegister { 109 f.POPQ(A) 110 } 111 // for j=1 to N-1 112 // (C,t[j-1]) := t[j] + m*q[j] + C 113 for j := 1; j < f.NbWords; j++ { 114 f.Comment(fmt.Sprintf("(C,t[%[1]d]) := t[%[2]d] + m*q[%[2]d] + C", j-1, j)) 115 f.ADCXQ(t[j], t[j-1]) 116 f.MULXQ(f.qAt(j), amd64.AX, t[j]) 117 f.ADOXQ(amd64.AX, t[j-1]) 118 } 119 120 f.Comment(fmt.Sprintf("t[%d] = C + A", f.NbWordsLastIndex)) 121 f.MOVQ(0, amd64.AX) 122 f.ADCXQ(amd64.AX, t[f.NbWordsLastIndex]) 123 f.ADOXQ(A, t[f.NbWordsLastIndex]) 124 125 } 126 127 if hasFreeRegister { 128 registers.Push(tr) 129 } 130 131 return t 132} 133 134func (f *FFAmd64) generateMul(forceADX bool) { 135 f.Comment("mul(res, x, y *Element)") 136 137 const argSize = 3 * 8 138 minStackSize := argSize 139 if forceADX { 140 minStackSize = 0 141 } 142 stackSize := f.StackSize(f.NbWords*2, 2, minStackSize) 143 registers := f.FnHeader("mul", stackSize, argSize, amd64.DX, amd64.AX) 144 defer f.AssertCleanStack(stackSize, minStackSize) 145 146 f.WriteLn(` 147 // the algorithm is described here 148 // https://hackmd.io/@zkteam/modular_multiplication 149 // however, to benefit from the ADCX and ADOX carry chains 150 // we split the inner loops in 2: 151 // for i=0 to N-1 152 // for j=0 to N-1 153 // (A,t[j]) := t[j] + x[j]*y[i] + A 154 // m := t[0]*q'[0] mod W 155 // C,_ := t[0] + m*q[0] 156 // for j=1 to N-1 157 // (C,t[j-1]) := t[j] + m*q[j] + C 158 // t[N-1] = C + A 159 `) 160 if stackSize > 0 { 161 f.WriteLn("NO_LOCAL_POINTERS") 162 } 163 164 noAdx := f.NewLabel() 165 166 if !forceADX { 167 // check ADX instruction support 168 f.CMPB("·supportAdx(SB)", 1) 169 f.JNE(noAdx) 170 } 171 172 { 173 // we need to access x and y words, per index 174 var xat, yat func(int) string 175 var gc func() 176 177 // we need NbWords registers for t, plus optionally one for tmp register in mulADX if we want to avoid PUSH/POP 178 nbRegisters := registers.Available() 179 if nbRegisters < f.NbWords { 180 panic("not enough registers, not supported.") 181 } 182 183 t := registers.PopN(f.NbWords) 184 nbRegisters = registers.Available() 185 switch nbRegisters { 186 case 0: 187 // y is access through use of AX/DX 188 yat = func(i int) string { 189 y := amd64.AX 190 f.MOVQ("y+16(FP)", y) 191 return y.At(i) 192 } 193 194 // we move x on the stack. 195 f.MOVQ("x+8(FP)", amd64.AX) 196 _x := f.PopN(®isters, true) 197 f.LabelRegisters("x", _x...) 198 f.Mov(amd64.AX, t) 199 f.Mov(t, _x) 200 xat = func(i int) string { 201 return string(_x[i]) 202 } 203 gc = func() { 204 f.Push(®isters, _x...) 205 } 206 case 1: 207 // y is access through use of AX/DX 208 yat = func(i int) string { 209 y := amd64.AX 210 f.MOVQ("y+16(FP)", y) 211 return y.At(i) 212 } 213 // x uses the register 214 x := registers.Pop() 215 f.MOVQ("x+8(FP)", x) 216 xat = func(i int) string { 217 return x.At(i) 218 } 219 gc = func() { 220 registers.Push(x) 221 } 222 case 2, 3: 223 // x, y uses registers 224 x := registers.Pop() 225 y := registers.Pop() 226 227 f.MOVQ("x+8(FP)", x) 228 f.MOVQ("y+16(FP)", y) 229 230 xat = func(i int) string { 231 return x.At(i) 232 } 233 234 yat = func(i int) string { 235 return y.At(i) 236 } 237 gc = func() { 238 registers.Push(x, y) 239 } 240 default: 241 // we have a least 4 registers. 242 // 1 for tmp. 243 nbRegisters-- 244 // 1 for y 245 nbRegisters-- 246 var y amd64.Register 247 248 if nbRegisters >= f.NbWords { 249 // we store x fully in registers 250 x := registers.Pop() 251 f.MOVQ("x+8(FP)", x) 252 _x := registers.PopN(f.NbWords) 253 f.LabelRegisters("x", _x...) 254 f.Mov(x, _x) 255 256 xat = func(i int) string { 257 return string(_x[i]) 258 } 259 registers.Push(x) 260 gc = func() { 261 registers.Push(y) 262 registers.Push(_x...) 263 } 264 } else { 265 // we take at least 1 register for x addr 266 nbRegisters-- 267 x := registers.Pop() 268 y = registers.Pop() // temporary lock 1 for y 269 f.MOVQ("x+8(FP)", x) 270 271 // and use the rest for x0...xn 272 _x := registers.PopN(nbRegisters) 273 f.LabelRegisters("x", _x...) 274 for i := 0; i < len(_x); i++ { 275 f.MOVQ(x.At(i), _x[i]) 276 } 277 xat = func(i int) string { 278 if i < len(_x) { 279 return string(_x[i]) 280 } 281 return x.At(i) 282 } 283 registers.Push(y) 284 285 gc = func() { 286 registers.Push(x, y) 287 registers.Push(_x...) 288 } 289 290 } 291 y = registers.Pop() 292 293 f.MOVQ("y+16(FP)", y) 294 yat = func(i int) string { 295 return y.At(i) 296 } 297 298 } 299 300 f.MulADX(®isters, xat, yat, t) 301 gc() 302 303 // --------------------------------------------------------------------------------------------- 304 // reduce 305 f.Reduce(®isters, t) 306 307 f.MOVQ("res+0(FP)", amd64.AX) 308 f.Mov(t, amd64.AX) 309 f.RET() 310 } 311 312 // --------------------------------------------------------------------------------------------- 313 // no MULX, ADX instructions 314 if !forceADX { 315 f.LABEL(noAdx) 316 317 f.MOVQ("res+0(FP)", amd64.AX) 318 f.MOVQ(amd64.AX, "(SP)") 319 f.MOVQ("x+8(FP)", amd64.AX) 320 f.MOVQ(amd64.AX, "8(SP)") 321 f.MOVQ("y+16(FP)", amd64.AX) 322 f.MOVQ(amd64.AX, "16(SP)") 323 f.WriteLn("CALL ·_mulGeneric(SB)") 324 f.RET() 325 326 } 327} 328