Update some of the derived type classes to use getImpl instead of a static_cast.

PiperOrigin-RevId: 240084937
This commit is contained in:
River Riddle 2019-03-24 23:51:05 -07:00 committed by jpienaar
parent e510de0305
commit 63e8725bc2
4 changed files with 16 additions and 28 deletions

View File

@ -167,7 +167,7 @@ public:
} }
/// Utility for easy access to the storage instance. /// Utility for easy access to the storage instance.
ImplType *getImpl() { return static_cast<ImplType *>(type); } ImplType *getImpl() const { return static_cast<ImplType *>(type); }
}; };
using ImplType = TypeStorage; using ImplType = TypeStorage;

View File

@ -50,9 +50,7 @@ IntegerType IntegerType::getChecked(unsigned width, MLIRContext *context,
return Base::getChecked(location, context, StandardTypes::Integer, width); return Base::getChecked(location, context, StandardTypes::Integer, width);
} }
unsigned IntegerType::getWidth() const { unsigned IntegerType::getWidth() const { return getImpl()->width; }
return static_cast<ImplType *>(type)->width;
}
/// Float Type. /// Float Type.
@ -221,9 +219,7 @@ bool VectorType::verifyConstructionInvariants(llvm::Optional<Location> loc,
return false; return false;
} }
ArrayRef<int64_t> VectorType::getShape() const { ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
return static_cast<ImplType *>(type)->getShape();
}
/// TensorType /// TensorType
@ -269,11 +265,7 @@ bool RankedTensorType::verifyConstructionInvariants(
} }
ArrayRef<int64_t> RankedTensorType::getShape() const { ArrayRef<int64_t> RankedTensorType::getShape() const {
return static_cast<ImplType *>(type)->getShape(); return getImpl()->getShape();
}
ArrayRef<int64_t> MemRefType::getShape() const {
return static_cast<ImplType *>(type)->getShape();
} }
/// UnrankedTensorType /// UnrankedTensorType
@ -351,6 +343,10 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
cleanedAffineMapComposition, memorySpace); cleanedAffineMapComposition, memorySpace);
} }
ArrayRef<int64_t> MemRefType::getShape() const {
return static_cast<ImplType *>(type)->getShape();
}
Type MemRefType::getElementType() const { Type MemRefType::getElementType() const {
return static_cast<ImplType *>(type)->elementType; return static_cast<ImplType *>(type)->elementType;
} }
@ -376,11 +372,7 @@ TupleType TupleType::get(ArrayRef<Type> elementTypes, MLIRContext *context) {
} }
/// Return the elements types for this tuple. /// Return the elements types for this tuple.
ArrayRef<Type> TupleType::getTypes() const { ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
return static_cast<ImplType *>(type)->getTypes();
}
/// Return the number of element types. /// Return the number of element types.
unsigned TupleType::size() const { unsigned TupleType::size() const { return getImpl()->size(); }
return static_cast<ImplType *>(type)->size();
}

View File

@ -41,15 +41,13 @@ FunctionType FunctionType::get(ArrayRef<Type> inputs, ArrayRef<Type> results,
} }
ArrayRef<Type> FunctionType::getInputs() const { ArrayRef<Type> FunctionType::getInputs() const {
return static_cast<ImplType *>(type)->getInputs(); return getImpl()->getInputs();
} }
unsigned FunctionType::getNumResults() const { unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
return static_cast<ImplType *>(type)->numResults;
}
ArrayRef<Type> FunctionType::getResults() const { ArrayRef<Type> FunctionType::getResults() const {
return static_cast<ImplType *>(type)->getResults(); return getImpl()->getResults();
} }
/// UnknownType /// UnknownType
@ -66,13 +64,11 @@ UnknownType UnknownType::getChecked(Identifier dialect, StringRef typeData,
/// Returns the dialect namespace of the unknown type. /// Returns the dialect namespace of the unknown type.
Identifier UnknownType::getDialectNamespace() const { Identifier UnknownType::getDialectNamespace() const {
return static_cast<ImplType *>(type)->dialectNamespace; return getImpl()->dialectNamespace;
} }
/// Returns the raw type data of the unknown type. /// Returns the raw type data of the unknown type.
StringRef UnknownType::getTypeData() const { StringRef UnknownType::getTypeData() const { return getImpl()->typeData; }
return static_cast<ImplType *>(type)->typeData;
}
/// Verify the construction of an unknown type. /// Verify the construction of an unknown type.
bool UnknownType::verifyConstructionInvariants(llvm::Optional<Location> loc, bool UnknownType::verifyConstructionInvariants(llvm::Optional<Location> loc,

View File

@ -57,7 +57,7 @@ LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) {
} }
llvm::Type *LLVMType::getUnderlyingType() const { llvm::Type *LLVMType::getUnderlyingType() const {
return static_cast<ImplType *>(type)->underlyingType; return getImpl()->underlyingType;
} }
/*---- LLVM IR Dialect and its registration ----------------------------- */ /*---- LLVM IR Dialect and its registration ----------------------------- */