[mlir] Start splitting the `tensor` dialect out of `std`.

This starts by moving `std.extract_element` to `tensor.extract` (this
mirrors the naming of `vector.extract`).

Curiously, `std.extract_element` supposedly works on vectors as well,
and this patch removes that functionality. I would tend to do that in
separate patch, but I couldn't find any downstream users relying on
this, and the fact that we have `vector.extract` made it seem safe
enough to lump in here.

This also sets up the `tensor` dialect as a dependency of the `std`
dialect, as some ops that currently live in `std` depend on
`tensor.extract` via their canonicalization patterns.

Part of RFC: https://llvm.discourse.group/t/rfc-split-the-tensor-dialect-from-std/2347/2

Differential Revision: https://reviews.llvm.org/D92991
This commit is contained in:
Sean Silva 2020-12-09 17:50:03 -08:00
parent 7b3470baf8
commit cab8dda90f
42 changed files with 611 additions and 311 deletions

View File

@ -14,5 +14,6 @@ add_subdirectory(SCF)
add_subdirectory(Shape)
add_subdirectory(SPIRV)
add_subdirectory(StandardOps)
add_subdirectory(Tensor)
add_subdirectory(Tosa)
add_subdirectory(Vector)

View File

@ -10,6 +10,7 @@
#include "mlir/Dialect/Linalg/EDSC/Builders.h"
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/FoldUtils.h"
@ -46,7 +47,6 @@ using folded_std_constant_float = FoldedValueBuilder<ConstantFloatOp>;
using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
using folded_std_constant_int = FoldedValueBuilder<ConstantIntOp>;
using folded_std_dim = FoldedValueBuilder<DimOp>;
using folded_std_extract_element = FoldedValueBuilder<ExtractElementOp>;
using folded_std_index_cast = FoldedValueBuilder<IndexCastOp>;
using folded_std_muli = FoldedValueBuilder<MulIOp>;
using folded_std_mulf = FoldedValueBuilder<MulFOp>;
@ -60,6 +60,7 @@ using folded_std_tensor_load = FoldedValueBuilder<TensorLoadOp>;
using folded_std_view = FoldedValueBuilder<ViewOp>;
using folded_std_zero_extendi = FoldedValueBuilder<ZeroExtendIOp>;
using folded_std_sign_extendi = FoldedValueBuilder<SignExtendIOp>;
using folded_tensor_extract = FoldedValueBuilder<tensor::ExtractOp>;
} // namespace intrinsics
} // namespace edsc
} // namespace mlir

View File

