NFC: Update std.subview op to use AttrSizedOperandSegments

This turns a few manually written helper methods into auto-generated ones.

PiperOrigin-RevId: 283339617
This commit is contained in:
Lei Zhang 2019-12-02 07:51:27 -08:00 committed by A. Unique TensorFlower
parent 4231de7897
commit 0d22a3fdc8
5 changed files with 77 additions and 109 deletions

View File

@ -1248,7 +1248,7 @@ def ViewOp : Std_Op<"view", [NoSideEffect]> {
let hasCanonicalizer = 1;
}
def SubViewOp : Std_Op<"subview", [NoSideEffect]> {
def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> {
let summary = "memref subview operation";
let description = [{
The "subview" operation converts a memref type to another memref type
@ -1356,23 +1356,25 @@ def SubViewOp : Std_Op<"subview", [NoSideEffect]> {
// TODO(b/144779634, ravishankarm) : Use different arguments for
// offsets, sizes and strides.
let arguments = (ins AnyMemRef:$source, I32Attr:$num_offsets,
I32Attr:$num_sizes, I32Attr:$num_strides,
Variadic<Index>:$operands);
let arguments = (ins
AnyMemRef:$source,
Variadic<Index>:$offsets,
Variadic<Index>:$sizes,
Variadic<Index>:$strides,
I32ElementsAttr:$operand_segment_sizes
);
let results = (outs AnyMemRef);
let builders = [OpBuilder<
"Builder *b, OperationState &result, Value *source, "
"ArrayRef<Value *> offsets, ArrayRef<Value *> sizes, "
"ArrayRef<Value *> strides, Type resultType = Type(), "
"ArrayRef<NamedAttribute> attrs = {}">,
let builders = [
OpBuilder<
"Builder *builder, OperationState &result, Type resultType, Value *source">,
"Builder *b, OperationState &result, Value *source, "
"ArrayRef<Value *> offsets, ArrayRef<Value *> sizes, "
"ArrayRef<Value *> strides, Type resultType = Type(), "
"ArrayRef<NamedAttribute> attrs = {}">,
OpBuilder<
"Builder *builder, OperationState &result, Type resultType, Value *source, "
"unsigned num_offsets, unsigned num_sizes, unsigned num_strides, "
"ArrayRef<Value *> offsets, ArrayRef<Value *> sizes, "
"ArrayRef<Value *> strides">];
"Builder *builder, OperationState &result, "
"Type resultType, Value *source">
];
let extraClassDeclaration = [{
/// Returns the type of the base memref operand.
@ -1384,28 +1386,16 @@ def SubViewOp : Std_Op<"subview", [NoSideEffect]> {
MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
/// Returns as integer value the number of offset operands.
int64_t getNumOffsets() {
return num_offsets().getSExtValue();
}
int64_t getNumOffsets() { return llvm::size(offsets()); }
/// Returns as integer value the number of size operands.
int64_t getNumSizes() {
return num_sizes().getSExtValue();
}
int64_t getNumSizes() { return llvm::size(sizes()); }
/// Returns as integer value the number of stride operands.
int64_t getNumStrides() {
return num_strides().getSExtValue();
}
/// Returns the dynamic offsets for this subview operation.
operand_range getDynamicOffsets();
int64_t getNumStrides() { return llvm::size(strides()); }
/// Returns the dynamic sizes for this subview operation if specified.
operand_range getDynamicSizes();
/// Returns the dynamic strides for this subview operation if specified.
operand_range getDynamicStrides();
operand_range getDynamicSizes() { return sizes(); }
// Auxiliary range data structure and helper function that unpacks the
// offset, size and stride operands of the SubViewOp into a list of triples.

View File

@ -120,6 +120,8 @@ public:
IntegerAttr getI32IntegerAttr(int32_t value);
IntegerAttr getI64IntegerAttr(int64_t value);
DenseIntElementsAttr getI32VectorAttr(ArrayRef<int32_t> values);
ArrayAttr getAffineMapArrayAttr(ArrayRef<AffineMap> values);
ArrayAttr getI32ArrayAttr(ArrayRef<int32_t> values);
ArrayAttr getI64ArrayAttr(ArrayRef<int64_t> values);

View File

@ -1476,7 +1476,6 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto viewOp = cast<SubViewOp>(op);
SubViewOpOperandAdaptor adaptor(operands);
// TODO(b/144779634, ravishankarm) : After Tblgen is adapted to support
// having multiple variadic operands where each operand can have different
// number of entries, clean all of this up.
@ -1518,7 +1517,7 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
return matchFailure();
// Create the descriptor.
MemRefDescriptor sourceMemRef(adaptor.source());
MemRefDescriptor sourceMemRef(operands.front());
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
// Copy the buffer pointer from the old descriptor to the new one.

View File

@ -1370,7 +1370,7 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
// Fold dim to the size argument of a SubViewOp.
auto memref = memrefOrTensor()->getDefiningOp();
if (auto subview = dyn_cast_or_null<SubViewOp>(memref)) {
auto sizes = subview.getDynamicSizes();
auto sizes = subview.sizes();
if (!sizes.empty())
return *(sizes.begin() + getIndex());
}
@ -2563,35 +2563,23 @@ static Type inferSubViewResultType(MemRefType memRefType) {
memRefType.getMemorySpace());
}
void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType,
Value *source, unsigned num_offsets,
unsigned num_sizes, unsigned num_strides,
ArrayRef<Value *> offsets, ArrayRef<Value *> sizes,
ArrayRef<Value *> strides) {
SmallVector<Value *, 8> operands;
operands.reserve(num_offsets + num_sizes + num_strides);
operands.append(offsets.begin(), offsets.end());
operands.append(sizes.begin(), sizes.end());
operands.append(strides.begin(), strides.end());
build(b, result, resultType, source, b->getI32IntegerAttr(num_offsets),
b->getI32IntegerAttr(num_sizes), b->getI32IntegerAttr(num_strides),
operands);
}
void mlir::SubViewOp::build(Builder *b, OperationState &result, Value *source,
ArrayRef<Value *> offsets, ArrayRef<Value *> sizes,
ArrayRef<Value *> strides, Type resultType,
ArrayRef<NamedAttribute> attrs) {
if (!resultType)
resultType = inferSubViewResultType(source->getType().cast<MemRefType>());
build(b, result, resultType, source, offsets.size(), sizes.size(),
strides.size(), offsets, sizes, strides);
auto segmentAttr = b->getI32VectorAttr(
{1, static_cast<int>(offsets.size()), static_cast<int32_t>(sizes.size()),
static_cast<int32_t>(strides.size())});
build(b, result, resultType, source, offsets, sizes, strides, segmentAttr);
result.addAttributes(attrs);
}
void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType,
Value *source) {
build(b, result, resultType, source, 0, 0, 0, {}, {}, {});
build(b, result, source, /*offsets=*/{}, /*sizes=*/{}, /*strides=*/{},
resultType);
}
static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
@ -2607,12 +2595,13 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
parser.parseOperandList(stridesInfo, OpAsmParser::Delimiter::Square)) {
return failure();
}
auto builder = parser.getBuilder();
result.addAttribute("num_offsets",
builder.getI32IntegerAttr(offsetsInfo.size()));
result.addAttribute("num_sizes", builder.getI32IntegerAttr(sizesInfo.size()));
result.addAttribute("num_strides",
builder.getI32IntegerAttr(stridesInfo.size()));
result.addAttribute(
SubViewOp::getOperandSegmentSizeAttr(),
builder.getI32VectorAttr({1, static_cast<int>(offsetsInfo.size()),
static_cast<int32_t>(sizesInfo.size()),
static_cast<int32_t>(stridesInfo.size())}));
return failure(
parser.parseOptionalAttrDict(result.attributes) ||
@ -2627,14 +2616,15 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
static void print(OpAsmPrinter &p, SubViewOp op) {
p << op.getOperationName() << ' ' << *op.getOperand(0) << '[';
p.printOperands(op.getDynamicOffsets());
p.printOperands(op.offsets());
p << "][";
p.printOperands(op.getDynamicSizes());
p.printOperands(op.sizes());
p << "][";
p.printOperands(op.getDynamicStrides());
p.printOperands(op.strides());
p << ']';
SmallVector<StringRef, 3> elidedAttrs = {"num_offsets", "num_sizes",
"num_strides"};
SmallVector<StringRef, 1> elidedAttrs = {
SubViewOp::getOperandSegmentSizeAttr()};
p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
p << " : " << op.getOperand(0)->getType() << " to " << op.getType();
}
@ -2689,14 +2679,16 @@ static LogicalResult verify(SubViewOp op) {
}
// Verify that if the shape of the subview type is static, then sizes are not
// dynamic values, and viceversa.
// dynamic values, and vice versa.
if ((subViewType.hasStaticShape() && op.getNumSizes() != 0) ||
(op.getNumSizes() == 0 && !subViewType.hasStaticShape())) {
return op.emitError("invalid to specify dynamic sizes when subview result "
"type is statically shaped and viceversa");
}
// Verify that if dynamic sizes are specified, then the result memref type
// have full dynamic dimensions.
if (op.getNumSizes() > 0) {
// Verify that non if the shape values of the result type are static.
if (llvm::any_of(subViewType.getShape(), [](int64_t dim) {
return dim != ShapedType::kDynamicSize;
})) {
@ -2758,9 +2750,8 @@ SmallVector<SubViewOp::Range, 8> SubViewOp::getRanges() {
unsigned rank = getType().getRank();
res.reserve(rank);
for (unsigned i = 0; i < rank; ++i)
res.emplace_back(Range{*(getDynamicOffsets().begin() + i),
*(getDynamicSizes().begin() + i),
*(getDynamicStrides().begin() + i)});
res.emplace_back(Range{*(offsets().begin() + i), *(sizes().begin() + i),
*(strides().begin() + i)});
return res;
}
@ -2792,13 +2783,13 @@ public:
// Follow all or nothing approach for shapes for now. If all the operands
// for sizes are constants then fold it into the type of the result memref.
if (subViewType.hasStaticShape() ||
llvm::any_of(subViewOp.getDynamicSizes(), [](Value *operand) {
llvm::any_of(subViewOp.sizes(), [](Value *operand) {
return !matchPattern(operand, m_ConstantIndex());
})) {
return matchFailure();
}
SmallVector<int64_t, 4> staticShape(subViewOp.getNumSizes());
for (auto size : enumerate(subViewOp.getDynamicSizes())) {
for (auto size : enumerate(subViewOp.sizes())) {
auto defOp = size.value()->getDefiningOp();
assert(defOp);
staticShape[size.index()] = cast<ConstantIndexOp>(defOp).getValue();
@ -2808,12 +2799,12 @@ public:
subViewType.getMemorySpace());
auto newSubViewOp = rewriter.create<SubViewOp>(
subViewOp.getLoc(), subViewOp.source(),
llvm::to_vector<4>(subViewOp.getDynamicOffsets()), ArrayRef<Value *>(),
llvm::to_vector<4>(subViewOp.getDynamicStrides()), newMemRefType);
llvm::to_vector<4>(subViewOp.offsets()), ArrayRef<Value *>(),
llvm::to_vector<4>(subViewOp.strides()), newMemRefType);
// Insert a memref_cast for compatibility of the uses of the op.
rewriter.replaceOpWithNewOp<MemRefCastOp>(
llvm::to_vector<4>(subViewOp.getDynamicSizes()), subViewOp,
newSubViewOp, subViewOp.getType());
llvm::to_vector<4>(subViewOp.sizes()), subViewOp, newSubViewOp,
subViewOp.getType());
return matchSuccess();
}
};
@ -2839,14 +2830,14 @@ public:
failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) ||
llvm::is_contained(baseStrides,
MemRefType::getDynamicStrideOrOffset()) ||
llvm::any_of(subViewOp.getDynamicStrides(), [](Value *stride) {
llvm::any_of(subViewOp.strides(), [](Value *stride) {
return !matchPattern(stride, m_ConstantIndex());
})) {
return matchFailure();
}
SmallVector<int64_t, 4> staticStrides(subViewOp.getNumStrides());
for (auto stride : enumerate(subViewOp.getDynamicStrides())) {
for (auto stride : enumerate(subViewOp.strides())) {
auto defOp = stride.value()->getDefiningOp();
assert(defOp);
assert(baseStrides[stride.index()] > 0);
@ -2858,15 +2849,15 @@ public:
MemRefType newMemRefType =
MemRefType::get(subViewType.getShape(), subViewType.getElementType(),
layoutMap, subViewType.getMemorySpace());
auto newSubViewOp = rewriter.create<SubViewOp>(
subViewOp.getLoc(), subViewOp.source(),
llvm::to_vector<4>(subViewOp.getDynamicOffsets()),
llvm::to_vector<4>(subViewOp.getDynamicSizes()), ArrayRef<Value *>(),
newMemRefType);
auto newSubViewOp =
rewriter.create<SubViewOp>(subViewOp.getLoc(), subViewOp.source(),
llvm::to_vector<4>(subViewOp.offsets()),
llvm::to_vector<4>(subViewOp.sizes()),
ArrayRef<Value *>(), newMemRefType);
// Insert a memref_cast for compatibility of the uses of the op.
rewriter.replaceOpWithNewOp<MemRefCastOp>(
llvm::to_vector<4>(subViewOp.getDynamicStrides()), subViewOp,
newSubViewOp, subViewOp.getType());
llvm::to_vector<4>(subViewOp.strides()), subViewOp, newSubViewOp,
subViewOp.getType());
return matchSuccess();
}
};
@ -2893,14 +2884,14 @@ public:
llvm::is_contained(baseStrides,
MemRefType::getDynamicStrideOrOffset()) ||
baseOffset == MemRefType::getDynamicStrideOrOffset() ||
llvm::any_of(subViewOp.getDynamicOffsets(), [](Value *stride) {
llvm::any_of(subViewOp.offsets(), [](Value *stride) {
return !matchPattern(stride, m_ConstantIndex());
})) {
return matchFailure();
}
auto staticOffset = baseOffset;
for (auto offset : enumerate(subViewOp.getDynamicOffsets())) {
for (auto offset : enumerate(subViewOp.offsets())) {
auto defOp = offset.value()->getDefiningOp();
assert(defOp);
assert(baseStrides[offset.index()] > 0);
@ -2915,39 +2906,17 @@ public:
layoutMap, subViewType.getMemorySpace());
auto newSubViewOp = rewriter.create<SubViewOp>(
subViewOp.getLoc(), subViewOp.source(), ArrayRef<Value *>(),
llvm::to_vector<4>(subViewOp.getDynamicSizes()),
llvm::to_vector<4>(subViewOp.getDynamicStrides()), newMemRefType);
llvm::to_vector<4>(subViewOp.sizes()),
llvm::to_vector<4>(subViewOp.strides()), newMemRefType);
// Insert a memref_cast for compatibility of the uses of the op.
rewriter.replaceOpWithNewOp<MemRefCastOp>(
llvm::to_vector<4>(subViewOp.getDynamicOffsets()), subViewOp,
newSubViewOp, subViewOp.getType());
llvm::to_vector<4>(subViewOp.offsets()), subViewOp, newSubViewOp,
subViewOp.getType());
return matchSuccess();
}
};
} // end anonymous namespace
SubViewOp::operand_range SubViewOp::getDynamicOffsets() {
auto numOffsets = getNumOffsets();
assert(getNumOperands() >= numOffsets + 1);
return {operand_begin() + 1, operand_begin() + 1 + numOffsets};
}
SubViewOp::operand_range SubViewOp::getDynamicSizes() {
auto numSizes = getNumSizes();
auto numOffsets = getNumOffsets();
assert(getNumOperands() >= numSizes + numOffsets + 1);
return {operand_begin() + 1 + numOffsets,
operand_begin() + 1 + numOffsets + numSizes};
}
SubViewOp::operand_range SubViewOp::getDynamicStrides() {
auto numSizes = getNumSizes();
auto numOffsets = getNumOffsets();
auto numStrides = getNumStrides();
assert(getNumOperands() >= numSizes + numOffsets + numStrides + 1);
return {operand_begin() + (1 + numOffsets + numSizes),
operand_begin() + (1 + numOffsets + numSizes + numStrides)};
}
void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {

View File

@ -100,6 +100,14 @@ IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
return IntegerAttr::get(getIntegerType(64), APInt(64, value));
}
DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef<int32_t> values) {
return DenseElementsAttr::get(
VectorType::get(static_cast<int64_t>(values.size()),
getIntegerType(32)),
values)
.cast<DenseIntElementsAttr>();
}
IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
return IntegerAttr::get(getIntegerType(32), APInt(32, value));
}