forked from OSchip/llvm-project
Add a new utility class TypeSwitch to ADT.
This class provides a simplified mechanism for defining a switch over a set of types using llvm casting functionality. More specifically, this allows for defining a switch over a value of type T where each case corresponds to a type(CaseT) that can be used with dyn_cast<CaseT>(...). An example is shown below: // Traditional piece of code: Operation *op = ...; if (auto constant = dyn_cast<ConstantOp>(op)) ...; else if (auto return = dyn_cast<ReturnOp>(op)) ...; else ...; // New piece of code: Operation *op = ...; TypeSwitch<Operation *>(op) .Case<ConstantOp>([](ConstantOp constant) { ... }) .Case<ReturnOp>([](ReturnOp return) { ... }) .Default([](Operation *op) { ... }); Aside from the above, TypeSwitch supports return values, void return, multiple types per case, etc. The usability is intended to be very similar to StringSwitch. (Using c++14 template lambdas makes everything even nicer) More complex example of how this makes certain things easier: LogicalResult process(Constant op); LogicalResult process(ReturnOp op); LogicalResult process(FuncOp op); TypeSwitch<Operation *, LogicalResult>(op) .Case<ConstantOp, ReturnOp, FuncOp>([](auto op) { return process(op); }) .Default([](Operation *op) { return op->emitError() << "could not be processed"; }); PiperOrigin-RevId: 286003613
This commit is contained in:
parent
6e581e29a4
commit
f44cf23297
|
@ -0,0 +1,185 @@
|
|||
//===- TypeSwitch.h - Switch functionality for RTTI casting -*- C++ -*-----===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements the TypeSwitch template, which mimics a switch()
|
||||
// statement whose cases are type names.
|
||||
//
|
||||
//===-----------------------------------------------------------------------===/
|
||||
|
||||
#ifndef MLIR_SUPPORT_TYPESWITCH_H
|
||||
#define MLIR_SUPPORT_TYPESWITCH_H
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace detail {
|
||||
|
||||
template <typename DerivedT, typename T> class TypeSwitchBase {
|
||||
public:
|
||||
TypeSwitchBase(const T &value) : value(value) {}
|
||||
TypeSwitchBase(TypeSwitchBase &&other) : value(other.value) {}
|
||||
~TypeSwitchBase() = default;
|
||||
|
||||
/// TypeSwitchBase is not copyable.
|
||||
TypeSwitchBase(const TypeSwitchBase &) = delete;
|
||||
void operator=(const TypeSwitchBase &) = delete;
|
||||
void operator=(TypeSwitchBase &&other) = delete;
|
||||
|
||||
/// Invoke a case on the derived class with multiple case types.
|
||||
template <typename CaseT, typename CaseT2, typename... CaseTs,
|
||||
typename CallableT>
|
||||
DerivedT &Case(CallableT &&caseFn) {
|
||||
DerivedT &derived = static_cast<DerivedT &>(*this);
|
||||
return derived.template Case<CaseT>(caseFn)
|
||||
.template Case<CaseT2, CaseTs...>(caseFn);
|
||||
}
|
||||
|
||||
/// Invoke a case on the derived class, inferring the type of the Case from
|
||||
/// the first input of the given callable.
|
||||
/// Note: This inference rules for this overload are very simple: strip
|
||||
/// pointers and references.
|
||||
template <typename CallableT> DerivedT &Case(CallableT &&caseFn) {
|
||||
using Traits = FunctionTraits<std::decay_t<CallableT>>;
|
||||
using CaseT = std::remove_cv_t<std::remove_pointer_t<
|
||||
std::remove_reference_t<typename Traits::template arg_t<0>>>>;
|
||||
|
||||
DerivedT &derived = static_cast<DerivedT &>(*this);
|
||||
return derived.template Case<CaseT>(std::forward<CallableT>(caseFn));
|
||||
}
|
||||
|
||||
protected:
|
||||
/// Trait to check whether `ValueT` provides a 'dyn_cast' method with type
|
||||
/// `CastT`.
|
||||
template <typename ValueT, typename CastT>
|
||||
using has_dyn_cast_t =
|
||||
decltype(std::declval<ValueT &>().template dyn_cast<CastT>());
|
||||
|
||||
/// Attempt to dyn_cast the given `value` to `CastT`. This overload is
|
||||
/// selected if `value` already has a suitable dyn_cast method.
|
||||
template <typename CastT, typename ValueT>
|
||||
static auto castValue(
|
||||
ValueT value,
|
||||
typename std::enable_if_t<
|
||||
is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
|
||||
return value.template dyn_cast<CastT>();
|
||||
}
|
||||
|
||||
/// Attempt to dyn_cast the given `value` to `CastT`. This overload is
|
||||
/// selected if llvm::dyn_cast should be used.
|
||||
template <typename CastT, typename ValueT>
|
||||
static auto castValue(
|
||||
ValueT value,
|
||||
typename std::enable_if_t<
|
||||
!is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
|
||||
return dyn_cast<CastT>(value);
|
||||
}
|
||||
|
||||
/// The root value we are switching on.
|
||||
const T value;
|
||||
};
|
||||
} // end namespace detail
|
||||
|
||||
/// This class implements a switch-like dispatch statement for a value of 'T'
|
||||
/// using dyn_cast functionality. Each `Case<T>` takes a callable to be invoked
|
||||
/// if the root value isa<T>, the callable is invoked with the result of
|
||||
/// dyn_cast<T>() as a parameter.
|
||||
///
|
||||
/// Example:
|
||||
/// Operation *op = ...;
|
||||
/// LogicalResult result = TypeSwitch<Operation *, LogicalResult>(op)
|
||||
/// .Case<ConstantOp>([](ConstantOp op) { ... })
|
||||
/// .Default([](Operation *op) { ... });
|
||||
///
|
||||
template <typename T, typename ResultT = void>
|
||||
class TypeSwitch : public detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T> {
|
||||
public:
|
||||
using BaseT = detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T>;
|
||||
using BaseT::BaseT;
|
||||
using BaseT::Case;
|
||||
TypeSwitch(TypeSwitch &&other) = default;
|
||||
|
||||
/// Add a case on the given type.
|
||||
template <typename CaseT, typename CallableT>
|
||||
TypeSwitch<T, ResultT> &Case(CallableT &&caseFn) {
|
||||
if (result)
|
||||
return *this;
|
||||
|
||||
// Check to see if CaseT applies to 'value'.
|
||||
if (auto caseValue = BaseT::template castValue<CaseT>(this->value))
|
||||
result = caseFn(caseValue);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// As a default, invoke the given callable within the root value.
|
||||
template <typename CallableT>
|
||||
LLVM_NODISCARD ResultT Default(CallableT &&defaultFn) {
|
||||
if (result)
|
||||
return std::move(*result);
|
||||
return defaultFn(this->value);
|
||||
}
|
||||
|
||||
LLVM_NODISCARD
|
||||
operator ResultT() {
|
||||
assert(result && "Fell off the end of a type-switch");
|
||||
return std::move(*result);
|
||||
}
|
||||
|
||||
private:
|
||||
/// The pointer to the result of this switch statement, once known,
|
||||
/// null before that.
|
||||
Optional<ResultT> result;
|
||||
};
|
||||
|
||||
/// Specialization of TypeSwitch for void returning callables.
|
||||
template <typename T>
|
||||
class TypeSwitch<T, void>
|
||||
: public detail::TypeSwitchBase<TypeSwitch<T, void>, T> {
|
||||
public:
|
||||
using BaseT = detail::TypeSwitchBase<TypeSwitch<T, void>, T>;
|
||||
using BaseT::BaseT;
|
||||
using BaseT::Case;
|
||||
TypeSwitch(TypeSwitch &&other) = default;
|
||||
|
||||
/// Add a case on the given type.
|
||||
template <typename CaseT, typename CallableT>
|
||||
TypeSwitch<T, void> &Case(CallableT &&caseFn) {
|
||||
if (foundMatch)
|
||||
return *this;
|
||||
|
||||
// Check to see if any of the types apply to 'value'.
|
||||
if (auto caseValue = BaseT::template castValue<CaseT>(this->value)) {
|
||||
caseFn(caseValue);
|
||||
foundMatch = true;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// As a default, invoke the given callable within the root value.
|
||||
template <typename CallableT> void Default(CallableT &&defaultFn) {
|
||||
if (!foundMatch)
|
||||
defaultFn(this->value);
|
||||
}
|
||||
|
||||
private:
|
||||
/// A flag detailing if we have already found a match.
|
||||
bool foundMatch = false;
|
||||
};
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_SUPPORT_TYPESWITCH_H
|
|
@ -344,6 +344,44 @@ template <typename ContainerTy> bool has_single_element(ContainerTy &&c) {
|
|||
auto it = std::begin(c), e = std::end(c);
|
||||
return it != e && std::next(it) == e;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Extra additions to <type_traits>
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class provides various trait information about a callable object.
|
||||
/// * To access the number of arguments: Traits::num_args
|
||||
/// * To access the type of an argument: Traits::arg_t<i>
|
||||
/// * To access the type of the result: Traits::result_t<i>
|
||||
template <typename T, bool isClass = std::is_class<T>::value>
|
||||
struct FunctionTraits : public FunctionTraits<decltype(&T::operator())> {};
|
||||
|
||||
/// Overload for class function types.
|
||||
template <typename ClassType, typename ReturnType, typename... Args>
|
||||
struct FunctionTraits<ReturnType (ClassType::*)(Args...) const, false> {
|
||||
/// The number of arguments to this function.
|
||||
enum { num_args = sizeof...(Args) };
|
||||
|
||||
/// The result type of this function.
|
||||
using result_t = ReturnType;
|
||||
|
||||
/// The type of an argument to this function.
|
||||
template <size_t i>
|
||||
using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
|
||||
};
|
||||
/// Overload for non-class function types.
|
||||
template <typename ReturnType, typename... Args>
|
||||
struct FunctionTraits<ReturnType (*)(Args...), false> {
|
||||
/// The number of arguments to this function.
|
||||
enum { num_args = sizeof...(Args) };
|
||||
|
||||
/// The result type of this function.
|
||||
using result_t = ReturnType;
|
||||
|
||||
/// The type of an argument to this function.
|
||||
template <size_t i>
|
||||
using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
|
||||
};
|
||||
} // end namespace mlir
|
||||
|
||||
// Allow tuples to be usable as DenseMap keys.
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
add_mlir_unittest(MLIRADTTests
|
||||
TypeSwitchTest.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(MLIRADTTests PRIVATE MLIRSupport LLVMSupport)
|
|
@ -0,0 +1,97 @@
|
|||
//===- TypeSwitchTest.cpp - TypeSwitch unit tests -------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/ADT/TypeSwitch.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
/// Utility classes to setup casting functionality.
|
||||
struct Base {
|
||||
enum Kind { DerivedA, DerivedB, DerivedC, DerivedD, DerivedE };
|
||||
Kind kind;
|
||||
};
|
||||
template <Base::Kind DerivedKind> struct DerivedImpl : Base {
|
||||
DerivedImpl() : Base{DerivedKind} {}
|
||||
static bool classof(const Base *base) { return base->kind == DerivedKind; }
|
||||
};
|
||||
struct DerivedA : public DerivedImpl<Base::DerivedA> {};
|
||||
struct DerivedB : public DerivedImpl<Base::DerivedB> {};
|
||||
struct DerivedC : public DerivedImpl<Base::DerivedC> {};
|
||||
struct DerivedD : public DerivedImpl<Base::DerivedD> {};
|
||||
struct DerivedE : public DerivedImpl<Base::DerivedE> {};
|
||||
} // end anonymous namespace
|
||||
|
||||
TEST(StringSwitchTest, CaseResult) {
|
||||
auto translate = [](auto value) {
|
||||
return TypeSwitch<Base *, int>(&value)
|
||||
.Case<DerivedA>([](DerivedA *) { return 0; })
|
||||
.Case([](DerivedB *) { return 1; })
|
||||
.Case([](DerivedC *) { return 2; })
|
||||
.Default([](Base *) { return -1; });
|
||||
};
|
||||
EXPECT_EQ(0, translate(DerivedA()));
|
||||
EXPECT_EQ(1, translate(DerivedB()));
|
||||
EXPECT_EQ(2, translate(DerivedC()));
|
||||
EXPECT_EQ(-1, translate(DerivedD()));
|
||||
}
|
||||
|
||||
TEST(StringSwitchTest, CasesResult) {
|
||||
auto translate = [](auto value) {
|
||||
return TypeSwitch<Base *, int>(&value)
|
||||
.Case<DerivedA, DerivedB, DerivedD>([](auto *) { return 0; })
|
||||
.Case([](DerivedC *) { return 1; })
|
||||
.Default([](Base *) { return -1; });
|
||||
};
|
||||
EXPECT_EQ(0, translate(DerivedA()));
|
||||
EXPECT_EQ(0, translate(DerivedB()));
|
||||
EXPECT_EQ(1, translate(DerivedC()));
|
||||
EXPECT_EQ(0, translate(DerivedD()));
|
||||
EXPECT_EQ(-1, translate(DerivedE()));
|
||||
}
|
||||
|
||||
TEST(StringSwitchTest, CaseVoid) {
|
||||
auto translate = [](auto value) {
|
||||
int result = -2;
|
||||
TypeSwitch<Base *>(&value)
|
||||
.Case([&](DerivedA *) { result = 0; })
|
||||
.Case([&](DerivedB *) { result = 1; })
|
||||
.Case([&](DerivedC *) { result = 2; })
|
||||
.Default([&](Base *) { result = -1; });
|
||||
return result;
|
||||
};
|
||||
EXPECT_EQ(0, translate(DerivedA()));
|
||||
EXPECT_EQ(1, translate(DerivedB()));
|
||||
EXPECT_EQ(2, translate(DerivedC()));
|
||||
EXPECT_EQ(-1, translate(DerivedD()));
|
||||
}
|
||||
|
||||
TEST(StringSwitchTest, CasesVoid) {
|
||||
auto translate = [](auto value) {
|
||||
int result = -1;
|
||||
TypeSwitch<Base *>(&value)
|
||||
.Case<DerivedA, DerivedB, DerivedD>([&](auto *) { result = 0; })
|
||||
.Case([&](DerivedC *) { result = 1; });
|
||||
return result;
|
||||
};
|
||||
EXPECT_EQ(0, translate(DerivedA()));
|
||||
EXPECT_EQ(0, translate(DerivedB()));
|
||||
EXPECT_EQ(1, translate(DerivedC()));
|
||||
EXPECT_EQ(0, translate(DerivedD()));
|
||||
EXPECT_EQ(-1, translate(DerivedE()));
|
||||
}
|
|
@ -5,6 +5,7 @@ function(add_mlir_unittest test_dirname)
|
|||
add_unittest(MLIRUnitTests ${test_dirname} ${ARGN})
|
||||
endfunction()
|
||||
|
||||
add_subdirectory(ADT)
|
||||
add_subdirectory(Dialect)
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Pass)
|
||||
|
|
Loading…
Reference in New Issue