1 /***************************************************************************
2 Copyright (c) 2020, The OpenBLAS Project
3 All rights reserved.
4 Redistribution and use in source and binary forms, with or without
5 modification, are permitted provided that the following conditions are
6 met:
7 1. Redistributions of source code must retain the above copyright
8 notice, this list of conditions and the following disclaimer.
9 2. Redistributions in binary form must reproduce the above copyright
10 notice, this list of conditions and the following disclaimer in
11 the documentation and/or other materials provided with the
12 distribution.
13 3. Neither the name of the OpenBLAS project nor the names of
14 its contributors may be used to endorse or promote products
15 derived from this software without specific prior written permission.
16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
25 USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 *****************************************************************************/
27 
28 #include "common.h"
29 #include <math.h>
30 #include <float.h>
31 
32 #if defined(DOUBLE)
33 
34 #define ABS fabs
35 #define VSETVL(n) vsetvl_e64m8(n)
36 #define VSETVL_MAX vsetvlmax_e64m1()
37 #define FLOAT_V_T vfloat64m8_t
38 #define FLOAT_V_T_M1 vfloat64m1_t
39 #define VLEV_FLOAT vle_v_f64m8
40 #define VLSEV_FLOAT vlse_v_f64m8
41 #define VFREDMAXVS_FLOAT vfredmax_vs_f64m8_f64m1
42 #define MASK_T vbool8_t
43 #define VMFLTVV_FLOAT vmflt_vv_f64m8_b8
44 #define VFMVVF_FLOAT vfmv_v_f_f64m8
45 #define VFMVVF_FLOAT_M1 vfmv_v_f_f64m1
46 #define VFMAXVV_FLOAT vfmax_vv_f64m8
47 #define VMFGEVF_FLOAT vmfge_vf_f64m8_b8
48 #define VMFIRSTM vmfirst_m_b8
49 #define UINT_V_T vuint64m8_t
50 #define VIDV_MASK_UINT vid_v_u64m8_m
51 #define VIDV_UINT vid_v_u64m8
52 #define VADDVX_MASK_UINT vadd_vx_u64m8_m
53 #define VADDVX_UINT vadd_vx_u64m8
54 #define VMVVX_UINT vmv_v_x_u64m8
55 #else
56 
57 #define ABS fabsf
58 #define VSETVL(n) vsetvl_e32m8(n)
59 #define VSETVL_MAX vsetvlmax_e32m1()
60 #define FLOAT_V_T vfloat32m8_t
61 #define FLOAT_V_T_M1 vfloat32m1_t
62 #define VLEV_FLOAT vle_v_f32m8
63 #define VLSEV_FLOAT vlse_v_f32m8
64 #define VFREDMAXVS_FLOAT vfredmax_vs_f32m8_f32m1
65 #define MASK_T vbool4_t
66 #define VMFLTVV_FLOAT vmflt_vv_f32m8_b4
67 #define VFMVVF_FLOAT vfmv_v_f_f32m8
68 #define VFMVVF_FLOAT_M1 vfmv_v_f_f32m1
69 #define VFMAXVV_FLOAT vfmax_vv_f32m8
70 #define VMFGEVF_FLOAT vmfge_vf_f32m8_b4
71 #define VMFIRSTM vmfirst_m_b4
72 #define UINT_V_T vuint32m8_t
73 #define VIDV_MASK_UINT vid_v_u32m8_m
74 #define VIDV_UINT vid_v_u32m8
75 #define VADDVX_MASK_UINT vadd_vx_u32m8_m
76 #define VADDVX_UINT vadd_vx_u32m8
77 #define VMVVX_UINT vmv_v_x_u32m8
78 #endif
79 
80 
CNAME(BLASLONG n,FLOAT * x,BLASLONG inc_x)81 BLASLONG CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x)
82 {
83 	BLASLONG i=0, j=0;
84         unsigned int max_index = 0;
85 	if (n <= 0 || inc_x <= 0) return(max_index);
86 	FLOAT maxf=-FLT_MAX;
87 
88         FLOAT_V_T vx, v_max;
89         UINT_V_T v_max_index;
90         MASK_T mask;
91         unsigned int gvl = 0;
92         FLOAT_V_T_M1 v_res, v_min;
93         gvl = VSETVL_MAX;
94         v_res = VFMVVF_FLOAT_M1(0, gvl);
95         v_min = VFMVVF_FLOAT_M1(-FLT_MAX, gvl);
96 
97         if(inc_x == 1){
98                 gvl = VSETVL(n);
99                 v_max_index = VMVVX_UINT(0, gvl);
100                 v_max = VFMVVF_FLOAT(-FLT_MAX, gvl);
101                 for(i=0,j=0; i < n/gvl; i++){
102                         vx = VLEV_FLOAT(&x[j], gvl);
103 
104                         //index where element greater than v_max
105                         mask = VMFLTVV_FLOAT(v_max, vx, gvl);
106                         v_max_index = VIDV_MASK_UINT(mask, v_max_index, gvl);
107                         v_max_index = VADDVX_MASK_UINT(mask, v_max_index, v_max_index, j,gvl);
108 
109                         //update v_max and start_index j
110                         v_max = VFMAXVV_FLOAT(v_max, vx, gvl);
111                         j += gvl;
112                 }
113                 v_res = VFREDMAXVS_FLOAT(v_res, v_max, v_min, gvl);
114                 maxf = v_res[0];
115                 mask = VMFGEVF_FLOAT(v_max, maxf, gvl);
116                 max_index = VMFIRSTM(mask,gvl);
117                 max_index = v_max_index[max_index];
118 
119                 if(j < n){
120                         gvl = VSETVL(n-j);
121                         v_max = VLEV_FLOAT(&x[j], gvl);
122 
123                         v_res = VFREDMAXVS_FLOAT(v_res, v_max, v_min, gvl);
124                         FLOAT cur_maxf = v_res[0];
125                         if(cur_maxf > maxf){
126                                 //tail index
127                                 v_max_index = VIDV_UINT(gvl);
128                                 v_max_index = VADDVX_UINT(v_max_index, j, gvl);
129 
130                                 mask = VMFGEVF_FLOAT(v_max, cur_maxf, gvl);
131                                 max_index = VMFIRSTM(mask,gvl);
132                                 max_index = v_max_index[max_index];
133                         }
134                 }
135         }else{
136                 gvl = VSETVL(n);
137                 unsigned int stride_x = inc_x * sizeof(FLOAT);
138                 unsigned int idx = 0, inc_v = gvl * inc_x;
139 
140                 v_max = VFMVVF_FLOAT(-FLT_MAX, gvl);
141                 v_max_index = VMVVX_UINT(0, gvl);
142                 for(i=0,j=0; i < n/gvl; i++){
143                         vx = VLSEV_FLOAT(&x[idx], stride_x, gvl);
144 
145                         //index where element greater than v_max
146                         mask = VMFLTVV_FLOAT(v_max, vx, gvl);
147                         v_max_index = VIDV_MASK_UINT(mask, v_max_index, gvl);
148                         v_max_index = VADDVX_MASK_UINT(mask, v_max_index, v_max_index, j,gvl);
149 
150                         //update v_max and start_index j
151                         v_max = VFMAXVV_FLOAT(v_max, vx, gvl);
152                         j += gvl;
153                         idx += inc_v;
154                 }
155                 v_res = VFREDMAXVS_FLOAT(v_res, v_max, v_min, gvl);
156                 maxf = v_res[0];
157                 mask = VMFGEVF_FLOAT(v_max, maxf, gvl);
158                 max_index = VMFIRSTM(mask,gvl);
159                 max_index = v_max_index[max_index];
160 
161                 if(j < n){
162                         gvl = VSETVL(n-j);
163                         v_max = VLSEV_FLOAT(&x[idx], stride_x, gvl);
164 
165                         v_res = VFREDMAXVS_FLOAT(v_res, v_max, v_min, gvl);
166                         FLOAT cur_maxf = v_res[0];
167                         if(cur_maxf > maxf){
168                                 //tail index
169                                 v_max_index = VIDV_UINT(gvl);
170                                 v_max_index = VADDVX_UINT(v_max_index, j, gvl);
171 
172                                 mask = VMFGEVF_FLOAT(v_max, cur_maxf, gvl);
173                                 max_index = VMFIRSTM(mask,gvl);
174                                 max_index = v_max_index[max_index];
175                         }
176                 }
177         }
178 	return(max_index+1);
179 }
180 
181 
182