Skip to content

Commit 90b9c49

Browse files
committed
[llvm] Expose type and element count-related APIs on TensorSpec
Added a mechanism to check the element type, get the total element count, and the size of an element. Differential Revision: https://reviews.llvm.org/D85250
1 parent ac70b37 commit 90b9c49

File tree

3 files changed

+37
-2
lines changed

3 files changed

+37
-2
lines changed

llvm/include/llvm/Analysis/Utils/TFUtils.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,18 @@ class TensorSpec final {
6666

6767
bool operator!=(const TensorSpec &Other) const { return !(*this == Other); }
6868

69+
/// Get the number of elements in a tensor with this shape.
70+
size_t getElementCount() const { return ElementCount; }
71+
/// Get the size, in bytes, of one element.
72+
size_t getElementByteSize() const;
73+
74+
template <typename T> bool isElementType() const {
75+
return getDataType<T>() == TypeIndex;
76+
}
77+
6978
private:
7079
TensorSpec(const std::string &Name, int Port, int TypeIndex,
71-
const std::vector<int64_t> &Shape)
72-
: Name(Name), Port(Port), TypeIndex(TypeIndex), Shape(Shape) {}
80+
const std::vector<int64_t> &Shape);
7381

7482
template <typename T> static int getDataType() {
7583
llvm_unreachable("Undefined tensor type");
@@ -79,6 +87,7 @@ class TensorSpec final {
7987
int Port = 0;
8088
int TypeIndex = 0;
8189
std::vector<int64_t> Shape;
90+
size_t ElementCount = 0;
8291
};
8392

8493
Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,

llvm/lib/Analysis/TFUtils.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "tensorflow/c/c_api_experimental.h"
2525

2626
#include <cassert>
27+
#include <numeric>
2728

2829
using namespace llvm;
2930

@@ -84,6 +85,16 @@ class EvaluationResultImpl {
8485
std::vector<TF_Tensor *> Output;
8586
};
8687

88+
size_t TensorSpec::getElementByteSize() const {
89+
return TF_DataTypeSize(static_cast<TF_DataType>(TypeIndex));
90+
}
91+
92+
TensorSpec::TensorSpec(const std::string &Name, int Port, int TypeIndex,
93+
const std::vector<int64_t> &Shape)
94+
: Name(Name), Port(Port), TypeIndex(TypeIndex), Shape(Shape),
95+
ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
96+
std::multiplies<int64_t>())) {}
97+
8798
Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
8899
const json::Value &Value) {
89100
auto EmitError = [&](const llvm::Twine &Message) -> Optional<TensorSpec> {

llvm/unittests/Analysis/TFUtilsTest.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,18 @@ TEST(TFUtilsTest, JSONParsingInvalidTensorType) {
123123
auto Spec = getTensorSpecFromJSON(Ctx, *Value);
124124
EXPECT_FALSE(Spec.hasValue());
125125
}
126+
127+
TEST(TFUtilsTest, TensorSpecSizesAndTypes) {
128+
auto Spec1D = TensorSpec::createSpec<int16_t>("Hi1", {1});
129+
auto Spec2D = TensorSpec::createSpec<int16_t>("Hi2", {1, 1});
130+
auto Spec1DLarge = TensorSpec::createSpec<float>("Hi3", {10});
131+
auto Spec3DLarge = TensorSpec::createSpec<float>("Hi3", {2, 4, 10});
132+
EXPECT_TRUE(Spec1D.isElementType<int16_t>());
133+
EXPECT_FALSE(Spec3DLarge.isElementType<double>());
134+
EXPECT_EQ(Spec1D.getElementCount(), 1);
135+
EXPECT_EQ(Spec2D.getElementCount(), 1);
136+
EXPECT_EQ(Spec1DLarge.getElementCount(), 10);
137+
EXPECT_EQ(Spec3DLarge.getElementCount(), 80);
138+
EXPECT_EQ(Spec3DLarge.getElementByteSize(), sizeof(float));
139+
EXPECT_EQ(Spec1D.getElementByteSize(), sizeof(int16_t));
140+
}

0 commit comments

Comments
 (0)