1//+build generate 2 3//go:generate go run gen.go -out galois_gen_amd64.s -stubs galois_gen_amd64.go 4//go:generate gofmt -w galois_gen_switch_amd64.go 5 6package main 7 8import ( 9 "bufio" 10 "fmt" 11 "os" 12 13 . "github.com/mmcloughlin/avo/build" 14 "github.com/mmcloughlin/avo/buildtags" 15 . "github.com/mmcloughlin/avo/operand" 16 "github.com/mmcloughlin/avo/reg" 17) 18 19// Technically we can do slightly bigger, but we stay reasonable. 20const inputMax = 10 21const outputMax = 8 22 23var switchDefs [inputMax][outputMax]string 24var switchDefsX [inputMax][outputMax]string 25 26const perLoopBits = 5 27const perLoop = 1 << perLoopBits 28 29func main() { 30 Constraint(buildtags.Not("appengine").ToConstraint()) 31 Constraint(buildtags.Not("noasm").ToConstraint()) 32 Constraint(buildtags.Not("nogen").ToConstraint()) 33 Constraint(buildtags.Term("gc").ToConstraint()) 34 35 for i := 1; i <= inputMax; i++ { 36 for j := 1; j <= outputMax; j++ { 37 //genMulAvx2(fmt.Sprintf("mulAvxTwoXor_%dx%d", i, j), i, j, true) 38 genMulAvx2(fmt.Sprintf("mulAvxTwo_%dx%d", i, j), i, j, false) 39 } 40 } 41 f, err := os.Create("galois_gen_switch_amd64.go") 42 if err != nil { 43 panic(err) 44 } 45 defer f.Close() 46 w := bufio.NewWriter(f) 47 defer w.Flush() 48 w.WriteString(`// Code generated by command: go generate ` + os.Getenv("GOFILE") + `. DO NOT EDIT. 49 50// +build !appengine 51// +build !noasm 52// +build gc 53// +build !nogen 54 55package reedsolomon 56 57import "fmt" 58 59`) 60 61 w.WriteString("const avx2CodeGen = true\n") 62 w.WriteString(fmt.Sprintf("const maxAvx2Inputs = %d\nconst maxAvx2Outputs = %d\n", inputMax, outputMax)) 63 w.WriteString(` 64 65func galMulSlicesAvx2(matrix []byte, in, out [][]byte, start, stop int) int { 66 n := stop-start 67`) 68 69 w.WriteString(fmt.Sprintf("n = (n>>%d)<<%d\n\n", perLoopBits, perLoopBits)) 70 w.WriteString(`switch len(in) { 71`) 72 for in, defs := range switchDefs[:] { 73 w.WriteString(fmt.Sprintf(" case %d:\n switch len(out) {\n", in+1)) 74 for out, def := range defs[:] { 75 w.WriteString(fmt.Sprintf(" case %d:\n", out+1)) 76 w.WriteString(def) 77 } 78 w.WriteString("}\n") 79 } 80 w.WriteString(`} 81 panic(fmt.Sprintf("unhandled size: %dx%d", len(in), len(out))) 82} 83`) 84 Generate() 85} 86 87func genMulAvx2(name string, inputs int, outputs int, xor bool) { 88 total := inputs * outputs 89 90 doc := []string{ 91 fmt.Sprintf("%s takes %d inputs and produces %d outputs.", name, inputs, outputs), 92 } 93 if !xor { 94 doc = append(doc, "The output is initialized to 0.") 95 } 96 97 // Load shuffle masks on every use. 98 var loadNone bool 99 // Use registers for destination registers. 100 var regDst = true 101 102 // lo, hi, 1 in, 1 out, 2 tmp, 1 mask 103 est := total*2 + outputs + 5 104 if outputs == 1 { 105 // We don't need to keep a copy of the input if only 1 output. 106 est -= 2 107 } 108 109 if est > 16 { 110 loadNone = true 111 // We run out of GP registers first, now. 112 if inputs+outputs > 12 { 113 regDst = false 114 } 115 } 116 117 TEXT(name, 0, fmt.Sprintf("func(matrix []byte, in [][]byte, out [][]byte, start, n int)")) 118 119 // SWITCH DEFINITION: 120 s := fmt.Sprintf(" mulAvxTwo_%dx%d(matrix, in, out, start, n)\n", inputs, outputs) 121 s += fmt.Sprintf("\t\t\t\treturn n\n") 122 switchDefs[inputs-1][outputs-1] = s 123 124 if loadNone { 125 Comment("Loading no tables to registers") 126 } else { 127 // loadNone == false 128 Comment("Loading all tables to registers") 129 } 130 131 Doc(doc...) 132 Pragma("noescape") 133 Commentf("Full registers estimated %d YMM used", est) 134 135 length := Load(Param("n"), GP64()) 136 matrixBase := GP64() 137 MOVQ(Param("matrix").Base().MustAddr(), matrixBase) 138 SHRQ(U8(perLoopBits), length) 139 TESTQ(length, length) 140 JZ(LabelRef(name + "_end")) 141 142 dst := make([]reg.VecVirtual, outputs) 143 dstPtr := make([]reg.GPVirtual, outputs) 144 outBase := Param("out").Base().MustAddr() 145 outSlicePtr := GP64() 146 MOVQ(outBase, outSlicePtr) 147 for i := range dst { 148 dst[i] = YMM() 149 if !regDst { 150 continue 151 } 152 ptr := GP64() 153 MOVQ(Mem{Base: outSlicePtr, Disp: i * 24}, ptr) 154 dstPtr[i] = ptr 155 } 156 157 inLo := make([]reg.VecVirtual, total) 158 inHi := make([]reg.VecVirtual, total) 159 160 for i := range inLo { 161 if loadNone { 162 break 163 } 164 tableLo := YMM() 165 tableHi := YMM() 166 VMOVDQU(Mem{Base: matrixBase, Disp: i * 64}, tableLo) 167 VMOVDQU(Mem{Base: matrixBase, Disp: i*64 + 32}, tableHi) 168 inLo[i] = tableLo 169 inHi[i] = tableHi 170 } 171 172 inPtrs := make([]reg.GPVirtual, inputs) 173 inSlicePtr := GP64() 174 MOVQ(Param("in").Base().MustAddr(), inSlicePtr) 175 for i := range inPtrs { 176 ptr := GP64() 177 MOVQ(Mem{Base: inSlicePtr, Disp: i * 24}, ptr) 178 inPtrs[i] = ptr 179 } 180 181 tmpMask := GP64() 182 MOVQ(U32(15), tmpMask) 183 lowMask := YMM() 184 MOVQ(tmpMask, lowMask.AsX()) 185 VPBROADCASTB(lowMask.AsX(), lowMask) 186 187 offset := GP64() 188 MOVQ(Param("start").MustAddr(), offset) 189 Label(name + "_loop") 190 if xor { 191 Commentf("Load %d outputs", outputs) 192 } else { 193 Commentf("Clear %d outputs", outputs) 194 } 195 for i := range dst { 196 if xor { 197 if regDst { 198 VMOVDQU(Mem{Base: dstPtr[i], Index: offset, Scale: 1}, dst[i]) 199 continue 200 } 201 ptr := GP64() 202 MOVQ(outBase, ptr) 203 VMOVDQU(Mem{Base: ptr, Index: offset, Scale: 1}, dst[i]) 204 } else { 205 VPXOR(dst[i], dst[i], dst[i]) 206 } 207 } 208 209 lookLow, lookHigh := YMM(), YMM() 210 inLow, inHigh := YMM(), YMM() 211 for i := range inPtrs { 212 Commentf("Load and process 32 bytes from input %d to %d outputs", i, outputs) 213 VMOVDQU(Mem{Base: inPtrs[i], Index: offset, Scale: 1}, inLow) 214 VPSRLQ(U8(4), inLow, inHigh) 215 VPAND(lowMask, inLow, inLow) 216 VPAND(lowMask, inHigh, inHigh) 217 for j := range dst { 218 if loadNone { 219 VMOVDQU(Mem{Base: matrixBase, Disp: 64 * (i*outputs + j)}, lookLow) 220 VMOVDQU(Mem{Base: matrixBase, Disp: 32 + 64*(i*outputs+j)}, lookHigh) 221 VPSHUFB(inLow, lookLow, lookLow) 222 VPSHUFB(inHigh, lookHigh, lookHigh) 223 } else { 224 VPSHUFB(inLow, inLo[i*outputs+j], lookLow) 225 VPSHUFB(inHigh, inHi[i*outputs+j], lookHigh) 226 } 227 VPXOR(lookLow, lookHigh, lookLow) 228 VPXOR(lookLow, dst[j], dst[j]) 229 } 230 } 231 Commentf("Store %d outputs", outputs) 232 for i := range dst { 233 if regDst { 234 VMOVDQU(dst[i], Mem{Base: dstPtr[i], Index: offset, Scale: 1}) 235 continue 236 } 237 ptr := GP64() 238 MOVQ(Mem{Base: outSlicePtr, Disp: i * 24}, ptr) 239 VMOVDQU(dst[i], Mem{Base: ptr, Index: offset, Scale: 1}) 240 } 241 Comment("Prepare for next loop") 242 ADDQ(U8(perLoop), offset) 243 DECQ(length) 244 JNZ(LabelRef(name + "_loop")) 245 VZEROUPPER() 246 247 Label(name + "_end") 248 RET() 249} 250