@ -9,6 +9,7 @@
#define MLIR_DIALECT_STANDARDOPS_EDSC_INTRINSICS_H_
#include "mlir/Dialect/StandardOps/EDSC/Builders.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
namespace mlir {
namespace edsc {
@ -28,7 +29,6 @@ using std_dealloc = OperationBuilder<DeallocOp>;
using std_divis = ValueBuilder<SignedDivIOp>;
using std_diviu = ValueBuilder<UnsignedDivIOp>;
using std_dim = ValueBuilder<DimOp>;
using std_extract_element = ValueBuilder<ExtractElementOp>;
using std_fpext = ValueBuilder<FPExtOp>;
using std_fptrunc = ValueBuilder<FPTruncOp>;
using std_im = ValueBuilder<ImOp>;
@ -52,6 +52,7 @@ using std_tensor_store = OperationBuilder<TensorStoreOp>;
using std_view = ValueBuilder<ViewOp>;
using std_zero_extendi = ValueBuilder<ZeroExtendIOp>;
using std_sign_extendi = ValueBuilder<SignExtendIOp>;
using tensor_extract = ValueBuilder<tensor::ExtractOp>;
/// Branches into `block` with `operands`.
BranchOp std_br(Block *block, ValueRange operands);

View File

@ -1669,59 +1669,6 @@ def Exp2Op : FloatUnaryOp<"exp2"> {
let summary = "base-2 exponential of the specified value";
}
//===----------------------------------------------------------------------===//
// ExtractElementOp
//===----------------------------------------------------------------------===//
def ExtractElementOp : Std_Op<"extract_element",
[NoSideEffect,
TypesMatchWith<"result type matches element type of aggregate",
"aggregate", "result",
"$_self.cast<ShapedType>().getElementType()">]> {
let summary = "element extract operation";
let description = [{
The `extract_element` op reads a tensor or vector and returns one element
from it specified by an index list. The output of the 'extract_element' is a
new value with the same type as the elements of the tensor or vector. The
arity of indices matches the rank of the accessed value (i.e., if a tensor
is of rank 3, then 3 indices are required for the extract. The indices
should all be of `index` type.
Example:
```mlir
%3 = extract_element %v[%1, %2] : vector<4x4xi32>
%4 = extract_element %t[%1, %2] : tensor<4x4xi32>
%5 = extract_element %ut[%1, %2] : tensor<*xi32>
```
}];
let arguments = (ins AnyTypeOf<[AnyVector, AnyTensor]>:$aggregate,
Variadic<Index>:$indices);
let results = (outs AnyType:$result);
let builders = [
OpBuilderDAG<(ins "Value":$aggregate, CArg<"ValueRange", "{}">:$indices), [{
auto resType = aggregate.getType().cast<ShapedType>()
.getElementType();
build($_builder, $_state, resType, aggregate, indices);
}]>];
let extraClassDeclaration = [{
Value getAggregate() { return getOperand(0); }
operand_range getIndices() {
return {operand_begin() + 1, operand_end()};
}
}];
let hasFolder = 1;
let assemblyFormat = [{
$aggregate `[` $indices `]` attr-dict `:` type($aggregate)
}];
}
//===----------------------------------------------------------------------===//
// TensorFromElementsOp
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@ -0,0 +1,2 @@
add_mlir_dialect(TensorOps tensor)
add_mlir_doc(TensorOps -gen-dialect-doc TensorOps Dialects/)

View File

@ -0,0 +1,31 @@
//===- Tensor.h - Tensor dialect --------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_TENSOR_IR_TENSOR_H_
#define MLIR_DIALECT_TENSOR_IR_TENSOR_H_
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
//===----------------------------------------------------------------------===//
// Tensor Dialect
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tensor/IR/TensorOpsDialect.h.inc"
//===----------------------------------------------------------------------===//
// Tensor Dialect Operations
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/Tensor/IR/TensorOps.h.inc"
#endif // MLIR_DIALECT_TENSOR_IR_TENSOR_H_

View File

@ -0,0 +1,48 @@
//===- TensorBase.td - Base definitions for tensor dialect -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef TENSOR_BASE
#define TENSOR_BASE
include "mlir/IR/OpBase.td"
def Tensor_Dialect : Dialect {
let name = "tensor";
let cppNamespace = "::mlir::tensor";
let description = [{
The `tensor` dialect is intended to hold core tensor creation and
manipulation ops, which are not strongly associated with any particular
other dialect or domain abstraction. The primary smoke test of this is ops
that make sense for any tensor element type.
We leave it to other dialects to hold the vast swath of possible
computations one might want to do on a tensor.
The `tensor` type is (for better or for worse) used to represent all kinds
of things, and supports an open-ended set of element types. Examples:
- representing large, dense aggregations of primitive types, suitable for
high-performance numerical computing.
- representing shapes in the `shape` dialect, which consist of small
1D tensors of `index` data type.
- representing aggregations of strings or variant types.
- representing large, sparse aggregations of primitive types, suitable
for high-performance numerical computing.
Thus, for the `tensor` dialect, we prefer for now to constrain the
scope as much as possible. The expectation is that at some point
in the future, the `tensor` dialects scope may be broadened through a
careful discussion of the tradeoffs.
The `tensor` type is actually a builtin type (it lives in the builtin
dialect), and does not live in this dialect.
}];
}
#endif // TENSOR_BASE

View File

@ -0,0 +1,62 @@
//===- TensorOps.td - Tensor op definitions ----------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef TENSOR_OPS
#define TENSOR_OPS
include "mlir/Dialect/Tensor/IR/TensorBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
class Tensor_Op<string mnemonic, list<OpTrait> traits = []>
: Op<Tensor_Dialect, mnemonic, traits> {
let printer = [{ return ::print(p, *this); }];
let verifier = [{ return ::verify(*this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
//===----------------------------------------------------------------------===//
// ExtractOp
//===----------------------------------------------------------------------===//
def Tensor_ExtractOp : Tensor_Op<"extract",
[NoSideEffect,
TypesMatchWith<"result type matches element type of tensor",
"tensor", "result",
"$_self.cast<ShapedType>().getElementType()">]> {
let summary = "element extraction operation";
let description = [{
The `tensor.extract` op reads a tensor and returns one
element from it specified by an index list. The output of the op is a
new value with the same type as the elements of the tensor. The
arity of indices must match the rank of the accessed value (i.e., if a
tensor is of rank 3, then 3 indices are required for the extract. The
indices should all be of `index` type.
Example:
```mlir
%4 = tensor.extract %t[%1, %2] : tensor<4x4xi32>
%5 = tensor.extract %rt[%1, %2] : tensor<?x?xi32>
%6 = tensor.extract %ut[%1, %2] : tensor<*xi32>
```
}];
let arguments = (ins AnyTensor:$tensor, Variadic<Index>:$indices);
let results = (outs AnyType:$result);
let assemblyFormat = "$tensor `[` $indices `]` attr-dict `:` type($tensor)";
let builders = [
OpBuilderDAG<(ins "Value":$tensor, CArg<"ValueRange", "{}">:$indices), [{
auto resType = tensor.getType().cast<ShapedType>().getElementType();
build($_builder, $_state, resType, tensor, indices);
}]>];
let hasFolder = 1;
}
#endif // TENSOR_OPS

View File

@ -0,0 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Tensor)
add_public_tablegen_target(MLIRTensorTransformsIncGen)
add_mlir_doc(Passes -gen-pass-doc TensorPasses ./)

View File

@ -0,0 +1,38 @@
//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES_H_
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Bufferize.h"
namespace mlir {
class OwningRewritePatternList;
void populateTensorBufferizePatterns(MLIRContext *context,
BufferizeTypeConverter &typeConverter,
OwningRewritePatternList &patterns);
/// Creates an instance of `tensor` dialect bufferization pass.
std::unique_ptr<Pass> createTensorBufferizePass();
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
namespace tensor {
/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
} // namespace tensor
} // end namespace mlir
#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES_H_

View File

@ -0,0 +1,19 @@
//===-- Passes.td - pass definition file -------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES
#define MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES
include "mlir/Pass/PassBase.td"
def TensorBufferize : FunctionPass<"tensor-bufferize"> {
let summary = "Bufferize the `tensor` dialect";
let constructor = "mlir::createTensorBufferizePass()";
}
#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES

View File

@ -35,6 +35,7 @@
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Dialect.h"
@ -66,6 +67,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
ROCDL::ROCDLDialect,
SDBMDialect,
shape::ShapeDialect,
tensor::TensorDialect,
tosa::TosaDialect>();
// clang-format on
}

View File

@ -25,6 +25,7 @@
#include "mlir/Dialect/SPIRV/Passes.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Transforms/Passes.h"
@ -57,6 +58,7 @@ inline void registerAllPasses() {
registerShapePasses();
spirv::registerSPIRVPasses();
registerStandardPasses();
tensor::registerTensorPasses();
tosa::registerTosaOptPasses();
}

View File

@ -12,6 +12,7 @@
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
@ -64,10 +65,10 @@ public:
rewriter.create<scf::ForOp>(
loc, rankDiff, greaterRank, one, llvm::None,
[&](OpBuilder &b, Location loc, Value iv, ValueRange) {
Value greaterRankOperandExtent = b.create<ExtractElementOp>(
Value greaterRankOperandExtent = b.create<tensor::ExtractOp>(
loc, greaterRankOperand, ValueRange{iv});
Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
Value lesserRankOperandExtent = b.create<ExtractElementOp>(
Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
loc, lesserRankOperand, ValueRange{ivShifted});
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(

View File

@ -12,6 +12,7 @@
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/DialectConversion.h"
@ -118,12 +119,12 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
Value outputDimension = args[0];
Value isUnchallengedDimension = b.create<CmpIOp>(
loc, CmpIPredicate::ult, outputDimension, rankDiff);
Value greaterRankOperandExtent = b.create<ExtractElementOp>(
Value greaterRankOperandExtent = b.create<tensor::ExtractOp>(
loc, greaterRankOperand, outputDimension);
// The initial dimensions of the greater-rank operand are unchallenged,
// so we can take them as-is. Otherwise, we need to do a comparison.
// We need an actual branch here (instead of a select) because the
// lesser-rank operand might be rank 0, so any extract_element would be
// lesser-rank operand might be rank 0, so any tensor.extract would be
// invalid.
auto ifOp = b.create<IfOp>(
loc, TypeRange{indexTy}, isUnchallengedDimension,
@ -140,7 +141,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
// dimensions of zero extent.
Value lesserRankOperandDimension =
b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff);
Value lesserRankOperandExtent = b.create<ExtractElementOp>(
Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
loc, lesserRankOperand,
ValueRange{lesserRankOperandDimension});
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
@ -262,12 +263,12 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
auto reduceResult = rewriter.create<ForOp>(
loc, rankDiff, greaterRank, one, ValueRange{init},
[&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
Value greaterRankOperandExtent =
b.create<ExtractElementOp>(loc, greaterRankOperand, ValueRange{iv});
Value greaterRankOperandExtent = b.create<tensor::ExtractOp>(
loc, greaterRankOperand, ValueRange{iv});
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
Value lesserRankOperandExtent = b.create<ExtractElementOp>(
Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
loc, lesserRankOperand, ValueRange{ivShifted});
Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
@ -316,9 +317,9 @@ LogicalResult GetExtentOpConverter::matchAndRewrite(
}
}
rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(),
transformed.shape(),
ValueRange{transformed.dim()});
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
transformed.shape(),
ValueRange{transformed.dim()});
return success();
}
@ -375,7 +376,8 @@ ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
auto loop = rewriter.create<scf::ForOp>(
loc, zero, rank, one, op.initVals(),
[&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
Value extent = b.create<ExtractElementOp>(loc, transformed.shape(), iv);
Value extent =
b.create<tensor::ExtractOp>(loc, transformed.shape(), iv);
SmallVector<Value, 2> mappedValues{iv, extent};
mappedValues.append(args.begin(), args.end());
@ -415,8 +417,8 @@ namespace {
/// %c1 = constant 1 : index
/// %true = constant true
/// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
/// %5 = extract_element %arg0[%arg2] : tensor<?xindex>
/// %6 = extract_element %arg1[%arg2] : tensor<?xindex>
/// %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
/// %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
/// %7 = cmpi "eq", %5, %6 : index
/// %8 = and %arg3, %7 : i1
/// scf.yield %8 : i1
@ -465,9 +467,9 @@ ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
[&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
Value conj = args[0];
Value lhsExtent =
b.create<ExtractElementOp>(loc, transformed.lhs(), iv);
b.create<tensor::ExtractOp>(loc, transformed.lhs(), iv);
Value rhsExtent =
b.create<ExtractElementOp>(loc, transformed.rhs(), iv);
b.create<tensor::ExtractOp>(loc, transformed.rhs(), iv);
Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
lhsExtent, rhsExtent);
Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
@ -584,7 +586,8 @@ void ConvertShapeToStandardPass::runOnOperation() {
// Setup target legality.
MLIRContext &ctx = getContext();
ConversionTarget target(ctx);
target.addLegalDialect<StandardOpsDialect, SCFDialect>();
target
.addLegalDialect<StandardOpsDialect, SCFDialect, tensor::TensorDialect>();
target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp, ModuleTerminatorOp>();
// Setup conversion patterns.

View File

@ -15,6 +15,7 @@ add_subdirectory(SDBM)
add_subdirectory(Shape)
add_subdirectory(SPIRV)
add_subdirectory(StandardOps)
add_subdirectory(Tensor)
add_subdirectory(Tosa)
add_subdirectory(Vector)

View File

@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRStandard
MLIREDSC
MLIRIR
MLIRSideEffectInterfaces
MLIRTensor
MLIRVectorInterfaces
MLIRViewLikeInterface
)

View File

@ -9,6 +9,7 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BlockAndValueMapping.h"
@ -153,6 +154,7 @@ static LogicalResult verifyCastOp(T op) {
}
void StandardOpsDialect::initialize() {
getContext()->loadDialect<tensor::TensorDialect>();
addOperations<DmaStartOp, DmaWaitOp,
#define GET_OP_LIST
#include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
@ -1863,18 +1865,18 @@ struct StaticDynamicTensorFromElements
/// <computation>
/// yield %1 : index
/// } : tensor<?xindex>
/// %extracted_element = extract_element %tensor[%c0] : tensor<?xi32>
/// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
///
/// to just <computation> with %arg0 replaced by %c0. We only do this if the
/// dynamic_tensor_from_elements operation has no side-effects.
struct ExtractElementFromDynamicTensorFromElements
: public OpRewritePattern<ExtractElementOp> {
using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
struct ExtractFromDynamicTensorFromElements
: public OpRewritePattern<tensor::ExtractOp> {
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractElementOp extract,
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
PatternRewriter &rewriter) const final {
auto tensorFromElements =
extract.aggregate().getDefiningOp<DynamicTensorFromElementsOp>();
extract.tensor().getDefiningOp<DynamicTensorFromElementsOp>();
if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
return failure();
@ -1894,23 +1896,22 @@ struct ExtractElementFromDynamicTensorFromElements
/// Canonicalizes the pattern of the form
///
/// %val = tensor_cast %source : : tensor<?xi32> to tensor<2xi32>
/// %extracted_element = extract_element %val[%c0] : tensor<2xi32>
/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
///
/// to
///
/// %extracted_element = extract_element %source[%c0] : tensor<?xi32>
struct ExtractElementFromTensorCast
: public OpRewritePattern<ExtractElementOp> {
using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractElementOp extract,
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
PatternRewriter &rewriter) const final {
auto tensorCast = extract.aggregate().getDefiningOp<TensorCastOp>();
auto tensorCast = extract.tensor().getDefiningOp<TensorCastOp>();
if (!tensorCast)
return failure();
rewriter.replaceOpWithNewOp<ExtractElementOp>(extract, tensorCast.source(),
extract.getIndices());
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(),
extract.indices());
return success();
}
};
@ -1919,51 +1920,9 @@ struct ExtractElementFromTensorCast
void DynamicTensorFromElementsOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ExtractElementFromDynamicTensorFromElements,
ExtractElementFromTensorCast, StaticDynamicTensorFromElements>(
context);
}
//===----------------------------------------------------------------------===//
// ExtractElementOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(ExtractElementOp op) {
// Verify the # indices match if we have a ranked type.
auto aggregateType = op.getAggregate().getType().cast<ShapedType>();
if (aggregateType.hasRank() &&
aggregateType.getRank() != op.getNumOperands() - 1)
return op.emitOpError("incorrect number of indices for extract_element");
return success();
}
OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
assert(!operands.empty() && "extract_element takes at least one operand");
// The aggregate operand must be a known constant.
Attribute aggregate = operands.front();
if (!aggregate)
return {};
// If this is a splat elements attribute, simply return the value. All of the
// elements of a splat attribute are the same.
if (auto splatAggregate = aggregate.dyn_cast<SplatElementsAttr>())
return splatAggregate.getSplatValue();
// Otherwise, collect the constant indices into the aggregate.
SmallVector<uint64_t, 8> indices;
for (Attribute indice : llvm::drop_begin(operands, 1)) {
if (!indice || !indice.isa<IntegerAttr>())
return {};
indices.push_back(indice.cast<IntegerAttr>().getInt());
}
// If this is an elements attribute, query the value at the given indices.
auto elementsAttr = aggregate.dyn_cast<ElementsAttr>();
if (elementsAttr && elementsAttr.isValidIndex(indices))
return elementsAttr.getValue(indices);
return {};
// TODO: Move extract patterns to tensor::ExtractOp.
results.insert<ExtractFromDynamicTensorFromElements, ExtractFromTensorCast,
StaticDynamicTensorFromElements>(context);
}
//===----------------------------------------------------------------------===//
@ -1989,20 +1948,20 @@ namespace {
// Canonicalizes the pattern of the form
//
// %tensor = "tensor_from_elements(%element) : (i32) -> tensor<1xi32>
// %extracted_element = extract_element %tensor[%c0] : tensor<1xi32>
// %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
//
// to just %element.
struct ExtractElementFromTensorFromElements
: public OpRewritePattern<ExtractElementOp> {
using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
: public OpRewritePattern<tensor::ExtractOp> {
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractElementOp extract,
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
PatternRewriter &rewriter) const final {
if (extract.indices().size() != 1)
return failure();
auto tensorFromElements = dyn_cast_or_null<TensorFromElementsOp>(
extract.aggregate().getDefiningOp());
extract.tensor().getDefiningOp());
if (tensorFromElements == nullptr)
return failure();
@ -2216,7 +2175,7 @@ OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
}
namespace {
/// Fold a load on a tensor_to_memref operation into an extract_element on the
/// Fold a load on a tensor_to_memref operation into an tensor.extract on the
/// corresponding tensor.
struct LoadOfTensorToMemref : public OpRewritePattern<LoadOp> {
using OpRewritePattern<LoadOp>::OpRewritePattern;
@ -2227,8 +2186,8 @@ struct LoadOfTensorToMemref : public OpRewritePattern<LoadOp> {
if (!tensorToMemref)
return failure();
rewriter.replaceOpWithNewOp<ExtractElementOp>(load, tensorToMemref.tensor(),
load.indices());
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
load, tensorToMemref.tensor(), load.indices());
return success();
}
};

View File

@ -15,6 +15,7 @@
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/DialectConversion.h"
@ -88,21 +89,6 @@ public:
};
} // namespace
namespace {
class BufferizeExtractElementOp : public OpConversionPattern<ExtractElementOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ExtractElementOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
ExtractElementOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<LoadOp>(op, adaptor.aggregate(),
adaptor.indices());
return success();
}
};
} // namespace
namespace {
class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
public:
@ -165,7 +151,6 @@ void mlir::populateStdBufferizePatterns(MLIRContext *context,
// clang-format off
BufferizeDimOp,
BufferizeDynamicTensorFromElementsOp,
BufferizeExtractElementOp,
BufferizeSelectOp,
BufferizeTensorCastOp,
BufferizeTensorFromElementsOp
@ -183,10 +168,11 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
target.addLegalDialect<StandardOpsDialect>();
target.addLegalDialect<scf::SCFDialect>();
target.addLegalDialect<tensor::TensorDialect>();
populateStdBufferizePatterns(context, typeConverter, patterns);
target.addIllegalOp<DynamicTensorFromElementsOp, ExtractElementOp,
TensorCastOp, TensorFromElementsOp>();
target.addIllegalOp<DynamicTensorFromElementsOp, TensorCastOp,
TensorFromElementsOp>();
// We only bufferize the case of tensor selected type and scalar condition,
// as that boils down to a select over memref descriptors (don't need to
// touch the data).

View File

@ -0,0 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@ -0,0 +1,17 @@
add_mlir_dialect_library(MLIRTensor
TensorDialect.cpp
TensorOps.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/mlir/Dialect/Tensor
DEPENDS
MLIRTensorOpsIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRSupport
)

View File

@ -0,0 +1,39 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/InliningUtils.h"
using namespace mlir;
using namespace mlir::tensor;
//===----------------------------------------------------------------------===//
// TensorDialect Dialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
struct TensorInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
BlockAndValueMapping &valueMapping) const final {
return true;
}
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
BlockAndValueMapping &) const final {
return true;
}
};
} // end anonymous namespace
void TensorDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
>();
addInterfaces<TensorInlinerInterface>();
}

View File

@ -0,0 +1,60 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/STLExtras.h"
using namespace mlir;
using namespace mlir::tensor;
//===----------------------------------------------------------------------===//
// ExtractOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(ExtractOp op) {
// Verify the # indices match if we have a ranked type.
if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
return op.emitOpError("incorrect number of indices for extract_element");
return success();
}
OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
// The tensor operand must be a known constant.
Attribute tensor = operands.front();
if (!tensor)
return {};
// If this is a splat elements attribute, simply return the value. All of the
// elements of a splat attribute are the same.
if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
return splatTensor.getSplatValue();
// Otherwise, collect the constant indices into the tensor.
SmallVector<uint64_t, 8> indices;
for (Attribute indice : llvm::drop_begin(operands, 1)) {
if (!indice || !indice.isa<IntegerAttr>())
return {};
indices.push_back(indice.cast<IntegerAttr>().getInt());
}
// If this is an elements attribute, query the value at the given indices.
auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
if (elementsAttr && elementsAttr.isValidIndex(indices))
return elementsAttr.getValue(indices);
return {};
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"

View File

@ -0,0 +1,64 @@
//===- Bufferize.cpp - Bufferization for `tensor` dialect ops -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements bufferization of `tensor` dialect ops
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Bufferize.h"
#include "PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
namespace {
class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::ExtractOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
tensor::ExtractOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<LoadOp>(op, adaptor.tensor(),
adaptor.indices());
return success();
}
};
} // namespace
void mlir::populateTensorBufferizePatterns(
MLIRContext *context, BufferizeTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
patterns.insert<BufferizeExtractOp>(typeConverter, context);
}
namespace {
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
void runOnFunction() override {
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
OwningRewritePatternList patterns;
ConversionTarget target(*context);
populateTensorBufferizePatterns(context, typeConverter, patterns);
target.addIllegalOp<tensor::ExtractOp>();
target.addLegalDialect<StandardOpsDialect>();
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
std::unique_ptr<Pass> mlir::createTensorBufferizePass() {
return std::make_unique<TensorBufferizePass>();
}

View File

@ -0,0 +1,15 @@
add_mlir_dialect_library(MLIRTensorTransforms
Bufferize.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Transforms
DEPENDS
MLIRTensorTransformsIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRTensor
MLIRTransforms
)

View File

@ -0,0 +1,21 @@
//===- PassDetail.h - GPU Pass class details --------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef DIALECT_TENSOR_TRANSFORMS_PASSDETAIL_H_
#define DIALECT_TENSOR_TRANSFORMS_PASSDETAIL_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
#define GEN_PASS_CLASSES
#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
} // end namespace mlir
#endif // DIALECT_TENSOR_TRANSFORMS_PASSDETAIL_H_

View File

@ -13,6 +13,7 @@
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
@ -59,6 +60,23 @@ static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
assert(matchPattern(constOp, m_Constant()));
return constOp;
}
// TODO: To faciliate splitting the std dialect (PR48490), have a special case
// for falling back to std.constant. Eventually, we will have separate ops
// tensor.constant, int.constant, float.constant, etc. that live in their
// respective dialects, which will allow each dialect to implement the
// materializeConstant hook above.
//
// The special case is needed because in the interim state while we are
// splitting out those dialects from std, the std dialect depends on the
// tensor dialect, which makes it impossible for the tensor dialect to use
// std.constant (it would be a cyclic dependency) as part of its
// materializeConstant hook.
//
// If the dialect is unable to materialize a constant, check to see if the
// standard constant can be used.
if (ConstantOp::isBuildableWith(value, type))
return builder.create<ConstantOp>(loc, type, value);
return nullptr;
}

