[mlir] Flip MemRef dialect to _Both (NFC)

This commit is contained in:
Jacques Pienaar 2022-06-26 20:45:25 -07:00
parent 24e53b01d5
commit 655dc02cb0
3 changed files with 15 additions and 41 deletions

View File

@ -22,9 +22,7 @@ def MemRef_Dialect : Dialect {
let dependentDialects = ["arith::ArithmeticDialect"];
let hasConstantMaterializer = 1;
// TODO: This has overlapping accessors with generated when switched to
// prefixed. Fix and update to _Both & then _Prefixed.
let emitAccessorPrefix = kEmitAccessorPrefix_Raw;
let emitAccessorPrefix = kEmitAccessorPrefix_Both;
}
#endif // MEMREF_BASE

View File

@ -90,16 +90,13 @@ class AllocLikeOp<string mnemonic,
static_cast<int32_t>(dynamicSizes.size()),
static_cast<int32_t>(symbolOperands.size())}));
if (alignment)
$_state.addAttribute(getAlignmentAttrName(), alignment);
$_state.addAttribute(getAlignmentAttrStrName(), alignment);
}]>];
let extraClassDeclaration = [{
static StringRef getAlignmentAttrName() { return "alignment"; }
static StringRef getAlignmentAttrStrName() { return "alignment"; }
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
/// Returns the dynamic sizes for this alloc operation if specified.
operand_range getDynamicSizes() { return dynamicSizes(); }
}];
let assemblyFormat = [{
@ -407,11 +404,6 @@ def CopyOp : MemRef_Op<"copy",
Arg<AnyRankedOrUnrankedMemRef, "the memref to copy to",
[MemWrite]>:$target);
let extraClassDeclaration = [{
Value getSource() { return source();}
Value getTarget() { return target(); }
}];
let assemblyFormat = [{
$source `,` $target attr-dict `:` type($source) `to` type($target)
}];
@ -602,6 +594,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
Value getTagMemRef() {
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
}
// Returns the rank (number of indices) of the tag MemRefType.
unsigned getTagMemRefRank() {
return getTagMemRef().getType().cast<MemRefType>().getRank();
@ -691,21 +684,10 @@ def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> {
type($tagMemRef)
}];
let extraClassDeclaration = [{
/// Returns the Tag MemRef associated with the DMA operation being waited
/// on.
Value getTagMemRef() { return tagMemRef(); }
/// Returns the tag memref index for this DMA operation.
operand_range getTagIndices() { return tagIndices(); }
/// Returns the rank (number of indices) of the tag memref.
unsigned getTagMemRefRank() {
return getTagMemRef().getType().cast<MemRefType>().getRank();
}
/// Returns the number of elements transferred in the associated DMA
/// operation.
Value getNumElements() { return numElements(); }
}];
let hasFolder = 1;
let hasVerifier = 1;
@ -950,8 +932,6 @@ def LoadOp : MemRef_Op<"load",
MemRefType getMemRefType() {
return getMemRef().getType().cast<MemRefType>();
}
operand_range getIndices() { return {operand_begin() + 1, operand_end()}; }
}];
let hasFolder = 1;
@ -993,9 +973,9 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
MemRefType getMemRefType() {
return memref().getType().cast<MemRefType>();
}
static StringRef getLocalityHintAttrName() { return "localityHint"; }
static StringRef getIsWriteAttrName() { return "isWrite"; }
static StringRef getIsDataCacheAttrName() { return "isDataCache"; }
static StringRef getLocalityHintAttrStrName() { return "localityHint"; }
static StringRef getIsWriteAttrStrName() { return "isWrite"; }
static StringRef getIsDataCacheAttrStrName() { return "isDataCache"; }
}];
let hasCustomAssemblyFormat = 1;
@ -1192,7 +1172,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
[NoSideEffect, ViewLikeOpInterface])>,
Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
Results<(outs AnyStridedMemRef:$result)>{
code commonExtraClassDeclaration = [{
SmallVector<AffineMap, 4> getReassociationMaps();
@ -1442,10 +1422,6 @@ def MemRef_StoreOp : MemRef_Op<"store",
MemRefType getMemRefType() {
return getMemRef().getType().cast<MemRefType>();
}
operand_range getIndices() {
return {operand_begin() + 2, operand_end()};
}
}];
let hasFolder = 1;
@ -1756,7 +1732,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [NoSideEffect]>,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
let extraClassDeclaration = [{
static StringRef getPermutationAttrName() { return "permutation"; }
static StringRef getPermutationAttrStrName() { return "permutation"; }
ShapedType getShapedType() { return in().getType().cast<ShapedType>(); }
}];

View File

@ -1415,7 +1415,7 @@ ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
return parser.emitError(parser.getNameLoc(),
"rw specifier has to be 'read' or 'write'");
result.addAttribute(
PrefetchOp::getIsWriteAttrName(),
PrefetchOp::getIsWriteAttrStrName(),
parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
if (!cacheType.equals("data") && !cacheType.equals("instr"))
@ -1423,7 +1423,7 @@ ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
"cache type has to be 'data' or 'instr'");
result.addAttribute(
PrefetchOp::getIsDataCacheAttrName(),
PrefetchOp::getIsDataCacheAttrStrName(),
parser.getBuilder().getBoolAttr(cacheType.equals("data")));
return success();
@ -1932,7 +1932,7 @@ void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
auto srcType = src.getType().cast<MemRefType>();
MemRefType resultType = computeCollapsedType(srcType, reassociation);
build(b, result, resultType, src, attrs);
result.addAttribute(getReassociationAttrName(),
result.addAttribute(::mlir::getReassociationAttrName(),
getReassociationIndicesAttribute(b, reassociation));
}
@ -2663,13 +2663,13 @@ void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
build(b, result, resultType, in, attrs);
result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
}
// transpose $in $permutation attr-dict : type($in) `to` type(results)
void TransposeOp::print(OpAsmPrinter &p) {
p << " " << in() << " " << permutation();
p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
p << " : " << in().getType() << " to " << getType();
}
@ -2685,7 +2685,7 @@ ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
parser.addTypeToList(dstType, result.types))
return failure();
result.addAttribute(TransposeOp::getPermutationAttrName(),
result.addAttribute(TransposeOp::getPermutationAttrStrName(),
AffineMapAttr::get(permutation));
return success();
}