From 3204ecb7d652e118f141a2ef1a4d0cfd8686f22e Mon Sep 17 00:00:00 2001 From: zhoufeng Date: Tue, 1 Dec 2020 18:35:15 +0800 Subject: [PATCH] mindspore c++ interface Signed-off-by: zhoufeng --- build.sh | 6 +- include/api/cell.h | 38 +- include/api/context.h | 41 + include/api/graph.h | 43 + include/api/model.h | 29 +- include/api/serialization.h | 2 + include/api/types.h | 3 + mindspore/ccsrc/cxx_api/CMakeLists.txt | 36 +- mindspore/ccsrc/cxx_api/cell.cc | 49 + mindspore/ccsrc/cxx_api/context.cc | 63 + mindspore/ccsrc/cxx_api/factory.h | 83 ++ .../ccsrc/cxx_api/graph/acl/acl_graph_impl.cc | 266 ++++ .../ccsrc/cxx_api/graph/acl/acl_graph_impl.h | 73 ++ .../{model => graph}/acl/model_process.cc | 124 +- .../{model => graph}/acl/model_process.h | 27 +- mindspore/ccsrc/cxx_api/graph/graph.cc | 29 + mindspore/ccsrc/cxx_api/graph/graph_data.cc | 73 ++ mindspore/ccsrc/cxx_api/graph/graph_data.h | 48 + mindspore/ccsrc/cxx_api/graph/graph_impl.h | 51 + .../ccsrc/cxx_api/graph/ms/ms_graph_impl.cc | 334 +++++ .../ccsrc/cxx_api/graph/ms/ms_graph_impl.h | 65 + .../ccsrc/cxx_api/model/acl/acl_model.cc | 287 +--- mindspore/ccsrc/cxx_api/model/acl/acl_model.h | 63 +- .../ccsrc/cxx_api/model/acl/dvpp_process.cc | 1160 ----------------- .../ccsrc/cxx_api/model/acl/dvpp_process.h | 160 --- .../cxx_api/model/acl/model_converter.cc | 126 +- .../ccsrc/cxx_api/model/acl/model_converter.h | 8 +- mindspore/ccsrc/cxx_api/model/model.cc | 76 +- mindspore/ccsrc/cxx_api/model/model_impl.h | 74 +- mindspore/ccsrc/cxx_api/model/ms/ms_model.cc | 396 +----- mindspore/ccsrc/cxx_api/model/ms/ms_model.h | 44 +- mindspore/ccsrc/cxx_api/python_utils.cc | 65 + mindspore/ccsrc/cxx_api/python_utils.h | 27 + mindspore/ccsrc/cxx_api/serialization.cc | 68 + mindspore/ccsrc/cxx_api/types.cc | 3 + mindspore/core/ir/anf.h | 1 - tests/CMakeLists.txt | 9 +- tests/cxx_st/CMakeLists.txt | 11 + tests/cxx_st/common/common_test.cc | 40 + tests/cxx_st/common/common_test.h | 76 ++ tests/cxx_st/common/test_main.cc | 22 + tests/cxx_st/model/test_tensor_add.cc | 58 + 42 files changed, 1920 insertions(+), 2337 deletions(-) create mode 100644 include/api/context.h create mode 100644 include/api/graph.h create mode 100644 mindspore/ccsrc/cxx_api/context.cc create mode 100644 mindspore/ccsrc/cxx_api/factory.h create mode 100644 mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.cc create mode 100644 mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.h rename mindspore/ccsrc/cxx_api/{model => graph}/acl/model_process.cc (77%) rename mindspore/ccsrc/cxx_api/{model => graph}/acl/model_process.h (72%) create mode 100644 mindspore/ccsrc/cxx_api/graph/graph.cc create mode 100644 mindspore/ccsrc/cxx_api/graph/graph_data.cc create mode 100644 mindspore/ccsrc/cxx_api/graph/graph_data.h create mode 100644 mindspore/ccsrc/cxx_api/graph/graph_impl.h create mode 100644 mindspore/ccsrc/cxx_api/graph/ms/ms_graph_impl.cc create mode 100644 mindspore/ccsrc/cxx_api/graph/ms/ms_graph_impl.h delete mode 100644 mindspore/ccsrc/cxx_api/model/acl/dvpp_process.cc delete mode 100644 mindspore/ccsrc/cxx_api/model/acl/dvpp_process.h create mode 100644 mindspore/ccsrc/cxx_api/python_utils.cc create mode 100644 mindspore/ccsrc/cxx_api/python_utils.h create mode 100644 tests/cxx_st/CMakeLists.txt create mode 100644 tests/cxx_st/common/common_test.cc create mode 100644 tests/cxx_st/common/common_test.h create mode 100644 tests/cxx_st/common/test_main.cc create mode 100644 tests/cxx_st/model/test_tensor_add.cc diff --git a/build.sh b/build.sh index 217fa055b73..2d8a714a029 100755 --- a/build.sh +++ b/build.sh @@ -23,7 +23,7 @@ usage() { echo "Usage:" echo "bash build.sh [-d] [-r] [-v] [-c on|off] [-t on|off] [-g on|off] [-h] [-b ge] [-m infer|train] \\" - echo " [-a on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|ascend|cpu|acl] \\" + echo " [-a on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|ascend|cpu|ascend310] \\" echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 9.2|10.1] [-I arm64|arm32|x86_64] [-K] \\" echo " [-B on|off] [-E] [-l on|off] [-n full|lite|off] [-T on|off] \\" echo " [-A [cpp|java|object-c] [-C on|off] [-o on|off] [-S on|off] [-k on|off] [-W sse|neon|avx|off] \\" @@ -45,7 +45,7 @@ usage() echo " -i Enable increment building, default off" echo " -L Enable load ANF-IR as input of 'infer', default off" echo " -j[n] Set the threads when building (Default: -j8)" - echo " -e Use cpu, gpu, ascend or acl" + echo " -e Use cpu, gpu, ascend or ascend310" echo " -P Enable dump anf graph to file in ProtoBuffer format, default on" echo " -D Enable dumping of function graph ir, default on" echo " -z Compile dataset & mindrecord, default on" @@ -224,7 +224,7 @@ checkopts() ENABLE_D="on" ENABLE_CPU="on" ENABLE_SERVING="on" - elif [[ "X$OPTARG" == "Xacl" ]]; then + elif [[ "X$OPTARG" == "Xascend310" ]]; then ENABLE_SERVING="on" ENABLE_ACL="on" elif [[ "X$OPTARG" == "Xcpu" ]]; then diff --git a/include/api/cell.h b/include/api/cell.h index 4b32256b298..096bb8b1a9c 100644 --- a/include/api/cell.h +++ b/include/api/cell.h @@ -21,6 +21,7 @@ #include #include "include/api/status.h" #include "include/api/types.h" +#include "include/api/graph.h" namespace mindspore { namespace api { @@ -34,6 +35,7 @@ class MS_API CellBase { virtual ~CellBase() = default; virtual std::vector Construct(const std::vector &inputs) { return {}; } virtual std::shared_ptr Clone() const = 0; + virtual Status Run(const std::vector &inputs, std::vector *outputs) { return SUCCESS; } std::vector operator()(const std::vector &inputs) const; }; @@ -41,9 +43,7 @@ template class MS_API Cell : public CellBase { public: virtual ~Cell() = default; - std::shared_ptr Clone() const override { - return std::make_shared(static_cast(*this)); - } + std::shared_ptr Clone() const override { return std::make_shared(static_cast(*this)); } }; class MS_API ParameterCell final : public Cell { @@ -84,9 +84,33 @@ class MS_API OpCell : public OpCellBase, public std::enable_shared_from_this public: explicit OpCell(const std::string &name) : OpCellBase(name) {} ~OpCell() override = default; - std::shared_ptr Clone() const override { - return std::make_shared(static_cast(*this)); - } + std::shared_ptr Clone() const override { return std::make_shared(static_cast(*this)); } +}; + +class MS_API GraphCell final : public Cell { + public: + class GraphImpl; + + GraphCell() = default; + ~GraphCell() override = default; + + explicit GraphCell(const Graph &); + explicit GraphCell(Graph &&); + explicit GraphCell(const std::shared_ptr &); + + const std::shared_ptr &GetGraph() const { return graph_; } + Status Run(const std::vector &inputs, std::vector *outputs) override; + Status GetInputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) const; + Status GetOutputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) const; + + private: + friend class ModelImpl; + Status Load(); + + std::shared_ptr graph_; + std::shared_ptr executor_; }; class MS_API InputAndOutput { @@ -96,7 +120,7 @@ class MS_API InputAndOutput { // no explicit InputAndOutput(const Tensor &); // NOLINT(runtime/explicit) - InputAndOutput(Tensor &&); // NOLINT(runtime/explicit) + InputAndOutput(Tensor &&); // NOLINT(runtime/explicit) InputAndOutput(const std::shared_ptr &, const std::vector &, int32_t index); diff --git a/include/api/context.h b/include/api/context.h new file mode 100644 index 00000000000..31552c95f41 --- /dev/null +++ b/include/api/context.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_INCLUDE_API_CONTEXT_H +#define MINDSPORE_INCLUDE_API_CONTEXT_H + +#include +#include +#include "include/api/types.h" + +namespace mindspore { +namespace api { +class MS_API Context { + public: + static Context &Instance(); + const std::string &GetDeviceTarget() const; + Context &SetDeviceTarget(const std::string &device_target); + uint32_t GetDeviceID() const; + Context &SetDeviceID(uint32_t device_id); + + private: + Context(); + ~Context(); + class ContextImpl; + std::shared_ptr impl_; +}; +} // namespace api +} // namespace mindspore +#endif // MINDSPORE_INCLUDE_API_CONTEXT_H diff --git a/include/api/graph.h b/include/api/graph.h new file mode 100644 index 00000000000..42ca2c85bac --- /dev/null +++ b/include/api/graph.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_INCLUDE_API_GRAPH_H +#define MINDSPORE_INCLUDE_API_GRAPH_H + +#include +#include +#include +#include +#include "include/api/status.h" +#include "include/api/types.h" + +namespace mindspore { +namespace api { +class MS_API Graph { + public: + class GraphData; + explicit Graph(const std::shared_ptr &graph_data); + explicit Graph(std::shared_ptr &&graph_data); + + enum ModelType ModelType() const; + + private: + friend class GraphCell; + friend class ModelImpl; + std::shared_ptr graph_data_; +}; +} // namespace api +} // namespace mindspore +#endif // MINDSPORE_INCLUDE_API_GRAPH_H diff --git a/include/api/model.h b/include/api/model.h index dffa73db890..efd06aedc57 100644 --- a/include/api/model.h +++ b/include/api/model.h @@ -22,42 +22,39 @@ #include #include "include/api/status.h" #include "include/api/types.h" +#include "include/api/graph.h" +#include "include/api/cell.h" namespace mindspore { namespace api { class ModelImpl; // todo: minddata c++ interface class DataSet {}; -class NetWork {}; class MS_API Model { public: - Model(const std::string &device_type, uint32_t device_id); - Model(NetWork network, const std::string &device_type, uint32_t device_id); + explicit Model(const std::vector &network); + explicit Model(const GraphCell &graph); ~Model(); Model(const Model &) = delete; void operator=(const Model &) = delete; - Status LoadModel(const Buffer &model_data, ModelType type, const std::map &options); - Status LoadModel(const std::string &file_name, ModelType type, const std::map &options); - Status UnloadModel(); + Status Build(const std::map &options); - Status Train(const DataSet &dataset, std::map *outputs); - Status Eval(const DataSet &dataset, std::map *outputs); - Status Predict(const std::map &inputs, std::map *outputs); - Status Predict(const std::vector &inputs, std::map *outputs); + Status Train(const DataSet &dataset, bool data_sink, std::map *outputs); + Status Eval(const DataSet &dataset, bool data_sink, std::map *outputs); + Status Predict(const std::vector &inputs, std::vector *outputs); - Status GetInputsInfo(std::vector *tensor_list) const; - Status GetOutputsInfo(std::vector *tensor_list) const; + Status GetInputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) const; + Status GetOutputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) const; - static bool CheckModelSupport(const std::string& device_type, ModelType model_type); + static bool CheckModelSupport(const std::string &device_type, ModelType model_type); private: std::shared_ptr impl_; }; - -extern MS_API const char* kDeviceTypeAscendCL; -extern MS_API const char* kDeviceTypeAscendMS; } // namespace api } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_MODEL_H diff --git a/include/api/serialization.h b/include/api/serialization.h index 4fcd08c56a0..9750337d0d1 100644 --- a/include/api/serialization.h +++ b/include/api/serialization.h @@ -23,11 +23,13 @@ #include "include/api/status.h" #include "include/api/types.h" #include "include/api/model.h" +#include "include/api/graph.h" namespace mindspore { namespace api { class MS_API Serialization { public: + static Graph LoadModel(const std::string &file, ModelType model_type); static Status LoadCheckPoint(const std::string &ckpt_file, std::map *parameters); static Status SetParameters(const std::map ¶meters, Model *model); static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data); diff --git a/include/api/types.h b/include/api/types.h index 194401cc976..459071b1d02 100644 --- a/include/api/types.h +++ b/include/api/types.h @@ -102,6 +102,9 @@ class MS_API Buffer { std::shared_ptr impl_; }; +extern MS_API const char *kDeviceTypeAscend310; +extern MS_API const char *kDeviceTypeAscend910; + constexpr auto kModelOptionDumpCfgPath = "mindspore.option.dump_config_file_path"; constexpr auto kModelOptionDvppCfgPath = "mindspore.option.dvpp_config_file_path"; constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file diff --git a/mindspore/ccsrc/cxx_api/CMakeLists.txt b/mindspore/ccsrc/cxx_api/CMakeLists.txt index c74b5cf038e..0945f70e3b8 100644 --- a/mindspore/ccsrc/cxx_api/CMakeLists.txt +++ b/mindspore/ccsrc/cxx_api/CMakeLists.txt @@ -6,22 +6,35 @@ set(LOAD_MINDIR_SRC file(GLOB_RECURSE API_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR} "ops/*.cc") if (ENABLE_ACL) - file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/acl/*.cc" "model/model_converter_utils/*.cc") -elseif (ENABLE_D) - file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/ms/*.cc") + add_compile_definitions(ENABLE_ACL) + include_directories(${CMAKE_SOURCE_DIR}/graphengine/src/ge) + include_directories(${CMAKE_BINARY_DIR}/proto/ge) + file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR} + "model/acl/*.cc" + "model/model_converter_utils/*.cc" + "graph/acl/*.cc" + ) + +endif () +if (ENABLE_D) + file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/ms/*.cc" "graph/ms/*.cc") endif () set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc - ${CMAKE_CURRENT_SOURCE_DIR}/cell.cc - ${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc - ${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc - ${API_MS_INFER_SRC} - ${API_ACL_SRC} - ${API_OPS_SRC} - ${LOAD_MINDIR_SRC}) + ${CMAKE_CURRENT_SOURCE_DIR}/context.cc + ${CMAKE_CURRENT_SOURCE_DIR}/cell.cc + ${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc + ${CMAKE_CURRENT_SOURCE_DIR}/python_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/graph/graph.cc + ${CMAKE_CURRENT_SOURCE_DIR}/graph/graph_data.cc + ${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc + ${API_MS_INFER_SRC} + ${API_ACL_SRC} + ${API_OPS_SRC} + ${LOAD_MINDIR_SRC}) add_library(mindspore_shared_lib SHARED ${MSLIB_SRC}) -set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore PUBLIC_HEADER "${API_INCLUDE}") +set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore) target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY} -Wl,--whole-archive mindspore -Wl,--no-whole-archive proto_input mindspore_gvar mindspore::protobuf) @@ -69,5 +82,6 @@ endif () if (ENABLE_D) find_library(adump_server libadump_server.a ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) target_link_libraries(mindspore_shared_lib PRIVATE ${adump_server}) + target_link_libraries(mindspore_shared_lib PRIVATE mindspore_core hccl_adapter) endif () diff --git a/mindspore/ccsrc/cxx_api/cell.cc b/mindspore/ccsrc/cxx_api/cell.cc index 0b684ce5aa3..7329675c0f7 100644 --- a/mindspore/ccsrc/cxx_api/cell.cc +++ b/mindspore/ccsrc/cxx_api/cell.cc @@ -14,6 +14,9 @@ * limitations under the License. */ #include "include/api/cell.h" +#include "include/api/context.h" +#include "cxx_api/factory.h" +#include "cxx_api/graph/graph_impl.h" namespace mindspore::api { std::vector CellBase::operator()(const std::vector &inputs) const { return Clone()->Construct(inputs); } @@ -51,6 +54,52 @@ ParameterCell &ParameterCell::operator=(Tensor &&tensor) { return *this; } +GraphCell::GraphCell(const Graph &graph) + : graph_(std::make_shared(graph)), + executor_(Factory::Instance().Create(Context::Instance().GetDeviceTarget())) { + MS_EXCEPTION_IF_NULL(graph_); + MS_EXCEPTION_IF_NULL(executor_); + executor_->SetGraph(graph_); +} + +GraphCell::GraphCell(const std::shared_ptr &graph) + : graph_(graph), + executor_(Factory::Instance().Create(Context::Instance().GetDeviceTarget())) { + MS_EXCEPTION_IF_NULL(graph_); + MS_EXCEPTION_IF_NULL(executor_); + executor_->SetGraph(graph_); +} + +GraphCell::GraphCell(Graph &&graph) + : graph_(std::make_shared(graph)), + executor_(Factory::Instance().Create(Context::Instance().GetDeviceTarget())) { + MS_EXCEPTION_IF_NULL(graph_); + MS_EXCEPTION_IF_NULL(executor_); + executor_->SetGraph(graph_); +} + +Status GraphCell::Run(const std::vector &inputs, std::vector *outputs) { + MS_EXCEPTION_IF_NULL(executor_); + return executor_->Run(inputs, outputs); +} + +Status GraphCell::Load() { + MS_EXCEPTION_IF_NULL(executor_); + return executor_->Load(); +} + +Status GraphCell::GetInputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) const { + MS_EXCEPTION_IF_NULL(executor_); + return executor_->GetInputsInfo(names, shapes, data_types, mem_sizes); +} + +Status GraphCell::GetOutputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) const { + MS_EXCEPTION_IF_NULL(executor_); + return executor_->GetOutputsInfo(names, shapes, data_types, mem_sizes); +} + InputAndOutput::InputAndOutput() : cell_(nullptr), prev_(), index_(-1) {} InputAndOutput::InputAndOutput(const Tensor &tensor) diff --git a/mindspore/ccsrc/cxx_api/context.cc b/mindspore/ccsrc/cxx_api/context.cc new file mode 100644 index 00000000000..464fc5c20f4 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/context.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "include/api/context.h" +#include "utils/log_adapter.h" + +namespace mindspore::api { +class Context::ContextImpl { + public: + ContextImpl() : device_target_("NotSet"), device_id_(0) {} + const std::string &GetDeviceTarget() const { return device_target_; } + void SetDeviceTarget(std::string_view device_target) { device_target_ = device_target; } + uint32_t GetDeviceID() const { return device_id_; } + void SetDeviceID(uint32_t device_id) { device_id_ = device_id; } + + private: + std::string device_target_; + uint32_t device_id_; +}; + +Context &Context::Instance() { + static Context context; + return context; +} + +const std::string &Context::GetDeviceTarget() const { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->GetDeviceTarget(); +} + +Context &Context::SetDeviceTarget(const std::string &device_target) { + MS_EXCEPTION_IF_NULL(impl_); + impl_->SetDeviceTarget(device_target); + return *this; +} + +uint32_t Context::GetDeviceID() const { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->GetDeviceID(); +} + +Context &Context::SetDeviceID(uint32_t device_id) { + MS_EXCEPTION_IF_NULL(impl_); + impl_->SetDeviceID(device_id); + return *this; +} + +Context::Context() : impl_(std::make_shared()) { MS_EXCEPTION_IF_NULL(impl_); } + +Context::~Context() {} +} // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/factory.h b/mindspore/ccsrc/cxx_api/factory.h new file mode 100644 index 00000000000..7a7b45e12a9 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/factory.h @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_CXX_API_FACTORY_H +#define MINDSPORE_CCSRC_CXX_API_FACTORY_H +#include +#include +#include +#include +#include +#include +#include "utils/utils.h" + +namespace mindspore::api { +template +class Factory { + using U = std::function()>; + + public: + Factory(const Factory &) = delete; + void operator=(const Factory &) = delete; + + static Factory &Instance() { + static Factory instance; + return instance; + } + + void Register(const std::string &device_name, U &&creator) { + if (creators_.find(device_name) == creators_.end()) { + (void)creators_.emplace(device_name, creator); + } + } + + bool CheckModelSupport(const std::string &device_name) { + return std::any_of(creators_.begin(), creators_.end(), + [&device_name](const std::pair &item) { return item.first == device_name; }); + } + + std::shared_ptr Create(const std::string &device_name) { + auto iter = creators_.find(device_name); + if (creators_.end() != iter) { + MS_EXCEPTION_IF_NULL(iter->second); + return (iter->second)(); + } + + MS_LOG(ERROR) << "Unsupported device target " << device_name; + return nullptr; + } + + private: + Factory() = default; + ~Factory() = default; + std::map creators_; +}; + +template +class Registrar { + using U = std::function()>; + + public: + Registrar(const std::string &device_name, U creator) { + Factory::Instance().Register(device_name, std::move(creator)); + } + ~Registrar() = default; +}; + +#define API_FACTORY_REG(BASE_CLASS, DEVICE_NAME, DERIVE_CLASS) \ + static const Registrar g_api_##DERIVE_CLASS##_registrar_##DEVICE_NAME##_reg( \ + #DEVICE_NAME, []() { return std::make_shared(); }); +} // namespace mindspore::api +#endif // MINDSPORE_CCSRC_CXX_API_FACTORY_H diff --git a/mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.cc b/mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.cc new file mode 100644 index 00000000000..9a6af5e7a74 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.cc @@ -0,0 +1,266 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "cxx_api/graph/acl/acl_graph_impl.h" +#include "include/api/context.h" +#include "cxx_api/model/acl/model_converter.h" +#include "cxx_api/python_utils.h" +#include "utils/log_adapter.h" + +namespace mindspore::api { +API_FACTORY_REG(GraphCell::GraphImpl, Ascend310, AclGraphImpl); +std::weak_ptr AclGraphImpl::global_acl_env_; +std::mutex AclGraphImpl::global_acl_env_mutex_; + +AclGraphImpl::AclGraphImpl() + : init_flag_(false), + load_flag_(false), + device_type_("AscendCL"), + device_id_(Context::Instance().GetDeviceID()), + context_(nullptr), + acl_env_(nullptr) {} + +AclGraphImpl::~AclGraphImpl() { (void)FinalizeEnv(); } + +Status AclGraphImpl::Run(const std::vector &inputs, std::vector *outputs) { + MS_EXCEPTION_IF_NULL(outputs); + Status ret = Load(); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Prepare model resource failed."; + return FAILED; + } + + return model_process_.PredictFromHost(inputs, outputs); +} + +Status AclGraphImpl::GetInputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) { + Status ret = Load(); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Prepare model resource failed."; + return FAILED; + } + + return model_process_.GetInputsInfo(names, shapes, data_types, mem_sizes); +} + +Status AclGraphImpl::GetOutputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) { + Status ret = Load(); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Prepare model resource failed."; + return FAILED; + } + + return model_process_.GetOutputsInfo(names, shapes, data_types, mem_sizes); +} + +Status AclGraphImpl::LoadAclModel(Buffer om_data) { + MS_LOG(INFO) << "Start load acl model."; + // acl load model + uint32_t acl_model_id; + auto acl_ret = aclmdlLoadFromMem(om_data.Data(), om_data.DataSize(), &acl_model_id); + if (acl_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Call aclmdlLoadFromMem failed."; + return FAILED; + } + + // acl init model resource + model_process_.set_model_id(acl_model_id); + Status ret = model_process_.PreInitModelResource(); + if (ret != SUCCESS) { + (void)aclmdlUnload(acl_model_id); + MS_LOG(ERROR) << "Pre init model resource failed."; + return FAILED; + } + + MS_LOG(INFO) << "Load acl model success."; + return SUCCESS; +} + +Status AclGraphImpl::InitEnv() { + if (init_flag_) { + return SUCCESS; + } + + aclError ret; + { + std::lock_guard lock(global_acl_env_mutex_); + acl_env_ = global_acl_env_.lock(); + if (acl_env_ != nullptr) { + MS_LOG(INFO) << "Acl has been initialized, skip."; + } else { + acl_env_ = std::make_shared(""); + if (acl_env_->GetErrno() != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Execute aclInit Failed"; + return FAILED; + } + global_acl_env_ = acl_env_; + MS_LOG(INFO) << "Acl init success"; + } + } + + ret = aclrtSetDevice(device_id_); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Acl open device " << device_id_ << " failed"; + return FAILED; + } + MS_LOG(INFO) << "Open device " << device_id_ << " success"; + + ret = aclrtCreateContext(&context_, device_id_); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Acl create context failed"; + return FAILED; + } + MS_LOG(INFO) << "Create context success"; + + aclrtRunMode run_mode; + ret = aclrtGetRunMode(&run_mode); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Acl get run mode failed"; + return FAILED; + } + bool is_device = (run_mode == ACL_DEVICE); + model_process_.SetIsDevice(is_device); + MS_LOG(INFO) << "Get run mode success is device input/output " << is_device; + + MS_LOG(INFO) << "Init acl success, device id " << device_id_; + init_flag_ = true; + return SUCCESS; +} + +Status AclGraphImpl::FinalizeEnv() { + if (!init_flag_) { + return SUCCESS; + } + + aclError rt_ret = aclrtSetCurrentContext(context_); + if (rt_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Set the ascend device context failed"; + return FAILED; + } + + Status ret = model_process_.UnLoad(); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Unload model inner failed."; + return FAILED; + } + + if (context_ != nullptr) { + rt_ret = aclrtDestroyContext(context_); + if (rt_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Destroy context failed"; + } + context_ = nullptr; + } + MS_LOG(INFO) << "End to destroy context"; + + rt_ret = aclrtResetDevice(device_id_); + if (rt_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Reset device " << device_id_ << " failed"; + } + MS_LOG(INFO) << "End to reset device " << device_id_; + + init_flag_ = false; + return SUCCESS; +} + +Status AclGraphImpl::Load() { + // check graph type + if (graph_->ModelType() != ModelType::kOM) { + Status ret = ConvertToOM(); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Load Failed."; + return FAILED; + } + } + + const auto &graph_data = GraphImpl::MutableGraphData(); + MS_EXCEPTION_IF_NULL(graph_data); + auto om_data = graph_data->GetOMData(); + + // init + Status ret = InitEnv(); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "InitEnv failed."; + return FAILED; + } + + // load model + if (!load_flag_) { + ret = LoadAclModel(om_data); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Load acl model failed."; + return ret; + } + load_flag_ = true; + } + + aclError rt_ret = aclrtSetCurrentContext(context_); + if (rt_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Set the ascend device context failed"; + return FAILED; + } + + return SUCCESS; +} + +Status AclGraphImpl::ConvertToOM() { + MS_LOG(INFO) << "Start convert to om model."; + RegAllOpFromPython(); + if (graph_ == nullptr) { + MS_LOG(ERROR) << "Invalid graph_ is null."; + return FAILED; + } + + auto &graph_data = GraphImpl::MutableGraphData(); + MS_EXCEPTION_IF_NULL(graph_data); + if (graph_->ModelType() == ModelType::kOM) { + MS_LOG(INFO) << "This model has been built, skip."; + return SUCCESS; + } else if (graph_->ModelType() == ModelType::kMindIR) { + auto func_graph = graph_data->GetFuncGraph(); + MS_EXCEPTION_IF_NULL(func_graph); + ModelConverter model_converter; + Buffer om_data = model_converter.LoadMindIR(func_graph); + if (om_data.Data() == nullptr || om_data.DataSize() == 0) { + MS_LOG(ERROR) << "Convert MindIR to OM failed."; + return FAILED; + } + graph_data = std::make_shared(om_data, ModelType::kOM); + MS_LOG(INFO) << "Convert MindIR to OM success."; + return SUCCESS; + } + MS_LOG(ERROR) << "Unsupported ModelType " << graph_->ModelType(); + return FAILED; +} + +AclGraphImpl::AclEnvGuard::AclEnvGuard(std::string_view cfg_file) { + errno_ = aclInit(cfg_file.data()); + if (errno_ != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Execute aclInit Failed"; + return; + } + MS_LOG(INFO) << "Acl init success"; +} + +AclGraphImpl::AclEnvGuard::~AclEnvGuard() { + errno_ = aclFinalize(); + if (errno_ != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Finalize acl failed"; + } + MS_LOG(INFO) << "Acl finalize success"; +} +} // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.h b/mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.h new file mode 100644 index 00000000000..97b9f3edf43 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.h @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H +#define MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H +#include +#include +#include +#include +#include +#include +#include "include/api/graph.h" +#include "cxx_api/graph/acl/model_process.h" +#include "cxx_api/graph/graph_impl.h" +#include "cxx_api/factory.h" + +namespace mindspore::api { +class AclGraphImpl : public GraphCell::GraphImpl { + public: + AclGraphImpl(); + ~AclGraphImpl() override; + + Status Run(const std::vector &inputs, std::vector *outputs) override; + Status Load() override; + Status GetInputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) override; + Status GetOutputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) override; + + private: + class AclEnvGuard; + + Status ConvertToOM(); + Status InitEnv(); + Status FinalizeEnv(); + Status LoadAclModel(Buffer om_data); + + bool init_flag_; + bool load_flag_; + std::string device_type_; + int32_t device_id_; + aclrtContext context_; + + std::shared_ptr acl_env_; + static std::weak_ptr global_acl_env_; + static std::mutex global_acl_env_mutex_; + + ModelProcess model_process_; +}; + +class AclGraphImpl::AclEnvGuard { + public: + explicit AclEnvGuard(std::string_view cfg_file); + ~AclEnvGuard(); + aclError GetErrno() const { return errno_; } + + private: + aclError errno_; +}; +} // namespace mindspore::api +#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H diff --git a/mindspore/ccsrc/cxx_api/model/acl/model_process.cc b/mindspore/ccsrc/cxx_api/graph/acl/model_process.cc similarity index 77% rename from mindspore/ccsrc/cxx_api/model/acl/model_process.cc rename to mindspore/ccsrc/cxx_api/graph/acl/model_process.cc index 9df51e1afd4..9e1d4d6884e 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/model_process.cc +++ b/mindspore/ccsrc/cxx_api/graph/acl/model_process.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "cxx_api/model/acl/model_process.h" +#include "cxx_api/graph/acl/model_process.h" #include #include #include "utils/utils.h" @@ -35,17 +35,33 @@ static DataType TransToApiType(aclDataType data_type) { } } -static void ConstructTensorDesc(const std::vector &acl_tensor_list, std::vector *tensor_list) { - MS_EXCEPTION_IF_NULL(tensor_list); - tensor_list->clear(); +template +inline static void ClearIfNotNull(T *vec) { + if (vec != nullptr) { + vec->clear(); + } +} +template > +inline static void PushbackIfNotNull(U *vec, T &&item) { + if (vec != nullptr) { + vec->emplace_back(item); + } +} + +static void ConstructTensorDesc(const std::vector &acl_tensor_list, std::vector *names, + std::vector> *shapes, std::vector *data_types, + std::vector *mem_sizes) { + ClearIfNotNull(names); + ClearIfNotNull(shapes); + ClearIfNotNull(data_types); + ClearIfNotNull(mem_sizes); for (size_t i = 0; i < acl_tensor_list.size(); ++i) { const auto &info = acl_tensor_list[i]; - Tensor tensor_desc; - tensor_desc.SetName(info.name); - tensor_desc.SetDataType(TransToApiType(info.data_type)); - tensor_desc.SetShape(info.dims); - tensor_list->push_back(tensor_desc); + PushbackIfNotNull(names, info.name); + PushbackIfNotNull(shapes, info.dims); + PushbackIfNotNull(data_types, TransToApiType(info.data_type)); + PushbackIfNotNull(mem_sizes, info.buffer_size); } } @@ -272,7 +288,7 @@ Status ModelProcess::UnLoad() { return SUCCESS; } -Status ModelProcess::CheckAndInitInput(const std::map &inputs) { +Status ModelProcess::CheckAndInitInput(const std::vector &inputs) { aclError ret; inputs_ = aclmdlCreateDataset(); // check inputs @@ -282,29 +298,16 @@ Status ModelProcess::CheckAndInitInput(const std::map &inpu return INVALID_INPUTS; } for (size_t i = 0; i < input_infos_.size(); ++i) { - const std::string &input_name = input_infos_[i].name; - auto iter = inputs.find(input_name); - if (iter == inputs.end()) { - MS_LOG(ERROR) << "Model missing input " << input_name; - return INVALID_INPUTS; - } - - if (iter->second.DataSize() != input_infos_[i].buffer_size) { + if (inputs[i].DataSize() != input_infos_[i].buffer_size) { MS_LOG(ERROR) << "input " << i << " data size not match, required size " << input_infos_[i].buffer_size - << ", given count " << iter->second.DataSize(); + << ", given count " << inputs[i].DataSize(); return INVALID_INPUTS; } } // copy inputs for (size_t i = 0; i < input_infos_.size(); ++i) { const auto &info = input_infos_[i]; - auto iter = inputs.find(info.name); - if (iter == inputs.end()) { - MS_LOG(ERROR) << "Model missing input " << info.name; - return INVALID_INPUTS; - } - - const auto &input = iter->second; + const auto &input = inputs[i]; const void *data = input.Data(); void *input_buffer = nullptr; @@ -333,42 +336,7 @@ Status ModelProcess::CheckAndInitInput(const std::map &inpu return SUCCESS; } -Status ModelProcess::CheckAndInitDvppInput(const void *dvpp_outputs_buffer_dev, size_t dvpp_outputs_buffer_size, - size_t input_index) { - aclError ret; - inputs_ = aclmdlCreateDataset(); - // check inputs - if (input_index >= input_infos_.size()) { - MS_LOG(ERROR) << "inputs count not match, required count " << input_infos_.size() << ", given index " - << input_index; - return INVALID_INPUTS; - } - if (dvpp_outputs_buffer_dev == nullptr) { - MS_LOG(ERROR) << "input " << 0 << " cannot be null"; - return FAILED; - } - if (dvpp_outputs_buffer_size != input_infos_[input_index].buffer_size) { - MS_LOG(ERROR) << "input " << 0 << " data size not match, required size " << input_infos_[input_index].buffer_size - << ", given count " << dvpp_outputs_buffer_size; - return INVALID_INPUTS; - } - // copy inputs - auto &info = input_infos_[input_index]; - auto data_buffer = aclCreateDataBuffer(const_cast(dvpp_outputs_buffer_dev), info.buffer_size); - if (data_buffer == nullptr) { - MS_LOG(ERROR) << "Create Data Buffer failed"; - return FAILED; - } - ret = aclmdlAddDatasetBuffer(inputs_, data_buffer); - if (ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "add data buffer failed"; - aclDestroyDataBuffer(data_buffer); - return FAILED; - } - return SUCCESS; -} - -Status ModelProcess::Predict(const std::map &inputs, std::map *outputs) { +Status ModelProcess::PredictFromHost(const std::vector &inputs, std::vector *outputs) { MS_EXCEPTION_IF_NULL(outputs); aclError acl_ret; Status ret = CheckAndInitInput(inputs); @@ -392,18 +360,7 @@ Status ModelProcess::Predict(const std::map &inputs, std::m return SUCCESS; } -size_t ModelProcess::GetBatchSize() const { - if (input_infos_.empty()) { - MS_LOG(ERROR) << "Model is not loaded"; - return 0; - } - if (input_infos_[0].dims.empty()) { - return 1; - } - return static_cast(input_infos_[0].dims[0]); -} - -Status ModelProcess::BuildOutputs(std::map *outputs) { +Status ModelProcess::BuildOutputs(std::vector *outputs) { MS_EXCEPTION_IF_NULL(outputs); aclError ret; // copy outputs @@ -411,14 +368,13 @@ Status ModelProcess::BuildOutputs(std::map *outputs) { aclrtMemcpyKind kind = is_run_on_device_ ? ACL_MEMCPY_HOST_TO_HOST : ACL_MEMCPY_DEVICE_TO_HOST; for (size_t i = 0; i < output_infos_.size(); ++i) { const auto &info = output_infos_[i]; - // todo - outputs->emplace(info.name, Buffer()); - auto output = outputs->rbegin()->second; - if (!output.ResizeData(info.buffer_size)) { + outputs->emplace_back(Buffer()); + auto output = outputs->rbegin(); + if (!output->ResizeData(info.buffer_size)) { MS_LOG(ERROR) << "new output data buffer failed, data size " << info.buffer_size; return FAILED; } - ret = aclrtMemcpy(output.MutableData(), output.DataSize(), info.device_data, info.buffer_size, kind); + ret = aclrtMemcpy(output->MutableData(), output->DataSize(), info.device_data, info.buffer_size, kind); if (ret != ACL_ERROR_NONE) { MS_LOG(ERROR) << "Memcpy output " << i << " from " << (is_run_on_device_ ? "host" : "device") << " to host failed, memory size " << info.buffer_size; @@ -428,13 +384,15 @@ Status ModelProcess::BuildOutputs(std::map *outputs) { return SUCCESS; } -Status ModelProcess::GetInputsInfo(std::vector *tensor_list) const { - ConstructTensorDesc(input_infos_, tensor_list); +Status ModelProcess::GetInputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) const { + ConstructTensorDesc(input_infos_, names, shapes, data_types, mem_sizes); return SUCCESS; } -Status ModelProcess::GetOutputsInfo(std::vector *tensor_list) const { - ConstructTensorDesc(output_infos_, tensor_list); +Status ModelProcess::GetOutputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) const { + ConstructTensorDesc(output_infos_, names, shapes, data_types, mem_sizes); return SUCCESS; } } // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/model/acl/model_process.h b/mindspore/ccsrc/cxx_api/graph/acl/model_process.h similarity index 72% rename from mindspore/ccsrc/cxx_api/model/acl/model_process.h rename to mindspore/ccsrc/cxx_api/graph/acl/model_process.h index 24f3f386236..e9c3363bd91 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/model_process.h +++ b/mindspore/ccsrc/cxx_api/graph/acl/model_process.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_PROCESS_H -#define MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_PROCESS_H +#ifndef MINDSPORE_CCSRC_CXXAPI_GRAPH_ACL_MODEL_PROCESS_H +#define MINDSPORE_CCSRC_CXXAPI_GRAPH_ACL_MODEL_PROCESS_H #include #include #include @@ -34,12 +34,6 @@ struct AclTensorInfo { std::string name; }; -struct ImagesDvppOutput { - void *buffer_device = nullptr; - size_t buffer_size = 0; - size_t input_index = 0; -}; - class ModelProcess { public: ModelProcess() @@ -53,24 +47,23 @@ class ModelProcess { ~ModelProcess() {} Status LoadModelFromFile(const std::string &file_name, uint32_t *model_id); Status UnLoad(); - Status Predict(const std::map &inputs, std::map *outputs); + Status PredictFromHost(const std::vector &inputs, std::vector *outputs); Status PreInitModelResource(); - Status GetInputsInfo(std::vector *tensor_list) const; - Status GetOutputsInfo(std::vector *tensor_list) const; + Status GetInputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) const; + Status GetOutputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) const; // override this method to avoid request/reply data copy void SetIsDevice(bool is_device) { is_run_on_device_ = is_device; } - size_t GetBatchSize() const; void set_model_id(uint32_t model_id) { model_id_ = model_id; } uint32_t model_id() const { return model_id_; } private: Status CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset); - Status CheckAndInitInput(const std::map &inputs); - Status CheckAndInitDvppInput(const void *dvpp_outputs_buffer_dev, size_t dvpp_outputs_buffer_size, - size_t input_index); - Status BuildOutputs(std::map *outputs); + Status CheckAndInitInput(const std::vector &inputs); + Status BuildOutputs(std::vector *outputs); Status InitInputsBuffer(); Status InitOutputsBuffer(); @@ -90,4 +83,4 @@ class ModelProcess { }; } // namespace mindspore::api -#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_PROCESS_H +#endif // MINDSPORE_CCSRC_CXXAPI_GRAPH_ACL_MODEL_PROCESS_H diff --git a/mindspore/ccsrc/cxx_api/graph/graph.cc b/mindspore/ccsrc/cxx_api/graph/graph.cc new file mode 100644 index 00000000000..dfd372d4fcd --- /dev/null +++ b/mindspore/ccsrc/cxx_api/graph/graph.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "include/api/graph.h" +#include "cxx_api/graph/graph_data.h" +#include "utils/log_adapter.h" + +namespace mindspore::api { +Graph::Graph(const std::shared_ptr &graph_data) : graph_data_(graph_data) {} + +Graph::Graph(std::shared_ptr &&graph_data) : graph_data_(graph_data) {} + +ModelType Graph::ModelType() const { + MS_EXCEPTION_IF_NULL(graph_data_); + return graph_data_->ModelType(); +} +} // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/graph/graph_data.cc b/mindspore/ccsrc/cxx_api/graph/graph_data.cc new file mode 100644 index 00000000000..e6ef8f52650 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/graph/graph_data.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "cxx_api/graph/graph_data.h" +#include "utils/log_adapter.h" +#ifdef ENABLE_ACL +#include "framework/common/helper/model_helper.h" +#endif + +namespace mindspore::api { +Graph::GraphData::GraphData(const FuncGraphPtr &func_graph, enum ModelType model_type) + : func_graph_(nullptr), om_data_(), model_type_(ModelType::kUnknownType) { + if (model_type != ModelType::kMindIR) { + MS_LOG(EXCEPTION) << "Invalid ModelType " << model_type; + } + func_graph_ = func_graph; + model_type_ = model_type; +} + +Graph::GraphData::GraphData(Buffer om_data, enum ModelType model_type) + : func_graph_(nullptr), om_data_(), model_type_(ModelType::kUnknownType) { + if (model_type != ModelType::kOM) { + MS_LOG(EXCEPTION) << "Invalid ModelType " << model_type; + } + +#ifdef ENABLE_ACL + // check om + ge::ModelHelper helper; + ge::ModelData model_data; + model_data.model_data = om_data.MutableData(); + model_data.model_len = om_data.DataSize(); + ge::Status ret = helper.LoadModel(model_data); + if (ret != ge::SUCCESS) { + MS_LOG(EXCEPTION) << "Invalid input data cannot parse to om."; + } + + om_data_ = om_data; + model_type_ = model_type; +#else + MS_LOG(EXCEPTION) << "Unsupported ModelType OM."; +#endif +} + +FuncGraphPtr Graph::GraphData::GetFuncGraph() const { + if (model_type_ != ModelType::kMindIR) { + MS_LOG(ERROR) << "Invalid ModelType " << model_type_; + return nullptr; + } + + return func_graph_; +} + +Buffer Graph::GraphData::GetOMData() const { + if (model_type_ != ModelType::kOM) { + MS_LOG(ERROR) << "Invalid ModelType " << model_type_; + return Buffer(); + } + + return om_data_; +} +} // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/graph/graph_data.h b/mindspore/ccsrc/cxx_api/graph/graph_data.h new file mode 100644 index 00000000000..f52421d6547 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/graph/graph_data.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H +#define MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H + +#include +#include +#include +#include +#include "include/api/graph.h" +#include "include/api/types.h" +#include "ir/func_graph.h" + +namespace mindspore::api { +class Graph::GraphData { + public: + GraphData(); + + explicit GraphData(const FuncGraphPtr &func_graph, enum ModelType model_type = kMindIR); + + GraphData(Buffer om_data, enum ModelType model_type); + + enum ModelType ModelType() const { return model_type_; } + + FuncGraphPtr GetFuncGraph() const; + + Buffer GetOMData() const; + + private: + FuncGraphPtr func_graph_; + Buffer om_data_; + enum ModelType model_type_; +}; +} // namespace mindspore::api +#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H diff --git a/mindspore/ccsrc/cxx_api/graph/graph_impl.h b/mindspore/ccsrc/cxx_api/graph/graph_impl.h new file mode 100644 index 00000000000..a2c651c4cf4 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/graph/graph_impl.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H +#define MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H +#include +#include +#include +#include +#include +#include +#include "include/api/cell.h" +#include "include/api/graph.h" +#include "cxx_api/graph/graph_data.h" +#include "utils/utils.h" + +namespace mindspore::api { +class GraphCell::GraphImpl { + public: + GraphImpl() = default; + virtual ~GraphImpl() = default; + + std::shared_ptr &MutableGraphData() const { return graph_->graph_data_; } + void SetGraph(const std::shared_ptr &graph) { graph_ = graph; } + + virtual Status Run(const std::vector &inputs, std::vector *outputs) = 0; + virtual Status Load() = 0; + + virtual Status GetInputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) = 0; + virtual Status GetOutputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) = 0; + + protected: + std::shared_ptr graph_; +}; +} // namespace mindspore::api + +#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H diff --git a/mindspore/ccsrc/cxx_api/graph/ms/ms_graph_impl.cc b/mindspore/ccsrc/cxx_api/graph/ms/ms_graph_impl.cc new file mode 100644 index 00000000000..442842baab8 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/graph/ms/ms_graph_impl.cc @@ -0,0 +1,334 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "cxx_api/graph/ms/ms_graph_impl.h" +#include +#include "include/api/context.h" +#include "cxx_api/factory.h" +#include "cxx_api/python_utils.h" +#include "utils/log_adapter.h" +#include "utils/context/context_extends.h" +#include "mindspore/core/base/base_ref_utils.h" +#include "backend/session/session_factory.h" +#include "backend/session/executor_manager.h" +#include "runtime/device/kernel_runtime_manager.h" + +namespace mindspore::api { +API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, MsGraphImpl); + +static DataType TransTypeId2InferDataType(TypeId type_id) { + const std::map id2type_map{ + {TypeId::kNumberTypeBegin, api::kMsUnknown}, {TypeId::kNumberTypeBool, api::kMsBool}, + {TypeId::kNumberTypeFloat64, api::kMsFloat64}, {TypeId::kNumberTypeInt8, api::kMsInt8}, + {TypeId::kNumberTypeUInt8, api::kMsUint8}, {TypeId::kNumberTypeInt16, api::kMsInt16}, + {TypeId::kNumberTypeUInt16, api::kMsUint16}, {TypeId::kNumberTypeInt32, api::kMsInt32}, + {TypeId::kNumberTypeUInt32, api::kMsUint32}, {TypeId::kNumberTypeInt64, api::kMsInt64}, + {TypeId::kNumberTypeUInt64, api::kMsUint64}, {TypeId::kNumberTypeFloat16, api::kMsFloat16}, + {TypeId::kNumberTypeFloat32, api::kMsFloat32}, + }; + + // cppcheck-suppress stlIfFind + if (auto it = id2type_map.find(type_id); it != id2type_map.end()) { + return it->second; + } + + MS_LOG(WARNING) << "Unsupported data id " << type_id; + return api::kMsUnknown; +} + +template +inline static void ClearIfNotNull(T *vec) { + if (vec != nullptr) { + vec->clear(); + } +} + +template > +inline static void PushbackIfNotNull(U *vec, T &&item) { + if (vec != nullptr) { + vec->emplace_back(item); + } +} + +MsGraphImpl::MsGraphImpl() + : session_impl_(nullptr), + graph_id_(0), + device_type_("Ascend"), + device_id_(Context::Instance().GetDeviceID()), + context_(nullptr), + inputs_(), + outputs_(), + input_names_(), + output_names_(), + load_flag_(false) {} + +MsGraphImpl::~MsGraphImpl() { (void)FinalizeEnv(); } + +Status MsGraphImpl::InitEnv() { + RegAllOpFromPython(); + auto ms_context = MsContext::GetInstance(); + if (ms_context == nullptr) { + MS_LOG(ERROR) << "Get Context failed!"; + return FAILED; + } + + ms_context->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); + ms_context->set_param(MS_CTX_DEVICE_ID, device_id_); + ms_context->set_param(MS_CTX_DEVICE_TARGET, kAscendDevice); + if (!context::OpenTsd(ms_context)) { + MS_LOG(ERROR) << "Session init OpenTsd failed!"; + return FAILED; + } + + session_impl_ = session::SessionFactory::Get().Create(kDavinciInferenceDevice); + if (session_impl_ == nullptr) { + MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << kDavinciInferenceDevice + << " is available."; + return FAILED; + } + session_impl_->Init(device_id_); + + return SUCCESS; +} + +Status MsGraphImpl::FinalizeEnv() { + MS_LOG_INFO << "Start finalize env"; + pybind11::gil_scoped_acquire acquire; + session::ExecutorManager::Instance().Clear(); + device::KernelRuntimeManager::Instance().ClearRuntimeResource(); + auto ms_context = MsContext::GetInstance(); + if (ms_context == nullptr) { + MS_LOG(ERROR) << "Get Context failed!"; + return FAILED; + } + if (!context::CloseTsd(ms_context)) { + MS_LOG(ERROR) << "CloseTsd failed!"; + return FAILED; + } + MS_LOG(INFO) << "End finalize env"; + return SUCCESS; +} + +Status MsGraphImpl::CompileGraph(const std::shared_ptr &funcGraphPtr) { + MS_ASSERT(session_impl_ != nullptr); + try { + graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); + pybind11::gil_scoped_release gil_release; + return SUCCESS; + } catch (std::exception &e) { + MS_LOG(ERROR) << "CompileGraph failed: " << e.what(); + return FAILED; + } +} + +std::vector MsGraphImpl::RunGraph(const std::vector &inputs) { + try { + VectorRef outputs; + session_impl_->RunGraph(graph_id_, inputs, &outputs); + return TransformVectorRefToMultiTensor(outputs); + } catch (std::exception &e) { + MS_LOG(ERROR) << "RunGraph failed: " << e.what(); + return std::vector(); + } +} + +Status MsGraphImpl::CheckModelInputs(const std::vector &inputs) const { + MS_ASSERT(session_impl_ != nullptr); + std::string error_msg; + if (!session_impl_->CheckModelInputs(graph_id_, inputs, &error_msg)) { + return Status(INVALID_INPUTS, error_msg); + } + return SUCCESS; +} + +Status MsGraphImpl::ExecuteModel(const std::vector &request, std::vector *reply) { + MS_EXCEPTION_IF_NULL(reply); + if (context_ == nullptr) { + MS_LOG(ERROR) << "rtCtx is nullptr"; + return FAILED; + } + rtError_t rt_ret = rtCtxSetCurrent(context_); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Set Ascend rtCtx failed"; + return FAILED; + } + + vector inputs; + for (size_t i = 0; i < request.size(); i++) { + auto &item = request[i]; + auto input = inputs_[i]; + if (input->Size() != item.DataSize()) { + MS_LOG(ERROR) << "Input " << i << " data size " << item.DataSize() << " not match model input data size " + << input->Size(); + return FAILED; + } + auto ret = memcpy_s(input->data_c(), input->Size(), item.Data(), item.DataSize()); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Tensor copy failed"; + return FAILED; + } + inputs.push_back(input); + } + vector outputs = RunGraph(inputs); + if (outputs.empty()) { + MS_LOG(ERROR) << "Execute Model Failed"; + return FAILED; + } + reply->clear(); + std::transform(outputs.begin(), outputs.end(), std::back_inserter(*reply), + [](const tensor::TensorPtr &tensor) { return Buffer(tensor->data_c(), tensor->Size()); }); + return SUCCESS; +} + +Status MsGraphImpl::GetInputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) { + if (!load_flag_) { + Status ret = Load(); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "PrepareModel failed."; + return ret; + } + } + + ClearIfNotNull(names); + ClearIfNotNull(shapes); + ClearIfNotNull(data_types); + ClearIfNotNull(mem_sizes); + for (size_t i = 0; i < inputs_.size(); i++) { + auto &tensor = inputs_[i]; + PushbackIfNotNull(names, input_names_[i]); + PushbackIfNotNull(shapes, tensor->shape()); + PushbackIfNotNull(data_types, TransTypeId2InferDataType(tensor->data_type())); + PushbackIfNotNull(mem_sizes, tensor->DataSize()); + } + return SUCCESS; +} + +Status MsGraphImpl::GetOutputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) { + if (!load_flag_) { + Status ret = Load(); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "PrepareModel failed."; + return ret; + } + } + + ClearIfNotNull(names); + ClearIfNotNull(shapes); + ClearIfNotNull(data_types); + ClearIfNotNull(mem_sizes); + for (size_t i = 0; i < outputs_.size(); i++) { + auto &tensor = outputs_[i]; + PushbackIfNotNull(names, output_names_[i]); + PushbackIfNotNull(shapes, tensor->shape()); + PushbackIfNotNull(data_types, TransTypeId2InferDataType(tensor->data_type())); + PushbackIfNotNull(mem_sizes, tensor->DataSize()); + } + + return SUCCESS; +} + +Status MsGraphImpl::Load() { + // check graph type + if (graph_->ModelType() != ModelType::kMindIR) { + MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType(); + return INVALID_INPUTS; + } + + const auto &graph_data = GraphImpl::MutableGraphData(); + MS_EXCEPTION_IF_NULL(graph_data); + auto func_graph = graph_data->GetFuncGraph(); + + // init + Status ret = InitEnv(); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "InitEnv failed."; + return FAILED; + } + + // load model + if (!load_flag_) { + ret = CompileGraph(func_graph); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Compile graph model failed"; + return FAILED; + } + session_impl_->GetModelInputsInfo(graph_id_, &inputs_, &input_names_); + session_impl_->GetModelOutputsInfo(graph_id_, &outputs_, &output_names_); + if (inputs_.empty() || inputs_.size() != input_names_.size()) { + MS_LOG_ERROR << "Get model inputs info failed"; + return FAILED; + } + if (outputs_.empty() || outputs_.size() != output_names_.size()) { + MS_LOG_ERROR << "Get model outputs info failed"; + return FAILED; + } + + // save d context + rtError_t rt_ret = rtCtxGetCurrent(&context_); + if (rt_ret != RT_ERROR_NONE || context_ == nullptr) { + MS_LOG(ERROR) << "the ascend device context is null"; + return FAILED; + } + + MS_LOG(INFO) << "Load model success"; + load_flag_ = true; + } + + rtError_t rt_ret = rtCtxSetCurrent(context_); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Set the ascend device context failed"; + return FAILED; + } + + return SUCCESS; +} + +Status MsGraphImpl::Run(const std::vector &inputs, std::vector *outputs) { + MS_EXCEPTION_IF_NULL(outputs); + if (!load_flag_) { + Status ret = Load(); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "PrepareModel failed."; + return ret; + } + } + + if (inputs.size() != inputs_.size()) { + MS_LOG(ERROR) << "inputs count not match, required count " << inputs_.size() << ", given count " << inputs.size(); + return INVALID_INPUTS; + } + + for (size_t i = 0; i < inputs_.size(); ++i) { + if (inputs[i].DataSize() != inputs_[i]->Size()) { + MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_[i]->Size() << ", given count " + << inputs[i].DataSize(); + return INVALID_INPUTS; + } + } + if (ExecuteModel(inputs, outputs) != SUCCESS) { + MS_LOG(ERROR) << "Execute Model Failed"; + return FAILED; + } + if (outputs_.size() != outputs->size()) { + MS_LOG(ERROR) << "Predict output size " << outputs->size() << " not match output size got from model info " + << outputs_.size(); + return FAILED; + } + + return SUCCESS; +} +} // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/graph/ms/ms_graph_impl.h b/mindspore/ccsrc/cxx_api/graph/ms/ms_graph_impl.h new file mode 100644 index 00000000000..b90842811b7 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/graph/ms/ms_graph_impl.h @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H +#define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H +#include +#include +#include +#include +#include +#include +#include "include/api/status.h" +#include "include/api/graph.h" +#include "cxx_api/graph/graph_impl.h" +#include "backend/session/session_basic.h" +#include "ir/anf.h" +#include "cxx_api/model/model_impl.h" +#include "runtime/context.h" + +namespace mindspore::api { +class MsGraphImpl : public GraphCell::GraphImpl { + public: + MsGraphImpl(); + ~MsGraphImpl() override; + + Status Run(const std::vector &inputs, std::vector *outputs) override; + Status Load() override; + Status GetInputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) override; + Status GetOutputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) override; + + private: + Status InitEnv(); + Status FinalizeEnv(); + Status CompileGraph(const std::shared_ptr &funcGraphPtr); + Status CheckModelInputs(const std::vector &inputs) const; + std::vector RunGraph(const std::vector &inputs); + Status ExecuteModel(const std::vector &inputs, std::vector *outputs); + + std::shared_ptr session_impl_; + uint32_t graph_id_; + std::string device_type_; + uint32_t device_id_; + rtContext_t context_; + std::vector inputs_; + std::vector outputs_; + std::vector input_names_; + std::vector output_names_; + bool load_flag_; +}; +} // namespace mindspore::api +#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model.cc b/mindspore/ccsrc/cxx_api/model/acl/acl_model.cc index 754a5808d55..e095693c220 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/acl_model.cc +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model.cc @@ -16,216 +16,57 @@ #include "cxx_api/model/acl/acl_model.h" #include -#include "utils/context/context_extends.h" +#include "cxx_api/factory.h" +#include "cxx_api/python_utils.h" namespace mindspore::api { -std::weak_ptr AclModel::global_acl_env_; -std::mutex AclModel::global_acl_env_mutex_; +API_FACTORY_REG(ModelImpl, Ascend310, AclModel); -Status AclModel::InitEnv() { - if (init_flag_) { +Status AclModel::Build(const std::map &options_map) { + MS_LOG(INFO) << "Start build model."; + MS_EXCEPTION_IF_NULL(graph_); + RegAllOpFromPython(); + std::unique_ptr options = std::make_unique(options_map); + std::string options_str = GenerateOptionsStr(options_map); + MS_EXCEPTION_IF_NULL(options); + if (graph_cell_ != nullptr && options_str == options_str_) { + MS_LOG(INFO) << "This model has been built, skip."; return SUCCESS; } - MS_EXCEPTION_IF_NULL(options_); - aclError ret; - { - std::lock_guard lock(global_acl_env_mutex_); - acl_env_ = global_acl_env_.lock(); - if (acl_env_ != nullptr) { - if (options_->dump_cfg_path.empty()) { - MS_LOG(INFO) << "Acl has been initialized, skip."; - } else { - MS_LOG(WARNING) << "Acl has been initialized, skip, so dump config will be ignored."; - } - } else { - acl_env_ = std::make_shared(options_->dump_cfg_path); - if (acl_env_->GetErrno() != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Execute aclInit Failed"; - return FAILED; - } - global_acl_env_ = acl_env_; - MS_LOG(INFO) << "Acl init success"; + if (graph_cell_ == nullptr && graph_->ModelType() == ModelType::kOM) { + graph_cell_ = std::make_shared(graph_); + MS_EXCEPTION_IF_NULL(graph_cell_); + if (!options_map.empty()) { + MS_LOG(WARNING) << "All build options will be ignored."; } - } - - ret = aclrtSetDevice(device_id_); - if (ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Acl open device " << device_id_ << " failed"; - return FAILED; - } - MS_LOG(INFO) << "Open device " << device_id_ << " success"; - - ret = aclrtCreateContext(&context_, device_id_); - if (ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Acl create context failed"; - return FAILED; - } - MS_LOG(INFO) << "Create context success"; - - ret = aclrtSetCurrentContext(context_); - if (ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Acl set current context failed"; - return FAILED; - } - MS_LOG(INFO) << "Set context success"; - - ret = aclrtCreateStream(&stream_); - if (ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Acl create stream failed"; - return FAILED; - } - MS_LOG(INFO) << "Create stream success"; - - aclrtRunMode run_mode; - ret = aclrtGetRunMode(&run_mode); - if (ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Acl get run mode failed"; - return FAILED; - } - bool is_device = (run_mode == ACL_DEVICE); - model_process_.SetIsDevice(is_device); - MS_LOG(INFO) << "Get run mode success is device input/output " << is_device; - - if (dvpp_process_.InitResource(stream_) != SUCCESS) { - MS_LOG(ERROR) << "DVPP init resource failed"; - return FAILED; - } - - MS_LOG(INFO) << "Init acl success, device id " << device_id_; - init_flag_ = true; - return SUCCESS; -} - -Status AclModel::FinalizeEnv() { - if (!init_flag_) { return SUCCESS; } - dvpp_process_.Finalize(); - aclError ret; - if (stream_ != nullptr) { - ret = aclrtDestroyStream(stream_); - if (ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Destroy stream failed"; - } - stream_ = nullptr; - } - MS_LOG(INFO) << "End to destroy stream"; - if (context_ != nullptr) { - ret = aclrtDestroyContext(context_); - if (ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Destroy context failed"; - } - context_ = nullptr; - } - MS_LOG(INFO) << "End to destroy context"; - - ret = aclrtResetDevice(device_id_); - if (ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Reset devie " << device_id_ << " failed"; - } - MS_LOG(INFO) << "End to reset device " << device_id_; - - init_flag_ = false; - return SUCCESS; -} - -Status AclModel::LoadModel(const Buffer &model_data, ModelType type, - const std::map &options) { - if (load_flag_) { - MS_LOG(ERROR) << "Model has been loaded."; + auto func_graph = ModelImpl::GetFuncGraph(); + MS_EXCEPTION_IF_NULL(func_graph); + model_converter_.set_options(options.get()); + auto om_data = model_converter_.LoadMindIR(func_graph); + if (om_data.Data() == nullptr || om_data.DataSize() == 0) { + MS_LOG(ERROR) << "Load MindIR failed."; return FAILED; } - options_ = std::make_unique(options); - MS_EXCEPTION_IF_NULL(options_); - - Status ret = InitEnv(); + auto graph = std::make_shared(std::make_shared(om_data, ModelType::kOM)); + MS_EXCEPTION_IF_NULL(graph); + auto graph_cell = std::make_shared(graph); + MS_EXCEPTION_IF_NULL(graph_cell); + auto ret = ModelImpl::Load(graph_cell); if (ret != SUCCESS) { - MS_LOG(ERROR) << "InitEnv failed."; - return FAILED; + MS_LOG(ERROR) << "Load failed."; + return ret; } - Buffer om_data; - if (type == ModelType::kMindIR) { - model_converter_.set_options(options_.get()); - om_data = model_converter_.LoadMindIR(model_data); - } else if (type == ModelType::kAIR) { - model_converter_.set_options(options_.get()); - om_data = model_converter_.LoadAscendIR(model_data); - } else if (type == ModelType::kOM) { - om_data = model_data; - } else { - MS_LOG(ERROR) << "Unsupported model type " << type; - return FAILED; - } - - // acl load model - uint32_t acl_model_id; - auto acl_ret = aclmdlLoadFromMem(om_data.Data(), om_data.DataSize(), &acl_model_id); - if (acl_ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Call aclmdlLoadFromMem failed."; - return FAILED; - } - - // acl init model resource - model_process_.set_model_id(acl_model_id); - ret = model_process_.PreInitModelResource(); - if (ret != SUCCESS) { - (void)aclmdlUnload(acl_model_id); - MS_LOG(ERROR) << "Pre init model resource failed."; - return FAILED; - } - - // acl init dvpp - ret = dvpp_process_.InitWithJsonConfig(options_->dvpp_cfg_path); - if (ret != SUCCESS) { - MS_LOG(ERROR) << "DVPP config file parse error."; - return FAILED; - } - - load_flag_ = true; - return SUCCESS; -} - -Status AclModel::LoadModel(const std::string &file_name, ModelType type, - const std::map &options) { - Buffer model_data = ModelConverter::ReadFile(file_name); - if (model_data.DataSize() == 0) { - MS_LOG(ERROR) << "Read file " << file_name << " failed."; - return FAILED; - } - - return LoadModel(model_data, type, options); -} - -Status AclModel::UnloadModel() { - if (!load_flag_) { - MS_LOG(WARNING) << "No model is loaded, skip unload."; - return SUCCESS; - } - - aclError rt_ret = aclrtSetCurrentContext(context_); - if (rt_ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Set the ascend device context failed"; - return FAILED; - } - - Status ret = model_process_.UnLoad(); - if (ret != SUCCESS) { - MS_LOG(ERROR) << "Unload model inner failed."; - return FAILED; - } - - ret = FinalizeEnv(); - if (ret != SUCCESS) { - MS_LOG(ERROR) << "FinalizeEnv failed."; - return FAILED; - } - - MS_LOG(INFO) << "Unload model success."; - load_flag_ = false; + // save result + graph_cell_ = graph_cell; + options_ = std::move(options); + options_str_ = options_str; + MS_LOG(INFO) << "Build model success."; return SUCCESS; } @@ -239,45 +80,49 @@ Status AclModel::Eval(const DataSet &, std::map *) { return FAILED; } -Status AclModel::Predict(const std::map &inputs, std::map *outputs) { +Status AclModel::Predict(const std::vector &inputs, std::vector *outputs) { MS_EXCEPTION_IF_NULL(outputs); - if (!load_flag_) { - MS_LOG(ERROR) << "No model is loaded, predict failed."; + if (graph_ == nullptr) { + MS_LOG(ERROR) << "Invalid data, graph_ is null."; return FAILED; } - aclError rt_ret = aclrtSetCurrentContext(context_); - if (rt_ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Set the ascend device context failed"; + if (graph_cell_ == nullptr) { + MS_LOG(WARNING) << "Model has not been built, it will be built with default options"; + Status ret = Build({}); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Build model failed."; + return FAILED; + } + } + + MS_EXCEPTION_IF_NULL(graph_cell_); + Status ret = graph_cell_->Run(inputs, outputs); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Run graph failed."; return FAILED; } - return model_process_.Predict(inputs, outputs); + + return SUCCESS; } -Status AclModel::GetInputsInfo(std::vector *tensor_list) const { - MS_EXCEPTION_IF_NULL(tensor_list); - return model_process_.GetInputsInfo(tensor_list); +Status AclModel::GetInputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) const { + MS_EXCEPTION_IF_NULL(graph_cell_); + return graph_cell_->GetInputsInfo(names, shapes, data_types, mem_sizes); } -Status AclModel::GetOutputsInfo(std::vector *tensor_list) const { - MS_EXCEPTION_IF_NULL(tensor_list); - return model_process_.GetOutputsInfo(tensor_list); +Status AclModel::GetOutputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) const { + MS_EXCEPTION_IF_NULL(graph_cell_); + return graph_cell_->GetOutputsInfo(names, shapes, data_types, mem_sizes); } -AclModel::AclEnvGuard::AclEnvGuard(const std::string &cfg_file) { - errno_ = aclInit(common::SafeCStr(cfg_file)); - if (errno_ != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Execute aclInit Failed"; - return; +std::string AclModel::GenerateOptionsStr(const std::map &options) { + std::string ret; + for (auto &[key, value] : options) { + ret += key + "^" + value + "^^"; } - MS_LOG(INFO) << "Acl init success"; -} - -AclModel::AclEnvGuard::~AclEnvGuard() { - errno_ = aclFinalize(); - if (errno_ != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Finalize acl failed"; - } - MS_LOG(INFO) << "Acl finalize success"; + return ret; } } // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model.h b/mindspore/ccsrc/cxx_api/model/acl/acl_model.h index 6c7cbdf6b74..4455eba7d10 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/acl_model.h +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model.h @@ -23,77 +23,38 @@ #include #include -#include "ir/anf.h" +#include "include/api/cell.h" #include "include/api/status.h" #include "cxx_api/model/model_impl.h" -#include "cxx_api/model/acl/dvpp_process.h" -#include "cxx_api/model/acl/model_process.h" #include "cxx_api/model/acl/model_converter.h" #include "cxx_api/model/acl/acl_model_options.h" #include "ir/tensor.h" +#include "ir/anf.h" namespace mindspore::api { class AclModel : public ModelImpl { public: - explicit AclModel(uint32_t device_id) - : init_flag_(false), - load_flag_(false), - device_type_("AscendCL"), - device_id_(device_id), - context_(nullptr), - stream_(nullptr), - acl_env_(nullptr), - model_process_(), - dvpp_process_(), - model_converter_(), - options_(nullptr) {} + AclModel() : model_converter_(), options_(nullptr), options_str_() {} ~AclModel() = default; - Status LoadModel(const Buffer &model_data, ModelType type, - const std::map &options) override; - Status LoadModel(const std::string &file_name, ModelType type, - const std::map &options) override; - Status UnloadModel() override; + Status Build(const std::map &options_map) override; Status Train(const DataSet &dataset, std::map *outputs) override; Status Eval(const DataSet &dataset, std::map *outputs) override; - Status Predict(const std::map &inputs, std::map *outputs) override; + Status Predict(const std::vector &inputs, std::vector *outputs) override; - Status GetInputsInfo(std::vector *tensor_list) const override; - Status GetOutputsInfo(std::vector *tensor_list) const override; + Status GetInputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) const override; + Status GetOutputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) const override; private: - bool init_flag_; - bool load_flag_; - std::string device_type_; - int32_t device_id_; - aclrtContext context_; - aclrtStream stream_; + static std::string GenerateOptionsStr(const std::map &options); - class AclEnvGuard; - std::shared_ptr acl_env_; - static std::weak_ptr global_acl_env_; - static std::mutex global_acl_env_mutex_; - - ModelProcess model_process_; - DvppProcess dvpp_process_; + std::shared_ptr graph_cell_; ModelConverter model_converter_; std::unique_ptr options_; - - Status InitEnv(); - Status FinalizeEnv(); + std::string options_str_; }; - -class AclModel::AclEnvGuard { - public: - explicit AclEnvGuard(const std::string &cfg_file); - ~AclEnvGuard(); - aclError GetErrno() const { return errno_; } - - private: - aclError errno_; -}; - -API_REG_MODEL(AscendCL, AclModel); } // namespace mindspore::api #endif // MINDSPORE_CCSRC_CXX_API_ACL_MODEL_H diff --git a/mindspore/ccsrc/cxx_api/model/acl/dvpp_process.cc b/mindspore/ccsrc/cxx_api/model/acl/dvpp_process.cc deleted file mode 100644 index 91ff286a31f..00000000000 --- a/mindspore/ccsrc/cxx_api/model/acl/dvpp_process.cc +++ /dev/null @@ -1,1160 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "cxx_api/model/acl/dvpp_process.h" -#include -#include -#include -#include -#include -#include "utils/utils.h" -#include "include/api/types.h" -#include "mindspore/core/utils/ms_utils.h" - -namespace mindspore::api { -DvppProcess::DvppProcess() {} - -DvppProcess::~DvppProcess() {} - -static uint32_t ToEven(uint32_t num) { return (num + 1) / 2 * 2; } -static uint32_t ToOdd(uint32_t num) { - if (num == 0) { - return 1; - } - return (num + 1) / 2 * 2 - 1; -} - -class DvppJsonConfigParser { - public: - DvppJsonConfigParser() = default; - ~DvppJsonConfigParser() = default; - - Status InitWithJsonConfig(const std::string &json_config); - DvppDecodePara GetDecodePara() const { return decode_para_; } - DvppResizePara GetResizePara() const { return resize_para_; } - DvppCropPara GetCropPara() const { return crop_para_; } - DvppCropAndPastePara GetCropAndPastePara() const { return crop_and_paste_para_; } - bool HasResizeConfig() const { return resize_flag_; } - bool HasCropConfig() const { return crop_flag_; } - bool HasCropAndPasteConfig() const { return crop_and_paste_flag_; } - - private: - DvppDecodePara decode_para_; - DvppResizePara resize_para_; - DvppCropPara crop_para_; - DvppCropAndPastePara crop_and_paste_para_; - bool resize_flag_ = false; - bool crop_flag_ = false; - bool crop_and_paste_flag_ = false; - - Status GetStringValue(const nlohmann::json &json_item, const std::string &key, std::string *val); - Status GetIntValue(const nlohmann::json &json_item, const std::string &key, uint32_t *val); - Status ParseInputPara(const nlohmann::json &preprocess_item); - Status ParseDecodePara(const nlohmann::json &preprocess_item); - Status ParseResizePara(const nlohmann::json &json_item); - Status ParseCropPara(const nlohmann::json &json_item); - Status ParseCropAndPastePara(const nlohmann::json &json_item); - Status InitWithJsonConfigImp(const std::string &json_config); -}; - -Status DvppProcess::InitResource(aclrtStream stream) { - stream_ = stream; - aclError acl_ret; - dvpp_channel_desc_ = acldvppCreateChannelDesc(); - if (dvpp_channel_desc_ == nullptr) { - MS_LOG(ERROR) << "Call acldvppCreateChannelDesc failed"; - return FAILED; - } - acl_ret = acldvppCreateChannel(dvpp_channel_desc_); - if (acl_ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Call acldvppCreateChannel failed, acl return " << acl_ret; - return FAILED; - } - MS_LOG(INFO) << "End init dvpp process resource"; - return SUCCESS; -} - -void DvppProcess::DestroyResource() { - if (dvpp_channel_desc_ != nullptr) { - auto acl_ret = acldvppDestroyChannel(dvpp_channel_desc_); - if (acl_ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Call acldvppDestroyChannel failed, acl return " << acl_ret; - } - acl_ret = acldvppDestroyChannelDesc(dvpp_channel_desc_); - if (acl_ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Call acldvppDestroyChannelDesc failed, acl return " << acl_ret; - } - dvpp_channel_desc_ = nullptr; - } -} - -void DvppProcess::Finalize() { - DestroyDecodeDesc(); - DestroyVpcOutputDesc(); - DestroyResource(); - if (resize_config_ != nullptr) { - acldvppDestroyResizeConfig(resize_config_); - resize_config_ = nullptr; - } - if (crop_area_ != nullptr) { - acldvppDestroyRoiConfig(crop_area_); - crop_area_ = nullptr; - } - if (paste_area_ != nullptr) { - acldvppDestroyRoiConfig(paste_area_); - paste_area_ = nullptr; - } - if (input_pic_dev_buffer_ != nullptr) { - acldvppFree(input_pic_dev_buffer_); - } - input_pic_buffer_size_ = 0; - MS_LOG(INFO) << "End dvpp process finalize"; -} - -Status DvppProcess::InitJpegDecodePara(const DvppDecodePara &decode_para) { - decode_para_ = decode_para; - MS_LOG(INFO) << "Init decode para, pixel_format " << decode_para_.pixel_format; - return SUCCESS; -} - -Status DvppProcess::InitResizePara(const DvppResizePara &resize_para) { - resize_para_ = resize_para; - MS_LOG(INFO) << "Init resize para, " - << "output_width " << resize_para_.output_width << ", output_height " << resize_para_.output_height; - to_resize_flag_ = true; - to_crop_flag_ = false; - to_crop_and_paste_flag_ = false; - Status ret = InitResizeOutputDesc(); - if (ret != SUCCESS) { - MS_LOG(ERROR) << "InitResizeOutputDesc failed"; - } - return ret; -} - -Status DvppProcess::InitCommonCropPara(uint32_t output_width, uint32_t output_height, DvppCropInfo *crop_info) { - MS_EXCEPTION_IF_NULL(crop_info); - if (crop_info->crop_type == kDvppCropTypeOffset) { - if (CheckAndAdjustRoiArea(&crop_info->crop_area) != SUCCESS) { - MS_LOG(ERROR) << "Check and adjust crop area failed"; - return FAILED; - } - MS_LOG(INFO) << "Init common crop para, crop type offset " - << ", left " << crop_info->crop_area.left << ", right " << crop_info->crop_area.right << ", top " - << crop_info->crop_area.top << ", bottom " << crop_info->crop_area.bottom << ", output_width " - << output_width << ", output_height " << output_height; - } else { - crop_info->crop_width = ToEven(crop_info->crop_width); - crop_info->crop_height = ToEven(crop_info->crop_height); - if (CheckRoiAreaWidthHeight(crop_info->crop_width, crop_info->crop_height) != SUCCESS) { - MS_LOG(ERROR) << "Check crop area width and height failed, actually width " << crop_info->crop_width << " height " - << crop_info->crop_height; - return FAILED; - } - MS_LOG(INFO) << "Init common crop para, crop type centre " - << ", crop_width " << crop_info->crop_width << ", crop_height " << crop_info->crop_height - << ", output_width " << output_width << ", output_height " << output_height; - } - return SUCCESS; -} - -Status DvppProcess::InitCropPara(const DvppCropPara &crop_para) { - crop_para_ = crop_para; - if (InitCommonCropPara(crop_para_.output_width, crop_para_.output_height, &crop_para_.crop_info) != SUCCESS) { - MS_LOG(ERROR) << "Init common crop para failed in InitCropPara"; - return FAILED; - } - to_crop_flag_ = true; - to_resize_flag_ = false; - to_crop_and_paste_flag_ = false; - Status ret = InitCropOutputDesc(); - if (ret != SUCCESS) { - MS_LOG(ERROR) << "InitCropOutputDesc failed"; - } - return ret; -} - -Status DvppProcess::InitCropAndPastePara(const DvppCropAndPastePara &crop_and_paste_para) { - crop_and_paste_para_ = crop_and_paste_para; - if (InitCommonCropPara(crop_and_paste_para_.output_width, crop_and_paste_para_.output_height, - &crop_and_paste_para_.crop_info) != SUCCESS) { - MS_LOG(ERROR) << "Init common crop para failed in InitCropAndPastePara"; - return FAILED; - } - auto &paste_area = crop_and_paste_para_.paste_area; - if (CheckAndAdjustRoiArea(&paste_area) != SUCCESS) { - MS_LOG(ERROR) << "Check and adjust paste area failed"; - return FAILED; - } - MS_LOG(INFO) << "Init crop and paste para, paste info: " - << ", left " << paste_area.left << ", right " << paste_area.right << ", top " << paste_area.top - << ", bottom " << paste_area.bottom; - - to_crop_and_paste_flag_ = true; - to_crop_flag_ = false; - to_resize_flag_ = false; - Status ret = InitCropAndPasteOutputDesc(); - if (ret != SUCCESS) { - MS_LOG(ERROR) << "InitCropAndPasteOutputDesc failed"; - } - return ret; -} - -Status DvppProcess::InputInputBuffer(const void *pic_buffer, size_t pic_buffer_size) { - aclError acl_ret; - if (pic_buffer_size != input_pic_buffer_size_) { - acldvppFree(input_pic_dev_buffer_); - input_pic_buffer_size_ = 0; - acl_ret = acldvppMalloc(&input_pic_dev_buffer_, pic_buffer_size); - if (acl_ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Call acldvppMalloc input picture buffer on device failed, buffer size " << pic_buffer_size; - return FAILED; - } - input_pic_buffer_size_ = pic_buffer_size; - } - acl_ret = - aclrtMemcpy(input_pic_dev_buffer_, input_pic_buffer_size_, pic_buffer, pic_buffer_size, ACL_MEMCPY_HOST_TO_DEVICE); - if (acl_ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Call aclrtMemcpy input picture buffer to device, buffer size " << pic_buffer_size; - return FAILED; - } - return SUCCESS; -} - -static void JpegErrorExitCustom(j_common_ptr cinfo) { - char jpeg_last_error_msg[JMSG_LENGTH_MAX] = {0}; - if (cinfo != nullptr && cinfo->err != nullptr && cinfo->err->format_message != nullptr) { - (*(cinfo->err->format_message))(cinfo, jpeg_last_error_msg); - } - throw std::runtime_error(jpeg_last_error_msg); -} - -Status DvppProcess::GetJpegWidthHeight(const void *pic_buffer, size_t pic_buffer_size, uint32_t *image_width, - uint32_t *image_height) { - MS_EXCEPTION_IF_NULL(image_width); - MS_EXCEPTION_IF_NULL(image_height); - struct jpeg_decompress_struct jpeg_header; - struct jpeg_error_mgr jpeg_error; - jpeg_header.err = jpeg_std_error(&jpeg_error); - jpeg_error.error_exit = JpegErrorExitCustom; - try { - jpeg_create_decompress(&jpeg_header); - jpeg_mem_src(&jpeg_header, reinterpret_cast(pic_buffer), pic_buffer_size); - (void)jpeg_read_header(&jpeg_header, TRUE); - } catch (std::runtime_error &e) { - jpeg_destroy_decompress(&jpeg_header); - MS_LOG(ERROR) << "JPEG images read failed, " << e.what(); - return INVALID_INPUTS; - } - *image_width = jpeg_header.image_width; - *image_height = jpeg_header.image_height; - - if (jpeg_header.jpeg_color_space != JCS_YCbCr) { - MS_LOG(ERROR) << "Expect color space YUV(YCbCr), current " << jpeg_header.jpeg_color_space; - jpeg_destroy_decompress(&jpeg_header); - return INVALID_INPUTS; - } - if (jpeg_header.dc_huff_tbl_ptrs[0] == nullptr) { - MS_LOG(ERROR) << "Only support Huffman code"; - jpeg_destroy_decompress(&jpeg_header); - return INVALID_INPUTS; - } - jpeg_destroy_decompress(&jpeg_header); - - const uint32_t min_width = 32; - const uint32_t max_width = 8192; - const uint32_t min_height = 32; - const uint32_t max_height = 8192; - if (*image_width < min_width || *image_width > max_width) { - MS_LOG(ERROR) << "Expect image width [" << min_width << ", " << max_width << "], the real image width is " - << *image_width; - return INVALID_INPUTS; - } - if (*image_height < min_height || *image_height > max_height) { - MS_LOG(ERROR) << "Expect image height [" << min_height << ", " << max_height << "], the real image height is " - << *image_height; - return INVALID_INPUTS; - } - return SUCCESS; -} - -Status DvppProcess::Process(const void *pic_buffer, size_t pic_buffer_size, void **output_device_buffer, - size_t *output_size) { - MS_EXCEPTION_IF_NULL(output_device_buffer); - MS_EXCEPTION_IF_NULL(output_size); - if (dvpp_channel_desc_ == nullptr) { - MS_LOG(ERROR) << "Process failed, dvpp not inited"; - return FAILED; - } - uint32_t image_width = 0; - uint32_t image_height = 0; - Status ret = GetJpegWidthHeight(pic_buffer, pic_buffer_size, &image_width, &image_height); - if (ret != SUCCESS) { - MS_LOG(ERROR) << "Get jpeg image height and width failed"; - return ret; - } - MS_LOG(INFO) << "Get jpeg width " << image_width << ", height " << image_height; - ret = InitDecodeOutputDesc(image_width, image_height); - if (ret != SUCCESS) { - MS_LOG(ERROR) << "InitDecodeOutputDesc failed"; - return FAILED; - } - ret = UpdateCropArea(image_width, image_height); - if (ret != SUCCESS) { - MS_LOG(ERROR) << "Update crop area failed"; - return ret; - } - ret = CheckResizeImageInfo(image_width, image_height); - if (ret != SUCCESS) { - MS_LOG(ERROR) << "Check resize para failed"; - return ret; - } - if (InputInputBuffer(pic_buffer, pic_buffer_size) != SUCCESS) { - MS_LOG(ERROR) << "InputInputBuffer failed"; - return FAILED; - } - if (ProcessDecode() != SUCCESS) { - MS_LOG(ERROR) << "Process Decode failed"; - return INVALID_INPUTS; - } - MS_LOG(INFO) << "Process Decode success"; - if (to_resize_flag_) { - if (ProcessResize() != SUCCESS) { - MS_LOG(ERROR) << "Process Resize failed"; - return INVALID_INPUTS; - } - MS_LOG(INFO) << "Process Resize success"; - } else if (to_crop_flag_) { - if (ProcessCrop() != SUCCESS) { - MS_LOG(ERROR) << "Process Crop failed"; - return INVALID_INPUTS; - } - MS_LOG(INFO) << "Process Crop success"; - } else if (to_crop_and_paste_flag_) { - if (ProcessCropAndPaste() != SUCCESS) { - MS_LOG(ERROR) << "Process Crop And Paste failed"; - return INVALID_INPUTS; - } - MS_LOG(INFO) << "Process Crop And Paste success"; - } - if (vpc_output_buffer_dev_ == nullptr) { - *output_device_buffer = decode_output_buffer_dev_; - *output_size = decode_output_buffer_size_; - } else { - *output_device_buffer = vpc_output_buffer_dev_; - *output_size = vpc_output_buffer_size_; - } - MS_LOG(INFO) << "Process dvpp success"; - return SUCCESS; -} - -Status DvppProcess::Process(const std::vector &pic_buffer_list, - const std::vector &pic_buffer_size_list, void **output_device_buffer, - size_t *output_size) { - MS_EXCEPTION_IF_NULL(output_device_buffer); - MS_EXCEPTION_IF_NULL(output_size); - auto batch_size = pic_buffer_list.size(); - if (batch_size == 0 || batch_size != pic_buffer_size_list.size()) { - MS_LOG(ERROR) << "Invalid batch size " << batch_size << ", pic size count" << pic_buffer_size_list.size(); - return FAILED; - } - MS_LOG(INFO) << "Begin dvpp process, batch size " << batch_size; - if (batch_size == 1) { - return Process(pic_buffer_list[0], pic_buffer_size_list[0], output_device_buffer, output_size); - } - size_t total_buffer_size = vpc_output_buffer_size_ * batch_size; - if (batch_size_ != batch_size) { - if (batch_vpc_output_buffer_dev_ != nullptr) { - acldvppFree(batch_vpc_output_buffer_dev_); - batch_vpc_output_buffer_dev_ = nullptr; - } - batch_size_ = batch_size; - auto acl_rt = acldvppMalloc(&batch_vpc_output_buffer_dev_, total_buffer_size); - if (acl_rt != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Call acldvppMalloc failed, buffer size " << total_buffer_size; - return FAILED; - } - } - for (size_t i = 0; i < batch_size; i++) { - const void *pic_buffer = pic_buffer_list[i]; - uint32_t pic_size = pic_buffer_size_list[i]; - if (pic_buffer == nullptr || pic_size == 0) { - MS_LOG(ERROR) << "Get " << 0 << "th images failed"; - return FAILED; - } - void *output_dev_buffer_tmp = nullptr; - size_t output_buffer_size_tmp = 0; - Status ret = Process(pic_buffer, pic_size, &output_dev_buffer_tmp, &output_buffer_size_tmp); - if (ret != SUCCESS) { - MS_LOG(ERROR) << "DVPP process failed"; - return ret; - } - aclrtMemcpy(static_cast(batch_vpc_output_buffer_dev_) + vpc_output_buffer_size_ * i, - total_buffer_size - vpc_output_buffer_size_ * i, output_dev_buffer_tmp, vpc_output_buffer_size_, - ACL_MEMCPY_DEVICE_TO_DEVICE); - - MS_LOG(INFO) << "DVPP process " << i << " th images success, input pic size " << pic_size << " output buffer size " - << output_buffer_size_tmp; - } - *output_device_buffer = batch_vpc_output_buffer_dev_; - *output_size = total_buffer_size; - MS_LOG(INFO) << "End DVPP process, batch size " << batch_size << ", output size " << output_size; - return SUCCESS; -} - -uint32_t DvppProcess::AlignmentHelper(uint32_t org_size, uint32_t alignment) const { - if (alignment == 0) { - return 0; - } - return (org_size + alignment - 1) / alignment * alignment; -} - -uint32_t DvppProcess::GetImageBufferSize(uint32_t stride_width, uint32_t stride_height, - acldvppPixelFormat pixel_format) const { - if (stride_height == 0 || stride_width == 0) { - MS_LOG(ERROR) << "Invalid stride height or width, stride_width " << stride_width << " stride_height " - << stride_height; - return 0; - } - if (UINT32_MAX / 3 < stride_height || UINT32_MAX / (3 * stride_height) < stride_width) { - MS_LOG(ERROR) << "Invalid stride height or width, stride_width " << stride_width << " stride_height " - << stride_height; - return 0; - } - if (pixel_format == PIXEL_FORMAT_YUV_SEMIPLANAR_420 || pixel_format == PIXEL_FORMAT_YVU_SEMIPLANAR_420) { - return stride_width * stride_height * 3 / 2; // 420 - } else if (pixel_format == PIXEL_FORMAT_YUV_SEMIPLANAR_422 || pixel_format == PIXEL_FORMAT_YVU_SEMIPLANAR_422) { - return stride_width * stride_height * 2; // 422 - } else if (pixel_format == PIXEL_FORMAT_YUV_SEMIPLANAR_444 || pixel_format == PIXEL_FORMAT_YVU_SEMIPLANAR_444) { - return stride_width * stride_height * 3; // 444 - } - MS_LOG(ERROR) << "Not support pixel format " << pixel_format; - return 0; -} - -Status DvppProcess::GetPicDescStride(uint32_t width, uint32_t height, uint32_t *stride_width, uint32_t *stride_height) { - MS_EXCEPTION_IF_NULL(stride_width); - MS_EXCEPTION_IF_NULL(stride_height); - const uint32_t width_alignment = 16; - const uint32_t height_alignment = 2; - const uint32_t stride_width_minimum = 32; - const uint32_t stride_width_maximum = 4096; - const uint32_t stride_height_minimum = 6; - const uint32_t stride_height_maximum = 4096; - - *stride_width = AlignmentHelper(width, width_alignment); - *stride_height = AlignmentHelper(height, height_alignment); - if (*stride_width == 0 || *stride_height == 0) { - MS_LOG(ERROR) << "Init VPC output desc failed, get stride width or height failed"; - return FAILED; - } - if (*stride_width < stride_width_minimum || *stride_width > stride_width_maximum) { - MS_LOG(ERROR) << "Expect stride width [" << stride_width_minimum << ", " << stride_width_maximum - << "], current stride width " << stride_width << " given width " << width; - return FAILED; - } - if (*stride_height < stride_height_minimum || *stride_height > stride_height_maximum) { - MS_LOG(ERROR) << "Expect stride height [" << stride_height_minimum << ", " << stride_height_maximum - << "], current stride height " << *stride_height << " given height " << height; - return FAILED; - } - return SUCCESS; -} - -Status DvppProcess::GetPicDescStrideDecode(uint32_t width, uint32_t height, uint32_t *stride_width, - uint32_t *stride_height) { - MS_EXCEPTION_IF_NULL(stride_width); - MS_EXCEPTION_IF_NULL(stride_height); - const uint32_t width_alignment = 128; - const uint32_t height_alignment = 16; - const uint32_t width_minimum = 32; - const uint32_t width_maximum = 4096; // decode support 8192, dvpp(resize/crop/crop&paste) support 4096 - const uint32_t height_minimum = 32; - const uint32_t height_maximum = 4096; // decode support 8192, dvpp(resize/crop/crop&paste) support 4096 - if (width < width_minimum || width > width_maximum) { - MS_LOG(ERROR) << "Expect width [" << width_minimum << ", " << width_maximum << "], current width " << width; - return INVALID_INPUTS; - } - if (height < height_minimum || height > height_maximum) { - MS_LOG(ERROR) << "Expect height [" << height_minimum << ", " << height_maximum << "], current height " << height; - return INVALID_INPUTS; - } - *stride_width = AlignmentHelper(width, width_alignment); - *stride_height = AlignmentHelper(height, height_alignment); - if (*stride_width == 0 || *stride_height == 0) { - MS_LOG(ERROR) << "Init decode output desc failed, get stride width or height failed"; - return FAILED; - } - return SUCCESS; -} - -Status DvppProcess::InitVpcOutputDesc(uint32_t output_width, uint32_t output_height, acldvppPixelFormat pixel_format) { - DestroyVpcOutputDesc(); - uint32_t vpc_stride_width = 0; - uint32_t vpc_stride_height = 0; - if (GetPicDescStride(output_width, output_height, &vpc_stride_width, &vpc_stride_height) != SUCCESS) { - MS_LOG(ERROR) << "Init VPC output desc failed, get VPC output stride width/height failed"; - return FAILED; - } - vpc_output_buffer_size_ = GetImageBufferSize(vpc_stride_width, vpc_stride_height, pixel_format); - if (vpc_output_buffer_size_ == 0) { - MS_LOG(ERROR) << "Init VPC output desc failed, get image buffer size failed"; - return FAILED; - } - auto acl_ret = acldvppMalloc(&vpc_output_buffer_dev_, vpc_output_buffer_size_); - if (acl_ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Init VPC output desc failed, malloc dvpp memory failed"; - return FAILED; - } - vpc_output_desc_ = acldvppCreatePicDesc(); - if (vpc_output_desc_ == nullptr) { - MS_LOG(ERROR) << "Init VPC output desc failed, create pic desc failed"; - return FAILED; - } - acldvppSetPicDescData(vpc_output_desc_, vpc_output_buffer_dev_); - acldvppSetPicDescSize(vpc_output_desc_, vpc_output_buffer_size_); - acldvppSetPicDescFormat(vpc_output_desc_, pixel_format); - acldvppSetPicDescWidth(vpc_output_desc_, output_width); - acldvppSetPicDescHeight(vpc_output_desc_, output_height); - acldvppSetPicDescWidthStride(vpc_output_desc_, vpc_stride_width); - acldvppSetPicDescHeightStride(vpc_output_desc_, vpc_stride_height); - MS_LOG(INFO) << "Init VPC output desc success"; - return SUCCESS; -} - -void DvppProcess::DestroyVpcOutputDesc() { - if (vpc_output_desc_ != nullptr) { - acldvppDestroyPicDesc(vpc_output_desc_); - vpc_output_desc_ = nullptr; - } - if (vpc_output_buffer_dev_ != nullptr) { - acldvppFree(vpc_output_buffer_dev_); - vpc_output_buffer_dev_ = nullptr; - } - if (batch_vpc_output_buffer_dev_ != nullptr) { - acldvppFree(batch_vpc_output_buffer_dev_); - batch_vpc_output_buffer_dev_ = nullptr; - } - vpc_output_buffer_size_ = 0; - MS_LOG(INFO) << "End destroy vpc desc"; -} - -Status DvppProcess::InitDecodeOutputDesc(uint32_t image_width, uint32_t image_height) { - if (decode_output_buffer_dev_ != nullptr && image_width == pic_width_ && image_height == pic_height_) { - return SUCCESS; - } - DestroyDecodeDesc(); - - pic_width_ = image_width; - pic_height_ = image_height; - - uint32_t stride_width = 0; - uint32_t stride_height = 0; - Status ret = GetPicDescStrideDecode(pic_width_, pic_height_, &stride_width, &stride_height); - if (ret != SUCCESS) { - MS_LOG(ERROR) << "Init VPC output desc failed, get VPC output stride width/height failed"; - return ret; - } - - decode_output_buffer_size_ = GetImageBufferSize(stride_width, stride_height, decode_para_.pixel_format); - if (decode_output_buffer_size_ == 0) { - MS_LOG(ERROR) << "Init decode output desc failed, get image buffer size failed"; - return FAILED; - } - auto acl_ret = acldvppMalloc(&decode_output_buffer_dev_, decode_output_buffer_size_); - if (acl_ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Init decode output desc failed, malloc dvpp memory failed"; - return FAILED; - } - decode_output_desc_ = acldvppCreatePicDesc(); - if (decode_output_desc_ == nullptr) { - MS_LOG(ERROR) << "Init decode output desc failed, create pic desc failed"; - return FAILED; - } - acldvppSetPicDescData(decode_output_desc_, decode_output_buffer_dev_); - acldvppSetPicDescSize(decode_output_desc_, decode_output_buffer_size_); - acldvppSetPicDescFormat(decode_output_desc_, decode_para_.pixel_format); - acldvppSetPicDescWidth(decode_output_desc_, pic_width_); - acldvppSetPicDescHeight(decode_output_desc_, pic_height_); - acldvppSetPicDescWidthStride(decode_output_desc_, stride_width); - acldvppSetPicDescHeightStride(decode_output_desc_, stride_height); - MS_LOG(INFO) << "Init decode output desc success"; - return SUCCESS; -} - -Status DvppProcess::CheckRoiAreaWidthHeight(uint32_t width, uint32_t height) { - const uint32_t min_crop_width = 10; - const uint32_t max_crop_width = 4096; - const uint32_t min_crop_height = 6; - const uint32_t max_crop_height = 4096; - - if (width < min_crop_width || width > max_crop_width) { - MS_LOG(ERROR) << "Expect roi area width in [" << min_crop_width << ", " << max_crop_width << "], actually " - << width; - return FAILED; - } - if (height < min_crop_height || height > max_crop_height) { - MS_LOG(ERROR) << "Expect roi area height in [" << min_crop_height << ", " << max_crop_height << "], actually " - << height; - return FAILED; - } - return SUCCESS; -} - -Status DvppProcess::CheckAndAdjustRoiArea(DvppRoiArea *area) { - MS_EXCEPTION_IF_NULL(area); - if (area->right < area->left) { - MS_LOG(ERROR) << "Check roi area failed, left " << area->left << ", right " << area->right; - return FAILED; - } - if (area->bottom < area->top) { - MS_LOG(ERROR) << "Check roi area failed, top " << area->top << ", bottom " << area->bottom; - return FAILED; - } - - area->left = ToEven(area->left); - area->top = ToEven(area->top); - area->right = ToOdd(area->right); - area->bottom = ToOdd(area->bottom); - - auto width = area->right - area->left + 1; - auto height = area->bottom - area->top + 1; - if (CheckRoiAreaWidthHeight(width, height) != SUCCESS) { - MS_LOG(ERROR) << "Check roi area width and height failed," - << " actually width " << width << " left " << area->left << ", right " << area->right - << " actually height " << height << " top " << area->top << ", bottom " << area->bottom; - return FAILED; - } - return SUCCESS; -} - -Status DvppProcess::UpdateCropArea(uint32_t image_width, uint32_t image_height) { - DvppCropInfo *crop_info = nullptr; - if (to_crop_flag_) { - crop_info = &crop_para_.crop_info; - } else if (to_crop_and_paste_flag_) { - crop_info = &crop_and_paste_para_.crop_info; - } else { - return SUCCESS; - } - if (crop_info->crop_type != kDvppCropTypeCentre) { - return SUCCESS; - } - if (image_width < crop_info->crop_width) { - MS_LOG(ERROR) << "Image width " << image_width << "smaller than crop width " << crop_info->crop_width; - return INVALID_INPUTS; - } - if (image_height < crop_info->crop_height) { - MS_LOG(ERROR) << "Image height " << image_height << "smaller than crop height " << crop_info->crop_height; - return INVALID_INPUTS; - } - uint32_t left = ToEven((image_width - crop_info->crop_width) / 2); - uint32_t top = ToEven((image_height - crop_info->crop_height) / 2); - uint32_t right = ToOdd(left + crop_info->crop_width); - uint32_t bottom = ToOdd(top + crop_info->crop_height); - - auto acl_ret = acldvppSetRoiConfig(crop_area_, left, right, top, bottom); - if (acl_ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Update Crop Area failed"; - return FAILED; - } - MS_LOG(INFO) << "Update crop area, crop type centre, crop info: " - << ", left " << left << ", right " << right << ", top " << top << ", bottom " << bottom; - return SUCCESS; -} - -Status DvppProcess::CheckResizeImageInfo(uint32_t image_width, uint32_t image_height) const { - if (!to_resize_flag_) { - return SUCCESS; - } - // resize ratio required [1/32, 16] - auto check_resize_ratio = [](uint32_t before_resize, uint32_t after_resize) { - if (before_resize == 0 || after_resize == 0) { - return false; - } - if (before_resize / after_resize > 32) { - return false; - } - if (after_resize / before_resize > 16) { - return false; - } - return true; - }; - if (!check_resize_ratio(image_width, resize_para_.output_width)) { - MS_LOG(ERROR) << "Resize ratio required [1/32, 16], current width resize from " << image_width << " to " - << resize_para_.output_width; - return INVALID_INPUTS; - } - if (!check_resize_ratio(image_height, resize_para_.output_height)) { - MS_LOG(ERROR) << "Resize ratio required [1/32, 16], current height resize from " << image_height << " to " - << resize_para_.output_height; - return INVALID_INPUTS; - } - return SUCCESS; -} - -void DvppProcess::DestroyDecodeDesc() { - if (decode_output_desc_ != nullptr) { - acldvppDestroyPicDesc(decode_output_desc_); - decode_output_desc_ = nullptr; - } - if (decode_output_buffer_dev_ != nullptr) { - acldvppFree(decode_output_buffer_dev_); - decode_output_buffer_dev_ = nullptr; - } - decode_output_buffer_size_ = 0; - MS_LOG(INFO) << "End destroy decode desc"; -} - -Status DvppProcess::InitResizeOutputDesc() { - if (InitVpcOutputDesc(resize_para_.output_width, resize_para_.output_height, decode_para_.pixel_format) != SUCCESS) { - MS_LOG(ERROR) << "Init VPC output desc failed"; - return FAILED; - } - if (resize_config_ == nullptr) { - resize_config_ = acldvppCreateResizeConfig(); - if (resize_config_ == nullptr) { - MS_LOG(ERROR) << "Create Resize config failed"; - return FAILED; - } - } - return SUCCESS; -} - -Status DvppProcess::InitRoiAreaConfig(const DvppRoiArea &init_para, acldvppRoiConfig **roi_area) { - MS_EXCEPTION_IF_NULL(roi_area); - if (*roi_area == nullptr) { - *roi_area = acldvppCreateRoiConfig(init_para.left, init_para.right, init_para.top, init_para.bottom); - if (*roi_area == nullptr) { - MS_LOG(ERROR) << "Create Roi config failed"; - return FAILED; - } - } else { - auto acl_ret = acldvppSetRoiConfig(*roi_area, init_para.left, init_para.right, init_para.top, init_para.bottom); - if (acl_ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Set Roi config failed"; - return FAILED; - } - } - return SUCCESS; -} - -Status DvppProcess::InitCropOutputDesc() { - if (InitVpcOutputDesc(crop_para_.output_width, crop_para_.output_height, decode_para_.pixel_format) != SUCCESS) { - MS_LOG(ERROR) << "Init VPC output desc failed"; - return FAILED; - } - if (InitRoiAreaConfig(crop_para_.crop_info.crop_area, &crop_area_) != SUCCESS) { - MS_LOG(ERROR) << "Init crop area failed"; - return FAILED; - } - return SUCCESS; -} - -Status DvppProcess::InitCropAndPasteOutputDesc() { - if (InitVpcOutputDesc(crop_and_paste_para_.output_width, crop_and_paste_para_.output_height, - decode_para_.pixel_format) != SUCCESS) { - MS_LOG(ERROR) << "Init VPC output desc failed"; - return FAILED; - } - if (InitRoiAreaConfig(crop_and_paste_para_.crop_info.crop_area, &crop_area_) != SUCCESS) { - MS_LOG(ERROR) << "Init crop area failed"; - return FAILED; - } - if (InitRoiAreaConfig(crop_and_paste_para_.paste_area, &paste_area_) != SUCCESS) { - MS_LOG(ERROR) << "Init paste area failed"; - return FAILED; - } - return SUCCESS; -} - -Status DvppProcess::ProcessDecode() { - aclError acl_ret; - acl_ret = acldvppJpegDecodeAsync(dvpp_channel_desc_, input_pic_dev_buffer_, input_pic_buffer_size_, - decode_output_desc_, stream_); - if (acl_ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Call acldvppJpegDecodeAsync failed, acl return " << acl_ret; - return FAILED; - } - acl_ret = aclrtSynchronizeStream(stream_); - if (acl_ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Call aclrtSynchronizeStream failed, acl return " << acl_ret; - return FAILED; - } - return SUCCESS; -} - -Status DvppProcess::ProcessResize() { - aclError acl_ret; - acl_ret = acldvppVpcResizeAsync(dvpp_channel_desc_, decode_output_desc_, vpc_output_desc_, resize_config_, stream_); - if (acl_ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Call acldvppVpcResizeAsync failed, acl return " << acl_ret; - return FAILED; - } - acl_ret = aclrtSynchronizeStream(stream_); - if (acl_ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Call aclrtSynchronizeStream failed, acl return " << acl_ret; - return FAILED; - } - return SUCCESS; -} - -Status DvppProcess::ProcessCrop() { - aclError acl_ret; - acl_ret = acldvppVpcCropAsync(dvpp_channel_desc_, decode_output_desc_, vpc_output_desc_, crop_area_, stream_); - if (acl_ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Call acldvppVpcCropAsync failed, acl return " << acl_ret; - return FAILED; - } - acl_ret = aclrtSynchronizeStream(stream_); - if (acl_ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Call aclrtSynchronizeStream failed, acl return " << acl_ret; - return FAILED; - } - return SUCCESS; -} - -Status DvppProcess::ProcessCropAndPaste() { - aclError acl_ret; - acl_ret = acldvppVpcCropAndPasteAsync(dvpp_channel_desc_, decode_output_desc_, vpc_output_desc_, crop_area_, - paste_area_, stream_); - if (acl_ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Call acldvppVpcCropAndPasteAsync failed, acl return " << acl_ret; - return FAILED; - } - acl_ret = aclrtSynchronizeStream(stream_); - if (acl_ret != ACL_ERROR_NONE) { - MS_LOG(ERROR) << "Call aclrtSynchronizeStream failed, acl return " << acl_ret; - return FAILED; - } - return SUCCESS; -} - -Status DvppJsonConfigParser::GetStringValue(const nlohmann::json &json_item, const std::string &key, std::string *val) { - MS_EXCEPTION_IF_NULL(val); - auto it = json_item.find(key); - if (it == json_item.end()) { - MS_LOG(ERROR) << "Get string item " << key << " failed"; - return FAILED; - } - if (!it->is_string()) { - MS_LOG(ERROR) << "Item " << key << " value is not string type"; - return FAILED; - } - *val = it->get(); - return SUCCESS; -} - -Status DvppJsonConfigParser::GetIntValue(const nlohmann::json &json_item, const std::string &key, uint32_t *val) { - MS_EXCEPTION_IF_NULL(val); - auto it = json_item.find(key); - if (it == json_item.end()) { - MS_LOG(ERROR) << "Get string item " << key << " failed"; - return FAILED; - } - if (!it->is_number_integer()) { - MS_LOG(ERROR) << "Item " << key << " value is not integer type"; - return FAILED; - } - *val = it->get(); - return SUCCESS; -} - -Status DvppJsonConfigParser::ParseInputPara(const nlohmann::json &preprocess_item) { - auto input = preprocess_item.find("input"); - if (input == preprocess_item.end()) { - MS_LOG(ERROR) << "Get input failed"; - return FAILED; - } - if (!input->is_object()) { - MS_LOG(ERROR) << "Input is not object"; - return FAILED; - } - return SUCCESS; -} - -Status DvppJsonConfigParser::ParseDecodePara(const nlohmann::json &preprocess_item) { - auto decode_para = preprocess_item.find("decode_para"); - if (decode_para == preprocess_item.end()) { - MS_LOG(ERROR) << "Get input failed"; - return FAILED; - } - if (!decode_para->is_object()) { - MS_LOG(ERROR) << "Input is not object"; - return FAILED; - } - const std::unordered_map pixel_format_map = { - {"YUV420SP", PIXEL_FORMAT_YUV_SEMIPLANAR_420}, {"YVU420SP", PIXEL_FORMAT_YVU_SEMIPLANAR_420}, - {"YUV422SP", PIXEL_FORMAT_YUV_SEMIPLANAR_422}, {"YVU422SP", PIXEL_FORMAT_YVU_SEMIPLANAR_422}, - {"YUV444SP", PIXEL_FORMAT_YUV_SEMIPLANAR_444}, {"YVU444SP", PIXEL_FORMAT_YVU_SEMIPLANAR_444}, - }; - std::string pixel_format; - if (GetStringValue(*decode_para, "out_pixel_format", &pixel_format) != SUCCESS) { - MS_LOG(ERROR) << "Get op out_pixel_format failed"; - return FAILED; - } - auto format = pixel_format_map.find(pixel_format); - if (format == pixel_format_map.end()) { - MS_LOG(ERROR) << "Unsupported out_pixel_format " << pixel_format; - return FAILED; - } - decode_para_.pixel_format = format->second; - return SUCCESS; -} - -Status DvppJsonConfigParser::ParseResizePara(const nlohmann::json &json_item) { - if (GetIntValue(json_item, "out_width", &resize_para_.output_width) != SUCCESS) { - return FAILED; - } - if (GetIntValue(json_item, "out_height", &resize_para_.output_height) != SUCCESS) { - return FAILED; - } - resize_flag_ = true; - return SUCCESS; -} - -Status DvppJsonConfigParser::ParseCropPara(const nlohmann::json &json_item) { - if (GetIntValue(json_item, "out_width", &crop_para_.output_width) != SUCCESS) { - return FAILED; - } - if (GetIntValue(json_item, "out_height", &crop_para_.output_height) != SUCCESS) { - return FAILED; - } - auto &crop_info = crop_para_.crop_info; - std::string crop_type = "crop_type"; - if (GetStringValue(json_item, "crop_type", &crop_type) != SUCCESS) { - return FAILED; - } - if (crop_type == "offset") { - MS_LOG(INFO) << "Crop type is 'offset'"; - crop_info.crop_type = kDvppCropTypeOffset; - auto &crop_area = crop_info.crop_area; - if (GetIntValue(json_item, "crop_left", &crop_area.left) != SUCCESS) { - return FAILED; - } - if (GetIntValue(json_item, "crop_top", &crop_area.top) != SUCCESS) { - return FAILED; - } - if (GetIntValue(json_item, "crop_right", &crop_area.right) != SUCCESS) { - return FAILED; - } - if (GetIntValue(json_item, "crop_bottom", &crop_area.bottom) != SUCCESS) { - return FAILED; - } - } else if (crop_type == "centre") { - MS_LOG(INFO) << "Crop type is 'centre'"; - if (GetIntValue(json_item, "crop_width", &crop_info.crop_width) != SUCCESS) { - return FAILED; - } - if (GetIntValue(json_item, "crop_height", &crop_info.crop_height) != SUCCESS) { - return FAILED; - } - crop_info.crop_type = kDvppCropTypeCentre; - } else { - MS_LOG(ERROR) << "Invalid crop type " << crop_type << ", expect offset or centre"; - return FAILED; - } - crop_flag_ = true; - return SUCCESS; -} - -Status DvppJsonConfigParser::ParseCropAndPastePara(const nlohmann::json &json_item) { - // crop info - if (GetIntValue(json_item, "out_width", &crop_and_paste_para_.output_width) != SUCCESS) { - return FAILED; - } - if (GetIntValue(json_item, "out_height", &crop_and_paste_para_.output_height) != SUCCESS) { - return FAILED; - } - auto &crop_info = crop_and_paste_para_.crop_info; - std::string crop_type = "crop_type"; - if (GetStringValue(json_item, "crop_type", &crop_type) != SUCCESS) { - return FAILED; - } - if (crop_type == "offset") { - MS_LOG(INFO) << "Crop type is 'offset'"; - crop_info.crop_type = kDvppCropTypeOffset; - auto &crop_area = crop_info.crop_area; - if (GetIntValue(json_item, "crop_left", &crop_area.left) != SUCCESS) { - return FAILED; - } - if (GetIntValue(json_item, "crop_top", &crop_area.top) != SUCCESS) { - return FAILED; - } - if (GetIntValue(json_item, "crop_right", &crop_area.right) != SUCCESS) { - return FAILED; - } - if (GetIntValue(json_item, "crop_bottom", &crop_area.bottom) != SUCCESS) { - return FAILED; - } - } else if (crop_type == "centre") { - MS_LOG(INFO) << "Crop type is 'centre'"; - if (GetIntValue(json_item, "crop_width", &crop_info.crop_width) != SUCCESS) { - return FAILED; - } - if (GetIntValue(json_item, "crop_height", &crop_info.crop_height) != SUCCESS) { - return FAILED; - } - crop_info.crop_type = kDvppCropTypeCentre; - } else { - MS_LOG(ERROR) << "Invalid crop type " << crop_type << ", expect offset or centre"; - return FAILED; - } - // paste info - auto &paste_area = crop_and_paste_para_.paste_area; - if (GetIntValue(json_item, "paste_left", &paste_area.left) != SUCCESS) { - return FAILED; - } - if (GetIntValue(json_item, "paste_top", &paste_area.top) != SUCCESS) { - return FAILED; - } - if (GetIntValue(json_item, "paste_right", &paste_area.right) != SUCCESS) { - return FAILED; - } - if (GetIntValue(json_item, "paste_bottom", &paste_area.bottom) != SUCCESS) { - return FAILED; - } - crop_and_paste_flag_ = true; - return SUCCESS; -} - -Status DvppJsonConfigParser::InitWithJsonConfigImp(const std::string &json_config) { - std::ifstream fp(json_config); - if (!fp.is_open()) { - MS_LOG(ERROR) << "Read json config file failed"; - return FAILED; - } - const auto &model_info = nlohmann::json::parse(fp); - auto preprocess_list = model_info.find("preprocess"); - if (preprocess_list == model_info.end()) { - MS_LOG(ERROR) << "Get preprocess failed"; - return FAILED; - } - if (!preprocess_list->is_array()) { - MS_LOG(ERROR) << "Preprocess is not array"; - return FAILED; - } - if (preprocess_list->empty()) { - MS_LOG(ERROR) << "Preprocess size is 0"; - return FAILED; - } - auto &preprocess = preprocess_list->at(0); - // input - if (ParseInputPara(preprocess) != SUCCESS) { - MS_LOG(ERROR) << "Parse input failed"; - return FAILED; - } - // decode para - if (ParseDecodePara(preprocess) != SUCCESS) { - MS_LOG(ERROR) << "Parse decode failed"; - return FAILED; - } - // ops - auto dvpp_process = preprocess.find("dvpp_process"); - if (dvpp_process == preprocess.end()) { - MS_LOG(ERROR) << "Get dvpp_process failed"; - return FAILED; - } - if (!dvpp_process->is_object()) { - MS_LOG(ERROR) << "Obj dvpp_process is not array"; - return FAILED; - } - const auto &item = *dvpp_process; - std::string op_name; - if (GetStringValue(item, "op_name", &op_name) != SUCCESS) { - return FAILED; - } - if (op_name == "resize") { - if (ParseResizePara(item) != SUCCESS) { - MS_LOG(ERROR) << "Parse resize para failed"; - return FAILED; - } - } else if (op_name == "crop") { - if (ParseCropPara(item) != SUCCESS) { - MS_LOG(ERROR) << "Parse crop para failed"; - return FAILED; - } - } else if (op_name == "crop_and_paste") { - if (ParseCropAndPastePara(item) != SUCCESS) { - MS_LOG(ERROR) << "Parse decode para failed"; - return FAILED; - } - } else { - MS_LOG(ERROR) << "Unsupported op name " << op_name << ", expect resize, crop or crop_and_paste"; - return FAILED; - } - return SUCCESS; -} - -Status DvppJsonConfigParser::InitWithJsonConfig(const std::string &json_config) { - try { - auto ret = InitWithJsonConfigImp(json_config); - if (ret != SUCCESS) { - MS_LOG(ERROR) << "Init dvpp with json config failed, json config " << json_config; - return FAILED; - } - } catch (nlohmann::json::exception &e) { - MS_LOG(ERROR) << "Init dvpp with json config failed, json config " << json_config << ", error: " << e.what(); - return FAILED; - } - MS_LOG(INFO) << "Init with json config " << json_config << " success"; - return SUCCESS; -} - -Status DvppProcess::InitWithJsonConfig(const std::string &json_config) { - if (json_config.empty()) { - MS_LOG(INFO) << "No dvpp config file path set, skip."; - loaded_flag_ = false; - return SUCCESS; - } - - char real_path[PATH_MAX] = {0}; - if (realpath(common::SafeCStr(json_config), real_path) == nullptr) { - MS_LOG(WARNING) << "Dvpp json file " << json_config << " is not exist."; - loaded_flag_ = false; - return SUCCESS; - } - - DvppJsonConfigParser parser; - if (parser.InitWithJsonConfig(real_path) != SUCCESS) { - MS_LOG(ERROR) << "Init json config failed"; - return FAILED; - } - if (InitJpegDecodePara(parser.GetDecodePara()) != SUCCESS) { - MS_LOG(ERROR) << "Init decode para failed"; - return FAILED; - } - if (parser.HasResizeConfig()) { - if (InitResizePara(parser.GetResizePara()) != SUCCESS) { - MS_LOG(ERROR) << "Init resize para failed"; - return FAILED; - } - } else if (parser.HasCropConfig()) { - if (InitCropPara(parser.GetCropPara()) != SUCCESS) { - MS_LOG(ERROR) << "Init crop para failed"; - return FAILED; - } - } else if (parser.HasCropAndPasteConfig()) { - if (InitCropAndPastePara(parser.GetCropAndPastePara()) != SUCCESS) { - MS_LOG(ERROR) << "Init crop and paste para failed"; - return FAILED; - } - } - - MS_LOG(INFO) << "Dvpp config success"; - loaded_flag_ = true; - return SUCCESS; -} -} // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/model/acl/dvpp_process.h b/mindspore/ccsrc/cxx_api/model/acl/dvpp_process.h deleted file mode 100644 index 105baef1689..00000000000 --- a/mindspore/ccsrc/cxx_api/model/acl/dvpp_process.h +++ /dev/null @@ -1,160 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_DVPP_PROCESS_H -#define MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_DVPP_PROCESS_H -#include -#include -#include -#include "acl/acl.h" -#include "acl/acl_mdl.h" -#include "acl/acl_rt.h" -#include "acl/ops/acl_dvpp.h" -#include "include/api/status.h" - -namespace mindspore::api { -struct DvppDecodePara { - acldvppPixelFormat pixel_format = PIXEL_FORMAT_YUV_SEMIPLANAR_420; -}; - -struct DvppResizePara { - uint32_t output_width = 0; - uint32_t output_height = 0; -}; - -enum DvppCropType { - // crop left,top,right,bottom is given in config - kDvppCropTypeOffset = 0, - // crop left,top,right,bottom is calculated by image width/height and output crop width/height - kDvppCropTypeCentre = 1, -}; - -struct DvppRoiArea { - uint32_t left = 0; - uint32_t top = 0; - uint32_t right = 0; - uint32_t bottom = 0; -}; - -struct DvppCropInfo { - DvppCropType crop_type = kDvppCropTypeOffset; - DvppRoiArea crop_area; // when kDvppCropTypeOffset - uint32_t crop_width = 0; // when kDvppCropTypeCentre - uint32_t crop_height = 0; // when kDvppCropTypeCentre -}; - -struct DvppCropPara { - DvppCropInfo crop_info; - uint32_t output_width = 0; - uint32_t output_height = 0; -}; - -struct DvppCropAndPastePara { - DvppCropInfo crop_info; - DvppRoiArea paste_area; - uint32_t output_width = 0; - uint32_t output_height = 0; -}; - -class DvppProcess { - public: - DvppProcess(); - ~DvppProcess(); - - Status InitResource(aclrtStream stream); - void Finalize(); - Status InitJpegDecodePara(const DvppDecodePara &decode_para); // jpeg decode + (resize | crop) - Status InitResizePara(const DvppResizePara &resize_para); // jpeg decode + resize - Status InitCropPara(const DvppCropPara &crop_para); // jpeg decode + crop - Status InitCropAndPastePara(const DvppCropAndPastePara &crop_and_paste_para); // jpeg decode + crop&paste - - Status InitWithJsonConfig(const std::string &json_config); - - // output device buffer will be destroy by DvppProcess itself. - Status Process(const void *pic_buffer, size_t pic_buffer_size, void **output_device_buffer, size_t *output_size); - Status Process(const std::vector &pic_buffer_list, const std::vector &pic_buffer_size_list, - void **output_device_buffer, size_t *output_size); - bool HasLoaded() const { return loaded_flag_; } - - private: - bool loaded_flag_ = false; - uint32_t pic_width_ = 0; - uint32_t pic_height_ = 0; - - DvppDecodePara decode_para_; - DvppResizePara resize_para_; - DvppCropPara crop_para_; - DvppCropAndPastePara crop_and_paste_para_; - // only one of the resize or crop flag can be true - bool to_resize_flag_ = false; - bool to_crop_flag_ = false; - bool to_crop_and_paste_flag_ = false; - - void *input_pic_dev_buffer_ = nullptr; - uint32_t input_pic_buffer_size_ = 0; - - uint32_t decode_output_buffer_size_ = 0; - void *decode_output_buffer_dev_ = nullptr; - acldvppPicDesc *decode_output_desc_ = nullptr; - - acldvppResizeConfig *resize_config_ = nullptr; - acldvppRoiConfig *crop_area_ = nullptr; - acldvppRoiConfig *paste_area_ = nullptr; - - acldvppPicDesc *vpc_output_desc_ = nullptr; - void *vpc_output_buffer_dev_ = nullptr; // vpc_output_buffer_size_ length - uint32_t vpc_output_buffer_size_ = 0; - - void *batch_vpc_output_buffer_dev_ = nullptr; // batch_size_ * vpc_output_buffer_size_ length - uint32_t batch_size_ = 0; - - aclrtStream stream_ = nullptr; - acldvppChannelDesc *dvpp_channel_desc_ = nullptr; - - uint32_t AlignmentHelper(uint32_t org_size, uint32_t alignment) const; - uint32_t GetImageBufferSize(uint32_t stride_width, uint32_t stride_height, acldvppPixelFormat pixel_format) const; - Status GetPicDescStride(uint32_t width, uint32_t height, uint32_t *stride_width, uint32_t *stride_height); - Status GetPicDescStrideDecode(uint32_t width, uint32_t height, uint32_t *stride_width, uint32_t *stride_height); - Status InputInputBuffer(const void *pic_buffer, size_t pic_buffer_size); - Status InitDecodeOutputDesc(uint32_t image_width, - uint32_t image_height); // decode_output_desc_, decode_output_buffer_dev_ - Status CheckRoiAreaWidthHeight(uint32_t width, uint32_t height); - Status CheckAndAdjustRoiArea(DvppRoiArea *area); - Status UpdateCropArea(uint32_t image_width, uint32_t image_height); - Status CheckResizeImageInfo(uint32_t image_width, uint32_t image_height) const; - void DestroyDecodeDesc(); - - Status InitVpcOutputDesc(uint32_t output_width, uint32_t output_height, - acldvppPixelFormat pixel_format); // vpc_output_desc_, vpc_output_buffer_dev_batch_ - Status InitRoiAreaConfig(const DvppRoiArea &init_para, acldvppRoiConfig **roi_area); - Status InitCommonCropPara(uint32_t out_width, uint32_t out_height, DvppCropInfo *crop_info); - Status InitResizeOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, resize_config - Status InitCropOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, crop_area_ - Status InitCropAndPasteOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, crop_area_, paste_area_ - void DestroyVpcOutputDesc(); - - Status ProcessDecode(); - Status ProcessResize(); - Status ProcessCrop(); - Status ProcessCropAndPaste(); - void DestroyResource(); - - Status GetJpegWidthHeight(const void *pic_buffer, size_t pic_buffer_size, uint32_t *image_width, - uint32_t *image_height); -}; -} // namespace mindspore::api - -#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_DVPP_PROCESS_H diff --git a/mindspore/ccsrc/cxx_api/model/acl/model_converter.cc b/mindspore/ccsrc/cxx_api/model/acl/model_converter.cc index ae506862544..a1e33d543da 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/model_converter.cc +++ b/mindspore/ccsrc/cxx_api/model/acl/model_converter.cc @@ -16,17 +16,13 @@ #include "cxx_api/model/acl/model_converter.h" #include -#include "pybind11/pybind11.h" #include "transform/graph_ir/convert.h" #include "transform/graph_ir/graph_runner.h" -#include "core/load_mindir/load_model.h" #include "mindspore/core/utils/ms_context.h" -#include "backend/kernel_compiler/oplib/oplib.h" - +#include "include/api/serialization.h" #include "graph/model.h" #include "cxx_api/model/model_converter_utils/multi_process.h" - -namespace py = pybind11; +#include "cxx_api/python_utils.h" namespace mindspore::api { namespace { @@ -74,19 +70,8 @@ bool CreateSessionAndGraphRunner() { return true; } - } // namespace -std::shared_ptr ModelConverter::ConvertMindIrToFuncGraph(const Buffer &model_data) { - try { - auto anf_graph = ConvertStreamToFuncGraph(reinterpret_cast(model_data.Data()), model_data.DataSize()); - return anf_graph; - } catch (std::exception &e) { - MS_LOG(ERROR) << "Load MindIR failed."; - return nullptr; - } -} - transform::DfGraphPtr ModelConverter::ConvertFuncGraphToAIR(const FuncGraphPtr &anf_graph) { for (auto &anf_node : anf_graph->parameters()) { MS_EXCEPTION_IF_NULL(anf_node); @@ -166,88 +151,31 @@ Buffer ModelConverter::BuildAirModel(const transform::DfGraphPtr &graph, return Buffer(model.data.get(), model.length); } -void ModelConverter::RegAllOp() { - static std::mutex init_mutex; - static bool Initialized = false; - - std::lock_guard lock(init_mutex); - if (Initialized) { - return; - } - Initialized = true; - MsContext::GetInstance()->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); - Py_Initialize(); - auto c_expression = PyImport_ImportModule("mindspore._c_expression"); - MS_EXCEPTION_IF_NULL(c_expression); - PyObject *c_expression_dict = PyModule_GetDict(c_expression); - MS_EXCEPTION_IF_NULL(c_expression_dict); - - PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy"); - MS_EXCEPTION_IF_NULL(op_info_loader_class); - PyObject *op_info_loader = PyInstanceMethod_New(op_info_loader_class); - MS_EXCEPTION_IF_NULL(op_info_loader); - PyObject *op_info_loader_ins = PyObject_CallObject(op_info_loader, nullptr); - MS_EXCEPTION_IF_NULL(op_info_loader_ins); - auto all_ops_info_vector_addr_ul = PyObject_CallMethod(op_info_loader_ins, "get_all_ops_info", nullptr); - MS_EXCEPTION_IF_NULL(all_ops_info_vector_addr_ul); - auto all_ops_info_vector_addr = PyLong_AsVoidPtr(all_ops_info_vector_addr_ul); - auto all_ops_info = static_cast *>(all_ops_info_vector_addr); - for (auto op_info : *all_ops_info) { - kernel::OpLib::RegOpInfo(std::shared_ptr(op_info)); - } - all_ops_info->clear(); - delete all_ops_info; - Py_DECREF(op_info_loader); - Py_DECREF(op_info_loader_class); - Py_DECREF(c_expression_dict); - Py_DECREF(c_expression); -} - -Buffer ModelConverter::ReadFile(const std::string &file) { - Buffer buffer; - if (file.empty()) { - MS_LOG(ERROR) << "Pointer file is nullptr"; - return buffer; - } - std::string realPath = file; - std::ifstream ifs(realPath); - if (!ifs.good()) { - MS_LOG(ERROR) << "File: " << realPath << " is not exist"; - return buffer; - } - - if (!ifs.is_open()) { - MS_LOG(ERROR) << "File: " << realPath << "open failed"; - return buffer; - } - - ifs.seekg(0, std::ios::end); - size_t size = ifs.tellg(); - buffer.ResizeData(size); - if (buffer.DataSize() != size) { - MS_LOG(ERROR) << "Malloc buf failed, file: " << realPath; - ifs.close(); - return buffer; - } - - ifs.seekg(0, std::ios::beg); - ifs.read(reinterpret_cast(buffer.MutableData()), size); - ifs.close(); - - return buffer; -} - -Buffer ModelConverter::LoadMindIR(const Buffer &model_data) { - if (Py_IsInitialized() == 0) { +Buffer ModelConverter::LoadMindIR(const FuncGraphPtr &func_graph) { + if (!PythonIsInited()) { MS_LOG_INFO << "Call LoadMindIRInner directly"; - return LoadMindIRInner(model_data); + return LoadMindIRInner(func_graph); } MultiProcess multi_process; Buffer buffer_ret; - auto parent_process = [&model_data, &buffer_ret](MultiProcess *multi_process) -> Status { + auto parent_process = [&func_graph, &buffer_ret, this](MultiProcess *multi_process) -> Status { MS_EXCEPTION_IF_NULL(multi_process); + auto df_graph = ConvertFuncGraphToAIR(func_graph); + if (df_graph == nullptr) { + MS_LOG(ERROR) << "Convert FuncGraph to AscendIR failed."; + return FAILED; + } + ge::Model model; + ge::Buffer model_data; + model.SetGraph(*df_graph); + auto ge_ret = model.Save(model_data); + if (ge_ret != ge::SUCCESS) { + MS_LOG(ERROR) << "Save ge model to buffer failed."; + return FAILED; + } + // send original model to child - auto status = multi_process->SendMsg(model_data.Data(), model_data.DataSize()); + auto status = multi_process->SendMsg(model_data.data(), model_data.size()); if (!status.IsSuccess()) { MS_LOG_ERROR << "Send original model to child process failed"; return FAILED; @@ -277,7 +205,7 @@ Buffer ModelConverter::LoadMindIR(const Buffer &model_data) { MS_LOG_ERROR << "Receive original model from parent process failed"; return FAILED; } - Buffer model_result = LoadMindIRInner(model); + Buffer model_result = LoadAscendIRInner(model); if (model_result.DataSize() == 0) { MS_LOG_ERROR << "Convert model from MindIR to OM failed"; return FAILED; @@ -300,7 +228,7 @@ Buffer ModelConverter::LoadMindIR(const Buffer &model_data) { } Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) { - if (Py_IsInitialized() == 0) { + if (!PythonIsInited()) { MS_LOG_INFO << "Call LoadAscendIRInner directly"; return LoadAscendIRInner(model_data); } @@ -361,10 +289,8 @@ Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) { return buffer_ret; } -Buffer ModelConverter::LoadMindIRInner(const Buffer &model_data) { - RegAllOp(); - Py_Initialize(); - auto func_graph = ConvertMindIrToFuncGraph(model_data); +Buffer ModelConverter::LoadMindIRInner(const FuncGraphPtr &func_graph) { + RegAllOpFromPython(); if (func_graph == nullptr) { MS_LOG(ERROR) << "Convert MindIR to FuncGraph failed."; return Buffer(); @@ -386,7 +312,7 @@ Buffer ModelConverter::LoadMindIRInner(const Buffer &model_data) { } Buffer ModelConverter::LoadAscendIRInner(const Buffer &model_data) { - RegAllOp(); + RegAllOpFromPython(); ge::Model load_model = ge::Model("loadmodel", "version2"); ge::Status ret = ge::Model::Load(reinterpret_cast(model_data.Data()), model_data.DataSize(), load_model); diff --git a/mindspore/ccsrc/cxx_api/model/acl/model_converter.h b/mindspore/ccsrc/cxx_api/model/acl/model_converter.h index 21d34ed3366..410b1e7413a 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/model_converter.h +++ b/mindspore/ccsrc/cxx_api/model/acl/model_converter.h @@ -32,21 +32,17 @@ class ModelConverter { public: ModelConverter() : options_(nullptr) {} - Buffer LoadMindIR(const Buffer &model_data); + Buffer LoadMindIR(const FuncGraphPtr &func_graph); Buffer LoadAscendIR(const Buffer &model_data); void set_options(AclModelOptions *options) { options_ = options; } - static Buffer ReadFile(const std::string &file); - static void RegAllOp(); - private: - std::shared_ptr ConvertMindIrToFuncGraph(const Buffer &model_data); transform::DfGraphPtr ConvertFuncGraphToAIR(const FuncGraphPtr &anf_graph); Buffer BuildAirModel(const transform::DfGraphPtr &graph, const std::map &acl_options); AclModelOptions *options_; - Buffer LoadMindIRInner(const Buffer &model_data); + Buffer LoadMindIRInner(const FuncGraphPtr &func_graph); Buffer LoadAscendIRInner(const Buffer &model_data); }; } // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/model/model.cc b/mindspore/ccsrc/cxx_api/model/model.cc index 0d3e5aee62d..b3fc97ef218 100644 --- a/mindspore/ccsrc/cxx_api/model/model.cc +++ b/mindspore/ccsrc/cxx_api/model/model.cc @@ -14,93 +14,59 @@ * limitations under the License. */ #include "include/api/model.h" +#include "include/api/context.h" #include "cxx_api/model/model_impl.h" +#include "cxx_api/factory.h" #include "utils/utils.h" namespace mindspore::api { -const char *kDeviceTypeAscendCL = "AscendCL"; -const char *kDeviceTypeAscendMS = "AscendMS"; - -Status Model::LoadModel(const Buffer &model_data, ModelType type, const std::map &options) { +Status Model::Build(const std::map &options) { MS_EXCEPTION_IF_NULL(impl_); - return impl_->LoadModel(model_data, type, options); + return impl_->Build(options); } -Status Model::LoadModel(const std::string &file_name, ModelType type, - const std::map &options) { - MS_EXCEPTION_IF_NULL(impl_); - return impl_->LoadModel(file_name, type, options); -} - -Status Model::UnloadModel() { - MS_EXCEPTION_IF_NULL(impl_); - return impl_->UnloadModel(); -} - -Status Model::Train(const DataSet &dataset, std::map *outputs) { +Status Model::Train(const DataSet &dataset, bool data_sink, std::map *outputs) { MS_EXCEPTION_IF_NULL(impl_); return impl_->Train(dataset, outputs); } -Status Model::Eval(const DataSet &dataset, std::map *outputs) { +Status Model::Eval(const DataSet &dataset, bool data_sink, std::map *outputs) { MS_EXCEPTION_IF_NULL(impl_); return impl_->Eval(dataset, outputs); } -Status Model::Predict(const std::map &inputs, std::map *outputs) { +Status Model::Predict(const std::vector &inputs, std::vector *outputs) { MS_EXCEPTION_IF_NULL(impl_); return impl_->Predict(inputs, outputs); } -Status Model::Predict(const std::vector &inputs, std::map *outputs) { - std::vector tensor_list; - auto ret = GetInputsInfo(&tensor_list); - if (ret != SUCCESS) { - MS_LOG(ERROR) << "GetInputsInfo failed."; - return ret; - } - - if (inputs.size() != tensor_list.size()) { - MS_LOG(ERROR) << "Model need " << tensor_list.size() << " inputs, but given " << inputs.size(); - return FAILED; - } - - std::map inputs_with_map; - for (size_t i = 0; i < tensor_list.size(); ++i) { - inputs_with_map.emplace(tensor_list[i].Name(), inputs[i]); - } - - return Predict(inputs_with_map, outputs); -} - -Status Model::GetInputsInfo(std::vector *tensor_list) const { +Status Model::GetInputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) const { MS_EXCEPTION_IF_NULL(impl_); - return impl_->GetInputsInfo(tensor_list); + return impl_->GetInputsInfo(names, shapes, data_types, mem_sizes); } -Status Model::GetOutputsInfo(std::vector *tensor_list) const { +Status Model::GetOutputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) const { MS_EXCEPTION_IF_NULL(impl_); - return impl_->GetOutputsInfo(tensor_list); + return impl_->GetOutputsInfo(names, shapes, data_types, mem_sizes); } -Model::Model(const std::string &device_type, uint32_t device_id) - : impl_(ModelFactory::Instance().Create(device_type, device_id)) { +Model::Model(const GraphCell &graph_cell) + : impl_(Factory::Instance().Create(Context::Instance().GetDeviceTarget())) { if (impl_ == nullptr) { - MS_LOG(EXCEPTION) << "Create session type " << device_type << " failed"; + MS_LOG(EXCEPTION) << "Create session type " << Context::Instance().GetDeviceTarget() << " failed"; } + MS_EXCEPTION_IF_NULL(graph_cell.GetGraph()); + impl_->SetGraph(std::make_shared(*graph_cell.GetGraph())); } -Model::Model(NetWork network, const std::string &device_type, uint32_t device_id) { - // todo - if (impl_ == nullptr) { - MS_LOG(EXCEPTION) << "Create session type " << device_type << " failed"; - } -} +Model::Model(const std::vector &network) { MS_LOG(EXCEPTION) << "Unsupported feature."; } Model::~Model() {} -bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) { - return ModelFactory::Instance().CheckModelSupport(device_type, model_type); +bool Model::CheckModelSupport(const std::string &device_type, ModelType) { + return Factory::Instance().CheckModelSupport(device_type); } } // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/model/model_impl.h b/mindspore/ccsrc/cxx_api/model/model_impl.h index d513a9a4c81..5ada9782b50 100644 --- a/mindspore/ccsrc/cxx_api/model/model_impl.h +++ b/mindspore/ccsrc/cxx_api/model/model_impl.h @@ -22,7 +22,10 @@ #include #include #include "include/api/model.h" +#include "include/api/graph.h" +#include "cxx_api/graph/graph_data.h" #include "utils/utils.h" +#include "ir/func_graph.h" namespace mindspore::api { class ModelImpl { @@ -30,70 +33,39 @@ class ModelImpl { ModelImpl() = default; virtual ~ModelImpl() = default; - virtual Status LoadModel(const Buffer &model_data, ModelType type, - const std::map &options) = 0; - virtual Status LoadModel(const std::string &file_name, ModelType type, - const std::map &options) = 0; - virtual Status UnloadModel() = 0; + virtual Status Build(const std::map &options) = 0; virtual Status Train(const DataSet &dataset, std::map *outputs) = 0; virtual Status Eval(const DataSet &dataset, std::map *outputs) = 0; - virtual Status Predict(const std::map &inputs, std::map *outputs) = 0; + virtual Status Predict(const std::vector &inputs, std::vector *outputs) = 0; - virtual Status GetInputsInfo(std::vector *tensor_list) const = 0; - virtual Status GetOutputsInfo(std::vector *tensor_list) const = 0; -}; + virtual Status GetInputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) const = 0; + virtual Status GetOutputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) const = 0; -using ModelCreator = std::function(uint32_t device_id)>; -class ModelFactory { - public: - ModelFactory(const ModelFactory &) = delete; - void operator=(const ModelFactory &) = delete; - - static ModelFactory &Instance() { - static ModelFactory instance; - return instance; + protected: + Status Load(const std::shared_ptr &graph_cell) { + MS_EXCEPTION_IF_NULL(graph_cell); + return graph_cell->Load(); } - void Register(const std::string &device_name, ModelCreator &&model_creator) { - if (model_creators_.find(device_name) == model_creators_.end()) { - (void)model_creators_.emplace(device_name, model_creator); + FuncGraphPtr GetFuncGraph() const { + if (graph_->ModelType() != ModelType::kMindIR) { + return nullptr; } + + auto graph_data = graph_->graph_data_; + MS_EXCEPTION_IF_NULL(graph_data); + return graph_data->GetFuncGraph(); } - std::shared_ptr Create(const std::string &device_name, uint32_t device_id) { - auto iter = model_creators_.find(device_name); - if (model_creators_.end() != iter) { - MS_EXCEPTION_IF_NULL(iter->second); - return (iter->second)(device_id); - } - return nullptr; - } - - bool CheckModelSupport(const std::string &device_type, ModelType /*model_type*/) { - return std::any_of( - model_creators_.begin(), model_creators_.end(), - [&device_type](const std::pair &item) { return item.first == device_type; }); - } + std::shared_ptr graph_; private: - ModelFactory() = default; - ~ModelFactory() = default; - std::map model_creators_; + friend class Model; + void SetGraph(const std::shared_ptr &graph) { graph_ = graph; } }; - -class ModelRegistrar { - public: - ModelRegistrar(const std::string &device_name, ModelCreator model_creator) { - ModelFactory::Instance().Register(device_name, std::move(model_creator)); - } - ~ModelRegistrar() = default; -}; - -#define API_REG_MODEL(DEVICE_NAME, MODEL_CLASS) \ - static const ModelRegistrar g_api_model_registrar__##DEVICE_NAME##_##_reg( \ - kDeviceType##DEVICE_NAME, [](uint32_t device_id) { return std::make_shared(device_id); }); - } // namespace mindspore::api #endif // MINDSPORE_CCSRC_CXX_API_MODEL_MODEL_IMPL_H diff --git a/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc b/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc index f6dd6828c2e..034d464d6ba 100644 --- a/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc +++ b/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc @@ -16,164 +16,33 @@ #include "cxx_api/model/ms/ms_model.h" #include -#include -#include - -#include "load_mindir/load_model.h" -#include "backend/session/session_basic.h" -#include "backend/session/session_factory.h" -#include "backend/session/executor_manager.h" -#include "base/base_ref_utils.h" -#include "backend/kernel_compiler/oplib/oplib.h" -#include "utils/context/context_extends.h" -#include "runtime/device/kernel_runtime_manager.h" - -#include "pybind11/pybind11.h" -#include "pybind11/embed.h" - -#ifdef ENABLE_D #include "utils/ms_context.h" -#endif +#include "cxx_api/factory.h" -using std::string; -using std::vector; - -namespace py = pybind11; namespace mindspore { namespace api { -MsModel::MsModel(uint32_t device_id) : device_id_(device_id) {} -MsModel::~MsModel() = default; +API_FACTORY_REG(ModelImpl, Ascend910, MsModel); -TypeId TransInferDataType2TypeId(DataType data_type) { - const std::map type2id_map{ - {api::kMsUnknown, TypeId::kNumberTypeBegin}, {api::kMsBool, TypeId::kNumberTypeBool}, - {api::kMsInt8, TypeId::kNumberTypeInt8}, {api::kMsUint8, TypeId::kNumberTypeUInt8}, - {api::kMsInt16, TypeId::kNumberTypeInt16}, {api::kMsUint16, TypeId::kNumberTypeUInt16}, - {api::kMsInt32, TypeId::kNumberTypeInt32}, {api::kMsUint32, TypeId::kNumberTypeUInt32}, - {api::kMsInt64, TypeId::kNumberTypeInt64}, {api::kMsUint64, TypeId::kNumberTypeUInt64}, - {api::kMsFloat16, TypeId::kNumberTypeFloat16}, {api::kMsFloat32, TypeId::kNumberTypeFloat32}, - {api::kMsFloat64, TypeId::kNumberTypeFloat64}, - }; - auto it = type2id_map.find(data_type); - if (it == type2id_map.end()) { - MS_LOG_WARNING << "Unsupported MSI data type " << data_type; - return TypeId::kNumberTypeBegin; - } else { - return it->second; - } -} +Status MsModel::Build(const std::map &) { + MS_LOG(INFO) << "Start build model."; + MS_EXCEPTION_IF_NULL(graph_); -DataType TransTypeId2InferDataType(TypeId type_id) { - const std::map id2type_map{ - {TypeId::kNumberTypeBegin, api::kMsUnknown}, {TypeId::kNumberTypeBool, api::kMsBool}, - {TypeId::kNumberTypeFloat64, api::kMsFloat64}, {TypeId::kNumberTypeInt8, api::kMsInt8}, - {TypeId::kNumberTypeUInt8, api::kMsUint8}, {TypeId::kNumberTypeInt16, api::kMsInt16}, - {TypeId::kNumberTypeUInt16, api::kMsUint16}, {TypeId::kNumberTypeInt32, api::kMsInt32}, - {TypeId::kNumberTypeUInt32, api::kMsUint32}, {TypeId::kNumberTypeInt64, api::kMsInt64}, - {TypeId::kNumberTypeUInt64, api::kMsUint64}, {TypeId::kNumberTypeFloat16, api::kMsFloat16}, - {TypeId::kNumberTypeFloat32, api::kMsFloat32}, - }; - auto it = id2type_map.find(type_id); - if (it == id2type_map.end()) { - MS_LOG_WARNING << "Unsupported data id " << type_id; - return api::kMsUnknown; - } else { - return it->second; - } -} + auto func_graph = ModelImpl::GetFuncGraph(); + MS_EXCEPTION_IF_NULL(func_graph); -Buffer MsModel::ReadFile(const std::string &file) { - if (file.empty()) { - MS_LOG(ERROR) << "file is nullptr"; - return Buffer(); - } - std::ifstream ifs(file); - if (!ifs.good()) { - MS_LOG(ERROR) << "file: " << file << " is not exist"; - return Buffer(); - } - - if (!ifs.is_open()) { - MS_LOG(ERROR) << "file: " << file << "open failed"; - return Buffer(); - } - - ifs.seekg(0, std::ios::end); - size_t size = ifs.tellg(); - Buffer buffer; - buffer.ResizeData(size); - ifs.seekg(0, std::ios::beg); - ifs.read(static_cast(buffer.MutableData()), size); - ifs.close(); - - return buffer; -} - -Status MsModel::LoadModel(const Buffer &model_data, ModelType type, const std::map &options) { - auto status = InitEnv({}); - if (status != SUCCESS) { - MS_LOG(ERROR) << "Init env failed"; - return FAILED; - } - std::shared_ptr anf_graph; - Py_Initialize(); - try { - anf_graph = ConvertStreamToFuncGraph(static_cast(model_data.Data()), model_data.DataSize()); - } catch (std::exception &e) { - MS_LOG(ERROR) << "Inference LoadModel failed"; - return FAILED; - } - Status ret = CompileGraph(anf_graph); + auto graph = std::make_shared(std::make_shared(func_graph, ModelType::kMindIR)); + MS_EXCEPTION_IF_NULL(graph); + auto graph_cell = std::make_shared(graph); + MS_EXCEPTION_IF_NULL(graph_cell); + auto ret = ModelImpl::Load(graph_cell); if (ret != SUCCESS) { - MS_LOG(ERROR) << "Compile graph model failed"; - return FAILED; + MS_LOG(ERROR) << "Load failed."; + return ret; } - session_impl_->GetModelInputsInfo(graph_id_, &inputs_, &input_names_); - session_impl_->GetModelOutputsInfo(graph_id_, &outputs_, &output_names_); - if (inputs_.empty() || inputs_.size() != input_names_.size()) { - MS_LOG_ERROR << "Get model inputs info failed"; - return FAILED; - } - if (outputs_.empty() || outputs_.size() != output_names_.size()) { - MS_LOG_ERROR << "Get model outputs info failed"; - return FAILED; - } - MS_LOG(INFO) << "Load model success"; -#ifdef ENABLE_D - // set d context - rtError_t rt_ret = rtCtxGetCurrent(&context_); - if (rt_ret != RT_ERROR_NONE || context_ == nullptr) { - MS_LOG(ERROR) << "the ascend device context is null"; - return FAILED; - } -#endif - load_flag_ = true; - return SUCCESS; -} - -Status MsModel::LoadModel(const std::string &file_name, ModelType type, - const std::map &options) { - auto graphBuf = ReadFile(file_name); - if (graphBuf.DataSize() == 0) { - MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str(); - return FAILED; - } - auto status = LoadModel(graphBuf, type, options); - if (status != SUCCESS) { - MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); - return FAILED; - } - return SUCCESS; -} - -Status MsModel::UnloadModel() { - if (!load_flag_) { - MS_LOG_ERROR << "Model has not been loaded"; - return FAILED; - } - FinalizeEnv(); - load_flag_ = false; + // save result + graph_cell_ = graph_cell; + MS_LOG(INFO) << "Build model success."; return SUCCESS; } @@ -187,231 +56,42 @@ Status MsModel::Eval(const DataSet &, std::map *) { return FAILED; } -Status MsModel::Predict(const std::map &inputs, std::map *outputs) { +Status MsModel::Predict(const std::vector &inputs, std::vector *outputs) { MS_EXCEPTION_IF_NULL(outputs); - if (!load_flag_) { - MS_LOG(ERROR) << "No model is loaded, predict failed."; + if (graph_ == nullptr) { + MS_LOG(ERROR) << "Invalid data, graph_ is null."; return FAILED; } - if (inputs.size() != inputs_.size()) { - MS_LOG(ERROR) << "inputs count not match, required count " << inputs_.size() << ", given count " << inputs.size(); - return INVALID_INPUTS; - } - std::vector request; - std::vector reply; - for (size_t i = 0; i < inputs_.size(); ++i) { - const auto &input_name = input_names_[i]; - auto iter = inputs.find(input_name); - if (iter == inputs.end()) { - MS_LOG(ERROR) << "Model missing input " << input_name; - return INVALID_INPUTS; - } - if (iter->second.DataSize() != inputs_[i]->Size()) { - MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_[i]->Size() << ", given count " - << iter->second.DataSize(); - return INVALID_INPUTS; - } - request.push_back(iter->second); - } - if (ExecuteModel(request, &reply) != SUCCESS) { - MS_LOG(ERROR) << "Execute Model Failed"; - return FAILED; - } - if (outputs_.size() != reply.size()) { - MS_LOG(ERROR) << "Predict output size " << reply.size() << " not match output size got from model info " - << outputs_.size(); - return FAILED; - } - outputs->clear(); - for (size_t i = 0; i < reply.size(); i++) { - outputs->emplace(output_names_[i], reply[i]); - } - return SUCCESS; -} - -Status MsModel::ExecuteModel(const std::vector &request, std::vector *reply) { - MS_EXCEPTION_IF_NULL(reply); -#ifdef ENABLE_D - if (context_ == nullptr) { - MS_LOG(ERROR) << "rtCtx is nullptr"; - return FAILED; - } - rtError_t rt_ret = rtCtxSetCurrent(context_); - if (rt_ret != RT_ERROR_NONE) { - MS_LOG(ERROR) << "set Ascend rtCtx failed"; - return FAILED; - } -#endif - vector inputs; - for (size_t i = 0; i < request.size(); i++) { - auto &item = request[i]; - auto input = inputs_[i]; - if (input->Size() != item.DataSize()) { - MS_LOG(ERROR) << "Predict input " << i << " data size " << item.DataSize() << " not match model input data size " - << input->Size(); - return FAILED; - } - auto ret = memcpy_s(input->data_c(), input->Size(), item.Data(), item.DataSize()); + if (graph_cell_ == nullptr) { + MS_LOG(INFO) << "Model has not been built, it will be built with default options"; + Status ret = Build({}); if (ret != SUCCESS) { - MS_LOG(ERROR) << "Tensor copy failed"; + MS_LOG(ERROR) << "Build model failed."; return FAILED; } - inputs.push_back(input); } - vector outputs = RunGraph(inputs); - if (outputs.empty()) { - MS_LOG(ERROR) << "Execute Model Failed"; + + MS_EXCEPTION_IF_NULL(graph_cell_); + Status ret = graph_cell_->Run(inputs, outputs); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Run graph failed."; return FAILED; } - reply->clear(); - std::transform(outputs.begin(), outputs.end(), std::back_inserter(*reply), - [](const tensor::TensorPtr &tensor) { return Buffer(tensor->data_c(), tensor->Size()); }); + return SUCCESS; } -Status MsModel::FinalizeEnv() { - MS_LOG_INFO << "Start finalize env"; - py::gil_scoped_acquire acquire; - session::ExecutorManager::Instance().Clear(); - device::KernelRuntimeManager::Instance().ClearRuntimeResource(); - auto ms_context = MsContext::GetInstance(); - if (ms_context == nullptr) { - MS_LOG(ERROR) << "Get Context failed!"; - return FAILED; - } - if (!context::CloseTsd(ms_context)) { - MS_LOG(ERROR) << "Inference CloseTsd failed!"; - return FAILED; - } - MS_LOG_INFO << "End finalize env"; - return SUCCESS; +Status MsModel::GetInputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) const { + MS_EXCEPTION_IF_NULL(graph_cell_); + return graph_cell_->GetInputsInfo(names, shapes, data_types, mem_sizes); } -std::shared_ptr MsModel::LoadModel(const char *model_buf, size_t size, const std::string &device) { - Py_Initialize(); - MS_EXCEPTION_IF_NULL(model_buf); - try { - auto anf_graph = ConvertStreamToFuncGraph(model_buf, size); - return anf_graph; - } catch (std::exception &e) { - MS_LOG(ERROR) << "Inference LoadModel failed: " << e.what(); - return nullptr; - } -} - -void MsModel::RegAllOp() { - static std::mutex init_mutex; - static bool Initialized = false; - - std::lock_guard lock(init_mutex); - if (Initialized) { - return; - } - Initialized = true; - auto ms_context_instance = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context_instance); - ms_context_instance->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); - try { - std::shared_ptr guard; - if (Py_IsInitialized() == 0) { - guard = std::make_shared(); - } - py::module c_expression = py::module::import("mindspore._c_expression"); - size_t ops_info_long = c_expression.attr("OpInfoLoaderPy")().attr("get_all_ops_info")().cast(); - auto all_ops_info = reinterpret_cast *>(static_cast(ops_info_long)); - for (auto op_info : *all_ops_info) { - kernel::OpLib::RegOpInfo(std::shared_ptr(op_info)); - } - all_ops_info->clear(); - delete all_ops_info; - } catch (const std::runtime_error &ex) { - MS_LOG_EXCEPTION << ex.what(); - } -} - -Status MsModel::CompileGraph(std::shared_ptr funcGraphPtr) { - MS_ASSERT(session_impl_ != nullptr); - try { - graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); - py::gil_scoped_release gil_release; - return SUCCESS; - } catch (std::exception &e) { - MS_LOG(ERROR) << "Inference CompileGraph failed: " << e.what(); - return FAILED; - } -} - -std::vector MsModel::RunGraph(const std::vector &inputs) { - try { - VectorRef outputs; - session_impl_->RunGraph(graph_id_, inputs, &outputs); - return TransformVectorRefToMultiTensor(outputs); - } catch (std::exception &e) { - MS_LOG(ERROR) << "Inference Rungraph failed: " << e.what(); - return std::vector(); - } -} - -Status MsModel::InitEnv(const std::unordered_map &other_options) { - RegAllOp(); - auto ms_context = MsContext::GetInstance(); - if (ms_context == nullptr) { - MS_LOG(ERROR) << "Get Context failed!"; - return FAILED; - } - ms_context->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); - ms_context->set_param(MS_CTX_DEVICE_ID, device_id_); - ms_context->set_param(MS_CTX_DEVICE_TARGET, kAscendDevice); - if (!context::OpenTsd(ms_context)) { - MS_LOG(ERROR) << "Session init OpenTsd failed!"; - return FAILED; - } - session_impl_ = session::SessionFactory::Get().Create(kDavinciInferenceDevice); - if (session_impl_ == nullptr) { - MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << kDavinciInferenceDevice - << " is available."; - return FAILED; - } - session_impl_->Init(device_id_); - return SUCCESS; -} - -Status MsModel::CheckModelInputs(uint32_t graph_id, const std::vector &inputs) const { - MS_ASSERT(session_impl_ != nullptr); - std::string error_msg; - if (!session_impl_->CheckModelInputs(graph_id, inputs, &error_msg)) { - return Status(INVALID_INPUTS, error_msg); - } - return SUCCESS; -} - -Status MsModel::GetInputsInfo(std::vector *tensor_list) const { - MS_EXCEPTION_IF_NULL(tensor_list); - tensor_list->clear(); - for (size_t i = 0; i < inputs_.size(); i++) { - auto &tensor = inputs_[i]; - Tensor infer_tensor; - infer_tensor.SetName(input_names_[i]); - infer_tensor.SetDataType(TransTypeId2InferDataType(tensor->data_type())); - infer_tensor.SetShape(tensor->shape()); - tensor_list->push_back(infer_tensor); - } - return SUCCESS; -} - -Status MsModel::GetOutputsInfo(std::vector *tensor_list) const { - MS_EXCEPTION_IF_NULL(tensor_list); - tensor_list->clear(); - for (size_t i = 0; i < outputs_.size(); i++) { - auto &tensor = outputs_[i]; - Tensor infer_tensor; - infer_tensor.SetName(output_names_[i]); - infer_tensor.SetDataType(TransTypeId2InferDataType(tensor->data_type())); - infer_tensor.SetShape(tensor->shape()); - tensor_list->push_back(infer_tensor); - } - return SUCCESS; +Status MsModel::GetOutputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) const { + MS_EXCEPTION_IF_NULL(graph_cell_); + return graph_cell_->GetOutputsInfo(names, shapes, data_types, mem_sizes); } } // namespace api } // namespace mindspore diff --git a/mindspore/ccsrc/cxx_api/model/ms/ms_model.h b/mindspore/ccsrc/cxx_api/model/ms/ms_model.h index 161aae5c200..747ff0da8b1 100644 --- a/mindspore/ccsrc/cxx_api/model/ms/ms_model.h +++ b/mindspore/ccsrc/cxx_api/model/ms/ms_model.h @@ -36,49 +36,23 @@ namespace mindspore { namespace api { class MsModel : public ModelImpl { public: - explicit MsModel(uint32_t device_id); - ~MsModel(); + MsModel() {} + ~MsModel() = default; - Status LoadModel(const Buffer &model_data, ModelType type, - const std::map &options) override; - Status LoadModel(const std::string &file_name, ModelType type, - const std::map &options) override; - Status UnloadModel() override; + Status Build(const std::map &options_map) override; Status Train(const DataSet &dataset, std::map *outputs) override; Status Eval(const DataSet &dataset, std::map *outputs) override; - Status Predict(const std::map &inputs, std::map *outputs) override; + Status Predict(const std::vector &inputs, std::vector *outputs) override; - Status GetInputsInfo(std::vector *tensor_list) const override; - Status GetOutputsInfo(std::vector *tensor_list) const override; - - Status InitEnv(const std::unordered_map &other_options); - Status FinalizeEnv(); + Status GetInputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) const override; + Status GetOutputsInfo(std::vector *names, std::vector> *shapes, + std::vector *data_types, std::vector *mem_sizes) const override; private: - std::shared_ptr session_impl_ = nullptr; - uint32_t graph_id_; - std::string device_type_; - int32_t device_id_ = 0; -#ifdef ENABLE_D - rtContext_t context_ = nullptr; -#endif - std::vector inputs_; - std::vector outputs_; - std::vector input_names_; - std::vector output_names_; - bool load_flag_ = false; - - std::shared_ptr LoadModel(const char *model_buf, size_t size, const std::string &device); - Buffer ReadFile(const std::string &file); - static void RegAllOp(); - Status CompileGraph(std::shared_ptr funcGraphPtr); - Status CheckModelInputs(uint32_t graph_id, const std::vector &inputs) const; - std::vector RunGraph(const std::vector &inputs); - Status ExecuteModel(const std::vector &inputs, std::vector *outputs); + std::shared_ptr graph_cell_; }; - -API_REG_MODEL(AscendMS, MsModel); } // namespace api } // namespace mindspore #endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H diff --git a/mindspore/ccsrc/cxx_api/python_utils.cc b/mindspore/ccsrc/cxx_api/python_utils.cc new file mode 100644 index 00000000000..0a37f762605 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/python_utils.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "cxx_api/python_utils.h" +#include +#include +#include +#include "mindspore/core/utils/ms_context.h" +#include "pybind11/pybind11.h" +#include "backend/kernel_compiler/oplib/oplib.h" + +namespace py = pybind11; + +namespace mindspore::api { +void RegAllOpFromPython() { + static std::mutex init_mutex; + static bool Initialized = false; + + std::lock_guard lock(init_mutex); + if (Initialized) { + return; + } + Initialized = true; + MsContext::GetInstance()->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); + Py_Initialize(); + auto c_expression = PyImport_ImportModule("mindspore._c_expression"); + MS_EXCEPTION_IF_NULL(c_expression); + PyObject *c_expression_dict = PyModule_GetDict(c_expression); + MS_EXCEPTION_IF_NULL(c_expression_dict); + + PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy"); + MS_EXCEPTION_IF_NULL(op_info_loader_class); + PyObject *op_info_loader = PyInstanceMethod_New(op_info_loader_class); + MS_EXCEPTION_IF_NULL(op_info_loader); + PyObject *op_info_loader_ins = PyObject_CallObject(op_info_loader, nullptr); + MS_EXCEPTION_IF_NULL(op_info_loader_ins); + auto all_ops_info_vector_addr_ul = PyObject_CallMethod(op_info_loader_ins, "get_all_ops_info", nullptr); + MS_EXCEPTION_IF_NULL(all_ops_info_vector_addr_ul); + auto all_ops_info_vector_addr = PyLong_AsVoidPtr(all_ops_info_vector_addr_ul); + auto all_ops_info = static_cast *>(all_ops_info_vector_addr); + for (auto op_info : *all_ops_info) { + kernel::OpLib::RegOpInfo(std::shared_ptr(op_info)); + } + all_ops_info->clear(); + delete all_ops_info; + Py_DECREF(op_info_loader); + Py_DECREF(op_info_loader_class); + Py_DECREF(c_expression_dict); + Py_DECREF(c_expression); +} + +bool PythonIsInited() { return Py_IsInitialized() != 0; } +} // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/python_utils.h b/mindspore/ccsrc/cxx_api/python_utils.h new file mode 100644 index 00000000000..5b8bee2f5c3 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/python_utils.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H +#define MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H + +#include "pybind11/pybind11.h" + +namespace mindspore::api { +void RegAllOpFromPython(); +bool PythonIsInited(); +} // namespace mindspore::api + +#endif // MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H diff --git a/mindspore/ccsrc/cxx_api/serialization.cc b/mindspore/ccsrc/cxx_api/serialization.cc index 2bd1be56b2a..e640328b237 100644 --- a/mindspore/ccsrc/cxx_api/serialization.cc +++ b/mindspore/ccsrc/cxx_api/serialization.cc @@ -14,9 +14,77 @@ * limitations under the License. */ #include "include/api/serialization.h" +#include +#include "cxx_api/graph/graph_data.h" #include "utils/log_adapter.h" +#include "mindspore/core/load_mindir/load_model.h" namespace mindspore::api { +static Buffer ReadFile(const std::string &file) { + Buffer buffer; + if (file.empty()) { + MS_LOG(ERROR) << "Pointer file is nullptr"; + return buffer; + } + + char real_path_mem[PATH_MAX] = {0}; + char *real_path_ret = nullptr; +#if defined(_WIN32) || defined(_WIN64) + real_path_ret = _fullpath(real_path_mem, common::SafeCStr(file), PATH_MAX); +#else + real_path_ret = realpath(common::SafeCStr(file), real_path_mem); +#endif + + if (real_path_ret == nullptr) { + MS_LOG(ERROR) << "File: " << file << " is not exist."; + return buffer; + } + + std::string real_path(real_path_mem); + std::ifstream ifs(real_path); + if (!ifs.good()) { + MS_LOG(ERROR) << "File: " << real_path << " is not exist"; + return buffer; + } + + if (!ifs.is_open()) { + MS_LOG(ERROR) << "File: " << real_path << "open failed"; + return buffer; + } + + ifs.seekg(0, std::ios::end); + size_t size = ifs.tellg(); + buffer.ResizeData(size); + if (buffer.DataSize() != size) { + MS_LOG(ERROR) << "Malloc buf failed, file: " << real_path; + ifs.close(); + return buffer; + } + + ifs.seekg(0, std::ios::beg); + ifs.read(reinterpret_cast(buffer.MutableData()), size); + ifs.close(); + + return buffer; +} + +Graph Serialization::LoadModel(const std::string &file, ModelType model_type) { + Buffer data = ReadFile(file); + if (model_type == kMindIR) { + FuncGraphPtr anf_graph = nullptr; + try { + anf_graph = ConvertStreamToFuncGraph(reinterpret_cast(data.Data()), data.DataSize()); + } catch (std::exception &e) { + MS_LOG(ERROR) << "Load MindIR failed."; + } + + return Graph(std::make_shared(anf_graph, kMindIR)); + } else if (model_type == kOM) { + return Graph(std::make_shared(data, kOM)); + } + MS_LOG(EXCEPTION) << "Unsupported ModelType " << model_type; +} + Status Serialization::LoadCheckPoint(const std::string &ckpt_file, std::map *parameters) { MS_LOG(ERROR) << "Unsupported feature."; return FAILED; diff --git a/mindspore/ccsrc/cxx_api/types.cc b/mindspore/ccsrc/cxx_api/types.cc index 03c3aa2dbb5..74d4c1bb99e 100644 --- a/mindspore/ccsrc/cxx_api/types.cc +++ b/mindspore/ccsrc/cxx_api/types.cc @@ -19,6 +19,9 @@ #include "utils/utils.h" namespace mindspore::api { +const char *kDeviceTypeAscend310 = "Ascend310"; +const char *kDeviceTypeAscend910 = "Ascend910"; + class DataImpl { public: DataImpl() : data_() {} diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index 62c1c164a70..71fab38187b 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -422,7 +422,6 @@ inline ValuePtr MakeValue(S v) { template ::type> static S GetValue(const ValuePtr &value) { MS_EXCEPTION_IF_NULL(value); - U imm = value->cast(); if (imm == nullptr) { MS_LOG(EXCEPTION) << "Cast failed, original value: " << value->ToString() << ", type: " << value->type_name(); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d73442fc402..5c2ae3d9110 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,5 +1,10 @@ #add flags set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-compare") -add_subdirectory("ut") - +if (ENABLE_ACL) + add_subdirectory(cxx_st) +elseif (ENABLE_GPU OR ENABLE_D OR ENABLE_CPU) + message(fatal "No need set -e xxx when compile ut") +else () + add_subdirectory(ut) +endif() diff --git a/tests/cxx_st/CMakeLists.txt b/tests/cxx_st/CMakeLists.txt new file mode 100644 index 00000000000..c6602db41d6 --- /dev/null +++ b/tests/cxx_st/CMakeLists.txt @@ -0,0 +1,11 @@ +include_directories(${PYTHON_INCLUDE_DIRS}) +include_directories(${MS_CCSRC_PATH}) +include_directories(${CMAKE_SOURCE_DIR}/mindspore/core) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/stub/runtime/) +include_directories(${CMAKE_BINARY_DIR}) +include_directories(${CUDA_INCLUDE_DIRS}) + +file(GLOB_RECURSE CXX_ST_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} *.cc) +add_executable(st_tests ${CXX_ST_SRC}) +target_link_libraries(st_tests PRIVATE mindspore_shared_lib mindspore::gtest) \ No newline at end of file diff --git a/tests/cxx_st/common/common_test.cc b/tests/cxx_st/common/common_test.cc new file mode 100644 index 00000000000..c50f2f62009 --- /dev/null +++ b/tests/cxx_st/common/common_test.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common_test.h" + +#ifdef __cplusplus +#if __cplusplus +extern "C" { +#endif +#endif + +namespace ST { + +void Common::SetUpTestCase() {} + +void Common::TearDownTestCase() {} + +void Common::SetUp() {} + +void Common::TearDown() {} + +} // namespace ST + +#ifdef __cplusplus +#if __cplusplus +} +#endif +#endif diff --git a/tests/cxx_st/common/common_test.h b/tests/cxx_st/common/common_test.h new file mode 100644 index 00000000000..0cc4fe2de76 --- /dev/null +++ b/tests/cxx_st/common/common_test.h @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef TESTS_CXX_ST_COMMON_COMMON_TEST_H_ +#define TESTS_CXX_ST_COMMON_COMMON_TEST_H_ + +#include +#include +#include +#include "gtest/gtest.h" +namespace ST { +class Common : public testing::Test { + public: + // TestCase only enter once + static void SetUpTestCase(); + static void TearDownTestCase(); + + // every TEST_F macro will enter one + virtual void SetUp(); + virtual void TearDown(); + + template + void PrintData(std::string name, T *output_data, int size) { + std::cout << "The " << name << " is as follows:" << std::endl; + if (typeid(output_data[0]) == typeid(uint8_t) || typeid(output_data[0]) == typeid(int8_t)) { + for (size_t i = 0; i < std::min(size, 100); i++) { + std::cout << (int)output_data[i] << " "; + } + } else { + for (size_t i = 0; i < std::min(size, 100); i++) { + std::cout << output_data[i] << " "; + } + } + std::cout << std::endl; + } + + template + static void CompareOutputData(T *output_data, T *correct_data, int size, float err_bound) { + for (size_t i = 0; i < size; i++) { + T abs = fabs(output_data[i] - correct_data[i]); + ASSERT_LE(abs, err_bound); + } + } + + void ReadFile(const char *file, size_t *size, char **buf) { + ASSERT_NE(nullptr, file); + ASSERT_NE(nullptr, size); + ASSERT_NE(nullptr, buf); + std::string path = std::string(file); + std::ifstream ifs(path); + ASSERT_EQ(true, ifs.good()); + ASSERT_EQ(true, ifs.is_open()); + + ifs.seekg(0, std::ios::end); + *size = ifs.tellg(); + *buf = new char[*size]; + + ifs.seekg(0, std::ios::beg); + ifs.read(*buf, *size); + ifs.close(); + } +}; +} // namespace ST +#endif // TESTS_CXX_ST_COMMON_COMMON_TEST_H_ diff --git a/tests/cxx_st/common/test_main.cc b/tests/cxx_st/common/test_main.cc new file mode 100644 index 00000000000..368fba65aab --- /dev/null +++ b/tests/cxx_st/common/test_main.cc @@ -0,0 +1,22 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "gtest/gtest.h" + +GTEST_API_ int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/tests/cxx_st/model/test_tensor_add.cc b/tests/cxx_st/model/test_tensor_add.cc new file mode 100644 index 00000000000..c820af51884 --- /dev/null +++ b/tests/cxx_st/model/test_tensor_add.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "common/common_test.h" +#include "include/api/model.h" +#include "include/api/serialization.h" +#include "include/api/context.h" + +using namespace mindspore::api; + +static const char tensor_add_file[] = "/home/workspace/mindspore_dataset/tensor_add/tensor_add.mindir"; +static const std::vector input_data_1 = {1, 2, 3, 4}; +static const std::vector input_data_2 = {2, 3, 4, 5}; + +class TestTensorAdd : public ST::Common { + public: + TestTensorAdd() {} +}; + +TEST_F(TestTensorAdd, InferMindIR) { + Context::Instance().SetDeviceTarget(kDeviceTypeAscend310).SetDeviceID(1); + auto graph = Serialization::LoadModel(tensor_add_file, ModelType::kMindIR); + Model tensor_add((GraphCell(graph))); + Status ret = tensor_add.Build({}); + ASSERT_TRUE(ret == SUCCESS); + + // prepare input + std::vector outputs; + std::vector inputs; + inputs.emplace_back(Buffer(input_data_1.data(), sizeof(float) * input_data_1.size())); + inputs.emplace_back(Buffer(input_data_2.data(), sizeof(float) * input_data_2.size())); + + // infer + ret = tensor_add.Predict(inputs, &outputs); + ASSERT_TRUE(ret == SUCCESS); + + // print + for (auto &buffer : outputs) { + const float *p = reinterpret_cast(buffer.Data()); + for (size_t i = 0; i < buffer.DataSize() / sizeof(float); ++i) { + ASSERT_LE(std::abs(p[i] - (input_data_1[i] + input_data_2[i])), 1e-4); + } + } +}