[mlir] Add MemRefTypeBuilder and refactor some MemRefType::get().

The refactored MemRefType::get() calls all intend to clone from another
memref type, with some modifications. In fact, some calls dropped memory space
during the cloning. Migrate them to the cloning API so that nothing gets
dropped if they are not explicitly listed.

It's close to NFC but not quite, as it helps with propagating memory spaces in
some places.

Differential Revision: https://reviews.llvm.org/D73296
This commit is contained in:
Tim Shen 2020-01-22 13:46:11 -08:00
parent 381e81a048
commit 3ccaac3cdd
8 changed files with 79 additions and 40 deletions

View File

@ -390,6 +390,52 @@ public:
class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType,
detail::MemRefTypeStorage> {
public:
/// This is a builder type that keeps local references to arguments. Arguments
/// that are passed into the builder must out-live the builder.
class Builder {
public:
// Build from another MemRefType.
explicit Builder(MemRefType other)
: shape(other.getShape()), elementType(other.getElementType()),
affineMaps(other.getAffineMaps()),
memorySpace(other.getMemorySpace()) {}
// Build from scratch.
Builder(ArrayRef<int64_t> shape, Type elementType)
: shape(shape), elementType(elementType), affineMaps(), memorySpace(0) {
}
Builder &setShape(ArrayRef<int64_t> newShape) {
shape = newShape;
return *this;
}
Builder &setElementType(Type newElementType) {
elementType = newElementType;
return *this;
}
Builder &setAffineMaps(ArrayRef<AffineMap> newAffineMaps) {
affineMaps = newAffineMaps;
return *this;
}
Builder &setMemorySpace(unsigned newMemorySpace) {
memorySpace = newMemorySpace;
return *this;
}
operator MemRefType() {
return MemRefType::get(shape, elementType, affineMaps, memorySpace);
}
private:
ArrayRef<int64_t> shape;
Type elementType;
ArrayRef<AffineMap> affineMaps;
unsigned memorySpace;
};
using Base::Base;
/// Get or create a new MemRefType based on shape, element type, affine

View File

@ -41,8 +41,7 @@ public:
auto memref = type.dyn_cast<MemRefType>();
if (memref &&
memref.getMemorySpace() == gpu::GPUDialect::getPrivateAddressSpace()) {
type = MemRefType::get(memref.getShape(), memref.getElementType(),
memref.getAffineMaps());
type = MemRefType::Builder(memref).setMemorySpace(0);
}
return LLVMTypeConverter::convertType(type);

View File

@ -168,8 +168,8 @@ inline Type castElementType(Type t, Type newElementType) {
case StandardTypes::Kind::UnrankedTensor:
return UnrankedTensorType::get(newElementType);
case StandardTypes::Kind::MemRef:
return MemRefType::get(st.getShape(), newElementType,
st.cast<MemRefType>().getAffineMaps());
return MemRefType::Builder(st.cast<MemRefType>())
.setElementType(newElementType);
}
}
assert(t.isIntOrFloat());

View File

