[MLIR][LinAlg] Start detensoring implementation.

This commit is the first baby step towards detensoring in
linalg-on-tensors.

Detensoring is the process through which a tensor value is convereted to one
or potentially more primitive value(s). During this process, operations with
such detensored operands are also converted to an equivalen form that works
on primitives.

The detensoring process is driven by linalg-on-tensor ops. In particular, a
linalg-on-tensor op is checked to see whether *all* its operands can be
detensored. If so, those operands are converted to thier primitive
counterparts and the linalg op is replaced by an equivalent op that takes
those new primitive values as operands.

This works towards handling github/google/iree#1159.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D96271
This commit is contained in:
KareemErgawy-TomTom 2021-02-16 07:42:41 +01:00
parent c61e511f38
commit 67e0d58de4
5 changed files with 309 additions and 0 deletions

View File

@ -59,6 +59,10 @@ void populateElementwiseToLinalgConversionPatterns(
/// operations.
std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass();
/// Create a pass to convert Linalg operations to equivalent operations that
/// work on primitive types, if possible.
std::unique_ptr<Pass> createLinalgDetensorizePass();
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
/// producer (consumer) generic operation by expanding the dimensionality of the
/// loop in the generic op.

View File

@ -136,4 +136,28 @@ def LinalgGeneralization : FunctionPass<"linalg-generalize-named-ops"> {
let dependentDialects = ["linalg::LinalgDialect"];
}
def LinalgDetensorize : FunctionPass<"linalg-detensorize"> {
let summary = "Detensorize linalg ops";
let constructor = "mlir::createLinalgDetensorizePass()";
let dependentDialects = [];
let description = [{
Detensoring is the process through which a tensor value is convereted to one
or potentially more primitive value(s). During this process, operations with
such detensored operands are also converted to an equivalent form that works
on primitives.
The detensoring process is driven by linalg-on-tensor ops. In particular, a
linalg-on-tensor op is checked to see whether *all* its operands can be
detensored. If so, those operands are converted to their primitive
counterparts and the linalg op is replaced by an equivalent op that takes
those new primitive values as operands. Therefore, the detensoring process
can be divided into 2 main logical phases:
1. Detect/match an op that can be detensored.
2. Detensor the operands of the op and replace it with a primitive
equivalent.
}];
}
#endif // MLIR_DIALECT_LINALG_PASSES

View File

@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRLinalgTransforms
Bufferize.cpp
CodegenStrategy.cpp
Detensorize.cpp
DropUnitDims.cpp
ElementwiseToLinalg.cpp
Fusion.cpp

View File

@ -0,0 +1,173 @@
//===- Detensorize.cpp - Linalg transformations as patterns ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <iterator>
#include <memory>
using namespace mlir;
using namespace mlir::linalg;
namespace {
/// Defines the criteria a TensorType must follow in order to be considered
/// "detensorable".
///
/// NOTE: For now, only 0-D are supported.
///
/// Returns true if tensorType can be detensored.
bool canBeDetensored(TensorType tensorType) {
return tensorType.hasRank() && tensorType.getRank() == 0;
}
/// A conversion patttern for detensoring `linalg.generic` ops.
class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(GenericOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Block *originalBlock = op->getBlock();
// Gather some information about the op before inling its region.
Block *opEntryBlock = &*op.region().begin();
YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator());
// Split the op's region before the op. This way, we have a clear insertion
// point in which the op can be inlined.
Block *newBlock = originalBlock->splitBlock(op);
rewriter.inlineRegionBefore(op.region(), newBlock);
// Now that op's region is inlined, the operands of its YieldOp are mapped
// to the materialized target values. Therefore, we can replace the op's
// uses with those of its YielOp's operands.
rewriter.replaceOp(op, yieldOp->getOperands());
// No need for these intermediate blocks, merge them into 1.
rewriter.mergeBlocks(opEntryBlock, originalBlock, operands);
rewriter.mergeBlocks(newBlock, originalBlock, {});
rewriter.eraseOp(&*Block::iterator(yieldOp));
return success();
}
};
class DetensorizeTypeConverter : public TypeConverter {
public:
DetensorizeTypeConverter() {
addConversion([](Type type) { return type; });
// A TensorType that can be detensored, is converted to the underlying
// element type.
addConversion([](TensorType tensorType) -> Type {
if (canBeDetensored(tensorType))
return tensorType.getElementType();
return tensorType;
});
// A tensor value is detensoried by extracting its element(s).
addTargetMaterialization([](OpBuilder &builder, Type type,
ValueRange inputs, Location loc) -> Value {
return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
});
// A detensored value is converted back by creating a new tensor from its
// element(s).
addSourceMaterialization([](OpBuilder &builder, Type type,
ValueRange inputs, Location loc) -> Value {
auto createNewTensorOp = builder.create<tensor::FromElementsOp>(
loc, inputs[0].getType(), inputs[0]);
// FromElementsOp results in a tensor<1xdtype>, we need to reshape that to
// a tensor<dtype> instead.
return builder.create<linalg::TensorReshapeOp>(
loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{});
});
}
};
/// Canonicalizes the pattern of the form
///
/// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
/// %reshaped_tensor = linalg.tensor_reshape %tensor [] : tensor<1xi32> into
/// tensor<i32>
/// %extracted_element = tensor.extract %reshaped_tensor[] : tensor<i32>
///
/// to just %element.
struct ExtractFromReshapeFromElements
: public OpRewritePattern<tensor::ExtractOp> {
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
PatternRewriter &rewriter) const final {
if (extract.indices().size() != 0)
return failure();
auto tensorReshape = extract.tensor().getDefiningOp<TensorReshapeOp>();
if (tensorReshape == nullptr)
return failure();
auto tensorFromElements =
tensorReshape.getOperand()
.getDefiningOp<mlir::tensor::FromElementsOp>();
if (tensorFromElements == nullptr)
return failure();
rewriter.replaceOp(extract, tensorFromElements.getOperand(0));
return success();
}
};
/// @see LinalgDetensorize in Linalg/Passes.td for more details.
struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
void runOnFunction() override {
auto *context = &getContext();
DetensorizeTypeConverter typeConverter;
OwningRewritePatternList patterns;
ConversionTarget target(*context);
target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
target.addLegalDialect<linalg::LinalgDialect>();
target.addDynamicallyLegalOp<GenericOp>([&](GenericOp op) {
// If any of the operands or results cannot be detensored, the op is
// considered legal and won't be detensored.
return llvm::any_of(
op.getShapedOperandTypes(), [](ShapedType shapedType) {
assert(shapedType.isa<TensorType>());
return !canBeDetensored(shapedType.cast<TensorType>());
});
});
patterns.insert<DetensorizeGenericOp>(typeConverter, context);
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
OwningRewritePatternList canonPatterns;
canonPatterns.insert<ExtractFromReshapeFromElements>(context);
if (failed(applyPatternsAndFoldGreedily(getFunction(),
std::move(canonPatterns))))
signalPassFailure();
// TODO Properly handle control flow within function boundaries.
}
};
} // namespace
std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() {
return std::make_unique<LinalgDetensorize>();
}

