diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h index 32c9a2da932d..e789ab729e15 100644 --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -390,6 +390,52 @@ public: class MemRefType : public Type::TypeBase { public: + /// This is a builder type that keeps local references to arguments. Arguments + /// that are passed into the builder must out-live the builder. + class Builder { + public: + // Build from another MemRefType. + explicit Builder(MemRefType other) + : shape(other.getShape()), elementType(other.getElementType()), + affineMaps(other.getAffineMaps()), + memorySpace(other.getMemorySpace()) {} + + // Build from scratch. + Builder(ArrayRef shape, Type elementType) + : shape(shape), elementType(elementType), affineMaps(), memorySpace(0) { + } + + Builder &setShape(ArrayRef newShape) { + shape = newShape; + return *this; + } + + Builder &setElementType(Type newElementType) { + elementType = newElementType; + return *this; + } + + Builder &setAffineMaps(ArrayRef newAffineMaps) { + affineMaps = newAffineMaps; + return *this; + } + + Builder &setMemorySpace(unsigned newMemorySpace) { + memorySpace = newMemorySpace; + return *this; + } + + operator MemRefType() { + return MemRefType::get(shape, elementType, affineMaps, memorySpace); + } + + private: + ArrayRef shape; + Type elementType; + ArrayRef affineMaps; + unsigned memorySpace; + }; + using Base::Base; /// Get or create a new MemRefType based on shape, element type, affine diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index a2c4ec677bd3..ab47dcfb685f 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -41,8 +41,7 @@ public: auto memref = type.dyn_cast(); if (memref && memref.getMemorySpace() == gpu::GPUDialect::getPrivateAddressSpace()) { - type = MemRefType::get(memref.getShape(), memref.getElementType(), - memref.getAffineMaps()); + type = MemRefType::Builder(memref).setMemorySpace(0); } return LLVMTypeConverter::convertType(type); diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h index 414acce798f1..0fa359024fe8 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h @@ -168,8 +168,8 @@ inline Type castElementType(Type t, Type newElementType) { case StandardTypes::Kind::UnrankedTensor: return UnrankedTensorType::get(newElementType); case StandardTypes::Kind::MemRef: - return MemRefType::get(st.getShape(), newElementType, - st.cast().getAffineMaps()); + return MemRefType::Builder(st.cast()) + .setElementType(newElementType); } } assert(t.isIntOrFloat()); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 1cb62eaac291..fb18fbf02f38 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -480,7 +480,7 @@ computeReshapeCollapsedType(MemRefType type, // Early-exit: if `type` is contiguous, the result must be contiguous. if (canonicalizeStridedLayout(type).getAffineMaps().empty()) - return MemRefType::get(newSizes, type.getElementType(), {}); + return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({}); // Convert back to int64_t because we don't have enough information to create // new strided layouts from AffineExpr only. This corresponds to a case where @@ -499,7 +499,7 @@ computeReshapeCollapsedType(MemRefType type, auto layout = makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); return canonicalizeStridedLayout( - MemRefType::get(newSizes, type.getElementType(), {layout})); + MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout})); } /// Helper functions assert Attribute of the proper type in attr and returns the @@ -613,11 +613,10 @@ void mlir::linalg::SliceOp::build(Builder *b, OperationState &result, unsigned rank = memRefType.getRank(); // TODO(ntv): propagate static size and stride information when available. SmallVector sizes(rank, -1); // -1 encodes dynamic size. - Type elementType = memRefType.getElementType(); - result.addTypes({MemRefType::get( - sizes, elementType, - {makeStridedLinearLayoutMap(strides, offset, b->getContext())}, - memRefType.getMemorySpace())}); + result.addTypes({MemRefType::Builder(memRefType) + .setShape(sizes) + .setAffineMaps(makeStridedLinearLayoutMap( + strides, offset, b->getContext()))}); } static void print(OpAsmPrinter &p, SliceOp op) { @@ -698,8 +697,8 @@ void mlir::linalg::TransposeOp::build(Builder *b, OperationState &result, auto map = makeStridedLinearLayoutMap(strides, offset, b->getContext()); map = permutationMap ? map.compose(permutationMap) : map; // Compute result type. - auto resultType = MemRefType::get(sizes, memRefType.getElementType(), map, - memRefType.getMemorySpace()); + MemRefType resultType = + MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); build(b, result, resultType, view, attrs); result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index fded6082273c..b1e0096c26e2 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -350,9 +350,8 @@ struct SimplifyAllocConst : public OpRewritePattern { } // Create new memref type (which will have fewer dynamic dimensions). - auto newMemRefType = MemRefType::get( - newShapeConstants, memrefType.getElementType(), - memrefType.getAffineMaps(), memrefType.getMemorySpace()); + MemRefType newMemRefType = + MemRefType::Builder(memrefType).setShape(newShapeConstants); assert(static_cast(newOperands.size()) == newMemRefType.getNumDynamicDims()); @@ -2453,9 +2452,9 @@ struct ViewOpShapeFolder : public OpRewritePattern { rewriter.getContext()); // Create new memref type with constant folded dims and/or offset/strides. - auto newMemRefType = - MemRefType::get(newShapeConstants, memrefType.getElementType(), {map}, - memrefType.getMemorySpace()); + MemRefType newMemRefType = MemRefType::Builder(memrefType) + .setShape(newShapeConstants) + .setAffineMaps({map}); (void)dynamicOffsetOperandCount; // unused in opt mode assert(static_cast(newOperands.size()) == dynamicOffsetOperandCount + newMemRefType.getNumDynamicDims()); @@ -2509,7 +2508,6 @@ static Type inferSubViewResultType(MemRefType memRefType) { auto rank = memRefType.getRank(); int64_t offset; SmallVector strides; - Type elementType = memRefType.getElementType(); auto res = getStridesAndOffset(memRefType, strides, offset); assert(succeeded(res) && "SubViewOp expected strided memref type"); (void)res; @@ -2524,8 +2522,9 @@ static Type inferSubViewResultType(MemRefType memRefType) { auto stridedLayout = makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); SmallVector sizes(rank, ShapedType::kDynamicSize); - return MemRefType::get(sizes, elementType, stridedLayout, - memRefType.getMemorySpace()); + return MemRefType::Builder(memRefType) + .setShape(sizes) + .setAffineMaps(stridedLayout); } void mlir::SubViewOp::build(Builder *b, OperationState &result, Value source, @@ -2774,9 +2773,8 @@ public: assert(defOp); staticShape[size.index()] = cast(defOp).getValue(); } - MemRefType newMemRefType = MemRefType::get( - staticShape, subViewType.getElementType(), subViewType.getAffineMaps(), - subViewType.getMemorySpace()); + MemRefType newMemRefType = + MemRefType::Builder(subViewType).setShape(staticShape); auto newSubViewOp = rewriter.create( subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), ArrayRef(), subViewOp.strides(), newMemRefType); @@ -2825,8 +2823,7 @@ public: AffineMap layoutMap = makeStridedLinearLayoutMap( staticStrides, resultOffset, rewriter.getContext()); MemRefType newMemRefType = - MemRefType::get(subViewType.getShape(), subViewType.getElementType(), - layoutMap, subViewType.getMemorySpace()); + MemRefType::Builder(subViewType).setAffineMaps(layoutMap); auto newSubViewOp = rewriter.create( subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), subViewOp.sizes(), ArrayRef(), newMemRefType); @@ -2877,8 +2874,7 @@ public: AffineMap layoutMap = makeStridedLinearLayoutMap( resultStrides, staticOffset, rewriter.getContext()); MemRefType newMemRefType = - MemRefType::get(subViewType.getShape(), subViewType.getElementType(), - layoutMap, subViewType.getMemorySpace()); + MemRefType::Builder(subViewType).setAffineMaps(layoutMap); auto newSubViewOp = rewriter.create( subViewOp.getLoc(), subViewOp.source(), ArrayRef(), subViewOp.sizes(), subViewOp.strides(), newMemRefType); diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp index 6f9f6a86a8dc..bd12cff65f7e 100644 --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -723,11 +723,9 @@ MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { auto simplifiedLayoutExpr = simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); if (expr != simplifiedLayoutExpr) - return MemRefType::get(t.getShape(), t.getElementType(), - {AffineMap::get(m.getNumDims(), m.getNumSymbols(), - {simplifiedLayoutExpr})}); - - return MemRefType::get(t.getShape(), t.getElementType(), {}); + return MemRefType::Builder(t).setAffineMaps({AffineMap::get( + m.getNumDims(), m.getNumSymbols(), {simplifiedLayoutExpr})}); + return MemRefType::Builder(t).setAffineMaps({}); } /// Return true if the layout for `t` is compatible with strided semantics. diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index a9a41a6afa4b..58a2f17f5aa0 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -72,10 +72,9 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) { SmallVector newShape(1 + oldMemRefType.getRank()); newShape[0] = 2; std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1); - auto newMemRefType = - MemRefType::get(newShape, oldMemRefType.getElementType(), {}, - oldMemRefType.getMemorySpace()); - return newMemRefType; + return MemRefType::Builder(oldMemRefType) + .setShape(newShape) + .setAffineMaps({}); }; auto oldMemRefType = oldMemRef.getType().cast(); diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index ca41ac2e58ac..5e824135164d 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -445,8 +445,10 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) { auto oldMemRef = allocOp.getResult(); SmallVector symbolOperands(allocOp.getSymbolicOperands()); - auto newMemRefType = MemRefType::get(newShape, memrefType.getElementType(), - b.getMultiDimIdentityMap(newRank)); + MemRefType newMemRefType = + MemRefType::Builder(memrefType) + .setShape(newShape) + .setAffineMaps(b.getMultiDimIdentityMap(newRank)); auto newAlloc = b.create(allocOp.getLoc(), newMemRefType); // Replace all uses of the old memref.