Skip to content

Commit 054697e

Browse files
chunyuan-wEikanWang
authored andcommitted
enable int8 for LSTM
1 parent 4017ffb commit 054697e

File tree

11 files changed

+397
-59
lines changed

11 files changed

+397
-59
lines changed

torch_ipex/csrc/cpu/CustomOPs.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -714,15 +714,16 @@ class NewRNNLayerOp : public torch::autograd::Function<NewRNNLayerOp> {
714714
public:
715715
static std::vector<at::Tensor> _forward(const at::Tensor& input, const at::Tensor& w1, const at::Tensor& w2,
716716
const at::Tensor& w3, const at::Tensor& w4, const at::Tensor& hx, const at::Tensor& cx, bool reverse, int64_t mode,
717-
int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes) {
717+
int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes,
718+
const std::vector<float>& scales = {}, const std::vector<int32_t>& shift = {}, bool quantized = false) {
718719
#if defined(IPEX_PROFILE_OP)
719720
RECORD_FUNCTION("NewRNNLayerOp::_forward", std::vector<c10::IValue>({}));
720721
#endif
721722
try {
722723
if (torch_ipex::check_auto_dnnl() &&
723724
input.device().type() == c10::DeviceType::XPU) {
724725
return torch_ipex::cpu::AtenIpexCPUDev::dil_rnn_layer(
725-
input, w1, w2, w3, w4, hx, cx, reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes);
726+
input, w1, w2, w3, w4, hx, cx, reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, scales, shift, quantized);
726727
}
727728
} catch (std::exception &e) {
728729
#if defined(_DEBUG)
@@ -783,6 +784,7 @@ class NewRNNLayerOp : public torch::autograd::Function<NewRNNLayerOp> {
783784
grad_inputs[3], grad_inputs[4], grad_inputs[5],
784785
grad_inputs[6], at::Tensor(), at::Tensor(),
785786
at::Tensor(), at::Tensor(), at::Tensor(),
787+
at::Tensor(), at::Tensor(), at::Tensor(),
786788
at::Tensor(), at::Tensor(), at::Tensor()};
787789
}
788790
} catch (std::exception &e) {

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2075,17 +2075,60 @@ at::Tensor AtenIpexCPUDev::dil_cat(at::TensorList tensors, int64_t dim) {
20752075
dim = at::legacy_cat_wrap_dim(dim, tensors);
20762076
std::vector<dil::tensor> x;
20772077
at::Tensor tensors_contiguous[tensors.size()];
2078+
2079+
bool has_scale = false;
2080+
bool has_shift = false;
2081+
dil::scale_t data_scale;
2082+
std::vector<int32_t> data_shift;
2083+
20782084
for (auto i = 0; i < tensors.size(); i++) {
20792085
IPEX_CHECK(!(tensors[i].dim() == 1 && tensors[i].sizes()[0] == 0),
20802086
"Currently Mkldnn cat operators do not support empty tensor.");
20812087
tensors_contiguous[i] = tensors[i].is_contiguous() ? tensors[i] : tensors[i].contiguous();
20822088

20832089
dbl::comm::reorder_to_bf16_for_mix_prec(tensors_contiguous[i], true);
20842090

2085-
x.push_back(dbl::comm::try_gen_dil_tensor(tensors_contiguous[i]));
2091+
auto dil_input = dbl::comm::try_gen_dil_tensor(tensors_contiguous[i]);
2092+
2093+
// TODO: verify using a simpler way??
2094+
if (i == 0) {
2095+
if (dil_input.has_scale()) {
2096+
IPEX_CHECK(dil_input.get_scale().size() == 1, "only support scale size == 1");
2097+
has_scale = true;
2098+
data_scale = dil_input.get_scale();
2099+
}
2100+
if (dil_input.has_zero_point()) {
2101+
IPEX_CHECK(dil_input.get_zero_point().size() == 1, "only support zero point size == 1");
2102+
has_shift = true;
2103+
data_shift = dil_input.get_zero_point();
2104+
}
2105+
} else {
2106+
IPEX_CHECK(dil_input.has_scale() == has_scale, "tensors to cat should have same scale");
2107+
if (dil_input.has_scale()) {
2108+
IPEX_CHECK(dil_input.get_scale().size() == 1, "only support scale size == 1");
2109+
IPEX_CHECK(dil_input.get_scale()[0] == data_scale[0], "tensors to cat should have same scale");
2110+
}
2111+
IPEX_CHECK(dil_input.has_zero_point() == has_shift, "tensors to cat should have same zero point");
2112+
if (dil_input.has_zero_point()) {
2113+
IPEX_CHECK(dil_input.get_zero_point().size() == 1, "only support zero point size == 1");
2114+
IPEX_CHECK(dil_input.get_zero_point()[0] == data_shift[0], "tensors to cat should have same zero point");
2115+
}
2116+
}
2117+
2118+
x.push_back(dil_input);
20862119
}
20872120
dil::tensor y;
20882121
dil::concat::compute(x, dim, y);
2122+
2123+
// For bidirectional LSTM output cat
2124+
if (has_scale){
2125+
y.set_scale(data_scale);
2126+
}
2127+
2128+
if (has_shift) {
2129+
y.set_zero_point(data_shift);
2130+
}
2131+
20892132
return dbl::comm::gen_aten_tensor_by(std::move(y));
20902133
}
20912134

@@ -2597,10 +2640,12 @@ at::Tensor& AtenIpexCPUDev::dil_copy_(
25972640

25982641
std::vector<at::Tensor> AtenIpexCPUDev::dil_rnn_layer(const at::Tensor& input, const at::Tensor& w1, const at::Tensor& w2,
25992642
const at::Tensor& w3, const at::Tensor& w4, const at::Tensor& hx, const at::Tensor& cx, bool reverse, int64_t mode,
2600-
int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes) {
2643+
int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes,
2644+
const std::vector<float>& scales, const std::vector<int32_t>& shift, bool quantized) {
26012645
DEBUG("AtenIpexCPUDev::dil_rnn_layer\n");
2646+
26022647
return dbl::rnn::mkldnn_rnn_layer(input, w1, w2, w3, w4, hx, cx, reverse, mode,
2603-
hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes);
2648+
hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, scales, shift, quantized);
26042649
}
26052650

26062651
std::vector<at::Tensor> AtenIpexCPUDev::dil_rnn_layer_backward(const at::Tensor& input, const at::Tensor& w1, const at::Tensor& w2,

torch_ipex/csrc/cpu/DevOPs.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class AtenIpexCPUDev {
9494
static at::Tensor dil_shuffle(const at::Tensor & self, at::IntArrayRef view_shape, int64_t dim0, int64_t dim1);
9595
static std::tuple<at::Tensor,at::Tensor> dil__pack_padded_sequence(const at::Tensor & input, const at::Tensor & lengths, bool batch_first);
9696
static at::Tensor& dil_copy_(at::Tensor & self, const at::Tensor & src, bool non_blocking);
97-
static std::vector<at::Tensor> dil_rnn_layer(const at::Tensor& input, const at::Tensor& w1, const at::Tensor& w2, const at::Tensor& w3, const at::Tensor& w4, const at::Tensor& hx, const at::Tensor& cx, bool reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes);
97+
static std::vector<at::Tensor> dil_rnn_layer(const at::Tensor& input, const at::Tensor& w1, const at::Tensor& w2, const at::Tensor& w3, const at::Tensor& w4, const at::Tensor& hx, const at::Tensor& cx, bool reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const std::vector<float>& scales, const std::vector<int32_t>& shift, bool quantized);
9898
static std::vector<at::Tensor> dil_rnn_layer_backward(const at::Tensor& input, const at::Tensor& w1, const at::Tensor& w2, const at::Tensor& w3, const at::Tensor& w4, const at::Tensor& hx, const at::Tensor& cx, const at::Tensor& output, const at::Tensor& hy, const at::Tensor& cy, const at::Tensor& grad_output, const at::Tensor& grad_hy, const at::Tensor& grad_cy, bool reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes);
9999
static at::Tensor dil_upsample_nearest1d(const at::Tensor & self, at::IntArrayRef output_size, c10::optional<double> scales);
100100
static at::Tensor dil_upsample_nearest1d_backward(const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional<double> scales);

torch_ipex/csrc/cpu/ExtendOPs.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
#include "CustomOPs.h"
44
#include "DevOPs.h"
55
#include "FusionOPs.h"
6+
#include "dbl/Common.h"
67
#include "aten/aten.hpp"
78
#include "bf16/vec/bf16_vec_kernel.h"
89
#include "dil/dil.hpp"
10+
#include "torch_ipex/csrc/cpu/int8/Config.h"
911
#include "xsmm/libxsmm_utils.h"
1012
#include <ATen/Parallel.h>
1113
#include <ATen/MatrixRef.h>
@@ -465,16 +467,19 @@ std::vector<at::Tensor> rnn_layer(const at::Tensor& input,
465467
at::TensorList weights, const at::Tensor& hx,
466468
const at::Tensor& cx, bool reverse, int64_t mode,
467469
int64_t hidden_size, int64_t num_layers, bool train,
468-
bool bidirectional, at::IntArrayRef batch_sizes) {
470+
bool bidirectional, at::IntArrayRef batch_sizes,
471+
const std::vector<float>& scales,
472+
const std::vector<int32_t>& shift,
473+
bool quantized) {
469474
TORCH_CHECK(weights.size() == 2 || weights.size() == 4);
470475
if (weights.size() == 4) {
471476
if (at::GradMode::is_enabled())
472477
return NewRNNLayerOp::apply(input, weights[0], weights[1], weights[2], weights[3], hx, cx, reverse, mode, hidden_size, num_layers, true, train, bidirectional, batch_sizes);
473-
return NewRNNLayerOp::_forward(input, weights[0], weights[1], weights[2], weights[3], hx, cx, reverse, mode, hidden_size, num_layers, true, train, bidirectional, batch_sizes);
478+
return NewRNNLayerOp::_forward(input, weights[0], weights[1], weights[2], weights[3], hx, cx, reverse, mode, hidden_size, num_layers, true, train, bidirectional, batch_sizes, scales, shift, quantized);
474479
} else {
475480
if (at::GradMode::is_enabled())
476481
return NewRNNLayerOp::apply(input, weights[0], weights[1], at::zeros(weights[0].sizes(), weights[0].options()), at::zeros(weights[1].sizes(), weights[1].options()), hx, cx, reverse, mode, hidden_size, num_layers, false, train, bidirectional, batch_sizes);
477-
return NewRNNLayerOp::_forward(input, weights[0], weights[1], at::zeros(weights[0].sizes(), weights[0].options()), at::zeros(weights[1].sizes(), weights[1].options()), hx, cx, reverse, mode, hidden_size, num_layers, false, train, bidirectional, batch_sizes);
482+
return NewRNNLayerOp::_forward(input, weights[0], weights[1], at::zeros(weights[0].sizes(), weights[0].options()), at::zeros(weights[1].sizes(), weights[1].options()), hx, cx, reverse, mode, hidden_size, num_layers, false, train, bidirectional, batch_sizes, scales, shift, quantized);
478483
}
479484
}
480485
// MKLDNN RNN integration notes:
@@ -514,6 +519,27 @@ std::vector<at::Tensor> rnn(
514519
at::MatrixRef<at::Tensor> weights{weight, static_cast<size_t>(weight_stride0)};
515520

516521
auto num_directions = bidirectional ? 2 : 1;
522+
523+
// no need to do calibration for the output in lstm, will use the scale & zero point of the input
524+
// to dequantize the output from u8 to f32, need to add an "output" here but actually unused
525+
// For LSTM, we only need to calibrate the input to the first layer
526+
// TODO: add int8 for gru and rnn.
527+
if (check_auto_mix_int8_fp32() && check_int8_calibration() && static_cast<dil::rnn_kind>(mode) == dil::rnn_kind::LSTM) {
528+
int64_t num_ops_id = Int8OptConfig::fetch_and_add_ops_id();
529+
insert_or_updata_observer({input}, {input}, "lstm", num_ops_id, /*asymmetric*/true);
530+
}
531+
532+
bool quantized = false;
533+
std::vector<std::vector<float>> scales = {};
534+
std::vector<std::vector<int32_t>> shift = {};
535+
if (check_auto_mix_int8_fp32() && !check_int8_calibration() && static_cast<dil::rnn_kind>(mode) == dil::rnn_kind::LSTM) {
536+
int64_t num_ops_id = Int8OptConfig::fetch_and_add_ops_id();
537+
quantized = torch_ipex::cpu::dbl::comm::get_int8_quantized_status(num_ops_id);
538+
std::tie(scales, shift) = torch_ipex::cpu::dbl::comm::get_int8_asymmetric(num_ops_id);
539+
IPEX_CHECK(scales.size() > 0, "incorrect scale size");
540+
IPEX_CHECK(shift.size() > 0, "incorrect shift size");
541+
}
542+
517543
auto layer_input = input;
518544
std::vector<at::Tensor> layer_output(num_directions);
519545
std::vector<at::Tensor> layer_hy(num_layers * num_directions);
@@ -525,7 +551,7 @@ std::vector<at::Tensor> rnn(
525551
auto layer_hx = hx[index];
526552
auto layer_cx = cx[index];
527553
auto reverse = (direction > 0);
528-
auto outputs = rnn_layer(layer_input, layer_weights, layer_hx, layer_cx, reverse, mode, hidden_size, num_layers, train, bidirectional, batch_sizes);
554+
auto outputs = rnn_layer(layer_input, layer_weights, layer_hx, layer_cx, reverse, mode, hidden_size, num_layers, train, bidirectional, batch_sizes, scales[0], shift[0], quantized);
529555
layer_output[direction] = outputs[0];
530556
layer_hy[index] = outputs[1];
531557
layer_cy[index] = outputs[2];

torch_ipex/csrc/cpu/ShadeDataContext.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,16 @@ struct ShadeDataContext {
227227
return res;
228228
}
229229

230+
static inline bool isTensorMixPrecision(const at::Tensor &tensor, MIX_PREC_TYPE mix_dtype) {
231+
// Check whether tensor is mix_type.
232+
void *raw_context = tensor.storage().data_ptr().get_context();
233+
ShadeDataContext *shade_data_context = (ShadeDataContext*)raw_context;
234+
if (shade_data_context->mix_prec_type == mix_dtype && mix_dtype != MIX_PREC_TYPE::NONE) {
235+
return true;
236+
}
237+
return false;
238+
}
239+
230240
/**
231241
* Check if the input aten tensor is a parameter.
232242
*

0 commit comments

Comments
 (0)