Revert "[mlir][Linalg] Improve region support in Linalg ops."

This reverts commit 973e133b76.

It triggers an issue in gcc5 that require investigation, the build is
broken with:

/tmp/ccdpj3B9.s: Assembler messages:
/tmp/ccdpj3B9.s:5821: Error: symbol `_ZNSt17_Function_handlerIFvjjEUljjE2_E9_M_invokeERKSt9_Any_dataOjS6_' is already defined
/tmp/ccdpj3B9.s:5860: Error: symbol `_ZNSt14_Function_base13_Base_managerIUljjE2_E10_M_managerERSt9_Any_dataRKS3_St18_Manager_operation' is already defined
This commit is contained in:
Mehdi Amini 2021-02-12 18:15:15 +00:00
parent a7538fee3a
commit 3f22547fd1
11 changed files with 280 additions and 316 deletions

View File

@ -1056,6 +1056,20 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
//===------------------------------------------------------------------===//
// Other static interface methods.
//===------------------------------------------------------------------===//
StaticInterfaceMethod<
/*desc=*/[{
Create an operation of the current type with the given location,
operands, and attributes.
}],
/*retTy=*/"Operation *",
/*methodName=*/"create",
(ins "OpBuilder &":$builder, "Location":$loc, "TypeRange":$resultTypes,
"ValueRange":$operands,
"ArrayRef<NamedAttribute>":$attributes), [{
return builder.create<ConcreteOp>(
loc, resultTypes, operands, attributes);
}]
>,
InterfaceMethod<
/*desc=*/[{
Clone the current operation with the given location and operands. This
@ -1068,13 +1082,14 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
(ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
"ValueRange":$operands),
[{
BlockAndValueMapping bvm;
OperationState state(
loc, ConcreteOp::getOperationName(), operands, resultTypes,
$_op->getAttrs());
for (Region &r : $_op->getRegions())
r.cloneInto(state.addRegion(), bvm);
return b.createOperation(state);
BlockAndValueMapping map;
unsigned numRegions = $_op->getNumRegions();
Operation *res = create(b, loc, resultTypes, operands, $_op->getAttrs());
assert(res->getNumRegions() == numRegions && "inconsistent # regions");
for (unsigned ridx = 0; ridx < numRegions; ++ridx)
$_op->getRegion(ridx).cloneInto(
&res->getRegion(ridx), map);
return res;
}]
>,
StaticInterfaceMethod<
@ -1083,7 +1098,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
Returns a null function if this named op does not define a region
builder.
}],
/*retTy=*/"std::function<void(Block &, ValueRange)>",
/*retTy=*/"std::function<void(Block &)>",
/*methodName=*/"getRegionBuilder",
(ins),
[{ return ConcreteOp::getRegionBuilder(); }]

View File

