Skip to content

Commit b051d72

Browse files
authored
Add AVX512 macro in CMake to enable AVX512 (#22)
1 parent 101fb32 commit b051d72

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

cmake/CPU.cmake

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules)
1717

1818
FIND_PACKAGE(AVX)
1919

20-
IF (NOT C_AVX512_FOUND)
20+
IF (NOT C_AVX512_FOUND AND NOT CXX_AVX512_FOUND)
2121
message(FATAL_ERROR "Please build IPEX on Machines that support AVX512.")
2222
ENDIF()
2323

@@ -58,13 +58,14 @@ endif()
5858
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pedantic")
5959
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=redundant-decls")
6060
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=old-style-cast")
61-
IF (C_AVX512_FOUND)
61+
IF (C_AVX512_FOUND OR CXX_AVX512_FOUND)
62+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DAVX512")
6263
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f")
6364
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512bw")
6465
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512vl")
6566
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mf16c")
6667
ENDIF()
67-
IF (C_AVX512_BF16_FOUND)
68+
IF (C_AVX512_BF16_FOUND OR CXX_AVX512_BF16_FOUND)
6869
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512bf16 -DAVX512_BF16")
6970
ENDIF()
7071
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")

torch_ipex/csrc/cpu/bf16/Converter.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "Converter.h"
22

3+
#include <ATen/Tensor.h>
4+
35
#if defined(AVX512)
46
#include "vec/vec_type_cvt.h"
57
#define BF16_2_FP32(dst, src, len) cvt_bf16_to_fp32(dst, src, len)
@@ -15,11 +17,11 @@ namespace bf16 {
1517
namespace converter {
1618

1719
void bf16_to_fp32(void *dst, const void *src, int len) {
18-
BF16_2_FP32(dst, src, len);
20+
BF16_2_FP32((float *)dst, (at::BFloat16 *)src, len);
1921
}
2022

2123
void fp32_to_bf16(void *dst, const void *src, int len) {
22-
FP32_2_BF16(dst, src, len);
24+
FP32_2_BF16((at::BFloat16 *)dst, (float *)src, len);
2325
}
2426

2527
} // namespace converter

0 commit comments

Comments
 (0)