check modeltype

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
zhoufeng 2021-02-03 14:43:42 +08:00
parent c2d120e714
commit 539d88552a
2 changed files with 24 additions and 3 deletions

View File

@ -25,6 +25,7 @@
namespace mindspore {
constexpr auto kDeviceTypeAscend310 = "Ascend310";
constexpr auto kDeviceTypeAscend910 = "Ascend910";
constexpr auto kDeviceTypeGPU = "GPU";
struct MS_API Context {
virtual ~Context() = default;

View File

@ -20,6 +20,13 @@
#include "utils/utils.h"
namespace mindspore {
namespace {
const std::map<std::string, std::set<ModelType>> kSupportedModelMap = {
{kDeviceTypeAscend310, {kOM, kMindIR}},
{kDeviceTypeAscend910, {kMindIR}},
{kDeviceTypeGPU, {kMindIR}},
};
}
Status Model::Build() {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->Build();
@ -61,8 +68,21 @@ Model::Model(const std::vector<Output> &network, const std::shared_ptr<Context>
Model::~Model() {}
bool Model::CheckModelSupport(const std::string &device_type, ModelType) {
return Factory<ModelImpl>::Instance().CheckModelSupport(device_type);
}
bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) {
if (!Factory<ModelImpl>::Instance().CheckModelSupport(device_type)) {
return false;
}
auto first_iter = kSupportedModelMap.find(device_type);
if (first_iter == kSupportedModelMap.end()) {
return false;
}
auto secend_iter = first_iter->second.find(model_type);
if (secend_iter == first_iter->second.end()) {
return false;
}
return true;
}
} // namespace mindspore