[mlir][spirv] Timely fail type conversion

Per the TypeConverter API contract, returning `llvm:None` means
other conversion rules should be tried. But we only have one
rule per input type. So there is no need to try others and we can
just directly fail, which should return `nullptr`. This avoids
unnecessary checks.

Differential Revision: https://reviews.llvm.org/D100058
This commit is contained in:
Lei Zhang 2021-04-07 14:52:07 -04:00
parent 94a6fe43de
commit 004f29c0bb
1 changed files with 32 additions and 35 deletions

View File

@ -235,9 +235,9 @@ Optional<int64_t> SPIRVTypeConverter::getConvertedTypeNumBytes(Type t) {
}
/// Converts a scalar `type` to a suitable type under the given `targetEnv`.
static Optional<Type>
convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
Optional<spirv::StorageClass> storageClass = {}) {
static Type convertScalarType(const spirv::TargetEnv &targetEnv,
spirv::ScalarType type,
Optional<spirv::StorageClass> storageClass = {}) {
// Get extension and capability requirements for the given type.
SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
@ -271,9 +271,9 @@ convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
}
/// Converts a vector `type` to a suitable type under the given `targetEnv`.
static Optional<Type>
convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type,
Optional<spirv::StorageClass> storageClass = {}) {
static Type convertVectorType(const spirv::TargetEnv &targetEnv,
VectorType type,
Optional<spirv::StorageClass> storageClass = {}) {
if (type.getRank() == 1 && type.getNumElements() == 1)
return type.getElementType();
@ -281,7 +281,7 @@ convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type,
// TODO: Vector types with more than four elements can be translated into
// array types.
LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n");
return llvm::None;
return nullptr;
}
// Get extension and capability requirements for the given type.
@ -298,8 +298,8 @@ convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type,
auto elementType = convertScalarType(
targetEnv, type.getElementType().cast<spirv::ScalarType>(), storageClass);
if (elementType)
return VectorType::get(type.getShape(), *elementType);
return llvm::None;
return VectorType::get(type.getShape(), elementType);
return nullptr;
}
/// Converts a tensor `type` to a suitable type under the given `targetEnv`.
@ -308,20 +308,20 @@ convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type,
/// create composite constants with OpConstantComposite to embed relative large
/// constant values and use OpCompositeExtract and OpCompositeInsert to
/// manipulate, like what we do for vectors.
static Optional<Type> convertTensorType(const spirv::TargetEnv &targetEnv,
TensorType type) {
static Type convertTensorType(const spirv::TargetEnv &targetEnv,
TensorType type) {
// TODO: Handle dynamic shapes.
if (!type.hasStaticShape()) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: dynamic shape unimplemented\n");
return llvm::None;
return nullptr;
}
auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
if (!scalarType) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot convert non-scalar element type\n");
return llvm::None;
return nullptr;
}
Optional<int64_t> scalarSize = getTypeNumBytes(scalarType);
@ -329,35 +329,35 @@ static Optional<Type> convertTensorType(const spirv::TargetEnv &targetEnv,
if (!scalarSize || !tensorSize) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot deduce element count\n");
return llvm::None;
return nullptr;
}
auto arrayElemCount = *tensorSize / *scalarSize;
auto arrayElemType = convertScalarType(targetEnv, scalarType);
if (!arrayElemType)
return llvm::None;
Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
return nullptr;
Optional<int64_t> arrayElemSize = getTypeNumBytes(arrayElemType);
if (!arrayElemSize) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot deduce converted element size\n");
return llvm::None;
return nullptr;
}
return spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize);
return spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
}
static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
MemRefType type) {
static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
MemRefType type) {
Optional<spirv::StorageClass> storageClass =
SPIRVTypeConverter::getStorageClassForMemorySpace(
type.getMemorySpaceAsInt());
if (!storageClass) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot convert memory space\n");
return llvm::None;
return nullptr;
}
Optional<Type> arrayElemType;
Type arrayElemType;
Type elementType = type.getElementType();
if (auto vecType = elementType.dyn_cast<VectorType>()) {
arrayElemType = convertVectorType(targetEnv, vecType, storageClass);
@ -368,20 +368,20 @@ static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
llvm::dbgs()
<< type
<< " unhandled: can only convert scalar or vector element type\n");
return llvm::None;
return nullptr;
}
if (!arrayElemType)
return llvm::None;
return nullptr;
Optional<int64_t> elementSize = getTypeNumBytes(elementType);
if (!elementSize) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot deduce element size\n");
return llvm::None;
return nullptr;
}
if (!type.hasStaticShape()) {
auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *elementSize);
auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, *elementSize);
// Wrap in a struct to satisfy Vulkan interface requirements.
auto structType = spirv::StructType::get(arrayType, 0);
return spirv::PointerType::get(structType, *storageClass);
@ -391,20 +391,20 @@ static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
if (!memrefSize) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot deduce element count\n");
return llvm::None;
return nullptr;
}
auto arrayElemCount = *memrefSize / *elementSize;
Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
Optional<int64_t> arrayElemSize = getTypeNumBytes(arrayElemType);
if (!arrayElemSize) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot deduce converted element size\n");
return llvm::None;
return nullptr;
}
auto arrayType =
spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize);
spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
// Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
// workgroup storage class do not need the struct to be laid out explicitly.
@ -418,9 +418,6 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)
: targetEnv(targetAttr) {
// Add conversions. The order matters here: later ones will be tried earlier.
// All other cases failed. Then we cannot convert this type.
addConversion([](Type type) { return llvm::None; });
// Allow all SPIR-V dialect specific types. This assumes all builtin types
// adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
// were tried before.
@ -438,13 +435,13 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)
addConversion([this](IntegerType intType) -> Optional<Type> {
if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
return convertScalarType(targetEnv, scalarType);
return llvm::None;
return Type();
});
addConversion([this](FloatType floatType) -> Optional<Type> {
if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>())
return convertScalarType(targetEnv, scalarType);
return llvm::None;
return Type();
});
addConversion([this](VectorType vectorType) {