forked from OSchip/llvm-project
[mlir] Add MemRefTypeBuilder and refactor some MemRefType::get().
The refactored MemRefType::get() calls all intend to clone from another memref type, with some modifications. In fact, some calls dropped memory space during the cloning. Migrate them to the cloning API so that nothing gets dropped if they are not explicitly listed. It's close to NFC but not quite, as it helps with propagating memory spaces in some places. Differential Revision: https://reviews.llvm.org/D73296
This commit is contained in:
parent
381e81a048
commit
3ccaac3cdd
|
@ -390,6 +390,52 @@ public:
|
|||
class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType,
|
||||
detail::MemRefTypeStorage> {
|
||||
public:
|
||||
/// This is a builder type that keeps local references to arguments. Arguments
|
||||
/// that are passed into the builder must out-live the builder.
|
||||
class Builder {
|
||||
public:
|
||||
// Build from another MemRefType.
|
||||
explicit Builder(MemRefType other)
|
||||
: shape(other.getShape()), elementType(other.getElementType()),
|
||||
affineMaps(other.getAffineMaps()),
|
||||
memorySpace(other.getMemorySpace()) {}
|
||||
|
||||
// Build from scratch.
|
||||
Builder(ArrayRef<int64_t> shape, Type elementType)
|
||||
: shape(shape), elementType(elementType), affineMaps(), memorySpace(0) {
|
||||
}
|
||||
|
||||
Builder &setShape(ArrayRef<int64_t> newShape) {
|
||||
shape = newShape;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Builder &setElementType(Type newElementType) {
|
||||
elementType = newElementType;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Builder &setAffineMaps(ArrayRef<AffineMap> newAffineMaps) {
|
||||
affineMaps = newAffineMaps;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Builder &setMemorySpace(unsigned newMemorySpace) {
|
||||
memorySpace = newMemorySpace;
|
||||
return *this;
|
||||
}
|
||||
|
||||
operator MemRefType() {
|
||||
return MemRefType::get(shape, elementType, affineMaps, memorySpace);
|
||||
}
|
||||
|
||||
private:
|
||||
ArrayRef<int64_t> shape;
|
||||
Type elementType;
|
||||
ArrayRef<AffineMap> affineMaps;
|
||||
unsigned memorySpace;
|
||||
};
|
||||
|
||||
using Base::Base;
|
||||
|
||||
/// Get or create a new MemRefType based on shape, element type, affine
|
||||
|
|
|
@ -41,8 +41,7 @@ public:
|
|||
auto memref = type.dyn_cast<MemRefType>();
|
||||
if (memref &&
|
||||
memref.getMemorySpace() == gpu::GPUDialect::getPrivateAddressSpace()) {
|
||||
type = MemRefType::get(memref.getShape(), memref.getElementType(),
|
||||
memref.getAffineMaps());
|
||||
type = MemRefType::Builder(memref).setMemorySpace(0);
|
||||
}
|
||||
|
||||
return LLVMTypeConverter::convertType(type);
|
||||
|
|
|
@ -168,8 +168,8 @@ inline Type castElementType(Type t, Type newElementType) {
|
|||
case StandardTypes::Kind::UnrankedTensor:
|
||||
return UnrankedTensorType::get(newElementType);
|
||||
case StandardTypes::Kind::MemRef:
|
||||
return MemRefType::get(st.getShape(), newElementType,
|
||||
st.cast<MemRefType>().getAffineMaps());
|
||||
return MemRefType::Builder(st.cast<MemRefType>())
|
||||
.setElementType(newElementType);
|
||||
}
|
||||
}
|
||||
assert(t.isIntOrFloat());
|
||||
|
|
|
@ -480,7 +480,7 @@ computeReshapeCollapsedType(MemRefType type,
|
|||
|
||||
// Early-exit: if `type` is contiguous, the result must be contiguous.
|
||||
if (canonicalizeStridedLayout(type).getAffineMaps().empty())
|
||||
return MemRefType::get(newSizes, type.getElementType(), {});
|
||||
return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({});
|
||||
|
||||
// Convert back to int64_t because we don't have enough information to create
|
||||
// new strided layouts from AffineExpr only. This corresponds to a case where
|
||||
|
@ -499,7 +499,7 @@ computeReshapeCollapsedType(MemRefType type,
|
|||
auto layout =
|
||||
makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext());
|
||||
return canonicalizeStridedLayout(
|
||||
MemRefType::get(newSizes, type.getElementType(), {layout}));
|
||||
MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout}));
|
||||
}
|
||||
|
||||
/// Helper functions assert Attribute of the proper type in attr and returns the
|
||||
|
@ -613,11 +613,10 @@ void mlir::linalg::SliceOp::build(Builder *b, OperationState &result,
|
|||
unsigned rank = memRefType.getRank();
|
||||
// TODO(ntv): propagate static size and stride information when available.
|
||||
SmallVector<int64_t, 4> sizes(rank, -1); // -1 encodes dynamic size.
|
||||
Type elementType = memRefType.getElementType();
|
||||
result.addTypes({MemRefType::get(
|
||||
sizes, elementType,
|
||||
{makeStridedLinearLayoutMap(strides, offset, b->getContext())},
|
||||
memRefType.getMemorySpace())});
|
||||
result.addTypes({MemRefType::Builder(memRefType)
|
||||
.setShape(sizes)
|
||||
.setAffineMaps(makeStridedLinearLayoutMap(
|
||||
strides, offset, b->getContext()))});
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, SliceOp op) {
|
||||
|
@ -698,8 +697,8 @@ void mlir::linalg::TransposeOp::build(Builder *b, OperationState &result,
|
|||
auto map = makeStridedLinearLayoutMap(strides, offset, b->getContext());
|
||||
map = permutationMap ? map.compose(permutationMap) : map;
|
||||
// Compute result type.
|
||||
auto resultType = MemRefType::get(sizes, memRefType.getElementType(), map,
|
||||
memRefType.getMemorySpace());
|
||||
MemRefType resultType =
|
||||
MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
|
||||
|
||||
build(b, result, resultType, view, attrs);
|
||||
result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
|
||||
|
|
|
@ -350,9 +350,8 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
|
|||
}
|
||||
|
||||
// Create new memref type (which will have fewer dynamic dimensions).
|
||||
auto newMemRefType = MemRefType::get(
|
||||
newShapeConstants, memrefType.getElementType(),
|
||||
memrefType.getAffineMaps(), memrefType.getMemorySpace());
|
||||
MemRefType newMemRefType =
|
||||
MemRefType::Builder(memrefType).setShape(newShapeConstants);
|
||||
assert(static_cast<int64_t>(newOperands.size()) ==
|
||||
newMemRefType.getNumDynamicDims());
|
||||
|
||||
|
@ -2453,9 +2452,9 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
|
|||
rewriter.getContext());
|
||||
|
||||
// Create new memref type with constant folded dims and/or offset/strides.
|
||||
auto newMemRefType =
|
||||
MemRefType::get(newShapeConstants, memrefType.getElementType(), {map},
|
||||
memrefType.getMemorySpace());
|
||||
MemRefType newMemRefType = MemRefType::Builder(memrefType)
|
||||
.setShape(newShapeConstants)
|
||||
.setAffineMaps({map});
|
||||
(void)dynamicOffsetOperandCount; // unused in opt mode
|
||||
assert(static_cast<int64_t>(newOperands.size()) ==
|
||||
dynamicOffsetOperandCount + newMemRefType.getNumDynamicDims());
|
||||
|
@ -2509,7 +2508,6 @@ static Type inferSubViewResultType(MemRefType memRefType) {
|
|||
auto rank = memRefType.getRank();
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
Type elementType = memRefType.getElementType();
|
||||
auto res = getStridesAndOffset(memRefType, strides, offset);
|
||||
assert(succeeded(res) && "SubViewOp expected strided memref type");
|
||||
(void)res;
|
||||
|
@ -2524,8 +2522,9 @@ static Type inferSubViewResultType(MemRefType memRefType) {
|
|||
auto stridedLayout =
|
||||
makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
|
||||
SmallVector<int64_t, 4> sizes(rank, ShapedType::kDynamicSize);
|
||||
return MemRefType::get(sizes, elementType, stridedLayout,
|
||||
memRefType.getMemorySpace());
|
||||
return MemRefType::Builder(memRefType)
|
||||
.setShape(sizes)
|
||||
.setAffineMaps(stridedLayout);
|
||||
}
|
||||
|
||||
void mlir::SubViewOp::build(Builder *b, OperationState &result, Value source,
|
||||
|
@ -2774,9 +2773,8 @@ public:
|
|||
assert(defOp);
|
||||
staticShape[size.index()] = cast<ConstantIndexOp>(defOp).getValue();
|
||||
}
|
||||
MemRefType newMemRefType = MemRefType::get(
|
||||
staticShape, subViewType.getElementType(), subViewType.getAffineMaps(),
|
||||
subViewType.getMemorySpace());
|
||||
MemRefType newMemRefType =
|
||||
MemRefType::Builder(subViewType).setShape(staticShape);
|
||||
auto newSubViewOp = rewriter.create<SubViewOp>(
|
||||
subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
|
||||
ArrayRef<Value>(), subViewOp.strides(), newMemRefType);
|
||||
|
@ -2825,8 +2823,7 @@ public:
|
|||
AffineMap layoutMap = makeStridedLinearLayoutMap(
|
||||
staticStrides, resultOffset, rewriter.getContext());
|
||||
MemRefType newMemRefType =
|
||||
MemRefType::get(subViewType.getShape(), subViewType.getElementType(),
|
||||
layoutMap, subViewType.getMemorySpace());
|
||||
MemRefType::Builder(subViewType).setAffineMaps(layoutMap);
|
||||
auto newSubViewOp = rewriter.create<SubViewOp>(
|
||||
subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
|
||||
subViewOp.sizes(), ArrayRef<Value>(), newMemRefType);
|
||||
|
@ -2877,8 +2874,7 @@ public:
|
|||
AffineMap layoutMap = makeStridedLinearLayoutMap(
|
||||
resultStrides, staticOffset, rewriter.getContext());
|
||||
MemRefType newMemRefType =
|
||||
MemRefType::get(subViewType.getShape(), subViewType.getElementType(),
|
||||
layoutMap, subViewType.getMemorySpace());
|
||||
MemRefType::Builder(subViewType).setAffineMaps(layoutMap);
|
||||
auto newSubViewOp = rewriter.create<SubViewOp>(
|
||||
subViewOp.getLoc(), subViewOp.source(), ArrayRef<Value>(),
|
||||
subViewOp.sizes(), subViewOp.strides(), newMemRefType);
|
||||
|
|
|
@ -723,11 +723,9 @@ MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
|
|||
auto simplifiedLayoutExpr =
|
||||
simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
|
||||
if (expr != simplifiedLayoutExpr)
|
||||
return MemRefType::get(t.getShape(), t.getElementType(),
|
||||
{AffineMap::get(m.getNumDims(), m.getNumSymbols(),
|
||||
{simplifiedLayoutExpr})});
|
||||
|
||||
return MemRefType::get(t.getShape(), t.getElementType(), {});
|
||||
return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
|
||||
m.getNumDims(), m.getNumSymbols(), {simplifiedLayoutExpr})});
|
||||
return MemRefType::Builder(t).setAffineMaps({});
|
||||
}
|
||||
|
||||
/// Return true if the layout for `t` is compatible with strided semantics.
|
||||
|
|
|
@ -72,10 +72,9 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
|
|||
SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank());
|
||||
newShape[0] = 2;
|
||||
std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1);
|
||||
auto newMemRefType =
|
||||
MemRefType::get(newShape, oldMemRefType.getElementType(), {},
|
||||
oldMemRefType.getMemorySpace());
|
||||
return newMemRefType;
|
||||
return MemRefType::Builder(oldMemRefType)
|
||||
.setShape(newShape)
|
||||
.setAffineMaps({});
|
||||
};
|
||||
|
||||
auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();
|
||||
|
|
|
@ -445,8 +445,10 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) {
|
|||
auto oldMemRef = allocOp.getResult();
|
||||
SmallVector<Value, 4> symbolOperands(allocOp.getSymbolicOperands());
|
||||
|
||||
auto newMemRefType = MemRefType::get(newShape, memrefType.getElementType(),
|
||||
b.getMultiDimIdentityMap(newRank));
|
||||
MemRefType newMemRefType =
|
||||
MemRefType::Builder(memrefType)
|
||||
.setShape(newShape)
|
||||
.setAffineMaps(b.getMultiDimIdentityMap(newRank));
|
||||
auto newAlloc = b.create<AllocOp>(allocOp.getLoc(), newMemRefType);
|
||||
|
||||
// Replace all uses of the old memref.
|
||||
|
|
Loading…
Reference in New Issue