Skip to content

enable bfloat16 RNE conversion with icx/icpx #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions cmake/CPU.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,15 @@ endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pedantic")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=redundant-decls")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=old-style-cast")
# TODO: Add flags basing on native machine
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512bw")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512vl")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mf16c")
IF (C_AVX512_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512bw")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512vl")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mf16c")
ENDIF()
IF (C_AVX512_BF16_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512bf16 -DAVX512_BF16")
ENDIF()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
# These flags are not available in GCC-4.8.5. Set only when using clang.
# Compared against https://gcc.gnu.org/onlinedocs/gcc-4.8.5/gcc/Option-Summary.html
Expand Down
19 changes: 17 additions & 2 deletions cmake/Modules/FindAVX.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@ SET(AVX512_CODE "
}
")

SET(AVX512_BF16_CODE "
#include <stdint.h>
#include <immintrin.h>

int main() {
__m512 src;
// detect avx512f and avx512bf16
_mm512_cvtneps_pbh(src);
return 0;
}
")

MACRO(CHECK_SSE lang type flags)
SET(__FLAG_I 1)
SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
Expand Down Expand Up @@ -43,5 +55,8 @@ MACRO(CHECK_SSE lang type flags)
MARK_AS_ADVANCED(${lang}_${type}_FOUND ${lang}_${type}_FLAGS)
ENDMACRO()

CHECK_SSE(C "AVX512" " ;-mavx512f -mavx512bw -mavx512vl ;/arch:AVX512")
CHECK_SSE(CXX "AVX512" " ;-mavx512f -mavx512bw -mavx512vl ;/arch:AVX512")
CHECK_SSE(C "AVX512" " ;-mavx512f -mavx512bw -mavx512vl")
CHECK_SSE(CXX "AVX512" " ;-mavx512f -mavx512bw -mavx512vl")

CHECK_SSE(C "AVX512_BF16" " ;-mavx512f -mavx512bf16")
CHECK_SSE(CXX "AVX512_BF16" " ;-mavx512f -mavx512bf16")
2 changes: 1 addition & 1 deletion torch_ipex/csrc/cpu/ExtendOPs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ static inline void mm_backward(at::BFloat16 *out, const at::BFloat16 *in1, const
}
// mm backward w/ fp32
mm_ker(tmp_in2, tmp_in1, tmp_out);
trunc_fp32_to_bf16(out, tmp_out, vector_nums * vector_size);
cvt_fp32_to_bf16(out, tmp_out, vector_nums * vector_size);
}

template<typename T>
Expand Down
2 changes: 1 addition & 1 deletion torch_ipex/csrc/cpu/bf16/Converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#if defined(AVX512)
#include "vec/vec_type_cvt.h"
#define BF16_2_FP32(dst, src, len) cvt_bf16_to_fp32(dst, src, len)
#define FP32_2_BF16(dst, src, len) trunc_fp32_to_bf16(dst, src, len)
#define FP32_2_BF16(dst, src, len) cvt_fp32_to_bf16(dst, src, len)
#else
#define BF16_2_FP32(dst, src, len)
#define FP32_2_BF16(dst, src, len)
Expand Down
4 changes: 2 additions & 2 deletions torch_ipex/csrc/cpu/bf16/vec/bf16_vec_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ inline void add_ker(at::BFloat16 *inout, at::BFloat16 *in, int len) {
auto x1 = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(inout + i)));
auto x2 = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(in + i)));
x1 = _mm512_add_ps(x1, x2);
_mm256_storeu_si256((__m256i*)(inout + i), trunc_fp32_to_bf16(x1));
_mm256_storeu_si256((__m256i*)(inout + i), cvt_fp32_to_bf16(x1));
}
if(i < len) {
auto mask = (1 << (len - i)) - 1;
auto x1 = cvt_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, inout + i));
auto x2 = cvt_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, in + i));
x1 = _mm512_add_ps(x1, x2);
_mm256_mask_storeu_epi16(inout + i, mask, trunc_fp32_to_bf16(x1));
_mm256_mask_storeu_epi16(inout + i, mask, cvt_fp32_to_bf16(x1));
}
}

Expand Down
29 changes: 17 additions & 12 deletions torch_ipex/csrc/cpu/bf16/vec/vec_type_cvt.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,6 @@ inline __m512 cvt_bf16_to_fp32(const __m256i src) {
return _mm512_castsi512_ps(_mm512_bslli_epi128(y, 2));
}

// Truncate conversion from FP32 to BF16
inline __m256i trunc_fp32_to_bf16(const __m512 src) {
auto y = _mm512_bsrli_epi128(_mm512_castps_si512(src), 2);
return _mm512_cvtepi32_epi16(y);
}

inline void cvt_bf16_to_fp32(float *dst, const at::BFloat16 *src, int len) {
int i = 0;
for (; i < len - 15; i += 16) {
Expand All @@ -25,18 +19,29 @@ inline void cvt_bf16_to_fp32(float *dst, const at::BFloat16 *src, int len) {
}
}

inline void trunc_fp32_to_bf16(at::BFloat16 *dst, const float *src, int len) {
// Conversion from FP32 to BF16
inline __m256i trunc_fp32_to_bf16(const __m512 src) {
auto y = _mm512_bsrli_epi128(_mm512_castps_si512(src), 2);
return _mm512_cvtepi32_epi16(y);
}

inline __m256i cvt_fp32_to_bf16(const __m512 src) {
#if defined(AVX512_BF16)
return _mm512_cvtneps_pbh(src);
#else
return trunc_fp32_to_bf16(src);
#endif
}

inline void cvt_fp32_to_bf16(at::BFloat16 *dst, const float *src, int len) {
int i = 0;
for (; i < len - 15; i += 16) {
auto f32 = _mm512_loadu_ps(src + i);
_mm256_storeu_si256((__m256i *)(dst + i), trunc_fp32_to_bf16(f32));
_mm256_storeu_si256((__m256i *)(dst + i), cvt_fp32_to_bf16(f32));
}
if (i < len) {
auto mask = (1 << (len - i )) - 1;
auto f32 = _mm512_maskz_loadu_ps(mask, src + i);
_mm256_mask_storeu_epi16(dst + i, mask, trunc_fp32_to_bf16(f32));
_mm256_mask_storeu_epi16(dst + i, mask, cvt_fp32_to_bf16(f32));
}
}

// TODO:
// RNE conversion from FP32 to BF16