Add a region to linalg.generic

This CL extends the Linalg GenericOp with an alternative way of specifying the body of the computation based on a single block region. The "fun" attribute becomes optional.
Either a SymbolRef "fun" attribute or a single block region must be specified to describe the side-effect-free computation. Upon lowering to loops, the new region body is inlined in the innermost loop.

The parser, verifier and pretty printer are extended.
Appropriate roundtrip, negative and lowering to loop tests are added.

PiperOrigin-RevId: 261895568
This commit is contained in:
Nicolas Vasilache 2019-08-06 05:50:10 -07:00 committed by A. Unique TensorFlower
parent 24647750d4
commit 4b422a51ed
9 changed files with 289 additions and 36 deletions

View File

@ -335,12 +335,13 @@ def GenericOp : LinalgLibraryBase_Op<"generic", []> {
```
}];
let arguments = (ins Variadic<View>:$views,
SymbolRefAttr:$fun,
AffineMapArrayAttr:$indexing_maps,
I64ArrayAttr:$n_loop_types,
I64ArrayAttr:$n_views,
OptionalAttr<StrAttr>:$doc,
OptionalAttr<SymbolRefAttr>:$fun,
OptionalAttr<StrAttr>:$library_call);
let regions = (region AnyRegion:$region);
let extraClassDeclaration = [{
SmallVector<StringRef, 8> linalgTraitAttrNames() {
return SmallVector<StringRef, 8>{
@ -386,8 +387,10 @@ def GenericOp : LinalgLibraryBase_Op<"generic", []> {
return getNumParallelLoops() + getNumReductionLoops() +
getNumWindowLoops();
}
StringRef getFunName() {
return fun();
FuncOp getFunction() {
auto moduleOp = getParentOfType<ModuleOp>();
return fun().hasValue() ?
moduleOp.lookupSymbol<FuncOp>(fun().getValue()) : FuncOp();
}
StringRef getLibraryCallName() {
return library_call().hasValue() ? library_call().getValue() : "";

View File

@ -20,6 +20,8 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Linalg/IR/LinalgTraits.h"
#include "mlir/Linalg/IR/LinalgTypes.h"

View File

@ -213,4 +213,15 @@ def SubViewOp : Linalg_Op<"subview", [NoSideEffect]>,
}];
}
def YieldOp : Linalg_Op<"yield", [NativeOpTrait<"IsTerminator">]>,
Arguments<(ins Variadic<AnyType>:$values)> {
let summary = "Linalg yield operation";
let description = [{
"linalg.yield" is a special terminator operation for blocks inside regions
in linalg ops. It returns values to the immediately enclosing linalg op.
linalg.yield %f0, %f1 : f32, f32
}];
}
#endif // LINALG_OPS

View File

@ -113,6 +113,11 @@ public:
}
/// Return the number of input and output views.
unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); }
/// Return the `i`-th view type.
mlir::linalg::ViewType getViewType(unsigned i) {
return (i < nInputs()) ? getInputViewType(i)
: getOutputViewType(i - nInputs());
}
/// Return the range over input and output views.
Operation::operand_range getInputsAndOutputs() {
auto range = this->getOperation()->getOperands();

View File

@ -538,13 +538,15 @@ static void print(OpAsmPrinter *p, GenericOp op) {
auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
*p << op.getOperationName() << " " << dictAttr << " ";
p->printOperands(op.getOperands());
if (!op.region().empty())
p->printRegion(op.region());
p->printOptionalAttrDict(op.getAttrs(), attrNames);
*p << ": ";
interleaveComma(op.getOperandTypes(), *p);
}
static ParseResult parseGenericOp(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 8> operandsInfo;
SmallVector<OpAsmParser::OperandType, 8> operandsInfo, regionOperandsInfo;
DictionaryAttr dictAttr;
// Parse the core linalg traits that must check into a dictAttr.
// The name is unimportant as we will overwrite result->attributes.
@ -556,8 +558,13 @@ static ParseResult parseGenericOp(OpAsmParser *parser, OperationState *result) {
result->attributes.assign(dictAttr.getValue().begin(),
dictAttr.getValue().end());
Region &region = *result->addRegion();
SmallVector<Type, 8> operandTypes, regionTypes;
// Optional attributes may be added.
SmallVector<Type, 8> operandTypes;
// Either Optional "fun" attribute or region must be specified.
if (!dictAttr.get("fun") &&
parser->parseOptionalRegion(region, regionOperandsInfo, regionTypes))
return failure();
if (parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonTypeList(operandTypes))
return failure();
@ -572,18 +579,36 @@ static LogicalResult verify(GenericOp op) {
if (nViews != llvm::size(op.views()))
return op.emitError("op expected exactly ") << nViews << " view operands";
auto m = op.getParentOfType<ModuleOp>();
auto fun = m.lookupSymbol<FuncOp>(op.fun());
if (!fun || !fun.getType())
return op.emitError(
"op expected fun attribute to refer to a defined symbol");
auto &region = op.region();
auto funOp = op.getFunction();
auto funType = funOp ? funOp.getType() : FunctionType();
if (!region.empty()) {
if (region.getBlocks().size() != 1)
return op.emitError("op expected region with 1 block");
auto funType = fun.getType();
if (funType.getNumInputs() != nViews)
return op.emitError("op expected fun arguments to match number of views");
if (funType.getNumResults() != op.getNumOutputs())
return op.emitError(
"op expected fun results to match number of output views");
auto &block = region.getBlocks().front();
if (block.getNumArguments() != nViews)
return op.emitError(
"op expected number of block arguments to match number of views");
for (unsigned i = 0; i < nViews; ++i) {
auto viewType = op.getViewType(i);
if (viewType.getElementType() != block.getArgument(i)->getType())
return op.emitError("op expected block argument ")
<< i << " of the same type as elemental type of "
<< ((i < nInputViews) ? "input " : "output ")
<< "view: " << viewType;
}
} else {
if (!funOp || !funOp.getType())
return op.emitError(
"op expected fun attribute to refer to a defined symbol");
if (funType.getNumInputs() != nViews)
return op.emitError("op expected fun arguments to match number of views");
if (funType.getNumResults() != op.getNumOutputs())
return op.emitError(
"op expected fun results to match number of output views");
}
auto nLoops = op.getNumLoops();
SmallVector<AffineMap, 4> indexingMaps;
@ -615,15 +640,18 @@ static LogicalResult verify(GenericOp op) {
return op.emitError("op expected indexing_map #")
<< idx << " results to match view rank: " << view;
if (funType.getInput(idx) != view.getElementType())
return op.emitError("op expected fun argument ")
<< idx << " to match view element type: " << view.getElementType();
if (funType) {
if (funType.getInput(idx) != view.getElementType())
return op.emitError("op expected fun argument ")
<< idx
<< " to match view element type: " << view.getElementType();
if (idx >= nInputViews)
if (funType.getResult(idx - nInputViews) != view.getElementType())
return op.emitError("op expected fun result ")
<< idx << " to match output view element type: "
<< view.getElementType();
if (idx >= nInputViews)
if (funType.getResult(idx - nInputViews) != view.getElementType())
return op.emitError("op expected fun result ")
<< idx << " to match output view element type: "
<< view.getElementType();
}
}
auto concatMap = concatAffineMaps(indexingMaps);
@ -635,6 +663,56 @@ static LogicalResult verify(GenericOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
static ParseResult parseYieldOp(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> opInfo;
SmallVector<Type, 2> types;
llvm::SMLoc loc = parser->getCurrentLocation();
return failure(parser->parseOperandList(opInfo) ||
(!opInfo.empty() && parser->parseColonTypeList(types)) ||
parser->resolveOperands(opInfo, types, loc, result->operands));
}
static void print(OpAsmPrinter *p, YieldOp op) {
*p << op.getOperationName();
if (op.getNumOperands() > 0) {
*p << ' ';
p->printOperands(op.operand_begin(), op.operand_end());
*p << " : ";
interleaveComma(op.getOperands(), *p,
[&](Value *e) { p->printType(e->getType()); });
}
}
static LogicalResult verify(YieldOp op) {
auto *parentOp = op.getParentOp();
if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
return op.emitOpError("op expected single non-empty parent region");
auto genericOp = dyn_cast<GenericOp>(parentOp);
if (!genericOp)
return op.emitOpError("op expected '")
<< GenericOp::getOperationName() << "' parent op";
// The operand number and types must match the view element types.
auto nOutputViews = genericOp.getNumOutputs();
if (op.getNumOperands() != nOutputViews)
return op.emitOpError("op expected ")
<< nOutputViews << " operand to match enclosing linalg.generic op";
for (unsigned i = 0; i != nOutputViews; ++i) {
auto elementType = genericOp.getOutputViewType(i).getElementType();
if (op.getOperand(i)->getType() != elementType)
return op.emitError("type of return operand ")
<< i << " (" << op.getOperand(i)->getType()
<< ") doesn't match view element type (" << elementType << ")";
}
return success();
}
static void print(OpAsmPrinter *p, SubViewOp op) {
*p << op.getOperationName() << " " << *op.getOperand(0) << "[";
auto ranges = op.getRanges();

View File

@ -20,6 +20,7 @@
#include "mlir/EDSC/Helpers.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Linalg/IR/LinalgOps.h"
#include "mlir/Linalg/IR/LinalgTypes.h"
@ -241,17 +242,44 @@ public:
linalg_load(genericOp.getOutput(i), indexing);
}
// 2. Emit call.
auto m = genericOp.getParentOfType<ModuleOp>();
auto fun = m.lookupSymbol<FuncOp>(genericOp.fun());
Operation *callOp = call(fun, indexedValues);
assert(callOp->getNumResults() == genericOp.getNumOutputs());
auto funcOp = genericOp.getFunction();
if (funcOp) {
// 2. Emit call.
Operation *callOp = call(funcOp, indexedValues);
assert(callOp->getNumResults() == genericOp.getNumOutputs());
// 3. Emit linalg_store.
for (unsigned i = 0, e = nOutputs; i < e; ++i) {
ValueHandleArray indexing(foldedAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
linalg_store(callOp->getResult(i), genericOp.getOutput(i), indexing);
// 3. Emit linalg_store.
for (unsigned i = 0, e = nOutputs; i < e; ++i) {
ValueHandleArray indexing(foldedAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
linalg_store(callOp->getResult(i), genericOp.getOutput(i), indexing);
}
} else {
// TODO(ntv): When a region inliner exists, use it.
// 2. Inline region, currently only works for a single basic block.
BlockAndValueMapping map;
auto &block = genericOp.region().front();
for (auto it : llvm::zip(block.getArguments(), indexedValues))
map.map(std::get<0>(it), std::get<1>(it));
for (auto &op : block) {
// Skip terminator.
if (&op == &block.back())
continue;
assert(op.getNumRegions() == 0);
auto *newOp = b.clone(op, map);
for (auto it : llvm::zip(op.getResults(), newOp->getResults()))
map.map(std::get<0>(it), std::get<1>(it));
}
// 3. Emit linalg_store.
auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
assert(yieldOp->getNumOperands() == nOutputs);
for (unsigned i = 0, e = nOutputs; i < e; ++i) {
ValueHandleArray indexing(foldedAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
linalg_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i),
indexing);
}
}
}
};

View File

@ -1,5 +1,9 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
////////////////////////////////////////////////////////////////////////////////
///////////////////////////// Function Attribute tests /////////////////////////
////////////////////////////////////////////////////////////////////////////////
// -----
// CHECK-LABEL: at_least_2_operands
@ -194,3 +198,78 @@ func @singular_maps(%arg0: !linalg.view<?xf32>, %arg1: !linalg.view<?xf32>) {
} %arg0, %arg1: !linalg.view<?xf32>, !linalg.view<?xf32>
return
}
////////////////////////////////////////////////////////////////////////////////
///////////////////////////// Region tests /////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
// -----
// CHECK-LABEL: empty_region
func @empty_region(%arg0: !linalg.view<f32>) {
// expected-error @+1 {{op expected region with 1 block}}
linalg.generic {
indexing_maps = [ () -> (0) ],
n_views = [1, 1],
n_loop_types = [0, 0, 0]
} %arg0, %arg0 {
^bb1:
^bb2:
}: !linalg.view<f32>, !linalg.view<f32>
return
}
// -----
// CHECK-LABEL: mismatched_num_arguments
func @mismatched_num_arguments(%arg0: !linalg.view<f32>) {
// expected-error @+1 {{op expected number of block arguments to match number of views}}
linalg.generic {
indexing_maps = [ () -> (0) ],
n_views = [0, 1],
n_loop_types = [0, 0, 0]
} %arg0 {
^bb:
}: !linalg.view<f32>
return
}
// -----
// CHECK-LABEL: block_arg_type
func @block_arg_type(%arg0: !linalg.view<f32>) {
// expected-error @+1 {{op expected block argument 0 of the same type as elemental type of output view: '!linalg.view<f32>'}}
linalg.generic {
indexing_maps = [ () -> (0) ],
n_views = [0, 1],
n_loop_types = [0, 0, 0]
} %arg0 {
^bb(%i: i1):
}: !linalg.view<f32>
return
}
// -----
// CHECK-LABEL: fun_result_0_element_type
func @fun_result_0_element_type(%arg0: !linalg.view<?xf32>) {
// expected-error @+8 {{type of return operand 0 ('i1') doesn't match view element type ('f32')}}
linalg.generic {
indexing_maps = [ (i) -> (i) ],
n_views = [0, 1],
n_loop_types = [1, 0, 0]
} %arg0 {
^bb(%i: f32):
%0 = constant 0: i1
linalg.yield %0: i1
}: !linalg.view<?xf32>
return
}
// -----
// CHECK-LABEL: wrong_yield_parent
func @fun_result_0_element_type(%arg0: !linalg.view<?xf32>) {
// expected-error @+1 {{op expected 'linalg.generic' parent op}}
linalg.yield %arg0: !linalg.view<?xf32>
}

View File

@ -219,13 +219,13 @@ func @foo(%0: f32, %1: f32, %2: f32) -> (f32, f32) {
library_call = "external_function_name",
doc = "B(i,j,k), C(i,k,j) = foo(A(i, j), B(i,j,k), C(i,k,j))"
}
func @generic(%arg0: !linalg.view<?x?xf32>, %arg1: !linalg.view<?x?x?xf32>, %arg2: !linalg.view<?x?x?xf32>) {
func @generic_function(%arg0: !linalg.view<?x?xf32>, %arg1: !linalg.view<?x?x?xf32>, %arg2: !linalg.view<?x?x?xf32>) {
linalg.generic #trait %arg0, %arg1, %arg2:
!linalg.view<?x?xf32>, !linalg.view<?x?x?xf32>, !linalg.view<?x?x?xf32>
return
}
// CHECK-LABEL: @foo
// CHECK-LABEL: @generic
// CHECK-LABEL: @generic_function
// CHECK: loop.for %[[i:.*]] = {{.*}}
// CHECK: loop.for %[[j:.*]] = {{.*}}
// CHECK: loop.for %[[k:.*]] = {{.*}}
@ -235,3 +235,31 @@ func @generic(%arg0: !linalg.view<?x?xf32>, %arg1: !linalg.view<?x?x?xf32>, %arg
// CHECK: %[[res:.*]]:2 = call @foo(%[[a]], %[[b]], %[[c]]) : (f32, f32, f32) -> (f32, f32)
// CHECK: linalg.store %[[res]]#0, %{{.*}}[%[[i]], %[[j]], %[[k]]] : !linalg.view<?x?x?xf32>
// CHECK: linalg.store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : !linalg.view<?x?x?xf32>
#trait2 = {
n_views = [1, 2],
n_loop_types = [3, 0, 0],
indexing_maps = #accesses,
library_call = "external_function_name",
doc = "B(i,j,k), C(i,k,j) = foo(A(i, j), B(i,j,k), C(i,k,j))"
}
func @generic_region(%arg0: !linalg.view<?x?xf32>, %arg1: !linalg.view<?x?x?xf32>, %arg2: !linalg.view<?x?x?xf32>) {
linalg.generic #trait2 %arg0, %arg1, %arg2 {
^bb0(%a: f32, %b: f32, %c: f32):
%d = mulf %a, %b : f32
%e = addf %c, %d : f32
linalg.yield %d, %e : f32, f32
}: !linalg.view<?x?xf32>, !linalg.view<?x?x?xf32>, !linalg.view<?x?x?xf32>
return
}
// CHECK-LABEL: @generic_region
// CHECK: loop.for %[[i:.*]] = {{.*}}
// CHECK: loop.for %[[j:.*]] = {{.*}}
// CHECK: loop.for %[[k:.*]] = {{.*}}
// CHECK: %[[a:.*]] = linalg.load %{{.*}}[%[[i]], %[[j]]] : !linalg.view<?x?xf32>
// CHECK: %[[b:.*]] = linalg.load %{{.*}}[%[[i]], %[[j]], %[[k]]] : !linalg.view<?x?x?xf32>
// CHECK: %[[c:.*]] = linalg.load %{{.*}}[%[[i]], %[[k]], %[[j]]] : !linalg.view<?x?x?xf32>
// CHECK: %[[d:.*]] = mulf %[[a]], %[[b]] : f32
// CHECK: %[[e:.*]] = addf %[[c]], %[[d]] : f32
// CHECK: linalg.store %[[d]], %{{.*}}[%[[i]], %[[j]], %[[k]]] : !linalg.view<?x?x?xf32>
// CHECK: linalg.store %[[e]], %{{.*}}[%[[i]], %[[k]], %[[j]]] : !linalg.view<?x?x?xf32>

View File

@ -179,3 +179,22 @@ func @generic(%arg0: !linalg.view<?x?xvector<3x4xi4>>, %arg1: !linalg.view<?x?x?
// CHECK-LABEL: func @foo
// CHECK-LABEL: func @generic
// CHECK: linalg.generic {fun = @foo, indexing_maps = [#map2, #map3], library_call = "external_function_name", n_loop_types = [3, 0, 0], n_views = [1, 1]} %{{.*}}, %{{.*}} {foo = 1 : i64}: !linalg.view<?x?xvector<3x4xi4>>, !linalg.view<?x?x?xf32>
#trait2 = {
indexing_maps = #accesses,
n_views = [1, 1],
n_loop_types = [3, 0, 0],
library_call = "external_function_name"
}
func @generic_region(%arg0: !linalg.view<?x?xvector<3x4xi4>>, %arg1: !linalg.view<?x?x?xf32>) {
linalg.generic #trait2 %arg0, %arg1 {
^bb(%a: vector<3x4xi4>, %b: f32) :
linalg.yield %b : f32
} {foo = 1}: !linalg.view<?x?xvector<3x4xi4>>, !linalg.view<?x?x?xf32>
return
}
// CHECK-LABEL: func @generic_region
// CHECK: linalg.generic {indexing_maps = [#map2, #map3], library_call = "external_function_name", n_loop_types = [3, 0, 0], n_views = [1, 1]} %{{.*}}, %{{.*}} {
// CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32): // no predecessors
// CHECK: linalg.yield %{{.*}} : f32
// CHECK: } {foo = 1 : i64}: !linalg.view<?x?xvector<3x4xi4>>, !linalg.view<?x?x?xf32>