Skip to content

Commit 8418986

Browse files
authored
Merge pull request #11 from hongzhen1/ws
enable bfloat16 RNE conversion with icx/icpx
2 parents 8eaf535 + cec2af8 commit 8418986

File tree

6 files changed

+47
-23
lines changed

6 files changed

+47
-23
lines changed

cmake/CPU.cmake

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,15 @@ endif()
5858
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pedantic")
5959
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=redundant-decls")
6060
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=old-style-cast")
61-
# TODO: Add flags basing on native machine
62-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f")
63-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512bw")
64-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512vl")
65-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mf16c")
61+
IF (C_AVX512_FOUND)
62+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f")
63+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512bw")
64+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512vl")
65+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mf16c")
66+
ENDIF()
67+
IF (C_AVX512_BF16_FOUND)
68+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512bf16 -DAVX512_BF16")
69+
ENDIF()
6670
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
6771
# These flags are not available in GCC-4.8.5. Set only when using clang.
6872
# Compared against https://gcc.gnu.org/onlinedocs/gcc-4.8.5/gcc/Option-Summary.html

cmake/Modules/FindAVX.cmake

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,18 @@ SET(AVX512_CODE "
1515
}
1616
")
1717

18+
SET(AVX512_BF16_CODE "
19+
#include <stdint.h>
20+
#include <immintrin.h>
21+
22+
int main() {
23+
__m512 src;
24+
// detect avx512f and avx512bf16
25+
_mm512_cvtneps_pbh(src);
26+
return 0;
27+
}
28+
")
29+
1830
MACRO(CHECK_SSE lang type flags)
1931
SET(__FLAG_I 1)
2032
SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
@@ -43,5 +55,8 @@ MACRO(CHECK_SSE lang type flags)
4355
MARK_AS_ADVANCED(${lang}_${type}_FOUND ${lang}_${type}_FLAGS)
4456
ENDMACRO()
4557

46-
CHECK_SSE(C "AVX512" " ;-mavx512f -mavx512bw -mavx512vl ;/arch:AVX512")
47-
CHECK_SSE(CXX "AVX512" " ;-mavx512f -mavx512bw -mavx512vl ;/arch:AVX512")
58+
CHECK_SSE(C "AVX512" " ;-mavx512f -mavx512bw -mavx512vl")
59+
CHECK_SSE(CXX "AVX512" " ;-mavx512f -mavx512bw -mavx512vl")
60+
61+
CHECK_SSE(C "AVX512_BF16" " ;-mavx512f -mavx512bf16")
62+
CHECK_SSE(CXX "AVX512_BF16" " ;-mavx512f -mavx512bf16")

torch_ipex/csrc/cpu/ExtendOPs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ static inline void mm_backward(at::BFloat16 *out, const at::BFloat16 *in1, const
189189
}
190190
// mm backward w/ fp32
191191
mm_ker(tmp_in2, tmp_in1, tmp_out);
192-
trunc_fp32_to_bf16(out, tmp_out, vector_nums * vector_size);
192+
cvt_fp32_to_bf16(out, tmp_out, vector_nums * vector_size);
193193
}
194194

195195
template<typename T>

torch_ipex/csrc/cpu/bf16/Converter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#if defined(AVX512)
44
#include "vec/vec_type_cvt.h"
55
#define BF16_2_FP32(dst, src, len) cvt_bf16_to_fp32(dst, src, len)
6-
#define FP32_2_BF16(dst, src, len) trunc_fp32_to_bf16(dst, src, len)
6+
#define FP32_2_BF16(dst, src, len) cvt_fp32_to_bf16(dst, src, len)
77
#else
88
#define BF16_2_FP32(dst, src, len)
99
#define FP32_2_BF16(dst, src, len)

torch_ipex/csrc/cpu/bf16/vec/bf16_vec_kernel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,14 @@ inline void add_ker(at::BFloat16 *inout, at::BFloat16 *in, int len) {
4646
auto x1 = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(inout + i)));
4747
auto x2 = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(in + i)));
4848
x1 = _mm512_add_ps(x1, x2);
49-
_mm256_storeu_si256((__m256i*)(inout + i), trunc_fp32_to_bf16(x1));
49+
_mm256_storeu_si256((__m256i*)(inout + i), cvt_fp32_to_bf16(x1));
5050
}
5151
if(i < len) {
5252
auto mask = (1 << (len - i)) - 1;
5353
auto x1 = cvt_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, inout + i));
5454
auto x2 = cvt_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, in + i));
5555
x1 = _mm512_add_ps(x1, x2);
56-
_mm256_mask_storeu_epi16(inout + i, mask, trunc_fp32_to_bf16(x1));
56+
_mm256_mask_storeu_epi16(inout + i, mask, cvt_fp32_to_bf16(x1));
5757
}
5858
}
5959

torch_ipex/csrc/cpu/bf16/vec/vec_type_cvt.h

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,6 @@ inline __m512 cvt_bf16_to_fp32(const __m256i src) {
66
return _mm512_castsi512_ps(_mm512_bslli_epi128(y, 2));
77
}
88

9-
// Truncate conversion from FP32 to BF16
10-
inline __m256i trunc_fp32_to_bf16(const __m512 src) {
11-
auto y = _mm512_bsrli_epi128(_mm512_castps_si512(src), 2);
12-
return _mm512_cvtepi32_epi16(y);
13-
}
14-
159
inline void cvt_bf16_to_fp32(float *dst, const at::BFloat16 *src, int len) {
1610
int i = 0;
1711
for (; i < len - 15; i += 16) {
@@ -25,18 +19,29 @@ inline void cvt_bf16_to_fp32(float *dst, const at::BFloat16 *src, int len) {
2519
}
2620
}
2721

28-
inline void trunc_fp32_to_bf16(at::BFloat16 *dst, const float *src, int len) {
22+
// Conversion from FP32 to BF16
23+
inline __m256i trunc_fp32_to_bf16(const __m512 src) {
24+
auto y = _mm512_bsrli_epi128(_mm512_castps_si512(src), 2);
25+
return _mm512_cvtepi32_epi16(y);
26+
}
27+
28+
inline __m256i cvt_fp32_to_bf16(const __m512 src) {
29+
#if defined(AVX512_BF16)
30+
return _mm512_cvtneps_pbh(src);
31+
#else
32+
return trunc_fp32_to_bf16(src);
33+
#endif
34+
}
35+
36+
inline void cvt_fp32_to_bf16(at::BFloat16 *dst, const float *src, int len) {
2937
int i = 0;
3038
for (; i < len - 15; i += 16) {
3139
auto f32 = _mm512_loadu_ps(src + i);
32-
_mm256_storeu_si256((__m256i *)(dst + i), trunc_fp32_to_bf16(f32));
40+
_mm256_storeu_si256((__m256i *)(dst + i), cvt_fp32_to_bf16(f32));
3341
}
3442
if (i < len) {
3543
auto mask = (1 << (len - i )) - 1;
3644
auto f32 = _mm512_maskz_loadu_ps(mask, src + i);
37-
_mm256_mask_storeu_epi16(dst + i, mask, trunc_fp32_to_bf16(f32));
45+
_mm256_mask_storeu_epi16(dst + i, mask, cvt_fp32_to_bf16(f32));
3846
}
3947
}
40-
41-
// TODO:
42-
// RNE conversion from FP32 to BF16

0 commit comments

Comments
 (0)