1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 2009-2021 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25
26 #if ! defined (octave_Sparse_perm_op_defs_h)
27 #define octave_Sparse_perm_op_defs_h 1
28
29 #include "octave-config.h"
30
31 #include "PermMatrix.h"
32 #include "lo-array-errwarn.h"
33 #include "oct-locbuf.h"
34 #include "oct-sort.h"
35 #include "quit.h"
36
37 // Matrix multiplication
38
39 template <typename SM>
octinternal_do_mul_colpm_sm(const octave_idx_type * pcol,const SM & a)40 SM octinternal_do_mul_colpm_sm (const octave_idx_type *pcol, const SM& a)
41 // Relabel the rows according to pcol.
42 {
43 const octave_idx_type nr = a.rows ();
44 const octave_idx_type nc = a.cols ();
45 const octave_idx_type nent = a.nnz ();
46 SM r (nr, nc, nent);
47
48 octave_sort<octave_idx_type> sort;
49
50 for (octave_idx_type j = 0; j <= nc; ++j)
51 r.xcidx (j) = a.cidx (j);
52
53 for (octave_idx_type j = 0; j < nc; j++)
54 {
55 octave_quit ();
56
57 OCTAVE_LOCAL_BUFFER (octave_idx_type, sidx, r.xcidx (j+1) - r.xcidx (j));
58 for (octave_idx_type i = r.xcidx (j), ii = 0; i < r.xcidx (j+1); i++)
59 {
60 sidx[ii++]=i;
61 r.xridx (i) = pcol[a.ridx (i)];
62 }
63 sort.sort (r.xridx () + r.xcidx (j), sidx, r.xcidx (j+1) - r.xcidx (j));
64 for (octave_idx_type i = r.xcidx (j), ii = 0; i < r.xcidx (j+1); i++)
65 r.xdata (i) = a.data (sidx[ii++]);
66 }
67
68 return r;
69 }
70
71 template <typename SM>
octinternal_do_mul_pm_sm(const PermMatrix & p,const SM & a)72 SM octinternal_do_mul_pm_sm (const PermMatrix& p, const SM& a)
73 {
74 const octave_idx_type nr = a.rows ();
75 if (p.cols () != nr)
76 octave::err_nonconformant ("operator *",
77 p.rows (), p.cols (), a.rows (), a.cols ());
78
79 return octinternal_do_mul_colpm_sm (p.col_perm_vec ().data (), a);
80 }
81
82 template <typename SM>
octinternal_do_mul_sm_rowpm(const SM & a,const octave_idx_type * prow)83 SM octinternal_do_mul_sm_rowpm (const SM& a, const octave_idx_type *prow)
84 // For a row permutation, iterate across the source a and stuff the
85 // results into the correct destination column in r.
86 {
87 const octave_idx_type nr = a.rows ();
88 const octave_idx_type nc = a.cols ();
89 const octave_idx_type nent = a.nnz ();
90 SM r (nr, nc, nent);
91
92 for (octave_idx_type j_src = 0; j_src < nc; ++j_src)
93 r.xcidx (prow[j_src]) = a.cidx (j_src+1) - a.cidx (j_src);
94 octave_idx_type k = 0;
95 for (octave_idx_type j = 0; j < nc; ++j)
96 {
97 const octave_idx_type tmp = r.xcidx (j);
98 r.xcidx (j) = k;
99 k += tmp;
100 }
101 r.xcidx (nc) = nent;
102
103 octave_idx_type k_src = 0;
104 for (octave_idx_type j_src = 0; j_src < nc; ++j_src)
105 {
106 octave_quit ();
107 const octave_idx_type j = prow[j_src];
108 const octave_idx_type kend_src = a.cidx (j_src + 1);
109 for (k = r.xcidx (j); k_src < kend_src; ++k, ++k_src)
110 {
111 r.xridx (k) = a.ridx (k_src);
112 r.xdata (k) = a.data (k_src);
113 }
114 }
115 assert (k_src == nent);
116
117 return r;
118 }
119
120 template <typename SM>
octinternal_do_mul_sm_colpm(const SM & a,const octave_idx_type * pcol)121 SM octinternal_do_mul_sm_colpm (const SM& a, const octave_idx_type *pcol)
122 // For a column permutation, iterate across the destination r and pull
123 // data from the correct column of a.
124 {
125 const octave_idx_type nr = a.rows ();
126 const octave_idx_type nc = a.cols ();
127 const octave_idx_type nent = a.nnz ();
128 SM r (nr, nc, nent);
129
130 for (octave_idx_type j = 0; j < nc; ++j)
131 {
132 const octave_idx_type j_src = pcol[j];
133 r.xcidx (j+1) = r.xcidx (j) + (a.cidx (j_src+1) - a.cidx (j_src));
134 }
135 assert (r.xcidx (nc) == nent);
136
137 octave_idx_type k = 0;
138 for (octave_idx_type j = 0; j < nc; ++j)
139 {
140 octave_quit ();
141 const octave_idx_type j_src = pcol[j];
142 octave_idx_type k_src;
143 const octave_idx_type kend_src = a.cidx (j_src + 1);
144 for (k_src = a.cidx (j_src); k_src < kend_src; ++k_src, ++k)
145 {
146 r.xridx (k) = a.ridx (k_src);
147 r.xdata (k) = a.data (k_src);
148 }
149 }
150 assert (k == nent);
151
152 return r;
153 }
154
155 template <typename SM>
octinternal_do_mul_sm_pm(const SM & a,const PermMatrix & p)156 SM octinternal_do_mul_sm_pm (const SM& a, const PermMatrix& p)
157 {
158 const octave_idx_type nc = a.cols ();
159 if (p.rows () != nc)
160 octave::err_nonconformant ("operator *",
161 a.rows (), a.cols (), p.rows (), p.cols ());
162
163 return octinternal_do_mul_sm_colpm (a, p.col_perm_vec ().data ());
164 }
165
166 #endif
167