[MLIR][Standard] Add `dynamic_tensor_from_elements` operation

With `dynamic_tensor_from_elements` tensor values of dynamic size can be
created. The body of the operation essentially maps the index space to tensor
elements.

Declare SCF operations in the `scf` namespace to avoid name clash with the new
`std.yield` operation. Resolve ambiguities between `linalg/shape/std/scf.yield`
operations.

Differential Revision: https://reviews.llvm.org/D86276
This commit is contained in:
Frederik Gossen 2020-09-07 11:41:27 +00:00
parent 928c4b4b49
commit 136eb79a88
14 changed files with 219 additions and 26 deletions

View File

@ -19,7 +19,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def SCF_Dialect : Dialect {
let name = "scf";
let cppNamespace = "";
let cppNamespace = "scf";
}
// Base class for SCF dialect ops.
@ -39,7 +39,7 @@ class SCF_Op<string mnemonic, list<OpTrait> traits = []> :
def ForOp : SCF_Op<"for",
[DeclareOpInterfaceMethods<LoopLikeOpInterface>,
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
SingleBlockImplicitTerminator<"YieldOp">,
SingleBlockImplicitTerminator<"scf::YieldOp">,
RecursiveSideEffects]> {
let summary = "for operation";
let description = [{
@ -183,7 +183,7 @@ def ForOp : SCF_Op<"for",
def IfOp : SCF_Op<"if",
[DeclareOpInterfaceMethods<RegionBranchOpInterface>,
SingleBlockImplicitTerminator<"YieldOp">, RecursiveSideEffects,
SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveSideEffects,
NoRegionArguments]> {
let summary = "if-then-else operation";
let description = [{
@ -271,7 +271,7 @@ def ParallelOp : SCF_Op<"parallel",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
RecursiveSideEffects,
SingleBlockImplicitTerminator<"YieldOp">]> {
SingleBlockImplicitTerminator<"scf::YieldOp">]> {
let summary = "parallel for operation";
let description = [{
The "scf.parallel" operation represents a loop nest taking 4 groups of SSA

View File

@ -1475,6 +1475,37 @@ def DivFOp : FloatArithmeticOp<"divf"> {
let summary = "floating point division operation";
}
//===----------------------------------------------------------------------===//
// DynamicTensorFromElementsOp
//===----------------------------------------------------------------------===//
def DynamicTensorFromElementsOp : Std_Op<"dynamic_tensor_from_elements",
[RecursiveSideEffects, SingleBlockImplicitTerminator<"YieldOp">]> {
string summary = "Creates a dynamically sized tensor from elements";
string description = [{
This operation creates a dynamically sized tensor with elements of any type.
It expects one index operand per dynamic extent of the result tensor.
The body region defines the tensor's elements. It takes index operands as
its region arguments that span the index space. The element at the given
position is yielded with the `yield` operation (see `YieldOp`).
Example:
```mlir
%tnsr = dynamic_tensor_from_elements %m, %n {
^bb0(%i : index, %j : index, %k : index):
...
yield %elem : f32
} : tensor<?x3x?f32>
```
}];
let arguments = (ins Variadic<Index>:$dynamicExtents);
let results = (outs AnyRankedTensor:$result);
let regions = (region SizedRegion<1>:$body);
}
//===----------------------------------------------------------------------===//
// ExpOp
//===----------------------------------------------------------------------===//
@ -3252,6 +3283,24 @@ def ViewOp : Std_Op<"view", [
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
def YieldOp : Std_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
HasParent<"DynamicTensorFromElementsOp">]> {
let summary = "Yield a value from a region";
let description = [{
This operation is used to yield a single value from a within a region. It
is used to create dynamically sized tensors
(see `DynamicTensorFromElementsOp`).
}];
let arguments = (ins AnyType:$value);
let assemblyFormat = "$value attr-dict `:` type($value)";
let verifier = ?;
}
//===----------------------------------------------------------------------===//
// XOrOp
//===----------------------------------------------------------------------===//

View File

@ -339,7 +339,8 @@ public:
class YieldOpConversion : public ConvertToLLVMPattern {
public:
explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(YieldOp::getOperationName(), context, lowering_) {}
: ConvertToLLVMPattern(linalg::YieldOp::getOperationName(), context,
lowering_) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,

View File

@ -356,7 +356,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
// A loop is constructed with an empty "yield" terminator if there are
// no results.
rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
rewriter.create<YieldOp>(loc, forOp.getResults());
rewriter.create<scf::YieldOp>(loc, forOp.getResults());
}
rewriter.setInsertionPointToStart(forOp.getBody());
@ -391,7 +391,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
if (!yieldOperands.empty()) {
rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
rewriter.create<YieldOp>(loc, yieldOperands);
rewriter.create<scf::YieldOp>(loc, yieldOperands);
}
rewriter.replaceOp(parallelOp, loopResults);

View File

@ -905,7 +905,7 @@ static ParseResult parseTransposeOp(OpAsmParser &parser,
// YieldOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, YieldOp op) {
static void print(OpAsmPrinter &p, linalg::YieldOp op) {
p << op.getOperationName();
if (op.getNumOperands() > 0)
p << ' ' << op.getOperands();
@ -926,7 +926,8 @@ static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
// Check the operand number and types must match the element types of the
// LinalgOp interface's shaped operands.
static LogicalResult verifyYield(YieldOp op, LinalgOp linalgOpInterface) {
static LogicalResult verifyYield(linalg::YieldOp op,
LinalgOp linalgOpInterface) {
auto nOutputs = linalgOpInterface.getNumOutputs();
if (op.getNumOperands() != nOutputs)
return op.emitOpError("expected number of yield values (")
@ -946,7 +947,7 @@ static LogicalResult verifyYield(YieldOp op, LinalgOp linalgOpInterface) {
return success();
}
static LogicalResult verify(YieldOp op) {
static LogicalResult verify(linalg::YieldOp op) {
auto *parentOp = op.getParentOp();
if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
return op.emitOpError("expected single non-empty parent region");

View File

@ -659,7 +659,7 @@ private:
// Add operations from producer (except the yield operation) to the fused
// op.
for (auto &op : producerBlock.getOperations()) {
if (auto yieldOp = dyn_cast<YieldOp>(op)) {
if (auto yieldOp = dyn_cast<linalg::YieldOp>(op)) {
// Lookup the value the yield operation is mapped to.
Value yieldVal = yieldOp.getOperand(0);
if (Value clonedVal = mapper.lookupOrNull(yieldVal))

View File

@ -147,7 +147,7 @@ static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues,
}
Operation &terminator = block.back();
assert(isa<YieldOp>(terminator) &&
assert(isa<linalg::YieldOp>(terminator) &&
"expected a yield op in the end of the region");
for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) {
IndexedValueType O(outputBuffers[i]);

View File

@ -48,14 +48,14 @@ static bool hasMultiplyAddBody(Region &r) {
auto c = m_Val(r.getArgument(2));
// TODO: Update this detection once we have matcher support for specifying
// that any permutation of operands matches.
auto pattern1 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c));
auto pattern2 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b)));
auto pattern3 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c));
auto pattern4 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a)));
auto pattern5 = m_Op<YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(a, b), c));
auto pattern6 = m_Op<YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(a, b)));
auto pattern7 = m_Op<YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(b, a), c));
auto pattern8 = m_Op<YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(b, a)));
auto pattern1 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c));
auto pattern2 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b)));
auto pattern3 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c));
auto pattern4 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a)));
auto pattern5 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(a, b), c));
auto pattern6 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(a, b)));
auto pattern7 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(b, a), c));
auto pattern8 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(b, a)));
return pattern1.match(&r.front().back()) ||
pattern2.match(&r.front().back()) ||
pattern3.match(&r.front().back()) ||

