Skip to content

Commit 6b2bf99

Browse files
mihaimaruseacGeeta Chavan
authored and
Geeta Chavan
committed
Validate that a and b are proper sparse tensors
PiperOrigin-RevId: 373274848 Change-Id: I3a665ac3a29dee9fb69bdf408a939330cb93ea75
1 parent 12a6ead commit 6b2bf99

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

tensorflow/core/kernels/sparse_sparse_binary_op_shared.cc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ class SparseSparseBinaryOpShared : public OpKernel {
150150

151151
const int64 a_nnz = a_indices_t->dim_size(0);
152152
const int64 b_nnz = b_indices_t->dim_size(0);
153+
153154
const auto a_values = a_values_t->vec<T>();
154155
const auto b_values = b_values_t->vec<T>();
155156

@@ -166,6 +167,14 @@ class SparseSparseBinaryOpShared : public OpKernel {
166167
"Input shapes should be a vector but received shapes ",
167168
a_shape_t->shape().DebugString(), " and ",
168169
b_shape_t->shape().DebugString()));
170+
const int num_dims = a_indices_t->dim_size(1);
171+
OP_REQUIRES(
172+
ctx, a_shape_t->NumElements() == num_dims,
173+
errors::InvalidArgument("Second dimension of a_indices and length of "
174+
"a_shape must match, got ",
175+
num_dims, " and ", a_shape_t->NumElements()));
176+
OP_REQUIRES(ctx, num_dims > 0,
177+
errors::InvalidArgument("Tensors must not be empty"));
169178
OP_REQUIRES(ctx, a_shape_t->IsSameSize(*b_shape_t),
170179
errors::InvalidArgument(
171180
"Operands do not have the same ranks; got shapes: ",
@@ -180,12 +189,6 @@ class SparseSparseBinaryOpShared : public OpKernel {
180189
" for dimension ", i));
181190
}
182191

183-
OP_REQUIRES(
184-
ctx, a_indices_t->dim_size(1) == b_indices_t->dim_size(1),
185-
errors::InvalidArgument(
186-
"Indices' dimensions do not match: got ", a_indices_t->dim_size(1),
187-
" and ", b_indices_t->dim_size(1), " for the second dimension."));
188-
const int num_dims = a_indices_t->dim_size(1);
189192
const auto a_indices_mat = a_indices_t->matrix<int64>();
190193
const auto b_indices_mat = b_indices_t->matrix<int64>();
191194
std::vector<T> a_augmented_values, b_augmented_values;

0 commit comments

Comments
 (0)