@ -110,13 +110,14 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
AnyStridedMemRef:$output,
OptionalAttr<AffineMapAttr>:$inputPermutation,
OptionalAttr<AffineMapAttr>:$outputPermutation);
let regions = (region AnyRegion:$region);
let builders = [
OpBuilderDAG<(ins "Value":$input, "Value":$output,
CArg<"AffineMap", "AffineMap()">:$inputPermutation,
CArg<"AffineMap", "AffineMap()">:$outputPermutation,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
// TODO: this should go away once the usage of OptionalAttr triggers emission
// of builders with default arguments left unspecified.
let builders = [OpBuilderDAG<(ins "Value":$input, "Value":$output),
[{
return build(
$_builder, $_state, input, output, AffineMapAttr(), AffineMapAttr());
}]>];
let extraClassDeclaration = structuredOpsDecls # [{
ValueRange inputs() { return getOperands().take_front(); }
@ -145,31 +146,24 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
Value getSource() { return input();}
Value getTarget() { return output(); }
static void regionBuilder(Block &block, ValueRange captures);
static std::function<void(Block &block, ValueRange captures)>
getRegionBuilder() {
return &regionBuilder;
static std::function<void(Block &)> getRegionBuilder() {
return nullptr;
}
static unsigned getNumRegionArgs() { return 2; }
}];
let verifier = [{ return ::verify(*this); }];
let assemblyFormat = [{
`(` $input `,` $output `)` attr-dict `:`
type($input) `,` type($output)
custom<CopyOpRegion>($region, ref(type($input)), ref(type($input)))
`(` operands `)` attr-dict `:` type(operands)
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
let skipDefaultBuilders = 1;
}
def FillOp : LinalgStructured_Op<"fill", []> {
let arguments = (ins AnyShaped:$output,
AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value);
let results = (outs Optional<AnyRankedTensor>:$result);
let regions = (region AnyRegion:$region);
let extraClassDeclaration = structuredOpsDecls # [{
ValueRange inputs() { return {}; }
ValueRange outputs() { return getOperands().take_front(); }
@ -189,18 +183,13 @@ def FillOp : LinalgStructured_Op<"fill", []> {
extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)});
}
static void regionBuilder(Block &block, ValueRange captures);
static std::function<void(Block &block, ValueRange captures)>
getRegionBuilder() {
return &regionBuilder;
static std::function<void(Block &)> getRegionBuilder() {
return nullptr;
}
static unsigned getNumRegionArgs() { return 1; }
}];
let assemblyFormat = [{
`(` $output `,` $value `)` attr-dict `:`
type($output) `,` type($value) (`->` type($result)^)?
custom<FillOpRegion>($region, ref(type($output)), ref($value))
`(` operands `)` attr-dict `:` type(operands) (`->` type($result)^)?
}];
let builders = [
@ -279,8 +268,7 @@ class PoolingBase_Op<string mnemonic, list<OpTrait> props>
return padding().getValue().getValue<int64_t>({i, 1});
}
static std::function<void(Block &, ValueRange captures)> getRegionBuilder()
{
static std::function<void(Block &)> getRegionBuilder() {
return nullptr;
}
}];
@ -531,7 +519,7 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
library_call()->str() : "op_has_no_registered_library_name";
}
static std::function<void(Block &, ValueRange)> getRegionBuilder() {
static std::function<void(Block &)> getRegionBuilder() {
return nullptr;
}
}];

View File

@ -154,13 +154,7 @@ LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite(
if (in == op.input() && out == op.output())
return failure();
auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
if (!libraryCallName)
return failure();
rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, libraryCallName.getValue(), TypeRange(),
createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), {in, out}));
rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
return success();
}

View File

@ -27,6 +27,8 @@ Operation *mlir::edsc::makeGenericLinalgOp(
ArrayRef<StructuredIndexed> outputs, TypeRange resultTensorTypes,
function_ref<void(ValueRange)> regionBuilder, ArrayRef<Value> otherValues,
ArrayRef<Attribute> otherAttributes) {
OpBuilder &builder = edsc::ScopedContext::getBuilderRef();
// Build maps
SmallVector<SmallVector<AffineExpr, 4>, 4> exprsList;
exprsList.reserve(inputs.size() + outputs.size());
@ -52,10 +54,13 @@ Operation *mlir::edsc::makeGenericLinalgOp(
resultTensorTypes,
inputValues,
outputValues,
maps,
iteratorStrTypes,
""/*doc*/,
""/*library_call*/)
builder.getAffineMapArrayAttr(maps),
builder.getStrArrayAttr(iteratorStrTypes),
StringAttr() /*doc*/,
StringAttr() /*library_call*/,
ArrayAttr() /*sparse*/
/* TODO: other attributes in op */
)
.getOperation();
// clang-format on

View File

