diff --git a/cmake/CPU.cmake b/cmake/CPU.cmake index 14979570e..08a246f35 100644 --- a/cmake/CPU.cmake +++ b/cmake/CPU.cmake @@ -58,11 +58,15 @@ endif() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pedantic") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=redundant-decls") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=old-style-cast") -# TODO: Add flags basing on native machine -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512bw") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512vl") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mf16c") +IF (C_AVX512_FOUND) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512bw") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512vl") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mf16c") +ENDIF() +IF (C_AVX512_BF16_FOUND) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512bf16 -DAVX512_BF16") +ENDIF() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") # These flags are not available in GCC-4.8.5. Set only when using clang. # Compared against https://gcc.gnu.org/onlinedocs/gcc-4.8.5/gcc/Option-Summary.html diff --git a/cmake/Modules/FindAVX.cmake b/cmake/Modules/FindAVX.cmake index bb3abbab8..f6ce079dd 100644 --- a/cmake/Modules/FindAVX.cmake +++ b/cmake/Modules/FindAVX.cmake @@ -15,6 +15,18 @@ SET(AVX512_CODE " } ") +SET(AVX512_BF16_CODE " + #include + #include + + int main() { + __m512 src; + // detect avx512f and avx512bf16 + _mm512_cvtneps_pbh(src); + return 0; + } +") + MACRO(CHECK_SSE lang type flags) SET(__FLAG_I 1) SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) @@ -43,5 +55,8 @@ MACRO(CHECK_SSE lang type flags) MARK_AS_ADVANCED(${lang}_${type}_FOUND ${lang}_${type}_FLAGS) ENDMACRO() -CHECK_SSE(C "AVX512" " ;-mavx512f -mavx512bw -mavx512vl ;/arch:AVX512") -CHECK_SSE(CXX "AVX512" " ;-mavx512f -mavx512bw -mavx512vl ;/arch:AVX512") +CHECK_SSE(C "AVX512" " ;-mavx512f -mavx512bw -mavx512vl") +CHECK_SSE(CXX "AVX512" " ;-mavx512f -mavx512bw -mavx512vl") + +CHECK_SSE(C "AVX512_BF16" " ;-mavx512f -mavx512bf16") +CHECK_SSE(CXX "AVX512_BF16" " ;-mavx512f -mavx512bf16") diff --git a/torch_ipex/csrc/cpu/ExtendOPs.cpp b/torch_ipex/csrc/cpu/ExtendOPs.cpp index dd81bdb29..636bef071 100644 --- a/torch_ipex/csrc/cpu/ExtendOPs.cpp +++ b/torch_ipex/csrc/cpu/ExtendOPs.cpp @@ -189,7 +189,7 @@ static inline void mm_backward(at::BFloat16 *out, const at::BFloat16 *in1, const } // mm backward w/ fp32 mm_ker(tmp_in2, tmp_in1, tmp_out); - trunc_fp32_to_bf16(out, tmp_out, vector_nums * vector_size); + cvt_fp32_to_bf16(out, tmp_out, vector_nums * vector_size); } template diff --git a/torch_ipex/csrc/cpu/bf16/Converter.cpp b/torch_ipex/csrc/cpu/bf16/Converter.cpp index 0fa971226..97ebfef75 100644 --- a/torch_ipex/csrc/cpu/bf16/Converter.cpp +++ b/torch_ipex/csrc/cpu/bf16/Converter.cpp @@ -3,7 +3,7 @@ #if defined(AVX512) #include "vec/vec_type_cvt.h" #define BF16_2_FP32(dst, src, len) cvt_bf16_to_fp32(dst, src, len) -#define FP32_2_BF16(dst, src, len) trunc_fp32_to_bf16(dst, src, len) +#define FP32_2_BF16(dst, src, len) cvt_fp32_to_bf16(dst, src, len) #else #define BF16_2_FP32(dst, src, len) #define FP32_2_BF16(dst, src, len) diff --git a/torch_ipex/csrc/cpu/bf16/vec/bf16_vec_kernel.h b/torch_ipex/csrc/cpu/bf16/vec/bf16_vec_kernel.h index 4cd65f8c9..f58620a55 100644 --- a/torch_ipex/csrc/cpu/bf16/vec/bf16_vec_kernel.h +++ b/torch_ipex/csrc/cpu/bf16/vec/bf16_vec_kernel.h @@ -46,14 +46,14 @@ inline void add_ker(at::BFloat16 *inout, at::BFloat16 *in, int len) { auto x1 = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(inout + i))); auto x2 = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(in + i))); x1 = _mm512_add_ps(x1, x2); - _mm256_storeu_si256((__m256i*)(inout + i), trunc_fp32_to_bf16(x1)); + _mm256_storeu_si256((__m256i*)(inout + i), cvt_fp32_to_bf16(x1)); } if(i < len) { auto mask = (1 << (len - i)) - 1; auto x1 = cvt_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, inout + i)); auto x2 = cvt_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, in + i)); x1 = _mm512_add_ps(x1, x2); - _mm256_mask_storeu_epi16(inout + i, mask, trunc_fp32_to_bf16(x1)); + _mm256_mask_storeu_epi16(inout + i, mask, cvt_fp32_to_bf16(x1)); } } diff --git a/torch_ipex/csrc/cpu/bf16/vec/vec_type_cvt.h b/torch_ipex/csrc/cpu/bf16/vec/vec_type_cvt.h index 5fb8906d1..895c624f3 100644 --- a/torch_ipex/csrc/cpu/bf16/vec/vec_type_cvt.h +++ b/torch_ipex/csrc/cpu/bf16/vec/vec_type_cvt.h @@ -6,12 +6,6 @@ inline __m512 cvt_bf16_to_fp32(const __m256i src) { return _mm512_castsi512_ps(_mm512_bslli_epi128(y, 2)); } -// Truncate conversion from FP32 to BF16 -inline __m256i trunc_fp32_to_bf16(const __m512 src) { - auto y = _mm512_bsrli_epi128(_mm512_castps_si512(src), 2); - return _mm512_cvtepi32_epi16(y); -} - inline void cvt_bf16_to_fp32(float *dst, const at::BFloat16 *src, int len) { int i = 0; for (; i < len - 15; i += 16) { @@ -25,18 +19,29 @@ inline void cvt_bf16_to_fp32(float *dst, const at::BFloat16 *src, int len) { } } -inline void trunc_fp32_to_bf16(at::BFloat16 *dst, const float *src, int len) { +// Conversion from FP32 to BF16 +inline __m256i trunc_fp32_to_bf16(const __m512 src) { + auto y = _mm512_bsrli_epi128(_mm512_castps_si512(src), 2); + return _mm512_cvtepi32_epi16(y); +} + +inline __m256i cvt_fp32_to_bf16(const __m512 src) { +#if defined(AVX512_BF16) + return _mm512_cvtneps_pbh(src); +#else + return trunc_fp32_to_bf16(src); +#endif +} + +inline void cvt_fp32_to_bf16(at::BFloat16 *dst, const float *src, int len) { int i = 0; for (; i < len - 15; i += 16) { auto f32 = _mm512_loadu_ps(src + i); - _mm256_storeu_si256((__m256i *)(dst + i), trunc_fp32_to_bf16(f32)); + _mm256_storeu_si256((__m256i *)(dst + i), cvt_fp32_to_bf16(f32)); } if (i < len) { auto mask = (1 << (len - i )) - 1; auto f32 = _mm512_maskz_loadu_ps(mask, src + i); - _mm256_mask_storeu_epi16(dst + i, mask, trunc_fp32_to_bf16(f32)); + _mm256_mask_storeu_epi16(dst + i, mask, cvt_fp32_to_bf16(f32)); } } - -// TODO: -// RNE conversion from FP32 to BF16