forked from OSchip/llvm-project
[mlir] Start splitting the `tensor` dialect out of `std`.
This starts by moving `std.extract_element` to `tensor.extract` (this mirrors the naming of `vector.extract`). Curiously, `std.extract_element` supposedly works on vectors as well, and this patch removes that functionality. I would tend to do that in separate patch, but I couldn't find any downstream users relying on this, and the fact that we have `vector.extract` made it seem safe enough to lump in here. This also sets up the `tensor` dialect as a dependency of the `std` dialect, as some ops that currently live in `std` depend on `tensor.extract` via their canonicalization patterns. Part of RFC: https://llvm.discourse.group/t/rfc-split-the-tensor-dialect-from-std/2347/2 Differential Revision: https://reviews.llvm.org/D92991
This commit is contained in:
parent
7b3470baf8
commit
cab8dda90f
|
@ -14,5 +14,6 @@ add_subdirectory(SCF)
|
|||
add_subdirectory(Shape)
|
||||
add_subdirectory(SPIRV)
|
||||
add_subdirectory(StandardOps)
|
||||
add_subdirectory(Tensor)
|
||||
add_subdirectory(Tosa)
|
||||
add_subdirectory(Vector)
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
|
||||
#include "mlir/Dialect/Linalg/EDSC/Builders.h"
|
||||
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
|
@ -46,7 +47,6 @@ using folded_std_constant_float = FoldedValueBuilder<ConstantFloatOp>;
|
|||
using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
|
||||
using folded_std_constant_int = FoldedValueBuilder<ConstantIntOp>;
|
||||
using folded_std_dim = FoldedValueBuilder<DimOp>;
|
||||
using folded_std_extract_element = FoldedValueBuilder<ExtractElementOp>;
|
||||
using folded_std_index_cast = FoldedValueBuilder<IndexCastOp>;
|
||||
using folded_std_muli = FoldedValueBuilder<MulIOp>;
|
||||
using folded_std_mulf = FoldedValueBuilder<MulFOp>;
|
||||
|
@ -60,6 +60,7 @@ using folded_std_tensor_load = FoldedValueBuilder<TensorLoadOp>;
|
|||
using folded_std_view = FoldedValueBuilder<ViewOp>;
|
||||
using folded_std_zero_extendi = FoldedValueBuilder<ZeroExtendIOp>;
|
||||
using folded_std_sign_extendi = FoldedValueBuilder<SignExtendIOp>;
|
||||
using folded_tensor_extract = FoldedValueBuilder<tensor::ExtractOp>;
|
||||
} // namespace intrinsics
|
||||
} // namespace edsc
|
||||
} // namespace mlir
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#define MLIR_DIALECT_STANDARDOPS_EDSC_INTRINSICS_H_
|
||||
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Builders.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace edsc {
|
||||
|
@ -28,7 +29,6 @@ using std_dealloc = OperationBuilder<DeallocOp>;
|
|||
using std_divis = ValueBuilder<SignedDivIOp>;
|
||||
using std_diviu = ValueBuilder<UnsignedDivIOp>;
|
||||
using std_dim = ValueBuilder<DimOp>;
|
||||
using std_extract_element = ValueBuilder<ExtractElementOp>;
|
||||
using std_fpext = ValueBuilder<FPExtOp>;
|
||||
using std_fptrunc = ValueBuilder<FPTruncOp>;
|
||||
using std_im = ValueBuilder<ImOp>;
|
||||
|
@ -52,6 +52,7 @@ using std_tensor_store = OperationBuilder<TensorStoreOp>;
|
|||
using std_view = ValueBuilder<ViewOp>;
|
||||
using std_zero_extendi = ValueBuilder<ZeroExtendIOp>;
|
||||
using std_sign_extendi = ValueBuilder<SignExtendIOp>;
|
||||
using tensor_extract = ValueBuilder<tensor::ExtractOp>;
|
||||
|
||||
/// Branches into `block` with `operands`.
|
||||
BranchOp std_br(Block *block, ValueRange operands);
|
||||
|
|
|
@ -1669,59 +1669,6 @@ def Exp2Op : FloatUnaryOp<"exp2"> {
|
|||
let summary = "base-2 exponential of the specified value";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExtractElementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ExtractElementOp : Std_Op<"extract_element",
|
||||
[NoSideEffect,
|
||||
TypesMatchWith<"result type matches element type of aggregate",
|
||||
"aggregate", "result",
|
||||
"$_self.cast<ShapedType>().getElementType()">]> {
|
||||
let summary = "element extract operation";
|
||||
let description = [{
|
||||
The `extract_element` op reads a tensor or vector and returns one element
|
||||
from it specified by an index list. The output of the 'extract_element' is a
|
||||
new value with the same type as the elements of the tensor or vector. The
|
||||
arity of indices matches the rank of the accessed value (i.e., if a tensor
|
||||
is of rank 3, then 3 indices are required for the extract. The indices
|
||||
should all be of `index` type.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%3 = extract_element %v[%1, %2] : vector<4x4xi32>
|
||||
%4 = extract_element %t[%1, %2] : tensor<4x4xi32>
|
||||
%5 = extract_element %ut[%1, %2] : tensor<*xi32>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTypeOf<[AnyVector, AnyTensor]>:$aggregate,
|
||||
Variadic<Index>:$indices);
|
||||
let results = (outs AnyType:$result);
|
||||
|
||||
let builders = [
|
||||
OpBuilderDAG<(ins "Value":$aggregate, CArg<"ValueRange", "{}">:$indices), [{
|
||||
auto resType = aggregate.getType().cast<ShapedType>()
|
||||
.getElementType();
|
||||
build($_builder, $_state, resType, aggregate, indices);
|
||||
}]>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
Value getAggregate() { return getOperand(0); }
|
||||
|
||||
operand_range getIndices() {
|
||||
return {operand_begin() + 1, operand_end()};
|
||||
}
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$aggregate `[` $indices `]` attr-dict `:` type($aggregate)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFromElementsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
|
@ -0,0 +1,2 @@
|
|||
add_mlir_dialect(TensorOps tensor)
|
||||
add_mlir_doc(TensorOps -gen-dialect-doc TensorOps Dialects/)
|
|
@ -0,0 +1,31 @@
|
|||
//===- Tensor.h - Tensor dialect --------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_TENSOR_IR_TENSOR_H_
|
||||
#define MLIR_DIALECT_TENSOR_IR_TENSOR_H_
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tensor Dialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/TensorOpsDialect.h.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tensor Dialect Operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Tensor/IR/TensorOps.h.inc"
|
||||
|
||||
#endif // MLIR_DIALECT_TENSOR_IR_TENSOR_H_
|
|
@ -0,0 +1,48 @@
|
|||
//===- TensorBase.td - Base definitions for tensor dialect -*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TENSOR_BASE
|
||||
#define TENSOR_BASE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def Tensor_Dialect : Dialect {
|
||||
let name = "tensor";
|
||||
let cppNamespace = "::mlir::tensor";
|
||||
let description = [{
|
||||
The `tensor` dialect is intended to hold core tensor creation and
|
||||
manipulation ops, which are not strongly associated with any particular
|
||||
other dialect or domain abstraction. The primary smoke test of this is ops
|
||||
that make sense for any tensor element type.
|
||||
|
||||
We leave it to other dialects to hold the vast swath of possible
|
||||
computations one might want to do on a tensor.
|
||||
|
||||
The `tensor` type is (for better or for worse) used to represent all kinds
|
||||
of things, and supports an open-ended set of element types. Examples:
|
||||
|
||||
- representing large, dense aggregations of primitive types, suitable for
|
||||
high-performance numerical computing.
|
||||
- representing shapes in the `shape` dialect, which consist of small
|
||||
1D tensors of `index` data type.
|
||||
- representing aggregations of strings or “variant” types.
|
||||
- representing large, sparse aggregations of primitive types, suitable
|
||||
for high-performance numerical computing.
|
||||
|
||||
Thus, for the `tensor` dialect, we prefer for now to constrain the
|
||||
scope as much as possible. The expectation is that at some point
|
||||
in the future, the `tensor` dialect’s scope may be broadened through a
|
||||
careful discussion of the tradeoffs.
|
||||
|
||||
The `tensor` type is actually a builtin type (it lives in the builtin
|
||||
dialect), and does not live in this dialect.
|
||||
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TENSOR_BASE
|
|
@ -0,0 +1,62 @@
|
|||
//===- TensorOps.td - Tensor op definitions ----------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TENSOR_OPS
|
||||
#define TENSOR_OPS
|
||||
|
||||
include "mlir/Dialect/Tensor/IR/TensorBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
class Tensor_Op<string mnemonic, list<OpTrait> traits = []>
|
||||
: Op<Tensor_Dialect, mnemonic, traits> {
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExtractOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Tensor_ExtractOp : Tensor_Op<"extract",
|
||||
[NoSideEffect,
|
||||
TypesMatchWith<"result type matches element type of tensor",
|
||||
"tensor", "result",
|
||||
"$_self.cast<ShapedType>().getElementType()">]> {
|
||||
let summary = "element extraction operation";
|
||||
let description = [{
|
||||
The `tensor.extract` op reads a tensor and returns one
|
||||
element from it specified by an index list. The output of the op is a
|
||||
new value with the same type as the elements of the tensor. The
|
||||
arity of indices must match the rank of the accessed value (i.e., if a
|
||||
tensor is of rank 3, then 3 indices are required for the extract. The
|
||||
indices should all be of `index` type.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%4 = tensor.extract %t[%1, %2] : tensor<4x4xi32>
|
||||
%5 = tensor.extract %rt[%1, %2] : tensor<?x?xi32>
|
||||
%6 = tensor.extract %ut[%1, %2] : tensor<*xi32>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTensor:$tensor, Variadic<Index>:$indices);
|
||||
let results = (outs AnyType:$result);
|
||||
let assemblyFormat = "$tensor `[` $indices `]` attr-dict `:` type($tensor)";
|
||||
|
||||
let builders = [
|
||||
OpBuilderDAG<(ins "Value":$tensor, CArg<"ValueRange", "{}">:$indices), [{
|
||||
auto resType = tensor.getType().cast<ShapedType>().getElementType();
|
||||
build($_builder, $_state, resType, tensor, indices);
|
||||
}]>];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
#endif // TENSOR_OPS
|
|
@ -0,0 +1,5 @@
|
|||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Tensor)
|
||||
add_public_tablegen_target(MLIRTensorTransformsIncGen)
|
||||
|
||||
add_mlir_doc(Passes -gen-pass-doc TensorPasses ./)
|
|
@ -0,0 +1,38 @@
|
|||
//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES_H_
|
||||
#define MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/Bufferize.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class OwningRewritePatternList;
|
||||
|
||||
void populateTensorBufferizePatterns(MLIRContext *context,
|
||||
BufferizeTypeConverter &typeConverter,
|
||||
OwningRewritePatternList &patterns);
|
||||
|
||||
/// Creates an instance of `tensor` dialect bufferization pass.
|
||||
std::unique_ptr<Pass> createTensorBufferizePass();
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace tensor {
|
||||
/// Generate the code for registering passes.
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
|
||||
} // namespace tensor
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES_H_
|
|
@ -0,0 +1,19 @@
|
|||
//===-- Passes.td - pass definition file -------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES
|
||||
#define MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def TensorBufferize : FunctionPass<"tensor-bufferize"> {
|
||||
let summary = "Bufferize the `tensor` dialect";
|
||||
let constructor = "mlir::createTensorBufferizePass()";
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES
|
|
@ -35,6 +35,7 @@
|
|||
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
@ -66,6 +67,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
|
|||
ROCDL::ROCDLDialect,
|
||||
SDBMDialect,
|
||||
shape::ShapeDialect,
|
||||
tensor::TensorDialect,
|
||||
tosa::TosaDialect>();
|
||||
// clang-format on
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "mlir/Dialect/SPIRV/Passes.h"
|
||||
#include "mlir/Dialect/Shape/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
|
@ -57,6 +58,7 @@ inline void registerAllPasses() {
|
|||
registerShapePasses();
|
||||
spirv::registerSPIRVPasses();
|
||||
registerStandardPasses();
|
||||
tensor::registerTensorPasses();
|
||||
tosa::registerTosaOptPasses();
|
||||
}
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
|
@ -64,10 +65,10 @@ public:
|
|||
rewriter.create<scf::ForOp>(
|
||||
loc, rankDiff, greaterRank, one, llvm::None,
|
||||
[&](OpBuilder &b, Location loc, Value iv, ValueRange) {
|
||||
Value greaterRankOperandExtent = b.create<ExtractElementOp>(
|
||||
Value greaterRankOperandExtent = b.create<tensor::ExtractOp>(
|
||||
loc, greaterRankOperand, ValueRange{iv});
|
||||
Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
|
||||
Value lesserRankOperandExtent = b.create<ExtractElementOp>(
|
||||
Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
|
||||
loc, lesserRankOperand, ValueRange{ivShifted});
|
||||
|
||||
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
|
@ -118,12 +119,12 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
|
|||
Value outputDimension = args[0];
|
||||
Value isUnchallengedDimension = b.create<CmpIOp>(
|
||||
loc, CmpIPredicate::ult, outputDimension, rankDiff);
|
||||
Value greaterRankOperandExtent = b.create<ExtractElementOp>(
|
||||
Value greaterRankOperandExtent = b.create<tensor::ExtractOp>(
|
||||
loc, greaterRankOperand, outputDimension);
|
||||
// The initial dimensions of the greater-rank operand are unchallenged,
|
||||
// so we can take them as-is. Otherwise, we need to do a comparison.
|
||||
// We need an actual branch here (instead of a select) because the
|
||||
// lesser-rank operand might be rank 0, so any extract_element would be
|
||||
// lesser-rank operand might be rank 0, so any tensor.extract would be
|
||||
// invalid.
|
||||
auto ifOp = b.create<IfOp>(
|
||||
loc, TypeRange{indexTy}, isUnchallengedDimension,
|
||||
|
@ -140,7 +141,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
|
|||
// dimensions of zero extent.
|
||||
Value lesserRankOperandDimension =
|
||||
b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff);
|
||||
Value lesserRankOperandExtent = b.create<ExtractElementOp>(
|
||||
Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
|
||||
loc, lesserRankOperand,
|
||||
ValueRange{lesserRankOperandDimension});
|
||||
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
|
||||
|
@ -262,12 +263,12 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
|
|||
auto reduceResult = rewriter.create<ForOp>(
|
||||
loc, rankDiff, greaterRank, one, ValueRange{init},
|
||||
[&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
|
||||
Value greaterRankOperandExtent =
|
||||
b.create<ExtractElementOp>(loc, greaterRankOperand, ValueRange{iv});
|
||||
Value greaterRankOperandExtent = b.create<tensor::ExtractOp>(
|
||||
loc, greaterRankOperand, ValueRange{iv});
|
||||
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
|
||||
loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
|
||||
Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
|
||||
Value lesserRankOperandExtent = b.create<ExtractElementOp>(
|
||||
Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
|
||||
loc, lesserRankOperand, ValueRange{ivShifted});
|
||||
Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
|
||||
loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
|
||||
|
@ -316,9 +317,9 @@ LogicalResult GetExtentOpConverter::matchAndRewrite(
|
|||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(),
|
||||
transformed.shape(),
|
||||
ValueRange{transformed.dim()});
|
||||
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
|
||||
transformed.shape(),
|
||||
ValueRange{transformed.dim()});
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -375,7 +376,8 @@ ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
|
|||
auto loop = rewriter.create<scf::ForOp>(
|
||||
loc, zero, rank, one, op.initVals(),
|
||||
[&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
|
||||
Value extent = b.create<ExtractElementOp>(loc, transformed.shape(), iv);
|
||||
Value extent =
|
||||
b.create<tensor::ExtractOp>(loc, transformed.shape(), iv);
|
||||
|
||||
SmallVector<Value, 2> mappedValues{iv, extent};
|
||||
mappedValues.append(args.begin(), args.end());
|
||||
|
@ -415,8 +417,8 @@ namespace {
|
|||
/// %c1 = constant 1 : index
|
||||
/// %true = constant true
|
||||
/// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
|
||||
/// %5 = extract_element %arg0[%arg2] : tensor<?xindex>
|
||||
/// %6 = extract_element %arg1[%arg2] : tensor<?xindex>
|
||||
/// %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
|
||||
/// %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
|
||||
/// %7 = cmpi "eq", %5, %6 : index
|
||||
/// %8 = and %arg3, %7 : i1
|
||||
/// scf.yield %8 : i1
|
||||
|
@ -465,9 +467,9 @@ ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
|
|||
[&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
|
||||
Value conj = args[0];
|
||||
Value lhsExtent =
|
||||
b.create<ExtractElementOp>(loc, transformed.lhs(), iv);
|
||||
b.create<tensor::ExtractOp>(loc, transformed.lhs(), iv);
|
||||
Value rhsExtent =
|
||||
b.create<ExtractElementOp>(loc, transformed.rhs(), iv);
|
||||
b.create<tensor::ExtractOp>(loc, transformed.rhs(), iv);
|
||||
Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
|
||||
lhsExtent, rhsExtent);
|
||||
Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
|
||||
|
@ -584,7 +586,8 @@ void ConvertShapeToStandardPass::runOnOperation() {
|
|||
// Setup target legality.
|
||||
MLIRContext &ctx = getContext();
|
||||
ConversionTarget target(ctx);
|
||||
target.addLegalDialect<StandardOpsDialect, SCFDialect>();
|
||||
target
|
||||
.addLegalDialect<StandardOpsDialect, SCFDialect, tensor::TensorDialect>();
|
||||
target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp, ModuleTerminatorOp>();
|
||||
|
||||
// Setup conversion patterns.
|
||||
|
|
|
@ -15,6 +15,7 @@ add_subdirectory(SDBM)
|
|||
add_subdirectory(Shape)
|
||||
add_subdirectory(SPIRV)
|
||||
add_subdirectory(StandardOps)
|
||||
add_subdirectory(Tensor)
|
||||
add_subdirectory(Tosa)
|
||||
add_subdirectory(Vector)
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRStandard
|
|||
MLIREDSC
|
||||
MLIRIR
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRTensor
|
||||
MLIRVectorInterfaces
|
||||
MLIRViewLikeInterface
|
||||
)
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
|
||||
#include "mlir/Dialect/CommonFolders.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
|
@ -153,6 +154,7 @@ static LogicalResult verifyCastOp(T op) {
|
|||
}
|
||||
|
||||
void StandardOpsDialect::initialize() {
|
||||
getContext()->loadDialect<tensor::TensorDialect>();
|
||||
addOperations<DmaStartOp, DmaWaitOp,
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
|
||||
|
@ -1863,18 +1865,18 @@ struct StaticDynamicTensorFromElements
|
|||
/// <computation>
|
||||
/// yield %1 : index
|
||||
/// } : tensor<?xindex>
|
||||
/// %extracted_element = extract_element %tensor[%c0] : tensor<?xi32>
|
||||
/// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
|
||||
///
|
||||
/// to just <computation> with %arg0 replaced by %c0. We only do this if the
|
||||
/// dynamic_tensor_from_elements operation has no side-effects.
|
||||
struct ExtractElementFromDynamicTensorFromElements
|
||||
: public OpRewritePattern<ExtractElementOp> {
|
||||
using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
|
||||
struct ExtractFromDynamicTensorFromElements
|
||||
: public OpRewritePattern<tensor::ExtractOp> {
|
||||
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ExtractElementOp extract,
|
||||
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto tensorFromElements =
|
||||
extract.aggregate().getDefiningOp<DynamicTensorFromElementsOp>();
|
||||
extract.tensor().getDefiningOp<DynamicTensorFromElementsOp>();
|
||||
if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
|
||||
return failure();
|
||||
|
||||
|
@ -1894,23 +1896,22 @@ struct ExtractElementFromDynamicTensorFromElements
|
|||
/// Canonicalizes the pattern of the form
|
||||
///
|
||||
/// %val = tensor_cast %source : : tensor<?xi32> to tensor<2xi32>
|
||||
/// %extracted_element = extract_element %val[%c0] : tensor<2xi32>
|
||||
/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
|
||||
///
|
||||
/// to
|
||||
///
|
||||
/// %extracted_element = extract_element %source[%c0] : tensor<?xi32>
|
||||
struct ExtractElementFromTensorCast
|
||||
: public OpRewritePattern<ExtractElementOp> {
|
||||
using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
|
||||
/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
|
||||
struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
|
||||
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ExtractElementOp extract,
|
||||
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto tensorCast = extract.aggregate().getDefiningOp<TensorCastOp>();
|
||||
auto tensorCast = extract.tensor().getDefiningOp<TensorCastOp>();
|
||||
if (!tensorCast)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<ExtractElementOp>(extract, tensorCast.source(),
|
||||
extract.getIndices());
|
||||
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(),
|
||||
extract.indices());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -1919,51 +1920,9 @@ struct ExtractElementFromTensorCast
|
|||
|
||||
void DynamicTensorFromElementsOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<ExtractElementFromDynamicTensorFromElements,
|
||||
ExtractElementFromTensorCast, StaticDynamicTensorFromElements>(
|
||||
context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExtractElementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(ExtractElementOp op) {
|
||||
// Verify the # indices match if we have a ranked type.
|
||||
auto aggregateType = op.getAggregate().getType().cast<ShapedType>();
|
||||
if (aggregateType.hasRank() &&
|
||||
aggregateType.getRank() != op.getNumOperands() - 1)
|
||||
return op.emitOpError("incorrect number of indices for extract_element");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(!operands.empty() && "extract_element takes at least one operand");
|
||||
|
||||
// The aggregate operand must be a known constant.
|
||||
Attribute aggregate = operands.front();
|
||||
if (!aggregate)
|
||||
return {};
|
||||
|
||||
// If this is a splat elements attribute, simply return the value. All of the
|
||||
// elements of a splat attribute are the same.
|
||||
if (auto splatAggregate = aggregate.dyn_cast<SplatElementsAttr>())
|
||||
return splatAggregate.getSplatValue();
|
||||
|
||||
// Otherwise, collect the constant indices into the aggregate.
|
||||
SmallVector<uint64_t, 8> indices;
|
||||
for (Attribute indice : llvm::drop_begin(operands, 1)) {
|
||||
if (!indice || !indice.isa<IntegerAttr>())
|
||||
return {};
|
||||
indices.push_back(indice.cast<IntegerAttr>().getInt());
|
||||
}
|
||||
|
||||
// If this is an elements attribute, query the value at the given indices.
|
||||
auto elementsAttr = aggregate.dyn_cast<ElementsAttr>();
|
||||
if (elementsAttr && elementsAttr.isValidIndex(indices))
|
||||
return elementsAttr.getValue(indices);
|
||||
return {};
|
||||
// TODO: Move extract patterns to tensor::ExtractOp.
|
||||
results.insert<ExtractFromDynamicTensorFromElements, ExtractFromTensorCast,
|
||||
StaticDynamicTensorFromElements>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1989,20 +1948,20 @@ namespace {
|
|||
// Canonicalizes the pattern of the form
|
||||
//
|
||||
// %tensor = "tensor_from_elements(%element) : (i32) -> tensor<1xi32>
|
||||
// %extracted_element = extract_element %tensor[%c0] : tensor<1xi32>
|
||||
// %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
|
||||
//
|
||||
// to just %element.
|
||||
struct ExtractElementFromTensorFromElements
|
||||
: public OpRewritePattern<ExtractElementOp> {
|
||||
using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
|
||||
: public OpRewritePattern<tensor::ExtractOp> {
|
||||
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ExtractElementOp extract,
|
||||
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
|
||||
PatternRewriter &rewriter) const final {
|
||||
if (extract.indices().size() != 1)
|
||||
return failure();
|
||||
|
||||
auto tensorFromElements = dyn_cast_or_null<TensorFromElementsOp>(
|
||||
extract.aggregate().getDefiningOp());
|
||||
extract.tensor().getDefiningOp());
|
||||
if (tensorFromElements == nullptr)
|
||||
return failure();
|
||||
|
||||
|
@ -2216,7 +2175,7 @@ OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
|
|||
}
|
||||
|
||||
namespace {
|
||||
/// Fold a load on a tensor_to_memref operation into an extract_element on the
|
||||
/// Fold a load on a tensor_to_memref operation into an tensor.extract on the
|
||||
/// corresponding tensor.
|
||||
struct LoadOfTensorToMemref : public OpRewritePattern<LoadOp> {
|
||||
using OpRewritePattern<LoadOp>::OpRewritePattern;
|
||||
|
@ -2227,8 +2186,8 @@ struct LoadOfTensorToMemref : public OpRewritePattern<LoadOp> {
|
|||
if (!tensorToMemref)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<ExtractElementOp>(load, tensorToMemref.tensor(),
|
||||
load.indices());
|
||||
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
|
||||
load, tensorToMemref.tensor(), load.indices());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
|
@ -88,21 +89,6 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class BufferizeExtractElementOp : public OpConversionPattern<ExtractElementOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(ExtractElementOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
ExtractElementOp::Adaptor adaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<LoadOp>(op, adaptor.aggregate(),
|
||||
adaptor.indices());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
|
||||
public:
|
||||
|
@ -165,7 +151,6 @@ void mlir::populateStdBufferizePatterns(MLIRContext *context,
|
|||
// clang-format off
|
||||
BufferizeDimOp,
|
||||
BufferizeDynamicTensorFromElementsOp,
|
||||
BufferizeExtractElementOp,
|
||||
BufferizeSelectOp,
|
||||
BufferizeTensorCastOp,
|
||||
BufferizeTensorFromElementsOp
|
||||
|
@ -183,10 +168,11 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
|
|||
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addLegalDialect<scf::SCFDialect>();
|
||||
target.addLegalDialect<tensor::TensorDialect>();
|
||||
|
||||
populateStdBufferizePatterns(context, typeConverter, patterns);
|
||||
target.addIllegalOp<DynamicTensorFromElementsOp, ExtractElementOp,
|
||||
TensorCastOp, TensorFromElementsOp>();
|
||||
target.addIllegalOp<DynamicTensorFromElementsOp, TensorCastOp,
|
||||
TensorFromElementsOp>();
|
||||
// We only bufferize the case of tensor selected type and scalar condition,
|
||||
// as that boils down to a select over memref descriptors (don't need to
|
||||
// touch the data).
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
|
@ -0,0 +1,17 @@
|
|||
add_mlir_dialect_library(MLIRTensor
|
||||
TensorDialect.cpp
|
||||
TensorOps.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/mlir/Dialect/Tensor
|
||||
|
||||
DEPENDS
|
||||
MLIRTensorOpsIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRSupport
|
||||
)
|
|
@ -0,0 +1,39 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tensor;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorDialect Dialect Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
struct TensorInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
|
||||
BlockAndValueMapping &valueMapping) const final {
|
||||
return true;
|
||||
}
|
||||
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
|
||||
BlockAndValueMapping &) const final {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void TensorDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
|
||||
>();
|
||||
addInterfaces<TensorInlinerInterface>();
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tensor;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExtractOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(ExtractOp op) {
|
||||
// Verify the # indices match if we have a ranked type.
|
||||
if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
|
||||
if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
|
||||
return op.emitOpError("incorrect number of indices for extract_element");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
|
||||
// The tensor operand must be a known constant.
|
||||
Attribute tensor = operands.front();
|
||||
if (!tensor)
|
||||
return {};
|
||||
// If this is a splat elements attribute, simply return the value. All of the
|
||||
// elements of a splat attribute are the same.
|
||||
if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
|
||||
return splatTensor.getSplatValue();
|
||||
|
||||
// Otherwise, collect the constant indices into the tensor.
|
||||
SmallVector<uint64_t, 8> indices;
|
||||
for (Attribute indice : llvm::drop_begin(operands, 1)) {
|
||||
if (!indice || !indice.isa<IntegerAttr>())
|
||||
return {};
|
||||
indices.push_back(indice.cast<IntegerAttr>().getInt());
|
||||
}
|
||||
|
||||
// If this is an elements attribute, query the value at the given indices.
|
||||
auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
|
||||
if (elementsAttr && elementsAttr.isValidIndex(indices))
|
||||
return elementsAttr.getValue(indices);
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
|
|
@ -0,0 +1,64 @@
|
|||
//===- Bufferize.cpp - Bufferization for `tensor` dialect ops -------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements bufferization of `tensor` dialect ops
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Transforms/Bufferize.h"
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(tensor::ExtractOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
tensor::ExtractOp::Adaptor adaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<LoadOp>(op, adaptor.tensor(),
|
||||
adaptor.indices());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::populateTensorBufferizePatterns(
|
||||
MLIRContext *context, BufferizeTypeConverter &typeConverter,
|
||||
OwningRewritePatternList &patterns) {
|
||||
patterns.insert<BufferizeExtractOp>(typeConverter, context);
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
|
||||
void runOnFunction() override {
|
||||
auto *context = &getContext();
|
||||
BufferizeTypeConverter typeConverter;
|
||||
OwningRewritePatternList patterns;
|
||||
ConversionTarget target(*context);
|
||||
|
||||
populateTensorBufferizePatterns(context, typeConverter, patterns);
|
||||
target.addIllegalOp<tensor::ExtractOp>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
|
||||
if (failed(
|
||||
applyPartialConversion(getFunction(), target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTensorBufferizePass() {
|
||||
return std::make_unique<TensorBufferizePass>();
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
add_mlir_dialect_library(MLIRTensorTransforms
|
||||
Bufferize.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Transforms
|
||||
|
||||
DEPENDS
|
||||
MLIRTensorTransformsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRTensor
|
||||
MLIRTransforms
|
||||
)
|
|
@ -0,0 +1,21 @@
|
|||
//===- PassDetail.h - GPU Pass class details --------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef DIALECT_TENSOR_TRANSFORMS_PASSDETAIL_H_
|
||||
#define DIALECT_TENSOR_TRANSFORMS_PASSDETAIL_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // DIALECT_TENSOR_TRANSFORMS_PASSDETAIL_H_
|
|
@ -13,6 +13,7 @@
|
|||
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
@ -59,6 +60,23 @@ static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
|
|||
assert(matchPattern(constOp, m_Constant()));
|
||||
return constOp;
|
||||
}
|
||||
|
||||
// TODO: To faciliate splitting the std dialect (PR48490), have a special case
|
||||
// for falling back to std.constant. Eventually, we will have separate ops
|
||||
// tensor.constant, int.constant, float.constant, etc. that live in their
|
||||
// respective dialects, which will allow each dialect to implement the
|
||||
// materializeConstant hook above.
|
||||
//
|
||||
// The special case is needed because in the interim state while we are
|
||||
// splitting out those dialects from std, the std dialect depends on the
|
||||
// tensor dialect, which makes it impossible for the tensor dialect to use
|
||||
// std.constant (it would be a cyclic dependency) as part of its
|
||||
// materializeConstant hook.
|
||||
//
|
||||
// If the dialect is unable to materialize a constant, check to see if the
|
||||
// standard constant can be used.
|
||||
if (ConstantOp::isBuildableWith(value, type))
|
||||
return builder.create<ConstantOp>(loc, type, value);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -16,9 +16,9 @@
|
|||
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[RHS]], %[[LHS]] : tensor<?xindex>
|
||||
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
|
||||
// CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] {
|
||||
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
|
||||
// CHECK: %[[IVSHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index
|
||||
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[IVSHIFTED]]] : tensor<?xindex>
|
||||
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[LESSER_RANK_OPERAND]][%[[IVSHIFTED]]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
|
||||
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[LESSER_RANK_OPERAND_EXTENT]], %[[C1]] : index
|
||||
// CHECK: %[[EXTENTS_AGREE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[LESSER_RANK_OPERAND_EXTENT]] : index
|
||||
|
|
|
@ -74,12 +74,12 @@ func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : index) -> index {
|
|||
|
||||
// -----
|
||||
|
||||
// Express `get_extent` as `std.extract_element`.
|
||||
// Express `get_extent` as `std.tensor.extract`.
|
||||
// CHECK-LABEL: @get_extent_from_extent_tensor
|
||||
// CHECK-SAME: (%[[EXTENTS:.*]]: tensor<?xindex>, %[[IDX:.*]]: index) -> index
|
||||
func @get_extent_from_extent_tensor(%extents : tensor<?xindex>, %idx : index)
|
||||
-> index {
|
||||
// CHECK: %[[RESULT:.*]] = extract_element %[[EXTENTS]][%[[IDX]]] : tensor<?xindex>
|
||||
// CHECK: %[[RESULT:.*]] = tensor.extract %[[EXTENTS]][%[[IDX]]] : tensor<?xindex>
|
||||
// CHECK: return %[[RESULT]] : index
|
||||
%result = shape.get_extent %extents, %idx : tensor<?xindex>, index -> index
|
||||
return %result : index
|
||||
|
@ -180,7 +180,7 @@ func @shape_reduce(%shape : tensor<?xindex>) -> index {
|
|||
// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK-NEXT: %[[RANK:.*]] = dim %[[SHAPE]], %[[C0]] : tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) -> (index)
|
||||
// CHECK-NEXT: %[[EXTENT:.*]] = extract_element %[[SHAPE]][%[[I]]]
|
||||
// CHECK-NEXT: %[[EXTENT:.*]] = tensor.extract %[[SHAPE]][%[[I]]]
|
||||
// CHECK-NEXT: %[[NEW_ACC:.*]] = muli %[[ACC]], %[[EXTENT]] : index
|
||||
// CHECK-NEXT: scf.yield %[[NEW_ACC]] : index
|
||||
// CHECK-NEXT: }
|
||||
|
@ -277,8 +277,8 @@ func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>) -> i1 {
|
|||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[INIT:.*]] = constant true
|
||||
// CHECK: %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) {
|
||||
// CHECK: %[[EXTENT_A:.*]] = extract_element %[[A]][%[[I]]] : tensor<?xindex>
|
||||
// CHECK: %[[EXTENT_B:.*]] = extract_element %[[B]][%[[I]]] : tensor<?xindex>
|
||||
// CHECK: %[[EXTENT_A:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xindex>
|
||||
// CHECK: %[[EXTENT_B:.*]] = tensor.extract %[[B]][%[[I]]] : tensor<?xindex>
|
||||
// CHECK: %[[EXTENT_EQ:.*]] = cmpi "eq", %[[EXTENT_A]], %[[EXTENT_B]]
|
||||
// CHECK: %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]]
|
||||
// CHECK: scf.yield %[[CONJ_NEXT]] : i1
|
||||
|
@ -324,12 +324,12 @@ func @broadcast_unknown_extents(%a : tensor<?xindex>, %b : tensor<?xindex>) {
|
|||
// CHECK: %[[RESULT:.*]] = dynamic_tensor_from_elements %[[GREATER_RANK]] {
|
||||
// CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index):
|
||||
// CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi "ult", %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
|
||||
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
|
||||
// CHECK: %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) {
|
||||
// CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
|
||||
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex>
|
||||
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
|
||||
// CHECK: %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index
|
||||
// CHECK: scf.yield %[[BROADCASTED_EXTENT]] : index
|
||||
|
@ -364,12 +364,12 @@ func @broadcast_known_different_extents(%a : tensor<2xindex>, %b : tensor<3xinde
|
|||
// CHECK: %[[RESULT:.*]] = dynamic_tensor_from_elements %[[GREATER_RANK]] {
|
||||
// CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index):
|
||||
// CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi "ult", %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
|
||||
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
|
||||
// CHECK: %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) {
|
||||
// CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
|
||||
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex>
|
||||
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
|
||||
// CHECK: %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index
|
||||
// CHECK: scf.yield %[[BROADCASTED_EXTENT]] : index
|
||||
|
@ -407,10 +407,10 @@ func @try_is_broadcastable(%a : tensor<3xindex>, %b : tensor<?xindex>) -> i1 {
|
|||
// CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index
|
||||
// CHECK: %[[TRUE:.*]] = constant true
|
||||
// CHECK: %[[ALL_RESULT:.*]] = scf.for %[[I:.*]] = %[[RANK_DIFF]] to %[[LARGER_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) {
|
||||
// CHECK: %[[LARGER_EXTENT:.*]] = extract_element %[[LARGER_SHAPE]]{{\[}}%[[I]]] : tensor<?xindex>
|
||||
// CHECK: %[[LARGER_EXTENT:.*]] = tensor.extract %[[LARGER_SHAPE]]{{\[}}%[[I]]] : tensor<?xindex>
|
||||
// CHECK: %[[LARGER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[C1]] : index
|
||||
// CHECK: %[[SMALLER_EXTENT_INDEX:.*]] = subi %[[I]], %[[RANK_DIFF]] : index
|
||||
// CHECK: %[[SMALLER_EXTENT:.*]] = extract_element %[[SMALLER_SHAPE]]{{\[}}%[[SMALLER_EXTENT_INDEX]]] : tensor<?xindex>
|
||||
// CHECK: %[[SMALLER_EXTENT:.*]] = tensor.extract %[[SMALLER_SHAPE]]{{\[}}%[[SMALLER_EXTENT_INDEX]]] : tensor<?xindex>
|
||||
// CHECK: %[[SMALLER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[SMALLER_EXTENT]], %[[C1]] : index
|
||||
// CHECK: %[[EXTENTS_ARE_EQUAL:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[SMALLER_EXTENT]] : index
|
||||
// CHECK: %[[EITHER_EXTENT_IS_ONE:.*]] = or %[[LARGER_EXTENT_IS_ONE]], %[[SMALLER_EXTENT_IS_ONE]] : i1
|
||||
|
@ -445,10 +445,10 @@ func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) -> !shape.witness {
|
|||
// CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index
|
||||
// CHECK: %[[TRUE:.*]] = constant true
|
||||
// CHECK: %[[ALL_RESULT:.*]] = scf.for %[[VAL_16:.*]] = %[[RANK_DIFF]] to %[[LARGER_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) {
|
||||
// CHECK: %[[LARGER_EXTENT:.*]] = extract_element %[[LARGER_SHAPE]]{{\[}}%[[VAL_16]]] : tensor<?xindex>
|
||||
// CHECK: %[[LARGER_EXTENT:.*]] = tensor.extract %[[LARGER_SHAPE]]{{\[}}%[[VAL_16]]] : tensor<?xindex>
|
||||
// CHECK: %[[LARGER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[C1]] : index
|
||||
// CHECK: %[[LHS_EXTENT_INDEX:.*]] = subi %[[VAL_16]], %[[RANK_DIFF]] : index
|
||||
// CHECK: %[[SMALLER_EXTENT:.*]] = extract_element %[[SMALLER_SHAPE]]{{\[}}%[[LHS_EXTENT_INDEX]]] : tensor<?xindex>
|
||||
// CHECK: %[[SMALLER_EXTENT:.*]] = tensor.extract %[[SMALLER_SHAPE]]{{\[}}%[[LHS_EXTENT_INDEX]]] : tensor<?xindex>
|
||||
// CHECK: %[[SMALLER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[SMALLER_EXTENT]], %[[C1]] : index
|
||||
// CHECK: %[[EXTENTS_ARE_EQUAL:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[SMALLER_EXTENT]] : index
|
||||
// CHECK: %[[EITHER_EXTENT_IS_ONE:.*]] = or %[[LARGER_EXTENT_IS_ONE]], %[[SMALLER_EXTENT_IS_ONE]] : i1
|
||||
|
|
|
@ -395,7 +395,7 @@ func @scalar_indexed_generic_fusion
|
|||
ins(%arg1 : tensor<i32>) {
|
||||
^bb0(%arg2: i32): // no predecessors
|
||||
%3 = index_cast %arg2 : i32 to index
|
||||
%4 = extract_element %arg0[%3, %c0, %c0] : tensor<5x1x1xf32>
|
||||
%4 = tensor.extract %arg0[%3, %c0, %c0] : tensor<5x1x1xf32>
|
||||
linalg.yield %4 : f32
|
||||
} -> tensor<f32>
|
||||
%1 = linalg.generic
|
||||
|
@ -418,6 +418,6 @@ func @scalar_indexed_generic_fusion
|
|||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel"]
|
||||
// CHECK-SAME: ins(%[[ARG1]] : tensor<i32>)
|
||||
// CHECK: extract_element %[[ARG0]]
|
||||
// CHECK: tensor.extract %[[ARG0]]
|
||||
// CHECK: linalg.yield
|
||||
// CHECK return %[[T0]]
|
||||
|
|
|
@ -61,18 +61,6 @@ func @dynamic_tensor_from_elements_static_and_dynamic(%arg0: index) -> tensor<16
|
|||
return %result : tensor<16x?xindex>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @extract_element(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xf32>,
|
||||
// CHECK-SAME: %[[IDX:.*]]: index) -> f32 {
|
||||
// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<?xf32>
|
||||
// CHECK: %[[RET:.*]] = load %[[MEMREF]][%[[IDX]]] : memref<?xf32>
|
||||
// CHECK: return %[[RET]] : f32
|
||||
// CHECK: }
|
||||
func @extract_element(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
|
||||
%0 = extract_element %arg0[%arg1] : tensor<?xf32>
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @select(
|
||||
// CHECK-SAME: %[[PRED:.*]]: i1,
|
||||
// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor<f32>,
|
||||
|
@ -138,14 +126,14 @@ func @tensor_from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> {
|
|||
// The dynamic_tensor_from_elements op clones each op in its body.
|
||||
// Make sure that regions nested within such ops are recursively converted.
|
||||
// CHECK-LABEL: func @recursively_convert_cloned_regions
|
||||
func @recursively_convert_cloned_regions(%arg0: tensor<?xindex>, %arg1: index, %arg2: i1) -> tensor<?xindex> {
|
||||
func @recursively_convert_cloned_regions(%arg0: tensor<*xf32>, %arg1: index, %arg2: i1) -> tensor<?xindex> {
|
||||
%tensor = dynamic_tensor_from_elements %arg1 {
|
||||
^bb0(%iv: index):
|
||||
%48 = scf.if %arg2 -> (index) {
|
||||
scf.yield %iv : index
|
||||
} else {
|
||||
// CHECK-NOT: extract_element
|
||||
%50 = extract_element %arg0[%iv] : tensor<?xindex>
|
||||
// CHECK-NOT: dim{{.*}}tensor
|
||||
%50 = dim %arg0, %iv : tensor<*xf32>
|
||||
scf.yield %50 : index
|
||||
}
|
||||
yield %48 : index
|
||||
|
|
|
@ -46,11 +46,11 @@ func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index {
|
|||
}
|
||||
|
||||
// Test case: Folding of load(tensor_to_memref(%v, %idxs))
|
||||
// -> extract_element(%v, %idx)
|
||||
// -> tensor.extract(%v, %idx)
|
||||
// CHECK-LABEL: func @load_from_tensor_to_memref(
|
||||
// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
|
||||
// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
|
||||
// CHECK: %[[RES:.*]] = extract_element %[[TENSOR]][%[[IDX0]], %[[IDX1]]]
|
||||
// CHECK: %[[RES:.*]] = tensor.extract %[[TENSOR]][%[[IDX0]], %[[IDX1]]]
|
||||
// CHECK-NOT: load
|
||||
// CHECK: return %[[RES]] : f32
|
||||
func @load_from_tensor_to_memref(%arg0: index, %arg1: index, %arg2: tensor<?x?xf32>) -> f32 {
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @extract(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xf32>,
|
||||
// CHECK-SAME: %[[IDX:.*]]: index) -> f32 {
|
||||
// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<?xf32>
|
||||
// CHECK: %[[RET:.*]] = load %[[MEMREF]][%[[IDX]]] : memref<?xf32>
|
||||
// CHECK: return %[[RET]] : f32
|
||||
// CHECK: }
|
||||
func @extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
|
||||
%0 = tensor.extract %arg0[%arg1] : tensor<?xf32>
|
||||
return %0 : f32
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
// RUN: mlir-opt %s -canonicalize | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fold_extract
|
||||
func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32) {
|
||||
%const_0 = constant 0 : index
|
||||
%const_1 = constant 1 : index
|
||||
%const_3 = constant 3 : index
|
||||
|
||||
// Fold an extract into a splat.
|
||||
// CHECK-NEXT: [[C4:%.+]] = constant 4.{{0*}}e+00 : f32
|
||||
%0 = constant dense<4.0> : tensor<4xf32>
|
||||
%ext_1 = tensor.extract %0[%arg0] : tensor<4xf32>
|
||||
|
||||
// Fold an extract into a sparse with a sparse index.
|
||||
// CHECK-NEXT: [[CM2:%.+]] = constant -2.{{0*}}e+00 : f16
|
||||
%1 = constant sparse<[[0, 0, 0], [1, 1, 1]], [-5.0, -2.0]> : tensor<4x4x4xf16>
|
||||
%ext_2 = tensor.extract %1[%const_1, %const_1, %const_1] : tensor<4x4x4xf16>
|
||||
|
||||
// Fold an extract into a sparse with a non sparse index.
|
||||
// CHECK-NEXT: [[C0:%.+]] = constant 0.{{0*}}e+00 : f16
|
||||
%2 = constant sparse<[[1, 1, 1]], [-2.0]> : tensor<1x1x1xf16>
|
||||
%ext_3 = tensor.extract %2[%const_0, %const_0, %const_0] : tensor<1x1x1xf16>
|
||||
|
||||
// Fold an extract into a dense tensor.
|
||||
// CHECK-NEXT: [[C64:%.+]] = constant 64 : i32
|
||||
%3 = constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
|
||||
%ext_4 = tensor.extract %3[%const_1, %const_0, %const_3] : tensor<2x1x4xi32>
|
||||
|
||||
// CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]]
|
||||
return %ext_1, %ext_2, %ext_3, %ext_4 : f32, f16, f16, i32
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
// RUN: mlir-opt <%s -verify-diagnostics
|
||||
|
||||
// -----
|
||||
|
||||
func @extract_too_many_indices(%arg0: tensor<?xf32>) {
|
||||
// expected-error@+1 {{incorrect number of indices for extract_element}}
|
||||
%0 = tensor.extract %arg0[] : tensor<?xf32>
|
||||
return
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
// RUN: mlir-opt <%s | mlir-opt | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @extract(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?x?x?xf32>,
|
||||
// CHECK-SAME: %[[INDEX:.*]]: index) {
|
||||
func @extract(%arg0: tensor<?x?x?xf32>, %arg1: index) {
|
||||
// CHECK: tensor.extract %[[TENSOR]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<?x?x?xf32>
|
||||
%0 = tensor.extract %arg0[%arg1, %arg1, %arg1] : tensor<?x?x?xf32>
|
||||
return
|
||||
}
|
|
@ -672,19 +672,6 @@ func @calls(%arg0: i32) {
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @extract_element(%arg0: tensor<*xi32>, %arg1: tensor<4x4xf32>) -> i32 {
|
||||
func @extract_element(%arg0: tensor<*xi32>, %arg1 : tensor<4x4xf32>) -> i32 {
|
||||
%c0 = "std.constant"() {value = 0: index} : () -> index
|
||||
|
||||
// CHECK: %0 = extract_element %arg0[%c0, %c0, %c0, %c0] : tensor<*xi32>
|
||||
%0 = extract_element %arg0[%c0, %c0, %c0, %c0] : tensor<*xi32>
|
||||
|
||||
// CHECK: %1 = extract_element %arg1[%c0, %c0] : tensor<4x4xf32>
|
||||
%1 = extract_element %arg1[%c0, %c0] : tensor<4x4xf32>
|
||||
|
||||
return %0 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tensor_from_elements() {
|
||||
func @tensor_from_elements() {
|
||||
%c0 = "std.constant"() {value = 0: index} : () -> index
|
||||
|
@ -972,4 +959,3 @@ func @subtensor_insert(%t: tensor<8x16x4xf32>, %t2: tensor<16x32x8xf32>, %idx :
|
|||
|
||||
// CHECK-LABEL: func private @legacy_visibility_syntax
|
||||
func @legacy_visibility_syntax() attributes { sym_visibility = "private" }
|
||||
|
||||
|
|
|
@ -541,61 +541,6 @@ func @cmpf_canonical_type_mismatch(%a : f32, %b : f64) { // expected-note {{prio
|
|||
|
||||
// -----
|
||||
|
||||
func @extract_element_no_operands() {
|
||||
// expected-error@+1 {{op expected 1 or more operands}}
|
||||
%0 = "std.extract_element"() : () -> f32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @extract_element_no_indices(%v : vector<3xf32>) {
|
||||
// expected-error@+1 {{incorrect number of indices for extract_element}}
|
||||
%0 = "std.extract_element"(%v) : (vector<3xf32>) -> f32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @extract_element_invalid_index_type(%v : vector<3xf32>, %i : i32) {
|
||||
// expected-error@+1 {{operand #1 must be index}}
|
||||
%0 = "std.extract_element"(%v, %i) : (vector<3xf32>, i32) -> f32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @extract_element_element_result_type_mismatch(%v : vector<3xf32>, %i : index) {
|
||||
// expected-error@+1 {{result type matches element type of aggregate}}
|
||||
%0 = "std.extract_element"(%v, %i) : (vector<3xf32>, index) -> f64
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @extract_element_vector_too_many_indices(%v : vector<3xf32>, %i : index) {
|
||||
// expected-error@+1 {{incorrect number of indices for extract_element}}
|
||||
%0 = "std.extract_element"(%v, %i, %i) : (vector<3xf32>, index, index) -> f32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @extract_element_tensor_too_many_indices(%t : tensor<2x3xf32>, %i : index) {
|
||||
// expected-error@+1 {{incorrect number of indices for extract_element}}
|
||||
%0 = "std.extract_element"(%t, %i, %i, %i) : (tensor<2x3xf32>, index, index, index) -> f32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @extract_element_tensor_too_few_indices(%t : tensor<2x3xf32>, %i : index) {
|
||||
// expected-error@+1 {{incorrect number of indices for extract_element}}
|
||||
%0 = "std.extract_element"(%t, %i) : (tensor<2x3xf32>, index) -> f32 return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @tensor_from_elements_wrong_result_type() {
|
||||
// expected-error@+2 {{'result' must be 1D tensor of any type values, but got 'tensor<*xi32>'}}
|
||||
%c0 = constant 0 : i32
|
||||
|
|
|
@ -1040,21 +1040,21 @@ func @memref_cast_folding_subview_static(%V: memref<16x16xf32>, %a: index, %b: i
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @extract_element_from_tensor_from_elements
|
||||
func @extract_element_from_tensor_from_elements(%element : index) -> index {
|
||||
// CHECK-LABEL: func @extract_from_tensor_from_elements
|
||||
func @extract_from_tensor_from_elements(%element : index) -> index {
|
||||
// CHECK-SAME: ([[ARG:%.*]]: index)
|
||||
%c0 = constant 0 : index
|
||||
%tensor = tensor_from_elements %element : tensor<1xindex>
|
||||
%extracted_element = extract_element %tensor[%c0] : tensor<1xindex>
|
||||
%extracted_element = tensor.extract %tensor[%c0] : tensor<1xindex>
|
||||
// CHECK: [[ARG]] : index
|
||||
return %extracted_element : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @extract_element_from_dynamic_tensor_from_elements
|
||||
// CHECK-LABEL: func @extract_from_dynamic_tensor_from_elements
|
||||
// CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
|
||||
func @extract_element_from_dynamic_tensor_from_elements(%idx: index, %tensor: tensor<*xf32>) -> index {
|
||||
func @extract_from_dynamic_tensor_from_elements(%idx: index, %tensor: tensor<*xf32>) -> index {
|
||||
%size = rank %tensor : tensor<*xf32>
|
||||
// CHECK-NEXT: %[[RES:.*]] = dim %[[TENSOR]], %[[IDX]]
|
||||
%0 = dynamic_tensor_from_elements %size {
|
||||
|
@ -1062,16 +1062,16 @@ func @extract_element_from_dynamic_tensor_from_elements(%idx: index, %tensor: te
|
|||
%1 = dim %tensor, %arg0 : tensor<*xf32>
|
||||
yield %1 : index
|
||||
} : tensor<?xindex>
|
||||
%1 = extract_element %0[%idx] : tensor<?xindex>
|
||||
%1 = tensor.extract %0[%idx] : tensor<?xindex>
|
||||
// CHECK-NEXT: return %[[RES]]
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @extract_element_from_dynamic_tensor_from_elements_2d
|
||||
// CHECK-LABEL: func @extract_from_dynamic_tensor_from_elements_2d
|
||||
// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
|
||||
func @extract_element_from_dynamic_tensor_from_elements_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index {
|
||||
func @extract_from_dynamic_tensor_from_elements_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index {
|
||||
%size = rank %tensor : tensor<*xf32>
|
||||
// CHECK-NEXT: %[[DIM0:.*]] = dim %[[TENSOR]], %[[IDX0]]
|
||||
// CHECK-NEXT: %[[DIM1:.*]] = dim %[[TENSOR]], %[[IDX1]]
|
||||
|
@ -1083,16 +1083,16 @@ func @extract_element_from_dynamic_tensor_from_elements_2d(%idx0: index, %idx1:
|
|||
%3 = addi %1, %2 : index
|
||||
yield %3 : index
|
||||
} : tensor<?x?xindex>
|
||||
%4 = extract_element %0[%idx0, %idx1] : tensor<?x?xindex>
|
||||
%4 = tensor.extract %0[%idx0, %idx1] : tensor<?x?xindex>
|
||||
// CHECK-NEXT: return %[[RES]]
|
||||
return %4 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @extract_element_from_dynamic_tensor_from_elements_sideeffects
|
||||
// CHECK-LABEL: func @extract_from_dynamic_tensor_from_elements_sideeffects
|
||||
// CHECK-SAME: %[[IDX:.*]]: index
|
||||
func @extract_element_from_dynamic_tensor_from_elements_sideeffects(%idx: index, %tensor: tensor<*xf32>) -> index {
|
||||
func @extract_from_dynamic_tensor_from_elements_sideeffects(%idx: index, %tensor: tensor<*xf32>) -> index {
|
||||
%size = rank %tensor : tensor<*xf32>
|
||||
%mem = alloc(%size) : memref<?xindex>
|
||||
// CHECK: %[[DTENSOR:.*]] = dynamic_tensor_from_elements
|
||||
|
@ -1102,8 +1102,8 @@ func @extract_element_from_dynamic_tensor_from_elements_sideeffects(%idx: index,
|
|||
store %1, %mem[%arg0] : memref<?xindex>
|
||||
yield %1 : index
|
||||
} : tensor<?xindex>
|
||||
// CHECK: %[[RES:.*]] = extract_element %[[DTENSOR]][%[[IDX]]]
|
||||
%1 = extract_element %0[%idx] : tensor<?xindex>
|
||||
// CHECK: %[[RES:.*]] = tensor.extract %[[DTENSOR]][%[[IDX]]]
|
||||
%1 = tensor.extract %0[%idx] : tensor<?xindex>
|
||||
// CHECK-NEXT: return %[[RES]]
|
||||
return %1 : index
|
||||
}
|
||||
|
@ -1205,14 +1205,14 @@ func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index)
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @extract_element_from_tensor_cast
|
||||
// CHECK-LABEL: func @extract_from_tensor_cast
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>
|
||||
func @extract_element_from_tensor_cast(%tensor: tensor<*xf32>) -> f32 {
|
||||
func @extract_from_tensor_cast(%tensor: tensor<*xf32>) -> f32 {
|
||||
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
|
||||
%c0 = constant 0 : index
|
||||
// CHECK-NOT: tensor_cast
|
||||
%casted = tensor_cast %tensor : tensor<*xf32> to tensor<?xf32>
|
||||
// CHECK-NEXT: extract_element %[[TENSOR]][%[[C0]]]
|
||||
%result = extract_element %casted[%c0] : tensor<?xf32>
|
||||
// CHECK-NEXT: tensor.extract %[[TENSOR]][%[[C0]]]
|
||||
%result = tensor.extract %casted[%c0] : tensor<?xf32>
|
||||
return %result : f32
|
||||
}
|
||||
|
|
|
@ -716,38 +716,6 @@ func @cmpf_inf() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1,
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fold_extract_element
|
||||
func @fold_extract_element(%arg0 : index) -> (f32, f16, f16, i32) {
|
||||
%const_0 = constant 0 : index
|
||||
%const_1 = constant 1 : index
|
||||
%const_3 = constant 3 : index
|
||||
|
||||
// Fold an extract into a splat.
|
||||
// CHECK-NEXT: [[C4:%.+]] = constant 4.{{0*}}e+00 : f32
|
||||
%0 = constant dense<4.0> : tensor<4xf32>
|
||||
%ext_1 = extract_element %0[%arg0] : tensor<4xf32>
|
||||
|
||||
// Fold an extract into a sparse with a sparse index.
|
||||
// CHECK-NEXT: [[CM2:%.+]] = constant -2.{{0*}}e+00 : f16
|
||||
%1 = constant sparse<[[0, 0, 0], [1, 1, 1]], [-5.0, -2.0]> : vector<4x4x4xf16>
|
||||
%ext_2 = extract_element %1[%const_1, %const_1, %const_1] : vector<4x4x4xf16>
|
||||
|
||||
// Fold an extract into a sparse with a non sparse index.
|
||||
// CHECK-NEXT: [[C0:%.+]] = constant 0.{{0*}}e+00 : f16
|
||||
%2 = constant sparse<[[1, 1, 1]], [-2.0]> : vector<1x1x1xf16>
|
||||
%ext_3 = extract_element %2[%const_0, %const_0, %const_0] : vector<1x1x1xf16>
|
||||
|
||||
// Fold an extract into a dense tensor.
|
||||
// CHECK-NEXT: [[C64:%.+]] = constant 64 : i32
|
||||
%3 = constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
|
||||
%ext_4 = extract_element %3[%const_1, %const_0, %const_3] : tensor<2x1x4xi32>
|
||||
|
||||
// CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]]
|
||||
return %ext_1, %ext_2, %ext_3, %ext_4 : f32, f16, f16, i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fold_rank
|
||||
func @fold_rank() -> (index) {
|
||||
%const_0 = constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
|
||||
|
|
|
@ -38,7 +38,7 @@ syn match mlirType /x\s*\zsvector/
|
|||
" TODO: this list is not exhaustive.
|
||||
syn keyword mlirOps alloc alloca addf addi and call call_indirect cmpf cmpi
|
||||
syn keyword mlirOps constant dealloc divf dma_start dma_wait dim exp
|
||||
syn keyword mlirOps extract_element getTensor index_cast load log memref_cast
|
||||
syn keyword mlirOps getTensor index_cast load log memref_cast
|
||||
syn keyword mlirOps memref_shape_cast mulf muli negf powf prefetch rsqrt sitofp
|
||||
syn keyword mlirOps splat store select sqrt subf subi subview tanh tensor_cast
|
||||
syn keyword mlirOps view
|
||||
|
|
Loading…
Reference in New Issue