View File

@ -38,7 +38,7 @@ struct SCFInlinerInterface : public DialectInlinerInterface {
// as necessary. Required when the region has only one block.
void handleTerminator(Operation *op,
ArrayRef<Value> valuesToRepl) const final {
auto retValOp = dyn_cast<YieldOp>(op);
auto retValOp = dyn_cast<scf::YieldOp>(op);
if (!retValOp)
return;
@ -889,7 +889,7 @@ static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
return success();
}
static void print(OpAsmPrinter &p, YieldOp op) {
static void print(OpAsmPrinter &p, scf::YieldOp op) {
p << op.getOperationName();
if (op.getNumOperands() != 0)
p << ' ' << op.getOperands() << " : " << op.getOperandTypes();
@ -899,5 +899,9 @@ static void print(OpAsmPrinter &p, YieldOp op) {
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
namespace mlir {
namespace scf {
#define GET_OP_CLASSES
#include "mlir/Dialect/SCF/SCFOps.cpp.inc"
} // namespace scf
} // namespace mlir

View File

@ -779,7 +779,7 @@ void SizeToIndexOp::getCanonicalizationPatterns(
// YieldOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(YieldOp op) {
static LogicalResult verify(shape::YieldOp op) {
auto *parentOp = op.getParentOp();
auto results = parentOp->getResults();
auto operands = op.getOperands();

View File

@ -45,7 +45,7 @@ NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
OpBuilder b = OpBuilder::atBlockEnd(body);
Value product = b.create<MulOp>(loc, valueType, body->getArgument(1),
body->getArgument(2));
b.create<YieldOp>(loc, product);
b.create<shape::YieldOp>(loc, product);
rewriter.replaceOp(op, reduce.result());
return success();

View File

@ -1312,7 +1312,6 @@ Optional<int64_t> DimOp::getConstantIndex() {
}
static LogicalResult verify(DimOp op) {
// Assume unknown index to be in range.
Optional<int64_t> index = op.getConstantIndex();
if (!index.hasValue())
@ -1634,6 +1633,67 @@ LogicalResult DmaWaitOp::verify() {
return success();
}
//===----------------------------------------------------------------------===//
// DynamicTensorFromElementsOp
//===----------------------------------------------------------------------===//
static ParseResult parseDynamicTensorFromElementsOp(OpAsmParser &parser,
OperationState &result) {
// Parse operands.
SmallVector<OpAsmParser::OperandType, 4> dynamicExtents;
Type indexTy = parser.getBuilder().getIndexType();
if (parser.parseOperandList(dynamicExtents) ||
parser.resolveOperands(dynamicExtents, indexTy, result.operands))
return failure();
// Parse body.
Region *body = result.addRegion();
if (parser.parseRegion(*body, {}, {}))
return failure();
// Parse result type.
Type resultType;
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(resultType))
return failure();
result.addTypes(resultType);
return success();
}
static void print(OpAsmPrinter &p, DynamicTensorFromElementsOp op) {
p << "dynamic_tensor_from_elements " << op.dynamicExtents();
p.printRegion(op.body());
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getType();
}
static LogicalResult verify(DynamicTensorFromElementsOp op) {
// Ensure that the tensor type has as many dynamic dimensions as are specified
// by the operands.
RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
if (op.getNumOperands() != resultTy.getNumDynamicDims())
return op.emitError("must have as many index operands as dynamic extents "
"in the result type");
// Ensure that region arguments span the index space.
if (!llvm::all_of(op.body().getArgumentTypes(),
[](Type ty) { return ty.isIndex(); }))
return op.emitError("all body arguments must be index");
if (op.body().getNumArguments() != resultTy.getRank())
return op.emitError("must have one body argument per input dimension");
// Ensure that the region yields an element of the right type.
auto yieldOp =
llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
if (yieldOp.value().getType() != resultTy.getElementType())
return op.emitOpError(
"body must be terminated with a `yield` operation of the tensor "
"element type");
return success();
}
//===----------------------------------------------------------------------===//
// ExtractElementOp
//===----------------------------------------------------------------------===//

View File

@ -15,3 +15,69 @@ func @test_index_cast_tensor_error(%arg0 : tensor<index>) -> i64 {
%0 = index_cast %arg0 : tensor<index> to i64
return %0 : i64
}
// -----
func @dynamic_tensor_from_elements(%m : index)
-> tensor<?x3x?xf32> {
// expected-error @+1 {{must have as many index operands as dynamic extents in the result type}}
%tnsr = dynamic_tensor_from_elements %m {
^bb0(%i : index, %j : index, %k : index):
%elem = constant 8.0 : f32
yield %elem : f32
} : tensor<?x3x?xf32>
return %tnsr : tensor<?x3x?xf32>
}
// -----
func @dynamic_tensor_from_elements(%m : index, %n : index)
-> tensor<?x3x?xf32> {
// expected-error @+1 {{must have one body argument per input dimension}}
%tnsr = dynamic_tensor_from_elements %m, %n {
^bb0(%i : index, %j : index):
%elem = constant 8.0 : f32
yield %elem : f32
} : tensor<?x3x?xf32>
return %tnsr : tensor<?x3x?xf32>
}
// -----
func @dynamic_tensor_from_elements(%m : index, %n : index)
-> tensor<?x3x?xf32> {
// expected-error @+1 {{all body arguments must be index}}
%tnsr = dynamic_tensor_from_elements %m, %n {
^bb0(%i : index, %j : index, %k : i64):
%elem = constant 8.0 : f32
yield %elem : f32
} : tensor<?x3x?xf32>
return %tnsr : tensor<?x3x?xf32>
}
// -----
func @dynamic_tensor_from_elements(%m : index, %n : index)
-> tensor<?x3x?xf32> {
// expected-error @+2 {{op expects regions to end with 'std.yield', found 'std.return'}}
// expected-note @+1 {{in custom textual format, the absence of terminator implies 'std.yield'}}
%tnsr = dynamic_tensor_from_elements %m, %n {
^bb0(%i : index, %j : index, %k : index):
%elem = constant 8.0 : f32
return %elem : f32
} : tensor<?x3x?xf32>
return %tnsr : tensor<?x3x?xf32>
}
// -----
func @dynamic_tensor_from_elements(%m : index, %n : index)
-> tensor<?x3x?xf32> {
// expected-error @+1 {{body must be terminated with a `yield` operation of the tensor element type}}
%tnsr = dynamic_tensor_from_elements %m, %n {
^bb0(%i : index, %j : index, %k : index):
%elem = constant 8 : i32
yield %elem : i32
} : tensor<?x3x?xf32>
return %tnsr : tensor<?x3x?xf32>
}

View File

@ -1,4 +1,5 @@
// RUN: mlir-opt -split-input-file %s | FileCheck %s
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
// CHECK-LABEL: test_index_cast
func @test_index_cast(%arg0 : index) -> i64 {
@ -22,3 +23,14 @@ func @assert(%arg : i1) {
assert %arg, "Some message in case this assertion fails."
return
}
func @dynamic_tensor_from_elements(%m : index, %n : index)
-> tensor<?x3x?xf32> {
%tnsr = dynamic_tensor_from_elements %m, %n {
^bb0(%i : index, %j : index, %k : index):
%elem = constant 8.0 : f32
yield %elem : f32
} : tensor<?x3x?xf32>
return %tnsr : tensor<?x3x?xf32>
}