Skip to content

Commit 4ec1990

Browse files
Jerry-GeTai78641
andauthored
[mlir][tosa] Remove FullyConnectedOp from TOSA Dialect (#126152)
This patch removes FullyConncected Operator from the TOSA Dialect and all associated tests and transforms. This is part of the TOSA v1.0 alignment effort: https://discourse.llvm.org/t/rfc-tosa-dialect-increment-to-v1-0/83708 Signed-off-by: Tai Ly <tai.ly@arm.com> Co-authored-by: Tai Ly <tai.ly@arm.com>
1 parent ac217ee commit 4ec1990

File tree

17 files changed

+4
-651
lines changed

17 files changed

+4
-651
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,6 @@ def Tosa_TransConvOpQuantInfoBuilder : OpBuilder<
150150
outputShape, acc_type);
151151
}]>;
152152

153-
// The tosa.fully_connected op has its own builder as it does not have
154-
// strides/dilation/padding.
155-
def Tosa_FCOpQuantInfoBuilder : OpBuilder<
156-
(ins "Type":$outputType, "Value":$input, "Value":$weight, "Value":$bias),
157-
[{
158-
buildFCOpWithQuantInfo($_builder, $_state, outputType,
159-
input, weight, bias);
160-
}]>;
161-
162153
// The tosa.matmul op is also intended to be generated where a fully_connected
163154
// op must be constructed where the weight is not a constant. In this case,
164155
// the fully_connected op must be expressed using matmul.

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -224,32 +224,6 @@ def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
224224
}];
225225
}
226226

227-
//===----------------------------------------------------------------------===//
228-
// Operator: fully_connected
229-
//===----------------------------------------------------------------------===//
230-
def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {
231-
let summary = "Fully Connected operator";
232-
233-
let description = [{
234-
Performs a fully connected network.
235-
}];
236-
237-
let arguments = (ins
238-
Tosa_Tensor2D:$input,
239-
TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
240-
Tosa_Tensor1D:$bias,
241-
OptionalAttr<I32Attr>:$input_zp,
242-
OptionalAttr<I32Attr>:$weight_zp
243-
);
244-
245-
let results = (outs
246-
Tosa_Tensor2D:$output
247-
);
248-
249-
let builders = [Tosa_FCOpQuantInfoBuilder];
250-
let hasVerifier = 1;
251-
}
252-
253227
//===----------------------------------------------------------------------===//
254228
// Operator: matmul
255229
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
8181
"number">;
8282

8383
// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
84-
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp, tosa::FullyConnectedOp
84+
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp
8585
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
8686
Tosa_QuantizedInt, AnyFloat]>;
8787

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ namespace tosa {
2626

2727
// Expose Rewrite Functions that decompose TOSA Ops into further TOSA Ops.
2828
// The rewrites can be selectively added to a conversion pass.
29-
void populateTosaDecomposeConv2D(MLIRContext *ctx, RewritePatternSet &patterns);
3029
void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
3130
RewritePatternSet &patterns);
3231
void populateTosaDecomposeDepthwise(MLIRContext *ctx,

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 0 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -607,84 +607,6 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
607607
}
608608
};
609609

610-
class FullyConnectedConverter
611-
: public OpConversionPattern<tosa::FullyConnectedOp> {
612-
public:
613-
using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern;
614-
LogicalResult
615-
matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor,
616-
ConversionPatternRewriter &rewriter) const final {
617-
Location loc = op.getLoc();
618-
auto outputTy = cast<ShapedType>(op.getType());
619-
auto input = op.getInput();
620-
auto inputTy = cast<ShapedType>(input.getType());
621-
622-
auto bias = op.getBias();
623-
624-
auto weight = op.getWeight();
625-
auto weightTy = cast<ShapedType>(weight.getType());
626-
auto weightShape = weightTy.getShape();
627-
628-
auto outputETy = outputTy.getElementType();
629-
630-
SmallVector<Value> dynDims;
631-
dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
632-
633-
if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) {
634-
dynDims[0] = rewriter.create<tensor::DimOp>(loc, input, 0);
635-
}
636-
637-
if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) {
638-
dynDims[1] = rewriter.create<tensor::DimOp>(loc, weight, 0);
639-
}
640-
641-
SmallVector<Value> filteredDims = condenseValues(dynDims);
642-
643-
SmallVector<int64_t> permutation = {1, 0};
644-
auto permutationAttr = rewriter.getI64TensorAttr(permutation);
645-
Value permutationValue =
646-
rewriter.create<arith::ConstantOp>(loc, permutationAttr);
647-
648-
SmallVector<int64_t> newWeightShape = {weightShape[1], weightShape[0]};
649-
Type newWeightTy =
650-
RankedTensorType::get(newWeightShape, weightTy.getElementType());
651-
652-
Value transposedWeight = rewriter.create<tosa::TransposeOp>(
653-
loc, newWeightTy, weight, permutationValue);
654-
655-
Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
656-
loc, outputTy.getShape(), outputETy, filteredDims);
657-
658-
Value broadcastBias =
659-
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
660-
661-
if (!op.getInputZp() && !op.getWeightZp()) {
662-
Value matmul = rewriter
663-
.create<linalg::MatmulOp>(
664-
loc, TypeRange{op.getType()},
665-
ValueRange{input, transposedWeight}, broadcastBias)
666-
->getResult(0);
667-
668-
rewriter.replaceOp(op, matmul);
669-
return success();
670-
}
671-
672-
auto inputZp = rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
673-
auto outputZp =
674-
rewriter.create<arith::ConstantOp>(loc, op.getWeightZpAttr());
675-
Value matmul =
676-
rewriter
677-
.create<linalg::QuantizedMatmulOp>(
678-
loc, TypeRange{op.getType()},
679-
ValueRange{input, transposedWeight, inputZp, outputZp},
680-
broadcastBias)
681-
->getResult(0);
682-
683-
rewriter.replaceOp(op, matmul);
684-
return success();
685-
}
686-
};
687-
688610
class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
689611
public:
690612
using OpConversionPattern::OpConversionPattern;
@@ -1090,7 +1012,6 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
10901012
DepthwiseConvConverter,
10911013
MatMulConverter,
10921014
AvgPool2dConverter,
1093-
FullyConnectedConverter,
10941015
TransposeConverter
10951016
>(patterns->getContext());
10961017

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ struct TosaToLinalgNamed
6262
target.addIllegalOp<tosa::MaxPool2dOp>();
6363
target.addIllegalOp<tosa::AvgPool2dOp>();
6464
target.addIllegalOp<tosa::MatMulOp>();
65-
target.addIllegalOp<tosa::FullyConnectedOp>();
6665
target.addIllegalOp<tosa::TransposeOp>();
6766