@ -33,53 +33,32 @@ using namespace mlir;
using namespace mlir::linalg;
/// Forward declarations.
/// Generic entry point to create the block for the region of a LinalgOp.
/// This is used by both named structured ops created by ods-gen and by manually
/// defined C++ ops.
/// This is used by both builders and parsers.
/// This function creates the block in the region with arguments corresponding
/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
/// to be ShapedType.
template <typename NamedStructuredOpType>
static void fillStructuredOpRegion(
OpBuilder &opBuilder, Region &region, TypeRange inputTypes,
TypeRange outputTypes, ValueRange captures = {},
std::function<void(unsigned, unsigned)> errorHandler = [](unsigned,
unsigned) {});
static void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
OperationState &result,
TypeRange inputTypes,
TypeRange outputTypes);
/// Generic entry point to create both the region and the block of a LinalgOp.
template <typename NamedStructuredOpType>
static void
createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result,
TypeRange inputTypes, TypeRange outputTypes,
ValueRange captures = {});
/// Common parsing and printing used for both named structured ops created by
/// ods-gen and by manually defined C++ ops. Does not handle regions.
static ParseResult
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
SmallVectorImpl<Type> &inputTypes,
SmallVectorImpl<Type> &outputTypes);
template <typename NamedStructuredOpType>
static void printCommonStructuredOpParts(OpAsmPrinter &p,
NamedStructuredOpType op);
/// Specific parsing and printing for named structured ops created by ods-gen.
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
TypeRange inputTypes, TypeRange outputTypes,
ArrayRef<OpAsmParser::OperandType> captures = {});
TypeRange inputTypes, TypeRange outputTypes);
static ParseResult
parseNamedStructuredOpResults(OpAsmParser &parser,
SmallVectorImpl<Type> &resultTypes);
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
ArrayRef<OpAsmParser::OperandType> captures = {});
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
OperationState &result);
template <typename NamedStructuredOpType>
static void printCommonStructuredOpParts(OpAsmPrinter &p,
NamedStructuredOpType op);
static void printNamedStructuredOpResults(OpAsmPrinter &p,
TypeRange resultTypes);
@ -122,136 +101,14 @@ static LogicalResult foldMemRefCast(Operation *op) {
return success(folded);
}
//===----------------------------------------------------------------------===//
// CopyOp
//===----------------------------------------------------------------------===//
void CopyOp::regionBuilder(Block &block, ValueRange captures) {
using namespace edsc::intrinsics;
assert(block.getNumArguments() == 2 && "CopyOp regionBuilder expects 2 args");
(linalg_yield(block.getArgument(0)));
}
void CopyOp::build(OpBuilder &builder, OperationState &result, Value input,
Value output, AffineMap inputPermutation,
AffineMap outputPermutation,
ArrayRef<NamedAttribute> namedAttrs) {
result.addOperands({input, output});
result.addAttributes(namedAttrs);
if (inputPermutation)
result.addAttribute("inputPermutation",
AffineMapAttr::get(inputPermutation));
if (outputPermutation)
result.addAttribute("outputPermutation",
AffineMapAttr::get(outputPermutation));
result.addRegion();
fillStructuredOpRegion<CopyOp>(builder, *result.regions.front(),
TypeRange{input.getType()},
TypeRange{output.getType()});
}
ParseResult parseCopyOpRegion(OpAsmParser &parser, Region &r, Type inputType,
Type outputType) {
OpBuilder opBuilder(parser.getBuilder().getContext());
fillStructuredOpRegion<CopyOp>(opBuilder, r, TypeRange{inputType},
TypeRange{outputType});
return success();
}
/// CopyOp region is elided when printing.
void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
static LogicalResult verify(CopyOp op) {
auto outputViewType = op.getOutputShapedType(0);
auto inputViewType = op.getInputShapedType(0);
if (inputViewType.getElementType() != outputViewType.getElementType())
return op.emitOpError("expects views of the same type");
if (inputViewType.getRank() != outputViewType.getRank())
return op.emitOpError("expects views of the same rank");
auto rank = op.getNumParallelLoops();
auto inputPermutationMap = op.inputPermutation();
if (inputPermutationMap) {
if (inputPermutationMap->getNumInputs() != rank)
return op.emitOpError("expects optional input_permutation map of rank ")
<< rank;
if (!inputPermutationMap->isPermutation())
return op.emitOpError(
"expects optional input_permutation map to be a permutation");
}
auto outputPermutationMap = op.outputPermutation();
if (outputPermutationMap) {
if (outputPermutationMap->getNumInputs() != rank)
return op.emitOpError("expects optional output_permutation map of rank ")
<< rank;
if (!outputPermutationMap->isPermutation())
return op.emitOpError(
"expects optional output_permutation map to be a permutation");
}
if (rank == 0 && inputPermutationMap)
return op.emitOpError("expected no input permutation when rank == 0");
if (rank == 0 && outputPermutationMap)
return op.emitOpError("expected no output permutation when rank == 0");
return success();
}
void CopyOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.emplace_back(MemoryEffects::Read::get(), input(),
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Write::get(), output(),
SideEffects::DefaultResource::get());
}
//===----------------------------------------------------------------------===//
// FillOp
//===----------------------------------------------------------------------===//
void FillOp::regionBuilder(Block &block, ValueRange captures) {
using namespace edsc::intrinsics;
assert(captures.size() == 1 && "FillOp regionBuilder expects 1 capture");
(linalg_yield(captures));
}
void FillOp::build(OpBuilder &builder, OperationState &result, Value output,
Value value) {
build(builder, result, output.getType().dyn_cast<RankedTensorType>(), output,
value);
fillStructuredOpRegion<FillOp>(builder, *result.regions.front(), TypeRange{},
TypeRange{output.getType()}, value);
}
ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type outputType,
OpAsmParser::OperandType valueRef) {
OpBuilder opBuilder(parser.getBuilder().getContext());
// Resolve `valueRef` into `value` at parse time so we can build the region
// with captures.
SmallVector<Value> value;
parser.resolveOperand(valueRef, getElementTypeOrSelf(outputType), value);
fillStructuredOpRegion<FillOp>(opBuilder, r, TypeRange{},
TypeRange{outputType}, value);
return success();
}
/// FillOp region is elided when printing.
void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Value) {}
static LogicalResult verify(FillOp op) {
auto viewType = op.getOutputShapedType(0);
auto fillType = op.value().getType();
if (viewType.getElementType() != fillType)
return op.emitOpError("expects fill type to match view elemental type");
if (!op.getNumResults() && !viewType.isa<MemRefType>()) {
return op.emitOpError(
"expected fill op with no result value to use memref type");
}
return success();
}
void FillOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
if (output().getType().isa<MemRefType>())
effects.emplace_back(MemoryEffects::Write::get(), output(),
SideEffects::DefaultResource::get());
}
//===----------------------------------------------------------------------===//
@ -576,6 +433,7 @@ void InitTensorOp::build(OpBuilder &b, OperationState &result,
result.addAttributes(attrs);
}
static LogicalResult verify(InitTensorOp op) {
RankedTensorType resultType = op.getType();
SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range(
@ -1556,6 +1414,68 @@ static LogicalResult verify(linalg::YieldOp op) {
/////// Operations corresponding to library calls defined with Tablegen ////////
void FillOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
if (output().getType().isa<MemRefType>())
effects.emplace_back(MemoryEffects::Write::get(), output(),
SideEffects::DefaultResource::get());
}
static LogicalResult verify(FillOp op) {
auto viewType = op.getOutputShapedType(0);
auto fillType = op.value().getType();
if (viewType.getElementType() != fillType)
return op.emitOpError("expects fill type to match view elemental type");
if (!op.getNumResults() && !viewType.isa<MemRefType>()) {
return op.emitOpError(
"expected fill op with no result value to use memref type");
}
return success();
}
void CopyOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.emplace_back(MemoryEffects::Read::get(), input(),
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Write::get(), output(),
SideEffects::DefaultResource::get());
}
static LogicalResult verify(CopyOp op) {
auto outputViewType = op.getOutputShapedType(0);
auto inputViewType = op.getInputShapedType(0);
if (inputViewType.getElementType() != outputViewType.getElementType())
return op.emitOpError("expects views of the same type");
if (inputViewType.getRank() != outputViewType.getRank())
return op.emitOpError("expects views of the same rank");
auto rank = op.getNumParallelLoops();
auto inputPermutationMap = op.inputPermutation();
if (inputPermutationMap) {
if (inputPermutationMap->getNumInputs() != rank)
return op.emitOpError("expects optional input_permutation map of rank ")
<< rank;
if (!inputPermutationMap->isPermutation())
return op.emitOpError(
"expects optional input_permutation map to be a permutation");
}
auto outputPermutationMap = op.outputPermutation();
if (outputPermutationMap) {
if (outputPermutationMap->getNumInputs() != rank)
return op.emitOpError("expects optional output_permutation map of rank ")
<< rank;
if (!outputPermutationMap->isPermutation())
return op.emitOpError(
"expects optional output_permutation map to be a permutation");
}
if (rank == 0 && inputPermutationMap)
return op.emitOpError("expected no input permutation when rank == 0");
if (rank == 0 && outputPermutationMap)
return op.emitOpError("expected no output permutation when rank == 0");
return success();
}
template <typename LinalgPoolingOp>
static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op,
ArrayRef<Attribute> attrs,
@ -1788,25 +1708,14 @@ OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute> operands) {
}
//===----------------------------------------------------------------------===//
// Support for named Linalg ops defined in ods-gen.
// Auto-generated Linalg named ops.
//===----------------------------------------------------------------------===//
/// Generic entry point to create the block for the region of a LinalgOp.
/// This is used by both named structured ops created by ods-gen and by manually
/// defined C++ ops.
/// This is used by both builders and parsers.
/// This function creates the block in the region with arguments corresponding
/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
/// to be ShapedType.
template <typename NamedStructuredOpType>
static void
fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
TypeRange inputTypes, TypeRange outputTypes,
ValueRange captures,
std::function<void(unsigned, unsigned)> errorHandler) {
assert(llvm::all_of(inputTypes, [](Type t) { return t.isa<ShapedType>(); }));
assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
static void buildNamedStructuredOpRegionAndAttributesImpl(
OpBuilder &opBuilder, Region &region, TypeRange inputTypes,
TypeRange outputTypes,
std::function<void(unsigned, unsigned)> errorHandler) {
// TODO: atm all operands go through getElementTypeOrSelf,
// reconsider when we have evidence we need to.
SmallVector<Type, 8> argTypes;
@ -1816,7 +1725,7 @@ fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
// RAII.
OpBuilder::InsertionGuard guard(opBuilder);
Block *body = opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes);
Block *body = opBuilder.createBlock(&region, {}, argTypes);
unsigned actual = body->getNumArguments();
unsigned expected = NamedStructuredOpType::getNumRegionArgs();
if (expected != actual)
@ -1824,30 +1733,53 @@ fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
opBuilder.setInsertionPointToStart(body);
mlir::edsc::ScopedContext scope(opBuilder, opBuilder.getUnknownLoc());
NamedStructuredOpType::regionBuilder(*body, captures);
NamedStructuredOpType::regionBuilder(*body);
// indexing_maps is an auto-generated method.
// iterator_types is an auto-generated method.
}
/// Generic entry point to create both the region and the block of a LinalgOp.
template <typename NamedStructuredOpType>
void createAndFillStructuredOpRegion(OpBuilder &opBuilder,
OperationState &result,
TypeRange inputTypes,
TypeRange outputTypes,
ValueRange captures) {
void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
OperationState &result,
TypeRange inputTypes,
TypeRange outputTypes) {
Region &region = *result.addRegion();
fillStructuredOpRegion<NamedStructuredOpType>(
opBuilder, region, inputTypes, outputTypes, captures,
buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
opBuilder, region, inputTypes, outputTypes,
[&](unsigned expected, unsigned actual) {
llvm::errs() << "region expects " << expected << " args, got "
<< actual;
assert(expected != actual && "incorrect number of arguments");
});
}
/// Common parsing used for both named structured ops created by ods-gen and by
/// manually defined C++ ops. Does not handle regions.
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
TypeRange inputTypes, TypeRange outputTypes) {
ParseResult res = success();
OpBuilder opBuilder(parser.getBuilder().getContext());
buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
opBuilder, region, inputTypes, outputTypes,
[&](unsigned expected, unsigned actual) {
res = parser.emitError(parser.getCurrentLocation(),
llvm::formatv("region expects {0} args, got {1}",
expected, actual));
});
return res;
}
static ParseResult
parseNamedStructuredOpResults(OpAsmParser &parser,
SmallVectorImpl<Type> &resultTypes) {
if (succeeded(parser.parseOptionalArrow()))
if (parser.parseTypeList(resultTypes))
return failure();
return success();
}
static ParseResult
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
SmallVectorImpl<Type> &inputTypes,
@ -1888,56 +1820,8 @@ parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
}
template <typename NamedStructuredOpType>
static void printCommonStructuredOpParts(OpAsmPrinter &p,
NamedStructuredOpType op) {
if (!op.inputs().empty())
p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
if (!op.outputs().empty())
p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")";
}
//===----------------------------------------------------------------------===//
// Specific parsing and printing for named structured ops created by ods-gen.
//===----------------------------------------------------------------------===//
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
TypeRange inputTypes, TypeRange outputTypes,
ArrayRef<OpAsmParser::OperandType> captures) {
ParseResult res = success();
OpBuilder opBuilder(parser.getBuilder().getContext());
// Resolve `captures` into `capturedValues` at parse time so we can build the
// region with captures.
SmallVector<Value> capturedValues;
fillStructuredOpRegion<NamedStructuredOpType>(
opBuilder, region, inputTypes, outputTypes, capturedValues,
[&](unsigned expected, unsigned actual) {
res = parser.emitError(
parser.getCurrentLocation(),
llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
"region expects {0} args, got {1}",
expected, actual));
region.front().dump();
});
return res;
}
static ParseResult
parseNamedStructuredOpResults(OpAsmParser &parser,
SmallVectorImpl<Type> &resultTypes) {
if (succeeded(parser.parseOptionalArrow()))
if (parser.parseTypeList(resultTypes))
return failure();
return success();
}
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
ArrayRef<OpAsmParser::OperandType> captures) {
// TODO: Enable when ods-gen supports captures.
assert(captures.empty() && "unexpected captures for named structured ops");
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<Type, 1> inputTypes, outputTypes;
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
return failure();
@ -1951,7 +1835,7 @@ parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
std::unique_ptr<Region> region = std::make_unique<Region>();
if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
parser, *region, inputTypes, outputTypes, captures))
parser, *region, inputTypes, outputTypes))
return failure();
result.addRegion(std::move(region));
@ -1965,6 +1849,15 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p,
p.printOptionalArrowTypeList(resultTypes);
}
template <typename NamedStructuredOpType>
static void printCommonStructuredOpParts(OpAsmPrinter &p,
NamedStructuredOpType op) {
if (!op.inputs().empty())
p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
if (!op.outputs().empty())
p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")";
}
template <typename NamedStructuredOpType>
static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
p << op.getOperationName();
@ -1986,10 +1879,6 @@ static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) {
return verifyGenericOp<NamedStructuredOpType>(op);
}
//===----------------------------------------------------------------------===//
// Canonicalizers and Folders.
//===----------------------------------------------------------------------===//
namespace {
struct EraseDeadLinalgOp : public RewritePattern {
EraseDeadLinalgOp(PatternBenefit benefit = 1)

View File

@ -49,7 +49,7 @@ static linalg::GenericOp createGenericOpFromNamedOp(linalg::LinalgOp namedOp,
indexingMaps, iterators,
[&regionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
edsc::ScopedContext scope(bodyBuilder, loc);
regionBuilder(*bodyBuilder.getBlock(), /*captures=*/{});
regionBuilder(*bodyBuilder.getBlock());
});
}

View File

@ -52,6 +52,14 @@ static SmallVector<Value, 8> makeCanonicalAffineApplies(OpBuilder &b,
return res;
}
static SmallVector<Value, 4> permuteIvs(ArrayRef<Value> ivs,
Optional<AffineMap> permutation) {
return permutation ? applyMapToValues(ScopedContext::getBuilderRef(),
ScopedContext::getLocation(),
permutation.getValue(), ivs)
: SmallVector<Value, 4>(ivs.begin(), ivs.end());
}
template <typename IndexedValueType, typename OpType>
static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues,
ArrayRef<SmallVector<Value, 8>> indexing,
@ -170,6 +178,40 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs,
outputBuffers);
}
template <typename IndexedValueType>
static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
assert(copyOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto nPar = copyOp.getNumParallelLoops();
assert(nPar == allIvs.size());
auto inputIvs =
permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation());
auto outputIvs =
permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation());
SmallVector<Value, 8> iivs(inputIvs.begin(), inputIvs.end());
SmallVector<Value, 8> oivs(outputIvs.begin(), outputIvs.end());
IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0));
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
// an n-D loop nest; with or without permutations.
// clang-format off
nPar > 0 ? O(oivs) = I(iivs) :
O() = I();
// clang-format on
}
template <typename IndexedValueType>
static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
assert(fillOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto nPar = fillOp.getNumParallelLoops();
assert(nPar == allIvs.size());
auto ivs = SmallVector<Value, 4>(allIvs.begin(), allIvs.begin() + nPar);
IndexedValueType O(fillOp.getOutputBuffer(0));
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
// an n-D loop nest; with or without permutations.
nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value();
}
// Create a padded view into the given `input` tensor using the 'indices'
// to access the tensor. `skipPadding` lists the dimensions for which no padding
// is needed e.g. the non-spatial dimensions for convolutions.
@ -491,8 +533,8 @@ linalgOpToLoopsImpl(Operation *op, OpBuilder &builder,
assert(iterArgs.empty() && "unexpected iterArgs");
allIvs.append(ivs.begin(), ivs.end());
llvm::TypeSwitch<Operation *>(op)
.Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp,
IndexedGenericOp, LinalgOp>([&](auto op) {
.Case<CopyOp, FillOp, ConvOp, PoolingMaxOp, PoolingMinOp,
PoolingSumOp, IndexedGenericOp, LinalgOp>([&](auto op) {
emitScalarImplementation<IndexedValueTy>(allIvs, op);
})
.Default([&](Operation *op) { assert(false && "unexpected op"); });

View File

@ -267,7 +267,7 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
llvm::map_range(linalgOp.getShapedOperandTypes(),
[](ShapedType t) { return t.getElementType(); }));
block->addArguments(elementTypes);
linalgOp.getRegionBuilder()(*block, /*captures=*/{});
linalgOp.getRegionBuilder()(*block);
}
Block *block = &region->front();
@ -333,26 +333,24 @@ static bool hasOnlyScalarElementwiseOp(Region &r) {
// Return true if the op is an element-wise linalg op.
static bool isElementwise(Operation *op) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
if (!linalgOp)
auto genericOp = dyn_cast<linalg::GenericOp>(op);
if (!genericOp)
return false;
if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
if (genericOp.getNumLoops() != genericOp.getNumParallelLoops())
return false;
// TODO: relax the restrictions on indexing map.
for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) {
if (!linalgOp.getOutputIndexingMap(i).isIdentity())
for (unsigned i = 0, e = genericOp.getNumOutputs(); i < e; i++) {
if (!genericOp.getOutputIndexingMap(i).isIdentity())
return false;
}
// Currently bound the input indexing map to minor identity as other
// permutations might require adding transpose ops to convert the vector read
// to the right shape.
for (unsigned i = 0, e = linalgOp.getNumInputs(); i < e; i++) {
if (!linalgOp.getInputIndexingMap(i).isMinorIdentity())
for (unsigned i = 0, e = genericOp.getNumInputs(); i < e; i++) {
if (!genericOp.getInputIndexingMap(i).isMinorIdentity())
return false;
}
if (linalgOp->getNumRegions() != 1)
return false;
return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
return hasOnlyScalarElementwiseOp(genericOp.getRegion());
}
static Optional<VectorizedLinalgOp> vectorizeContraction(OpBuilder &builder,
@ -395,6 +393,9 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
for (Type outputTensorType : linalgOp.getOutputTensorTypes())
if (!outputTensorType.cast<ShapedType>().hasStaticShape())
return failure();
if (isa<linalg::FillOp, linalg::CopyOp>(op))
return success();
if (isElementwise(op))
return success();
return success(isaContractionOpInterface(linalgOp));
@ -406,12 +407,43 @@ Optional<VectorizedLinalgOp> mlir::linalg::vectorizeLinalgOp(OpBuilder &builder,
return llvm::None;
edsc::ScopedContext scope(builder, op->getLoc());
// In the case of 0-D memrefs, return null and special case to scalar load or
// store later.
if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
// Vectorize fill as a vector.broadcast.
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
<< "Rewrite linalg.fill as vector.broadcast: " << *op);
VectorizedLinalgOp res;
if (Value v = buildVectorWrite(builder, fillOp.value(), fillOp.output()))
res.tensorResults.push_back(v);
return res;
}
if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
// Vectorize copy as a vector.transfer_read+vector.transfer_write.
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
<< "Rewrite linalg.copy as vector.transfer_read + "
"vector.transfer_write: "
<< *op);
Value vector = buildVectorRead(builder, copyOp.input());
VectorizedLinalgOp res;
if (Value v = buildVectorWrite(builder, vector, copyOp.output()))
res.tensorResults.push_back(v);
return res;
}
if (isElementwise(op)) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
<< "Vectorize linalg op as a generic: " << *op);
return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
}
// TODO: as soon as Copy and FillOp. get a region builder, replace all the
// above by:
// if (isa<FillOp, CopyOp>(op) || isElementwise(op)) {
// LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
// << "Vectorize linalg op as a generic: " << *op);
// return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
// }
return vectorizeContraction(builder, cast<LinalgOp>(op));
}

