[Orc] Use recursive mutexes for Error serialization.

Errors can be nested, so we need recursive locking for serialization /
deserialization.

llvm-svn: 301147
This commit is contained in:
Lang Hames 2017-04-23 23:36:13 +00:00
parent 148e49f3c8
commit 70eccdc727
1 changed files with 39 additions and 28 deletions

View File

@ -348,7 +348,7 @@ public:
// key of the deserializers map to save us from duplicating the string in
// the serializer. This should be changed to use a stringpool if we switch
// to a map type that may move keys in memory.
std::lock_guard<std::mutex> Lock(DeserializersMutex);
std::lock_guard<std::recursive_mutex> Lock(DeserializersMutex);
auto I =
Deserializers.insert(Deserializers.begin(),
std::make_pair(std::move(Name),
@ -358,7 +358,7 @@ public:
{
assert(KeyName != nullptr && "No keyname pointer");
std::lock_guard<std::mutex> Lock(SerializersMutex);
std::lock_guard<std::recursive_mutex> Lock(SerializersMutex);
// FIXME: Move capture Serialize once we have C++14.
Serializers[ErrorInfoT::classID()] =
[KeyName, Serialize](ChannelT &C, const ErrorInfoBase &EIB) -> Error {
@ -372,12 +372,13 @@ public:
}
static Error serialize(ChannelT &C, Error &&Err) {
std::lock_guard<std::mutex> Lock(SerializersMutex);
std::lock_guard<std::recursive_mutex> Lock(SerializersMutex);
if (!Err)
return serializeSeq(C, std::string());
return handleErrors(std::move(Err),
[&C](const ErrorInfoBase &EIB) {
[&C, &Lock](const ErrorInfoBase &EIB) {
auto SI = Serializers.find(EIB.dynamicClassID());
if (SI == Serializers.end())
return serializeAsStringError(C, EIB);
@ -386,7 +387,7 @@ public:
}
static Error deserialize(ChannelT &C, Error &Err) {
std::lock_guard<std::mutex> Lock(DeserializersMutex);
std::lock_guard<std::recursive_mutex> Lock(DeserializersMutex);
std::string Key;
if (auto Err = deserializeSeq(C, Key))
@ -406,8 +407,6 @@ public:
private:
static Error serializeAsStringError(ChannelT &C, const ErrorInfoBase &EIB) {
assert(EIB.dynamicClassID() != StringError::classID() &&
"StringError serialization not registered");
std::string ErrMsg;
{
raw_string_ostream ErrMsgStream(ErrMsg);
@ -417,17 +416,17 @@ private:
inconvertibleErrorCode()));
}
static std::mutex SerializersMutex;
static std::mutex DeserializersMutex;
static std::recursive_mutex SerializersMutex;
static std::recursive_mutex DeserializersMutex;
static std::map<const void*, WrappedErrorSerializer> Serializers;
static std::map<std::string, WrappedErrorDeserializer> Deserializers;
};
template <typename ChannelT>
std::mutex SerializationTraits<ChannelT, Error>::SerializersMutex;
std::recursive_mutex SerializationTraits<ChannelT, Error>::SerializersMutex;
template <typename ChannelT>
std::mutex SerializationTraits<ChannelT, Error>::DeserializersMutex;
std::recursive_mutex SerializationTraits<ChannelT, Error>::DeserializersMutex;
template <typename ChannelT>
std::map<const void*,
@ -439,27 +438,39 @@ std::map<std::string,
typename SerializationTraits<ChannelT, Error>::WrappedErrorDeserializer>
SerializationTraits<ChannelT, Error>::Deserializers;
/// Registers a serializer and deserializer for the given error type on the
/// given channel type.
template <typename ChannelT, typename ErrorInfoT, typename SerializeFtor,
typename DeserializeFtor>
void registerErrorSerialization(std::string Name, SerializeFtor &&Serialize,
DeserializeFtor &&Deserialize) {
SerializationTraits<ChannelT, Error>::template registerErrorType<ErrorInfoT>(
std::move(Name),
std::forward<SerializeFtor>(Serialize),
std::forward<DeserializeFtor>(Deserialize));
}
/// Registers serialization/deserialization for StringError.
template <typename ChannelT>
void registerStringError() {
static bool AlreadyRegistered = false;
if (!AlreadyRegistered) {
SerializationTraits<ChannelT, Error>::
template registerErrorType<StringError>(
"StringError",
[](ChannelT &C, const StringError &SE) {
return serializeSeq(C, SE.getMessage());
},
[](ChannelT &C, Error &Err) {
ErrorAsOutParameter EAO(&Err);
std::string Msg;
if (auto E2 = deserializeSeq(C, Msg))
return E2;
Err =
make_error<StringError>(std::move(Msg),
orcError(
OrcErrorCode::UnknownErrorCodeFromRemote));
return Error::success();
});
registerErrorSerialization<ChannelT, StringError>(
"StringError",
[](ChannelT &C, const StringError &SE) {
return serializeSeq(C, SE.getMessage());
},
[](ChannelT &C, Error &Err) -> Error {
ErrorAsOutParameter EAO(&Err);
std::string Msg;
if (auto E2 = deserializeSeq(C, Msg))
return E2;
Err =
make_error<StringError>(std::move(Msg),
orcError(
OrcErrorCode::UnknownErrorCodeFromRemote));
return Error::success();
});
AlreadyRegistered = true;
}
}