1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 2009-2021 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING.  If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 #if ! defined (octave_Sparse_perm_op_defs_h)
27 #define octave_Sparse_perm_op_defs_h 1
28 
29 #include "octave-config.h"
30 
31 #include "PermMatrix.h"
32 #include "lo-array-errwarn.h"
33 #include "oct-locbuf.h"
34 #include "oct-sort.h"
35 #include "quit.h"
36 
37 // Matrix multiplication
38 
39 template <typename SM>
octinternal_do_mul_colpm_sm(const octave_idx_type * pcol,const SM & a)40 SM octinternal_do_mul_colpm_sm (const octave_idx_type *pcol, const SM& a)
41 // Relabel the rows according to pcol.
42 {
43   const octave_idx_type nr = a.rows ();
44   const octave_idx_type nc = a.cols ();
45   const octave_idx_type nent = a.nnz ();
46   SM r (nr, nc, nent);
47 
48   octave_sort<octave_idx_type> sort;
49 
50   for (octave_idx_type j = 0; j <= nc; ++j)
51     r.xcidx (j) = a.cidx (j);
52 
53   for (octave_idx_type j = 0; j < nc; j++)
54     {
55       octave_quit ();
56 
57       OCTAVE_LOCAL_BUFFER (octave_idx_type, sidx, r.xcidx (j+1) - r.xcidx (j));
58       for (octave_idx_type i = r.xcidx (j), ii = 0; i < r.xcidx (j+1); i++)
59         {
60           sidx[ii++]=i;
61           r.xridx (i) = pcol[a.ridx (i)];
62         }
63       sort.sort (r.xridx () + r.xcidx (j), sidx, r.xcidx (j+1) - r.xcidx (j));
64       for (octave_idx_type i = r.xcidx (j), ii = 0; i < r.xcidx (j+1); i++)
65         r.xdata (i) = a.data (sidx[ii++]);
66     }
67 
68   return r;
69 }
70 
71 template <typename SM>
octinternal_do_mul_pm_sm(const PermMatrix & p,const SM & a)72 SM octinternal_do_mul_pm_sm (const PermMatrix& p, const SM& a)
73 {
74   const octave_idx_type nr = a.rows ();
75   if (p.cols () != nr)
76     octave::err_nonconformant ("operator *",
77                                p.rows (), p.cols (), a.rows (), a.cols ());
78 
79   return octinternal_do_mul_colpm_sm (p.col_perm_vec ().data (), a);
80 }
81 
82 template <typename SM>
octinternal_do_mul_sm_rowpm(const SM & a,const octave_idx_type * prow)83 SM octinternal_do_mul_sm_rowpm (const SM& a, const octave_idx_type *prow)
84 // For a row permutation, iterate across the source a and stuff the
85 // results into the correct destination column in r.
86 {
87   const octave_idx_type nr = a.rows ();
88   const octave_idx_type nc = a.cols ();
89   const octave_idx_type nent = a.nnz ();
90   SM r (nr, nc, nent);
91 
92   for (octave_idx_type j_src = 0; j_src < nc; ++j_src)
93     r.xcidx (prow[j_src]) = a.cidx (j_src+1) - a.cidx (j_src);
94   octave_idx_type k = 0;
95   for (octave_idx_type j = 0; j < nc; ++j)
96     {
97       const octave_idx_type tmp = r.xcidx (j);
98       r.xcidx (j) = k;
99       k += tmp;
100     }
101   r.xcidx (nc) = nent;
102 
103   octave_idx_type k_src = 0;
104   for (octave_idx_type j_src = 0; j_src < nc; ++j_src)
105     {
106       octave_quit ();
107       const octave_idx_type j = prow[j_src];
108       const octave_idx_type kend_src = a.cidx (j_src + 1);
109       for (k = r.xcidx (j); k_src < kend_src; ++k, ++k_src)
110         {
111           r.xridx (k) = a.ridx (k_src);
112           r.xdata (k) = a.data (k_src);
113         }
114     }
115   assert (k_src == nent);
116 
117   return r;
118 }
119 
120 template <typename SM>
octinternal_do_mul_sm_colpm(const SM & a,const octave_idx_type * pcol)121 SM octinternal_do_mul_sm_colpm (const SM& a, const octave_idx_type *pcol)
122 // For a column permutation, iterate across the destination r and pull
123 // data from the correct column of a.
124 {
125   const octave_idx_type nr = a.rows ();
126   const octave_idx_type nc = a.cols ();
127   const octave_idx_type nent = a.nnz ();
128   SM r (nr, nc, nent);
129 
130   for (octave_idx_type j = 0; j < nc; ++j)
131     {
132       const octave_idx_type j_src = pcol[j];
133       r.xcidx (j+1) = r.xcidx (j) + (a.cidx (j_src+1) - a.cidx (j_src));
134     }
135   assert (r.xcidx (nc) == nent);
136 
137   octave_idx_type k = 0;
138   for (octave_idx_type j = 0; j < nc; ++j)
139     {
140       octave_quit ();
141       const octave_idx_type j_src = pcol[j];
142       octave_idx_type k_src;
143       const octave_idx_type kend_src = a.cidx (j_src + 1);
144       for (k_src = a.cidx (j_src); k_src < kend_src; ++k_src, ++k)
145         {
146           r.xridx (k) = a.ridx (k_src);
147           r.xdata (k) = a.data (k_src);
148         }
149     }
150   assert (k == nent);
151 
152   return r;
153 }
154 
155 template <typename SM>
octinternal_do_mul_sm_pm(const SM & a,const PermMatrix & p)156 SM octinternal_do_mul_sm_pm (const SM& a, const PermMatrix& p)
157 {
158   const octave_idx_type nc = a.cols ();
159   if (p.rows () != nc)
160     octave::err_nonconformant ("operator *",
161                                a.rows (), a.cols (), p.rows (), p.cols ());
162 
163   return octinternal_do_mul_sm_colpm (a, p.col_perm_vec ().data ());
164 }
165 
166 #endif
167