View File

@ -1,4 +1,5 @@
// RUN: mlir-opt -copy-removal -split-input-file %s | FileCheck %s
// RUN: mlir-opt -copy-removal -split-input-file %s
//| FileCheck %s
// All linalg copies except the linalg.copy(%1, %9) must be removed since the
// defining operation of %1 and its DeallocOp have been defined in another block.
@ -255,7 +256,7 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>, %result: memref<2xf32>)
%tmp2 = math.exp %gen2_arg0 : f32
linalg.yield %tmp2 : f32
}
linalg.copy(%temp, %result) : memref<2xf32>, memref<2xf32>
"linalg.copy"(%temp, %result) : (memref<2xf32>, memref<2xf32>) -> ()
dealloc %temp : memref<2xf32>
// CHECK: return
return
@ -291,7 +292,7 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>){
linalg.yield %tmp2 : f32
}
// CHECK: linalg.copy
linalg.copy(%temp, %to) : memref<2xf32>, memref<2xf32>
"linalg.copy"(%temp, %to) : (memref<2xf32>, memref<2xf32>) -> ()
dealloc %temp : memref<2xf32>
return
}
@ -354,7 +355,7 @@ func @check_with_affine_dialect(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg
}
// CHECK-NOT: linalg.copy
// CHECK-NOT: dealloc
linalg.copy(%0, %arg2) : memref<4xf32>, memref<4xf32>
"linalg.copy"(%0, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
dealloc %0 : memref<4xf32>
//CHECK: return
return

View File

@ -23,7 +23,7 @@
// IMPL-NEXT: map2 = simplifyAffineMap(map2);
// IMPL-NEXT: return {{.+}}.getAffineMapArrayAttr({ map0, map1, map2 });
//
// IMPL: void Test1Op::regionBuilder(Block &block, ValueRange captures) {
// IMPL: void Test1Op::regionBuilder(Block &block) {
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]);
// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]);
@ -47,7 +47,7 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
// IMPL: AffineMap::get(3, 3, {d2, d1}, context)
// IMPL: AffineMap::get(3, 3, {d0, d1}, context)
//
// IMPL: Test2Op::regionBuilder(Block &block, ValueRange captures) {
// IMPL: Test2Op::regionBuilder(Block &block) {
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]);
// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]);
@ -71,7 +71,7 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
// IMPL: AffineMap::get(4, 4, {d3, d2}, context)
// IMPL: AffineMap::get(4, 4, {d0, d1, d2}, context)
//
// IMPL: Test3Op::regionBuilder(Block &block, ValueRange captures) {
// IMPL: Test3Op::regionBuilder(Block &block) {
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]);
// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]);

