1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Alexander Heinecke (Intel Corp.)
10 ******************************************************************************/
11 #include "libxsmm_dnn_softmaxloss_forward.h"
12 #include "libxsmm_main.h"
13
14
15 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_fwd_ncnc_f32_f32(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid);
16 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_fwd_ncnc_bf16_bf16(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid);
17
18
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)19 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
20 libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_fwd_ncnc_f32_f32(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid)
21 {
22 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
23 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
24 typedef float element_input_type;
25 typedef float element_output_type;
26 typedef int element_label_type;
27
28 # include "template/libxsmm_dnn_softmaxloss_st_fwd_ncnc_generic.tpl.c"
29 #else /* should not happen */
30 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
31 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
32 #endif
33 return status;
34 }
35
36
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)37 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
38 libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_fwd_ncnc_bf16_bf16(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid)
39 {
40 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
41 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
42 typedef libxsmm_bfloat16 element_input_type;
43 typedef libxsmm_bfloat16 element_output_type;
44 typedef int element_label_type;
45
46 # define LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16_AVX512
47 # include "template/libxsmm_dnn_softmaxloss_st_fwd_ncnc_generic.tpl.c"
48 # undef LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16_AVX512
49 #else /* should not happen */
50 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
51 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
52 #endif
53 return status;
54 }
55
56
libxsmm_dnn_softmaxloss_st_fwd_ncnc(libxsmm_dnn_softmaxloss * handle,int start_thread,int tid)57 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_fwd_ncnc(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid)
58 {
59 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
60
61 /* check if we have input, output and mask */
62 if ( handle->reg_input == 0 || handle->reg_output == 0 || handle->label == 0 ) {
63 status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
64 return status;
65 }
66
67 /* check if we are on an AVX512 platform */
68 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
69 if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) {
70 if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_F32 ) {
71 status = libxsmm_dnn_softmaxloss_st_fwd_ncnc_f32_f32( handle, start_thread, tid);
72 } else if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) {
73 status = libxsmm_dnn_softmaxloss_st_fwd_ncnc_bf16_bf16( handle, start_thread, tid);
74 } else {
75 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
76 return status;
77 }
78 } else
79 #endif
80 {
81 if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_F32 ) {
82 typedef float element_input_type;
83 typedef float element_output_type;
84 typedef int element_label_type;
85
86 # include "template/libxsmm_dnn_softmaxloss_st_fwd_ncnc_generic.tpl.c"
87 } else if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) {
88 typedef libxsmm_bfloat16 element_input_type;
89 typedef libxsmm_bfloat16 element_output_type;
90 typedef int element_label_type;
91
92 # define LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16
93 # include "template/libxsmm_dnn_softmaxloss_st_fwd_ncnc_generic.tpl.c"
94 # undef LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16
95 } else {
96 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
97 return status;
98 }
99 }
100
101 return status;
102 }
103
104