@ -480,7 +480,7 @@ computeReshapeCollapsedType(MemRefType type,
// Early-exit: if `type` is contiguous, the result must be contiguous.
if (canonicalizeStridedLayout(type).getAffineMaps().empty())
return MemRefType::get(newSizes, type.getElementType(), {});
return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({});
// Convert back to int64_t because we don't have enough information to create
// new strided layouts from AffineExpr only. This corresponds to a case where
@ -499,7 +499,7 @@ computeReshapeCollapsedType(MemRefType type,
auto layout =
makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext());
return canonicalizeStridedLayout(
MemRefType::get(newSizes, type.getElementType(), {layout}));
MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout}));
}
/// Helper functions assert Attribute of the proper type in attr and returns the
@ -613,11 +613,10 @@ void mlir::linalg::SliceOp::build(Builder *b, OperationState &result,
unsigned rank = memRefType.getRank();
// TODO(ntv): propagate static size and stride information when available.
SmallVector<int64_t, 4> sizes(rank, -1); // -1 encodes dynamic size.
Type elementType = memRefType.getElementType();
result.addTypes({MemRefType::get(
sizes, elementType,
{makeStridedLinearLayoutMap(strides, offset, b->getContext())},
memRefType.getMemorySpace())});
result.addTypes({MemRefType::Builder(memRefType)
.setShape(sizes)
.setAffineMaps(makeStridedLinearLayoutMap(
strides, offset, b->getContext()))});
}
static void print(OpAsmPrinter &p, SliceOp op) {
@ -698,8 +697,8 @@ void mlir::linalg::TransposeOp::build(Builder *b, OperationState &result,
auto map = makeStridedLinearLayoutMap(strides, offset, b->getContext());
map = permutationMap ? map.compose(permutationMap) : map;
// Compute result type.
auto resultType = MemRefType::get(sizes, memRefType.getElementType(), map,
memRefType.getMemorySpace());
MemRefType resultType =
MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
build(b, result, resultType, view, attrs);
result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);

View File

@ -350,9 +350,8 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
}
// Create new memref type (which will have fewer dynamic dimensions).
auto newMemRefType = MemRefType::get(
newShapeConstants, memrefType.getElementType(),
memrefType.getAffineMaps(), memrefType.getMemorySpace());
MemRefType newMemRefType =
MemRefType::Builder(memrefType).setShape(newShapeConstants);
assert(static_cast<int64_t>(newOperands.size()) ==
newMemRefType.getNumDynamicDims());
@ -2453,9 +2452,9 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
rewriter.getContext());
// Create new memref type with constant folded dims and/or offset/strides.
auto newMemRefType =
MemRefType::get(newShapeConstants, memrefType.getElementType(), {map},
memrefType.getMemorySpace());
MemRefType newMemRefType = MemRefType::Builder(memrefType)
.setShape(newShapeConstants)
.setAffineMaps({map});
(void)dynamicOffsetOperandCount; // unused in opt mode
assert(static_cast<int64_t>(newOperands.size()) ==
dynamicOffsetOperandCount + newMemRefType.getNumDynamicDims());
@ -2509,7 +2508,6 @@ static Type inferSubViewResultType(MemRefType memRefType) {
auto rank = memRefType.getRank();
int64_t offset;
SmallVector<int64_t, 4> strides;
Type elementType = memRefType.getElementType();
auto res = getStridesAndOffset(memRefType, strides, offset);
assert(succeeded(res) && "SubViewOp expected strided memref type");
(void)res;
@ -2524,8 +2522,9 @@ static Type inferSubViewResultType(MemRefType memRefType) {
auto stridedLayout =
makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
SmallVector<int64_t, 4> sizes(rank, ShapedType::kDynamicSize);
return MemRefType::get(sizes, elementType, stridedLayout,
memRefType.getMemorySpace());
return MemRefType::Builder(memRefType)
.setShape(sizes)
.setAffineMaps(stridedLayout);
}
void mlir::SubViewOp::build(Builder *b, OperationState &result, Value source,
@ -2774,9 +2773,8 @@ public:
assert(defOp);
staticShape[size.index()] = cast<ConstantIndexOp>(defOp).getValue();
}
MemRefType newMemRefType = MemRefType::get(
staticShape, subViewType.getElementType(), subViewType.getAffineMaps(),
subViewType.getMemorySpace());
MemRefType newMemRefType =
MemRefType::Builder(subViewType).setShape(staticShape);
auto newSubViewOp = rewriter.create<SubViewOp>(
subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
ArrayRef<Value>(), subViewOp.strides(), newMemRefType);
@ -2825,8 +2823,7 @@ public:
AffineMap layoutMap = makeStridedLinearLayoutMap(
staticStrides, resultOffset, rewriter.getContext());
MemRefType newMemRefType =
MemRefType::get(subViewType.getShape(), subViewType.getElementType(),
layoutMap, subViewType.getMemorySpace());
MemRefType::Builder(subViewType).setAffineMaps(layoutMap);
auto newSubViewOp = rewriter.create<SubViewOp>(
subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
subViewOp.sizes(), ArrayRef<Value>(), newMemRefType);
@ -2877,8 +2874,7 @@ public:
AffineMap layoutMap = makeStridedLinearLayoutMap(
resultStrides, staticOffset, rewriter.getContext());
MemRefType newMemRefType =
MemRefType::get(subViewType.getShape(), subViewType.getElementType(),
layoutMap, subViewType.getMemorySpace());
MemRefType::Builder(subViewType).setAffineMaps(layoutMap);
auto newSubViewOp = rewriter.create<SubViewOp>(
subViewOp.getLoc(), subViewOp.source(), ArrayRef<Value>(),
subViewOp.sizes(), subViewOp.strides(), newMemRefType);

View File

@ -723,11 +723,9 @@ MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
auto simplifiedLayoutExpr =
simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
if (expr != simplifiedLayoutExpr)
return MemRefType::get(t.getShape(), t.getElementType(),
{AffineMap::get(m.getNumDims(), m.getNumSymbols(),
{simplifiedLayoutExpr})});
return MemRefType::get(t.getShape(), t.getElementType(), {});
return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
m.getNumDims(), m.getNumSymbols(), {simplifiedLayoutExpr})});
return MemRefType::Builder(t).setAffineMaps({});
}
/// Return true if the layout for `t` is compatible with strided semantics.

View File

@ -72,10 +72,9 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank());
newShape[0] = 2;
std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1);
auto newMemRefType =
MemRefType::get(newShape, oldMemRefType.getElementType(), {},
oldMemRefType.getMemorySpace());
return newMemRefType;
return MemRefType::Builder(oldMemRefType)
.setShape(newShape)
.setAffineMaps({});
};
auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();

View File

@ -445,8 +445,10 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) {
auto oldMemRef = allocOp.getResult();
SmallVector<Value, 4> symbolOperands(allocOp.getSymbolicOperands());
auto newMemRefType = MemRefType::get(newShape, memrefType.getElementType(),
b.getMultiDimIdentityMap(newRank));
MemRefType newMemRefType =
MemRefType::Builder(memrefType)
.setShape(newShape)
.setAffineMaps(b.getMultiDimIdentityMap(newRank));
auto newAlloc = b.create<AllocOp>(allocOp.getLoc(), newMemRefType);
// Replace all uses of the old memref.