forked from OSchip/llvm-project
[MLIR][Standard] Add `dynamic_tensor_from_elements` operation
With `dynamic_tensor_from_elements` tensor values of dynamic size can be created. The body of the operation essentially maps the index space to tensor elements. Declare SCF operations in the `scf` namespace to avoid name clash with the new `std.yield` operation. Resolve ambiguities between `linalg/shape/std/scf.yield` operations. Differential Revision: https://reviews.llvm.org/D86276
This commit is contained in:
parent
928c4b4b49
commit
136eb79a88
|
@ -19,7 +19,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
|
|||
|
||||
def SCF_Dialect : Dialect {
|
||||
let name = "scf";
|
||||
let cppNamespace = "";
|
||||
let cppNamespace = "scf";
|
||||
}
|
||||
|
||||
// Base class for SCF dialect ops.
|
||||
|
@ -39,7 +39,7 @@ class SCF_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||
def ForOp : SCF_Op<"for",
|
||||
[DeclareOpInterfaceMethods<LoopLikeOpInterface>,
|
||||
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
|
||||
SingleBlockImplicitTerminator<"YieldOp">,
|
||||
SingleBlockImplicitTerminator<"scf::YieldOp">,
|
||||
RecursiveSideEffects]> {
|
||||
let summary = "for operation";
|
||||
let description = [{
|
||||
|
@ -183,7 +183,7 @@ def ForOp : SCF_Op<"for",
|
|||
|
||||
def IfOp : SCF_Op<"if",
|
||||
[DeclareOpInterfaceMethods<RegionBranchOpInterface>,
|
||||
SingleBlockImplicitTerminator<"YieldOp">, RecursiveSideEffects,
|
||||
SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveSideEffects,
|
||||
NoRegionArguments]> {
|
||||
let summary = "if-then-else operation";
|
||||
let description = [{
|
||||
|
@ -271,7 +271,7 @@ def ParallelOp : SCF_Op<"parallel",
|
|||
[AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
|
||||
RecursiveSideEffects,
|
||||
SingleBlockImplicitTerminator<"YieldOp">]> {
|
||||
SingleBlockImplicitTerminator<"scf::YieldOp">]> {
|
||||
let summary = "parallel for operation";
|
||||
let description = [{
|
||||
The "scf.parallel" operation represents a loop nest taking 4 groups of SSA
|
||||
|
|
|
@ -1475,6 +1475,37 @@ def DivFOp : FloatArithmeticOp<"divf"> {
|
|||
let summary = "floating point division operation";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DynamicTensorFromElementsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def DynamicTensorFromElementsOp : Std_Op<"dynamic_tensor_from_elements",
|
||||
[RecursiveSideEffects, SingleBlockImplicitTerminator<"YieldOp">]> {
|
||||
string summary = "Creates a dynamically sized tensor from elements";
|
||||
string description = [{
|
||||
This operation creates a dynamically sized tensor with elements of any type.
|
||||
It expects one index operand per dynamic extent of the result tensor.
|
||||
|
||||
The body region defines the tensor's elements. It takes index operands as
|
||||
its region arguments that span the index space. The element at the given
|
||||
position is yielded with the `yield` operation (see `YieldOp`).
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%tnsr = dynamic_tensor_from_elements %m, %n {
|
||||
^bb0(%i : index, %j : index, %k : index):
|
||||
...
|
||||
yield %elem : f32
|
||||
} : tensor<?x3x?f32>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<Index>:$dynamicExtents);
|
||||
let results = (outs AnyRankedTensor:$result);
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExpOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -3252,6 +3283,24 @@ def ViewOp : Std_Op<"view", [
|
|||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// YieldOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def YieldOp : Std_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
|
||||
HasParent<"DynamicTensorFromElementsOp">]> {
|
||||
let summary = "Yield a value from a region";
|
||||
let description = [{
|
||||
This operation is used to yield a single value from a within a region. It
|
||||
is used to create dynamically sized tensors
|
||||
(see `DynamicTensorFromElementsOp`).
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyType:$value);
|
||||
let assemblyFormat = "$value attr-dict `:` type($value)";
|
||||
let verifier = ?;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XOrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -339,7 +339,8 @@ public:
|
|||
class YieldOpConversion : public ConvertToLLVMPattern {
|
||||
public:
|
||||
explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
|
||||
: ConvertToLLVMPattern(YieldOp::getOperationName(), context, lowering_) {}
|
||||
: ConvertToLLVMPattern(linalg::YieldOp::getOperationName(), context,
|
||||
lowering_) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
|
|
|
@ -356,7 +356,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
|
|||
// A loop is constructed with an empty "yield" terminator if there are
|
||||
// no results.
|
||||
rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
|
||||
rewriter.create<YieldOp>(loc, forOp.getResults());
|
||||
rewriter.create<scf::YieldOp>(loc, forOp.getResults());
|
||||
}
|
||||
|
||||
rewriter.setInsertionPointToStart(forOp.getBody());
|
||||
|
@ -391,7 +391,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
|
|||
|
||||
if (!yieldOperands.empty()) {
|
||||
rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
|
||||
rewriter.create<YieldOp>(loc, yieldOperands);
|
||||
rewriter.create<scf::YieldOp>(loc, yieldOperands);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(parallelOp, loopResults);
|
||||
|
|
|
@ -905,7 +905,7 @@ static ParseResult parseTransposeOp(OpAsmParser &parser,
|
|||
// YieldOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, YieldOp op) {
|
||||
static void print(OpAsmPrinter &p, linalg::YieldOp op) {
|
||||
p << op.getOperationName();
|
||||
if (op.getNumOperands() > 0)
|
||||
p << ' ' << op.getOperands();
|
||||
|
@ -926,7 +926,8 @@ static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
|
|||
|
||||
// Check the operand number and types must match the element types of the
|
||||
// LinalgOp interface's shaped operands.
|
||||
static LogicalResult verifyYield(YieldOp op, LinalgOp linalgOpInterface) {
|
||||
static LogicalResult verifyYield(linalg::YieldOp op,
|
||||
LinalgOp linalgOpInterface) {
|
||||
auto nOutputs = linalgOpInterface.getNumOutputs();
|
||||
if (op.getNumOperands() != nOutputs)
|
||||
return op.emitOpError("expected number of yield values (")
|
||||
|
@ -946,7 +947,7 @@ static LogicalResult verifyYield(YieldOp op, LinalgOp linalgOpInterface) {
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(YieldOp op) {
|
||||
static LogicalResult verify(linalg::YieldOp op) {
|
||||
auto *parentOp = op.getParentOp();
|
||||
if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
|
||||
return op.emitOpError("expected single non-empty parent region");
|
||||
|
|
|
@ -659,7 +659,7 @@ private:
|
|||
// Add operations from producer (except the yield operation) to the fused
|
||||
// op.
|
||||
for (auto &op : producerBlock.getOperations()) {
|
||||
if (auto yieldOp = dyn_cast<YieldOp>(op)) {
|
||||
if (auto yieldOp = dyn_cast<linalg::YieldOp>(op)) {
|
||||
// Lookup the value the yield operation is mapped to.
|
||||
Value yieldVal = yieldOp.getOperand(0);
|
||||
if (Value clonedVal = mapper.lookupOrNull(yieldVal))
|
||||
|
|
|
@ -147,7 +147,7 @@ static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues,
|
|||
}
|
||||
|
||||
Operation &terminator = block.back();
|
||||
assert(isa<YieldOp>(terminator) &&
|
||||
assert(isa<linalg::YieldOp>(terminator) &&
|
||||
"expected a yield op in the end of the region");
|
||||
for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) {
|
||||
IndexedValueType O(outputBuffers[i]);
|
||||
|
|
|
@ -48,14 +48,14 @@ static bool hasMultiplyAddBody(Region &r) {
|
|||
auto c = m_Val(r.getArgument(2));
|
||||
// TODO: Update this detection once we have matcher support for specifying
|
||||
// that any permutation of operands matches.
|
||||
auto pattern1 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c));
|
||||
auto pattern2 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b)));
|
||||
auto pattern3 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c));
|
||||
auto pattern4 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a)));
|
||||
auto pattern5 = m_Op<YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(a, b), c));
|
||||
auto pattern6 = m_Op<YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(a, b)));
|
||||
auto pattern7 = m_Op<YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(b, a), c));
|
||||
auto pattern8 = m_Op<YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(b, a)));
|
||||
auto pattern1 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c));
|
||||
auto pattern2 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b)));
|
||||
auto pattern3 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c));
|
||||
auto pattern4 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a)));
|
||||
auto pattern5 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(a, b), c));
|
||||
auto pattern6 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(a, b)));
|
||||
auto pattern7 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(b, a), c));
|
||||
auto pattern8 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(b, a)));
|
||||
return pattern1.match(&r.front().back()) ||
|
||||
pattern2.match(&r.front().back()) ||
|
||||
pattern3.match(&r.front().back()) ||
|
||||
|
|
|
@ -38,7 +38,7 @@ struct SCFInlinerInterface : public DialectInlinerInterface {
|
|||
// as necessary. Required when the region has only one block.
|
||||
void handleTerminator(Operation *op,
|
||||
ArrayRef<Value> valuesToRepl) const final {
|
||||
auto retValOp = dyn_cast<YieldOp>(op);
|
||||
auto retValOp = dyn_cast<scf::YieldOp>(op);
|
||||
if (!retValOp)
|
||||
return;
|
||||
|
||||
|
@ -889,7 +889,7 @@ static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
|
|||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, YieldOp op) {
|
||||
static void print(OpAsmPrinter &p, scf::YieldOp op) {
|
||||
p << op.getOperationName();
|
||||
if (op.getNumOperands() != 0)
|
||||
p << ' ' << op.getOperands() << " : " << op.getOperandTypes();
|
||||
|
@ -899,5 +899,9 @@ static void print(OpAsmPrinter &p, YieldOp op) {
|
|||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace mlir {
|
||||
namespace scf {
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/SCF/SCFOps.cpp.inc"
|
||||
} // namespace scf
|
||||
} // namespace mlir
|
||||
|
|
|
@ -779,7 +779,7 @@ void SizeToIndexOp::getCanonicalizationPatterns(
|
|||
// YieldOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(YieldOp op) {
|
||||
static LogicalResult verify(shape::YieldOp op) {
|
||||
auto *parentOp = op.getParentOp();
|
||||
auto results = parentOp->getResults();
|
||||
auto operands = op.getOperands();
|
||||
|
|
|
@ -45,7 +45,7 @@ NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
|
|||
OpBuilder b = OpBuilder::atBlockEnd(body);
|
||||
Value product = b.create<MulOp>(loc, valueType, body->getArgument(1),
|
||||
body->getArgument(2));
|
||||
b.create<YieldOp>(loc, product);
|
||||
b.create<shape::YieldOp>(loc, product);
|
||||
|
||||
rewriter.replaceOp(op, reduce.result());
|
||||
return success();
|
||||
|
|
|
@ -1312,7 +1312,6 @@ Optional<int64_t> DimOp::getConstantIndex() {
|
|||
}
|
||||
|
||||
static LogicalResult verify(DimOp op) {
|
||||
|
||||
// Assume unknown index to be in range.
|
||||
Optional<int64_t> index = op.getConstantIndex();
|
||||
if (!index.hasValue())
|
||||
|
@ -1634,6 +1633,67 @@ LogicalResult DmaWaitOp::verify() {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DynamicTensorFromElementsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseDynamicTensorFromElementsOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
// Parse operands.
|
||||
SmallVector<OpAsmParser::OperandType, 4> dynamicExtents;
|
||||
Type indexTy = parser.getBuilder().getIndexType();
|
||||
if (parser.parseOperandList(dynamicExtents) ||
|
||||
parser.resolveOperands(dynamicExtents, indexTy, result.operands))
|
||||
return failure();
|
||||
|
||||
// Parse body.
|
||||
Region *body = result.addRegion();
|
||||
if (parser.parseRegion(*body, {}, {}))
|
||||
return failure();
|
||||
|
||||
// Parse result type.
|
||||
Type resultType;
|
||||
if (parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(resultType))
|
||||
return failure();
|
||||
result.addTypes(resultType);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, DynamicTensorFromElementsOp op) {
|
||||
p << "dynamic_tensor_from_elements " << op.dynamicExtents();
|
||||
p.printRegion(op.body());
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(DynamicTensorFromElementsOp op) {
|
||||
// Ensure that the tensor type has as many dynamic dimensions as are specified
|
||||
// by the operands.
|
||||
RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
|
||||
if (op.getNumOperands() != resultTy.getNumDynamicDims())
|
||||
return op.emitError("must have as many index operands as dynamic extents "
|
||||
"in the result type");
|
||||
|
||||
// Ensure that region arguments span the index space.
|
||||
if (!llvm::all_of(op.body().getArgumentTypes(),
|
||||
[](Type ty) { return ty.isIndex(); }))
|
||||
return op.emitError("all body arguments must be index");
|
||||
if (op.body().getNumArguments() != resultTy.getRank())
|
||||
return op.emitError("must have one body argument per input dimension");
|
||||
|
||||
// Ensure that the region yields an element of the right type.
|
||||
auto yieldOp =
|
||||
llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
|
||||
if (yieldOp.value().getType() != resultTy.getElementType())
|
||||
return op.emitOpError(
|
||||
"body must be terminated with a `yield` operation of the tensor "
|
||||
"element type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExtractElementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -15,3 +15,69 @@ func @test_index_cast_tensor_error(%arg0 : tensor<index>) -> i64 {
|
|||
%0 = index_cast %arg0 : tensor<index> to i64
|
||||
return %0 : i64
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dynamic_tensor_from_elements(%m : index)
|
||||
-> tensor<?x3x?xf32> {
|
||||
// expected-error @+1 {{must have as many index operands as dynamic extents in the result type}}
|
||||
%tnsr = dynamic_tensor_from_elements %m {
|
||||
^bb0(%i : index, %j : index, %k : index):
|
||||
%elem = constant 8.0 : f32
|
||||
yield %elem : f32
|
||||
} : tensor<?x3x?xf32>
|
||||
return %tnsr : tensor<?x3x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dynamic_tensor_from_elements(%m : index, %n : index)
|
||||
-> tensor<?x3x?xf32> {
|
||||
// expected-error @+1 {{must have one body argument per input dimension}}
|
||||
%tnsr = dynamic_tensor_from_elements %m, %n {
|
||||
^bb0(%i : index, %j : index):
|
||||
%elem = constant 8.0 : f32
|
||||
yield %elem : f32
|
||||
} : tensor<?x3x?xf32>
|
||||
return %tnsr : tensor<?x3x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dynamic_tensor_from_elements(%m : index, %n : index)
|
||||
-> tensor<?x3x?xf32> {
|
||||
// expected-error @+1 {{all body arguments must be index}}
|
||||
%tnsr = dynamic_tensor_from_elements %m, %n {
|
||||
^bb0(%i : index, %j : index, %k : i64):
|
||||
%elem = constant 8.0 : f32
|
||||
yield %elem : f32
|
||||
} : tensor<?x3x?xf32>
|
||||
return %tnsr : tensor<?x3x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dynamic_tensor_from_elements(%m : index, %n : index)
|
||||
-> tensor<?x3x?xf32> {
|
||||
// expected-error @+2 {{op expects regions to end with 'std.yield', found 'std.return'}}
|
||||
// expected-note @+1 {{in custom textual format, the absence of terminator implies 'std.yield'}}
|
||||
%tnsr = dynamic_tensor_from_elements %m, %n {
|
||||
^bb0(%i : index, %j : index, %k : index):
|
||||
%elem = constant 8.0 : f32
|
||||
return %elem : f32
|
||||
} : tensor<?x3x?xf32>
|
||||
return %tnsr : tensor<?x3x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dynamic_tensor_from_elements(%m : index, %n : index)
|
||||
-> tensor<?x3x?xf32> {
|
||||
// expected-error @+1 {{body must be terminated with a `yield` operation of the tensor element type}}
|
||||
%tnsr = dynamic_tensor_from_elements %m, %n {
|
||||
^bb0(%i : index, %j : index, %k : index):
|
||||
%elem = constant 8 : i32
|
||||
yield %elem : i32
|
||||
} : tensor<?x3x?xf32>
|
||||
return %tnsr : tensor<?x3x?xf32>
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
// RUN: mlir-opt -split-input-file %s | FileCheck %s
|
||||
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
|
||||
// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: test_index_cast
|
||||
func @test_index_cast(%arg0 : index) -> i64 {
|
||||
|
@ -22,3 +23,14 @@ func @assert(%arg : i1) {
|
|||
assert %arg, "Some message in case this assertion fails."
|
||||
return
|
||||
}
|
||||
|
||||
func @dynamic_tensor_from_elements(%m : index, %n : index)
|
||||
-> tensor<?x3x?xf32> {
|
||||
%tnsr = dynamic_tensor_from_elements %m, %n {
|
||||
^bb0(%i : index, %j : index, %k : index):
|
||||
%elem = constant 8.0 : f32
|
||||
yield %elem : f32
|
||||
} : tensor<?x3x?xf32>
|
||||
return %tnsr : tensor<?x3x?xf32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue