[mlir] Add clone method to ShapedType

Allow clients to create a new ShapedType of the same "container" type
but with different element or shape. First use case is when refining
shape during shape inference without needing to consider which
ShapedType is being refined.

Differential Revision: https://reviews.llvm.org/D96682
This commit is contained in:
Jacques Pienaar 2021-02-15 11:04:16 -08:00
parent 6c5f17e701
commit 381a65fa06
4 changed files with 204 additions and 0 deletions

View File

@ -88,6 +88,11 @@ public:
static constexpr int64_t kDynamicStrideOrOffset =
std::numeric_limits<int64_t>::min();
/// Return clone of this type with new shape and element type.
ShapedType clone(ArrayRef<int64_t> shape, Type elementType);
ShapedType clone(ArrayRef<int64_t> shape);
ShapedType clone(Type elementType);
/// Return the element type.
Type getElementType() const;

View File

@ -197,6 +197,75 @@ LogicalResult OpaqueType::verifyConstructionInvariants(Location loc,
constexpr int64_t ShapedType::kDynamicSize;
constexpr int64_t ShapedType::kDynamicStrideOrOffset;
ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) {
if (auto other = dyn_cast<MemRefType>()) {
MemRefType::Builder b(other);
b.setShape(shape);
b.setElementType(elementType);
return b;
}
if (auto other = dyn_cast<UnrankedMemRefType>()) {
MemRefType::Builder b(shape, elementType);
b.setMemorySpace(other.getMemorySpace());
return b;
}
if (isa<TensorType>())
return RankedTensorType::get(shape, elementType);
if (isa<VectorType>())
return VectorType::get(shape, elementType);
llvm_unreachable("Unhandled ShapedType clone case");
}
ShapedType ShapedType::clone(ArrayRef<int64_t> shape) {
if (auto other = dyn_cast<MemRefType>()) {
MemRefType::Builder b(other);
b.setShape(shape);
return b;
}
if (auto other = dyn_cast<UnrankedMemRefType>()) {
MemRefType::Builder b(shape, other.getElementType());
b.setShape(shape);
b.setMemorySpace(other.getMemorySpace());
return b;
}
if (isa<TensorType>())
return RankedTensorType::get(shape, getElementType());
if (isa<VectorType>())
return VectorType::get(shape, getElementType());
llvm_unreachable("Unhandled ShapedType clone case");
}
ShapedType ShapedType::clone(Type elementType) {
if (auto other = dyn_cast<MemRefType>()) {
MemRefType::Builder b(other);
b.setElementType(elementType);
return b;
}
if (auto other = dyn_cast<UnrankedMemRefType>()) {
return UnrankedMemRefType::get(elementType, other.getMemorySpace());
}
if (isa<TensorType>()) {
if (hasRank())
return RankedTensorType::get(getShape(), elementType);
return UnrankedTensorType::get(elementType);
}
if (isa<VectorType>())
return VectorType::get(getShape(), elementType);
llvm_unreachable("Unhandled ShapedType clone hit");
}
Type ShapedType::getElementType() const {
return static_cast<ImplType *>(impl)->elementType;
}

View File

