[mlir] Add custom directive hooks for printing mixed integer or value operands.

Add printer and parser hooks for a custom directive that allows
parsing and printing of idioms that can represent a list of values
each of which is either an integer or an SSA value. For example in

`subview %source[%offset_0, 1] [4, %size_1] [%stride_0, 3]`

each of the list (which represents offset, size and strides) is a mix
of either statically know integer values or dynamically computed SSA
values. Since this is used in many places adding a custom directive to
parse/print this idiom allows using assembly format on operations
which use this idiom.

Differential Revision: https://reviews.llvm.org/D95773
This commit is contained in:
MaheshRavishankar 2021-02-01 19:03:12 -08:00
parent 87f8a08ce3
commit 342d4662e1
8 changed files with 150 additions and 418 deletions

View File

@ -45,6 +45,11 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> {
let results = (outs AnyTensor:$result);
let assemblyFormat = [{
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes) attr-dict
`:` type($result)
}];
let verifier = [{ return ::verify(*this); }];
let extraClassDeclaration = [{
@ -118,7 +123,7 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> {
}
def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
[AttrSizedOperandSegments, SingleBlockImplicitTerminator<"YieldOp">]> {
[AttrSizedOperandSegments]> {
let summary = "tensor pad operation";
let description = [{
`linalg.pad_tensor` is an operation that pads the `source` tensor
@ -181,10 +186,16 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
I64ArrayAttr:$static_low,
I64ArrayAttr:$static_high);
let regions = (region AnyRegion:$region);
let regions = (region SizedRegion<1>:$region);
let results = (outs AnyTensor:$result);
let assemblyFormat = [{
$source `low` `` custom<OperandsOrIntegersSizesList>($low, $static_low)
`high` `` custom<OperandsOrIntegersSizesList>($high, $static_high)
$region attr-dict `:` type($source) `to` type($result)
}];
let extraClassDeclaration = [{
static StringRef getStaticLowAttrName() {
return "static_low";

View File

@ -1956,6 +1956,19 @@ def MemRefReinterpretCastOp:
);
let results = (outs AnyMemRef:$result);
let assemblyFormat = [{
$source `to` `offset` `` `:`
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
`` `,` `sizes` `` `:`
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes) `` `,` `strides`
`` `:`
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
attr-dict `:` type($source) `to` type($result)
}];
let parser=?;
let printer=?;
let builders = [
// Build a ReinterpretCastOp with mixed static and dynamic entries.
OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$source,
@ -2931,6 +2944,14 @@ def SubViewOp : BaseOpWithOffsetSizesAndStrides<
);
let results = (outs AnyMemRef:$result);
let assemblyFormat = [{
$source ``
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
attr-dict `:` type($source) `to` type($result)
}];
let builders = [
// Build a SubViewOp with mixed static and dynamic entries and custom
// result type. If the type passed is nullptr, it is inferred.
@ -3053,6 +3074,14 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<
);
let results = (outs AnyRankedTensor:$result);
let assemblyFormat = [{
$source ``
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
attr-dict `:` type($source) `to` type($result)
}];
let builders = [
// Build a SubTensorOp with mixed static and dynamic entries and inferred
// result type.
@ -3115,7 +3144,10 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<
//===----------------------------------------------------------------------===//
def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<
"subtensor_insert", [OffsetSizeAndStrideOpInterface]> {
"subtensor_insert",
[OffsetSizeAndStrideOpInterface,
TypesMatchWith<"expected result type to match dest type",
"dest", "result", "$_self">]> {
let summary = "subtensor_insert operation";
let description = [{
The "subtensor_insert" operation insert a tensor `source` into another
@ -3159,6 +3191,16 @@ def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<
);
let results = (outs AnyRankedTensor:$result);
let assemblyFormat = [{
$source `into` $dest ``
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
attr-dict `:` type($source) `into` type($dest)
}];
let verifier = ?;
let builders = [
// Build a SubTensorInsertOp with mixed static and dynamic entries.
OpBuilderDAG<(ins "Value":$source, "Value":$dest,

View File

@ -36,78 +36,68 @@ LogicalResult verify(OffsetSizeAndStrideOpInterface op);
#include "mlir/Interfaces/ViewLikeInterface.h.inc"
namespace mlir {
/// Print a list with either (1) the static integer value in `arrayAttr` if
/// `isDynamic` evaluates to false or (2) the next value otherwise.
/// This allows idiomatic printing of mixed value and integer attributes in a
/// Printer hook for custom directive in assemblyFormat.
///
/// custom<OperandsOrIntegersOffsetsOrStridesList>($values, $integers)
///
/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
/// type `I64ArrayAttr`. for use in in assemblyFormat. Prints a list with
/// either (1) the static integer value in `integers` if the value is
/// ShapedType::kDynamicStrideOrOffset or (2) the next value otherwise. This
/// allows idiomatic printing of mixed value and integer attributes in a
/// list. E.g. `[%arg0, 7, 42, %arg42]`.
void printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values,
ArrayAttr arrayAttr,
llvm::function_ref<bool(int64_t)> isDynamic);
void printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &printer,
Operation *op,
OperandRange values,
ArrayAttr integers);
/// Print part of an op of the form:
/// ```
/// <optional-offset-prefix>`[` offset-list `]`
/// <optional-size-prefix>`[` size-list `]`
/// <optional-stride-prefix>[` stride-list `]`
/// ```
void printOffsetsSizesAndStrides(
OpAsmPrinter &p, OffsetSizeAndStrideOpInterface op,
StringRef offsetPrefix = "", StringRef sizePrefix = " ",
StringRef stridePrefix = " ",
ArrayRef<StringRef> elidedAttrs =
OffsetSizeAndStrideOpInterface::getSpecialAttrNames());
/// Printer hook for custom directive in assemblyFormat.
///
/// custom<OperandsOrIntegersSizesList>($values, $integers)
///
/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
/// type `I64ArrayAttr`. for use in in assemblyFormat. Prints a list with
/// either (1) the static integer value in `integers` if the value is
/// ShapedType::kDynamicSize or (2) the next value otherwise. This
/// allows idiomatic printing of mixed value and integer attributes in a
/// list. E.g. `[%arg0, 7, 42, %arg42]`.
void printOperandsOrIntegersSizesList(OpAsmPrinter &printer, Operation *op,
OperandRange values, ArrayAttr integers);
/// Parse a mixed list with either (1) static integer values or (2) SSA values.
/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal`
/// encode the position of SSA values. Add the parsed SSA values to `ssa`
/// in-order.
/// Pasrer hook for custom directive in assemblyFormat.
///
/// custom<OperandsOrIntegersOffsetsOrStridesList>($values, $integers)
///
/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
/// type `I64ArrayAttr`. for use in in assemblyFormat. Parse a mixed list with
/// either (1) static integer values or (2) SSA values. Fill `integers` with
/// the integer ArrayAttr, where ShapedType::kDynamicStrideOrOffset encodes the
/// position of SSA values. Add the parsed SSA values to `values` in-order.
//
/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
/// 2. `ssa` is filled with "[%arg0, %arg1]".
ParseResult
parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result,
StringRef attrName, int64_t dynVal,
SmallVectorImpl<OpAsmParser::OperandType> &ssa);
ParseResult parseOperandsOrIntegersOffsetsOrStridesList(
OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &values,
ArrayAttr &integers);
/// Parse trailing part of an op of the form:
/// ```
/// <optional-offset-prefix>`[` offset-list `]`
/// <optional-size-prefix>`[` size-list `]`
/// <optional-stride-prefix>[` stride-list `]`
/// ```
/// Each entry in the offset, size and stride list either resolves to an integer
/// constant or an operand of index type.
/// Constants are added to the `result` as named integer array attributes with
/// name `OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName()` (resp.
/// `getStaticSizesAttrName()`, `getStaticStridesAttrName()`).
/// Pasrer hook for custom directive in assemblyFormat.
///
/// Append the number of offset, size and stride operands to `segmentSizes`
/// before adding it to `result` as the named attribute:
/// `OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr()`.
/// custom<OperandsOrIntegersSizesList>($values, $integers)
///
/// Offset, size and stride operands resolution occurs after `preResolutionFn`
/// to give a chance to leading operands to resolve first, after parsing the
/// types.
ParseResult parseOffsetsSizesAndStrides(
OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
llvm::function_ref<ParseResult(OpAsmParser &, OperationState &)>
preResolutionFn = nullptr,
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix =
nullptr,
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix =
nullptr,
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix =
nullptr);
/// `preResolutionFn`-less version of `parseOffsetsSizesAndStrides`.
ParseResult parseOffsetsSizesAndStrides(
OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix =
nullptr,
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix =
nullptr,
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix =
nullptr);
/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
/// type `I64ArrayAttr`. for use in in assemblyFormat. Parse a mixed list with
/// either (1) static integer values or (2) SSA values. Fill `integers` with
/// the integer ArrayAttr, where ShapedType::kDynamicSize encodes the
/// position of SSA values. Add the parsed SSA values to `values` in-order.
//
/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
/// 2. `ssa` is filled with "[%arg0, %arg1]".
ParseResult parseOperandsOrIntegersSizesList(
OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &values,
ArrayAttr &integers);
/// Verify that a the `values` has as many elements as the number of entries in
/// `attr` for which `isDynamic` evaluates to true.

View File

@ -676,30 +676,6 @@ static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
// InitTensorOp
//===----------------------------------------------------------------------===//
static ParseResult parseInitTensorOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType srcInfo;
Type dstType;
SmallVector<OpAsmParser::OperandType, 2> sizeInfo;
IndexType indexType = parser.getBuilder().getIndexType();
if (failed(parseListOfOperandsOrIntegers(
parser, result, InitTensorOp::getStaticSizesAttrName(),
ShapedType::kDynamicSize, sizeInfo)) ||
failed(parser.parseOptionalAttrDict(result.attributes)) ||
failed(parser.parseColonType(dstType)) ||
failed(parser.resolveOperands(sizeInfo, indexType, result.operands)))
return failure();
return parser.addTypeToList(dstType, result.types);
}
static void print(OpAsmPrinter &p, InitTensorOp op) {
p << op.getOperation()->getName() << ' ';
printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
ShapedType::isDynamic);
p.printOptionalAttrDict(op.getAttrs(),
InitTensorOp::getStaticSizesAttrName());
p << " : " << op.getType();
}
static LogicalResult verify(InitTensorOp op) {
RankedTensorType resultType = op.getType();
@ -981,8 +957,6 @@ static LogicalResult verify(PadTensorOp op) {
}
auto &region = op.region();
if (!llvm::hasSingleElement(region))
return op.emitOpError("expected region with 1 block");
unsigned rank = resultType.getRank();
Block &block = region.front();
if (block.getNumArguments() != rank)
@ -1020,67 +994,6 @@ RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType,
return RankedTensorType::get(resultShape, sourceType.getElementType());
}
static ParseResult parsePadTensorOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType baseInfo;
SmallVector<OpAsmParser::OperandType, 8> operands;
SmallVector<Type, 8> types;
if (parser.parseOperand(baseInfo))
return failure();
IndexType indexType = parser.getBuilder().getIndexType();
SmallVector<OpAsmParser::OperandType, 4> lowPadding, highPadding;
if (parser.parseKeyword("low") ||
parseListOfOperandsOrIntegers(parser, result,
PadTensorOp::getStaticLowAttrName(),
ShapedType::kDynamicSize, lowPadding))
return failure();
if (parser.parseKeyword("high") ||
parseListOfOperandsOrIntegers(parser, result,
PadTensorOp::getStaticHighAttrName(),
ShapedType::kDynamicSize, highPadding))
return failure();
SmallVector<OpAsmParser::OperandType, 8> regionOperands;
std::unique_ptr<Region> region = std::make_unique<Region>();
SmallVector<Type, 8> operandTypes, regionTypes;
if (parser.parseRegion(*region, regionOperands, regionTypes))
return failure();
result.addRegion(std::move(region));
Type srcType, dstType;
if (parser.parseColonType(srcType) || parser.parseKeywordType("to", dstType))
return failure();
if (parser.addTypeToList(dstType, result.types))
return failure();
SmallVector<int, 4> segmentSizesFinal = {1}; // source tensor
segmentSizesFinal.append({static_cast<int>(lowPadding.size()),
static_cast<int>(highPadding.size())});
result.addAttribute(
OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(),
parser.getBuilder().getI32VectorAttr(segmentSizesFinal));
return failure(
parser.parseOptionalAttrDict(result.attributes) ||
parser.resolveOperand(baseInfo, srcType, result.operands) ||
parser.resolveOperands(lowPadding, indexType, result.operands) ||
parser.resolveOperands(highPadding, indexType, result.operands));
}
static void print(OpAsmPrinter &p, PadTensorOp op) {
p << op->getName().getStringRef() << ' ';
p << op.source();
p << " low";
printListOfOperandsOrIntegers(p, op.low(), op.static_low(),
ShapedType::isDynamic);
p << " high";
printListOfOperandsOrIntegers(p, op.high(), op.static_high(),
ShapedType::isDynamic);
p.printRegion(op.region());
p << " : " << op.source().getType() << " to " << op.getType();
}
/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
/// it is a Value or into `staticVec` if it is an IntegerAttr.
/// In the case of a Value, a copy of the `sentinel` value is also pushed to

View File

@ -2148,67 +2148,6 @@ void mlir::MemRefReinterpretCastOp::build(OpBuilder &b, OperationState &result,
build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
}
/// Print a memref_reinterpret_cast op of the form:
/// ```
/// `memref_reinterpret_cast` ssa-name to
/// offset: `[` offset `]`
/// sizes: `[` size-list `]`
/// strides:`[` stride-list `]`
/// `:` any-memref-type to strided-memref-type
/// ```
static void print(OpAsmPrinter &p, MemRefReinterpretCastOp op) {
int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
p << op->getName().getStringRef().drop_front(stdDotLen) << ' ';
p << op.source() << " ";
printOffsetsSizesAndStrides(
p, op, /*offsetPrefix=*/"to offset: ", /*sizePrefix=*/", sizes: ",
/*stridePrefix=*/", strides: ");
p << ": " << op.source().getType() << " to " << op.getType();
}
/// Parse a memref_reinterpret_cast op of the form:
/// ```
/// `memref_reinterpret_cast` ssa-name to
/// offset: `[` offset `]`
/// sizes: `[` size-list `]`
/// strides:`[` stride-list `]`
/// `:` any-memref-type to strided-memref-type
/// ```
static ParseResult parseMemRefReinterpretCastOp(OpAsmParser &parser,
OperationState &result) {
// Parse `operand`
OpAsmParser::OperandType srcInfo;
if (parser.parseOperand(srcInfo))
return failure();
auto parseOffsetPrefix = [](OpAsmParser &parser) {
return failure(parser.parseKeyword("to") || parser.parseKeyword("offset") ||
parser.parseColon());
};
auto parseSizePrefix = [](OpAsmParser &parser) {
return failure(parser.parseComma() || parser.parseKeyword("sizes") ||
parser.parseColon());
};
auto parseStridePrefix = [](OpAsmParser &parser) {
return failure(parser.parseComma() || parser.parseKeyword("strides") ||
parser.parseColon());
};
Type srcType, dstType;
auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) {
return failure(parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(srcType) ||
parser.parseKeywordType("to", dstType) ||
parser.resolveOperand(srcInfo, srcType, result.operands));
};
if (failed(parseOffsetsSizesAndStrides(parser, result,
/*segmentSizes=*/{1}, // source memref
preResolutionFn, parseOffsetPrefix,
parseSizePrefix, parseStridePrefix)))
return failure();
return parser.addTypeToList(dstType, result.types);
}
// TODO: ponder whether we want to allow missing trailing sizes/strides that are
// completed automatically, like we have for subview and subtensor.
static LogicalResult verify(MemRefReinterpretCastOp op) {
@ -2892,45 +2831,6 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
sourceMemRefType.getMemorySpace());
}
/// Print a subview op of the form:
/// ```
/// `subview` ssa-name
/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
/// `:` strided-memref-type `to` strided-memref-type
/// ```
static void print(OpAsmPrinter &p, SubViewOp op) {
int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
p << op->getName().getStringRef().drop_front(stdDotLen) << ' ';
p << op.source();
printOffsetsSizesAndStrides(p, op);
p << " : " << op.source().getType() << " to " << op.getType();
}
/// Parse a subview op of the form:
/// ```
/// `subview` ssa-name
/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
/// `:` strided-memref-type `to` strided-memref-type
/// ```
static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType srcInfo;
if (parser.parseOperand(srcInfo))
return failure();
Type srcType, dstType;
auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) {
return failure(parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(srcType) ||
parser.parseKeywordType("to", dstType) ||
parser.resolveOperand(srcInfo, srcType, result.operands));
};
if (failed(parseOffsetsSizesAndStrides(parser, result,
/*segmentSizes=*/{1}, // source memref
preResolutionFn)))
return failure();
return parser.addTypeToList(dstType, result.types);
}
// Build a SubViewOp with mixed static and dynamic entries and custom result
// type. If the type passed is nullptr, it is inferred.
void mlir::SubViewOp::build(OpBuilder &b, OperationState &result,
@ -3466,46 +3366,6 @@ OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
// SubTensorOp
//===----------------------------------------------------------------------===//
/// Print a subtensor op of the form:
/// ```
/// `subtensor` ssa-name
/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
/// `:` ranked-tensor-type `to` ranked-tensor-type
/// ```
static void print(OpAsmPrinter &p, SubTensorOp op) {
int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
p << op->getName().getStringRef().drop_front(stdDotLen) << ' ';
p << op.source();
printOffsetsSizesAndStrides(p, op);
p << " : " << op.getSourceType() << " to " << op.getType();
}
/// Parse a subtensor op of the form:
/// ```
/// `subtensor` ssa-name
/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
/// `:` ranked-tensor-type `to` ranked-tensor-type
/// ```
static ParseResult parseSubTensorOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType srcInfo;
if (parser.parseOperand(srcInfo))
return failure();
Type srcType, dstType;
auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) {
return failure(parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(srcType) ||
parser.parseKeywordType("to", dstType) ||
parser.resolveOperand(srcInfo, srcType, result.operands));
};
if (failed(parseOffsetsSizesAndStrides(parser, result,
/*segmentSizes=*/{1}, // source tensor
preResolutionFn)))
return failure();
return parser.addTypeToList(dstType, result.types);
}
/// A subtensor result type can be fully inferred from the source type and the
/// static representation of offsets, sizes and strides. Special sentinels
/// encode the dynamic case.
@ -3612,49 +3472,6 @@ void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
// SubTensorInsertOp
//===----------------------------------------------------------------------===//
/// Print a subtensor_insert op of the form:
/// ```
/// `subtensor_insert` ssa-name `into` ssa-name
/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
/// `:` ranked-tensor-type `into` ranked-tensor-type
/// ```
static void print(OpAsmPrinter &p, SubTensorInsertOp op) {
int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
p << op->getName().getStringRef().drop_front(stdDotLen) << ' ';
p << op.source() << " into " << op.dest();
printOffsetsSizesAndStrides(p, op);
p << " : " << op.getSourceType() << " into " << op.getType();
}
/// Parse a subtensor_insert op of the form:
/// ```
/// `subtensor_insert` ssa-name `into` ssa-name
/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
/// `:` ranked-tensor-type `into` ranked-tensor-type
/// ```
static ParseResult parseSubTensorInsertOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType srcInfo, dstInfo;
if (parser.parseOperand(srcInfo) || parser.parseKeyword("into") ||
parser.parseOperand(dstInfo))
return failure();
Type srcType, dstType;
auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) {
return failure(parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(srcType) ||
parser.parseKeywordType("into", dstType) ||
parser.resolveOperand(srcInfo, srcType, result.operands) ||
parser.resolveOperand(dstInfo, dstType, result.operands));
};
if (failed(parseOffsetsSizesAndStrides(
parser, result,
/*segmentSizes=*/{1, 1}, // source tensor, destination tensor
preResolutionFn)))
return failure();
return parser.addTypeToList(dstType, result.types);
}
// Build a SubTensorInsertOp with mixed static and dynamic entries.
void mlir::SubTensorInsertOp::build(OpBuilder &b, OperationState &result,
Value source, Value dest,
@ -3691,13 +3508,6 @@ void mlir::SubTensorInsertOp::build(OpBuilder &b, OperationState &result,
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
}
/// Verifier for SubViewOp.
static LogicalResult verify(SubTensorInsertOp op) {
if (op.getType() != op.dest().getType())
return op.emitError("expected result type to be ") << op.dest().getType();
return success();
}
//===----------------------------------------------------------------------===//
// TensorLoadOp
//===----------------------------------------------------------------------===//

View File

@ -69,14 +69,18 @@ LogicalResult mlir::verify(OffsetSizeAndStrideOpInterface op) {
return success();
}
void mlir::printListOfOperandsOrIntegers(
OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr,
llvm::function_ref<bool(int64_t)> isDynamic) {
template <int64_t dynVal>
static void printOperandsOrIntegersListImpl(OpAsmPrinter &p, ValueRange values,
ArrayAttr arrayAttr) {
p << '[';
if (arrayAttr.empty()) {
p << "]";
return;
}
unsigned idx = 0;
llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
int64_t val = a.cast<IntegerAttr>().getInt();
if (isDynamic(val))
if (val == dynVal)
p << values[idx++];
else
p << val;
@ -84,32 +88,31 @@ void mlir::printListOfOperandsOrIntegers(
p << ']';
}
void mlir::printOffsetsSizesAndStrides(OpAsmPrinter &p,
OffsetSizeAndStrideOpInterface op,
StringRef offsetPrefix,
StringRef sizePrefix,
StringRef stridePrefix,
ArrayRef<StringRef> elidedAttrs) {
p << offsetPrefix;
printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
ShapedType::isDynamicStrideOrOffset);
p << sizePrefix;
printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
ShapedType::isDynamic);
p << stridePrefix;
printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
ShapedType::isDynamicStrideOrOffset);
p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
void mlir::printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &p,
Operation *op,
OperandRange values,
ArrayAttr integers) {
return printOperandsOrIntegersListImpl<ShapedType::kDynamicStrideOrOffset>(
p, values, integers);
}
ParseResult mlir::parseListOfOperandsOrIntegers(
OpAsmParser &parser, OperationState &result, StringRef attrName,
int64_t dynVal, SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
void mlir::printOperandsOrIntegersSizesList(OpAsmPrinter &p, Operation *op,
OperandRange values,
ArrayAttr integers) {
return printOperandsOrIntegersListImpl<ShapedType::kDynamicSize>(p, values,
integers);
}
template <int64_t dynVal>
static ParseResult
parseOperandsOrIntegersImpl(OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::OperandType> &values,
ArrayAttr &integers) {
if (failed(parser.parseLSquare()))
return failure();
// 0-D.
if (succeeded(parser.parseOptionalRSquare())) {
result.addAttribute(attrName, parser.getBuilder().getArrayAttr({}));
integers = parser.getBuilder().getArrayAttr({});
return success();
}
@ -118,7 +121,7 @@ ParseResult mlir::parseListOfOperandsOrIntegers(
OpAsmParser::OperandType operand;
auto res = parser.parseOptionalOperand(operand);
if (res.hasValue() && succeeded(res.getValue())) {
ssa.push_back(operand);
values.push_back(operand);
attrVals.push_back(dynVal);
} else {
IntegerAttr attr;
@ -134,59 +137,20 @@ ParseResult mlir::parseListOfOperandsOrIntegers(
return failure();
break;
}
auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals);
result.addAttribute(attrName, arrayAttr);
integers = parser.getBuilder().getI64ArrayAttr(attrVals);
return success();
}
ParseResult mlir::parseOffsetsSizesAndStrides(
OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix,
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix,
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix) {
return parseOffsetsSizesAndStrides(
parser, result, segmentSizes, nullptr, parseOptionalOffsetPrefix,
parseOptionalSizePrefix, parseOptionalStridePrefix);
ParseResult mlir::parseOperandsOrIntegersOffsetsOrStridesList(
OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &values,
ArrayAttr &integers) {
return parseOperandsOrIntegersImpl<ShapedType::kDynamicStrideOrOffset>(
parser, values, integers);
}
ParseResult mlir::parseOffsetsSizesAndStrides(
OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
llvm::function_ref<ParseResult(OpAsmParser &, OperationState &)>
preResolutionFn,
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix,
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix,
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix) {
SmallVector<OpAsmParser::OperandType, 4> offsetsInfo, sizesInfo, stridesInfo;
auto indexType = parser.getBuilder().getIndexType();
if ((parseOptionalOffsetPrefix && parseOptionalOffsetPrefix(parser)) ||
parseListOfOperandsOrIntegers(
parser, result,
OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(),
ShapedType::kDynamicStrideOrOffset, offsetsInfo) ||
(parseOptionalSizePrefix && parseOptionalSizePrefix(parser)) ||
parseListOfOperandsOrIntegers(
parser, result,
OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(),
ShapedType::kDynamicSize, sizesInfo) ||
(parseOptionalStridePrefix && parseOptionalStridePrefix(parser)) ||
parseListOfOperandsOrIntegers(
parser, result,
OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(),
ShapedType::kDynamicStrideOrOffset, stridesInfo))
return failure();
// Add segment sizes to result
SmallVector<int, 4> segmentSizesFinal(segmentSizes.begin(),
segmentSizes.end());
segmentSizesFinal.append({static_cast<int>(offsetsInfo.size()),
static_cast<int>(sizesInfo.size()),
static_cast<int>(stridesInfo.size())});
result.addAttribute(
OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(),
parser.getBuilder().getI32VectorAttr(segmentSizesFinal));
return failure(
(preResolutionFn && preResolutionFn(parser, result)) ||
parser.resolveOperands(offsetsInfo, indexType, result.operands) ||
parser.resolveOperands(sizesInfo, indexType, result.operands) ||
parser.resolveOperands(stridesInfo, indexType, result.operands));
ParseResult mlir::parseOperandsOrIntegersSizesList(
OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &values,
ArrayAttr &integers) {
return parseOperandsOrIntegersImpl<ShapedType::kDynamicSize>(parser, values,
integers);
}

View File

@ -643,7 +643,7 @@ func @pad_number_of_block_args(%arg0: tensor<?x4xi32>, %arg1: i32) -> tensor<?x9
// -----
func @pad_no_block(%arg0: tensor<?x4xi32>, %arg1: i32) -> tensor<?x9xi32> {
// expected-error @+1 {{expected region with 1 block}}
// expected-error @+1 {{op region #0 ('region') failed to verify constraint: region with 1 blocks}}
%0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] {
} : tensor<?x4xi32> to tensor<?x9xi32>
return %0 : tensor<?x9xi32>

View File

@ -1581,6 +1581,8 @@ static void genSpacePrinter(bool value, OpMethodBody &body,
if (value) {
body << " p << ' ';\n";
lastWasPunctuation = false;
} else {
lastWasPunctuation = true;
}
shouldEmitSpace = false;
}