[mlir][Standard] Add canonicalizer for dynamic_tensor_from_elements

This add canonicalizer for

- extracting an element from a dynamic_tensor_from_elements
- propagating constant operands to the type of dynamic_tensor_from_elements

Differential Revision: https://reviews.llvm.org/D87525
This commit is contained in:
Stephan Herhut 2020-09-14 11:54:55 +02:00
parent c0809f8d79
commit c897a7fb3e
3 changed files with 177 additions and 3 deletions

View File

@ -1511,6 +1511,8 @@ def DynamicTensorFromElementsOp : Std_Op<"dynamic_tensor_from_elements",
"ValueRange dynamicExtents, "
"function_ref<void(OpBuilder &, Location, ValueRange)>">,
];
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//

View File

@ -11,6 +11,7 @@
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Matchers.h"
@ -1730,6 +1731,101 @@ void DynamicTensorFromElementsOp::build(
bodyBuilder(b, result.location, bodyBlock->getArguments());
}
namespace {
/// Canonicalizes dynamic_tensor_from_elements operations with a constant
/// operand into the equivalent operation with the operand expressed in the
/// result type, instead. We also insert a type cast to make sure that the
/// resulting IR is still well-typed.
struct StaticDynamicTensorFromElements
: public OpRewritePattern<DynamicTensorFromElementsOp> {
using OpRewritePattern<DynamicTensorFromElementsOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicTensorFromElementsOp tensorFromElements,
PatternRewriter &rewriter) const final {
auto resultType =
tensorFromElements.getResult().getType().cast<RankedTensorType>();
if (resultType.hasStaticShape())
return failure();
SmallVector<Value, 4> newOperands;
SmallVector<int64_t, 4> newShape;
auto operandsIt = tensorFromElements.dynamicExtents().begin();
for (int64_t dim : resultType.getShape()) {
if (dim != RankedTensorType::kDynamicSize) {
newShape.push_back(dim);
continue;
}
APInt index;
if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
newShape.push_back(RankedTensorType::kDynamicSize);
newOperands.push_back(*operandsIt++);
continue;
}
newShape.push_back(index.getSExtValue());
operandsIt++;
}
if (newOperands.size() == tensorFromElements.dynamicExtents().size())
return failure();
auto loc = tensorFromElements.getLoc();
auto newOp = rewriter.create<DynamicTensorFromElementsOp>(
loc, RankedTensorType::get(newShape, resultType.getElementType()),
newOperands);
rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
newOp.body().begin());
rewriter.replaceOpWithNewOp<TensorCastOp>(tensorFromElements, resultType,
newOp);
return success();
}
};
/// Canonicalizes the pattern of the form
///
/// %tensor = dynamic_tensor_from_elements %x {
/// ^bb0(%arg0: index): // no predecessors
/// <computation>
/// yield %1 : index
/// } : tensor<?xindex>
/// %extracted_element = extract_element %tensor[%c0] : tensor<?xi32>
///
/// to just <computation> with %arg0 replaced by %c0. We only do this if the
/// dynamic_tensor_from_elements operation has no side-effects.
struct ExtractElementFromDynamicTensorFromElements
: public OpRewritePattern<ExtractElementOp> {
using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractElementOp extract,
PatternRewriter &rewriter) const final {
auto tensorFromElements =
extract.aggregate().getDefiningOp<DynamicTensorFromElementsOp>();
if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
return failure();
BlockAndValueMapping mapping;
Block *body = tensorFromElements.getBody();
mapping.map(body->getArguments(), extract.indices());
for (auto &op : body->without_terminator())
rewriter.clone(op, mapping);
auto yield = cast<YieldOp>(body->getTerminator());
rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
return success();
}
};
} // namespace
void DynamicTensorFromElementsOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ExtractElementFromDynamicTensorFromElements,
StaticDynamicTensorFromElements>(context);
}
//===----------------------------------------------------------------------===//
// ExtractElementOp
//===----------------------------------------------------------------------===//
@ -1807,16 +1903,16 @@ struct ExtractElementFromTensorFromElements
if (extract.indices().size() != 1)
return failure();
auto tensor_from_elements = dyn_cast_or_null<TensorFromElementsOp>(
auto tensorFromElements = dyn_cast_or_null<TensorFromElementsOp>(
extract.aggregate().getDefiningOp());
if (tensor_from_elements == nullptr)
if (tensorFromElements == nullptr)
return failure();
APInt index;
if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
return failure();
rewriter.replaceOp(extract,
tensor_from_elements.getOperand(index.getZExtValue()));
tensorFromElements.getOperand(index.getZExtValue()));
return success();
}
};

