/dports/math/libxsmm/libxsmm-1.16.3/samples/deeplearning/gxm/include/ |
H A D | common.hpp | 61 static __m512 gxm_fp32_to_bfp16_rne_adjustment_avx512f(__m512 vfp32) { in gxm_fp32_to_bfp16_rne_adjustment_avx512f() argument 66 __m512i vfp32_as_int = _mm512_castps_si512(vfp32); in gxm_fp32_to_bfp16_rne_adjustment_avx512f() 70 return _mm512_fixupimm_ps(_mm512_castsi512_ps(vfp32_as_int), vfp32, selector, 0); in gxm_fp32_to_bfp16_rne_adjustment_avx512f() 73 static __m256i gxm_fp32_to_bfp16_truncate_avx512f(__m512 vfp32) { in gxm_fp32_to_bfp16_truncate_avx512f() argument 74 __m512i vbfp16_32 = _mm512_srai_epi32(_mm512_castps_si512(vfp32), 16); in gxm_fp32_to_bfp16_truncate_avx512f()
|
/dports/math/libxsmm/libxsmm-1.16.3/samples/deeplearning/gxm/src/ |
H A D | reduce_weight_grads_bf16.c | 21 __m512 vfp32 = _mm512_add_ps(vfp32_l, vfp32_r); 22 __m512 vfp32rne = gxm_fp32_to_bfp16_rne_adjustment_avx512f(vfp32); 36 __m512 vfp32 = _mm512_add_ps(vfp32_l, vfp32_r); 37 __m512 vfp32rne = gxm_fp32_to_bfp16_rne_adjustment_avx512f(vfp32);
|
H A D | FullyConnected.cpp | 452 __m512 vfp32 = gxm_fp32_to_bfp16_rne_adjustment_avx512f(_mm512_loadu_ps(in + i)); in convert_f32_bf16() local 453 __m256i vbfp16 = gxm_fp32_to_bfp16_truncate_avx512f(vfp32); in convert_f32_bf16() 479 __m512 vfp32 = gxm_bfp16_to_fp32_avx512f( vbfp16 ); in convert_bf16_f32() local 480 _mm512_storeu_ps( out+i, vfp32 ); in convert_bf16_f32() 491 __m512 vfp32 = gxm_bfp16_to_fp32_avx512f( vbfp16 ); in convert_bf16_f32() local 492 _mm512_storeu_ps( out+i, vfp32 ); in convert_bf16_f32()
|
H A D | FCXSMM.cpp | 410 __m512 vfp32 = _mm512_add_ps(vfp32_l, vfp32_r); in weightUpdate() local 411 __m512 vfp32rne = gxm_fp32_to_bfp16_rne_adjustment_avx512f(vfp32); in weightUpdate() 425 __m512 vfp32 = _mm512_add_ps(vfp32_l, vfp32_r); in weightUpdate() local 426 __m512 vfp32rne = gxm_fp32_to_bfp16_rne_adjustment_avx512f(vfp32); in weightUpdate()
|
H A D | DummyData.cpp | 120 __m512 vfp32 = gxm_fp32_to_bfp16_rne_adjustment_avx512f( _mm512_loadu_ps( in+i ) ); in convert_f32_bf16() local 121 __m256i vbfp16 = gxm_fp32_to_bfp16_truncate_avx512f( vfp32 ); in convert_f32_bf16()
|
H A D | Conv.cpp | 528 __m512 vfp32 = gxm_fp32_to_bfp16_rne_adjustment_avx512f( _mm512_loadu_ps( in+i ) ); in convert_f32_bf16() local 529 __m256i vbfp16 = gxm_fp32_to_bfp16_truncate_avx512f( vfp32 ); in convert_f32_bf16() 555 __m512 vfp32 = gxm_bfp16_to_fp32_avx512f( vbfp16 ); in convert_bf16_f32() local 556 _mm512_storeu_ps( out+i, vfp32 ); in convert_bf16_f32() 567 __m512 vfp32 = gxm_bfp16_to_fp32_avx512f( vbfp16 ); in convert_bf16_f32() local 568 _mm512_storeu_ps( out+i, vfp32 ); in convert_bf16_f32()
|
H A D | Split.cpp | 145 __m512 vfp32 = gxm_bfp16_to_fp32_avx512f( vbfp16 ); in convert_bf16_f32() local 146 _mm512_storeu_ps( out+i, vfp32 ); in convert_bf16_f32()
|
H A D | Pooling.cpp | 205 __m512 vfp32 = gxm_bfp16_to_fp32_avx512f( vbfp16 ); in convert_bf16_f32() local 206 _mm512_storeu_ps( out+i, vfp32 ); in convert_bf16_f32()
|
H A D | Solver.cpp | 65 __m512 vfp32 = gxm_bfp16_to_fp32_avx512f( vbfp16 ); in convert_bf16_f32() local 66 _mm512_storeu_ps( outp+i, vfp32 ); in convert_bf16_f32() 80 __m512 vfp32 = gxm_bfp16_to_fp32_avx512f( vbfp16 ); in convert_bf16_f32() local 81 _mm512_storeu_ps( out+i, vfp32 ); in convert_bf16_f32()
|
H A D | ConvXSMM.cpp | 777 __m512 vfp32 = _mm512_add_ps(vfp32_l, vfp32_r); in weightUpdate() local 778 __m512 vfp32rne = gxm_fp32_to_bfp16_rne_adjustment_avx512f(vfp32); in weightUpdate() 792 __m512 vfp32 = _mm512_add_ps(vfp32_l, vfp32_r); in weightUpdate() local 793 __m512 vfp32rne = gxm_fp32_to_bfp16_rne_adjustment_avx512f(vfp32); in weightUpdate()
|
H A D | Engine.cpp | 888 __m512 vfp32 = gxm_fp32_to_bfp16_rne_adjustment_avx512f( _mm512_loadu_ps( in+i ) ); in convert_f32_bf16() local 889 __m256i vbfp16 = gxm_fp32_to_bfp16_truncate_avx512f( vfp32 ); in convert_f32_bf16() 915 __m512 vfp32 = gxm_fp32_to_bfp16_rne_adjustment_avx512f(_mm512_loadu_ps(inp + i)); in convert_f32_bf16() local 916 __m256i vbfp16 = gxm_fp32_to_bfp16_truncate_avx512f(vfp32); in convert_f32_bf16() 930 __m512 vfp32 = gxm_bfp16_to_fp32_avx512f( vbfp16 ); in convert_bf16_f32() local 931 _mm512_storeu_ps( out+i, vfp32 ); in convert_bf16_f32()
|
H A D | FusedConvBN.cpp | 580 __m512 vfp32 = gxm_fp32_to_bfp16_rne_adjustment_avx512f(_mm512_loadu_ps(in + i)); in convert_f32_bf16() local 581 __m256i vbfp16 = gxm_fp32_to_bfp16_truncate_avx512f(vfp32); in convert_f32_bf16() 595 __m512 vfp32 = gxm_bfp16_to_fp32_avx512f( vbfp16 ); in convert_bf16_f32() local 596 _mm512_storeu_ps( out+i, vfp32 ); in convert_bf16_f32()
|
H A D | JitterData.cpp | 521 __m512 vfp32 = gxm_fp32_to_bfp16_rne_adjustment_avx512f(_mm512_loadu_ps(in + i)); in convert_f32_bf16() local 522 __m256i vbfp16 = gxm_fp32_to_bfp16_truncate_avx512f(vfp32); in convert_f32_bf16()
|
H A D | FusedBNorm.cpp | 314 __m512 vfp32 = gxm_bfp16_to_fp32_avx512f( vbfp16 ); in convert_bf16_f32() local 315 _mm512_storeu_ps( out+i, vfp32 ); in convert_bf16_f32()
|