diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h index 5d68328acc7e..7d93dd00d86a 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -59,6 +59,10 @@ void populateElementwiseToLinalgConversionPatterns( /// operations. std::unique_ptr> createLinalgGeneralizationPass(); +/// Create a pass to convert Linalg operations to equivalent operations that +/// work on primitive types, if possible. +std::unique_ptr createLinalgDetensorizePass(); + /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its /// producer (consumer) generic operation by expanding the dimensionality of the /// loop in the generic op. diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index a20289af3054..e51d08d3770d 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -136,4 +136,28 @@ def LinalgGeneralization : FunctionPass<"linalg-generalize-named-ops"> { let dependentDialects = ["linalg::LinalgDialect"]; } +def LinalgDetensorize : FunctionPass<"linalg-detensorize"> { + let summary = "Detensorize linalg ops"; + let constructor = "mlir::createLinalgDetensorizePass()"; + let dependentDialects = []; + + let description = [{ + Detensoring is the process through which a tensor value is convereted to one + or potentially more primitive value(s). During this process, operations with + such detensored operands are also converted to an equivalent form that works + on primitives. + + The detensoring process is driven by linalg-on-tensor ops. In particular, a + linalg-on-tensor op is checked to see whether *all* its operands can be + detensored. If so, those operands are converted to their primitive + counterparts and the linalg op is replaced by an equivalent op that takes + those new primitive values as operands. Therefore, the detensoring process + can be divided into 2 main logical phases: + + 1. Detect/match an op that can be detensored. + 2. Detensor the operands of the op and replace it with a primitive + equivalent. + }]; +} + #endif // MLIR_DIALECT_LINALG_PASSES diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index d988e245c9f7..1469371e1466 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms Bufferize.cpp CodegenStrategy.cpp + Detensorize.cpp DropUnitDims.cpp ElementwiseToLinalg.cpp Fusion.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp new file mode 100644 index 000000000000..2e2e3b94a34a --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -0,0 +1,173 @@ +//===- Detensorize.cpp - Linalg transformations as patterns ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include +#include + +using namespace mlir; +using namespace mlir::linalg; + +namespace { +/// Defines the criteria a TensorType must follow in order to be considered +/// "detensorable". +/// +/// NOTE: For now, only 0-D are supported. +/// +/// Returns true if tensorType can be detensored. +bool canBeDetensored(TensorType tensorType) { + return tensorType.hasRank() && tensorType.getRank() == 0; +} + +/// A conversion patttern for detensoring `linalg.generic` ops. +class DetensorizeGenericOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(GenericOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Block *originalBlock = op->getBlock(); + + // Gather some information about the op before inling its region. + Block *opEntryBlock = &*op.region().begin(); + YieldOp yieldOp = dyn_cast(op.region().back().getTerminator()); + + // Split the op's region before the op. This way, we have a clear insertion + // point in which the op can be inlined. + Block *newBlock = originalBlock->splitBlock(op); + rewriter.inlineRegionBefore(op.region(), newBlock); + // Now that op's region is inlined, the operands of its YieldOp are mapped + // to the materialized target values. Therefore, we can replace the op's + // uses with those of its YielOp's operands. + rewriter.replaceOp(op, yieldOp->getOperands()); + + // No need for these intermediate blocks, merge them into 1. + rewriter.mergeBlocks(opEntryBlock, originalBlock, operands); + rewriter.mergeBlocks(newBlock, originalBlock, {}); + + rewriter.eraseOp(&*Block::iterator(yieldOp)); + + return success(); + } +}; + +class DetensorizeTypeConverter : public TypeConverter { +public: + DetensorizeTypeConverter() { + addConversion([](Type type) { return type; }); + + // A TensorType that can be detensored, is converted to the underlying + // element type. + addConversion([](TensorType tensorType) -> Type { + if (canBeDetensored(tensorType)) + return tensorType.getElementType(); + + return tensorType; + }); + + // A tensor value is detensoried by extracting its element(s). + addTargetMaterialization([](OpBuilder &builder, Type type, + ValueRange inputs, Location loc) -> Value { + return builder.create(loc, inputs[0], ValueRange{}); + }); + + // A detensored value is converted back by creating a new tensor from its + // element(s). + addSourceMaterialization([](OpBuilder &builder, Type type, + ValueRange inputs, Location loc) -> Value { + auto createNewTensorOp = builder.create( + loc, inputs[0].getType(), inputs[0]); + + // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to + // a tensor instead. + return builder.create( + loc, type, createNewTensorOp, ArrayRef{}); + }); + } +}; + +/// Canonicalizes the pattern of the form +/// +/// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> +/// %reshaped_tensor = linalg.tensor_reshape %tensor [] : tensor<1xi32> into +/// tensor +/// %extracted_element = tensor.extract %reshaped_tensor[] : tensor +/// +/// to just %element. +struct ExtractFromReshapeFromElements + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp extract, + PatternRewriter &rewriter) const final { + if (extract.indices().size() != 0) + return failure(); + + auto tensorReshape = extract.tensor().getDefiningOp(); + if (tensorReshape == nullptr) + return failure(); + + auto tensorFromElements = + tensorReshape.getOperand() + .getDefiningOp(); + if (tensorFromElements == nullptr) + return failure(); + + rewriter.replaceOp(extract, tensorFromElements.getOperand(0)); + return success(); + } +}; + +/// @see LinalgDetensorize in Linalg/Passes.td for more details. +struct LinalgDetensorize : public LinalgDetensorizeBase { + void runOnFunction() override { + auto *context = &getContext(); + DetensorizeTypeConverter typeConverter; + OwningRewritePatternList patterns; + ConversionTarget target(*context); + + target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; }); + target.addLegalDialect(); + target.addDynamicallyLegalOp([&](GenericOp op) { + // If any of the operands or results cannot be detensored, the op is + // considered legal and won't be detensored. + return llvm::any_of( + op.getShapedOperandTypes(), [](ShapedType shapedType) { + assert(shapedType.isa()); + return !canBeDetensored(shapedType.cast()); + }); + }); + + patterns.insert(typeConverter, context); + + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) + signalPassFailure(); + + OwningRewritePatternList canonPatterns; + canonPatterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(getFunction(), + std::move(canonPatterns)))) + signalPassFailure(); + + // TODO Properly handle control flow within function boundaries. + } +}; +} // namespace + +std::unique_ptr mlir::createLinalgDetensorizePass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Linalg/detensorized_0d.mlir b/mlir/test/Dialect/Linalg/detensorized_0d.mlir new file mode 100644 index 000000000000..e35a34fd157d --- /dev/null +++ b/mlir/test/Dialect/Linalg/detensorized_0d.mlir @@ -0,0 +1,107 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize | FileCheck %s + +#map = affine_map<() -> ()> + +func @detensor_simple(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { + %0 = linalg.init_tensor [] : tensor + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} + ins(%arg1, %arg2 : tensor, tensor) + outs(%0 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %2 = addf %arg3, %arg4 : f32 + linalg.yield %2 : f32 + } -> tensor + return %1: tensor +} +// CHECK-LABEL: func @detensor_simple +// CHECK-SAME: (%[[arg1:.*]]: tensor, %[[arg2:.*]]: tensor) +// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]] +// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] +// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]] +// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]] +// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]] +// CHECK: return %[[reshaped_tensor_res]] + +func @detensor_op_sequence(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { + %0 = linalg.init_tensor [] : tensor + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} + ins(%arg1, %arg2 : tensor, tensor) + outs(%0 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %2 = addf %arg3, %arg4 : f32 + linalg.yield %2 : f32 + } -> tensor + + %3 = linalg.init_tensor [] : tensor + %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} + ins(%arg1, %1 : tensor, tensor) + outs(%3 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %5 = mulf %arg3, %arg4 : f32 + linalg.yield %5 : f32 + } -> tensor + + %6 = linalg.init_tensor [] : tensor + %7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} + ins(%1, %4 : tensor, tensor) + outs(%6 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %5 = divf %arg3, %arg4 : f32 + linalg.yield %5 : f32 + } -> tensor + + return %7: tensor +} +// CHECK-LABEL: func @detensor_op_sequence +// CHECK-SAME: (%[[arg1:.*]]: tensor, %[[arg2:.*]]: tensor) +// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]] +// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] +// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]] +// CHECK-DAG: %[[arg1_val2:.*]] = tensor.extract %[[arg1]] +// CHECK: %[[detensored_res2:.*]] = mulf %[[arg1_val2]], %[[detensored_res]] +// CHECK: %[[detensored_res3:.*]] = divf %[[detensored_res]], %[[detensored_res2]] +// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]] +// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]] +// CHECK: return %[[reshaped_tensor_res]] + +func @detensor_multiple_ops(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { + %0 = linalg.init_tensor [] : tensor + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} + ins(%arg1, %arg2 : tensor, tensor) + outs(%0 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %2 = addf %arg3, %arg4 : f32 + %3 = mulf %2, %arg4 : f32 + linalg.yield %3 : f32 + } -> tensor + return %1: tensor +} +// CHECK-LABEL: func @detensor_multiple_ops +// CHECK-SAME: (%[[arg1:.*]]: tensor, %[[arg2:.*]]: tensor) +// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]] +// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] +// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]] +// CHECK: %[[detensored_res2:.*]] = mulf %[[detensored_res]], %[[arg2_val]] +// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res2]] +// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]] +// CHECK: return %[[reshaped_tensor_res]] + +func @detensor_foreign_op(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { + %0 = linalg.init_tensor [] : tensor + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} + ins(%arg1, %arg2 : tensor, tensor) + outs(%0 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %2 = "foreign.do_something"(%arg3, %arg4) {} : (f32, f32) -> f32 + linalg.yield %2 : f32 + } -> tensor + return %1: tensor +} +// CHECK-LABEL: func @detensor_foreign_op +// CHECK-SAME: (%[[arg1:.*]]: tensor, %[[arg2:.*]]: tensor) +// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]] +// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] +// CHECK: %[[detensored_res:.*]] = "foreign.do_something"(%[[arg1_val]], %[[arg2_val]]) +// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]] +// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]] +// CHECK: return %[[reshaped_tensor_res]]