@ -2,6 +2,7 @@ add_mlir_unittest(MLIRIRTests
AttributeTest.cpp
DialectTest.cpp
OperationSupportTest.cpp
ShapedTypeTest.cpp
)
target_link_libraries(MLIRIRTests
PRIVATE

View File

@ -0,0 +1,129 @@
//===- ShapedTypeTest.cpp - ShapedType unit tests -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectInterface.h"
#include "llvm/ADT/SmallVector.h"
#include "gtest/gtest.h"
#include <cstdint>
using namespace mlir;
using namespace mlir::detail;
namespace {
TEST(ShapedTypeTest, CloneMemref) {
MLIRContext context;
Type i32 = IntegerType::get(&context, 32);
Type f32 = FloatType::getF32(&context);
int memSpace = 7;
Type memrefOriginalType = i32;
llvm::SmallVector<int64_t> memrefOriginalShape({10, 20});
AffineMap map = makeStridedLinearLayoutMap({2, 3}, 5, &context);
ShapedType memrefType =
MemRefType::Builder(memrefOriginalShape, memrefOriginalType)
.setMemorySpace(memSpace)
.setAffineMaps(map);
// Update shape.
llvm::SmallVector<int64_t> memrefNewShape({30, 40});
ASSERT_NE(memrefOriginalShape, memrefNewShape);
ASSERT_EQ(memrefType.clone(memrefNewShape),
(MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
.setMemorySpace(memSpace)
.setAffineMaps(map));
// Update type.
Type memrefNewType = f32;
ASSERT_NE(memrefOriginalType, memrefNewType);
ASSERT_EQ(memrefType.clone(memrefNewType),
(MemRefType)MemRefType::Builder(memrefOriginalShape, memrefNewType)
.setMemorySpace(memSpace)
.setAffineMaps(map));
// Update both.
ASSERT_EQ(memrefType.clone(memrefNewShape, memrefNewType),
(MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)
.setMemorySpace(memSpace)
.setAffineMaps(map));
// Test unranked memref cloning.
ShapedType unrankedTensorType =
UnrankedMemRefType::get(memrefOriginalType, memSpace);
ASSERT_EQ(unrankedTensorType.clone(memrefNewShape),
(MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
.setMemorySpace(memSpace));
ASSERT_EQ(unrankedTensorType.clone(memrefNewType),
UnrankedMemRefType::get(memrefNewType, memSpace));
ASSERT_EQ(unrankedTensorType.clone(memrefNewShape, memrefNewType),
(MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)
.setMemorySpace(memSpace));
}
TEST(ShapedTypeTest, CloneTensor) {
MLIRContext context;
Type i32 = IntegerType::get(&context, 32);
Type f32 = FloatType::getF32(&context);
Type tensorOriginalType = i32;
llvm::SmallVector<int64_t> tensorOriginalShape({10, 20});
// Test ranked tensor cloning.
ShapedType tensorType =
RankedTensorType::get(tensorOriginalShape, tensorOriginalType);
// Update shape.
llvm::SmallVector<int64_t> tensorNewShape({30, 40});
ASSERT_NE(tensorOriginalShape, tensorNewShape);
ASSERT_EQ(tensorType.clone(tensorNewShape),
RankedTensorType::get(tensorNewShape, tensorOriginalType));
// Update type.
Type tensorNewType = f32;
ASSERT_NE(tensorOriginalType, tensorNewType);
ASSERT_EQ(tensorType.clone(tensorNewType),
RankedTensorType::get(tensorOriginalShape, tensorNewType));
// Update both.
ASSERT_EQ(tensorType.clone(tensorNewShape, tensorNewType),
RankedTensorType::get(tensorNewShape, tensorNewType));
// Test unranked tensor cloning.
ShapedType unrankedTensorType = UnrankedTensorType::get(tensorOriginalType);
ASSERT_EQ(unrankedTensorType.clone(tensorNewShape),
RankedTensorType::get(tensorNewShape, tensorOriginalType));
ASSERT_EQ(unrankedTensorType.clone(tensorNewType),
UnrankedTensorType::get(tensorNewType));
ASSERT_EQ(unrankedTensorType.clone(tensorNewShape),
RankedTensorType::get(tensorNewShape, tensorOriginalType));
}
TEST(ShapedTypeTest, CloneVector) {
MLIRContext context;
Type i32 = IntegerType::get(&context, 32);
Type f32 = FloatType::getF32(&context);
Type vectorOriginalType = i32;
llvm::SmallVector<int64_t> vectorOriginalShape({10, 20});
ShapedType vectorType =
VectorType::get(vectorOriginalShape, vectorOriginalType);
// Update shape.
llvm::SmallVector<int64_t> vectorNewShape({30, 40});
ASSERT_NE(vectorOriginalShape, vectorNewShape);
ASSERT_EQ(vectorType.clone(vectorNewShape),
VectorType::get(vectorNewShape, vectorOriginalType));
// Update type.
Type vectorNewType = f32;
ASSERT_NE(vectorOriginalType, vectorNewType);
ASSERT_EQ(vectorType.clone(vectorNewType),
VectorType::get(vectorOriginalShape, vectorNewType));
// Update both.
ASSERT_EQ(vectorType.clone(vectorNewShape, vectorNewType),
VectorType::get(vectorNewShape, vectorNewType));
}
} // end namespace