forked from OSchip/llvm-project
[mlir] Switch {collapse,expand}_shape ops to the declarative assembly format
Same functionality, a lot less code.
This commit is contained in:
parent
4dfa68e483
commit
1af15de6b7
|
@ -1240,9 +1240,12 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
|
||||||
Value getViewSource() { return src(); }
|
Value getViewSource() { return src(); }
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$src $reassociation attr-dict `:` type($src) `into` type($result)
|
||||||
|
}];
|
||||||
|
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -732,9 +732,12 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$src $reassociation attr-dict `:` type($src) `into` type($result)
|
||||||
|
}];
|
||||||
|
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -74,31 +74,6 @@ getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType);
|
||||||
bool isReassociationValid(ArrayRef<AffineMap> reassociation,
|
bool isReassociationValid(ArrayRef<AffineMap> reassociation,
|
||||||
int *invalidIndex = nullptr);
|
int *invalidIndex = nullptr);
|
||||||
|
|
||||||
/// Parse a reshape-like op, i.e. linalg::(Tensor)ExpandShapeOp,
|
|
||||||
/// linalg::(Tensor)CollapseShapeOp.
|
|
||||||
ParseResult parseReshapeLikeOp(OpAsmParser &parser, OperationState &result);
|
|
||||||
|
|
||||||
/// Print a reshape-like op, i.e. linalg::(Tensor)ExpandShapeOp,
|
|
||||||
/// linalg::(Tensor)CollapseShapeOp.
|
|
||||||
template <typename ReshapeLikeOp>
|
|
||||||
void printReshapeOp(OpAsmPrinter &p, ReshapeLikeOp op) {
|
|
||||||
p << ' ' << op.src() << " [";
|
|
||||||
|
|
||||||
llvm::interleaveComma(op.reassociation(), p, [&](const Attribute &attr) {
|
|
||||||
p << '[';
|
|
||||||
auto arrayAttr = attr.template cast<ArrayAttr>();
|
|
||||||
llvm::interleaveComma(arrayAttr, p, [&](const Attribute &attr) {
|
|
||||||
p << attr.cast<IntegerAttr>().getInt();
|
|
||||||
});
|
|
||||||
p << ']';
|
|
||||||
});
|
|
||||||
|
|
||||||
p << "] ";
|
|
||||||
p.printOptionalAttrDict(op->getAttrs(),
|
|
||||||
/*elidedAttrs=*/{getReassociationAttrName()});
|
|
||||||
p << ": " << op.src().getType() << " into " << op.getType();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename ReshapeOpTy, typename InverseReshapeOpTy>
|
template <typename ReshapeOpTy, typename InverseReshapeOpTy>
|
||||||
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
|
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
|
||||||
ArrayRef<Attribute> operands) {
|
ArrayRef<Attribute> operands) {
|
||||||
|
|
|
@ -1370,21 +1370,6 @@ SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
|
||||||
getReassociationIndices());
|
getReassociationIndices());
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult ExpandShapeOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
||||||
return parseReshapeLikeOp(parser, result);
|
|
||||||
}
|
|
||||||
void ExpandShapeOp::print(OpAsmPrinter &p) {
|
|
||||||
::mlir::printReshapeOp<ExpandShapeOp>(p, *this);
|
|
||||||
}
|
|
||||||
|
|
||||||
ParseResult CollapseShapeOp::parse(OpAsmParser &parser,
|
|
||||||
OperationState &result) {
|
|
||||||
return parseReshapeLikeOp(parser, result);
|
|
||||||
}
|
|
||||||
void CollapseShapeOp::print(OpAsmPrinter &p) {
|
|
||||||
::mlir::printReshapeOp<CollapseShapeOp>(p, *this);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Detect whether memref dims [dim, dim + extent) can be reshaped without
|
/// Detect whether memref dims [dim, dim + extent) can be reshaped without
|
||||||
/// copies.
|
/// copies.
|
||||||
static bool isReshapableDimBand(unsigned dim, unsigned extent,
|
static bool isReshapableDimBand(unsigned dim, unsigned extent,
|
||||||
|
|
|
@ -733,17 +733,6 @@ SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
|
||||||
getReassociationIndices());
|
getReassociationIndices());
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult ExpandShapeOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
||||||
return parseReshapeLikeOp(parser, result);
|
|
||||||
}
|
|
||||||
void ExpandShapeOp::print(OpAsmPrinter &p) { printReshapeOp(p, *this); }
|
|
||||||
|
|
||||||
ParseResult CollapseShapeOp::parse(OpAsmParser &parser,
|
|
||||||
OperationState &result) {
|
|
||||||
return parseReshapeLikeOp(parser, result);
|
|
||||||
}
|
|
||||||
void CollapseShapeOp::print(OpAsmPrinter &p) { printReshapeOp(p, *this); }
|
|
||||||
|
|
||||||
/// Compute the RankedTensorType obtained by applying `reassociation` to `type`.
|
/// Compute the RankedTensorType obtained by applying `reassociation` to `type`.
|
||||||
static RankedTensorType
|
static RankedTensorType
|
||||||
computeTensorReshapeCollapsedType(RankedTensorType type,
|
computeTensorReshapeCollapsedType(RankedTensorType type,
|
||||||
|
|
|
@ -91,62 +91,6 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
|
||||||
return reassociationMap;
|
return reassociationMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult mlir::parseReshapeLikeOp(OpAsmParser &parser,
|
|
||||||
OperationState &result) {
|
|
||||||
// Parse the operand.
|
|
||||||
OpAsmParser::OperandType src;
|
|
||||||
if (parser.parseOperand(src))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Parse reassociation indices.
|
|
||||||
Builder &b = parser.getBuilder();
|
|
||||||
SmallVector<Attribute, 4> reassociation;
|
|
||||||
if (parser.parseLSquare())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
if (succeeded(parser.parseOptionalRSquare()))
|
|
||||||
break;
|
|
||||||
if (parser.parseLSquare())
|
|
||||||
return failure();
|
|
||||||
SmallVector<int64_t> indices;
|
|
||||||
while (true) {
|
|
||||||
int64_t index;
|
|
||||||
if (parser.parseInteger(index))
|
|
||||||
return failure();
|
|
||||||
indices.push_back(index);
|
|
||||||
|
|
||||||
if (succeeded(parser.parseOptionalComma()))
|
|
||||||
continue;
|
|
||||||
if (failed(parser.parseRSquare()))
|
|
||||||
return failure();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
reassociation.push_back(b.getI64ArrayAttr(indices));
|
|
||||||
if (succeeded(parser.parseOptionalComma()))
|
|
||||||
continue;
|
|
||||||
if (failed(parser.parseRSquare()))
|
|
||||||
return failure();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.addAttribute(getReassociationAttrName(),
|
|
||||||
b.getArrayAttr(reassociation));
|
|
||||||
|
|
||||||
// Parse optional attributes.
|
|
||||||
parser.parseOptionalAttrDict(result.attributes);
|
|
||||||
|
|
||||||
// Parse types.
|
|
||||||
Type srcType;
|
|
||||||
Type resultType;
|
|
||||||
if (parser.parseColon() || parser.parseType(srcType) ||
|
|
||||||
parser.resolveOperand(src, srcType, result.operands) ||
|
|
||||||
parser.parseKeyword("into") || parser.parseType(resultType))
|
|
||||||
return failure();
|
|
||||||
result.addTypes(resultType);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices(
|
Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices(
|
||||||
ArrayRef<ReassociationIndices> producerReassociations,
|
ArrayRef<ReassociationIndices> producerReassociations,
|
||||||
ArrayRef<ReassociationIndices> consumerReassociations,
|
ArrayRef<ReassociationIndices> consumerReassociations,
|
||||||
|
|
Loading…
Reference in New Issue