View File

@ -986,3 +986,79 @@ func @extract_element_from_tensor_from_elements(%element : index) -> index {
// CHECK: [[ARG]] : index
return %extracted_element : index
}
// -----
// CHECK-LABEL: func @extract_element_from_dynamic_tensor_from_elements
// CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
func @extract_element_from_dynamic_tensor_from_elements(%idx: index, %tensor: tensor<*xf32>) -> index {
%size = rank %tensor : tensor<*xf32>
// CHECK-NEXT: %[[RES:.*]] = dim %[[TENSOR]], %[[IDX]]
%0 = dynamic_tensor_from_elements %size {
^bb0(%arg0: index):
%1 = dim %tensor, %arg0 : tensor<*xf32>
yield %1 : index
} : tensor<?xindex>
%1 = extract_element %0[%idx] : tensor<?xindex>
// CHECK-NEXT: return %[[RES]]
return %1 : index
}
// -----
// CHECK-LABEL: func @extract_element_from_dynamic_tensor_from_elements_2d
// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
func @extract_element_from_dynamic_tensor_from_elements_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index {
%size = rank %tensor : tensor<*xf32>
// CHECK-NEXT: %[[DIM0:.*]] = dim %[[TENSOR]], %[[IDX0]]
// CHECK-NEXT: %[[DIM1:.*]] = dim %[[TENSOR]], %[[IDX1]]
// CHECK-NEXT: %[[RES:.*]] = addi %[[DIM0]], %[[DIM1]]
%0 = dynamic_tensor_from_elements %size, %size {
^bb0(%arg0: index, %arg1: index):
%1 = dim %tensor, %arg0 : tensor<*xf32>
%2 = dim %tensor, %arg1 : tensor<*xf32>
%3 = addi %1, %2 : index
yield %3 : index
} : tensor<?x?xindex>
%4 = extract_element %0[%idx0, %idx1] : tensor<?x?xindex>
// CHECK-NEXT: return %[[RES]]
return %4 : index
}
// -----
// CHECK-LABEL: func @extract_element_from_dynamic_tensor_from_elements_sideeffects
// CHECK-SAME: %[[IDX:.*]]: index
func @extract_element_from_dynamic_tensor_from_elements_sideeffects(%idx: index, %tensor: tensor<*xf32>) -> index {
%size = rank %tensor : tensor<*xf32>
%mem = alloc(%size) : memref<?xindex>
// CHECK: %[[DTENSOR:.*]] = dynamic_tensor_from_elements
%0 = dynamic_tensor_from_elements %size {
^bb0(%arg0: index):
%1 = dim %tensor, %arg0 : tensor<*xf32>
store %1, %mem[%arg0] : memref<?xindex>
yield %1 : index
} : tensor<?xindex>
// CHECK: %[[RES:.*]] = extract_element %[[DTENSOR]][%[[IDX]]]
%1 = extract_element %0[%idx] : tensor<?xindex>
// CHECK-NEXT: return %[[RES]]
return %1 : index
}
// -----
// CHECK-LABEL: @static_dynamic_tensor_from_elements
// CHECK-SAME: %[[SIZE1:.*]]: index, %[[SIZE4:.*]]: index)
func @static_dynamic_tensor_from_elements(%size1: index, %size4: index) -> tensor<3x?x?x7x?xindex> {
%c5 = constant 5 : index
// CHECK: dynamic_tensor_from_elements %[[SIZE1]], %[[SIZE4]]
%0 = dynamic_tensor_from_elements %size1, %c5, %size4 {
^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index):
%1 = constant 32 : index
yield %1 : index
// CHECK: : tensor<3x?x5x7x?xindex>
} : tensor<3x?x?x7x?xindex>
// CHECK: tensor_cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex>
return %0 : tensor<3x?x?x7x?xindex>
}