1 /***************************************************************************
2 Copyright (c) 2020, The OpenBLAS Project
3 All rights reserved.
4 Redistribution and use in source and binary forms, with or without
5 modification, are permitted provided that the following conditions are
6 met:
7 1. Redistributions of source code must retain the above copyright
8 notice, this list of conditions and the following disclaimer.
9 2. Redistributions in binary form must reproduce the above copyright
10 notice, this list of conditions and the following disclaimer in
11 the documentation and/or other materials provided with the
12 distribution.
13 3. Neither the name of the OpenBLAS project nor the names of
14 its contributors may be used to endorse or promote products
15 derived from this software without specific prior written permission.
16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
25 USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 *****************************************************************************/
27
28 #include "common.h"
29 #include <math.h>
30 #include <float.h>
31
32 #if defined(DOUBLE)
33
34 #define ABS fabs
35 #define VSETVL(n) vsetvl_e64m8(n)
36 #define VSETVL_MAX vsetvlmax_e64m1()
37 #define FLOAT_V_T vfloat64m8_t
38 #define FLOAT_V_T_M1 vfloat64m1_t
39 #define VLEV_FLOAT vle_v_f64m8
40 #define VLSEV_FLOAT vlse_v_f64m8
41 #define VFREDMAXVS_FLOAT vfredmax_vs_f64m8_f64m1
42 #define MASK_T vbool8_t
43 #define VMFLTVV_FLOAT vmflt_vv_f64m8_b8
44 #define VFMVVF_FLOAT vfmv_v_f_f64m8
45 #define VFMVVF_FLOAT_M1 vfmv_v_f_f64m1
46 #define VFMAXVV_FLOAT vfmax_vv_f64m8
47 #define VMFGEVF_FLOAT vmfge_vf_f64m8_b8
48 #define VMFIRSTM vmfirst_m_b8
49 #define UINT_V_T vuint64m8_t
50 #define VIDV_MASK_UINT vid_v_u64m8_m
51 #define VIDV_UINT vid_v_u64m8
52 #define VADDVX_MASK_UINT vadd_vx_u64m8_m
53 #define VADDVX_UINT vadd_vx_u64m8
54 #define VMVVX_UINT vmv_v_x_u64m8
55 #else
56
57 #define ABS fabsf
58 #define VSETVL(n) vsetvl_e32m8(n)
59 #define VSETVL_MAX vsetvlmax_e32m1()
60 #define FLOAT_V_T vfloat32m8_t
61 #define FLOAT_V_T_M1 vfloat32m1_t
62 #define VLEV_FLOAT vle_v_f32m8
63 #define VLSEV_FLOAT vlse_v_f32m8
64 #define VFREDMAXVS_FLOAT vfredmax_vs_f32m8_f32m1
65 #define MASK_T vbool4_t
66 #define VMFLTVV_FLOAT vmflt_vv_f32m8_b4
67 #define VFMVVF_FLOAT vfmv_v_f_f32m8
68 #define VFMVVF_FLOAT_M1 vfmv_v_f_f32m1
69 #define VFMAXVV_FLOAT vfmax_vv_f32m8
70 #define VMFGEVF_FLOAT vmfge_vf_f32m8_b4
71 #define VMFIRSTM vmfirst_m_b4
72 #define UINT_V_T vuint32m8_t
73 #define VIDV_MASK_UINT vid_v_u32m8_m
74 #define VIDV_UINT vid_v_u32m8
75 #define VADDVX_MASK_UINT vadd_vx_u32m8_m
76 #define VADDVX_UINT vadd_vx_u32m8
77 #define VMVVX_UINT vmv_v_x_u32m8
78 #endif
79
80
CNAME(BLASLONG n,FLOAT * x,BLASLONG inc_x)81 BLASLONG CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x)
82 {
83 BLASLONG i=0, j=0;
84 unsigned int max_index = 0;
85 if (n <= 0 || inc_x <= 0) return(max_index);
86 FLOAT maxf=-FLT_MAX;
87
88 FLOAT_V_T vx, v_max;
89 UINT_V_T v_max_index;
90 MASK_T mask;
91 unsigned int gvl = 0;
92 FLOAT_V_T_M1 v_res, v_min;
93 gvl = VSETVL_MAX;
94 v_res = VFMVVF_FLOAT_M1(0, gvl);
95 v_min = VFMVVF_FLOAT_M1(-FLT_MAX, gvl);
96
97 if(inc_x == 1){
98 gvl = VSETVL(n);
99 v_max_index = VMVVX_UINT(0, gvl);
100 v_max = VFMVVF_FLOAT(-FLT_MAX, gvl);
101 for(i=0,j=0; i < n/gvl; i++){
102 vx = VLEV_FLOAT(&x[j], gvl);
103
104 //index where element greater than v_max
105 mask = VMFLTVV_FLOAT(v_max, vx, gvl);
106 v_max_index = VIDV_MASK_UINT(mask, v_max_index, gvl);
107 v_max_index = VADDVX_MASK_UINT(mask, v_max_index, v_max_index, j,gvl);
108
109 //update v_max and start_index j
110 v_max = VFMAXVV_FLOAT(v_max, vx, gvl);
111 j += gvl;
112 }
113 v_res = VFREDMAXVS_FLOAT(v_res, v_max, v_min, gvl);
114 maxf = v_res[0];
115 mask = VMFGEVF_FLOAT(v_max, maxf, gvl);
116 max_index = VMFIRSTM(mask,gvl);
117 max_index = v_max_index[max_index];
118
119 if(j < n){
120 gvl = VSETVL(n-j);
121 v_max = VLEV_FLOAT(&x[j], gvl);
122
123 v_res = VFREDMAXVS_FLOAT(v_res, v_max, v_min, gvl);
124 FLOAT cur_maxf = v_res[0];
125 if(cur_maxf > maxf){
126 //tail index
127 v_max_index = VIDV_UINT(gvl);
128 v_max_index = VADDVX_UINT(v_max_index, j, gvl);
129
130 mask = VMFGEVF_FLOAT(v_max, cur_maxf, gvl);
131 max_index = VMFIRSTM(mask,gvl);
132 max_index = v_max_index[max_index];
133 }
134 }
135 }else{
136 gvl = VSETVL(n);
137 unsigned int stride_x = inc_x * sizeof(FLOAT);
138 unsigned int idx = 0, inc_v = gvl * inc_x;
139
140 v_max = VFMVVF_FLOAT(-FLT_MAX, gvl);
141 v_max_index = VMVVX_UINT(0, gvl);
142 for(i=0,j=0; i < n/gvl; i++){
143 vx = VLSEV_FLOAT(&x[idx], stride_x, gvl);
144
145 //index where element greater than v_max
146 mask = VMFLTVV_FLOAT(v_max, vx, gvl);
147 v_max_index = VIDV_MASK_UINT(mask, v_max_index, gvl);
148 v_max_index = VADDVX_MASK_UINT(mask, v_max_index, v_max_index, j,gvl);
149
150 //update v_max and start_index j
151 v_max = VFMAXVV_FLOAT(v_max, vx, gvl);
152 j += gvl;
153 idx += inc_v;
154 }
155 v_res = VFREDMAXVS_FLOAT(v_res, v_max, v_min, gvl);
156 maxf = v_res[0];
157 mask = VMFGEVF_FLOAT(v_max, maxf, gvl);
158 max_index = VMFIRSTM(mask,gvl);
159 max_index = v_max_index[max_index];
160
161 if(j < n){
162 gvl = VSETVL(n-j);
163 v_max = VLSEV_FLOAT(&x[idx], stride_x, gvl);
164
165 v_res = VFREDMAXVS_FLOAT(v_res, v_max, v_min, gvl);
166 FLOAT cur_maxf = v_res[0];
167 if(cur_maxf > maxf){
168 //tail index
169 v_max_index = VIDV_UINT(gvl);
170 v_max_index = VADDVX_UINT(v_max_index, j, gvl);
171
172 mask = VMFGEVF_FLOAT(v_max, cur_maxf, gvl);
173 max_index = VMFIRSTM(mask,gvl);
174 max_index = v_max_index[max_index];
175 }
176 }
177 }
178 return(max_index+1);
179 }
180
181
182