diff --git a/cmake/CPU.cmake b/cmake/CPU.cmake index 4440f42f2..633b0bba7 100644 --- a/cmake/CPU.cmake +++ b/cmake/CPU.cmake @@ -13,6 +13,14 @@ SET(DNNL_ENABLE_PRIMITIVE_CACHE TRUE CACHE BOOL "" FORCE) set(DPCPP_CPU_ROOT "${PROJECT_SOURCE_DIR}/torch_ipex/csrc/cpu") add_subdirectory(${DPCPP_THIRD_PARTY_ROOT}/mkl-dnn) +list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules) + +FIND_PACKAGE(AVX) + +IF (NOT C_AVX512_FOUND) + message(FATAL_ERROR "Please build IPEX on Machines that support AVX512.") +ENDIF() + # Define build type IF(CMAKE_BUILD_TYPE MATCHES Debug) message("Debug build.") diff --git a/cmake/Modules/FindAVX.cmake b/cmake/Modules/FindAVX.cmake new file mode 100644 index 000000000..bb3abbab8 --- /dev/null +++ b/cmake/Modules/FindAVX.cmake @@ -0,0 +1,47 @@ +INCLUDE(CheckCSourceCompiles) +INCLUDE(CheckCXXSourceCompiles) + +SET(AVX512_CODE " + #include + #include + + int main() { + __m256i src; + __mmask16 mask; + int16_t addr[16]; + // detect avx512f, avx512bw and avx512vl. + _mm512_cvtepi16_epi32(_mm256_mask_loadu_epi16(src, mask, (void *)addr)); + return 0; + } +") + +MACRO(CHECK_SSE lang type flags) + SET(__FLAG_I 1) + SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) + FOREACH(__FLAG ${flags}) + IF(NOT ${lang}_${type}_FOUND) + SET(CMAKE_REQUIRED_FLAGS ${__FLAG}) + IF(lang STREQUAL "CXX") + CHECK_CXX_SOURCE_COMPILES("${${type}_CODE}" ${lang}_HAS_${type}_${__FLAG_I}) + ELSE() + CHECK_C_SOURCE_COMPILES("${${type}_CODE}" ${lang}_HAS_${type}_${__FLAG_I}) + ENDIF() + IF(${lang}_HAS_${type}_${__FLAG_I}) + SET(${lang}_${type}_FOUND TRUE CACHE BOOL "${lang} ${type} support") + SET(${lang}_${type}_FLAGS "${__FLAG}" CACHE STRING "${lang} ${type} flags") + ENDIF() + MATH(EXPR __FLAG_I "${__FLAG_I}+1") + ENDIF() + ENDFOREACH() + SET(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) + + IF(NOT ${lang}_${type}_FOUND) + SET(${lang}_${type}_FOUND FALSE CACHE BOOL "${lang} ${type} support") + SET(${lang}_${type}_FLAGS "" CACHE STRING "${lang} ${type} flags") + ENDIF() + + MARK_AS_ADVANCED(${lang}_${type}_FOUND ${lang}_${type}_FLAGS) +ENDMACRO() + +CHECK_SSE(C "AVX512" " ;-mavx512f -mavx512bw -mavx512vl ;/arch:AVX512") +CHECK_SSE(CXX "AVX512" " ;-mavx512f -mavx512bw -mavx512vl ;/arch:AVX512")