forked from OSchip/llvm-project
NFC: Update std.subview op to use AttrSizedOperandSegments
This turns a few manually written helper methods into auto-generated ones. PiperOrigin-RevId: 283339617
This commit is contained in:
parent
4231de7897
commit
0d22a3fdc8
|
@ -1248,7 +1248,7 @@ def ViewOp : Std_Op<"view", [NoSideEffect]> {
|
|||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def SubViewOp : Std_Op<"subview", [NoSideEffect]> {
|
||||
def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> {
|
||||
let summary = "memref subview operation";
|
||||
let description = [{
|
||||
The "subview" operation converts a memref type to another memref type
|
||||
|
@ -1356,23 +1356,25 @@ def SubViewOp : Std_Op<"subview", [NoSideEffect]> {
|
|||
|
||||
// TODO(b/144779634, ravishankarm) : Use different arguments for
|
||||
// offsets, sizes and strides.
|
||||
let arguments = (ins AnyMemRef:$source, I32Attr:$num_offsets,
|
||||
I32Attr:$num_sizes, I32Attr:$num_strides,
|
||||
Variadic<Index>:$operands);
|
||||
let arguments = (ins
|
||||
AnyMemRef:$source,
|
||||
Variadic<Index>:$offsets,
|
||||
Variadic<Index>:$sizes,
|
||||
Variadic<Index>:$strides,
|
||||
I32ElementsAttr:$operand_segment_sizes
|
||||
);
|
||||
let results = (outs AnyMemRef);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *b, OperationState &result, Value *source, "
|
||||
"ArrayRef<Value *> offsets, ArrayRef<Value *> sizes, "
|
||||
"ArrayRef<Value *> strides, Type resultType = Type(), "
|
||||
"ArrayRef<NamedAttribute> attrs = {}">,
|
||||
let builders = [
|
||||
OpBuilder<
|
||||
"Builder *builder, OperationState &result, Type resultType, Value *source">,
|
||||
"Builder *b, OperationState &result, Value *source, "
|
||||
"ArrayRef<Value *> offsets, ArrayRef<Value *> sizes, "
|
||||
"ArrayRef<Value *> strides, Type resultType = Type(), "
|
||||
"ArrayRef<NamedAttribute> attrs = {}">,
|
||||
OpBuilder<
|
||||
"Builder *builder, OperationState &result, Type resultType, Value *source, "
|
||||
"unsigned num_offsets, unsigned num_sizes, unsigned num_strides, "
|
||||
"ArrayRef<Value *> offsets, ArrayRef<Value *> sizes, "
|
||||
"ArrayRef<Value *> strides">];
|
||||
"Builder *builder, OperationState &result, "
|
||||
"Type resultType, Value *source">
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Returns the type of the base memref operand.
|
||||
|
@ -1384,28 +1386,16 @@ def SubViewOp : Std_Op<"subview", [NoSideEffect]> {
|
|||
MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
|
||||
|
||||
/// Returns as integer value the number of offset operands.
|
||||
int64_t getNumOffsets() {
|
||||
return num_offsets().getSExtValue();
|
||||
}
|
||||
int64_t getNumOffsets() { return llvm::size(offsets()); }
|
||||
|
||||
/// Returns as integer value the number of size operands.
|
||||
int64_t getNumSizes() {
|
||||
return num_sizes().getSExtValue();
|
||||
}
|
||||
int64_t getNumSizes() { return llvm::size(sizes()); }
|
||||
|
||||
/// Returns as integer value the number of stride operands.
|
||||
int64_t getNumStrides() {
|
||||
return num_strides().getSExtValue();
|
||||
}
|
||||
|
||||
/// Returns the dynamic offsets for this subview operation.
|
||||
operand_range getDynamicOffsets();
|
||||
int64_t getNumStrides() { return llvm::size(strides()); }
|
||||
|
||||
/// Returns the dynamic sizes for this subview operation if specified.
|
||||
operand_range getDynamicSizes();
|
||||
|
||||
/// Returns the dynamic strides for this subview operation if specified.
|
||||
operand_range getDynamicStrides();
|
||||
operand_range getDynamicSizes() { return sizes(); }
|
||||
|
||||
// Auxiliary range data structure and helper function that unpacks the
|
||||
// offset, size and stride operands of the SubViewOp into a list of triples.
|
||||
|
|
|
@ -120,6 +120,8 @@ public:
|
|||
IntegerAttr getI32IntegerAttr(int32_t value);
|
||||
IntegerAttr getI64IntegerAttr(int64_t value);
|
||||
|
||||
DenseIntElementsAttr getI32VectorAttr(ArrayRef<int32_t> values);
|
||||
|
||||
ArrayAttr getAffineMapArrayAttr(ArrayRef<AffineMap> values);
|
||||
ArrayAttr getI32ArrayAttr(ArrayRef<int32_t> values);
|
||||
ArrayAttr getI64ArrayAttr(ArrayRef<int64_t> values);
|
||||
|
|
|
@ -1476,7 +1476,6 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto viewOp = cast<SubViewOp>(op);
|
||||
SubViewOpOperandAdaptor adaptor(operands);
|
||||
// TODO(b/144779634, ravishankarm) : After Tblgen is adapted to support
|
||||
// having multiple variadic operands where each operand can have different
|
||||
// number of entries, clean all of this up.
|
||||
|
@ -1518,7 +1517,7 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
|
|||
return matchFailure();
|
||||
|
||||
// Create the descriptor.
|
||||
MemRefDescriptor sourceMemRef(adaptor.source());
|
||||
MemRefDescriptor sourceMemRef(operands.front());
|
||||
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
|
||||
|
||||
// Copy the buffer pointer from the old descriptor to the new one.
|
||||
|
|
|
@ -1370,7 +1370,7 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
|
|||
// Fold dim to the size argument of a SubViewOp.
|
||||
auto memref = memrefOrTensor()->getDefiningOp();
|
||||
if (auto subview = dyn_cast_or_null<SubViewOp>(memref)) {
|
||||
auto sizes = subview.getDynamicSizes();
|
||||
auto sizes = subview.sizes();
|
||||
if (!sizes.empty())
|
||||
return *(sizes.begin() + getIndex());
|
||||
}
|
||||
|
@ -2563,35 +2563,23 @@ static Type inferSubViewResultType(MemRefType memRefType) {
|
|||
memRefType.getMemorySpace());
|
||||
}
|
||||
|
||||
void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType,
|
||||
Value *source, unsigned num_offsets,
|
||||
unsigned num_sizes, unsigned num_strides,
|
||||
ArrayRef<Value *> offsets, ArrayRef<Value *> sizes,
|
||||
ArrayRef<Value *> strides) {
|
||||
SmallVector<Value *, 8> operands;
|
||||
operands.reserve(num_offsets + num_sizes + num_strides);
|
||||
operands.append(offsets.begin(), offsets.end());
|
||||
operands.append(sizes.begin(), sizes.end());
|
||||
operands.append(strides.begin(), strides.end());
|
||||
build(b, result, resultType, source, b->getI32IntegerAttr(num_offsets),
|
||||
b->getI32IntegerAttr(num_sizes), b->getI32IntegerAttr(num_strides),
|
||||
operands);
|
||||
}
|
||||
|
||||
void mlir::SubViewOp::build(Builder *b, OperationState &result, Value *source,
|
||||
ArrayRef<Value *> offsets, ArrayRef<Value *> sizes,
|
||||
ArrayRef<Value *> strides, Type resultType,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
if (!resultType)
|
||||
resultType = inferSubViewResultType(source->getType().cast<MemRefType>());
|
||||
build(b, result, resultType, source, offsets.size(), sizes.size(),
|
||||
strides.size(), offsets, sizes, strides);
|
||||
auto segmentAttr = b->getI32VectorAttr(
|
||||
{1, static_cast<int>(offsets.size()), static_cast<int32_t>(sizes.size()),
|
||||
static_cast<int32_t>(strides.size())});
|
||||
build(b, result, resultType, source, offsets, sizes, strides, segmentAttr);
|
||||
result.addAttributes(attrs);
|
||||
}
|
||||
|
||||
void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType,
|
||||
Value *source) {
|
||||
build(b, result, resultType, source, 0, 0, 0, {}, {}, {});
|
||||
build(b, result, source, /*offsets=*/{}, /*sizes=*/{}, /*strides=*/{},
|
||||
resultType);
|
||||
}
|
||||
|
||||
static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
|
||||
|
@ -2607,12 +2595,13 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
|
|||
parser.parseOperandList(stridesInfo, OpAsmParser::Delimiter::Square)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto builder = parser.getBuilder();
|
||||
result.addAttribute("num_offsets",
|
||||
builder.getI32IntegerAttr(offsetsInfo.size()));
|
||||
result.addAttribute("num_sizes", builder.getI32IntegerAttr(sizesInfo.size()));
|
||||
result.addAttribute("num_strides",
|
||||
builder.getI32IntegerAttr(stridesInfo.size()));
|
||||
result.addAttribute(
|
||||
SubViewOp::getOperandSegmentSizeAttr(),
|
||||
builder.getI32VectorAttr({1, static_cast<int>(offsetsInfo.size()),
|
||||
static_cast<int32_t>(sizesInfo.size()),
|
||||
static_cast<int32_t>(stridesInfo.size())}));
|
||||
|
||||
return failure(
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
|
@ -2627,14 +2616,15 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
|
|||
|
||||
static void print(OpAsmPrinter &p, SubViewOp op) {
|
||||
p << op.getOperationName() << ' ' << *op.getOperand(0) << '[';
|
||||
p.printOperands(op.getDynamicOffsets());
|
||||
p.printOperands(op.offsets());
|
||||
p << "][";
|
||||
p.printOperands(op.getDynamicSizes());
|
||||
p.printOperands(op.sizes());
|
||||
p << "][";
|
||||
p.printOperands(op.getDynamicStrides());
|
||||
p.printOperands(op.strides());
|
||||
p << ']';
|
||||
SmallVector<StringRef, 3> elidedAttrs = {"num_offsets", "num_sizes",
|
||||
"num_strides"};
|
||||
|
||||
SmallVector<StringRef, 1> elidedAttrs = {
|
||||
SubViewOp::getOperandSegmentSizeAttr()};
|
||||
p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
|
||||
p << " : " << op.getOperand(0)->getType() << " to " << op.getType();
|
||||
}
|
||||
|
@ -2689,14 +2679,16 @@ static LogicalResult verify(SubViewOp op) {
|
|||
}
|
||||
|
||||
// Verify that if the shape of the subview type is static, then sizes are not
|
||||
// dynamic values, and viceversa.
|
||||
// dynamic values, and vice versa.
|
||||
if ((subViewType.hasStaticShape() && op.getNumSizes() != 0) ||
|
||||
(op.getNumSizes() == 0 && !subViewType.hasStaticShape())) {
|
||||
return op.emitError("invalid to specify dynamic sizes when subview result "
|
||||
"type is statically shaped and viceversa");
|
||||
}
|
||||
|
||||
// Verify that if dynamic sizes are specified, then the result memref type
|
||||
// have full dynamic dimensions.
|
||||
if (op.getNumSizes() > 0) {
|
||||
// Verify that non if the shape values of the result type are static.
|
||||
if (llvm::any_of(subViewType.getShape(), [](int64_t dim) {
|
||||
return dim != ShapedType::kDynamicSize;
|
||||
})) {
|
||||
|
@ -2758,9 +2750,8 @@ SmallVector<SubViewOp::Range, 8> SubViewOp::getRanges() {
|
|||
unsigned rank = getType().getRank();
|
||||
res.reserve(rank);
|
||||
for (unsigned i = 0; i < rank; ++i)
|
||||
res.emplace_back(Range{*(getDynamicOffsets().begin() + i),
|
||||
*(getDynamicSizes().begin() + i),
|
||||
*(getDynamicStrides().begin() + i)});
|
||||
res.emplace_back(Range{*(offsets().begin() + i), *(sizes().begin() + i),
|
||||
*(strides().begin() + i)});
|
||||
return res;
|
||||
}
|
||||
|
||||
|
@ -2792,13 +2783,13 @@ public:
|
|||
// Follow all or nothing approach for shapes for now. If all the operands
|
||||
// for sizes are constants then fold it into the type of the result memref.
|
||||
if (subViewType.hasStaticShape() ||
|
||||
llvm::any_of(subViewOp.getDynamicSizes(), [](Value *operand) {
|
||||
llvm::any_of(subViewOp.sizes(), [](Value *operand) {
|
||||
return !matchPattern(operand, m_ConstantIndex());
|
||||
})) {
|
||||
return matchFailure();
|
||||
}
|
||||
SmallVector<int64_t, 4> staticShape(subViewOp.getNumSizes());
|
||||
for (auto size : enumerate(subViewOp.getDynamicSizes())) {
|
||||
for (auto size : enumerate(subViewOp.sizes())) {
|
||||
auto defOp = size.value()->getDefiningOp();
|
||||
assert(defOp);
|
||||
staticShape[size.index()] = cast<ConstantIndexOp>(defOp).getValue();
|
||||
|
@ -2808,12 +2799,12 @@ public:
|
|||
subViewType.getMemorySpace());
|
||||
auto newSubViewOp = rewriter.create<SubViewOp>(
|
||||
subViewOp.getLoc(), subViewOp.source(),
|
||||
llvm::to_vector<4>(subViewOp.getDynamicOffsets()), ArrayRef<Value *>(),
|
||||
llvm::to_vector<4>(subViewOp.getDynamicStrides()), newMemRefType);
|
||||
llvm::to_vector<4>(subViewOp.offsets()), ArrayRef<Value *>(),
|
||||
llvm::to_vector<4>(subViewOp.strides()), newMemRefType);
|
||||
// Insert a memref_cast for compatibility of the uses of the op.
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(
|
||||
llvm::to_vector<4>(subViewOp.getDynamicSizes()), subViewOp,
|
||||
newSubViewOp, subViewOp.getType());
|
||||
llvm::to_vector<4>(subViewOp.sizes()), subViewOp, newSubViewOp,
|
||||
subViewOp.getType());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
@ -2839,14 +2830,14 @@ public:
|
|||
failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) ||
|
||||
llvm::is_contained(baseStrides,
|
||||
MemRefType::getDynamicStrideOrOffset()) ||
|
||||
llvm::any_of(subViewOp.getDynamicStrides(), [](Value *stride) {
|
||||
llvm::any_of(subViewOp.strides(), [](Value *stride) {
|
||||
return !matchPattern(stride, m_ConstantIndex());
|
||||
})) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4> staticStrides(subViewOp.getNumStrides());
|
||||
for (auto stride : enumerate(subViewOp.getDynamicStrides())) {
|
||||
for (auto stride : enumerate(subViewOp.strides())) {
|
||||
auto defOp = stride.value()->getDefiningOp();
|
||||
assert(defOp);
|
||||
assert(baseStrides[stride.index()] > 0);
|
||||
|
@ -2858,15 +2849,15 @@ public:
|
|||
MemRefType newMemRefType =
|
||||
MemRefType::get(subViewType.getShape(), subViewType.getElementType(),
|
||||
layoutMap, subViewType.getMemorySpace());
|
||||
auto newSubViewOp = rewriter.create<SubViewOp>(
|
||||
subViewOp.getLoc(), subViewOp.source(),
|
||||
llvm::to_vector<4>(subViewOp.getDynamicOffsets()),
|
||||
llvm::to_vector<4>(subViewOp.getDynamicSizes()), ArrayRef<Value *>(),
|
||||
newMemRefType);
|
||||
auto newSubViewOp =
|
||||
rewriter.create<SubViewOp>(subViewOp.getLoc(), subViewOp.source(),
|
||||
llvm::to_vector<4>(subViewOp.offsets()),
|
||||
llvm::to_vector<4>(subViewOp.sizes()),
|
||||
ArrayRef<Value *>(), newMemRefType);
|
||||
// Insert a memref_cast for compatibility of the uses of the op.
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(
|
||||
llvm::to_vector<4>(subViewOp.getDynamicStrides()), subViewOp,
|
||||
newSubViewOp, subViewOp.getType());
|
||||
llvm::to_vector<4>(subViewOp.strides()), subViewOp, newSubViewOp,
|
||||
subViewOp.getType());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
@ -2893,14 +2884,14 @@ public:
|
|||
llvm::is_contained(baseStrides,
|
||||
MemRefType::getDynamicStrideOrOffset()) ||
|
||||
baseOffset == MemRefType::getDynamicStrideOrOffset() ||
|
||||
llvm::any_of(subViewOp.getDynamicOffsets(), [](Value *stride) {
|
||||
llvm::any_of(subViewOp.offsets(), [](Value *stride) {
|
||||
return !matchPattern(stride, m_ConstantIndex());
|
||||
})) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
auto staticOffset = baseOffset;
|
||||
for (auto offset : enumerate(subViewOp.getDynamicOffsets())) {
|
||||
for (auto offset : enumerate(subViewOp.offsets())) {
|
||||
auto defOp = offset.value()->getDefiningOp();
|
||||
assert(defOp);
|
||||
assert(baseStrides[offset.index()] > 0);
|
||||
|
@ -2915,39 +2906,17 @@ public:
|
|||
layoutMap, subViewType.getMemorySpace());
|
||||
auto newSubViewOp = rewriter.create<SubViewOp>(
|
||||
subViewOp.getLoc(), subViewOp.source(), ArrayRef<Value *>(),
|
||||
llvm::to_vector<4>(subViewOp.getDynamicSizes()),
|
||||
llvm::to_vector<4>(subViewOp.getDynamicStrides()), newMemRefType);
|
||||
llvm::to_vector<4>(subViewOp.sizes()),
|
||||
llvm::to_vector<4>(subViewOp.strides()), newMemRefType);
|
||||
// Insert a memref_cast for compatibility of the uses of the op.
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(
|
||||
llvm::to_vector<4>(subViewOp.getDynamicOffsets()), subViewOp,
|
||||
newSubViewOp, subViewOp.getType());
|
||||
llvm::to_vector<4>(subViewOp.offsets()), subViewOp, newSubViewOp,
|
||||
subViewOp.getType());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
SubViewOp::operand_range SubViewOp::getDynamicOffsets() {
|
||||
auto numOffsets = getNumOffsets();
|
||||
assert(getNumOperands() >= numOffsets + 1);
|
||||
return {operand_begin() + 1, operand_begin() + 1 + numOffsets};
|
||||
}
|
||||
|
||||
SubViewOp::operand_range SubViewOp::getDynamicSizes() {
|
||||
auto numSizes = getNumSizes();
|
||||
auto numOffsets = getNumOffsets();
|
||||
assert(getNumOperands() >= numSizes + numOffsets + 1);
|
||||
return {operand_begin() + 1 + numOffsets,
|
||||
operand_begin() + 1 + numOffsets + numSizes};
|
||||
}
|
||||
|
||||
SubViewOp::operand_range SubViewOp::getDynamicStrides() {
|
||||
auto numSizes = getNumSizes();
|
||||
auto numOffsets = getNumOffsets();
|
||||
auto numStrides = getNumStrides();
|
||||
assert(getNumOperands() >= numSizes + numOffsets + numStrides + 1);
|
||||
return {operand_begin() + (1 + numOffsets + numSizes),
|
||||
operand_begin() + (1 + numOffsets + numSizes + numStrides)};
|
||||
}
|
||||
|
||||
void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
|
|
|
@ -100,6 +100,14 @@ IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
|
|||
return IntegerAttr::get(getIntegerType(64), APInt(64, value));
|
||||
}
|
||||
|
||||
DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef<int32_t> values) {
|
||||
return DenseElementsAttr::get(
|
||||
VectorType::get(static_cast<int64_t>(values.size()),
|
||||
getIntegerType(32)),
|
||||
values)
|
||||
.cast<DenseIntElementsAttr>();
|
||||
}
|
||||
|
||||
IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
|
||||
return IntegerAttr::get(getIntegerType(32), APInt(32, value));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue