NFC: Use TypeSwitch to simplify existing code.

PiperOrigin-RevId: 286066371
This commit is contained in:
River Riddle 2019-12-17 14:57:07 -08:00 committed by A. Unique TensorFlower
parent 6fa3bd5b3e
commit 74278dd01e
11 changed files with 116 additions and 175 deletions

View File

@ -21,6 +21,7 @@
#include "toy/AST.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
@ -84,20 +85,15 @@ template <typename T> static std::string loc(T *node) {
/// Dispatch to a generic expressions to the appropriate subclass using RTTI
void ASTDumper::dump(ExprAST *expr) {
#define dispatch(CLASS) \
if (CLASS *node = llvm::dyn_cast<CLASS>(expr)) \
return dump(node);
dispatch(VarDeclExprAST);
dispatch(LiteralExprAST);
dispatch(NumberExprAST);
dispatch(VariableExprAST);
dispatch(ReturnExprAST);
dispatch(BinaryExprAST);
dispatch(CallExprAST);
dispatch(PrintExprAST);
// No match, fallback to a generic message
INDENT();
llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
mlir::TypeSwitch<ExprAST *>(expr)
.Case<BinaryExprAST, CallExprAST, LiteralExprAST, NumberExprAST,
PrintExprAST, ReturnExprAST, VarDeclExprAST, VariableExprAST>(
[&](auto *node) { dump(node); })
.Default([&](ExprAST *) {
// No match, fallback to a generic message
INDENT();
llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
});
}
/// A variable declaration is printing the variable name, the type, and then

View File

@ -21,6 +21,7 @@
#include "toy/AST.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
@ -84,20 +85,15 @@ template <typename T> static std::string loc(T *node) {
/// Dispatch to a generic expressions to the appropriate subclass using RTTI
void ASTDumper::dump(ExprAST *expr) {
#define dispatch(CLASS) \
if (CLASS *node = llvm::dyn_cast<CLASS>(expr)) \
return dump(node);
dispatch(VarDeclExprAST);
dispatch(LiteralExprAST);
dispatch(NumberExprAST);
dispatch(VariableExprAST);
dispatch(ReturnExprAST);
dispatch(BinaryExprAST);
dispatch(CallExprAST);
dispatch(PrintExprAST);
// No match, fallback to a generic message
INDENT();
llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
mlir::TypeSwitch<ExprAST *>(expr)
.Case<BinaryExprAST, CallExprAST, LiteralExprAST, NumberExprAST,
PrintExprAST, ReturnExprAST, VarDeclExprAST, VariableExprAST>(
[&](auto *node) { dump(node); })
.Default([&](ExprAST *) {
// No match, fallback to a generic message
INDENT();
llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
});
}
/// A variable declaration is printing the variable name, the type, and then

View File

@ -21,6 +21,7 @@
#include "toy/AST.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
@ -84,20 +85,15 @@ template <typename T> static std::string loc(T *node) {
/// Dispatch to a generic expressions to the appropriate subclass using RTTI
void ASTDumper::dump(ExprAST *expr) {
#define dispatch(CLASS) \
if (CLASS *node = llvm::dyn_cast<CLASS>(expr)) \
return dump(node);
dispatch(VarDeclExprAST);
dispatch(LiteralExprAST);
dispatch(NumberExprAST);
dispatch(VariableExprAST);
dispatch(ReturnExprAST);
dispatch(BinaryExprAST);
dispatch(CallExprAST);
dispatch(PrintExprAST);
// No match, fallback to a generic message
INDENT();
llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
mlir::TypeSwitch<ExprAST *>(expr)
.Case<BinaryExprAST, CallExprAST, LiteralExprAST, NumberExprAST,
PrintExprAST, ReturnExprAST, VarDeclExprAST, VariableExprAST>(
[&](auto *node) { dump(node); })
.Default([&](ExprAST *) {
// No match, fallback to a generic message
INDENT();
llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
});
}
/// A variable declaration is printing the variable name, the type, and then

View File

@ -21,6 +21,7 @@
#include "toy/AST.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
@ -84,20 +85,15 @@ template <typename T> static std::string loc(T *node) {
/// Dispatch to a generic expressions to the appropriate subclass using RTTI
void ASTDumper::dump(ExprAST *expr) {
#define dispatch(CLASS) \
if (CLASS *node = llvm::dyn_cast<CLASS>(expr)) \
return dump(node);
dispatch(VarDeclExprAST);
dispatch(LiteralExprAST);
dispatch(NumberExprAST);
dispatch(VariableExprAST);
dispatch(ReturnExprAST);
dispatch(BinaryExprAST);
dispatch(CallExprAST);
dispatch(PrintExprAST);
// No match, fallback to a generic message
INDENT();
llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
mlir::TypeSwitch<ExprAST *>(expr)
.Case<BinaryExprAST, CallExprAST, LiteralExprAST, NumberExprAST,
PrintExprAST, ReturnExprAST, VarDeclExprAST, VariableExprAST>(
[&](auto *node) { dump(node); })
.Default([&](ExprAST *) {
// No match, fallback to a generic message
INDENT();
llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
});
}
/// A variable declaration is printing the variable name, the type, and then

View File

@ -21,6 +21,7 @@
#include "toy/AST.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
@ -84,20 +85,15 @@ template <typename T> static std::string loc(T *node) {
/// Dispatch to a generic expressions to the appropriate subclass using RTTI
void ASTDumper::dump(ExprAST *expr) {
#define dispatch(CLASS) \
if (CLASS *node = llvm::dyn_cast<CLASS>(expr)) \
return dump(node);
dispatch(VarDeclExprAST);
dispatch(LiteralExprAST);
dispatch(NumberExprAST);
dispatch(VariableExprAST);
dispatch(ReturnExprAST);
dispatch(BinaryExprAST);
dispatch(CallExprAST);
dispatch(PrintExprAST);
// No match, fallback to a generic message
INDENT();
llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
mlir::TypeSwitch<ExprAST *>(expr)
.Case<BinaryExprAST, CallExprAST, LiteralExprAST, NumberExprAST,
PrintExprAST, ReturnExprAST, VarDeclExprAST, VariableExprAST>(
[&](auto *node) { dump(node); })
.Default([&](ExprAST *) {
// No match, fallback to a generic message
INDENT();
llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
});
}
/// A variable declaration is printing the variable name, the type, and then

View File

@ -21,6 +21,7 @@
#include "toy/AST.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
@ -84,20 +85,15 @@ template <typename T> static std::string loc(T *node) {
/// Dispatch to a generic expressions to the appropriate subclass using RTTI
void ASTDumper::dump(ExprAST *expr) {
#define dispatch(CLASS) \
if (CLASS *node = llvm::dyn_cast<CLASS>(expr)) \
return dump(node);
dispatch(VarDeclExprAST);
dispatch(LiteralExprAST);
dispatch(NumberExprAST);
dispatch(VariableExprAST);
dispatch(ReturnExprAST);
dispatch(BinaryExprAST);
dispatch(CallExprAST);
dispatch(PrintExprAST);
// No match, fallback to a generic message
INDENT();
llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
mlir::TypeSwitch<ExprAST *>(expr)
.Case<BinaryExprAST, CallExprAST, LiteralExprAST, NumberExprAST,
PrintExprAST, ReturnExprAST, VarDeclExprAST, VariableExprAST>(
[&](auto *node) { dump(node); })
.Default([&](ExprAST *) {
// No match, fallback to a generic message
INDENT();
llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
});
}
/// A variable declaration is printing the variable name, the type, and then

View File

@ -21,6 +21,7 @@
#include "toy/AST.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
@ -86,21 +87,15 @@ template <typename T> static std::string loc(T *node) {
/// Dispatch to a generic expressions to the appropriate subclass using RTTI
void ASTDumper::dump(ExprAST *expr) {
#define dispatch(CLASS) \
if (CLASS *node = llvm::dyn_cast<CLASS>(expr)) \
return dump(node);
dispatch(VarDeclExprAST);
dispatch(LiteralExprAST);
dispatch(StructLiteralExprAST);
dispatch(NumberExprAST);
dispatch(VariableExprAST);
dispatch(ReturnExprAST);
dispatch(BinaryExprAST);
dispatch(CallExprAST);
dispatch(PrintExprAST);
// No match, fallback to a generic message
INDENT();
llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
mlir::TypeSwitch<ExprAST *>(expr)
.Case<BinaryExprAST, CallExprAST, LiteralExprAST, NumberExprAST,
PrintExprAST, ReturnExprAST, StructLiteralExprAST, VarDeclExprAST,
VariableExprAST>([&](auto *node) { dump(node); })
.Default([&](ExprAST *) {
// No match, fallback to a generic message
INDENT();
llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
});
}
/// A variable declaration is printing the variable name, the type, and then

View File

@ -20,6 +20,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/Passes.h"
@ -49,11 +50,9 @@ std::unique_ptr<OpPassBase<FuncOp>> mlir::createMemRefBoundCheckPass() {
void MemRefBoundCheck::runOnFunction() {
getFunction().walk([](Operation *opInst) {
if (auto loadOp = dyn_cast<AffineLoadOp>(opInst)) {
boundCheckLoadOrStoreOp(loadOp);
} else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst)) {
boundCheckLoadOrStoreOp(storeOp);
}
TypeSwitch<Operation *>(opInst).Case<AffineLoadOp, AffineStoreOp>(
[](auto op) { boundCheckLoadOrStoreOp(op); });
// TODO(bondhugula): do this for DMA ops as well.
});
}

View File

@ -21,6 +21,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@ -232,25 +233,19 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) {
}
// Dispatch based on the actual type. Return null type on error.
Type LLVMTypeConverter::convertStandardType(Type type) {
if (auto funcType = type.dyn_cast<FunctionType>())
return convertFunctionType(funcType);
if (auto intType = type.dyn_cast<IntegerType>())
return convertIntegerType(intType);
if (auto floatType = type.dyn_cast<FloatType>())
return convertFloatType(floatType);
if (auto indexType = type.dyn_cast<IndexType>())
return convertIndexType(indexType);
if (auto memRefType = type.dyn_cast<MemRefType>())
return convertMemRefType(memRefType);
if (auto memRefType = type.dyn_cast<UnrankedMemRefType>())
return convertUnrankedMemRefType(memRefType);
if (auto vectorType = type.dyn_cast<VectorType>())
return convertVectorType(vectorType);
if (auto llvmType = type.dyn_cast<LLVM::LLVMType>())
return llvmType;
return {};
Type LLVMTypeConverter::convertStandardType(Type t) {
return TypeSwitch<Type, Type>(t)
.Case([&](FloatType type) { return convertFloatType(type); })
.Case([&](FunctionType type) { return convertFunctionType(type); })
.Case([&](IndexType type) { return convertIndexType(type); })
.Case([&](IntegerType type) { return convertIntegerType(type); })
.Case([&](MemRefType type) { return convertMemRefType(type); })
.Case([&](UnrankedMemRefType type) {
return convertUnrankedMemRefType(type);
})
.Case([&](VectorType type) { return convertVectorType(type); })
.Case([](LLVM::LLVMType type) { return type; })
.Default([](Type) { return Type(); });
}
LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,

View File

@ -21,6 +21,7 @@
#include "mlir/Dialect/SPIRV/Serialization.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
@ -1634,54 +1635,33 @@ Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
return success();
}
LogicalResult Serializer::processOperation(Operation *op) {
LLVM_DEBUG(llvm::dbgs() << "[op] '" << op->getName() << "'\n");
LogicalResult Serializer::processOperation(Operation *opInst) {
LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
// First dispatch the ops that do not directly mirror an instruction from
// the SPIR-V spec.
if (auto addressOfOp = dyn_cast<spirv::AddressOfOp>(op)) {
return processAddressOfOp(addressOfOp);
}
if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
return processBranchOp(branchOp);
}
if (auto condBranchOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
return processBranchConditionalOp(condBranchOp);
}
if (auto constOp = dyn_cast<spirv::ConstantOp>(op)) {
return processConstantOp(constOp);
}
if (auto fnOp = dyn_cast<FuncOp>(op)) {
return processFuncOp(fnOp);
}
if (auto varOp = dyn_cast<spirv::VariableOp>(op)) {
return processVariableOp(varOp);
}
if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
return processGlobalVariableOp(varOp);
}
if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op)) {
return processSelectionOp(selectionOp);
}
if (auto loopOp = dyn_cast<spirv::LoopOp>(op)) {
return processLoopOp(loopOp);
}
if (isa<spirv::ModuleEndOp>(op)) {
return success();
}
if (auto refOpOp = dyn_cast<spirv::ReferenceOfOp>(op)) {
return processReferenceOfOp(refOpOp);
}
if (auto specConstOp = dyn_cast<spirv::SpecConstantOp>(op)) {
return processSpecConstantOp(specConstOp);
}
if (auto undefOp = dyn_cast<spirv::UndefOp>(op)) {
return processUndefOp(undefOp);
}
return TypeSwitch<Operation *, LogicalResult>(opInst)
.Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); })
.Case([&](spirv::BranchOp op) { return processBranchOp(op); })
.Case([&](spirv::BranchConditionalOp op) {
return processBranchConditionalOp(op);
})
.Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
.Case([&](FuncOp op) { return processFuncOp(op); })
.Case([&](spirv::GlobalVariableOp op) {
return processGlobalVariableOp(op);
})
.Case([&](spirv::LoopOp op) { return processLoopOp(op); })
.Case([&](spirv::ModuleEndOp) { return success(); })
.Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
.Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
.Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
.Case([&](spirv::UndefOp op) { return processUndefOp(op); })
.Case([&](spirv::VariableOp op) { return processVariableOp(op); })
// Then handle all the ops that directly mirror SPIR-V instructions with
// auto-generated methods.
return dispatchToAutogenSerialization(op);
// Then handle all the ops that directly mirror SPIR-V instructions with
// auto-generated methods.
.Default([&](auto *op) { return dispatchToAutogenSerialization(op); });
}
namespace {

View File

@ -22,6 +22,7 @@
#include "mlir/Transforms/Utils.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/Dominance.h"
@ -47,14 +48,9 @@ static bool isMemRefDereferencingOp(Operation &op) {
/// Return the AffineMapAttr associated with memory 'op' on 'memref'.
static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value *memref) {
if (auto loadOp = dyn_cast<AffineLoadOp>(op))
return loadOp.getAffineMapAttrForMemRef(memref);
else if (auto storeOp = dyn_cast<AffineStoreOp>(op))
return storeOp.getAffineMapAttrForMemRef(memref);
else if (auto dmaStart = dyn_cast<AffineDmaStartOp>(op))
return dmaStart.getAffineMapAttrForMemRef(memref);
assert(isa<AffineDmaWaitOp>(op));
return cast<AffineDmaWaitOp>(op).getAffineMapAttrForMemRef(memref);
return TypeSwitch<Operation *, NamedAttribute>(op)
.Case<AffineDmaStartOp, AffineLoadOp, AffineStoreOp, AffineDmaWaitOp>(
[=](auto op) { return op.getAffineMapAttrForMemRef(memref); });
}
// Perform the replacement in `op`.