1 /**************************************************************************************************
2 *                                                                                                 *
3 * This file is part of BLASFEO.                                                                   *
4 *                                                                                                 *
5 * BLASFEO -- BLAS for embedded optimization.                                                      *
6 * Copyright (C) 2019 by Gianluca Frison.                                                          *
7 * Developed at IMTEK (University of Freiburg) under the supervision of Moritz Diehl.              *
8 * All rights reserved.                                                                            *
9 *                                                                                                 *
10 * The 2-Clause BSD License                                                                        *
11 *                                                                                                 *
12 * Redistribution and use in source and binary forms, with or without                              *
13 * modification, are permitted provided that the following conditions are met:                     *
14 *                                                                                                 *
15 * 1. Redistributions of source code must retain the above copyright notice, this                  *
16 *    list of conditions and the following disclaimer.                                             *
17 * 2. Redistributions in binary form must reproduce the above copyright notice,                    *
18 *    this list of conditions and the following disclaimer in the documentation                    *
19 *    and/or other materials provided with the distribution.                                       *
20 *                                                                                                 *
21 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND                 *
22 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED                   *
23 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE                          *
24 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR                 *
25 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES                  *
26 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;                    *
27 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND                     *
28 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT                      *
29 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS                   *
30 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.                                    *
31 *                                                                                                 *
32 * Author: Gianluca Frison, gianluca.frison (at) imtek.uni-freiburg.de                             *
33 *                                                                                                 *
34 **************************************************************************************************/
35 
36 #include <stdlib.h>
37 #include <stdio.h>
38 
39 #include "../include/blasfeo_common.h"
40 #include "../include/blasfeo_d_kernel.h"
41 
42 
43 
44 /****************************
45 * old interface
46 ****************************/
47 
dgemm_diag_left_lib(int m,int n,double alpha,double * dA,double * pB,int sdb,double beta,double * pC,int sdc,double * pD,int sdd)48 void dgemm_diag_left_lib(int m, int n, double alpha, double *dA, double *pB, int sdb, double beta, double *pC, int sdc, double *pD, int sdd)
49 	{
50 
51 	if(m<=0 || n<=0)
52 		return;
53 
54 	const int bs = 4;
55 
56 	int ii;
57 
58 	ii = 0;
59 	if(beta==0.0)
60 		{
61 		for( ; ii<m-3; ii+=4)
62 			{
63 			kernel_dgemm_diag_left_4_a0_lib4(n, &alpha, &dA[ii], &pB[ii*sdb], &pD[ii*sdd]);
64 			}
65 		}
66 	else
67 		{
68 		for( ; ii<m-3; ii+=4)
69 			{
70 			kernel_dgemm_diag_left_4_lib4(n, &alpha, &dA[ii], &pB[ii*sdb], &beta, &pC[ii*sdc], &pD[ii*sdd]);
71 			}
72 		}
73 	if(m-ii>0)
74 		{
75 		if(m-ii==1)
76 			kernel_dgemm_diag_left_1_lib4(n, &alpha, &dA[ii], &pB[ii*sdb], &beta, &pC[ii*sdc], &pD[ii*sdd]);
77 		else if(m-ii==2)
78 			kernel_dgemm_diag_left_2_lib4(n, &alpha, &dA[ii], &pB[ii*sdb], &beta, &pC[ii*sdc], &pD[ii*sdd]);
79 		else // if(m-ii==3)
80 			kernel_dgemm_diag_left_3_lib4(n, &alpha, &dA[ii], &pB[ii*sdb], &beta, &pC[ii*sdc], &pD[ii*sdd]);
81 		}
82 
83 	}
84 
85 
86 
dgemm_diag_right_lib(int m,int n,double alpha,double * pA,int sda,double * dB,double beta,double * pC,int sdc,double * pD,int sdd)87 void dgemm_diag_right_lib(int m, int n, double alpha, double *pA, int sda, double *dB, double beta, double *pC, int sdc, double *pD, int sdd)
88 	{
89 
90 	if(m<=0 || n<=0)
91 		return;
92 
93 	const int bs = 4;
94 
95 	int ii;
96 
97 	ii = 0;
98 	if(beta==0.0)
99 		{
100 		for( ; ii<n-3; ii+=4)
101 			{
102 			kernel_dgemm_diag_right_4_a0_lib4(m, &alpha, &pA[ii*bs], sda, &dB[ii], &pD[ii*bs], sdd);
103 			}
104 		}
105 	else
106 		{
107 		for( ; ii<n-3; ii+=4)
108 			{
109 			kernel_dgemm_diag_right_4_lib4(m, &alpha, &pA[ii*bs], sda, &dB[ii], &beta, &pC[ii*bs], sdc, &pD[ii*bs], sdd);
110 			}
111 		}
112 	if(n-ii>0)
113 		{
114 		if(n-ii==1)
115 			kernel_dgemm_diag_right_1_lib4(m, &alpha, &pA[ii*bs], sda, &dB[ii], &beta, &pC[ii*bs], sdc, &pD[ii*bs], sdd);
116 		else if(n-ii==2)
117 			kernel_dgemm_diag_right_2_lib4(m, &alpha, &pA[ii*bs], sda, &dB[ii], &beta, &pC[ii*bs], sdc, &pD[ii*bs], sdd);
118 		else // if(n-ii==3)
119 			kernel_dgemm_diag_right_3_lib4(m, &alpha, &pA[ii*bs], sda, &dB[ii], &beta, &pC[ii*bs], sdc, &pD[ii*bs], sdd);
120 		}
121 
122 	}
123 
124 
125 
126 /****************************
127 * new interface
128 ****************************/
129 
130 
131 
132 #if defined(LA_HIGH_PERFORMANCE)
133 
134 
135 
136 // dgemm with A diagonal matrix (stored as strvec)
blasfeo_dgemm_dn(int m,int n,double alpha,struct blasfeo_dvec * sA,int ai,struct blasfeo_dmat * sB,int bi,int bj,double beta,struct blasfeo_dmat * sC,int ci,int cj,struct blasfeo_dmat * sD,int di,int dj)137 void blasfeo_dgemm_dn(int m, int n, double alpha, struct blasfeo_dvec *sA, int ai, struct blasfeo_dmat *sB, int bi, int bj, double beta, struct blasfeo_dmat *sC, int ci, int cj, struct blasfeo_dmat *sD, int di, int dj)
138 	{
139 	if(m<=0 | n<=0)
140 		return;
141 	if(bi!=0 | ci!=0 | di!=0)
142 		{
143 		printf("\nblasfeo_dgemm_dn: feature not implemented yet: bi=%d, ci=%d, di=%d\n", bi, ci, di);
144 		exit(1);
145 		}
146 
147 	// invalidate stored inverse diagonal of result matrix
148 	sD->use_dA = 0;
149 
150 	const int bs = 4;
151 	int sdb = sB->cn;
152 	int sdc = sC->cn;
153 	int sdd = sD->cn;
154 	double *dA = sA->pa + ai;
155 	double *pB = sB->pA + bj*bs;
156 	double *pC = sC->pA + cj*bs;
157 	double *pD = sD->pA + dj*bs;
158 	dgemm_diag_left_lib(m, n, alpha, dA, pB, sdb, beta, pC, sdc, pD, sdd);
159 	return;
160 	}
161 
162 
163 
164 // dgemm with B diagonal matrix (stored as strvec)
blasfeo_dgemm_nd(int m,int n,double alpha,struct blasfeo_dmat * sA,int ai,int aj,struct blasfeo_dvec * sB,int bi,double beta,struct blasfeo_dmat * sC,int ci,int cj,struct blasfeo_dmat * sD,int di,int dj)165 void blasfeo_dgemm_nd(int m, int n, double alpha, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sB, int bi, double beta, struct blasfeo_dmat *sC, int ci, int cj, struct blasfeo_dmat *sD, int di, int dj)
166 	{
167 	if(m<=0 | n<=0)
168 		return;
169 	if(ai!=0 | ci!=0 | di!=0)
170 		{
171 		printf("\nblasfeo_dgemm_nd: feature not implemented yet: ai=%d, ci=%d, di=%d\n", ai, ci, di);
172 		exit(1);
173 		}
174 
175 	// invalidate stored inverse diagonal of result matrix
176 	sD->use_dA = 0;
177 
178 	const int bs = 4;
179 	int sda = sA->cn;
180 	int sdc = sC->cn;
181 	int sdd = sD->cn;
182 	double *pA = sA->pA + aj*bs;
183 	double *dB = sB->pa + bi;
184 	double *pC = sC->pA + cj*bs;
185 	double *pD = sD->pA + dj*bs;
186 	dgemm_diag_right_lib(m, n, alpha, pA, sda, dB, beta, pC, sdc, pD, sdd);
187 	return;
188 	}
189 
190 
191 
192 #else
193 
194 #error : wrong LA choice
195 
196 #endif
197 
198 
199 
200