6867
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 3 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -566,26 +566,9 @@ static void buildTransConvOpWithQuantInfo(
566566
result.addTypes(finalOutputType);
567567
}
568568

569-
/// The tosa.fully_connected op has its own builder as it does not have
570-
/// strides/dilation/padding.
571-
static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
572-
Type outputType, Value input, Value weight,
573-
Value bias) {
574-
575-
result.addOperands({input, weight, bias});
576-
auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
577-
if (quantAttr) {
578-
result.addAttribute("quantization_info", quantAttr);
579-
result.addTypes(
580-
buildConvOpResultTypeInfo(builder, outputType, input, weight));
581-
} else {
582-
result.addTypes(outputType);
583-
}
584-
}
585-
586-
/// The tosa.matmul op is also intended to be generated where a
587-
/// fully_connected op must be constructed where the weight is not a constant.
588-
/// In this case, the fully_connected op must be expressed using matmul.
569+
/// The tosa.matmul op is also intended to be generated where a fully_connected
570+
/// op must be constructed where the weight is not a constant. In this case,
571+
/// the fully_connected op must be expressed using matmul.
589572
/// TODO: Add link to the leglization document explaining this.
590573
static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
591574
OperationState &result, Type outputType,
@@ -889,76 +872,6 @@ bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
889872
return succeeded(verifyCompatibleShape(l[0], r[0]));
890873
}
891874

892-
LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
893-
MLIRContext *context, ::std::optional<Location> location,
894-
FullyConnectedOp::Adaptor adaptor,
895-
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
896-
ShapeAdaptor inputShape(adaptor.getInput().getType());
897-
ShapeAdaptor weightShape(adaptor.getWeight().getType());
898-
ShapeAdaptor biasShape(adaptor.getBias().getType());
899-
900-
// All shapes are dynamic.
901-
SmallVector<int64_t> outShape;
902-
outShape.resize(2, ShapedType::kDynamic);
903-
904-
if (inputShape.hasRank()) {
905-
outShape[0] = inputShape.getDimSize(0);
906-
}
907-
908-
if (weightShape.hasRank()) {
909-
outShape[1] = weightShape.getDimSize(0);
910-
}
911-
912-
if (biasShape.hasRank()) {
913-
outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0)
914-
: outShape[1];
915-
}
916-
917-
inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
918-
return success();
919-
}
920-
921-
LogicalResult FullyConnectedOp::verify() {
922-
// All TOSA conv ops have an input() and weight().
923-
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
924-
925-
RankedTensorType weightType =
926-
llvm::dyn_cast<RankedTensorType>(getWeight().getType());
927-
928-
// Must be ranked tensor types
929-
if (!inputType) {
930-
emitOpError("expect a ranked tensor for input, got ") << getInput();
931-
return failure();
932-
}
933-
if (!weightType) {
934-
emitOpError("expect a ranked tensor for weight, got ") << getWeight();
935-
return failure();
936-
}
937-
938-
auto inputEType = inputType.getElementType();
939-
auto weightEType = weightType.getElementType();
940-
941-
bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
942-
bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
943-
944-
// Either both must be quantized or both unquantized.
945-
if (inputIsQuant != weightIsQuant) {
946-
emitOpError(
947-
"expect both input and weight to be float or not together, got ")
948-
<< inputEType << " and " << weightEType;
949-
return failure();
950-
}
951-
952-
// Quantized type must have constructed the quantizationattr, and unquantized
953-
// types should not have a quantizationattr.
954-
if ((inputIsQuant && !getInputZp()) || (!inputIsQuant && getInputZp())) {
955-
emitOpError("input zero point is required for quantized type, and not "
956-
"allowed for float type");
957-
return failure();
958-
}
959-
return success();
960-
}
961-
962875
LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
963876
MLIRContext *context, ::std::optional<Location> location,
964877
MatMulOp::Adaptor adaptor,

mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
add_mlir_dialect_library(MLIRTosaTransforms
22
TosaDecomposeTransposeConv.cpp
3-
TosaDecomposeConv2D.cpp
43
TosaDecomposeDepthwise.cpp
54
TosaFolders.cpp
65
TosaInferShapes.cpp

0 commit comments

Comments
 (0)