forked from OSchip/llvm-project
[mlir][Standard] Add canonicalizer for dynamic_tensor_from_elements
This add canonicalizer for - extracting an element from a dynamic_tensor_from_elements - propagating constant operands to the type of dynamic_tensor_from_elements Differential Revision: https://reviews.llvm.org/D87525
This commit is contained in:
parent
c0809f8d79
commit
c897a7fb3e
|
@ -1511,6 +1511,8 @@ def DynamicTensorFromElementsOp : Std_Op<"dynamic_tensor_from_elements",
|
|||
"ValueRange dynamicExtents, "
|
||||
"function_ref<void(OpBuilder &, Location, ValueRange)>">,
|
||||
];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
#include "mlir/Dialect/CommonFolders.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
@ -1730,6 +1731,101 @@ void DynamicTensorFromElementsOp::build(
|
|||
bodyBuilder(b, result.location, bodyBlock->getArguments());
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// Canonicalizes dynamic_tensor_from_elements operations with a constant
|
||||
/// operand into the equivalent operation with the operand expressed in the
|
||||
/// result type, instead. We also insert a type cast to make sure that the
|
||||
/// resulting IR is still well-typed.
|
||||
struct StaticDynamicTensorFromElements
|
||||
: public OpRewritePattern<DynamicTensorFromElementsOp> {
|
||||
using OpRewritePattern<DynamicTensorFromElementsOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(DynamicTensorFromElementsOp tensorFromElements,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto resultType =
|
||||
tensorFromElements.getResult().getType().cast<RankedTensorType>();
|
||||
|
||||
if (resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
SmallVector<Value, 4> newOperands;
|
||||
SmallVector<int64_t, 4> newShape;
|
||||
auto operandsIt = tensorFromElements.dynamicExtents().begin();
|
||||
|
||||
for (int64_t dim : resultType.getShape()) {
|
||||
if (dim != RankedTensorType::kDynamicSize) {
|
||||
newShape.push_back(dim);
|
||||
continue;
|
||||
}
|
||||
APInt index;
|
||||
if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
|
||||
newShape.push_back(RankedTensorType::kDynamicSize);
|
||||
newOperands.push_back(*operandsIt++);
|
||||
continue;
|
||||
}
|
||||
newShape.push_back(index.getSExtValue());
|
||||
operandsIt++;
|
||||
}
|
||||
|
||||
if (newOperands.size() == tensorFromElements.dynamicExtents().size())
|
||||
return failure();
|
||||
|
||||
auto loc = tensorFromElements.getLoc();
|
||||
auto newOp = rewriter.create<DynamicTensorFromElementsOp>(
|
||||
loc, RankedTensorType::get(newShape, resultType.getElementType()),
|
||||
newOperands);
|
||||
rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
|
||||
newOp.body().begin());
|
||||
rewriter.replaceOpWithNewOp<TensorCastOp>(tensorFromElements, resultType,
|
||||
newOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Canonicalizes the pattern of the form
|
||||
///
|
||||
/// %tensor = dynamic_tensor_from_elements %x {
|
||||
/// ^bb0(%arg0: index): // no predecessors
|
||||
/// <computation>
|
||||
/// yield %1 : index
|
||||
/// } : tensor<?xindex>
|
||||
/// %extracted_element = extract_element %tensor[%c0] : tensor<?xi32>
|
||||
///
|
||||
/// to just <computation> with %arg0 replaced by %c0. We only do this if the
|
||||
/// dynamic_tensor_from_elements operation has no side-effects.
|
||||
struct ExtractElementFromDynamicTensorFromElements
|
||||
: public OpRewritePattern<ExtractElementOp> {
|
||||
using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ExtractElementOp extract,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto tensorFromElements =
|
||||
extract.aggregate().getDefiningOp<DynamicTensorFromElementsOp>();
|
||||
if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
|
||||
return failure();
|
||||
|
||||
BlockAndValueMapping mapping;
|
||||
Block *body = tensorFromElements.getBody();
|
||||
mapping.map(body->getArguments(), extract.indices());
|
||||
for (auto &op : body->without_terminator())
|
||||
rewriter.clone(op, mapping);
|
||||
|
||||
auto yield = cast<YieldOp>(body->getTerminator());
|
||||
|
||||
rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void DynamicTensorFromElementsOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<ExtractElementFromDynamicTensorFromElements,
|
||||
StaticDynamicTensorFromElements>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExtractElementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1807,16 +1903,16 @@ struct ExtractElementFromTensorFromElements
|
|||
if (extract.indices().size() != 1)
|
||||
return failure();
|
||||
|
||||
auto tensor_from_elements = dyn_cast_or_null<TensorFromElementsOp>(
|
||||
auto tensorFromElements = dyn_cast_or_null<TensorFromElementsOp>(
|
||||
extract.aggregate().getDefiningOp());
|
||||
if (tensor_from_elements == nullptr)
|
||||
if (tensorFromElements == nullptr)
|
||||
return failure();
|
||||
|
||||
APInt index;
|
||||
if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
|
||||
return failure();
|
||||
rewriter.replaceOp(extract,
|
||||
tensor_from_elements.getOperand(index.getZExtValue()));
|
||||
tensorFromElements.getOperand(index.getZExtValue()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -986,3 +986,79 @@ func @extract_element_from_tensor_from_elements(%element : index) -> index {
|
|||
// CHECK: [[ARG]] : index
|
||||
return %extracted_element : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @extract_element_from_dynamic_tensor_from_elements
|
||||
// CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
|
||||
func @extract_element_from_dynamic_tensor_from_elements(%idx: index, %tensor: tensor<*xf32>) -> index {
|
||||
%size = rank %tensor : tensor<*xf32>
|
||||
// CHECK-NEXT: %[[RES:.*]] = dim %[[TENSOR]], %[[IDX]]
|
||||
%0 = dynamic_tensor_from_elements %size {
|
||||
^bb0(%arg0: index):
|
||||
%1 = dim %tensor, %arg0 : tensor<*xf32>
|
||||
yield %1 : index
|
||||
} : tensor<?xindex>
|
||||
%1 = extract_element %0[%idx] : tensor<?xindex>
|
||||
// CHECK-NEXT: return %[[RES]]
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @extract_element_from_dynamic_tensor_from_elements_2d
|
||||
// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
|
||||
func @extract_element_from_dynamic_tensor_from_elements_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index {
|
||||
%size = rank %tensor : tensor<*xf32>
|
||||
// CHECK-NEXT: %[[DIM0:.*]] = dim %[[TENSOR]], %[[IDX0]]
|
||||
// CHECK-NEXT: %[[DIM1:.*]] = dim %[[TENSOR]], %[[IDX1]]
|
||||
// CHECK-NEXT: %[[RES:.*]] = addi %[[DIM0]], %[[DIM1]]
|
||||
%0 = dynamic_tensor_from_elements %size, %size {
|
||||
^bb0(%arg0: index, %arg1: index):
|
||||
%1 = dim %tensor, %arg0 : tensor<*xf32>
|
||||
%2 = dim %tensor, %arg1 : tensor<*xf32>
|
||||
%3 = addi %1, %2 : index
|
||||
yield %3 : index
|
||||
} : tensor<?x?xindex>
|
||||
%4 = extract_element %0[%idx0, %idx1] : tensor<?x?xindex>
|
||||
// CHECK-NEXT: return %[[RES]]
|
||||
return %4 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @extract_element_from_dynamic_tensor_from_elements_sideeffects
|
||||
// CHECK-SAME: %[[IDX:.*]]: index
|
||||
func @extract_element_from_dynamic_tensor_from_elements_sideeffects(%idx: index, %tensor: tensor<*xf32>) -> index {
|
||||
%size = rank %tensor : tensor<*xf32>
|
||||
%mem = alloc(%size) : memref<?xindex>
|
||||
// CHECK: %[[DTENSOR:.*]] = dynamic_tensor_from_elements
|
||||
%0 = dynamic_tensor_from_elements %size {
|
||||
^bb0(%arg0: index):
|
||||
%1 = dim %tensor, %arg0 : tensor<*xf32>
|
||||
store %1, %mem[%arg0] : memref<?xindex>
|
||||
yield %1 : index
|
||||
} : tensor<?xindex>
|
||||
// CHECK: %[[RES:.*]] = extract_element %[[DTENSOR]][%[[IDX]]]
|
||||
%1 = extract_element %0[%idx] : tensor<?xindex>
|
||||
// CHECK-NEXT: return %[[RES]]
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @static_dynamic_tensor_from_elements
|
||||
// CHECK-SAME: %[[SIZE1:.*]]: index, %[[SIZE4:.*]]: index)
|
||||
func @static_dynamic_tensor_from_elements(%size1: index, %size4: index) -> tensor<3x?x?x7x?xindex> {
|
||||
%c5 = constant 5 : index
|
||||
// CHECK: dynamic_tensor_from_elements %[[SIZE1]], %[[SIZE4]]
|
||||
%0 = dynamic_tensor_from_elements %size1, %c5, %size4 {
|
||||
^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index):
|
||||
%1 = constant 32 : index
|
||||
yield %1 : index
|
||||
// CHECK: : tensor<3x?x5x7x?xindex>
|
||||
} : tensor<3x?x?x7x?xindex>
|
||||
// CHECK: tensor_cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex>
|
||||
return %0 : tensor<3x?x?x7x?xindex>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue