forked from mindspore-Ecosystem/mindspore
check modeltype
Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
parent
c2d120e714
commit
539d88552a
|
@ -25,6 +25,7 @@
|
|||
namespace mindspore {
|
||||
constexpr auto kDeviceTypeAscend310 = "Ascend310";
|
||||
constexpr auto kDeviceTypeAscend910 = "Ascend910";
|
||||
constexpr auto kDeviceTypeGPU = "GPU";
|
||||
|
||||
struct MS_API Context {
|
||||
virtual ~Context() = default;
|
||||
|
|
|
@ -20,6 +20,13 @@
|
|||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
const std::map<std::string, std::set<ModelType>> kSupportedModelMap = {
|
||||
{kDeviceTypeAscend310, {kOM, kMindIR}},
|
||||
{kDeviceTypeAscend910, {kMindIR}},
|
||||
{kDeviceTypeGPU, {kMindIR}},
|
||||
};
|
||||
}
|
||||
Status Model::Build() {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->Build();
|
||||
|
@ -61,8 +68,21 @@ Model::Model(const std::vector<Output> &network, const std::shared_ptr<Context>
|
|||
|
||||
Model::~Model() {}
|
||||
|
||||
bool Model::CheckModelSupport(const std::string &device_type, ModelType) {
|
||||
return Factory<ModelImpl>::Instance().CheckModelSupport(device_type);
|
||||
}
|
||||
bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) {
|
||||
if (!Factory<ModelImpl>::Instance().CheckModelSupport(device_type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto first_iter = kSupportedModelMap.find(device_type);
|
||||
if (first_iter == kSupportedModelMap.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto secend_iter = first_iter->second.find(model_type);
|
||||
if (secend_iter == first_iter->second.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue