1// Copyright ©2015 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package mat
6
7import "fmt"
8
9// Product calculates the product of the given factors and places the result in
10// the receiver. The order of multiplication operations is optimized to minimize
11// the number of floating point operations on the basis that all matrix
12// multiplications are general.
13func (m *Dense) Product(factors ...Matrix) {
14	// The operation order optimisation is the naive O(n^3) dynamic
15	// programming approach and does not take into consideration
16	// finer-grained optimisations that might be available.
17	//
18	// TODO(kortschak) Consider using the O(nlogn) or O(mlogn)
19	// algorithms that are available. e.g.
20	//
21	// e.g. http://www.jofcis.com/publishedpapers/2014_10_10_4299_4306.pdf
22	//
23	// In the case that this is replaced, retain this code in
24	// tests to compare against.
25
26	r, c := m.Dims()
27	switch len(factors) {
28	case 0:
29		if r != 0 || c != 0 {
30			panic(ErrShape)
31		}
32		return
33	case 1:
34		m.reuseAs(factors[0].Dims())
35		m.Copy(factors[0])
36		return
37	case 2:
38		// Don't do work that we know the answer to.
39		m.Mul(factors[0], factors[1])
40		return
41	}
42
43	p := newMultiplier(m, factors)
44	p.optimize()
45	result := p.multiply()
46	m.reuseAs(result.Dims())
47	m.Copy(result)
48	putWorkspace(result)
49}
50
51// debugProductWalk enables debugging output for Product.
52const debugProductWalk = false
53
54// multiplier performs operation order optimisation and tree traversal.
55type multiplier struct {
56	// factors is the ordered set of
57	// factors to multiply.
58	factors []Matrix
59	// dims is the chain of factor
60	// dimensions.
61	dims []int
62
63	// table contains the dynamic
64	// programming costs and subchain
65	// division indices.
66	table table
67}
68
69func newMultiplier(m *Dense, factors []Matrix) *multiplier {
70	// Check size early, but don't yet
71	// allocate data for m.
72	r, c := m.Dims()
73	fr, fc := factors[0].Dims() // newMultiplier is only called with len(factors) > 2.
74	if !m.IsZero() {
75		if fr != r {
76			panic(ErrShape)
77		}
78		if _, lc := factors[len(factors)-1].Dims(); lc != c {
79			panic(ErrShape)
80		}
81	}
82
83	dims := make([]int, len(factors)+1)
84	dims[0] = r
85	dims[len(dims)-1] = c
86	pc := fc
87	for i, f := range factors[1:] {
88		cr, cc := f.Dims()
89		dims[i+1] = cr
90		if pc != cr {
91			panic(ErrShape)
92		}
93		pc = cc
94	}
95
96	return &multiplier{
97		factors: factors,
98		dims:    dims,
99		table:   newTable(len(factors)),
100	}
101}
102
103// optimize determines an optimal matrix multiply operation order.
104func (p *multiplier) optimize() {
105	if debugProductWalk {
106		fmt.Printf("chain dims: %v\n", p.dims)
107	}
108	const maxInt = int(^uint(0) >> 1)
109	for f := 1; f < len(p.factors); f++ {
110		for i := 0; i < len(p.factors)-f; i++ {
111			j := i + f
112			p.table.set(i, j, entry{cost: maxInt})
113			for k := i; k < j; k++ {
114				cost := p.table.at(i, k).cost + p.table.at(k+1, j).cost + p.dims[i]*p.dims[k+1]*p.dims[j+1]
115				if cost < p.table.at(i, j).cost {
116					p.table.set(i, j, entry{cost: cost, k: k})
117				}
118			}
119		}
120	}
121}
122
123// multiply walks the optimal operation tree found by optimize,
124// leaving the final result in the stack. It returns the
125// product, which may be copied but should be returned to
126// the workspace pool.
127func (p *multiplier) multiply() *Dense {
128	result, _ := p.multiplySubchain(0, len(p.factors)-1)
129	if debugProductWalk {
130		r, c := result.Dims()
131		fmt.Printf("\tpop result (%d×%d) cost=%d\n", r, c, p.table.at(0, len(p.factors)-1).cost)
132	}
133	return result.(*Dense)
134}
135
136func (p *multiplier) multiplySubchain(i, j int) (m Matrix, intermediate bool) {
137	if i == j {
138		return p.factors[i], false
139	}
140
141	a, aTmp := p.multiplySubchain(i, p.table.at(i, j).k)
142	b, bTmp := p.multiplySubchain(p.table.at(i, j).k+1, j)
143
144	ar, ac := a.Dims()
145	br, bc := b.Dims()
146	if ac != br {
147		// Panic with a string since this
148		// is not a user-facing panic.
149		panic(ErrShape.Error())
150	}
151
152	if debugProductWalk {
153		fmt.Printf("\tpush f[%d] (%d×%d)%s * f[%d] (%d×%d)%s\n",
154			i, ar, ac, result(aTmp), j, br, bc, result(bTmp))
155	}
156
157	r := getWorkspace(ar, bc, false)
158	r.Mul(a, b)
159	if aTmp {
160		putWorkspace(a.(*Dense))
161	}
162	if bTmp {
163		putWorkspace(b.(*Dense))
164	}
165	return r, true
166}
167
168type entry struct {
169	k    int // is the chain subdivision index.
170	cost int // cost is the cost of the operation.
171}
172
173// table is a row major n×n dynamic programming table.
174type table struct {
175	n       int
176	entries []entry
177}
178
179func newTable(n int) table {
180	return table{n: n, entries: make([]entry, n*n)}
181}
182
183func (t table) at(i, j int) entry     { return t.entries[i*t.n+j] }
184func (t table) set(i, j int, e entry) { t.entries[i*t.n+j] = e }
185
186type result bool
187
188func (r result) String() string {
189	if r {
190		return " (popped result)"
191	}
192	return ""
193}
194