forked from OSchip/llvm-project
Revert "[mlir][Linalg] Improve region support in Linalg ops."
This reverts commit 973e133b76
.
It triggers an issue in gcc5 that require investigation, the build is
broken with:
/tmp/ccdpj3B9.s: Assembler messages:
/tmp/ccdpj3B9.s:5821: Error: symbol `_ZNSt17_Function_handlerIFvjjEUljjE2_E9_M_invokeERKSt9_Any_dataOjS6_' is already defined
/tmp/ccdpj3B9.s:5860: Error: symbol `_ZNSt14_Function_base13_Base_managerIUljjE2_E10_M_managerERSt9_Any_dataRKS3_St18_Manager_operation' is already defined
This commit is contained in:
parent
a7538fee3a
commit
3f22547fd1
|
@ -1056,6 +1056,20 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
||||||
//===------------------------------------------------------------------===//
|
//===------------------------------------------------------------------===//
|
||||||
// Other static interface methods.
|
// Other static interface methods.
|
||||||
//===------------------------------------------------------------------===//
|
//===------------------------------------------------------------------===//
|
||||||
|
StaticInterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Create an operation of the current type with the given location,
|
||||||
|
operands, and attributes.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"Operation *",
|
||||||
|
/*methodName=*/"create",
|
||||||
|
(ins "OpBuilder &":$builder, "Location":$loc, "TypeRange":$resultTypes,
|
||||||
|
"ValueRange":$operands,
|
||||||
|
"ArrayRef<NamedAttribute>":$attributes), [{
|
||||||
|
return builder.create<ConcreteOp>(
|
||||||
|
loc, resultTypes, operands, attributes);
|
||||||
|
}]
|
||||||
|
>,
|
||||||
InterfaceMethod<
|
InterfaceMethod<
|
||||||
/*desc=*/[{
|
/*desc=*/[{
|
||||||
Clone the current operation with the given location and operands. This
|
Clone the current operation with the given location and operands. This
|
||||||
|
@ -1068,13 +1082,14 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
||||||
(ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
|
(ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
|
||||||
"ValueRange":$operands),
|
"ValueRange":$operands),
|
||||||
[{
|
[{
|
||||||
BlockAndValueMapping bvm;
|
BlockAndValueMapping map;
|
||||||
OperationState state(
|
unsigned numRegions = $_op->getNumRegions();
|
||||||
loc, ConcreteOp::getOperationName(), operands, resultTypes,
|
Operation *res = create(b, loc, resultTypes, operands, $_op->getAttrs());
|
||||||
$_op->getAttrs());
|
assert(res->getNumRegions() == numRegions && "inconsistent # regions");
|
||||||
for (Region &r : $_op->getRegions())
|
for (unsigned ridx = 0; ridx < numRegions; ++ridx)
|
||||||
r.cloneInto(state.addRegion(), bvm);
|
$_op->getRegion(ridx).cloneInto(
|
||||||
return b.createOperation(state);
|
&res->getRegion(ridx), map);
|
||||||
|
return res;
|
||||||
}]
|
}]
|
||||||
>,
|
>,
|
||||||
StaticInterfaceMethod<
|
StaticInterfaceMethod<
|
||||||
|
@ -1083,7 +1098,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
||||||
Returns a null function if this named op does not define a region
|
Returns a null function if this named op does not define a region
|
||||||
builder.
|
builder.
|
||||||
}],
|
}],
|
||||||
/*retTy=*/"std::function<void(Block &, ValueRange)>",
|
/*retTy=*/"std::function<void(Block &)>",
|
||||||
/*methodName=*/"getRegionBuilder",
|
/*methodName=*/"getRegionBuilder",
|
||||||
(ins),
|
(ins),
|
||||||
[{ return ConcreteOp::getRegionBuilder(); }]
|
[{ return ConcreteOp::getRegionBuilder(); }]
|
||||||
|
|
|
@ -110,13 +110,14 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
|
||||||
AnyStridedMemRef:$output,
|
AnyStridedMemRef:$output,
|
||||||
OptionalAttr<AffineMapAttr>:$inputPermutation,
|
OptionalAttr<AffineMapAttr>:$inputPermutation,
|
||||||
OptionalAttr<AffineMapAttr>:$outputPermutation);
|
OptionalAttr<AffineMapAttr>:$outputPermutation);
|
||||||
let regions = (region AnyRegion:$region);
|
|
||||||
|
|
||||||
let builders = [
|
// TODO: this should go away once the usage of OptionalAttr triggers emission
|
||||||
OpBuilderDAG<(ins "Value":$input, "Value":$output,
|
// of builders with default arguments left unspecified.
|
||||||
CArg<"AffineMap", "AffineMap()">:$inputPermutation,
|
let builders = [OpBuilderDAG<(ins "Value":$input, "Value":$output),
|
||||||
CArg<"AffineMap", "AffineMap()">:$outputPermutation,
|
[{
|
||||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
|
return build(
|
||||||
|
$_builder, $_state, input, output, AffineMapAttr(), AffineMapAttr());
|
||||||
|
}]>];
|
||||||
|
|
||||||
let extraClassDeclaration = structuredOpsDecls # [{
|
let extraClassDeclaration = structuredOpsDecls # [{
|
||||||
ValueRange inputs() { return getOperands().take_front(); }
|
ValueRange inputs() { return getOperands().take_front(); }
|
||||||
|
@ -145,31 +146,24 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
|
||||||
Value getSource() { return input();}
|
Value getSource() { return input();}
|
||||||
Value getTarget() { return output(); }
|
Value getTarget() { return output(); }
|
||||||
|
|
||||||
static void regionBuilder(Block &block, ValueRange captures);
|
static std::function<void(Block &)> getRegionBuilder() {
|
||||||
static std::function<void(Block &block, ValueRange captures)>
|
return nullptr;
|
||||||
getRegionBuilder() {
|
|
||||||
return ®ionBuilder;
|
|
||||||
}
|
}
|
||||||
static unsigned getNumRegionArgs() { return 2; }
|
|
||||||
}];
|
}];
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let verifier = [{ return ::verify(*this); }];
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
`(` $input `,` $output `)` attr-dict `:`
|
`(` operands `)` attr-dict `:` type(operands)
|
||||||
type($input) `,` type($output)
|
|
||||||
custom<CopyOpRegion>($region, ref(type($input)), ref(type($input)))
|
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
let skipDefaultBuilders = 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def FillOp : LinalgStructured_Op<"fill", []> {
|
def FillOp : LinalgStructured_Op<"fill", []> {
|
||||||
let arguments = (ins AnyShaped:$output,
|
let arguments = (ins AnyShaped:$output,
|
||||||
AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value);
|
AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value);
|
||||||
let results = (outs Optional<AnyRankedTensor>:$result);
|
let results = (outs Optional<AnyRankedTensor>:$result);
|
||||||
let regions = (region AnyRegion:$region);
|
|
||||||
let extraClassDeclaration = structuredOpsDecls # [{
|
let extraClassDeclaration = structuredOpsDecls # [{
|
||||||
ValueRange inputs() { return {}; }
|
ValueRange inputs() { return {}; }
|
||||||
ValueRange outputs() { return getOperands().take_front(); }
|
ValueRange outputs() { return getOperands().take_front(); }
|
||||||
|
@ -189,18 +183,13 @@ def FillOp : LinalgStructured_Op<"fill", []> {
|
||||||
extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)});
|
extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)});
|
||||||
}
|
}
|
||||||
|
|
||||||
static void regionBuilder(Block &block, ValueRange captures);
|
static std::function<void(Block &)> getRegionBuilder() {
|
||||||
static std::function<void(Block &block, ValueRange captures)>
|
return nullptr;
|
||||||
getRegionBuilder() {
|
|
||||||
return ®ionBuilder;
|
|
||||||
}
|
}
|
||||||
static unsigned getNumRegionArgs() { return 1; }
|
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
`(` $output `,` $value `)` attr-dict `:`
|
`(` operands `)` attr-dict `:` type(operands) (`->` type($result)^)?
|
||||||
type($output) `,` type($value) (`->` type($result)^)?
|
|
||||||
custom<FillOpRegion>($region, ref(type($output)), ref($value))
|
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let builders = [
|
let builders = [
|
||||||
|
@ -279,8 +268,7 @@ class PoolingBase_Op<string mnemonic, list<OpTrait> props>
|
||||||
return padding().getValue().getValue<int64_t>({i, 1});
|
return padding().getValue().getValue<int64_t>({i, 1});
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::function<void(Block &, ValueRange captures)> getRegionBuilder()
|
static std::function<void(Block &)> getRegionBuilder() {
|
||||||
{
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
@ -531,7 +519,7 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
|
||||||
library_call()->str() : "op_has_no_registered_library_name";
|
library_call()->str() : "op_has_no_registered_library_name";
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::function<void(Block &, ValueRange)> getRegionBuilder() {
|
static std::function<void(Block &)> getRegionBuilder() {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
|
|
@ -154,13 +154,7 @@ LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite(
|
||||||
if (in == op.input() && out == op.output())
|
if (in == op.input() && out == op.output())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
|
rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
|
||||||
if (!libraryCallName)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
|
||||||
op, libraryCallName.getValue(), TypeRange(),
|
|
||||||
createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), {in, out}));
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,8 @@ Operation *mlir::edsc::makeGenericLinalgOp(
|
||||||
ArrayRef<StructuredIndexed> outputs, TypeRange resultTensorTypes,
|
ArrayRef<StructuredIndexed> outputs, TypeRange resultTensorTypes,
|
||||||
function_ref<void(ValueRange)> regionBuilder, ArrayRef<Value> otherValues,
|
function_ref<void(ValueRange)> regionBuilder, ArrayRef<Value> otherValues,
|
||||||
ArrayRef<Attribute> otherAttributes) {
|
ArrayRef<Attribute> otherAttributes) {
|
||||||
|
OpBuilder &builder = edsc::ScopedContext::getBuilderRef();
|
||||||
|
|
||||||
// Build maps
|
// Build maps
|
||||||
SmallVector<SmallVector<AffineExpr, 4>, 4> exprsList;
|
SmallVector<SmallVector<AffineExpr, 4>, 4> exprsList;
|
||||||
exprsList.reserve(inputs.size() + outputs.size());
|
exprsList.reserve(inputs.size() + outputs.size());
|
||||||
|
@ -52,10 +54,13 @@ Operation *mlir::edsc::makeGenericLinalgOp(
|
||||||
resultTensorTypes,
|
resultTensorTypes,
|
||||||
inputValues,
|
inputValues,
|
||||||
outputValues,
|
outputValues,
|
||||||
maps,
|
builder.getAffineMapArrayAttr(maps),
|
||||||
iteratorStrTypes,
|
builder.getStrArrayAttr(iteratorStrTypes),
|
||||||
""/*doc*/,
|
StringAttr() /*doc*/,
|
||||||
""/*library_call*/)
|
StringAttr() /*library_call*/,
|
||||||
|
ArrayAttr() /*sparse*/
|
||||||
|
/* TODO: other attributes in op */
|
||||||
|
)
|
||||||
.getOperation();
|
.getOperation();
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
|
|
|
@ -33,53 +33,32 @@ using namespace mlir;
|
||||||
using namespace mlir::linalg;
|
using namespace mlir::linalg;
|
||||||
|
|
||||||
/// Forward declarations.
|
/// Forward declarations.
|
||||||
|
|
||||||
/// Generic entry point to create the block for the region of a LinalgOp.
|
|
||||||
/// This is used by both named structured ops created by ods-gen and by manually
|
|
||||||
/// defined C++ ops.
|
|
||||||
/// This is used by both builders and parsers.
|
|
||||||
/// This function creates the block in the region with arguments corresponding
|
|
||||||
/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
|
|
||||||
/// to be ShapedType.
|
|
||||||
template <typename NamedStructuredOpType>
|
template <typename NamedStructuredOpType>
|
||||||
static void fillStructuredOpRegion(
|
static void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
|
||||||
OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes,
|
OperationState &result,
|
||||||
TypeRange outputTypes, ValueRange captures = {},
|
TypeRange inputTypes,
|
||||||
std::function<void(unsigned, unsigned)> errorHandler = [](unsigned,
|
TypeRange outputTypes);
|
||||||
unsigned) {});
|
|
||||||
|
|
||||||
/// Generic entry point to create both the region and the block of a LinalgOp.
|
|
||||||
template <typename NamedStructuredOpType>
|
|
||||||
static void
|
|
||||||
createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result,
|
|
||||||
TypeRange inputTypes, TypeRange outputTypes,
|
|
||||||
ValueRange captures = {});
|
|
||||||
|
|
||||||
/// Common parsing and printing used for both named structured ops created by
|
|
||||||
/// ods-gen and by manually defined C++ ops. Does not handle regions.
|
|
||||||
static ParseResult
|
static ParseResult
|
||||||
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
|
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
|
||||||
SmallVectorImpl<Type> &inputTypes,
|
SmallVectorImpl<Type> &inputTypes,
|
||||||
SmallVectorImpl<Type> &outputTypes);
|
SmallVectorImpl<Type> &outputTypes);
|
||||||
template <typename NamedStructuredOpType>
|
|
||||||
static void printCommonStructuredOpParts(OpAsmPrinter &p,
|
|
||||||
NamedStructuredOpType op);
|
|
||||||
|
|
||||||
/// Specific parsing and printing for named structured ops created by ods-gen.
|
|
||||||
template <typename NamedStructuredOpType>
|
template <typename NamedStructuredOpType>
|
||||||
static ParseResult
|
static ParseResult
|
||||||
parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion,
|
parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion,
|
||||||
TypeRange inputTypes, TypeRange outputTypes,
|
TypeRange inputTypes, TypeRange outputTypes);
|
||||||
ArrayRef<OpAsmParser::OperandType> captures = {});
|
|
||||||
|
|
||||||
static ParseResult
|
static ParseResult
|
||||||
parseNamedStructuredOpResults(OpAsmParser &parser,
|
parseNamedStructuredOpResults(OpAsmParser &parser,
|
||||||
SmallVectorImpl<Type> &resultTypes);
|
SmallVectorImpl<Type> &resultTypes);
|
||||||
|
|
||||||
template <typename NamedStructuredOpType>
|
template <typename NamedStructuredOpType>
|
||||||
static ParseResult
|
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
|
||||||
parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
|
OperationState &result);
|
||||||
ArrayRef<OpAsmParser::OperandType> captures = {});
|
|
||||||
|
template <typename NamedStructuredOpType>
|
||||||
|
static void printCommonStructuredOpParts(OpAsmPrinter &p,
|
||||||
|
NamedStructuredOpType op);
|
||||||
|
|
||||||
static void printNamedStructuredOpResults(OpAsmPrinter &p,
|
static void printNamedStructuredOpResults(OpAsmPrinter &p,
|
||||||
TypeRange resultTypes);
|
TypeRange resultTypes);
|
||||||
|
@ -122,136 +101,14 @@ static LogicalResult foldMemRefCast(Operation *op) {
|
||||||
return success(folded);
|
return success(folded);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// CopyOp
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
void CopyOp::regionBuilder(Block &block, ValueRange captures) {
|
|
||||||
using namespace edsc::intrinsics;
|
|
||||||
assert(block.getNumArguments() == 2 && "CopyOp regionBuilder expects 2 args");
|
|
||||||
(linalg_yield(block.getArgument(0)));
|
|
||||||
}
|
|
||||||
|
|
||||||
void CopyOp::build(OpBuilder &builder, OperationState &result, Value input,
|
|
||||||
Value output, AffineMap inputPermutation,
|
|
||||||
AffineMap outputPermutation,
|
|
||||||
ArrayRef<NamedAttribute> namedAttrs) {
|
|
||||||
result.addOperands({input, output});
|
|
||||||
result.addAttributes(namedAttrs);
|
|
||||||
if (inputPermutation)
|
|
||||||
result.addAttribute("inputPermutation",
|
|
||||||
AffineMapAttr::get(inputPermutation));
|
|
||||||
if (outputPermutation)
|
|
||||||
result.addAttribute("outputPermutation",
|
|
||||||
AffineMapAttr::get(outputPermutation));
|
|
||||||
result.addRegion();
|
|
||||||
fillStructuredOpRegion<CopyOp>(builder, *result.regions.front(),
|
|
||||||
TypeRange{input.getType()},
|
|
||||||
TypeRange{output.getType()});
|
|
||||||
}
|
|
||||||
|
|
||||||
ParseResult parseCopyOpRegion(OpAsmParser &parser, Region &r, Type inputType,
|
|
||||||
Type outputType) {
|
|
||||||
OpBuilder opBuilder(parser.getBuilder().getContext());
|
|
||||||
fillStructuredOpRegion<CopyOp>(opBuilder, r, TypeRange{inputType},
|
|
||||||
TypeRange{outputType});
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// CopyOp region is elided when printing.
|
|
||||||
void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
|
|
||||||
|
|
||||||
static LogicalResult verify(CopyOp op) {
|
|
||||||
auto outputViewType = op.getOutputShapedType(0);
|
|
||||||
auto inputViewType = op.getInputShapedType(0);
|
|
||||||
if (inputViewType.getElementType() != outputViewType.getElementType())
|
|
||||||
return op.emitOpError("expects views of the same type");
|
|
||||||
if (inputViewType.getRank() != outputViewType.getRank())
|
|
||||||
return op.emitOpError("expects views of the same rank");
|
|
||||||
auto rank = op.getNumParallelLoops();
|
|
||||||
auto inputPermutationMap = op.inputPermutation();
|
|
||||||
if (inputPermutationMap) {
|
|
||||||
if (inputPermutationMap->getNumInputs() != rank)
|
|
||||||
return op.emitOpError("expects optional input_permutation map of rank ")
|
|
||||||
<< rank;
|
|
||||||
if (!inputPermutationMap->isPermutation())
|
|
||||||
return op.emitOpError(
|
|
||||||
"expects optional input_permutation map to be a permutation");
|
|
||||||
}
|
|
||||||
auto outputPermutationMap = op.outputPermutation();
|
|
||||||
if (outputPermutationMap) {
|
|
||||||
if (outputPermutationMap->getNumInputs() != rank)
|
|
||||||
return op.emitOpError("expects optional output_permutation map of rank ")
|
|
||||||
<< rank;
|
|
||||||
if (!outputPermutationMap->isPermutation())
|
|
||||||
return op.emitOpError(
|
|
||||||
"expects optional output_permutation map to be a permutation");
|
|
||||||
}
|
|
||||||
if (rank == 0 && inputPermutationMap)
|
|
||||||
return op.emitOpError("expected no input permutation when rank == 0");
|
|
||||||
if (rank == 0 && outputPermutationMap)
|
|
||||||
return op.emitOpError("expected no output permutation when rank == 0");
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
void CopyOp::getEffects(
|
|
||||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
||||||
&effects) {
|
|
||||||
effects.emplace_back(MemoryEffects::Read::get(), input(),
|
|
||||||
SideEffects::DefaultResource::get());
|
|
||||||
effects.emplace_back(MemoryEffects::Write::get(), output(),
|
|
||||||
SideEffects::DefaultResource::get());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// FillOp
|
// FillOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
void FillOp::regionBuilder(Block &block, ValueRange captures) {
|
|
||||||
using namespace edsc::intrinsics;
|
|
||||||
assert(captures.size() == 1 && "FillOp regionBuilder expects 1 capture");
|
|
||||||
(linalg_yield(captures));
|
|
||||||
}
|
|
||||||
|
|
||||||
void FillOp::build(OpBuilder &builder, OperationState &result, Value output,
|
void FillOp::build(OpBuilder &builder, OperationState &result, Value output,
|
||||||
Value value) {
|
Value value) {
|
||||||
build(builder, result, output.getType().dyn_cast<RankedTensorType>(), output,
|
build(builder, result, output.getType().dyn_cast<RankedTensorType>(), output,
|
||||||
value);
|
value);
|
||||||
fillStructuredOpRegion<FillOp>(builder, *result.regions.front(), TypeRange{},
|
|
||||||
TypeRange{output.getType()}, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type outputType,
|
|
||||||
OpAsmParser::OperandType valueRef) {
|
|
||||||
OpBuilder opBuilder(parser.getBuilder().getContext());
|
|
||||||
// Resolve `valueRef` into `value` at parse time so we can build the region
|
|
||||||
// with captures.
|
|
||||||
SmallVector<Value> value;
|
|
||||||
parser.resolveOperand(valueRef, getElementTypeOrSelf(outputType), value);
|
|
||||||
fillStructuredOpRegion<FillOp>(opBuilder, r, TypeRange{},
|
|
||||||
TypeRange{outputType}, value);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// FillOp region is elided when printing.
|
|
||||||
void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Value) {}
|
|
||||||
|
|
||||||
static LogicalResult verify(FillOp op) {
|
|
||||||
auto viewType = op.getOutputShapedType(0);
|
|
||||||
auto fillType = op.value().getType();
|
|
||||||
if (viewType.getElementType() != fillType)
|
|
||||||
return op.emitOpError("expects fill type to match view elemental type");
|
|
||||||
if (!op.getNumResults() && !viewType.isa<MemRefType>()) {
|
|
||||||
return op.emitOpError(
|
|
||||||
"expected fill op with no result value to use memref type");
|
|
||||||
}
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
void FillOp::getEffects(
|
|
||||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
||||||
&effects) {
|
|
||||||
if (output().getType().isa<MemRefType>())
|
|
||||||
effects.emplace_back(MemoryEffects::Write::get(), output(),
|
|
||||||
SideEffects::DefaultResource::get());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -576,6 +433,7 @@ void InitTensorOp::build(OpBuilder &b, OperationState &result,
|
||||||
result.addAttributes(attrs);
|
result.addAttributes(attrs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static LogicalResult verify(InitTensorOp op) {
|
static LogicalResult verify(InitTensorOp op) {
|
||||||
RankedTensorType resultType = op.getType();
|
RankedTensorType resultType = op.getType();
|
||||||
SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range(
|
SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range(
|
||||||
|
@ -1556,6 +1414,68 @@ static LogicalResult verify(linalg::YieldOp op) {
|
||||||
|
|
||||||
/////// Operations corresponding to library calls defined with Tablegen ////////
|
/////// Operations corresponding to library calls defined with Tablegen ////////
|
||||||
|
|
||||||
|
void FillOp::getEffects(
|
||||||
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||||
|
&effects) {
|
||||||
|
if (output().getType().isa<MemRefType>())
|
||||||
|
effects.emplace_back(MemoryEffects::Write::get(), output(),
|
||||||
|
SideEffects::DefaultResource::get());
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verify(FillOp op) {
|
||||||
|
auto viewType = op.getOutputShapedType(0);
|
||||||
|
auto fillType = op.value().getType();
|
||||||
|
if (viewType.getElementType() != fillType)
|
||||||
|
return op.emitOpError("expects fill type to match view elemental type");
|
||||||
|
if (!op.getNumResults() && !viewType.isa<MemRefType>()) {
|
||||||
|
return op.emitOpError(
|
||||||
|
"expected fill op with no result value to use memref type");
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CopyOp::getEffects(
|
||||||
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||||
|
&effects) {
|
||||||
|
effects.emplace_back(MemoryEffects::Read::get(), input(),
|
||||||
|
SideEffects::DefaultResource::get());
|
||||||
|
effects.emplace_back(MemoryEffects::Write::get(), output(),
|
||||||
|
SideEffects::DefaultResource::get());
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verify(CopyOp op) {
|
||||||
|
auto outputViewType = op.getOutputShapedType(0);
|
||||||
|
auto inputViewType = op.getInputShapedType(0);
|
||||||
|
if (inputViewType.getElementType() != outputViewType.getElementType())
|
||||||
|
return op.emitOpError("expects views of the same type");
|
||||||
|
if (inputViewType.getRank() != outputViewType.getRank())
|
||||||
|
return op.emitOpError("expects views of the same rank");
|
||||||
|
auto rank = op.getNumParallelLoops();
|
||||||
|
auto inputPermutationMap = op.inputPermutation();
|
||||||
|
if (inputPermutationMap) {
|
||||||
|
if (inputPermutationMap->getNumInputs() != rank)
|
||||||
|
return op.emitOpError("expects optional input_permutation map of rank ")
|
||||||
|
<< rank;
|
||||||
|
if (!inputPermutationMap->isPermutation())
|
||||||
|
return op.emitOpError(
|
||||||
|
"expects optional input_permutation map to be a permutation");
|
||||||
|
}
|
||||||
|
auto outputPermutationMap = op.outputPermutation();
|
||||||
|
if (outputPermutationMap) {
|
||||||
|
if (outputPermutationMap->getNumInputs() != rank)
|
||||||
|
return op.emitOpError("expects optional output_permutation map of rank ")
|
||||||
|
<< rank;
|
||||||
|
if (!outputPermutationMap->isPermutation())
|
||||||
|
return op.emitOpError(
|
||||||
|
"expects optional output_permutation map to be a permutation");
|
||||||
|
}
|
||||||
|
if (rank == 0 && inputPermutationMap)
|
||||||
|
return op.emitOpError("expected no input permutation when rank == 0");
|
||||||
|
if (rank == 0 && outputPermutationMap)
|
||||||
|
return op.emitOpError("expected no output permutation when rank == 0");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
template <typename LinalgPoolingOp>
|
template <typename LinalgPoolingOp>
|
||||||
static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op,
|
static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op,
|
||||||
ArrayRef<Attribute> attrs,
|
ArrayRef<Attribute> attrs,
|
||||||
|
@ -1788,25 +1708,14 @@ OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute> operands) {
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Support for named Linalg ops defined in ods-gen.
|
// Auto-generated Linalg named ops.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Generic entry point to create the block for the region of a LinalgOp.
|
|
||||||
/// This is used by both named structured ops created by ods-gen and by manually
|
|
||||||
/// defined C++ ops.
|
|
||||||
/// This is used by both builders and parsers.
|
|
||||||
/// This function creates the block in the region with arguments corresponding
|
|
||||||
/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
|
|
||||||
/// to be ShapedType.
|
|
||||||
template <typename NamedStructuredOpType>
|
template <typename NamedStructuredOpType>
|
||||||
static void
|
static void buildNamedStructuredOpRegionAndAttributesImpl(
|
||||||
fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
|
OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes,
|
||||||
TypeRange inputTypes, TypeRange outputTypes,
|
TypeRange outputTypes,
|
||||||
ValueRange captures,
|
std::function<void(unsigned, unsigned)> errorHandler) {
|
||||||
std::function<void(unsigned, unsigned)> errorHandler) {
|
|
||||||
assert(llvm::all_of(inputTypes, [](Type t) { return t.isa<ShapedType>(); }));
|
|
||||||
assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
|
|
||||||
|
|
||||||
// TODO: atm all operands go through getElementTypeOrSelf,
|
// TODO: atm all operands go through getElementTypeOrSelf,
|
||||||
// reconsider when we have evidence we need to.
|
// reconsider when we have evidence we need to.
|
||||||
SmallVector<Type, 8> argTypes;
|
SmallVector<Type, 8> argTypes;
|
||||||
|
@ -1816,7 +1725,7 @@ fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
|
||||||
|
|
||||||
// RAII.
|
// RAII.
|
||||||
OpBuilder::InsertionGuard guard(opBuilder);
|
OpBuilder::InsertionGuard guard(opBuilder);
|
||||||
Block *body = opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes);
|
Block *body = opBuilder.createBlock(®ion, {}, argTypes);
|
||||||
unsigned actual = body->getNumArguments();
|
unsigned actual = body->getNumArguments();
|
||||||
unsigned expected = NamedStructuredOpType::getNumRegionArgs();
|
unsigned expected = NamedStructuredOpType::getNumRegionArgs();
|
||||||
if (expected != actual)
|
if (expected != actual)
|
||||||
|
@ -1824,30 +1733,53 @@ fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
|
||||||
|
|
||||||
opBuilder.setInsertionPointToStart(body);
|
opBuilder.setInsertionPointToStart(body);
|
||||||
mlir::edsc::ScopedContext scope(opBuilder, opBuilder.getUnknownLoc());
|
mlir::edsc::ScopedContext scope(opBuilder, opBuilder.getUnknownLoc());
|
||||||
NamedStructuredOpType::regionBuilder(*body, captures);
|
NamedStructuredOpType::regionBuilder(*body);
|
||||||
|
|
||||||
// indexing_maps is an auto-generated method.
|
// indexing_maps is an auto-generated method.
|
||||||
|
|
||||||
// iterator_types is an auto-generated method.
|
// iterator_types is an auto-generated method.
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generic entry point to create both the region and the block of a LinalgOp.
|
|
||||||
template <typename NamedStructuredOpType>
|
template <typename NamedStructuredOpType>
|
||||||
void createAndFillStructuredOpRegion(OpBuilder &opBuilder,
|
void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
|
||||||
OperationState &result,
|
OperationState &result,
|
||||||
TypeRange inputTypes,
|
TypeRange inputTypes,
|
||||||
TypeRange outputTypes,
|
TypeRange outputTypes) {
|
||||||
ValueRange captures) {
|
|
||||||
Region ®ion = *result.addRegion();
|
Region ®ion = *result.addRegion();
|
||||||
fillStructuredOpRegion<NamedStructuredOpType>(
|
buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
|
||||||
opBuilder, region, inputTypes, outputTypes, captures,
|
opBuilder, region, inputTypes, outputTypes,
|
||||||
[&](unsigned expected, unsigned actual) {
|
[&](unsigned expected, unsigned actual) {
|
||||||
|
llvm::errs() << "region expects " << expected << " args, got "
|
||||||
|
<< actual;
|
||||||
assert(expected != actual && "incorrect number of arguments");
|
assert(expected != actual && "incorrect number of arguments");
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Common parsing used for both named structured ops created by ods-gen and by
|
template <typename NamedStructuredOpType>
|
||||||
/// manually defined C++ ops. Does not handle regions.
|
static ParseResult
|
||||||
|
parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion,
|
||||||
|
TypeRange inputTypes, TypeRange outputTypes) {
|
||||||
|
ParseResult res = success();
|
||||||
|
OpBuilder opBuilder(parser.getBuilder().getContext());
|
||||||
|
buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
|
||||||
|
opBuilder, region, inputTypes, outputTypes,
|
||||||
|
[&](unsigned expected, unsigned actual) {
|
||||||
|
res = parser.emitError(parser.getCurrentLocation(),
|
||||||
|
llvm::formatv("region expects {0} args, got {1}",
|
||||||
|
expected, actual));
|
||||||
|
});
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ParseResult
|
||||||
|
parseNamedStructuredOpResults(OpAsmParser &parser,
|
||||||
|
SmallVectorImpl<Type> &resultTypes) {
|
||||||
|
if (succeeded(parser.parseOptionalArrow()))
|
||||||
|
if (parser.parseTypeList(resultTypes))
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
static ParseResult
|
static ParseResult
|
||||||
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
|
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
|
||||||
SmallVectorImpl<Type> &inputTypes,
|
SmallVectorImpl<Type> &inputTypes,
|
||||||
|
@ -1888,56 +1820,8 @@ parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NamedStructuredOpType>
|
template <typename NamedStructuredOpType>
|
||||||
static void printCommonStructuredOpParts(OpAsmPrinter &p,
|
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
|
||||||
NamedStructuredOpType op) {
|
OperationState &result) {
|
||||||
if (!op.inputs().empty())
|
|
||||||
p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
|
|
||||||
if (!op.outputs().empty())
|
|
||||||
p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")";
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Specific parsing and printing for named structured ops created by ods-gen.
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
template <typename NamedStructuredOpType>
|
|
||||||
static ParseResult
|
|
||||||
parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion,
|
|
||||||
TypeRange inputTypes, TypeRange outputTypes,
|
|
||||||
ArrayRef<OpAsmParser::OperandType> captures) {
|
|
||||||
ParseResult res = success();
|
|
||||||
OpBuilder opBuilder(parser.getBuilder().getContext());
|
|
||||||
// Resolve `captures` into `capturedValues` at parse time so we can build the
|
|
||||||
// region with captures.
|
|
||||||
SmallVector<Value> capturedValues;
|
|
||||||
fillStructuredOpRegion<NamedStructuredOpType>(
|
|
||||||
opBuilder, region, inputTypes, outputTypes, capturedValues,
|
|
||||||
[&](unsigned expected, unsigned actual) {
|
|
||||||
res = parser.emitError(
|
|
||||||
parser.getCurrentLocation(),
|
|
||||||
llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
|
|
||||||
"region expects {0} args, got {1}",
|
|
||||||
expected, actual));
|
|
||||||
region.front().dump();
|
|
||||||
});
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
static ParseResult
|
|
||||||
parseNamedStructuredOpResults(OpAsmParser &parser,
|
|
||||||
SmallVectorImpl<Type> &resultTypes) {
|
|
||||||
if (succeeded(parser.parseOptionalArrow()))
|
|
||||||
if (parser.parseTypeList(resultTypes))
|
|
||||||
return failure();
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename NamedStructuredOpType>
|
|
||||||
static ParseResult
|
|
||||||
parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
|
|
||||||
ArrayRef<OpAsmParser::OperandType> captures) {
|
|
||||||
// TODO: Enable when ods-gen supports captures.
|
|
||||||
assert(captures.empty() && "unexpected captures for named structured ops");
|
|
||||||
SmallVector<Type, 1> inputTypes, outputTypes;
|
SmallVector<Type, 1> inputTypes, outputTypes;
|
||||||
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
|
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -1951,7 +1835,7 @@ parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
|
||||||
|
|
||||||
std::unique_ptr<Region> region = std::make_unique<Region>();
|
std::unique_ptr<Region> region = std::make_unique<Region>();
|
||||||
if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
|
if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
|
||||||
parser, *region, inputTypes, outputTypes, captures))
|
parser, *region, inputTypes, outputTypes))
|
||||||
return failure();
|
return failure();
|
||||||
result.addRegion(std::move(region));
|
result.addRegion(std::move(region));
|
||||||
|
|
||||||
|
@ -1965,6 +1849,15 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p,
|
||||||
p.printOptionalArrowTypeList(resultTypes);
|
p.printOptionalArrowTypeList(resultTypes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename NamedStructuredOpType>
|
||||||
|
static void printCommonStructuredOpParts(OpAsmPrinter &p,
|
||||||
|
NamedStructuredOpType op) {
|
||||||
|
if (!op.inputs().empty())
|
||||||
|
p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
|
||||||
|
if (!op.outputs().empty())
|
||||||
|
p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")";
|
||||||
|
}
|
||||||
|
|
||||||
template <typename NamedStructuredOpType>
|
template <typename NamedStructuredOpType>
|
||||||
static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
|
static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
|
||||||
p << op.getOperationName();
|
p << op.getOperationName();
|
||||||
|
@ -1986,10 +1879,6 @@ static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) {
|
||||||
return verifyGenericOp<NamedStructuredOpType>(op);
|
return verifyGenericOp<NamedStructuredOpType>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Canonicalizers and Folders.
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
struct EraseDeadLinalgOp : public RewritePattern {
|
struct EraseDeadLinalgOp : public RewritePattern {
|
||||||
EraseDeadLinalgOp(PatternBenefit benefit = 1)
|
EraseDeadLinalgOp(PatternBenefit benefit = 1)
|
||||||
|
|
|
@ -49,7 +49,7 @@ static linalg::GenericOp createGenericOpFromNamedOp(linalg::LinalgOp namedOp,
|
||||||
indexingMaps, iterators,
|
indexingMaps, iterators,
|
||||||
[®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
|
[®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
|
||||||
edsc::ScopedContext scope(bodyBuilder, loc);
|
edsc::ScopedContext scope(bodyBuilder, loc);
|
||||||
regionBuilder(*bodyBuilder.getBlock(), /*captures=*/{});
|
regionBuilder(*bodyBuilder.getBlock());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -52,6 +52,14 @@ static SmallVector<Value, 8> makeCanonicalAffineApplies(OpBuilder &b,
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static SmallVector<Value, 4> permuteIvs(ArrayRef<Value> ivs,
|
||||||
|
Optional<AffineMap> permutation) {
|
||||||
|
return permutation ? applyMapToValues(ScopedContext::getBuilderRef(),
|
||||||
|
ScopedContext::getLocation(),
|
||||||
|
permutation.getValue(), ivs)
|
||||||
|
: SmallVector<Value, 4>(ivs.begin(), ivs.end());
|
||||||
|
}
|
||||||
|
|
||||||
template <typename IndexedValueType, typename OpType>
|
template <typename IndexedValueType, typename OpType>
|
||||||
static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues,
|
static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues,
|
||||||
ArrayRef<SmallVector<Value, 8>> indexing,
|
ArrayRef<SmallVector<Value, 8>> indexing,
|
||||||
|
@ -170,6 +178,40 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs,
|
||||||
outputBuffers);
|
outputBuffers);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename IndexedValueType>
|
||||||
|
static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
|
||||||
|
assert(copyOp.hasBufferSemantics() &&
|
||||||
|
"expected linalg op with buffer semantics");
|
||||||
|
auto nPar = copyOp.getNumParallelLoops();
|
||||||
|
assert(nPar == allIvs.size());
|
||||||
|
auto inputIvs =
|
||||||
|
permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation());
|
||||||
|
auto outputIvs =
|
||||||
|
permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation());
|
||||||
|
SmallVector<Value, 8> iivs(inputIvs.begin(), inputIvs.end());
|
||||||
|
SmallVector<Value, 8> oivs(outputIvs.begin(), outputIvs.end());
|
||||||
|
IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0));
|
||||||
|
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
|
||||||
|
// an n-D loop nest; with or without permutations.
|
||||||
|
// clang-format off
|
||||||
|
nPar > 0 ? O(oivs) = I(iivs) :
|
||||||
|
O() = I();
|
||||||
|
// clang-format on
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename IndexedValueType>
|
||||||
|
static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
|
||||||
|
assert(fillOp.hasBufferSemantics() &&
|
||||||
|
"expected linalg op with buffer semantics");
|
||||||
|
auto nPar = fillOp.getNumParallelLoops();
|
||||||
|
assert(nPar == allIvs.size());
|
||||||
|
auto ivs = SmallVector<Value, 4>(allIvs.begin(), allIvs.begin() + nPar);
|
||||||
|
IndexedValueType O(fillOp.getOutputBuffer(0));
|
||||||
|
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
|
||||||
|
// an n-D loop nest; with or without permutations.
|
||||||
|
nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value();
|
||||||
|
}
|
||||||
|
|
||||||
// Create a padded view into the given `input` tensor using the 'indices'
|
// Create a padded view into the given `input` tensor using the 'indices'
|
||||||
// to access the tensor. `skipPadding` lists the dimensions for which no padding
|
// to access the tensor. `skipPadding` lists the dimensions for which no padding
|
||||||
// is needed e.g. the non-spatial dimensions for convolutions.
|
// is needed e.g. the non-spatial dimensions for convolutions.
|
||||||
|
@ -491,8 +533,8 @@ linalgOpToLoopsImpl(Operation *op, OpBuilder &builder,
|
||||||
assert(iterArgs.empty() && "unexpected iterArgs");
|
assert(iterArgs.empty() && "unexpected iterArgs");
|
||||||
allIvs.append(ivs.begin(), ivs.end());
|
allIvs.append(ivs.begin(), ivs.end());
|
||||||
llvm::TypeSwitch<Operation *>(op)
|
llvm::TypeSwitch<Operation *>(op)
|
||||||
.Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp,
|
.Case<CopyOp, FillOp, ConvOp, PoolingMaxOp, PoolingMinOp,
|
||||||
IndexedGenericOp, LinalgOp>([&](auto op) {
|
PoolingSumOp, IndexedGenericOp, LinalgOp>([&](auto op) {
|
||||||
emitScalarImplementation<IndexedValueTy>(allIvs, op);
|
emitScalarImplementation<IndexedValueTy>(allIvs, op);
|
||||||
})
|
})
|
||||||
.Default([&](Operation *op) { assert(false && "unexpected op"); });
|
.Default([&](Operation *op) { assert(false && "unexpected op"); });
|
||||||
|
|
|
@ -267,7 +267,7 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
|
||||||
llvm::map_range(linalgOp.getShapedOperandTypes(),
|
llvm::map_range(linalgOp.getShapedOperandTypes(),
|
||||||
[](ShapedType t) { return t.getElementType(); }));
|
[](ShapedType t) { return t.getElementType(); }));
|
||||||
block->addArguments(elementTypes);
|
block->addArguments(elementTypes);
|
||||||
linalgOp.getRegionBuilder()(*block, /*captures=*/{});
|
linalgOp.getRegionBuilder()(*block);
|
||||||
}
|
}
|
||||||
Block *block = ®ion->front();
|
Block *block = ®ion->front();
|
||||||
|
|
||||||
|
@ -333,26 +333,24 @@ static bool hasOnlyScalarElementwiseOp(Region &r) {
|
||||||
|
|
||||||
// Return true if the op is an element-wise linalg op.
|
// Return true if the op is an element-wise linalg op.
|
||||||
static bool isElementwise(Operation *op) {
|
static bool isElementwise(Operation *op) {
|
||||||
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
|
auto genericOp = dyn_cast<linalg::GenericOp>(op);
|
||||||
if (!linalgOp)
|
if (!genericOp)
|
||||||
return false;
|
return false;
|
||||||
if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
|
if (genericOp.getNumLoops() != genericOp.getNumParallelLoops())
|
||||||
return false;
|
return false;
|
||||||
// TODO: relax the restrictions on indexing map.
|
// TODO: relax the restrictions on indexing map.
|
||||||
for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) {
|
for (unsigned i = 0, e = genericOp.getNumOutputs(); i < e; i++) {
|
||||||
if (!linalgOp.getOutputIndexingMap(i).isIdentity())
|
if (!genericOp.getOutputIndexingMap(i).isIdentity())
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// Currently bound the input indexing map to minor identity as other
|
// Currently bound the input indexing map to minor identity as other
|
||||||
// permutations might require adding transpose ops to convert the vector read
|
// permutations might require adding transpose ops to convert the vector read
|
||||||
// to the right shape.
|
// to the right shape.
|
||||||
for (unsigned i = 0, e = linalgOp.getNumInputs(); i < e; i++) {
|
for (unsigned i = 0, e = genericOp.getNumInputs(); i < e; i++) {
|
||||||
if (!linalgOp.getInputIndexingMap(i).isMinorIdentity())
|
if (!genericOp.getInputIndexingMap(i).isMinorIdentity())
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (linalgOp->getNumRegions() != 1)
|
return hasOnlyScalarElementwiseOp(genericOp.getRegion());
|
||||||
return false;
|
|
||||||
return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Optional<VectorizedLinalgOp> vectorizeContraction(OpBuilder &builder,
|
static Optional<VectorizedLinalgOp> vectorizeContraction(OpBuilder &builder,
|
||||||
|
@ -395,6 +393,9 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
|
||||||
for (Type outputTensorType : linalgOp.getOutputTensorTypes())
|
for (Type outputTensorType : linalgOp.getOutputTensorTypes())
|
||||||
if (!outputTensorType.cast<ShapedType>().hasStaticShape())
|
if (!outputTensorType.cast<ShapedType>().hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
if (isa<linalg::FillOp, linalg::CopyOp>(op))
|
||||||
|
return success();
|
||||||
if (isElementwise(op))
|
if (isElementwise(op))
|
||||||
return success();
|
return success();
|
||||||
return success(isaContractionOpInterface(linalgOp));
|
return success(isaContractionOpInterface(linalgOp));
|
||||||
|
@ -406,12 +407,43 @@ Optional<VectorizedLinalgOp> mlir::linalg::vectorizeLinalgOp(OpBuilder &builder,
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
|
|
||||||
edsc::ScopedContext scope(builder, op->getLoc());
|
edsc::ScopedContext scope(builder, op->getLoc());
|
||||||
|
// In the case of 0-D memrefs, return null and special case to scalar load or
|
||||||
|
// store later.
|
||||||
|
if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
|
||||||
|
// Vectorize fill as a vector.broadcast.
|
||||||
|
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
|
||||||
|
<< "Rewrite linalg.fill as vector.broadcast: " << *op);
|
||||||
|
VectorizedLinalgOp res;
|
||||||
|
if (Value v = buildVectorWrite(builder, fillOp.value(), fillOp.output()))
|
||||||
|
res.tensorResults.push_back(v);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
|
||||||
|
// Vectorize copy as a vector.transfer_read+vector.transfer_write.
|
||||||
|
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
|
||||||
|
<< "Rewrite linalg.copy as vector.transfer_read + "
|
||||||
|
"vector.transfer_write: "
|
||||||
|
<< *op);
|
||||||
|
Value vector = buildVectorRead(builder, copyOp.input());
|
||||||
|
VectorizedLinalgOp res;
|
||||||
|
if (Value v = buildVectorWrite(builder, vector, copyOp.output()))
|
||||||
|
res.tensorResults.push_back(v);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
if (isElementwise(op)) {
|
if (isElementwise(op)) {
|
||||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
|
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
|
||||||
<< "Vectorize linalg op as a generic: " << *op);
|
<< "Vectorize linalg op as a generic: " << *op);
|
||||||
return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
|
return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: as soon as Copy and FillOp. get a region builder, replace all the
|
||||||
|
// above by:
|
||||||
|
// if (isa<FillOp, CopyOp>(op) || isElementwise(op)) {
|
||||||
|
// LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
|
||||||
|
// << "Vectorize linalg op as a generic: " << *op);
|
||||||
|
// return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
|
||||||
|
// }
|
||||||
|
|
||||||
return vectorizeContraction(builder, cast<LinalgOp>(op));
|
return vectorizeContraction(builder, cast<LinalgOp>(op));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
// RUN: mlir-opt -copy-removal -split-input-file %s | FileCheck %s
|
// RUN: mlir-opt -copy-removal -split-input-file %s
|
||||||
|
//| FileCheck %s
|
||||||
|
|
||||||
// All linalg copies except the linalg.copy(%1, %9) must be removed since the
|
// All linalg copies except the linalg.copy(%1, %9) must be removed since the
|
||||||
// defining operation of %1 and its DeallocOp have been defined in another block.
|
// defining operation of %1 and its DeallocOp have been defined in another block.
|
||||||
|
@ -255,7 +256,7 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>, %result: memref<2xf32>)
|
||||||
%tmp2 = math.exp %gen2_arg0 : f32
|
%tmp2 = math.exp %gen2_arg0 : f32
|
||||||
linalg.yield %tmp2 : f32
|
linalg.yield %tmp2 : f32
|
||||||
}
|
}
|
||||||
linalg.copy(%temp, %result) : memref<2xf32>, memref<2xf32>
|
"linalg.copy"(%temp, %result) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||||
dealloc %temp : memref<2xf32>
|
dealloc %temp : memref<2xf32>
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
return
|
return
|
||||||
|
@ -291,7 +292,7 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>){
|
||||||
linalg.yield %tmp2 : f32
|
linalg.yield %tmp2 : f32
|
||||||
}
|
}
|
||||||
// CHECK: linalg.copy
|
// CHECK: linalg.copy
|
||||||
linalg.copy(%temp, %to) : memref<2xf32>, memref<2xf32>
|
"linalg.copy"(%temp, %to) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||||
dealloc %temp : memref<2xf32>
|
dealloc %temp : memref<2xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -354,7 +355,7 @@ func @check_with_affine_dialect(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg
|
||||||
}
|
}
|
||||||
// CHECK-NOT: linalg.copy
|
// CHECK-NOT: linalg.copy
|
||||||
// CHECK-NOT: dealloc
|
// CHECK-NOT: dealloc
|
||||||
linalg.copy(%0, %arg2) : memref<4xf32>, memref<4xf32>
|
"linalg.copy"(%0, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
|
||||||
dealloc %0 : memref<4xf32>
|
dealloc %0 : memref<4xf32>
|
||||||
//CHECK: return
|
//CHECK: return
|
||||||
return
|
return
|
||||||
|
|
|
@ -23,7 +23,7 @@
|
||||||
// IMPL-NEXT: map2 = simplifyAffineMap(map2);
|
// IMPL-NEXT: map2 = simplifyAffineMap(map2);
|
||||||
// IMPL-NEXT: return {{.+}}.getAffineMapArrayAttr({ map0, map1, map2 });
|
// IMPL-NEXT: return {{.+}}.getAffineMapArrayAttr({ map0, map1, map2 });
|
||||||
//
|
//
|
||||||
// IMPL: void Test1Op::regionBuilder(Block &block, ValueRange captures) {
|
// IMPL: void Test1Op::regionBuilder(Block &block) {
|
||||||
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
|
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
|
||||||
// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]);
|
// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]);
|
||||||
// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]);
|
// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]);
|
||||||
|
@ -47,7 +47,7 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
|
||||||
// IMPL: AffineMap::get(3, 3, {d2, d1}, context)
|
// IMPL: AffineMap::get(3, 3, {d2, d1}, context)
|
||||||
// IMPL: AffineMap::get(3, 3, {d0, d1}, context)
|
// IMPL: AffineMap::get(3, 3, {d0, d1}, context)
|
||||||
//
|
//
|
||||||
// IMPL: Test2Op::regionBuilder(Block &block, ValueRange captures) {
|
// IMPL: Test2Op::regionBuilder(Block &block) {
|
||||||
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
|
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
|
||||||
// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]);
|
// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]);
|
||||||
// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]);
|
// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]);
|
||||||
|
@ -71,7 +71,7 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
|
||||||
// IMPL: AffineMap::get(4, 4, {d3, d2}, context)
|
// IMPL: AffineMap::get(4, 4, {d3, d2}, context)
|
||||||
// IMPL: AffineMap::get(4, 4, {d0, d1, d2}, context)
|
// IMPL: AffineMap::get(4, 4, {d0, d1, d2}, context)
|
||||||
//
|
//
|
||||||
// IMPL: Test3Op::regionBuilder(Block &block, ValueRange captures) {
|
// IMPL: Test3Op::regionBuilder(Block &block) {
|
||||||
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
|
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
|
||||||
// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]);
|
// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]);
|
||||||
// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]);
|
// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]);
|
||||||
|
|
|
@ -1871,11 +1871,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
|
||||||
$_builder.getI32VectorAttr({{
|
$_builder.getI32VectorAttr({{
|
||||||
static_cast<int32_t>(inputs.size()),
|
static_cast<int32_t>(inputs.size()),
|
||||||
static_cast<int32_t>(outputs.size())}));
|
static_cast<int32_t>(outputs.size())}));
|
||||||
createAndFillStructuredOpRegion<{0}>(
|
buildNamedStructuredOpRegionAndAttributes<{0}>(
|
||||||
$_builder,
|
$_builder,
|
||||||
$_state,
|
$_state,
|
||||||
TypeRange(inputs),
|
TypeRange(inputs),
|
||||||
TypeRange(outputs)/*, TODO: support captures*/);
|
TypeRange(outputs));
|
||||||
}]>,
|
}]>,
|
||||||
OpBuilderDAG<
|
OpBuilderDAG<
|
||||||
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
||||||
|
@ -1889,11 +1889,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
|
||||||
$_builder.getI32VectorAttr({{
|
$_builder.getI32VectorAttr({{
|
||||||
static_cast<int32_t>(inputs.size()),
|
static_cast<int32_t>(inputs.size()),
|
||||||
static_cast<int32_t>(outputs.size())}));
|
static_cast<int32_t>(outputs.size())}));
|
||||||
createAndFillStructuredOpRegion<{0}>(
|
buildNamedStructuredOpRegionAndAttributes<{0}>(
|
||||||
$_builder,
|
$_builder,
|
||||||
$_state,
|
$_state,
|
||||||
TypeRange(inputs),
|
TypeRange(inputs),
|
||||||
TypeRange(outputs)/*, TODO: support captures*/);
|
TypeRange(outputs));
|
||||||
}]>,
|
}]>,
|
||||||
OpBuilderDAG<
|
OpBuilderDAG<
|
||||||
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
|
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
|
||||||
|
@ -1907,9 +1907,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
|
||||||
{6}
|
{6}
|
||||||
];
|
];
|
||||||
let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
|
let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
|
||||||
let parser = [{{
|
let parser = [{{ return ::parseNamedStructuredOp<{0}>(parser, result); }];
|
||||||
return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/);
|
|
||||||
}];
|
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
|
|
||||||
|
@ -1917,8 +1915,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
|
||||||
// Auto-generated.
|
// Auto-generated.
|
||||||
ArrayAttr iterator_types();
|
ArrayAttr iterator_types();
|
||||||
ArrayAttr indexing_maps();
|
ArrayAttr indexing_maps();
|
||||||
static void regionBuilder(Block &block, ValueRange captures);
|
static void regionBuilder(Block &block);
|
||||||
static std::function<void(Block &, ValueRange)> getRegionBuilder() {{
|
static std::function<void(Block &)> getRegionBuilder() {{
|
||||||
return regionBuilder;
|
return regionBuilder;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1982,11 +1980,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
|
||||||
$_builder.getI32VectorAttr({{
|
$_builder.getI32VectorAttr({{
|
||||||
static_cast<int32_t>(inputs.size()),
|
static_cast<int32_t>(inputs.size()),
|
||||||
static_cast<int32_t>(outputs.size())}));
|
static_cast<int32_t>(outputs.size())}));
|
||||||
createAndFillStructuredOpRegion<{0}>(
|
buildNamedStructuredOpRegionAndAttributes<{0}>(
|
||||||
$_builder,
|
$_builder,
|
||||||
$_state,
|
$_state,
|
||||||
TypeRange(inputs),
|
TypeRange(inputs),
|
||||||
TypeRange(outputs)/*, TODO: support captures*/);
|
TypeRange(outputs));
|
||||||
{2}
|
{2}
|
||||||
}]>
|
}]>
|
||||||
)FMT";
|
)FMT";
|
||||||
|
@ -2313,7 +2311,7 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
|
||||||
};
|
};
|
||||||
|
|
||||||
const char *regionBuilderFmt = R"FMT(
|
const char *regionBuilderFmt = R"FMT(
|
||||||
void {0}::regionBuilder(Block &block, ValueRange captures) {
|
void {0}::regionBuilder(Block &block) {
|
||||||
using namespace edsc;
|
using namespace edsc;
|
||||||
using namespace intrinsics;
|
using namespace intrinsics;
|
||||||
auto args = block.getArguments();
|
auto args = block.getArguments();
|
||||||
|
|
Loading…
Reference in New Issue