Skip to content

Commit 05543d5

Browse files
refactor_autocast_mechanism (#29)
1 parent c837979 commit 05543d5

File tree

5 files changed

+57
-29
lines changed

5 files changed

+57
-29
lines changed

intel_pytorch_extension_py/amp/autocast_mode.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88

99
class autocast(object):
1010
def __init__(self, enabled=True, configure=conf.AmpConf(torch.bfloat16)):
11-
supported_dtype = [torch.float32, torch.bfloat16, torch.int8]
11+
supported_dtype = [torch.bfloat16, torch.int8]
1212
if configure.dtype not in supported_dtype :
1313
warnings.warn("In CPU autocast, but the target dtype is not supported. Disable the autocast.")
14-
warnings.warn("Supported dtype input is: torch.float32, torch.bfloat16, torch.int8.")
14+
warnings.warn("Supported dtype input is: torch.bfloat16, torch.int8.")
1515
enabled = False
16-
configure = conf.AmpConf(torch.float32)
16+
configure = conf.AmpConf(torch.bfloat16)
1717
self._enabled = enabled
1818
self._dtype = configure.dtype
1919

torch_ipex/csrc/autocast_kernel.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ at::Tensor conv_transpose3d(const at::Tensor& input, const at::Tensor& weight, c
5151
#if defined(ENABLE_AUTOCAST_VERBOSE)
5252
verbose::OpNameGuard op_name("conv_transpose3d");
5353
#endif
54-
return at::conv_transpose3d(cpu_cached_cast(at::kFloat, input),
55-
cpu_cached_cast(at::kFloat, weight),
56-
cpu_cached_cast(at::kFloat, bias),
54+
return at::conv_transpose3d(cpu_cached_cast(target_type, input),
55+
cpu_cached_cast(target_type, weight),
56+
cpu_cached_cast(target_type, bias),
5757
stride, padding, output_padding, groups, dilation);
5858
}
5959

@@ -222,7 +222,7 @@ at::Tensor gelu(const at::Tensor& input) {
222222
return int8::gelu(input);
223223
}
224224
// convert to fp32 path.
225-
return at::gelu(cpu_cached_cast(at::kFloat, input));
225+
return at::gelu(input);
226226
}
227227

228228
} // autocast

torch_ipex/csrc/autocast_mode.cpp

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@ thread_local std::unordered_map<c10::TensorImpl *, val_type> cached_casts;
1515

1616
thread_local int nesting = 0;
1717

18-
thread_local at::ScalarType current_target_dtype = at::kFloat;
18+
thread_local at::ScalarType current_target_dtype = at::kBFloat16;
1919
} // namespace
2020

2121
bool is_autocast_enabled() {
22-
return c10::impl::tls_is_dispatch_key_included(c10::DispatchKey::AutocastCPU);
22+
return !c10::impl::tls_is_dispatch_key_excluded(c10::DispatchKey::AutocastCPU);
2323
}
2424

2525
void set_autocast_enabled(bool new_enabled) {
26-
c10::impl::tls_set_dispatch_key_included(c10::DispatchKey::AutocastCPU,
27-
new_enabled);
26+
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastCPU,
27+
!new_enabled);
2828
}
2929

