@@ -15,16 +15,16 @@ thread_local std::unordered_map<c10::TensorImpl *, val_type> cached_casts;
15
15
16
16
thread_local int nesting = 0 ;
17
17
18
- thread_local at::ScalarType current_target_dtype = at::kFloat ;
18
+ thread_local at::ScalarType current_target_dtype = at::kBFloat16 ;
19
19
} // namespace
20
20
21
21
bool is_autocast_enabled () {
22
- return c10::impl::tls_is_dispatch_key_included (c10::DispatchKey::AutocastCPU);
22
+ return ! c10::impl::tls_is_dispatch_key_excluded (c10::DispatchKey::AutocastCPU);
23
23
}
24
24
25
25
void set_autocast_enabled (bool new_enabled) {
26
- c10::impl::tls_set_dispatch_key_included (c10:: DispatchKey::AutocastCPU,
27
- new_enabled);
26
+ c10::impl::tls_set_dispatch_key_excluded ( DispatchKey::AutocastCPU,
27
+ ! new_enabled);
28
28
}
29
29
30
30
at::ScalarType get_autocast_dtype () {
@@ -176,23 +176,14 @@ MAKE_REGISTER_FUNC(ADD_NS(mm), "mm", Tensor (const Tensor &, const Tensor &), us
176
176
MAKE_REGISTER_FUNC (ADD_NS(baddbmm), " baddbmm" , Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), user_defined_dtype)
177
177
MAKE_REGISTER_FUNC (ADD_NS(addmm), " addmm" , Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), user_defined_dtype)
178
178
MAKE_REGISTER_FUNC (ADD_NS(addbmm), " addbmm" , Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), user_defined_dtype)
179
+ MAKE_REGISTER_FUNC (ADD_NS(conv_transpose1d), " conv_transpose1d" , Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&,
180
+ IntArrayRef, IntArrayRef, IntArrayRef, int64_t , IntArrayRef), user_defined_dtype)
181
+ MAKE_REGISTER_FUNC (ADD_NS(conv_transpose2d), " conv_transpose2d.input" , Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&,
182
+ IntArrayRef, IntArrayRef, IntArrayRef, int64_t , IntArrayRef), user_defined_dtype)
179
183
180
184
// fp32 cast policy
181
- MAKE_REGISTER_FUNC (ADD_NS(convolution), " convolution" , Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool , IntArrayRef, int64_t ), fp32)
182
185
MAKE_REGISTER_FUNC (ADD_NS(avg_pool2d), " avg_pool2d" , Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool , bool , c10::optional<int64_t >), fp32)
183
186
MAKE_REGISTER_FUNC (ADD_NS(avg_pool3d), " avg_pool3d" , Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool , bool , c10::optional<int64_t >), fp32)
184
- MAKE_REGISTER_FUNC (ADD_NS(upsample_nearest1d), " upsample_nearest1d" , Tensor (const Tensor &, IntArrayRef, c10::optional<double >), fp32)
185
- MAKE_REGISTER_FUNC (ADD_NS(upsample_nearest1d), " upsample_nearest1d.vec" , Tensor (const Tensor &, c10::optional<IntArrayRef>, c10::optional<ArrayRef<double >>), fp32)
186
- MAKE_REGISTER_FUNC (ADD_NS(upsample_nearest2d), " upsample_nearest2d" , Tensor (const Tensor &, IntArrayRef, c10::optional<double >, c10::optional<double >), fp32)
187
- MAKE_REGISTER_FUNC (ADD_NS(upsample_nearest2d), " upsample_nearest2d.vec" , Tensor (const Tensor &, c10::optional<IntArrayRef>, c10::optional<ArrayRef<double >>), fp32)
188
- MAKE_REGISTER_FUNC (ADD_NS(upsample_nearest3d), " upsample_nearest3d" , Tensor (const Tensor &, IntArrayRef, c10::optional<double >, c10::optional<double >, c10::optional<double >), fp32)
189
- MAKE_REGISTER_FUNC (ADD_NS(upsample_nearest3d), " upsample_nearest3d.vec" , Tensor (const Tensor &, c10::optional<IntArrayRef>, c10::optional<ArrayRef<double >>), fp32)
190
- MAKE_REGISTER_FUNC (ADD_NS(upsample_linear1d), " upsample_linear1d" , Tensor (const Tensor &, IntArrayRef, bool , c10::optional<double >), fp32)
191
- MAKE_REGISTER_FUNC (ADD_NS(upsample_linear1d), " upsample_linear1d.vec" , Tensor (const Tensor &, c10::optional<IntArrayRef>, bool , c10::optional<ArrayRef<double >>), fp32)
192
- MAKE_REGISTER_FUNC (ADD_NS(upsample_bilinear2d), " upsample_bilinear2d" , Tensor (const Tensor &, IntArrayRef, bool , c10::optional<double >, c10::optional<double >), fp32)
193
- MAKE_REGISTER_FUNC (ADD_NS(upsample_bilinear2d), " upsample_bilinear2d.vec" , Tensor (const Tensor &, c10::optional<IntArrayRef>, bool , c10::optional<ArrayRef<double >>), fp32)
194
- MAKE_REGISTER_FUNC (ADD_NS(upsample_trilinear3d), " upsample_trilinear3d" , Tensor (const Tensor &, IntArrayRef, bool , c10::optional<double >, c10::optional<double >, c10::optional<double >), fp32)
195
- MAKE_REGISTER_FUNC (ADD_NS(upsample_trilinear3d), " upsample_trilinear3d.vec" , Tensor (const Tensor &, c10::optional<IntArrayRef>, bool , c10::optional<ArrayRef<double >>), fp32)
196
187
MAKE_REGISTER_FUNC (ADD_NS(binary_cross_entropy), " binary_cross_entropy" , Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, int64_t ), fp32)
197
188
MAKE_REGISTER_FUNC (ADD_NS(binary_cross_entropy_with_logits), " binary_cross_entropy_with_logits" , Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, int64_t ), fp32)
198
189
MAKE_REGISTER_FUNC (ADD_NS(pow), " pow.Tensor_Scalar" , Tensor (const Tensor &, const Scalar &), fp32)
0 commit comments