@@ -399,55 +399,14 @@ LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) {
399
399
return failure ();
400
400
}
401
401
}
402
-
403
- // For shapes that were created by some operations, we can obtain partial
404
- // information on the shapes and sometimes determine if they will be
405
- // broadcastable with that.
406
- struct CstrBroadcastablePartialInfo
407
- : public OpRewritePattern<CstrBroadcastableOp> {
408
- using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
409
-
410
- LogicalResult matchAndRewrite (CstrBroadcastableOp op,
411
- PatternRewriter &rewriter) const override {
412
- SmallVector<int64_t , 6 > lhsShape, rhsShape;
413
- if (failed (getShapeVec (op.lhs (), lhsShape)))
414
- return failure ();
415
- if (failed (getShapeVec (op.rhs (), rhsShape)))
416
- return failure ();
417
- if (!OpTrait::util::staticallyKnownBroadcastable (lhsShape, rhsShape))
418
- return failure ();
419
-
420
- rewriter.replaceOpWithNewOp <ConstWitnessOp>(op.getOperation (), true );
421
- return success ();
422
- }
423
- };
424
-
425
- // Scalars are always broadcastable.
426
- struct CstrBroadcastableScalar : public OpRewritePattern <CstrBroadcastableOp> {
427
- using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
428
-
429
- LogicalResult matchAndRewrite (CstrBroadcastableOp op,
430
- PatternRewriter &rewriter) const override {
431
- SmallVector<int64_t , 6 > shape;
432
- if (failed (getShapeVec (op.lhs (), shape)) || shape.size () > 0 )
433
- return failure ();
434
- if (failed (getShapeVec (op.rhs (), shape)) || shape.size () > 0 )
435
- return failure ();
436
-
437
- rewriter.replaceOpWithNewOp <ConstWitnessOp>(op.getOperation (), true );
438
- return success ();
439
- }
440
- };
441
-
442
402
} // namespace
443
403
444
404
void CstrBroadcastableOp::getCanonicalizationPatterns (
445
405
OwningRewritePatternList &patterns, MLIRContext *context) {
446
406
// Canonicalization patterns have overlap with the considerations during
447
407
// folding in case additional shape information is inferred at some point that
448
408
// does not result in folding.
449
- patterns.insert <CstrBroadcastableEqOps, CstrBroadcastablePartialInfo,
450
- CstrBroadcastableScalar>(context);
409
+ patterns.insert <CstrBroadcastableEqOps>(context);
451
410
}
452
411
453
412
OpFoldResult CstrBroadcastableOp::fold (ArrayRef<Attribute> operands) {
0 commit comments