mindspore/predict/common/module_registry.h

98 lines
2.3 KiB
C++

/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_COMMON_MODULE_REGISTRY_H_
#define PREDICT_COMMON_MODULE_REGISTRY_H_
#include <memory>
#include <string>
#include <unordered_map>
#include "common/mslog.h"
#define MSPREDICT_API __attribute__((visibility("default")))
namespace mindspore {
namespace predict {
class ModuleBase {
public:
virtual ~ModuleBase() = default;
};
template <typename T>
class Module;
class ModuleRegistry {
public:
ModuleRegistry() = default;
virtual ~ModuleRegistry() = default;
template <class T>
bool Register(const std::string &name, const T &t) {
modules[name] = &t;
return true;
}
template <class T>
std::shared_ptr<T> Create(const std::string &name) {
auto it = modules.find(name);
if (it == modules.end()) {
return nullptr;
}
auto *module = (Module<T> *)it->second;
if (module == nullptr) {
return nullptr;
} else {
return module->Create();
}
}
template <class T>
T *GetInstance(const std::string &name) {
auto it = modules.find(name);
if (it == modules.end()) {
return nullptr;
}
auto *module = (Module<T> *)it->second;
if (module == nullptr) {
return nullptr;
} else {
return module->GetInstance();
}
}
protected:
std::unordered_map<std::string, const ModuleBase *> modules;
};
ModuleRegistry *GetRegistryInstance() MSPREDICT_API;
template <class T>
class ModuleRegistrar {
public:
ModuleRegistrar(const std::string &name, const T &module) {
auto registryInstance = GetRegistryInstance();
if (registryInstance == nullptr) {
MS_LOGW("registryInstance is nullptr.");
} else {
registryInstance->Register(name, module);
}
}
};
} // namespace predict
} // namespace mindspore
#endif // PREDICT_COMMON_MODULE_REGISTRY_H_