View File

@ -16,9 +16,9 @@
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[RHS]], %[[LHS]] : tensor<?xindex>
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
// CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] {
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
// CHECK: %[[IVSHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[IVSHIFTED]]] : tensor<?xindex>
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[LESSER_RANK_OPERAND]][%[[IVSHIFTED]]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[LESSER_RANK_OPERAND_EXTENT]], %[[C1]] : index
// CHECK: %[[EXTENTS_AGREE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[LESSER_RANK_OPERAND_EXTENT]] : index

View File

@ -74,12 +74,12 @@ func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : index) -> index {
// -----
// Express `get_extent` as `std.extract_element`.
// Express `get_extent` as `std.tensor.extract`.
// CHECK-LABEL: @get_extent_from_extent_tensor
// CHECK-SAME: (%[[EXTENTS:.*]]: tensor<?xindex>, %[[IDX:.*]]: index) -> index
func @get_extent_from_extent_tensor(%extents : tensor<?xindex>, %idx : index)
-> index {
// CHECK: %[[RESULT:.*]] = extract_element %[[EXTENTS]][%[[IDX]]] : tensor<?xindex>
// CHECK: %[[RESULT:.*]] = tensor.extract %[[EXTENTS]][%[[IDX]]] : tensor<?xindex>
// CHECK: return %[[RESULT]] : index
%result = shape.get_extent %extents, %idx : tensor<?xindex>, index -> index
return %result : index
@ -180,7 +180,7 @@ func @shape_reduce(%shape : tensor<?xindex>) -> index {
// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
// CHECK-NEXT: %[[RANK:.*]] = dim %[[SHAPE]], %[[C0]] : tensor<?xindex>
// CHECK-NEXT: %[[RESULT:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) -> (index)
// CHECK-NEXT: %[[EXTENT:.*]] = extract_element %[[SHAPE]][%[[I]]]
// CHECK-NEXT: %[[EXTENT:.*]] = tensor.extract %[[SHAPE]][%[[I]]]
// CHECK-NEXT: %[[NEW_ACC:.*]] = muli %[[ACC]], %[[EXTENT]] : index
// CHECK-NEXT: scf.yield %[[NEW_ACC]] : index
// CHECK-NEXT: }
@ -277,8 +277,8 @@ func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>) -> i1 {
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[INIT:.*]] = constant true
// CHECK: %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) {
// CHECK: %[[EXTENT_A:.*]] = extract_element %[[A]][%[[I]]] : tensor<?xindex>
// CHECK: %[[EXTENT_B:.*]] = extract_element %[[B]][%[[I]]] : tensor<?xindex>
// CHECK: %[[EXTENT_A:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xindex>
// CHECK: %[[EXTENT_B:.*]] = tensor.extract %[[B]][%[[I]]] : tensor<?xindex>
// CHECK: %[[EXTENT_EQ:.*]] = cmpi "eq", %[[EXTENT_A]], %[[EXTENT_B]]
// CHECK: %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]]
// CHECK: scf.yield %[[CONJ_NEXT]] : i1
@ -324,12 +324,12 @@ func @broadcast_unknown_extents(%a : tensor<?xindex>, %b : tensor<?xindex>) {
// CHECK: %[[RESULT:.*]] = dynamic_tensor_from_elements %[[GREATER_RANK]] {
// CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index):
// CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi "ult", %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
// CHECK: %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) {
// CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
// CHECK: } else {
// CHECK: %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex>
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
// CHECK: %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index
// CHECK: scf.yield %[[BROADCASTED_EXTENT]] : index
@ -364,12 +364,12 @@ func @broadcast_known_different_extents(%a : tensor<2xindex>, %b : tensor<3xinde
// CHECK: %[[RESULT:.*]] = dynamic_tensor_from_elements %[[GREATER_RANK]] {
// CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index):
// CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi "ult", %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
// CHECK: %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) {
// CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
// CHECK: } else {
// CHECK: %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex>
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
// CHECK: %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index
// CHECK: scf.yield %[[BROADCASTED_EXTENT]] : index
@ -407,10 +407,10 @@ func @try_is_broadcastable(%a : tensor<3xindex>, %b : tensor<?xindex>) -> i1 {
// CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index
// CHECK: %[[TRUE:.*]] = constant true
// CHECK: %[[ALL_RESULT:.*]] = scf.for %[[I:.*]] = %[[RANK_DIFF]] to %[[LARGER_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) {
// CHECK: %[[LARGER_EXTENT:.*]] = extract_element %[[LARGER_SHAPE]]{{\[}}%[[I]]] : tensor<?xindex>
// CHECK: %[[LARGER_EXTENT:.*]] = tensor.extract %[[LARGER_SHAPE]]{{\[}}%[[I]]] : tensor<?xindex>
// CHECK: %[[LARGER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[C1]] : index
// CHECK: %[[SMALLER_EXTENT_INDEX:.*]] = subi %[[I]], %[[RANK_DIFF]] : index
// CHECK: %[[SMALLER_EXTENT:.*]] = extract_element %[[SMALLER_SHAPE]]{{\[}}%[[SMALLER_EXTENT_INDEX]]] : tensor<?xindex>
// CHECK: %[[SMALLER_EXTENT:.*]] = tensor.extract %[[SMALLER_SHAPE]]{{\[}}%[[SMALLER_EXTENT_INDEX]]] : tensor<?xindex>
// CHECK: %[[SMALLER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[SMALLER_EXTENT]], %[[C1]] : index
// CHECK: %[[EXTENTS_ARE_EQUAL:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[SMALLER_EXTENT]] : index
// CHECK: %[[EITHER_EXTENT_IS_ONE:.*]] = or %[[LARGER_EXTENT_IS_ONE]], %[[SMALLER_EXTENT_IS_ONE]] : i1
@ -445,10 +445,10 @@ func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) -> !shape.witness {
// CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index
// CHECK: %[[TRUE:.*]] = constant true
// CHECK: %[[ALL_RESULT:.*]] = scf.for %[[VAL_16:.*]] = %[[RANK_DIFF]] to %[[LARGER_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) {
// CHECK: %[[LARGER_EXTENT:.*]] = extract_element %[[LARGER_SHAPE]]{{\[}}%[[VAL_16]]] : tensor<?xindex>
// CHECK: %[[LARGER_EXTENT:.*]] = tensor.extract %[[LARGER_SHAPE]]{{\[}}%[[VAL_16]]] : tensor<?xindex>
// CHECK: %[[LARGER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[C1]] : index
// CHECK: %[[LHS_EXTENT_INDEX:.*]] = subi %[[VAL_16]], %[[RANK_DIFF]] : index
// CHECK: %[[SMALLER_EXTENT:.*]] = extract_element %[[SMALLER_SHAPE]]{{\[}}%[[LHS_EXTENT_INDEX]]] : tensor<?xindex>
// CHECK: %[[SMALLER_EXTENT:.*]] = tensor.extract %[[SMALLER_SHAPE]]{{\[}}%[[LHS_EXTENT_INDEX]]] : tensor<?xindex>
// CHECK: %[[SMALLER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[SMALLER_EXTENT]], %[[C1]] : index
// CHECK: %[[EXTENTS_ARE_EQUAL:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[SMALLER_EXTENT]] : index
// CHECK: %[[EITHER_EXTENT_IS_ONE:.*]] = or %[[LARGER_EXTENT_IS_ONE]], %[[SMALLER_EXTENT_IS_ONE]] : i1

View File

@ -395,7 +395,7 @@ func @scalar_indexed_generic_fusion
ins(%arg1 : tensor<i32>) {
^bb0(%arg2: i32): // no predecessors
%3 = index_cast %arg2 : i32 to index
%4 = extract_element %arg0[%3, %c0, %c0] : tensor<5x1x1xf32>
%4 = tensor.extract %arg0[%3, %c0, %c0] : tensor<5x1x1xf32>
linalg.yield %4 : f32
} -> tensor<f32>
%1 = linalg.generic
@ -418,6 +418,6 @@ func @scalar_indexed_generic_fusion
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: iterator_types = ["parallel"]
// CHECK-SAME: ins(%[[ARG1]] : tensor<i32>)
// CHECK: extract_element %[[ARG0]]
// CHECK: tensor.extract %[[ARG0]]
// CHECK: linalg.yield
// CHECK return %[[T0]]

View File

@ -61,18 +61,6 @@ func @dynamic_tensor_from_elements_static_and_dynamic(%arg0: index) -> tensor<16
return %result : tensor<16x?xindex>
}
// CHECK-LABEL: func @extract_element(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) -> f32 {
// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<?xf32>
// CHECK: %[[RET:.*]] = load %[[MEMREF]][%[[IDX]]] : memref<?xf32>
// CHECK: return %[[RET]] : f32
// CHECK: }
func @extract_element(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
%0 = extract_element %arg0[%arg1] : tensor<?xf32>
return %0 : f32
}
// CHECK-LABEL: func @select(
// CHECK-SAME: %[[PRED:.*]]: i1,
// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor<f32>,
@ -138,14 +126,14 @@ func @tensor_from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> {
// The dynamic_tensor_from_elements op clones each op in its body.
// Make sure that regions nested within such ops are recursively converted.
// CHECK-LABEL: func @recursively_convert_cloned_regions
func @recursively_convert_cloned_regions(%arg0: tensor<?xindex>, %arg1: index, %arg2: i1) -> tensor<?xindex> {
func @recursively_convert_cloned_regions(%arg0: tensor<*xf32>, %arg1: index, %arg2: i1) -> tensor<?xindex> {
%tensor = dynamic_tensor_from_elements %arg1 {
^bb0(%iv: index):
%48 = scf.if %arg2 -> (index) {
scf.yield %iv : index
} else {
// CHECK-NOT: extract_element
%50 = extract_element %arg0[%iv] : tensor<?xindex>
// CHECK-NOT: dim{{.*}}tensor
%50 = dim %arg0, %iv : tensor<*xf32>
scf.yield %50 : index
}
yield %48 : index

View File

@ -46,11 +46,11 @@ func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index {
}
// Test case: Folding of load(tensor_to_memref(%v, %idxs))
// -> extract_element(%v, %idx)
// -> tensor.extract(%v, %idx)
// CHECK-LABEL: func @load_from_tensor_to_memref(
// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
// CHECK: %[[RES:.*]] = extract_element %[[TENSOR]][%[[IDX0]], %[[IDX1]]]
// CHECK: %[[RES:.*]] = tensor.extract %[[TENSOR]][%[[IDX0]], %[[IDX1]]]
// CHECK-NOT: load
// CHECK: return %[[RES]] : f32
func @load_from_tensor_to_memref(%arg0: index, %arg1: index, %arg2: tensor<?x?xf32>) -> f32 {

View File

@ -0,0 +1,13 @@
// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s
// CHECK-LABEL: func @extract(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) -> f32 {
// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<?xf32>
// CHECK: %[[RET:.*]] = load %[[MEMREF]][%[[IDX]]] : memref<?xf32>
// CHECK: return %[[RET]] : f32
// CHECK: }
func @extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
%0 = tensor.extract %arg0[%arg1] : tensor<?xf32>
return %0 : f32
}

View File

@ -0,0 +1,33 @@
// RUN: mlir-opt %s -canonicalize | FileCheck %s
// -----
// CHECK-LABEL: func @fold_extract
func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32) {
%const_0 = constant 0 : index
%const_1 = constant 1 : index
%const_3 = constant 3 : index
// Fold an extract into a splat.
// CHECK-NEXT: [[C4:%.+]] = constant 4.{{0*}}e+00 : f32
%0 = constant dense<4.0> : tensor<4xf32>
%ext_1 = tensor.extract %0[%arg0] : tensor<4xf32>
// Fold an extract into a sparse with a sparse index.
// CHECK-NEXT: [[CM2:%.+]] = constant -2.{{0*}}e+00 : f16
%1 = constant sparse<[[0, 0, 0], [1, 1, 1]], [-5.0, -2.0]> : tensor<4x4x4xf16>
%ext_2 = tensor.extract %1[%const_1, %const_1, %const_1] : tensor<4x4x4xf16>
// Fold an extract into a sparse with a non sparse index.
// CHECK-NEXT: [[C0:%.+]] = constant 0.{{0*}}e+00 : f16
%2 = constant sparse<[[1, 1, 1]], [-2.0]> : tensor<1x1x1xf16>
%ext_3 = tensor.extract %2[%const_0, %const_0, %const_0] : tensor<1x1x1xf16>
// Fold an extract into a dense tensor.
// CHECK-NEXT: [[C64:%.+]] = constant 64 : i32
%3 = constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
%ext_4 = tensor.extract %3[%const_1, %const_0, %const_3] : tensor<2x1x4xi32>
// CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]]
return %ext_1, %ext_2, %ext_3, %ext_4 : f32, f16, f16, i32
}

View File

@ -0,0 +1,9 @@
// RUN: mlir-opt <%s -verify-diagnostics
// -----
func @extract_too_many_indices(%arg0: tensor<?xf32>) {
// expected-error@+1 {{incorrect number of indices for extract_element}}
%0 = tensor.extract %arg0[] : tensor<?xf32>
return
}

View File

@ -0,0 +1,10 @@
// RUN: mlir-opt <%s | mlir-opt | FileCheck %s
// CHECK-LABEL: func @extract(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?x?x?xf32>,
// CHECK-SAME: %[[INDEX:.*]]: index) {
func @extract(%arg0: tensor<?x?x?xf32>, %arg1: index) {
// CHECK: tensor.extract %[[TENSOR]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<?x?x?xf32>
%0 = tensor.extract %arg0[%arg1, %arg1, %arg1] : tensor<?x?x?xf32>
return
}

View File

@ -672,19 +672,6 @@ func @calls(%arg0: i32) {
return
}
// CHECK-LABEL: func @extract_element(%arg0: tensor<*xi32>, %arg1: tensor<4x4xf32>) -> i32 {
func @extract_element(%arg0: tensor<*xi32>, %arg1 : tensor<4x4xf32>) -> i32 {
%c0 = "std.constant"() {value = 0: index} : () -> index
// CHECK: %0 = extract_element %arg0[%c0, %c0, %c0, %c0] : tensor<*xi32>
%0 = extract_element %arg0[%c0, %c0, %c0, %c0] : tensor<*xi32>
// CHECK: %1 = extract_element %arg1[%c0, %c0] : tensor<4x4xf32>
%1 = extract_element %arg1[%c0, %c0] : tensor<4x4xf32>
return %0 : i32
}
// CHECK-LABEL: func @tensor_from_elements() {
func @tensor_from_elements() {
%c0 = "std.constant"() {value = 0: index} : () -> index
@ -972,4 +959,3 @@ func @subtensor_insert(%t: tensor<8x16x4xf32>, %t2: tensor<16x32x8xf32>, %idx :
// CHECK-LABEL: func private @legacy_visibility_syntax
func @legacy_visibility_syntax() attributes { sym_visibility = "private" }

View File

@ -541,61 +541,6 @@ func @cmpf_canonical_type_mismatch(%a : f32, %b : f64) { // expected-note {{prio
// -----
func @extract_element_no_operands() {
// expected-error@+1 {{op expected 1 or more operands}}
%0 = "std.extract_element"() : () -> f32
return
}
// -----
func @extract_element_no_indices(%v : vector<3xf32>) {
// expected-error@+1 {{incorrect number of indices for extract_element}}
%0 = "std.extract_element"(%v) : (vector<3xf32>) -> f32
return
}
// -----
func @extract_element_invalid_index_type(%v : vector<3xf32>, %i : i32) {
// expected-error@+1 {{operand #1 must be index}}
%0 = "std.extract_element"(%v, %i) : (vector<3xf32>, i32) -> f32
return
}
// -----
func @extract_element_element_result_type_mismatch(%v : vector<3xf32>, %i : index) {
// expected-error@+1 {{result type matches element type of aggregate}}
%0 = "std.extract_element"(%v, %i) : (vector<3xf32>, index) -> f64
return
}
// -----
func @extract_element_vector_too_many_indices(%v : vector<3xf32>, %i : index) {
// expected-error@+1 {{incorrect number of indices for extract_element}}
%0 = "std.extract_element"(%v, %i, %i) : (vector<3xf32>, index, index) -> f32
return
}
// -----
func @extract_element_tensor_too_many_indices(%t : tensor<2x3xf32>, %i : index) {
// expected-error@+1 {{incorrect number of indices for extract_element}}
%0 = "std.extract_element"(%t, %i, %i, %i) : (tensor<2x3xf32>, index, index, index) -> f32
return
}
// -----
func @extract_element_tensor_too_few_indices(%t : tensor<2x3xf32>, %i : index) {
// expected-error@+1 {{incorrect number of indices for extract_element}}
%0 = "std.extract_element"(%t, %i) : (tensor<2x3xf32>, index) -> f32 return
}
// -----
func @tensor_from_elements_wrong_result_type() {
// expected-error@+2 {{'result' must be 1D tensor of any type values, but got 'tensor<*xi32>'}}
%c0 = constant 0 : i32

View File

@ -1040,21 +1040,21 @@ func @memref_cast_folding_subview_static(%V: memref<16x16xf32>, %a: index, %b: i
// -----
// CHECK-LABEL: func @extract_element_from_tensor_from_elements
func @extract_element_from_tensor_from_elements(%element : index) -> index {
// CHECK-LABEL: func @extract_from_tensor_from_elements
func @extract_from_tensor_from_elements(%element : index) -> index {
// CHECK-SAME: ([[ARG:%.*]]: index)
%c0 = constant 0 : index
%tensor = tensor_from_elements %element : tensor<1xindex>
%extracted_element = extract_element %tensor[%c0] : tensor<1xindex>
%extracted_element = tensor.extract %tensor[%c0] : tensor<1xindex>
// CHECK: [[ARG]] : index
return %extracted_element : index
}
// -----
// CHECK-LABEL: func @extract_element_from_dynamic_tensor_from_elements
// CHECK-LABEL: func @extract_from_dynamic_tensor_from_elements
// CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
func @extract_element_from_dynamic_tensor_from_elements(%idx: index, %tensor: tensor<*xf32>) -> index {
func @extract_from_dynamic_tensor_from_elements(%idx: index, %tensor: tensor<*xf32>) -> index {
%size = rank %tensor : tensor<*xf32>
// CHECK-NEXT: %[[RES:.*]] = dim %[[TENSOR]], %[[IDX]]
%0 = dynamic_tensor_from_elements %size {
@ -1062,16 +1062,16 @@ func @extract_element_from_dynamic_tensor_from_elements(%idx: index, %tensor: te
%1 = dim %tensor, %arg0 : tensor<*xf32>
yield %1 : index
} : tensor<?xindex>
%1 = extract_element %0[%idx] : tensor<?xindex>
%1 = tensor.extract %0[%idx] : tensor<?xindex>
// CHECK-NEXT: return %[[RES]]
return %1 : index
}
// -----
// CHECK-LABEL: func @extract_element_from_dynamic_tensor_from_elements_2d
// CHECK-LABEL: func @extract_from_dynamic_tensor_from_elements_2d
// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
func @extract_element_from_dynamic_tensor_from_elements_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index {
func @extract_from_dynamic_tensor_from_elements_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index {
%size = rank %tensor : tensor<*xf32>
// CHECK-NEXT: %[[DIM0:.*]] = dim %[[TENSOR]], %[[IDX0]]
// CHECK-NEXT: %[[DIM1:.*]] = dim %[[TENSOR]], %[[IDX1]]
@ -1083,16 +1083,16 @@ func @extract_element_from_dynamic_tensor_from_elements_2d(%idx0: index, %idx1:
%3 = addi %1, %2 : index
yield %3 : index
} : tensor<?x?xindex>
%4 = extract_element %0[%idx0, %idx1] : tensor<?x?xindex>
%4 = tensor.extract %0[%idx0, %idx1] : tensor<?x?xindex>
// CHECK-NEXT: return %[[RES]]
return %4 : index
}
// -----
// CHECK-LABEL: func @extract_element_from_dynamic_tensor_from_elements_sideeffects
// CHECK-LABEL: func @extract_from_dynamic_tensor_from_elements_sideeffects
// CHECK-SAME: %[[IDX:.*]]: index
func @extract_element_from_dynamic_tensor_from_elements_sideeffects(%idx: index, %tensor: tensor<*xf32>) -> index {
func @extract_from_dynamic_tensor_from_elements_sideeffects(%idx: index, %tensor: tensor<*xf32>) -> index {
%size = rank %tensor : tensor<*xf32>
%mem = alloc(%size) : memref<?xindex>
// CHECK: %[[DTENSOR:.*]] = dynamic_tensor_from_elements
@ -1102,8 +1102,8 @@ func @extract_element_from_dynamic_tensor_from_elements_sideeffects(%idx: index,
store %1, %mem[%arg0] : memref<?xindex>
yield %1 : index
} : tensor<?xindex>
// CHECK: %[[RES:.*]] = extract_element %[[DTENSOR]][%[[IDX]]]
%1 = extract_element %0[%idx] : tensor<?xindex>
// CHECK: %[[RES:.*]] = tensor.extract %[[DTENSOR]][%[[IDX]]]
%1 = tensor.extract %0[%idx] : tensor<?xindex>
// CHECK-NEXT: return %[[RES]]
return %1 : index
}
@ -1205,14 +1205,14 @@ func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index)
// -----
// CHECK-LABEL: func @extract_element_from_tensor_cast
// CHECK-LABEL: func @extract_from_tensor_cast
// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>
func @extract_element_from_tensor_cast(%tensor: tensor<*xf32>) -> f32 {
func @extract_from_tensor_cast(%tensor: tensor<*xf32>) -> f32 {
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
%c0 = constant 0 : index
// CHECK-NOT: tensor_cast
%casted = tensor_cast %tensor : tensor<*xf32> to tensor<?xf32>
// CHECK-NEXT: extract_element %[[TENSOR]][%[[C0]]]
%result = extract_element %casted[%c0] : tensor<?xf32>
// CHECK-NEXT: tensor.extract %[[TENSOR]][%[[C0]]]
%result = tensor.extract %casted[%c0] : tensor<?xf32>
return %result : f32
}

View File

@ -716,38 +716,6 @@ func @cmpf_inf() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1,
// -----
// CHECK-LABEL: func @fold_extract_element
func @fold_extract_element(%arg0 : index) -> (f32, f16, f16, i32) {
%const_0 = constant 0 : index
%const_1 = constant 1 : index
%const_3 = constant 3 : index
// Fold an extract into a splat.
// CHECK-NEXT: [[C4:%.+]] = constant 4.{{0*}}e+00 : f32
%0 = constant dense<4.0> : tensor<4xf32>
%ext_1 = extract_element %0[%arg0] : tensor<4xf32>
// Fold an extract into a sparse with a sparse index.
// CHECK-NEXT: [[CM2:%.+]] = constant -2.{{0*}}e+00 : f16
%1 = constant sparse<[[0, 0, 0], [1, 1, 1]], [-5.0, -2.0]> : vector<4x4x4xf16>
%ext_2 = extract_element %1[%const_1, %const_1, %const_1] : vector<4x4x4xf16>
// Fold an extract into a sparse with a non sparse index.
// CHECK-NEXT: [[C0:%.+]] = constant 0.{{0*}}e+00 : f16
%2 = constant sparse<[[1, 1, 1]], [-2.0]> : vector<1x1x1xf16>
%ext_3 = extract_element %2[%const_0, %const_0, %const_0] : vector<1x1x1xf16>
// Fold an extract into a dense tensor.
// CHECK-NEXT: [[C64:%.+]] = constant 64 : i32
%3 = constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
%ext_4 = extract_element %3[%const_1, %const_0, %const_3] : tensor<2x1x4xi32>
// CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]]
return %ext_1, %ext_2, %ext_3, %ext_4 : f32, f16, f16, i32
}
// -----
// CHECK-LABEL: func @fold_rank
func @fold_rank() -> (index) {
%const_0 = constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>

View File

@ -38,7 +38,7 @@ syn match mlirType /x\s*\zsvector/
" TODO: this list is not exhaustive.
syn keyword mlirOps alloc alloca addf addi and call call_indirect cmpf cmpi
syn keyword mlirOps constant dealloc divf dma_start dma_wait dim exp
syn keyword mlirOps extract_element getTensor index_cast load log memref_cast
syn keyword mlirOps getTensor index_cast load log memref_cast
syn keyword mlirOps memref_shape_cast mulf muli negf powf prefetch rsqrt sitofp
syn keyword mlirOps splat store select sqrt subf subi subview tanh tensor_cast
syn keyword mlirOps view