View File

@ -1871,11 +1871,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
$_builder.getI32VectorAttr({{
static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())}));
createAndFillStructuredOpRegion<{0}>(
buildNamedStructuredOpRegionAndAttributes<{0}>(
$_builder,
$_state,
TypeRange(inputs),
TypeRange(outputs)/*, TODO: support captures*/);
TypeRange(outputs));
}]>,
OpBuilderDAG<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
@ -1889,11 +1889,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
$_builder.getI32VectorAttr({{
static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())}));
createAndFillStructuredOpRegion<{0}>(
buildNamedStructuredOpRegionAndAttributes<{0}>(
$_builder,
$_state,
TypeRange(inputs),
TypeRange(outputs)/*, TODO: support captures*/);
TypeRange(outputs));
}]>,
OpBuilderDAG<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
@ -1907,9 +1907,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
{6}
];
let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
let parser = [{{
return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/);
}];
let parser = [{{ return ::parseNamedStructuredOp<{0}>(parser, result); }];
let hasFolder = 1;
let hasCanonicalizer = 1;
@ -1917,8 +1915,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
// Auto-generated.
ArrayAttr iterator_types();
ArrayAttr indexing_maps();
static void regionBuilder(Block &block, ValueRange captures);
static std::function<void(Block &, ValueRange)> getRegionBuilder() {{
static void regionBuilder(Block &block);
static std::function<void(Block &)> getRegionBuilder() {{
return regionBuilder;
}
@ -1982,11 +1980,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
$_builder.getI32VectorAttr({{
static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())}));
createAndFillStructuredOpRegion<{0}>(
buildNamedStructuredOpRegionAndAttributes<{0}>(
$_builder,
$_state,
TypeRange(inputs),
TypeRange(outputs)/*, TODO: support captures*/);
TypeRange(outputs));
{2}
}]>
)FMT";
@ -2313,7 +2311,7 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
};
const char *regionBuilderFmt = R"FMT(
void {0}::regionBuilder(Block &block, ValueRange captures) {
void {0}::regionBuilder(Block &block) {
using namespace edsc;
using namespace intrinsics;
auto args = block.getArguments();