1// Copyright 2013 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package types2_test
6
7import (
8	"bytes"
9	"cmd/compile/internal/syntax"
10	"flag"
11	"fmt"
12	"io/ioutil"
13	"testing"
14
15	. "cmd/compile/internal/types2"
16)
17
18var (
19	H   = flag.Int("H", 5, "Hilbert matrix size")
20	out = flag.String("out", "", "write generated program to out")
21)
22
23func TestHilbert(t *testing.T) {
24	// generate source
25	src := program(*H, *out)
26	if *out != "" {
27		ioutil.WriteFile(*out, src, 0666)
28		return
29	}
30
31	// parse source
32	f, err := syntax.Parse(syntax.NewFileBase("hilbert.go"), bytes.NewReader(src), nil, nil, 0)
33	if err != nil {
34		t.Fatal(err)
35	}
36
37	// type-check file
38	DefPredeclaredTestFuncs() // define assert built-in
39	conf := Config{Importer: defaultImporter()}
40	_, err = conf.Check(f.PkgName.Value, []*syntax.File{f}, nil)
41	if err != nil {
42		t.Fatal(err)
43	}
44}
45
46func program(n int, out string) []byte {
47	var g gen
48
49	g.p(`// Code generated by: go test -run=Hilbert -H=%d -out=%q. DO NOT EDIT.
50
51// +`+`build ignore
52
53// This program tests arbitrary precision constant arithmetic
54// by generating the constant elements of a Hilbert matrix H,
55// its inverse I, and the product P = H*I. The product should
56// be the identity matrix.
57package main
58
59func main() {
60	if !ok {
61		printProduct()
62		return
63	}
64	println("PASS")
65}
66
67`, n, out)
68	g.hilbert(n)
69	g.inverse(n)
70	g.product(n)
71	g.verify(n)
72	g.printProduct(n)
73	g.binomials(2*n - 1)
74	g.factorials(2*n - 1)
75
76	return g.Bytes()
77}
78
79type gen struct {
80	bytes.Buffer
81}
82
83func (g *gen) p(format string, args ...interface{}) {
84	fmt.Fprintf(&g.Buffer, format, args...)
85}
86
87func (g *gen) hilbert(n int) {
88	g.p(`// Hilbert matrix, n = %d
89const (
90`, n)
91	for i := 0; i < n; i++ {
92		g.p("\t")
93		for j := 0; j < n; j++ {
94			if j > 0 {
95				g.p(", ")
96			}
97			g.p("h%d_%d", i, j)
98		}
99		if i == 0 {
100			g.p(" = ")
101			for j := 0; j < n; j++ {
102				if j > 0 {
103					g.p(", ")
104				}
105				g.p("1.0/(iota + %d)", j+1)
106			}
107		}
108		g.p("\n")
109	}
110	g.p(")\n\n")
111}
112
113func (g *gen) inverse(n int) {
114	g.p(`// Inverse Hilbert matrix
115const (
116`)
117	for i := 0; i < n; i++ {
118		for j := 0; j < n; j++ {
119			s := "+"
120			if (i+j)&1 != 0 {
121				s = "-"
122			}
123			g.p("\ti%d_%d = %s%d * b%d_%d * b%d_%d * b%d_%d * b%d_%d\n",
124				i, j, s, i+j+1, n+i, n-j-1, n+j, n-i-1, i+j, i, i+j, i)
125		}
126		g.p("\n")
127	}
128	g.p(")\n\n")
129}
130
131func (g *gen) product(n int) {
132	g.p(`// Product matrix
133const (
134`)
135	for i := 0; i < n; i++ {
136		for j := 0; j < n; j++ {
137			g.p("\tp%d_%d = ", i, j)
138			for k := 0; k < n; k++ {
139				if k > 0 {
140					g.p(" + ")
141				}
142				g.p("h%d_%d*i%d_%d", i, k, k, j)
143			}
144			g.p("\n")
145		}
146		g.p("\n")
147	}
148	g.p(")\n\n")
149}
150
151func (g *gen) verify(n int) {
152	g.p(`// Verify that product is the identity matrix
153const ok =
154`)
155	for i := 0; i < n; i++ {
156		for j := 0; j < n; j++ {
157			if j == 0 {
158				g.p("\t")
159			} else {
160				g.p(" && ")
161			}
162			v := 0
163			if i == j {
164				v = 1
165			}
166			g.p("p%d_%d == %d", i, j, v)
167		}
168		g.p(" &&\n")
169	}
170	g.p("\ttrue\n\n")
171
172	// verify ok at type-check time
173	if *out == "" {
174		g.p("const _ = assert(ok)\n\n")
175	}
176}
177
178func (g *gen) printProduct(n int) {
179	g.p("func printProduct() {\n")
180	for i := 0; i < n; i++ {
181		g.p("\tprintln(")
182		for j := 0; j < n; j++ {
183			if j > 0 {
184				g.p(", ")
185			}
186			g.p("p%d_%d", i, j)
187		}
188		g.p(")\n")
189	}
190	g.p("}\n\n")
191}
192
193func (g *gen) binomials(n int) {
194	g.p(`// Binomials
195const (
196`)
197	for j := 0; j <= n; j++ {
198		if j > 0 {
199			g.p("\n")
200		}
201		for k := 0; k <= j; k++ {
202			g.p("\tb%d_%d = f%d / (f%d*f%d)\n", j, k, j, k, j-k)
203		}
204	}
205	g.p(")\n\n")
206}
207
208func (g *gen) factorials(n int) {
209	g.p(`// Factorials
210const (
211	f0 = 1
212	f1 = 1
213`)
214	for i := 2; i <= n; i++ {
215		g.p("\tf%d = f%d * %d\n", i, i-1, i)
216	}
217	g.p(")\n\n")
218}
219