1//+build generate
2
3//go:generate go run gen.go -out galois_gen_amd64.s -stubs galois_gen_amd64.go
4//go:generate gofmt -w galois_gen_switch_amd64.go
5
6package main
7
8import (
9	"bufio"
10	"fmt"
11	"os"
12
13	. "github.com/mmcloughlin/avo/build"
14	"github.com/mmcloughlin/avo/buildtags"
15	. "github.com/mmcloughlin/avo/operand"
16	"github.com/mmcloughlin/avo/reg"
17)
18
19// Technically we can do slightly bigger, but we stay reasonable.
20const inputMax = 10
21const outputMax = 8
22
23var switchDefs [inputMax][outputMax]string
24var switchDefsX [inputMax][outputMax]string
25
26const perLoopBits = 5
27const perLoop = 1 << perLoopBits
28
29func main() {
30	Constraint(buildtags.Not("appengine").ToConstraint())
31	Constraint(buildtags.Not("noasm").ToConstraint())
32	Constraint(buildtags.Not("nogen").ToConstraint())
33	Constraint(buildtags.Term("gc").ToConstraint())
34
35	for i := 1; i <= inputMax; i++ {
36		for j := 1; j <= outputMax; j++ {
37			//genMulAvx2(fmt.Sprintf("mulAvxTwoXor_%dx%d", i, j), i, j, true)
38			genMulAvx2(fmt.Sprintf("mulAvxTwo_%dx%d", i, j), i, j, false)
39		}
40	}
41	f, err := os.Create("galois_gen_switch_amd64.go")
42	if err != nil {
43		panic(err)
44	}
45	defer f.Close()
46	w := bufio.NewWriter(f)
47	defer w.Flush()
48	w.WriteString(`// Code generated by command: go generate ` + os.Getenv("GOFILE") + `. DO NOT EDIT.
49
50// +build !appengine
51// +build !noasm
52// +build gc
53// +build !nogen
54
55package reedsolomon
56
57import "fmt"
58
59`)
60
61	w.WriteString("const avx2CodeGen = true\n")
62	w.WriteString(fmt.Sprintf("const maxAvx2Inputs = %d\nconst maxAvx2Outputs = %d\n", inputMax, outputMax))
63	w.WriteString(`
64
65func galMulSlicesAvx2(matrix []byte, in, out [][]byte, start, stop int) int {
66	n := stop-start
67`)
68
69	w.WriteString(fmt.Sprintf("n = (n>>%d)<<%d\n\n", perLoopBits, perLoopBits))
70	w.WriteString(`switch len(in) {
71`)
72	for in, defs := range switchDefs[:] {
73		w.WriteString(fmt.Sprintf("		case %d:\n			switch len(out) {\n", in+1))
74		for out, def := range defs[:] {
75			w.WriteString(fmt.Sprintf("				case %d:\n", out+1))
76			w.WriteString(def)
77		}
78		w.WriteString("}\n")
79	}
80	w.WriteString(`}
81	panic(fmt.Sprintf("unhandled size: %dx%d", len(in), len(out)))
82}
83`)
84	Generate()
85}
86
87func genMulAvx2(name string, inputs int, outputs int, xor bool) {
88	total := inputs * outputs
89
90	doc := []string{
91		fmt.Sprintf("%s takes %d inputs and produces %d outputs.", name, inputs, outputs),
92	}
93	if !xor {
94		doc = append(doc, "The output is initialized to 0.")
95	}
96
97	// Load shuffle masks on every use.
98	var loadNone bool
99	// Use registers for destination registers.
100	var regDst = true
101
102	// lo, hi, 1 in, 1 out, 2 tmp, 1 mask
103	est := total*2 + outputs + 5
104	if outputs == 1 {
105		// We don't need to keep a copy of the input if only 1 output.
106		est -= 2
107	}
108
109	if est > 16 {
110		loadNone = true
111		// We run out of GP registers first, now.
112		if inputs+outputs > 12 {
113			regDst = false
114		}
115	}
116
117	TEXT(name, 0, fmt.Sprintf("func(matrix []byte, in [][]byte, out [][]byte, start, n int)"))
118
119	// SWITCH DEFINITION:
120	s := fmt.Sprintf("			mulAvxTwo_%dx%d(matrix, in, out, start, n)\n", inputs, outputs)
121	s += fmt.Sprintf("\t\t\t\treturn n\n")
122	switchDefs[inputs-1][outputs-1] = s
123
124	if loadNone {
125		Comment("Loading no tables to registers")
126	} else {
127		// loadNone == false
128		Comment("Loading all tables to registers")
129	}
130
131	Doc(doc...)
132	Pragma("noescape")
133	Commentf("Full registers estimated %d YMM used", est)
134
135	length := Load(Param("n"), GP64())
136	matrixBase := GP64()
137	MOVQ(Param("matrix").Base().MustAddr(), matrixBase)
138	SHRQ(U8(perLoopBits), length)
139	TESTQ(length, length)
140	JZ(LabelRef(name + "_end"))
141
142	dst := make([]reg.VecVirtual, outputs)
143	dstPtr := make([]reg.GPVirtual, outputs)
144	outBase := Param("out").Base().MustAddr()
145	outSlicePtr := GP64()
146	MOVQ(outBase, outSlicePtr)
147	for i := range dst {
148		dst[i] = YMM()
149		if !regDst {
150			continue
151		}
152		ptr := GP64()
153		MOVQ(Mem{Base: outSlicePtr, Disp: i * 24}, ptr)
154		dstPtr[i] = ptr
155	}
156
157	inLo := make([]reg.VecVirtual, total)
158	inHi := make([]reg.VecVirtual, total)
159
160	for i := range inLo {
161		if loadNone {
162			break
163		}
164		tableLo := YMM()
165		tableHi := YMM()
166		VMOVDQU(Mem{Base: matrixBase, Disp: i * 64}, tableLo)
167		VMOVDQU(Mem{Base: matrixBase, Disp: i*64 + 32}, tableHi)
168		inLo[i] = tableLo
169		inHi[i] = tableHi
170	}
171
172	inPtrs := make([]reg.GPVirtual, inputs)
173	inSlicePtr := GP64()
174	MOVQ(Param("in").Base().MustAddr(), inSlicePtr)
175	for i := range inPtrs {
176		ptr := GP64()
177		MOVQ(Mem{Base: inSlicePtr, Disp: i * 24}, ptr)
178		inPtrs[i] = ptr
179	}
180
181	tmpMask := GP64()
182	MOVQ(U32(15), tmpMask)
183	lowMask := YMM()
184	MOVQ(tmpMask, lowMask.AsX())
185	VPBROADCASTB(lowMask.AsX(), lowMask)
186
187	offset := GP64()
188	MOVQ(Param("start").MustAddr(), offset)
189	Label(name + "_loop")
190	if xor {
191		Commentf("Load %d outputs", outputs)
192	} else {
193		Commentf("Clear %d outputs", outputs)
194	}
195	for i := range dst {
196		if xor {
197			if regDst {
198				VMOVDQU(Mem{Base: dstPtr[i], Index: offset, Scale: 1}, dst[i])
199				continue
200			}
201			ptr := GP64()
202			MOVQ(outBase, ptr)
203			VMOVDQU(Mem{Base: ptr, Index: offset, Scale: 1}, dst[i])
204		} else {
205			VPXOR(dst[i], dst[i], dst[i])
206		}
207	}
208
209	lookLow, lookHigh := YMM(), YMM()
210	inLow, inHigh := YMM(), YMM()
211	for i := range inPtrs {
212		Commentf("Load and process 32 bytes from input %d to %d outputs", i, outputs)
213		VMOVDQU(Mem{Base: inPtrs[i], Index: offset, Scale: 1}, inLow)
214		VPSRLQ(U8(4), inLow, inHigh)
215		VPAND(lowMask, inLow, inLow)
216		VPAND(lowMask, inHigh, inHigh)
217		for j := range dst {
218			if loadNone {
219				VMOVDQU(Mem{Base: matrixBase, Disp: 64 * (i*outputs + j)}, lookLow)
220				VMOVDQU(Mem{Base: matrixBase, Disp: 32 + 64*(i*outputs+j)}, lookHigh)
221				VPSHUFB(inLow, lookLow, lookLow)
222				VPSHUFB(inHigh, lookHigh, lookHigh)
223			} else {
224				VPSHUFB(inLow, inLo[i*outputs+j], lookLow)
225				VPSHUFB(inHigh, inHi[i*outputs+j], lookHigh)
226			}
227			VPXOR(lookLow, lookHigh, lookLow)
228			VPXOR(lookLow, dst[j], dst[j])
229		}
230	}
231	Commentf("Store %d outputs", outputs)
232	for i := range dst {
233		if regDst {
234			VMOVDQU(dst[i], Mem{Base: dstPtr[i], Index: offset, Scale: 1})
235			continue
236		}
237		ptr := GP64()
238		MOVQ(Mem{Base: outSlicePtr, Disp: i * 24}, ptr)
239		VMOVDQU(dst[i], Mem{Base: ptr, Index: offset, Scale: 1})
240	}
241	Comment("Prepare for next loop")
242	ADDQ(U8(perLoop), offset)
243	DECQ(length)
244	JNZ(LabelRef(name + "_loop"))
245	VZEROUPPER()
246
247	Label(name + "_end")
248	RET()
249}
250