forked from OSchip/llvm-project
[mlir] Add custom directive hooks for printing mixed integer or value operands.
Add printer and parser hooks for a custom directive that allows parsing and printing of idioms that can represent a list of values each of which is either an integer or an SSA value. For example in `subview %source[%offset_0, 1] [4, %size_1] [%stride_0, 3]` each of the list (which represents offset, size and strides) is a mix of either statically know integer values or dynamically computed SSA values. Since this is used in many places adding a custom directive to parse/print this idiom allows using assembly format on operations which use this idiom. Differential Revision: https://reviews.llvm.org/D95773
This commit is contained in:
parent
87f8a08ce3
commit
342d4662e1
|
@ -45,6 +45,11 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> {
|
|||
|
||||
let results = (outs AnyTensor:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes) attr-dict
|
||||
`:` type($result)
|
||||
}];
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
@ -118,7 +123,7 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> {
|
|||
}
|
||||
|
||||
def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
|
||||
[AttrSizedOperandSegments, SingleBlockImplicitTerminator<"YieldOp">]> {
|
||||
[AttrSizedOperandSegments]> {
|
||||
let summary = "tensor pad operation";
|
||||
let description = [{
|
||||
`linalg.pad_tensor` is an operation that pads the `source` tensor
|
||||
|
@ -181,10 +186,16 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
|
|||
I64ArrayAttr:$static_low,
|
||||
I64ArrayAttr:$static_high);
|
||||
|
||||
let regions = (region AnyRegion:$region);
|
||||
let regions = (region SizedRegion<1>:$region);
|
||||
|
||||
let results = (outs AnyTensor:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$source `low` `` custom<OperandsOrIntegersSizesList>($low, $static_low)
|
||||
`high` `` custom<OperandsOrIntegersSizesList>($high, $static_high)
|
||||
$region attr-dict `:` type($source) `to` type($result)
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static StringRef getStaticLowAttrName() {
|
||||
return "static_low";
|
||||
|
|
|
@ -1956,6 +1956,19 @@ def MemRefReinterpretCastOp:
|
|||
);
|
||||
let results = (outs AnyMemRef:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$source `to` `offset` `` `:`
|
||||
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
|
||||
`` `,` `sizes` `` `:`
|
||||
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes) `` `,` `strides`
|
||||
`` `:`
|
||||
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
|
||||
attr-dict `:` type($source) `to` type($result)
|
||||
}];
|
||||
|
||||
let parser=?;
|
||||
let printer=?;
|
||||
|
||||
let builders = [
|
||||
// Build a ReinterpretCastOp with mixed static and dynamic entries.
|
||||
OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$source,
|
||||
|
@ -2931,6 +2944,14 @@ def SubViewOp : BaseOpWithOffsetSizesAndStrides<
|
|||
);
|
||||
let results = (outs AnyMemRef:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$source ``
|
||||
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
|
||||
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
|
||||
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
|
||||
attr-dict `:` type($source) `to` type($result)
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
// Build a SubViewOp with mixed static and dynamic entries and custom
|
||||
// result type. If the type passed is nullptr, it is inferred.
|
||||
|
@ -3053,6 +3074,14 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<
|
|||
);
|
||||
let results = (outs AnyRankedTensor:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$source ``
|
||||
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
|
||||
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
|
||||
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
|
||||
attr-dict `:` type($source) `to` type($result)
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
// Build a SubTensorOp with mixed static and dynamic entries and inferred
|
||||
// result type.
|
||||
|
@ -3115,7 +3144,10 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<
|
||||
"subtensor_insert", [OffsetSizeAndStrideOpInterface]> {
|
||||
"subtensor_insert",
|
||||
[OffsetSizeAndStrideOpInterface,
|
||||
TypesMatchWith<"expected result type to match dest type",
|
||||
"dest", "result", "$_self">]> {
|
||||
let summary = "subtensor_insert operation";
|
||||
let description = [{
|
||||
The "subtensor_insert" operation insert a tensor `source` into another
|
||||
|
@ -3159,6 +3191,16 @@ def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<
|
|||
);
|
||||
let results = (outs AnyRankedTensor:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$source `into` $dest ``
|
||||
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
|
||||
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
|
||||
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
|
||||
attr-dict `:` type($source) `into` type($dest)
|
||||
}];
|
||||
|
||||
let verifier = ?;
|
||||
|
||||
let builders = [
|
||||
// Build a SubTensorInsertOp with mixed static and dynamic entries.
|
||||
OpBuilderDAG<(ins "Value":$source, "Value":$dest,
|
||||
|
|
|
@ -36,78 +36,68 @@ LogicalResult verify(OffsetSizeAndStrideOpInterface op);
|
|||
#include "mlir/Interfaces/ViewLikeInterface.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
/// Print a list with either (1) the static integer value in `arrayAttr` if
|
||||
/// `isDynamic` evaluates to false or (2) the next value otherwise.
|
||||
/// This allows idiomatic printing of mixed value and integer attributes in a
|
||||
|
||||
/// Printer hook for custom directive in assemblyFormat.
|
||||
///
|
||||
/// custom<OperandsOrIntegersOffsetsOrStridesList>($values, $integers)
|
||||
///
|
||||
/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
|
||||
/// type `I64ArrayAttr`. for use in in assemblyFormat. Prints a list with
|
||||
/// either (1) the static integer value in `integers` if the value is
|
||||
/// ShapedType::kDynamicStrideOrOffset or (2) the next value otherwise. This
|
||||
/// allows idiomatic printing of mixed value and integer attributes in a
|
||||
/// list. E.g. `[%arg0, 7, 42, %arg42]`.
|
||||
void printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values,
|
||||
ArrayAttr arrayAttr,
|
||||
llvm::function_ref<bool(int64_t)> isDynamic);
|
||||
void printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &printer,
|
||||
Operation *op,
|
||||
OperandRange values,
|
||||
ArrayAttr integers);
|
||||
|
||||
/// Print part of an op of the form:
|
||||
/// ```
|
||||
/// <optional-offset-prefix>`[` offset-list `]`
|
||||
/// <optional-size-prefix>`[` size-list `]`
|
||||
/// <optional-stride-prefix>[` stride-list `]`
|
||||
/// ```
|
||||
void printOffsetsSizesAndStrides(
|
||||
OpAsmPrinter &p, OffsetSizeAndStrideOpInterface op,
|
||||
StringRef offsetPrefix = "", StringRef sizePrefix = " ",
|
||||
StringRef stridePrefix = " ",
|
||||
ArrayRef<StringRef> elidedAttrs =
|
||||
OffsetSizeAndStrideOpInterface::getSpecialAttrNames());
|
||||
/// Printer hook for custom directive in assemblyFormat.
|
||||
///
|
||||
/// custom<OperandsOrIntegersSizesList>($values, $integers)
|
||||
///
|
||||
/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
|
||||
/// type `I64ArrayAttr`. for use in in assemblyFormat. Prints a list with
|
||||
/// either (1) the static integer value in `integers` if the value is
|
||||
/// ShapedType::kDynamicSize or (2) the next value otherwise. This
|
||||
/// allows idiomatic printing of mixed value and integer attributes in a
|
||||
/// list. E.g. `[%arg0, 7, 42, %arg42]`.
|
||||
void printOperandsOrIntegersSizesList(OpAsmPrinter &printer, Operation *op,
|
||||
OperandRange values, ArrayAttr integers);
|
||||
|
||||
/// Parse a mixed list with either (1) static integer values or (2) SSA values.
|
||||
/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal`
|
||||
/// encode the position of SSA values. Add the parsed SSA values to `ssa`
|
||||
/// in-order.
|
||||
/// Pasrer hook for custom directive in assemblyFormat.
|
||||
///
|
||||
/// custom<OperandsOrIntegersOffsetsOrStridesList>($values, $integers)
|
||||
///
|
||||
/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
|
||||
/// type `I64ArrayAttr`. for use in in assemblyFormat. Parse a mixed list with
|
||||
/// either (1) static integer values or (2) SSA values. Fill `integers` with
|
||||
/// the integer ArrayAttr, where ShapedType::kDynamicStrideOrOffset encodes the
|
||||
/// position of SSA values. Add the parsed SSA values to `values` in-order.
|
||||
//
|
||||
/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
|
||||
/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
|
||||
/// 2. `ssa` is filled with "[%arg0, %arg1]".
|
||||
ParseResult
|
||||
parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result,
|
||||
StringRef attrName, int64_t dynVal,
|
||||
SmallVectorImpl<OpAsmParser::OperandType> &ssa);
|
||||
ParseResult parseOperandsOrIntegersOffsetsOrStridesList(
|
||||
OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &values,
|
||||
ArrayAttr &integers);
|
||||
|
||||
/// Parse trailing part of an op of the form:
|
||||
/// ```
|
||||
/// <optional-offset-prefix>`[` offset-list `]`
|
||||
/// <optional-size-prefix>`[` size-list `]`
|
||||
/// <optional-stride-prefix>[` stride-list `]`
|
||||
/// ```
|
||||
/// Each entry in the offset, size and stride list either resolves to an integer
|
||||
/// constant or an operand of index type.
|
||||
/// Constants are added to the `result` as named integer array attributes with
|
||||
/// name `OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName()` (resp.
|
||||
/// `getStaticSizesAttrName()`, `getStaticStridesAttrName()`).
|
||||
/// Pasrer hook for custom directive in assemblyFormat.
|
||||
///
|
||||
/// Append the number of offset, size and stride operands to `segmentSizes`
|
||||
/// before adding it to `result` as the named attribute:
|
||||
/// `OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr()`.
|
||||
/// custom<OperandsOrIntegersSizesList>($values, $integers)
|
||||
///
|
||||
/// Offset, size and stride operands resolution occurs after `preResolutionFn`
|
||||
/// to give a chance to leading operands to resolve first, after parsing the
|
||||
/// types.
|
||||
ParseResult parseOffsetsSizesAndStrides(
|
||||
OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
|
||||
llvm::function_ref<ParseResult(OpAsmParser &, OperationState &)>
|
||||
preResolutionFn = nullptr,
|
||||
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix =
|
||||
nullptr,
|
||||
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix =
|
||||
nullptr,
|
||||
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix =
|
||||
nullptr);
|
||||
/// `preResolutionFn`-less version of `parseOffsetsSizesAndStrides`.
|
||||
ParseResult parseOffsetsSizesAndStrides(
|
||||
OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
|
||||
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix =
|
||||
nullptr,
|
||||
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix =
|
||||
nullptr,
|
||||
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix =
|
||||
nullptr);
|
||||
/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
|
||||
/// type `I64ArrayAttr`. for use in in assemblyFormat. Parse a mixed list with
|
||||
/// either (1) static integer values or (2) SSA values. Fill `integers` with
|
||||
/// the integer ArrayAttr, where ShapedType::kDynamicSize encodes the
|
||||
/// position of SSA values. Add the parsed SSA values to `values` in-order.
|
||||
//
|
||||
/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
|
||||
/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
|
||||
/// 2. `ssa` is filled with "[%arg0, %arg1]".
|
||||
ParseResult parseOperandsOrIntegersSizesList(
|
||||
OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &values,
|
||||
ArrayAttr &integers);
|
||||
|
||||
/// Verify that a the `values` has as many elements as the number of entries in
|
||||
/// `attr` for which `isDynamic` evaluates to true.
|
||||
|
|
|
@ -676,30 +676,6 @@ static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
|
|||
// InitTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseInitTensorOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::OperandType srcInfo;
|
||||
Type dstType;
|
||||
SmallVector<OpAsmParser::OperandType, 2> sizeInfo;
|
||||
IndexType indexType = parser.getBuilder().getIndexType();
|
||||
if (failed(parseListOfOperandsOrIntegers(
|
||||
parser, result, InitTensorOp::getStaticSizesAttrName(),
|
||||
ShapedType::kDynamicSize, sizeInfo)) ||
|
||||
failed(parser.parseOptionalAttrDict(result.attributes)) ||
|
||||
failed(parser.parseColonType(dstType)) ||
|
||||
failed(parser.resolveOperands(sizeInfo, indexType, result.operands)))
|
||||
return failure();
|
||||
return parser.addTypeToList(dstType, result.types);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, InitTensorOp op) {
|
||||
p << op.getOperation()->getName() << ' ';
|
||||
printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
|
||||
ShapedType::isDynamic);
|
||||
p.printOptionalAttrDict(op.getAttrs(),
|
||||
InitTensorOp::getStaticSizesAttrName());
|
||||
p << " : " << op.getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(InitTensorOp op) {
|
||||
RankedTensorType resultType = op.getType();
|
||||
|
@ -981,8 +957,6 @@ static LogicalResult verify(PadTensorOp op) {
|
|||
}
|
||||
|
||||
auto ®ion = op.region();
|
||||
if (!llvm::hasSingleElement(region))
|
||||
return op.emitOpError("expected region with 1 block");
|
||||
unsigned rank = resultType.getRank();
|
||||
Block &block = region.front();
|
||||
if (block.getNumArguments() != rank)
|
||||
|
@ -1020,67 +994,6 @@ RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType,
|
|||
return RankedTensorType::get(resultShape, sourceType.getElementType());
|
||||
}
|
||||
|
||||
static ParseResult parsePadTensorOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::OperandType baseInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 8> operands;
|
||||
SmallVector<Type, 8> types;
|
||||
if (parser.parseOperand(baseInfo))
|
||||
return failure();
|
||||
|
||||
IndexType indexType = parser.getBuilder().getIndexType();
|
||||
SmallVector<OpAsmParser::OperandType, 4> lowPadding, highPadding;
|
||||
if (parser.parseKeyword("low") ||
|
||||
parseListOfOperandsOrIntegers(parser, result,
|
||||
PadTensorOp::getStaticLowAttrName(),
|
||||
ShapedType::kDynamicSize, lowPadding))
|
||||
return failure();
|
||||
if (parser.parseKeyword("high") ||
|
||||
parseListOfOperandsOrIntegers(parser, result,
|
||||
PadTensorOp::getStaticHighAttrName(),
|
||||
ShapedType::kDynamicSize, highPadding))
|
||||
return failure();
|
||||
|
||||
SmallVector<OpAsmParser::OperandType, 8> regionOperands;
|
||||
std::unique_ptr<Region> region = std::make_unique<Region>();
|
||||
SmallVector<Type, 8> operandTypes, regionTypes;
|
||||
if (parser.parseRegion(*region, regionOperands, regionTypes))
|
||||
return failure();
|
||||
result.addRegion(std::move(region));
|
||||
|
||||
Type srcType, dstType;
|
||||
if (parser.parseColonType(srcType) || parser.parseKeywordType("to", dstType))
|
||||
return failure();
|
||||
|
||||
if (parser.addTypeToList(dstType, result.types))
|
||||
return failure();
|
||||
|
||||
SmallVector<int, 4> segmentSizesFinal = {1}; // source tensor
|
||||
segmentSizesFinal.append({static_cast<int>(lowPadding.size()),
|
||||
static_cast<int>(highPadding.size())});
|
||||
result.addAttribute(
|
||||
OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(),
|
||||
parser.getBuilder().getI32VectorAttr(segmentSizesFinal));
|
||||
return failure(
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.resolveOperand(baseInfo, srcType, result.operands) ||
|
||||
parser.resolveOperands(lowPadding, indexType, result.operands) ||
|
||||
parser.resolveOperands(highPadding, indexType, result.operands));
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, PadTensorOp op) {
|
||||
p << op->getName().getStringRef() << ' ';
|
||||
p << op.source();
|
||||
p << " low";
|
||||
printListOfOperandsOrIntegers(p, op.low(), op.static_low(),
|
||||
ShapedType::isDynamic);
|
||||
p << " high";
|
||||
printListOfOperandsOrIntegers(p, op.high(), op.static_high(),
|
||||
ShapedType::isDynamic);
|
||||
p.printRegion(op.region());
|
||||
p << " : " << op.source().getType() << " to " << op.getType();
|
||||
}
|
||||
|
||||
/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
|
||||
/// it is a Value or into `staticVec` if it is an IntegerAttr.
|
||||
/// In the case of a Value, a copy of the `sentinel` value is also pushed to
|
||||
|
|
|
@ -2148,67 +2148,6 @@ void mlir::MemRefReinterpretCastOp::build(OpBuilder &b, OperationState &result,
|
|||
build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
|
||||
}
|
||||
|
||||
/// Print a memref_reinterpret_cast op of the form:
|
||||
/// ```
|
||||
/// `memref_reinterpret_cast` ssa-name to
|
||||
/// offset: `[` offset `]`
|
||||
/// sizes: `[` size-list `]`
|
||||
/// strides:`[` stride-list `]`
|
||||
/// `:` any-memref-type to strided-memref-type
|
||||
/// ```
|
||||
static void print(OpAsmPrinter &p, MemRefReinterpretCastOp op) {
|
||||
int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
|
||||
p << op->getName().getStringRef().drop_front(stdDotLen) << ' ';
|
||||
p << op.source() << " ";
|
||||
printOffsetsSizesAndStrides(
|
||||
p, op, /*offsetPrefix=*/"to offset: ", /*sizePrefix=*/", sizes: ",
|
||||
/*stridePrefix=*/", strides: ");
|
||||
p << ": " << op.source().getType() << " to " << op.getType();
|
||||
}
|
||||
|
||||
/// Parse a memref_reinterpret_cast op of the form:
|
||||
/// ```
|
||||
/// `memref_reinterpret_cast` ssa-name to
|
||||
/// offset: `[` offset `]`
|
||||
/// sizes: `[` size-list `]`
|
||||
/// strides:`[` stride-list `]`
|
||||
/// `:` any-memref-type to strided-memref-type
|
||||
/// ```
|
||||
static ParseResult parseMemRefReinterpretCastOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
// Parse `operand`
|
||||
OpAsmParser::OperandType srcInfo;
|
||||
if (parser.parseOperand(srcInfo))
|
||||
return failure();
|
||||
|
||||
auto parseOffsetPrefix = [](OpAsmParser &parser) {
|
||||
return failure(parser.parseKeyword("to") || parser.parseKeyword("offset") ||
|
||||
parser.parseColon());
|
||||
};
|
||||
auto parseSizePrefix = [](OpAsmParser &parser) {
|
||||
return failure(parser.parseComma() || parser.parseKeyword("sizes") ||
|
||||
parser.parseColon());
|
||||
};
|
||||
auto parseStridePrefix = [](OpAsmParser &parser) {
|
||||
return failure(parser.parseComma() || parser.parseKeyword("strides") ||
|
||||
parser.parseColon());
|
||||
};
|
||||
|
||||
Type srcType, dstType;
|
||||
auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) {
|
||||
return failure(parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(srcType) ||
|
||||
parser.parseKeywordType("to", dstType) ||
|
||||
parser.resolveOperand(srcInfo, srcType, result.operands));
|
||||
};
|
||||
if (failed(parseOffsetsSizesAndStrides(parser, result,
|
||||
/*segmentSizes=*/{1}, // source memref
|
||||
preResolutionFn, parseOffsetPrefix,
|
||||
parseSizePrefix, parseStridePrefix)))
|
||||
return failure();
|
||||
return parser.addTypeToList(dstType, result.types);
|
||||
}
|
||||
|
||||
// TODO: ponder whether we want to allow missing trailing sizes/strides that are
|
||||
// completed automatically, like we have for subview and subtensor.
|
||||
static LogicalResult verify(MemRefReinterpretCastOp op) {
|
||||
|
@ -2892,45 +2831,6 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
|
|||
sourceMemRefType.getMemorySpace());
|
||||
}
|
||||
|
||||
/// Print a subview op of the form:
|
||||
/// ```
|
||||
/// `subview` ssa-name
|
||||
/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
|
||||
/// `:` strided-memref-type `to` strided-memref-type
|
||||
/// ```
|
||||
static void print(OpAsmPrinter &p, SubViewOp op) {
|
||||
int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
|
||||
p << op->getName().getStringRef().drop_front(stdDotLen) << ' ';
|
||||
p << op.source();
|
||||
printOffsetsSizesAndStrides(p, op);
|
||||
p << " : " << op.source().getType() << " to " << op.getType();
|
||||
}
|
||||
|
||||
/// Parse a subview op of the form:
|
||||
/// ```
|
||||
/// `subview` ssa-name
|
||||
/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
|
||||
/// `:` strided-memref-type `to` strided-memref-type
|
||||
/// ```
|
||||
static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
|
||||
OpAsmParser::OperandType srcInfo;
|
||||
if (parser.parseOperand(srcInfo))
|
||||
return failure();
|
||||
Type srcType, dstType;
|
||||
auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) {
|
||||
return failure(parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(srcType) ||
|
||||
parser.parseKeywordType("to", dstType) ||
|
||||
parser.resolveOperand(srcInfo, srcType, result.operands));
|
||||
};
|
||||
|
||||
if (failed(parseOffsetsSizesAndStrides(parser, result,
|
||||
/*segmentSizes=*/{1}, // source memref
|
||||
preResolutionFn)))
|
||||
return failure();
|
||||
return parser.addTypeToList(dstType, result.types);
|
||||
}
|
||||
|
||||
// Build a SubViewOp with mixed static and dynamic entries and custom result
|
||||
// type. If the type passed is nullptr, it is inferred.
|
||||
void mlir::SubViewOp::build(OpBuilder &b, OperationState &result,
|
||||
|
@ -3466,46 +3366,6 @@ OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
|
|||
// SubTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Print a subtensor op of the form:
|
||||
/// ```
|
||||
/// `subtensor` ssa-name
|
||||
/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
|
||||
/// `:` ranked-tensor-type `to` ranked-tensor-type
|
||||
/// ```
|
||||
static void print(OpAsmPrinter &p, SubTensorOp op) {
|
||||
int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
|
||||
p << op->getName().getStringRef().drop_front(stdDotLen) << ' ';
|
||||
p << op.source();
|
||||
printOffsetsSizesAndStrides(p, op);
|
||||
p << " : " << op.getSourceType() << " to " << op.getType();
|
||||
}
|
||||
|
||||
/// Parse a subtensor op of the form:
|
||||
/// ```
|
||||
/// `subtensor` ssa-name
|
||||
/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
|
||||
/// `:` ranked-tensor-type `to` ranked-tensor-type
|
||||
/// ```
|
||||
static ParseResult parseSubTensorOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::OperandType srcInfo;
|
||||
if (parser.parseOperand(srcInfo))
|
||||
return failure();
|
||||
Type srcType, dstType;
|
||||
auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) {
|
||||
return failure(parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(srcType) ||
|
||||
parser.parseKeywordType("to", dstType) ||
|
||||
parser.resolveOperand(srcInfo, srcType, result.operands));
|
||||
};
|
||||
|
||||
if (failed(parseOffsetsSizesAndStrides(parser, result,
|
||||
/*segmentSizes=*/{1}, // source tensor
|
||||
preResolutionFn)))
|
||||
return failure();
|
||||
return parser.addTypeToList(dstType, result.types);
|
||||
}
|
||||
|
||||
/// A subtensor result type can be fully inferred from the source type and the
|
||||
/// static representation of offsets, sizes and strides. Special sentinels
|
||||
/// encode the dynamic case.
|
||||
|
@ -3612,49 +3472,6 @@ void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||
// SubTensorInsertOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Print a subtensor_insert op of the form:
|
||||
/// ```
|
||||
/// `subtensor_insert` ssa-name `into` ssa-name
|
||||
/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
|
||||
/// `:` ranked-tensor-type `into` ranked-tensor-type
|
||||
/// ```
|
||||
static void print(OpAsmPrinter &p, SubTensorInsertOp op) {
|
||||
int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
|
||||
p << op->getName().getStringRef().drop_front(stdDotLen) << ' ';
|
||||
p << op.source() << " into " << op.dest();
|
||||
printOffsetsSizesAndStrides(p, op);
|
||||
p << " : " << op.getSourceType() << " into " << op.getType();
|
||||
}
|
||||
|
||||
/// Parse a subtensor_insert op of the form:
|
||||
/// ```
|
||||
/// `subtensor_insert` ssa-name `into` ssa-name
|
||||
/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
|
||||
/// `:` ranked-tensor-type `into` ranked-tensor-type
|
||||
/// ```
|
||||
static ParseResult parseSubTensorInsertOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::OperandType srcInfo, dstInfo;
|
||||
if (parser.parseOperand(srcInfo) || parser.parseKeyword("into") ||
|
||||
parser.parseOperand(dstInfo))
|
||||
return failure();
|
||||
Type srcType, dstType;
|
||||
auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) {
|
||||
return failure(parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(srcType) ||
|
||||
parser.parseKeywordType("into", dstType) ||
|
||||
parser.resolveOperand(srcInfo, srcType, result.operands) ||
|
||||
parser.resolveOperand(dstInfo, dstType, result.operands));
|
||||
};
|
||||
|
||||
if (failed(parseOffsetsSizesAndStrides(
|
||||
parser, result,
|
||||
/*segmentSizes=*/{1, 1}, // source tensor, destination tensor
|
||||
preResolutionFn)))
|
||||
return failure();
|
||||
return parser.addTypeToList(dstType, result.types);
|
||||
}
|
||||
|
||||
// Build a SubTensorInsertOp with mixed static and dynamic entries.
|
||||
void mlir::SubTensorInsertOp::build(OpBuilder &b, OperationState &result,
|
||||
Value source, Value dest,
|
||||
|
@ -3691,13 +3508,6 @@ void mlir::SubTensorInsertOp::build(OpBuilder &b, OperationState &result,
|
|||
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
|
||||
}
|
||||
|
||||
/// Verifier for SubViewOp.
|
||||
static LogicalResult verify(SubTensorInsertOp op) {
|
||||
if (op.getType() != op.dest().getType())
|
||||
return op.emitError("expected result type to be ") << op.dest().getType();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorLoadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -69,14 +69,18 @@ LogicalResult mlir::verify(OffsetSizeAndStrideOpInterface op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
void mlir::printListOfOperandsOrIntegers(
|
||||
OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr,
|
||||
llvm::function_ref<bool(int64_t)> isDynamic) {
|
||||
template <int64_t dynVal>
|
||||
static void printOperandsOrIntegersListImpl(OpAsmPrinter &p, ValueRange values,
|
||||
ArrayAttr arrayAttr) {
|
||||
p << '[';
|
||||
if (arrayAttr.empty()) {
|
||||
p << "]";
|
||||
return;
|
||||
}
|
||||
unsigned idx = 0;
|
||||
llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
|
||||
int64_t val = a.cast<IntegerAttr>().getInt();
|
||||
if (isDynamic(val))
|
||||
if (val == dynVal)
|
||||
p << values[idx++];
|
||||
else
|
||||
p << val;
|
||||
|
@ -84,32 +88,31 @@ void mlir::printListOfOperandsOrIntegers(
|
|||
p << ']';
|
||||
}
|
||||
|
||||
void mlir::printOffsetsSizesAndStrides(OpAsmPrinter &p,
|
||||
OffsetSizeAndStrideOpInterface op,
|
||||
StringRef offsetPrefix,
|
||||
StringRef sizePrefix,
|
||||
StringRef stridePrefix,
|
||||
ArrayRef<StringRef> elidedAttrs) {
|
||||
p << offsetPrefix;
|
||||
printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
|
||||
ShapedType::isDynamicStrideOrOffset);
|
||||
p << sizePrefix;
|
||||
printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
|
||||
ShapedType::isDynamic);
|
||||
p << stridePrefix;
|
||||
printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
|
||||
ShapedType::isDynamicStrideOrOffset);
|
||||
p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
|
||||
void mlir::printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &p,
|
||||
Operation *op,
|
||||
OperandRange values,
|
||||
ArrayAttr integers) {
|
||||
return printOperandsOrIntegersListImpl<ShapedType::kDynamicStrideOrOffset>(
|
||||
p, values, integers);
|
||||
}
|
||||
|
||||
ParseResult mlir::parseListOfOperandsOrIntegers(
|
||||
OpAsmParser &parser, OperationState &result, StringRef attrName,
|
||||
int64_t dynVal, SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
|
||||
void mlir::printOperandsOrIntegersSizesList(OpAsmPrinter &p, Operation *op,
|
||||
OperandRange values,
|
||||
ArrayAttr integers) {
|
||||
return printOperandsOrIntegersListImpl<ShapedType::kDynamicSize>(p, values,
|
||||
integers);
|
||||
}
|
||||
|
||||
template <int64_t dynVal>
|
||||
static ParseResult
|
||||
parseOperandsOrIntegersImpl(OpAsmParser &parser,
|
||||
SmallVectorImpl<OpAsmParser::OperandType> &values,
|
||||
ArrayAttr &integers) {
|
||||
if (failed(parser.parseLSquare()))
|
||||
return failure();
|
||||
// 0-D.
|
||||
if (succeeded(parser.parseOptionalRSquare())) {
|
||||
result.addAttribute(attrName, parser.getBuilder().getArrayAttr({}));
|
||||
integers = parser.getBuilder().getArrayAttr({});
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -118,7 +121,7 @@ ParseResult mlir::parseListOfOperandsOrIntegers(
|
|||
OpAsmParser::OperandType operand;
|
||||
auto res = parser.parseOptionalOperand(operand);
|
||||
if (res.hasValue() && succeeded(res.getValue())) {
|
||||
ssa.push_back(operand);
|
||||
values.push_back(operand);
|
||||
attrVals.push_back(dynVal);
|
||||
} else {
|
||||
IntegerAttr attr;
|
||||
|
@ -134,59 +137,20 @@ ParseResult mlir::parseListOfOperandsOrIntegers(
|
|||
return failure();
|
||||
break;
|
||||
}
|
||||
|
||||
auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals);
|
||||
result.addAttribute(attrName, arrayAttr);
|
||||
integers = parser.getBuilder().getI64ArrayAttr(attrVals);
|
||||
return success();
|
||||
}
|
||||
|
||||
ParseResult mlir::parseOffsetsSizesAndStrides(
|
||||
OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
|
||||
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix,
|
||||
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix,
|
||||
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix) {
|
||||
return parseOffsetsSizesAndStrides(
|
||||
parser, result, segmentSizes, nullptr, parseOptionalOffsetPrefix,
|
||||
parseOptionalSizePrefix, parseOptionalStridePrefix);
|
||||
ParseResult mlir::parseOperandsOrIntegersOffsetsOrStridesList(
|
||||
OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &values,
|
||||
ArrayAttr &integers) {
|
||||
return parseOperandsOrIntegersImpl<ShapedType::kDynamicStrideOrOffset>(
|
||||
parser, values, integers);
|
||||
}
|
||||
|
||||
ParseResult mlir::parseOffsetsSizesAndStrides(
|
||||
OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
|
||||
llvm::function_ref<ParseResult(OpAsmParser &, OperationState &)>
|
||||
preResolutionFn,
|
||||
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix,
|
||||
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix,
|
||||
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix) {
|
||||
SmallVector<OpAsmParser::OperandType, 4> offsetsInfo, sizesInfo, stridesInfo;
|
||||
auto indexType = parser.getBuilder().getIndexType();
|
||||
if ((parseOptionalOffsetPrefix && parseOptionalOffsetPrefix(parser)) ||
|
||||
parseListOfOperandsOrIntegers(
|
||||
parser, result,
|
||||
OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(),
|
||||
ShapedType::kDynamicStrideOrOffset, offsetsInfo) ||
|
||||
(parseOptionalSizePrefix && parseOptionalSizePrefix(parser)) ||
|
||||
parseListOfOperandsOrIntegers(
|
||||
parser, result,
|
||||
OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(),
|
||||
ShapedType::kDynamicSize, sizesInfo) ||
|
||||
(parseOptionalStridePrefix && parseOptionalStridePrefix(parser)) ||
|
||||
parseListOfOperandsOrIntegers(
|
||||
parser, result,
|
||||
OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(),
|
||||
ShapedType::kDynamicStrideOrOffset, stridesInfo))
|
||||
return failure();
|
||||
// Add segment sizes to result
|
||||
SmallVector<int, 4> segmentSizesFinal(segmentSizes.begin(),
|
||||
segmentSizes.end());
|
||||
segmentSizesFinal.append({static_cast<int>(offsetsInfo.size()),
|
||||
static_cast<int>(sizesInfo.size()),
|
||||
static_cast<int>(stridesInfo.size())});
|
||||
result.addAttribute(
|
||||
OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(),
|
||||
parser.getBuilder().getI32VectorAttr(segmentSizesFinal));
|
||||
return failure(
|
||||
(preResolutionFn && preResolutionFn(parser, result)) ||
|
||||
parser.resolveOperands(offsetsInfo, indexType, result.operands) ||
|
||||
parser.resolveOperands(sizesInfo, indexType, result.operands) ||
|
||||
parser.resolveOperands(stridesInfo, indexType, result.operands));
|
||||
ParseResult mlir::parseOperandsOrIntegersSizesList(
|
||||
OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &values,
|
||||
ArrayAttr &integers) {
|
||||
return parseOperandsOrIntegersImpl<ShapedType::kDynamicSize>(parser, values,
|
||||
integers);
|
||||
}
|
||||
|
|
|
@ -643,7 +643,7 @@ func @pad_number_of_block_args(%arg0: tensor<?x4xi32>, %arg1: i32) -> tensor<?x9
|
|||
// -----
|
||||
|
||||
func @pad_no_block(%arg0: tensor<?x4xi32>, %arg1: i32) -> tensor<?x9xi32> {
|
||||
// expected-error @+1 {{expected region with 1 block}}
|
||||
// expected-error @+1 {{op region #0 ('region') failed to verify constraint: region with 1 blocks}}
|
||||
%0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] {
|
||||
} : tensor<?x4xi32> to tensor<?x9xi32>
|
||||
return %0 : tensor<?x9xi32>
|
||||
|
|
|
@ -1581,6 +1581,8 @@ static void genSpacePrinter(bool value, OpMethodBody &body,
|
|||
if (value) {
|
||||
body << " p << ' ';\n";
|
||||
lastWasPunctuation = false;
|
||||
} else {
|
||||
lastWasPunctuation = true;
|
||||
}
|
||||
shouldEmitSpace = false;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue