[mlir] Switch {collapse,expand}_shape ops to the declarative assembly format

Same functionality, a lot less code.
This commit is contained in:
Benjamin Kramer 2022-02-17 19:56:38 +01:00
parent 4dfa68e483
commit 1af15de6b7
6 changed files with 8 additions and 109 deletions

View File

@ -1240,9 +1240,12 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
Value getViewSource() { return src(); }
}];
let assemblyFormat = [{
$src $reassociation attr-dict `:` type($src) `into` type($result)
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

View File

@ -732,9 +732,12 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
}
}];
let assemblyFormat = [{
$src $reassociation attr-dict `:` type($src) `into` type($result)
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

View File

@ -74,31 +74,6 @@ getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType);
bool isReassociationValid(ArrayRef<AffineMap> reassociation,
int *invalidIndex = nullptr);
/// Parse a reshape-like op, i.e. linalg::(Tensor)ExpandShapeOp,
/// linalg::(Tensor)CollapseShapeOp.
ParseResult parseReshapeLikeOp(OpAsmParser &parser, OperationState &result);
/// Print a reshape-like op, i.e. linalg::(Tensor)ExpandShapeOp,
/// linalg::(Tensor)CollapseShapeOp.
template <typename ReshapeLikeOp>
void printReshapeOp(OpAsmPrinter &p, ReshapeLikeOp op) {
p << ' ' << op.src() << " [";
llvm::interleaveComma(op.reassociation(), p, [&](const Attribute &attr) {
p << '[';
auto arrayAttr = attr.template cast<ArrayAttr>();
llvm::interleaveComma(arrayAttr, p, [&](const Attribute &attr) {
p << attr.cast<IntegerAttr>().getInt();
});
p << ']';
});
p << "] ";
p.printOptionalAttrDict(op->getAttrs(),
/*elidedAttrs=*/{getReassociationAttrName()});
p << ": " << op.src().getType() << " into " << op.getType();
}
template <typename ReshapeOpTy, typename InverseReshapeOpTy>
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
ArrayRef<Attribute> operands) {

View File

@ -1370,21 +1370,6 @@ SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
getReassociationIndices());
}
ParseResult ExpandShapeOp::parse(OpAsmParser &parser, OperationState &result) {
return parseReshapeLikeOp(parser, result);
}
void ExpandShapeOp::print(OpAsmPrinter &p) {
::mlir::printReshapeOp<ExpandShapeOp>(p, *this);
}
ParseResult CollapseShapeOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseReshapeLikeOp(parser, result);
}
void CollapseShapeOp::print(OpAsmPrinter &p) {
::mlir::printReshapeOp<CollapseShapeOp>(p, *this);
}
/// Detect whether memref dims [dim, dim + extent) can be reshaped without
/// copies.
static bool isReshapableDimBand(unsigned dim, unsigned extent,

View File

@ -733,17 +733,6 @@ SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
getReassociationIndices());
}
ParseResult ExpandShapeOp::parse(OpAsmParser &parser, OperationState &result) {
return parseReshapeLikeOp(parser, result);
}
void ExpandShapeOp::print(OpAsmPrinter &p) { printReshapeOp(p, *this); }
ParseResult CollapseShapeOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseReshapeLikeOp(parser, result);
}
void CollapseShapeOp::print(OpAsmPrinter &p) { printReshapeOp(p, *this); }
/// Compute the RankedTensorType obtained by applying `reassociation` to `type`.
static RankedTensorType
computeTensorReshapeCollapsedType(RankedTensorType type,

View File

@ -91,62 +91,6 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
return reassociationMap;
}
ParseResult mlir::parseReshapeLikeOp(OpAsmParser &parser,
OperationState &result) {
// Parse the operand.
OpAsmParser::OperandType src;
if (parser.parseOperand(src))
return failure();
// Parse reassociation indices.
Builder &b = parser.getBuilder();
SmallVector<Attribute, 4> reassociation;
if (parser.parseLSquare())
return failure();
while (true) {
if (succeeded(parser.parseOptionalRSquare()))
break;
if (parser.parseLSquare())
return failure();
SmallVector<int64_t> indices;
while (true) {
int64_t index;
if (parser.parseInteger(index))
return failure();
indices.push_back(index);
if (succeeded(parser.parseOptionalComma()))
continue;
if (failed(parser.parseRSquare()))
return failure();
break;
}
reassociation.push_back(b.getI64ArrayAttr(indices));
if (succeeded(parser.parseOptionalComma()))
continue;
if (failed(parser.parseRSquare()))
return failure();
break;
}
result.addAttribute(getReassociationAttrName(),
b.getArrayAttr(reassociation));
// Parse optional attributes.
parser.parseOptionalAttrDict(result.attributes);
// Parse types.
Type srcType;
Type resultType;
if (parser.parseColon() || parser.parseType(srcType) ||
parser.resolveOperand(src, srcType, result.operands) ||
parser.parseKeyword("into") || parser.parseType(resultType))
return failure();
result.addTypes(resultType);
return success();
}
Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices(
ArrayRef<ReassociationIndices> producerReassociations,
ArrayRef<ReassociationIndices> consumerReassociations,