1// Copyright 2013 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package types2_test 6 7import ( 8 "bytes" 9 "cmd/compile/internal/syntax" 10 "flag" 11 "fmt" 12 "io/ioutil" 13 "testing" 14 15 . "cmd/compile/internal/types2" 16) 17 18var ( 19 H = flag.Int("H", 5, "Hilbert matrix size") 20 out = flag.String("out", "", "write generated program to out") 21) 22 23func TestHilbert(t *testing.T) { 24 // generate source 25 src := program(*H, *out) 26 if *out != "" { 27 ioutil.WriteFile(*out, src, 0666) 28 return 29 } 30 31 // parse source 32 f, err := syntax.Parse(syntax.NewFileBase("hilbert.go"), bytes.NewReader(src), nil, nil, 0) 33 if err != nil { 34 t.Fatal(err) 35 } 36 37 // type-check file 38 DefPredeclaredTestFuncs() // define assert built-in 39 conf := Config{Importer: defaultImporter()} 40 _, err = conf.Check(f.PkgName.Value, []*syntax.File{f}, nil) 41 if err != nil { 42 t.Fatal(err) 43 } 44} 45 46func program(n int, out string) []byte { 47 var g gen 48 49 g.p(`// Code generated by: go test -run=Hilbert -H=%d -out=%q. DO NOT EDIT. 50 51// +`+`build ignore 52 53// This program tests arbitrary precision constant arithmetic 54// by generating the constant elements of a Hilbert matrix H, 55// its inverse I, and the product P = H*I. The product should 56// be the identity matrix. 57package main 58 59func main() { 60 if !ok { 61 printProduct() 62 return 63 } 64 println("PASS") 65} 66 67`, n, out) 68 g.hilbert(n) 69 g.inverse(n) 70 g.product(n) 71 g.verify(n) 72 g.printProduct(n) 73 g.binomials(2*n - 1) 74 g.factorials(2*n - 1) 75 76 return g.Bytes() 77} 78 79type gen struct { 80 bytes.Buffer 81} 82 83func (g *gen) p(format string, args ...interface{}) { 84 fmt.Fprintf(&g.Buffer, format, args...) 85} 86 87func (g *gen) hilbert(n int) { 88 g.p(`// Hilbert matrix, n = %d 89const ( 90`, n) 91 for i := 0; i < n; i++ { 92 g.p("\t") 93 for j := 0; j < n; j++ { 94 if j > 0 { 95 g.p(", ") 96 } 97 g.p("h%d_%d", i, j) 98 } 99 if i == 0 { 100 g.p(" = ") 101 for j := 0; j < n; j++ { 102 if j > 0 { 103 g.p(", ") 104 } 105 g.p("1.0/(iota + %d)", j+1) 106 } 107 } 108 g.p("\n") 109 } 110 g.p(")\n\n") 111} 112 113func (g *gen) inverse(n int) { 114 g.p(`// Inverse Hilbert matrix 115const ( 116`) 117 for i := 0; i < n; i++ { 118 for j := 0; j < n; j++ { 119 s := "+" 120 if (i+j)&1 != 0 { 121 s = "-" 122 } 123 g.p("\ti%d_%d = %s%d * b%d_%d * b%d_%d * b%d_%d * b%d_%d\n", 124 i, j, s, i+j+1, n+i, n-j-1, n+j, n-i-1, i+j, i, i+j, i) 125 } 126 g.p("\n") 127 } 128 g.p(")\n\n") 129} 130 131func (g *gen) product(n int) { 132 g.p(`// Product matrix 133const ( 134`) 135 for i := 0; i < n; i++ { 136 for j := 0; j < n; j++ { 137 g.p("\tp%d_%d = ", i, j) 138 for k := 0; k < n; k++ { 139 if k > 0 { 140 g.p(" + ") 141 } 142 g.p("h%d_%d*i%d_%d", i, k, k, j) 143 } 144 g.p("\n") 145 } 146 g.p("\n") 147 } 148 g.p(")\n\n") 149} 150 151func (g *gen) verify(n int) { 152 g.p(`// Verify that product is the identity matrix 153const ok = 154`) 155 for i := 0; i < n; i++ { 156 for j := 0; j < n; j++ { 157 if j == 0 { 158 g.p("\t") 159 } else { 160 g.p(" && ") 161 } 162 v := 0 163 if i == j { 164 v = 1 165 } 166 g.p("p%d_%d == %d", i, j, v) 167 } 168 g.p(" &&\n") 169 } 170 g.p("\ttrue\n\n") 171 172 // verify ok at type-check time 173 if *out == "" { 174 g.p("const _ = assert(ok)\n\n") 175 } 176} 177 178func (g *gen) printProduct(n int) { 179 g.p("func printProduct() {\n") 180 for i := 0; i < n; i++ { 181 g.p("\tprintln(") 182 for j := 0; j < n; j++ { 183 if j > 0 { 184 g.p(", ") 185 } 186 g.p("p%d_%d", i, j) 187 } 188 g.p(")\n") 189 } 190 g.p("}\n\n") 191} 192 193func (g *gen) binomials(n int) { 194 g.p(`// Binomials 195const ( 196`) 197 for j := 0; j <= n; j++ { 198 if j > 0 { 199 g.p("\n") 200 } 201 for k := 0; k <= j; k++ { 202 g.p("\tb%d_%d = f%d / (f%d*f%d)\n", j, k, j, k, j-k) 203 } 204 } 205 g.p(")\n\n") 206} 207 208func (g *gen) factorials(n int) { 209 g.p(`// Factorials 210const ( 211 f0 = 1 212 f1 = 1 213`) 214 for i := 2; i <= n; i++ { 215 g.p("\tf%d = f%d * %d\n", i, i-1, i) 216 } 217 g.p(")\n\n") 218} 219