1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Alexander Heinecke (Intel Corp.)
10 ******************************************************************************/
11 #include "libxsmm_dnn_softmaxloss_forward.h"
12 #include "libxsmm_main.h"
13 
14 
15 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_fwd_ncnc_f32_f32(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid);
16 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_fwd_ncnc_bf16_bf16(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid);
17 
18 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)19 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
20 libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_fwd_ncnc_f32_f32(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid)
21 {
22   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
23 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
24   typedef float element_input_type;
25   typedef float element_output_type;
26   typedef int   element_label_type;
27 
28 # include "template/libxsmm_dnn_softmaxloss_st_fwd_ncnc_generic.tpl.c"
29 #else /* should not happen */
30   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
31   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
32 #endif
33   return status;
34 }
35 
36 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)37 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
38 libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_fwd_ncnc_bf16_bf16(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid)
39 {
40   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
41 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
42   typedef libxsmm_bfloat16 element_input_type;
43   typedef libxsmm_bfloat16 element_output_type;
44   typedef int              element_label_type;
45 
46 # define LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16_AVX512
47 # include "template/libxsmm_dnn_softmaxloss_st_fwd_ncnc_generic.tpl.c"
48 # undef LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16_AVX512
49 #else /* should not happen */
50   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
51   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
52 #endif
53   return status;
54 }
55 
56 
libxsmm_dnn_softmaxloss_st_fwd_ncnc(libxsmm_dnn_softmaxloss * handle,int start_thread,int tid)57 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_fwd_ncnc(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid)
58 {
59   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
60 
61   /* check if we have input, output and mask */
62   if ( handle->reg_input == 0 || handle->reg_output == 0 || handle->label == 0 ) {
63     status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
64     return status;
65   }
66 
67   /* check if we are on an AVX512 platform */
68 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
69   if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) {
70     if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_F32 ) {
71       status = libxsmm_dnn_softmaxloss_st_fwd_ncnc_f32_f32( handle, start_thread, tid);
72     } else if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) {
73       status = libxsmm_dnn_softmaxloss_st_fwd_ncnc_bf16_bf16( handle, start_thread, tid);
74     } else {
75       status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
76       return status;
77     }
78   } else
79 #endif
80   {
81     if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_F32 ) {
82       typedef float element_input_type;
83       typedef float element_output_type;
84       typedef int   element_label_type;
85 
86 # include "template/libxsmm_dnn_softmaxloss_st_fwd_ncnc_generic.tpl.c"
87     } else if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) {
88       typedef libxsmm_bfloat16 element_input_type;
89       typedef libxsmm_bfloat16 element_output_type;
90       typedef int     element_label_type;
91 
92 # define LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16
93 # include "template/libxsmm_dnn_softmaxloss_st_fwd_ncnc_generic.tpl.c"
94 # undef LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16
95     } else {
96       status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
97       return status;
98     }
99   }
100 
101   return status;
102 }
103 
104