3030
at::ScalarType get_autocast_dtype() {
@@ -176,23 +176,14 @@ MAKE_REGISTER_FUNC(ADD_NS(mm), "mm", Tensor (const Tensor &, const Tensor &), us
176176
MAKE_REGISTER_FUNC(ADD_NS(baddbmm), "baddbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), user_defined_dtype)
177177
MAKE_REGISTER_FUNC(ADD_NS(addmm), "addmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), user_defined_dtype)
178178
MAKE_REGISTER_FUNC(ADD_NS(addbmm), "addbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), user_defined_dtype)
179+
MAKE_REGISTER_FUNC(ADD_NS(conv_transpose1d), "conv_transpose1d", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&,
180+
IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), user_defined_dtype)
181+
MAKE_REGISTER_FUNC(ADD_NS(conv_transpose2d), "conv_transpose2d.input", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&,
182+
IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), user_defined_dtype)
179183

180184
// fp32 cast policy
181-
MAKE_REGISTER_FUNC(ADD_NS(convolution), "convolution", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t), fp32)
182185
MAKE_REGISTER_FUNC(ADD_NS(avg_pool2d), "avg_pool2d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional<int64_t>), fp32)
183186
MAKE_REGISTER_FUNC(ADD_NS(avg_pool3d), "avg_pool3d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional<int64_t>), fp32)
184-
MAKE_REGISTER_FUNC(ADD_NS(upsample_nearest1d), "upsample_nearest1d", Tensor (const Tensor &, IntArrayRef, c10::optional<double>), fp32)
185-
MAKE_REGISTER_FUNC(ADD_NS(upsample_nearest1d), "upsample_nearest1d.vec", Tensor (const Tensor &, c10::optional<IntArrayRef>, c10::optional<ArrayRef<double>>), fp32)
186-
MAKE_REGISTER_FUNC(ADD_NS(upsample_nearest2d), "upsample_nearest2d", Tensor (const Tensor &, IntArrayRef, c10::optional<double>, c10::optional<double>), fp32)
187-
MAKE_REGISTER_FUNC(ADD_NS(upsample_nearest2d), "upsample_nearest2d.vec", Tensor (const Tensor &, c10::optional<IntArrayRef>, c10::optional<ArrayRef<double>>), fp32)
188-
MAKE_REGISTER_FUNC(ADD_NS(upsample_nearest3d), "upsample_nearest3d", Tensor (const Tensor &, IntArrayRef, c10::optional<double>, c10::optional<double>, c10::optional<double>), fp32)
189-
MAKE_REGISTER_FUNC(ADD_NS(upsample_nearest3d), "upsample_nearest3d.vec", Tensor (const Tensor &, c10::optional<IntArrayRef>, c10::optional<ArrayRef<double>>), fp32)
190-
MAKE_REGISTER_FUNC(ADD_NS(upsample_linear1d), "upsample_linear1d", Tensor (const Tensor &, IntArrayRef, bool, c10::optional<double>), fp32)
191-
MAKE_REGISTER_FUNC(ADD_NS(upsample_linear1d), "upsample_linear1d.vec", Tensor (const Tensor &, c10::optional<IntArrayRef>, bool, c10::optional<ArrayRef<double>>), fp32)
192-
MAKE_REGISTER_FUNC(ADD_NS(upsample_bilinear2d), "upsample_bilinear2d", Tensor (const Tensor &, IntArrayRef, bool, c10::optional<double>, c10::optional<double>), fp32)
193-
MAKE_REGISTER_FUNC(ADD_NS(upsample_bilinear2d), "upsample_bilinear2d.vec", Tensor (const Tensor &, c10::optional<IntArrayRef>, bool, c10::optional<ArrayRef<double>>), fp32)
194-
MAKE_REGISTER_FUNC(ADD_NS(upsample_trilinear3d), "upsample_trilinear3d", Tensor (const Tensor &, IntArrayRef, bool, c10::optional<double>, c10::optional<double>, c10::optional<double>), fp32)
195-
MAKE_REGISTER_FUNC(ADD_NS(upsample_trilinear3d), "upsample_trilinear3d.vec", Tensor (const Tensor &, c10::optional<IntArrayRef>, bool, c10::optional<ArrayRef<double>>), fp32)
196187
MAKE_REGISTER_FUNC(ADD_NS(binary_cross_entropy), "binary_cross_entropy", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, int64_t), fp32)
197188
MAKE_REGISTER_FUNC(ADD_NS(binary_cross_entropy_with_logits), "binary_cross_entropy_with_logits", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, int64_t), fp32)
198189
MAKE_REGISTER_FUNC(ADD_NS(pow), "pow.Tensor_Scalar", Tensor (const Tensor &, const Scalar &), fp32)

torch_ipex/csrc/init_python_bindings.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ void InitIpexModuleBindings(py::module m) {
7777
m.def("disable_jit_opt", []() { AutoOptConfig::singleton().set_jit_fuse(false); });
7878
m.def("get_jit_opt", []() { return AutoOptConfig::singleton().get_jit_fuse(); });
7979

80-
8180
// int8 path
8281
m.def("clear_autocast_cache_int8", &torch_ipex::autocast::int8::clear_autocast_cache_int8);
8382
m.def("enable_int8_calibration", []() { AutoOptConfig::singleton().set_int8_calibration(true); });

torch_patches/autocast.patch

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,50 @@
11
diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h
2-
index b32f991df3..99bf28b380 100644
2+
index ff6a84ebbe..b3d3153169 100644
33
--- a/c10/core/DispatchKey.h
44
+++ b/c10/core/DispatchKey.h
5-
@@ -227,6 +227,7 @@ enum class DispatchKey : uint8_t {
5+
@@ -228,7 +228,7 @@ enum class DispatchKey : uint8_t {
6+
67
// Autocasting precedes VariableTypeId, to ensure casts are autograd-exposed
78
// and inputs are saved for backward in the post-autocast type.
8-
Autocast,
9+
- // AutocastCPU,
910
+ AutocastCPU,
11+
AutocastCUDA,
1012

1113
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ WRAPPERS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
12-
// There are a number of alternative modes which may want to handle before
14+
diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp
15+
index 272cf33118..8358e931f0 100644
16+
--- a/c10/core/DispatchKeySet.cpp
17+
+++ b/c10/core/DispatchKeySet.cpp
18+
@@ -78,8 +78,8 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
19+
20+
DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t) {
21+
switch (t) {
22+
- // case DispatchKey::CPU:
23+
- // return DispatchKeySet(DispatchKey::AutocastCPU);
24+
+ case DispatchKey::CPU:
25+
+ return DispatchKeySet(DispatchKey::AutocastCPU);
26+
case DispatchKey::CUDA:
27+
return DispatchKeySet(DispatchKey::AutocastCUDA);
28+
default:
29+
diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h
30+
index 223355203c..e11572f23a 100644
31+
--- a/c10/core/DispatchKeySet.h
32+
+++ b/c10/core/DispatchKeySet.h
33+
@@ -223,7 +223,7 @@ constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({
34+
});
35+
36+
constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({
37+
- // DispatchKey::AutocastCPU,
38+
+ DispatchKey::AutocastCPU,
39+
DispatchKey::AutocastCUDA,
40+
});
41+
42+
@@ -234,7 +234,7 @@ constexpr DispatchKeySet default_included_set = DispatchKeySet({
43+
});
44+
45+
constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
46+
- // DispatchKey::AutocastCPU,
47+
+ DispatchKey::AutocastCPU,
48+
DispatchKey::AutocastCUDA,
49+
});
50+

0 commit comments

Comments
 (0)