View File

@ -0,0 +1,107 @@
// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize | FileCheck %s
#map = affine_map<() -> ()>
func @detensor_simple(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
%0 = linalg.init_tensor [] : tensor<f32>
%1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
ins(%arg1, %arg2 : tensor<f32>, tensor<f32>)
outs(%0 : tensor<f32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
%2 = addf %arg3, %arg4 : f32
linalg.yield %2 : f32
} -> tensor<f32>
return %1: tensor<f32>
}
// CHECK-LABEL: func @detensor_simple
// CHECK-SAME: (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>)
// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]]
// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]]
// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]]
// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
// CHECK: return %[[reshaped_tensor_res]]
func @detensor_op_sequence(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
%0 = linalg.init_tensor [] : tensor<f32>
%1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
ins(%arg1, %arg2 : tensor<f32>, tensor<f32>)
outs(%0 : tensor<f32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
%2 = addf %arg3, %arg4 : f32
linalg.yield %2 : f32
} -> tensor<f32>
%3 = linalg.init_tensor [] : tensor<f32>
%4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
ins(%arg1, %1 : tensor<f32>, tensor<f32>)
outs(%3 : tensor<f32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
%5 = mulf %arg3, %arg4 : f32
linalg.yield %5 : f32
} -> tensor<f32>
%6 = linalg.init_tensor [] : tensor<f32>
%7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
ins(%1, %4 : tensor<f32>, tensor<f32>)
outs(%6 : tensor<f32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
%5 = divf %arg3, %arg4 : f32
linalg.yield %5 : f32
} -> tensor<f32>
return %7: tensor<f32>
}
// CHECK-LABEL: func @detensor_op_sequence
// CHECK-SAME: (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>)
// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]]
// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]]
// CHECK-DAG: %[[arg1_val2:.*]] = tensor.extract %[[arg1]]
// CHECK: %[[detensored_res2:.*]] = mulf %[[arg1_val2]], %[[detensored_res]]
// CHECK: %[[detensored_res3:.*]] = divf %[[detensored_res]], %[[detensored_res2]]
// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]]
// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
// CHECK: return %[[reshaped_tensor_res]]
func @detensor_multiple_ops(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
%0 = linalg.init_tensor [] : tensor<f32>
%1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
ins(%arg1, %arg2 : tensor<f32>, tensor<f32>)
outs(%0 : tensor<f32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
%2 = addf %arg3, %arg4 : f32
%3 = mulf %2, %arg4 : f32
linalg.yield %3 : f32
} -> tensor<f32>
return %1: tensor<f32>
}
// CHECK-LABEL: func @detensor_multiple_ops
// CHECK-SAME: (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>)
// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]]
// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]]
// CHECK: %[[detensored_res2:.*]] = mulf %[[detensored_res]], %[[arg2_val]]
// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res2]]
// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
// CHECK: return %[[reshaped_tensor_res]]
func @detensor_foreign_op(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
%0 = linalg.init_tensor [] : tensor<f32>
%1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
ins(%arg1, %arg2 : tensor<f32>, tensor<f32>)
outs(%0 : tensor<f32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
%2 = "foreign.do_something"(%arg3, %arg4) {} : (f32, f32) -> f32
linalg.yield %2 : f32
} -> tensor<f32>
return %1: tensor<f32>
}
// CHECK-LABEL: func @detensor_foreign_op
// CHECK-SAME: (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>)
// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]]
// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
// CHECK: %[[detensored_res:.*]] = "foreign.do_something"(%[[arg1_val]], %[[arg2_val]])
// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]]
// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
// CHECK: return %[[reshaped_tensor_res]]