Skip to content

Mix bf16 fp32 #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions cmake/CPU.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules)

FIND_PACKAGE(AVX)

IF (NOT C_AVX512_FOUND)
IF (NOT C_AVX512_FOUND AND NOT CXX_AVX512_FOUND)
message(FATAL_ERROR "Please build IPEX on Machines that support AVX512.")
ENDIF()

Expand Down Expand Up @@ -58,13 +58,14 @@ endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pedantic")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=redundant-decls")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=old-style-cast")
IF (C_AVX512_FOUND)
IF (C_AVX512_FOUND OR CXX_AVX512_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DAVX512")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512bw")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512vl")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mf16c")
ENDIF()
IF (C_AVX512_BF16_FOUND)
IF (C_AVX512_BF16_FOUND OR CXX_AVX512_BF16_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512bf16 -DAVX512_BF16")
ENDIF()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
Expand Down
53 changes: 50 additions & 3 deletions scripts/cpu/gen-dense-cpu-ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@
'aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor',
]

_FN_BF16_FUNCS_WITH_SIMPLE_ATEN_SIG = [
'aten::convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor',
'aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor',
'aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)',
]

_SHALLOW_FALLBACK_TO_CPU_TENSOR_LIST = 'shallowFallbackToCPUTensorList'
_SHALLOW_FALLBACK_TO_CPU_TENSOR = 'shallowFallbackToCPUTensor'
_SHALLOW_UPGRADE_TO_DPCPP_TENSOR = 'shallowUpgradeToDPCPPTensor'
Expand Down Expand Up @@ -125,6 +131,7 @@ class AtenIpexCPUDefault {{
#include "utils.h"
#include "DevOPs.h"
#include "dbl/DNNLChecker.h"
#include "bf16/BF16Checker.h"

namespace torch_ipex {{
namespace cpu {{
Expand Down Expand Up @@ -177,6 +184,13 @@ def is_dnnl_func(self, simple_aten_sig):
return True
return False

def is_bf16_func(self, simple_aten_sig):
stripped_str = simple_aten_sig.replace(' ', '')
for item in _FN_BF16_FUNCS_WITH_SIMPLE_ATEN_SIG:
if stripped_str == item.replace(' ', ''):
return True
return False

def is_bypass_func(self, cpp_sig):
for frx in _FN_BYPASS_REGEX:
if re.match(frx, cpp_sig.def_name):
Expand Down Expand Up @@ -256,6 +270,13 @@ def get_func_dec(self, cpp_sig):
func_dec_str = func_dec_str.replace(key, _TYPE_NSMAP[key])
return func_dec_str

def get_tensor_parameter(self, cpp_sig):
tensor_param_vars = []
for param in cpp_sig.input_params:
if param.core_type == 'Tensor':
tensor_param_vars.append(param.name)
return tensor_param_vars

def gen_func_signature(self, cpp_func_str):
cpp_func_str_h = cpp_func_str
for key in _TYPE_NSMAP:
Expand All @@ -269,6 +290,33 @@ def gen_func_signature(self, cpp_func_str):

return (cpp_func_str_h, cpp_func_str_cpp)

def gen_bf16_code(self, cpp_sig, aten_func_sig_str):
# Does not plan to support in-place tensor
code = ''

reorder_func_name = 'reorderTensorToScalaraType'
if self.is_dnnl_func(aten_func_sig_str):
reorder_func_name = 'reorderTensorToScalarTypeForDNNL'

tensor_param_vars = self.get_tensor_parameter(cpp_sig)
if not self.is_bf16_func(aten_func_sig_str):
code += ' if (check_mix_bf16_fp32()) {\n'
for tensor_param_var in tensor_param_vars:
code += ' bridge::{}({}, at::kFloat);\n'.format(reorder_func_name, tensor_param_var)
code += ' }\n'
return code
else:
code += ' if (check_mix_bf16_fp32()) {\n'
code += ' std::vector<at::Tensor> dnnl_input_tensors;\n'
for tensor_param_var in tensor_param_vars:
code += ' dnnl_input_tensors.push_back({});\n'.format(tensor_param_var)
code += ' if (bf16::chk::bf16_support_the_tensors(dnnl_input_tensors)) {\n'
for tensor_param_var in tensor_param_vars:
code += ' bridge::{}({}, at::kBFloat16);\n'.format(reorder_func_name, tensor_param_var)
code += ' }\n'
code += ' }\n'
return code

def gen_dnnl_code(self, cpp_sig, aten_func_sig_str):
code = ''

Expand All @@ -278,11 +326,9 @@ def is_out_func(fname):
if not self.is_dnnl_func(aten_func_sig_str):
return code

dnnl_tensor_param_vars = self.get_tensor_parameter(cpp_sig)
param_vars = []
dnnl_tensor_param_vars = []
for param in cpp_sig.input_params:
if param.core_type == 'Tensor':
dnnl_tensor_param_vars.append(param.name)
param_vars.append(param.name)

code += ' try {\n'
Expand Down Expand Up @@ -472,6 +518,7 @@ def is_conv_overrideable_func(fname):
# Gen definition code for cpp file
code = '{} {{\n'.format(cpp_func_str_cpp)

code += self.gen_bf16_code(cpp_sig, aten_func_sig_str)
if is_conv_overrideable_func(cpp_sig.def_name):
code += ' return AtenIpexCPUDev::dil_{}({});\n'.format(cpp_sig.def_name, ', '.join([param.name for param in cpp_sig.input_params]))
else:
Expand Down
59 changes: 59 additions & 0 deletions tests/cpu/test_bf16_lazy_reorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Tests for lazy reorder."""
from __future__ import division
from __future__ import print_function

import os
import math
import time
import random
import unittest
from functools import reduce
import copy
import sys
import torch
import _torch_ipex as ipex
ipex._initialize_aten_bindings()
import intel_pytorch_extension

import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd import gradcheck
from torch.autograd.gradcheck import gradgradcheck
from torch._six import inf, nan

from common_utils import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \
TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \
IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, do_test_dtypes, do_test_empty_full, \
IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist, slowTest, \
skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf

def get_rand_seed():
return int(time.time() * 1000000000)

device = torch.device("dpcpp:0")
class TestConv(TestCase):
def test_Conv2d_with_cpu(self):
rand_seed = int(get_rand_seed())
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
torch.manual_seed(rand_seed)
conv_cpu = torch.nn.Conv2d(1, 1, (3, 3))

conv_dpcpp = torch.nn.Conv2d(1, 1, (3, 3)).to(device=device)
conv_dpcpp.weight.data = conv_cpu.weight.data.to(device=device)
conv_dpcpp.bias.data = conv_cpu.bias.data.to(device=device)

input_cpu = torch.rand((1, 1, 7, 7))
input_dpcpp = input_cpu.to(device=device)

ipex.enable_auto_dnnl()
ipex.enable_mix_bf16_fp32()
self.assertEqual(input_dpcpp.dtype, torch.float)
out_dpcpp = conv_dpcpp(input_dpcpp)
out_cpu = conv_cpu(input_cpu)
self.assertEqual(out_dpcpp.dtype, torch.bfloat16)
self.assertEqual(out_dpcpp, out_cpu, 1e-2)

if __name__ == '__main__':
test = unittest.main()
17 changes: 11 additions & 6 deletions torch_ipex/csrc/aten_ipex_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,7 @@ std::vector<at::Tensor> shallowFallbackToCPUTensorList(const at::TensorList& ten
void reorderTensorToScalarTypeForDNNL(const at::Tensor& ipexTensor, at::ScalarType dstScalarType) {
TORCH_CHECK(dstScalarType == at::kBFloat16 || dstScalarType == at::kFloat);
auto tensor_dtype = ipexTensor.scalar_type();
TORCH_CHECK(tensor_dtype == at::kBFloat16 || tensor_dtype == at::kFloat);
if (tensor_dtype == dstScalarType)
if ((tensor_dtype != at::kBFloat16 && tensor_dtype != at::kFloat) || tensor_dtype == dstScalarType)
return;

if (check_tensor_own_shade_context(ipexTensor)) {
Expand All @@ -421,6 +420,7 @@ void reorderTensorToScalarTypeForDNNL(const at::Tensor& ipexTensor, at::ScalarTy
IPEXTensorImpl* ipex_tensor_impl = (IPEXTensorImpl *)ipexTensor.unsafeGetTensorImpl();
ipex_tensor_impl->reset_data_type(dstScalarType);
ipex_tensor_impl->storage().unsafeGetStorageImpl()->set_dtype(at::scalarTypeToTypeMeta(dstScalarType));
ipex_tensor_impl->set_storage(ipexTensor.storage());
return;
}
}
Expand All @@ -433,14 +433,18 @@ void reorderTensorToScalaraType(const at::Tensor& ipexTensor, at::ScalarType dst
if (!(ipexTensor.defined()))
return;

TORCH_CHECK(dstScalarType == at::kBFloat16 || dstScalarType == at::kFloat);

auto tensor_dtype = ipexTensor.scalar_type();
TORCH_CHECK(tensor_dtype == at::kBFloat16 || tensor_dtype == at::kFloat);
if (tensor_dtype == dstScalarType)
if ((tensor_dtype != at::kBFloat16 && tensor_dtype != at::kFloat) || tensor_dtype == dstScalarType)
return;

if (ipexTensor.is_sparse()) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ipexTensor.layout() == c10::kSparse);
auto&& ipex_values = ipexTensor._values();
reorderTensorToScalaraType(ipex_values, dstScalarType);
}

if (!check_tensor_own_whole_storage(ipexTensor)) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
return;
}

Expand Down Expand Up @@ -470,6 +474,7 @@ void reorderTensorToScalaraType(const at::Tensor& ipexTensor, at::ScalarType dst
}

ipexTensor.unsafeGetTensorImpl()->set_storage(storage_impl);
attachShadeDataConext(ipexTensor);
}


Expand Down
25 changes: 25 additions & 0 deletions torch_ipex/csrc/cpu/bf16/BF16Checker.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "BF16Checker.h"

#include "torch_ipex/csrc/utils.h"
#include "torch_ipex/csrc/auto_opt_config.h"

namespace torch_ipex {
namespace cpu {
namespace bf16 {
namespace chk {

bool bf16_support_the_tensors(const std::vector<at::Tensor> &tensor_vec) {
for (auto it = tensor_vec.begin(); it != tensor_vec.end(); ++it) {
if (!check_tensor_own_whole_storage(*it)) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
return false;
}
}

return true;
}

} // namespace chk
} // namespace bf16
} // namespace cpu
} // namespace torch_ipex
20 changes: 20 additions & 0 deletions torch_ipex/csrc/cpu/bf16/BF16Checker.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#pragma once

#include <ATen/Tensor.h>

namespace torch_ipex {
namespace cpu {
namespace bf16 {
namespace chk {

/**
* Check if the input tensors can be supported by BF16 OP.
*
* @param tensor_vec input tensors.
*/
bool bf16_support_the_tensors(const std::vector<at::Tensor> &tensor_vec);

} // namespace chk
} // namespace bf16
} // namespace cpu
} // namespace torch_ipex
6 changes: 4 additions & 2 deletions torch_ipex/csrc/cpu/bf16/Converter.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "Converter.h"

#include <ATen/Tensor.h>

#if defined(AVX512)
#include "vec/vec_type_cvt.h"
#define BF16_2_FP32(dst, src, len) cvt_bf16_to_fp32(dst, src, len)
Expand All @@ -15,11 +17,11 @@ namespace bf16 {
namespace converter {

void bf16_to_fp32(void *dst, const void *src, int len) {
BF16_2_FP32(dst, src, len);
BF16_2_FP32((float *)dst, (at::BFloat16 *)src, len);
}

void fp32_to_bf16(void *dst, const void *src, int len) {
FP32_2_BF16(dst, src, len);
FP32_2_BF16((at::BFloat16 *)dst, (float *)src, len);
}

} // namespace converter
Expand Down
4 changes: 4 additions & 0 deletions torch_ipex/csrc/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ bool check_auto_dnnl() {
return AutoOptConfig::singleton().get_auto_dnnl();
}

bool check_mix_bf16_fp32() {
return AutoOptConfig::singleton().get_mix_bf16_fp32();
}

bool check_tensor_own_whole_storage(const at::Tensor& tensor) {
if (!(tensor.defined()))
return false;
Expand Down
1 change: 1 addition & 0 deletions torch_ipex/csrc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ bool get_device_count(c10::Device dev_type, c10::DeviceIndex *count);
dil::data_type get_dil_data_type(at::ScalarType);
at::ScalarType get_at_data_type(dil::data_type);
bool check_auto_dnnl();
bool check_mix_bf16_fp32();
bool check_tensor_own_whole_storage(const at::Tensor& tensor);
bool check_tensor_own_shade_context(const at::Tensor& tensor);
bool check_aten_dil_shape_info(const at::Tensor& ipex_tensor, const dil::tensor &dil_tensor);
Expand Down