diff --git a/CMakeLists.txt b/CMakeLists.txt index cda70d5b74a..51bd064bb02 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -69,19 +69,9 @@ include_directories(${PYTHON_INCLUDE_DIRS}) set(MS_CCSRC_PATH ${CMAKE_SOURCE_DIR}/mindspore/ccsrc) set(MS_CCSRC_BUILD_PATH ${BUILD_PATH}/mindspore/mindspore/ccsrc) -if (ENABLE_GE) - link_directories(${CMAKE_SOURCE_DIR}/third_party/ge/lib) -elseif(ENABLE_D OR ENABLE_TESTCASES) +if (ENABLE_D OR ENABLE_ACL OR ENABLE_TESTCASES) include(${CMAKE_SOURCE_DIR}/cmake/dependency_graphengine.cmake) -endif() - -if (ENABLE_GE OR ENABLE_D OR ENABLE_TESTCASES) - include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/inc) - include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/inc/external) - include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/inc/framework) - include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc) - include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc/toolchain) -endif() +endif () set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden") add_subdirectory(mindspore/ccsrc) diff --git a/build.sh b/build.sh index 65b74a0eafb..2c1075c8f33 100755 --- a/build.sh +++ b/build.sh @@ -23,9 +23,9 @@ usage() { echo "Usage:" echo "bash build.sh [-d] [-r] [-v] [-c on|off] [-t on|off] [-g on|off] [-h] [-b ge] [-m infer|train] \\" - echo " [-a on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|d|cpu] \\" + echo " [-a on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|ascend|cpu|acl] \\" echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 9.2|10.1] [-I arm64|arm32|x86_64] [-K] \\" - echo " [-B on|off] [-w on|off] [-E] [-l on|off] [-n full|lite|off] [-T on|off] \\" + echo " [-B on|off] [-E] [-l on|off] [-n full|lite|off] [-T on|off] \\" echo " [-A [cpp|java|object-c] [-C on|off] [-o on|off] [-S on|off] [-k on|off] [-W sse|neon|avx|off] \\" echo "" echo "Options:" @@ -45,7 +45,7 @@ usage() echo " -i Enable increment building, default off" echo " -L Enable load ANF-IR as input of 'infer', default off" echo " -j[n] Set the threads when building (Default: -j8)" - echo " -e Use gpu, d or cpu" + echo " -e Use cpu, gpu, ascend or acl" echo " -P Enable dump anf graph to file in ProtoBuffer format, default on" echo " -D Enable dumping of function graph ir, default on" echo " -z Compile dataset & mindrecord, default on" @@ -55,7 +55,6 @@ usage() echo " -I Enable compiling mindspore lite for arm64, arm32 or x86_64, default disable mindspore lite compilation" echo " -K Compile with AKG, default on" echo " -s Enable serving module, default off" - echo " -w Enable acl module, default off" echo " -B Enable debugger, default on" echo " -E Enable IBVERBS for parameter server, default off" echo " -l Compile with python dependency, default on" @@ -225,6 +224,9 @@ checkopts() ENABLE_D="on" ENABLE_CPU="on" ENABLE_SERVING="on" + elif [[ "X$OPTARG" == "Xacl" ]]; then + ENABLE_SERVING="on" + ENABLE_ACL="on" elif [[ "X$OPTARG" == "Xcpu" ]]; then ENABLE_CPU="on" else diff --git a/cmake/dependency_graphengine.cmake b/cmake/dependency_graphengine.cmake index 9ed90a66a31..b6912d869c4 100644 --- a/cmake/dependency_graphengine.cmake +++ b/cmake/dependency_graphengine.cmake @@ -11,7 +11,7 @@ include(${GE_SOURCE_DIR}/cmake/external_libs/onnx.cmake) include(${GE_SOURCE_DIR}/cmake/external_libs/securec.cmake) # for UT, find slog and error_manager from local prebuild -if (NOT ENABLE_D) +if (NOT ENABLE_D AND NOT ENABLE_ACL) set(GE_PREBUILD_PATH ${GE_SOURCE_DIR}/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}) find_library(slog libslog.so ${GE_PREBUILD_PATH}) find_library(error_manager liberror_manager.so ${GE_PREBUILD_PATH}) @@ -28,6 +28,7 @@ elseif (DEFINED ENV{D_LINK_PATH}) message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") endif() set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) + find_library(c_sec libc_sec.so ${GE_LIB_PATH}) find_library(slog libslog.so ${GE_LIB_PATH}) find_library(mmpa libmmpa.so ${GE_LIB_PATH}) find_library(runtime libruntime.so ${GE_LIB_PATH}) @@ -44,8 +45,8 @@ else() else() set(ASCEND_PATH /usr/local/Ascend) endif() - set(ASCEND_DRIVER_PATH ${ASCEND_PATH}/driver/lib64/common) - set(ASCEND_RUNTIME_PATH ${ASCEND_PATH}/fwkacllib/lib64) + set(ASCEND_DRIVER_PATH ${ASCEND_PATH}/driver/lib64/common ${ASCEND_PATH}/driver/lib64) + set(ASCEND_RUNTIME_PATH ${ASCEND_PATH}/fwkacllib/lib64 ${ASCEND_PATH}/acllib/lib64 ${ASCEND_PATH}/atc/lib64) find_library(c_sec libc_sec.so ${ASCEND_DRIVER_PATH}) find_library(slog libslog.so ${ASCEND_DRIVER_PATH}) find_library(mmpa libmmpa.so ${ASCEND_DRIVER_PATH}) @@ -76,9 +77,11 @@ string(REPLACE " -Werror" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") # force __FILE__ to show relative path of file, from source directory set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__FILE__='\"$(subst $(realpath ${CMAKE_SOURCE_DIR})/,,$(abspath $<))\"' -Wno-builtin-macro-redefined") add_subdirectory(${GE_SOURCE_DIR}/src/common/graph) -if(ENABLE_D) +if (ENABLE_ACL OR ENABLE_D) add_subdirectory(${GE_SOURCE_DIR}/src/ge/common) - add_subdirectory(${GE_SOURCE_DIR}/src/ge/ge_runtime) -endif() + if (ENABLE_D) + add_subdirectory(${GE_SOURCE_DIR}/src/ge/ge_runtime) + endif () +endif () set(CMAKE_CXX_FLAGS ${_ge_tmp_CMAKE_CXX_FLAGS}) diff --git a/cmake/mind_expression.cmake b/cmake/mind_expression.cmake index 87d082c81fd..0a18e56f38e 100644 --- a/cmake/mind_expression.cmake +++ b/cmake/mind_expression.cmake @@ -58,13 +58,22 @@ if (ENABLE_GE) include_directories(${CMAKE_SOURCE_DIR}/third_party/ge/include) include_directories(${CMAKE_SOURCE_DIR}/third_party/ge/include/external) include_directories(${CMAKE_SOURCE_DIR}/third_party/ge/include/external/graph) -elseif(ENABLE_D OR ENABLE_TESTCASES) + link_directories(${CMAKE_SOURCE_DIR}/third_party/ge/lib) +elseif(ENABLE_D OR ENABLE_ACL OR ENABLE_TESTCASES) include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc) include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc/ops) include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc/external) include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc/external/graph) endif() +if (ENABLE_GE OR ENABLE_D OR ENABLE_ACL OR ENABLE_TESTCASES) + include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc) + include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc/external) + include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc/framework) + include_directories(${CMAKE_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc) + include_directories(${CMAKE_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc/toolchain) +endif() + if (ENABLE_MINDDATA) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/icu4c.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/libtiff.cmake) diff --git a/cmake/options.cmake b/cmake/options.cmake index 286f4324606..3433c57a0d7 100644 --- a/cmake/options.cmake +++ b/cmake/options.cmake @@ -19,6 +19,7 @@ option(ENABLE_AKG "enable akg" OFF) option(ENABLE_DEBUGGER "enable debugger" OFF) option(ENABLE_IBVERBS "enable IBVERBS for parameter server" OFF) option(ENABLE_PYTHON "Enable python" ON) +option(ENABLE_ACL "enable acl" OFF) if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") if (WIN32) diff --git a/cmake/package.cmake b/cmake/package.cmake index b3cacd8cb87..988f8089f22 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -58,6 +58,12 @@ install( COMPONENT mindspore ) +install( + TARGETS mindspore_shared_lib + LIBRARY DESTINATION ${INSTALL_LIB_DIR} + COMPONENT mindspore +) + install( TARGETS mindspore_gvar DESTINATION ${INSTALL_LIB_DIR} @@ -194,7 +200,7 @@ if (ENABLE_SERVING OR ENABLE_TESTCASES) endif () if (NOT ENABLE_GE) - if (ENABLE_D) + if (ENABLE_D OR ENABLE_ACL) if (DEFINED ENV{ASCEND_CUSTOM_PATH}) set(ASCEND_PATH $ENV{ASCEND_CUSTOM_PATH}) else () @@ -203,19 +209,26 @@ if (NOT ENABLE_GE) set(ASCEND_DRIVER_PATH ${ASCEND_PATH}/driver/lib64/common) install( - FILES - ${CMAKE_BINARY_DIR}/graphengine/src/common/graph/libgraph.so - ${CMAKE_BINARY_DIR}/graphengine/src/ge/common/libge_common.so - ${CMAKE_BINARY_DIR}/graphengine/src/ge/ge_runtime/libge_runtime.so - ${CMAKE_SOURCE_DIR}/build/graphengine/libc_sec.so - DESTINATION ${INSTALL_LIB_DIR} - COMPONENT mindspore - ) - install( - TARGETS ms_profile + FILES ${CMAKE_SOURCE_DIR}/build/graphengine/libc_sec.so DESTINATION ${INSTALL_LIB_DIR} COMPONENT mindspore ) + + if (ENABLE_D) + install( + TARGETS ms_profile + DESTINATION ${INSTALL_LIB_DIR} + COMPONENT mindspore + ) + install( + FILES + ${CMAKE_BINARY_DIR}/graphengine/src/common/graph/libgraph.so + ${CMAKE_BINARY_DIR}/graphengine/src/ge/common/libge_common.so + ${CMAKE_BINARY_DIR}/graphengine/src/ge/ge_runtime/libge_runtime.so + DESTINATION ${INSTALL_LIB_DIR} + COMPONENT mindspore + ) + endif () elseif (ENABLE_TESTCASES) install( FILES @@ -287,6 +300,13 @@ if (EXISTS ${CMAKE_SOURCE_DIR}/mindspore/dataset) ) endif () +## Public header files +install( + DIRECTORY ${CMAKE_SOURCE_DIR}/include + DESTINATION ${INSTALL_BASE_DIR} + COMPONENT mindspore +) + if (ENABLE_SERVING) install( TARGETS ms_serving @@ -308,8 +328,8 @@ if (ENABLE_SERVING) ) install( - FILES ${LIBEVENT_LIB_LIST} - DESTINATION ${INSTALL_LIB_DIR} - COMPONENT mindspore + FILES ${LIBEVENT_LIB_LIST} + DESTINATION ${INSTALL_LIB_DIR} + COMPONENT mindspore ) endif () diff --git a/graphengine b/graphengine index 423c0228e8c..42d217fb8ce 160000 --- a/graphengine +++ b/graphengine @@ -1 +1 @@ -Subproject commit 423c0228e8c421f2b095e40d14e9fb3b563f63aa +Subproject commit 42d217fb8cec74b1c73685b8abe94d5f1520e9fe diff --git a/include/api/cell.h b/include/api/cell.h new file mode 100644 index 00000000000..4b32256b298 --- /dev/null +++ b/include/api/cell.h @@ -0,0 +1,113 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_INCLUDE_API_CELL_H +#define MINDSPORE_INCLUDE_API_CELL_H +#include +#include +#include +#include +#include "include/api/status.h" +#include "include/api/types.h" + +namespace mindspore { +namespace api { +class InputAndOutput; +using Input = InputAndOutput; +using Output = InputAndOutput; + +class MS_API CellBase { + public: + CellBase() = default; + virtual ~CellBase() = default; + virtual std::vector Construct(const std::vector &inputs) { return {}; } + virtual std::shared_ptr Clone() const = 0; + std::vector operator()(const std::vector &inputs) const; +}; + +template +class MS_API Cell : public CellBase { + public: + virtual ~Cell() = default; + std::shared_ptr Clone() const override { + return std::make_shared(static_cast(*this)); + } +}; + +class MS_API ParameterCell final : public Cell { + public: + ParameterCell() = default; + ~ParameterCell() override = default; + + ParameterCell(const ParameterCell &); + ParameterCell &operator=(const ParameterCell &); + + ParameterCell(ParameterCell &&); + ParameterCell &operator=(ParameterCell &&); + + explicit ParameterCell(const Tensor &); + ParameterCell &operator=(const Tensor &); + + explicit ParameterCell(Tensor &&); + ParameterCell &operator=(Tensor &&); + + Tensor GetTensor() const { return tensor_; } + + private: + Tensor tensor_; +}; + +class MS_API OpCellBase : public CellBase { + public: + explicit OpCellBase(const std::string &name) : name_(name) {} + ~OpCellBase() override = default; + const std::string &GetOpType() const { return name_; } + + protected: + std::string name_; +}; + +template +class MS_API OpCell : public OpCellBase, public std::enable_shared_from_this { + public: + explicit OpCell(const std::string &name) : OpCellBase(name) {} + ~OpCell() override = default; + std::shared_ptr Clone() const override { + return std::make_shared(static_cast(*this)); + } +}; + +class MS_API InputAndOutput { + public: + InputAndOutput(); + ~InputAndOutput() = default; + + // no explicit + InputAndOutput(const Tensor &); // NOLINT(runtime/explicit) + InputAndOutput(Tensor &&); // NOLINT(runtime/explicit) + + InputAndOutput(const std::shared_ptr &, const std::vector &, int32_t index); + + int32_t GetIndex() const { return index_; } + void SetIndex(int32_t index) { index_ = index; } + + private: + std::shared_ptr cell_; + std::vector prev_; + int32_t index_; +}; +} // namespace api +} // namespace mindspore +#endif // MINDSPORE_INCLUDE_API_CELL_H diff --git a/include/api/model.h b/include/api/model.h new file mode 100644 index 00000000000..e14b778b491 --- /dev/null +++ b/include/api/model.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_INCLUDE_API_MODEL_H +#define MINDSPORE_INCLUDE_API_MODEL_H + +#include +#include +#include +#include +#include "include/api/status.h" +#include "include/api/types.h" + +namespace mindspore { +namespace api { +class ModelImpl; +// todo: minddata c++ interface +class DataSet {}; +class NetWork {}; + +class MS_API Model { + public: + Model(const std::string &device_type, uint32_t device_id); + Model(NetWork network, const std::string &device_type, uint32_t device_id); + ~Model(); + Model(const Model &) = delete; + void operator=(const Model &) = delete; + + Status LoadModel(const Buffer &model_data, ModelType type, const std::map &options); + Status LoadModel(const std::string &file_name, ModelType type, const std::map &options); + Status UnloadModel(); + + Status Train(const DataSet &dataset, std::map *outputs); + Status Eval(const DataSet &dataset, std::map *outputs); + Status Predict(const std::map &inputs, std::map *outputs); + Status Predict(const std::vector &inputs, std::map *outputs); + + Status GetInputsInfo(std::vector *tensor_list) const; + Status GetOutputsInfo(std::vector *tensor_list) const; + + private: + std::shared_ptr impl_; +}; +} // namespace api +} // namespace mindspore +#endif // MINDSPORE_INCLUDE_API_MODEL_H diff --git a/include/api/ops/ops.h b/include/api/ops/ops.h new file mode 100644 index 00000000000..0715bac8988 --- /dev/null +++ b/include/api/ops/ops.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_INCLUDE_API_OPS_OPS_H +#define MINDSPORE_INCLUDE_API_OPS_OPS_H + +#include +#include +#include +#include +#include "include/api/status.h" +#include "include/api/types.h" +#include "include/api/cell.h" + +namespace mindspore { +namespace api { +struct MS_API Conv2D : public OpCell { + Conv2D() : OpCell("Conv2D") {} + ~Conv2D() override = default; + std::vector Construct(const std::vector &inputs) override; + Conv2D(int out_channel, const std::vector &kernel_size, int mode = 1, const std::string &pad_mode = "valid", + const std::vector &pad = {0, 0, 0, 0}, const std::vector &stride = {1, 1, 1, 1}, + const std::vector &dilation = {1, 1, 1, 1}, int group = 1); + + Output operator()(const Input &, const Input &) const; + + int out_channel; + std::vector kernel_size; + int mode = 1; + std::string pad_mode = "valid"; + std::vector pad = {0, 0, 0, 0}; + std::vector stride = {1, 1, 1, 1}; + std::vector dilation = {1, 1, 1, 1}; + int group = 1; +}; +} // namespace api +} // namespace mindspore +#endif // MINDSPORE_INCLUDE_API_OPS_OPS_H diff --git a/include/api/serialization.h b/include/api/serialization.h new file mode 100644 index 00000000000..4fcd08c56a0 --- /dev/null +++ b/include/api/serialization.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_INCLUDE_API_SERIALIZATION_H +#define MINDSPORE_INCLUDE_API_SERIALIZATION_H + +#include +#include +#include +#include +#include "include/api/status.h" +#include "include/api/types.h" +#include "include/api/model.h" + +namespace mindspore { +namespace api { +class MS_API Serialization { + public: + static Status LoadCheckPoint(const std::string &ckpt_file, std::map *parameters); + static Status SetParameters(const std::map ¶meters, Model *model); + static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data); + static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file); +}; +} // namespace api +} // namespace mindspore +#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H diff --git a/include/api/status.h b/include/api/status.h new file mode 100644 index 00000000000..c8284fbaa5b --- /dev/null +++ b/include/api/status.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_INCLUDE_API_STATUS_H +#define MINDSPORE_INCLUDE_API_STATUS_H + +#include + +namespace mindspore { +namespace api { +enum StatusCode { + SUCCESS = 0, + FAILED, + INVALID_INPUTS, + // insert new status code here + UNKNOWN = 0xFFFFFFFF +}; + +class Status { + public: + Status() : status_code_(FAILED) {} + Status(enum StatusCode status_code, const std::string &status_msg = "") // NOLINT(runtime/explicit) + : status_code_(status_code), status_msg_(status_msg) {} + ~Status() = default; + + bool IsSuccess() const { return status_code_ == SUCCESS; } + enum StatusCode StatusCode() const { return status_code_; } + std::string StatusMessage() const { return status_msg_; } + bool operator==(const Status &other) const { return status_code_ == other.status_code_; } + bool operator==(enum StatusCode other_code) const { return status_code_ == other_code; } + bool operator!=(const Status &other) const { return status_code_ != other.status_code_; } + bool operator!=(enum StatusCode other_code) const { return status_code_ != other_code; } + operator bool() const = delete; + + private: + enum StatusCode status_code_; + std::string status_msg_; +}; +} // namespace api +} // namespace mindspore +#endif // MINDSPORE_INCLUDE_API_STATUS_H diff --git a/include/api/types.h b/include/api/types.h new file mode 100644 index 00000000000..194401cc976 --- /dev/null +++ b/include/api/types.h @@ -0,0 +1,119 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_INCLUDE_API_TYPES_H +#define MINDSPORE_INCLUDE_API_TYPES_H + +#include +#include +#include + +#define MS_API __attribute__((visibility("default"))) + +namespace mindspore { +namespace api { +enum ModelType { + kMindIR = 0, + kAIR = 1, + kOM = 2, + kONNX = 3, + // insert new data type here + kUnknownType = 0xFFFFFFFF +}; + +enum DataType { + kMsUnknown = 0, + kMsBool = 1, + kMsInt8 = 2, + kMsInt16 = 3, + kMsInt32 = 4, + kMsInt64 = 5, + kMsUint8 = 6, + kMsUint16 = 7, + kMsUint32 = 8, + kMsUint64 = 9, + kMsFloat16 = 10, + kMsFloat32 = 11, + kMsFloat64 = 12, + // insert new data type here + kInvalidDataType = 0xFFFFFFFF +}; + +class MS_API Tensor { + public: + Tensor(); + Tensor(const std::string &name, DataType type, const std::vector &shape, const void *data, size_t data_len); + ~Tensor(); + + const std::string &Name() const; + void SetName(const std::string &name); + + api::DataType DataType() const; + void SetDataType(api::DataType type); + + const std::vector &Shape() const; + void SetShape(const std::vector &shape); + + const void *Data() const; + void *MutableData(); + size_t DataSize() const; + + bool ResizeData(size_t data_len); + bool SetData(const void *data, size_t data_len); + + int64_t ElementNum() const; + static int GetTypeSize(api::DataType type); + Tensor Clone() const; + + private: + class Impl; + std::shared_ptr impl_; +}; + +class MS_API Buffer { + public: + Buffer(); + Buffer(const void *data, size_t data_len); + ~Buffer(); + + const void *Data() const; + void *MutableData(); + size_t DataSize() const; + + bool ResizeData(size_t data_len); + bool SetData(const void *data, size_t data_len); + + Buffer Clone() const; + + private: + class Impl; + std::shared_ptr impl_; +}; + +constexpr auto kModelOptionDumpCfgPath = "mindspore.option.dump_config_file_path"; +constexpr auto kModelOptionDvppCfgPath = "mindspore.option.dvpp_config_file_path"; +constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file +constexpr auto kModelOptionInputFormat = "mindspore.option.input_format"; // nchw or nhwc +// Mandatory while dynamic batch: e.g. "input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1" +constexpr auto kModelOptionInputShape = "mindspore.option.input_shape"; +constexpr auto kModelOptionDynamicBatchSize = "mindspore.option.dynamic_batch_size"; +constexpr auto kModelOptionDynamicImageSize = "mindspore.option.dynamic_image_size"; +constexpr auto kModelOptionDynamicDims = "mindspore.option.dynamic_dims"; +constexpr auto kModelOptionSerialInput = "mindspore.option.serial_inputs_name"; // separated by ';' +constexpr auto kModelOptionOutputNode = "mindspore.option.output_node"; // e.g. "node_name1:0;node_name2:1" +constexpr auto kModelOptionOutputType = "mindspore.option.output_type"; // "FP32", "UINT8" or "FP16", default as "FP32" +} // namespace api +} // namespace mindspore +#endif // MINDSPORE_INCLUDE_API_TYPES_H diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index b60de02baed..756767c799d 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -295,6 +295,9 @@ else () target_link_libraries(mindspore ibverbs rdmacm) endif() endif() + if (ENABLE_ACL) + target_link_libraries(_c_expression PRIVATE graph) + endif () target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore proto_input -Wl,--no-whole-archive) target_link_libraries(_c_expression PRIVATE mindspore::pybind11_module) target_link_libraries(_c_expression PRIVATE mindspore_gvar) @@ -359,3 +362,5 @@ if (CMAKE_SYSTEM_NAME MATCHES "Linux") elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin") set_target_properties(inference PROPERTIES MACOSX_RPATH ON) endif () + +add_subdirectory(cxx_api) diff --git a/mindspore/ccsrc/cxx_api/CMakeLists.txt b/mindspore/ccsrc/cxx_api/CMakeLists.txt new file mode 100644 index 00000000000..23d6c4935b6 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/CMakeLists.txt @@ -0,0 +1,62 @@ +# build mindspore_shared_lib +set(LOAD_ONNX_SRC + ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/utils/load_onnx/anf_converter.cc + ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc + ) +file(GLOB_RECURSE API_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR} "ops/*.cc") + +if (ENABLE_ACL) + file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/acl/*.cc") +endif () + +set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc + ${CMAKE_CURRENT_SOURCE_DIR}/cell.cc + ${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc + ${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc + ${API_ACL_SRC} + ${API_OPS_SRC} + ${LOAD_ONNX_SRC}) + +add_library(mindspore_shared_lib SHARED ${MSLIB_SRC}) +set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore PUBLIC_HEADER "${API_INCLUDE}") + +target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY} + -Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_gvar mindspore::protobuf) + +if (ENABLE_CPU) + target_link_libraries(mindspore_shared_lib PRIVATE mindspore::dnnl mindspore::mkldnn) +endif () + +if (USE_GLOG) + target_link_libraries(mindspore_shared_lib PRIVATE mindspore::glog) +endif () + +if (CMAKE_SYSTEM_NAME MATCHES "Linux") + target_link_options(mindspore_shared_lib PRIVATE -Wl,-init,common_log_init) +endif () + +if (ENABLE_ACL) + if (DEFINED ENV{ASCEND_CUSTOM_PATH}) + set(ASCEND_PATH $ENV{ASCEND_CUSTOM_PATH}) + else () + set(ASCEND_PATH /usr/local/Ascend) + endif () + set(ACL_LIB_DIR ${ASCEND_PATH}/acllib/) + set(ATLAS_ACL_LIB_DIR ${ASCEND_PATH}/ascend-toolkit/latest/acllib) + set(ATC_DIR ${ASCEND_PATH}/atc/) + set(ATLAS_ATC_DIR ${ASCEND_PATH}/ascend-toolkit/latest/atc) + MESSAGE("acl lib dir " ${ACL_LIB_DIR} ", atc dir " ${ATC_DIR}) + MESSAGE("atlas acl lib dir " ${ATLAS_ACL_LIB_DIR} ", atc dir " ${ATLAS_ATC_DIR}) + + include_directories(${ACL_LIB_DIR}/include/) + include_directories(${ATLAS_ACL_LIB_DIR}/include/) + add_compile_definitions(ENABLE_DVPP_INTERFACE) + find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64) + find_library(acl_retr libacl_retr.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64) + find_library(acl_cblas libacl_cblas.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64) + find_library(acl_dvpp libacl_dvpp.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64) + find_library(acl_runtime libruntime.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64) + find_library(ge_compiler libge_compiler.so ${ATC_DIR}/lib64 ${ATLAS_ATC_DIR}/lib64) + target_link_libraries(mindspore_shared_lib PRIVATE ${acl} ${acl_retr} ${acl_cblas} ${acl_dvpp} ${acl_runtime} + ${ge_compiler} mindspore::jpeg_turbo) +endif () diff --git a/mindspore/ccsrc/cxx_api/cell.cc b/mindspore/ccsrc/cxx_api/cell.cc new file mode 100644 index 00000000000..0b684ce5aa3 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/cell.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "include/api/cell.h" + +namespace mindspore::api { +std::vector CellBase::operator()(const std::vector &inputs) const { return Clone()->Construct(inputs); } + +ParameterCell::ParameterCell(const ParameterCell &cell) : tensor_(cell.tensor_.Clone()) {} +ParameterCell &ParameterCell::operator=(const ParameterCell &cell) { + if (&cell == this) { + return *this; + } + tensor_ = cell.tensor_.Clone(); + return *this; +} + +ParameterCell::ParameterCell(ParameterCell &&cell) : tensor_(cell.tensor_) {} + +ParameterCell &ParameterCell::operator=(ParameterCell &&cell) { + if (&cell == this) { + return *this; + } + tensor_ = cell.tensor_; + return *this; +} + +ParameterCell::ParameterCell(const Tensor &tensor) : tensor_(tensor.Clone()) {} + +ParameterCell &ParameterCell::operator=(const Tensor &tensor) { + tensor_ = tensor.Clone(); + return *this; +} + +ParameterCell::ParameterCell(Tensor &&tensor) : tensor_(tensor) {} + +ParameterCell &ParameterCell::operator=(Tensor &&tensor) { + tensor_ = tensor; + return *this; +} + +InputAndOutput::InputAndOutput() : cell_(nullptr), prev_(), index_(-1) {} + +InputAndOutput::InputAndOutput(const Tensor &tensor) + : cell_(std::make_shared(tensor.Clone())), prev_(), index_(-1) {} +InputAndOutput::InputAndOutput(Tensor &&tensor) : cell_(std::make_shared(tensor)), prev_(), index_(-1) {} + +InputAndOutput::InputAndOutput(const std::shared_ptr &cell, const std::vector &prev, + int32_t index) + : cell_(cell), prev_(prev), index_(index) {} +} // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model.cc b/mindspore/ccsrc/cxx_api/model/acl/acl_model.cc new file mode 100644 index 00000000000..10ecfc5b05e --- /dev/null +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model.cc @@ -0,0 +1,284 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cxx_api/model/acl/acl_model.h" +#include +#include "utils/context/context_extends.h" + +namespace mindspore::api { +std::weak_ptr AclModel::global_acl_env_; +std::mutex AclModel::global_acl_env_mutex_; + +Status AclModel::InitEnv() { + if (init_flag_) { + return SUCCESS; + } + + MS_EXCEPTION_IF_NULL(options_); + aclError ret; + { + std::lock_guard lock(global_acl_env_mutex_); + acl_env_ = global_acl_env_.lock(); + if (acl_env_ != nullptr) { + if (options_->dump_cfg_path.empty()) { + MS_LOG(INFO) << "Acl has been initialized, skip."; + } else { + MS_LOG(WARNING) << "Acl has been initialized, skip, so dump config will be ignored."; + } + } else { + acl_env_ = std::make_shared(options_->dump_cfg_path); + if (acl_env_->GetErrno() != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Execute aclInit Failed"; + return FAILED; + } + global_acl_env_ = acl_env_; + MS_LOG(INFO) << "Acl init success"; + } + } + + ret = aclrtSetDevice(device_id_); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Acl open device " << device_id_ << " failed"; + return FAILED; + } + MS_LOG(INFO) << "Open device " << device_id_ << " success"; + + ret = aclrtCreateContext(&context_, device_id_); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Acl create context failed"; + return FAILED; + } + MS_LOG(INFO) << "Create context success"; + + ret = aclrtSetCurrentContext(context_); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Acl set current context failed"; + return FAILED; + } + MS_LOG(INFO) << "Set context success"; + + ret = aclrtCreateStream(&stream_); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Acl create stream failed"; + return FAILED; + } + MS_LOG(INFO) << "Create stream success"; + + aclrtRunMode run_mode; + ret = aclrtGetRunMode(&run_mode); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Acl get run mode failed"; + return FAILED; + } + bool is_device = (run_mode == ACL_DEVICE); + model_process_.SetIsDevice(is_device); + MS_LOG(INFO) << "Get run mode success is device input/output " << is_device; + + if (dvpp_process_.InitResource(stream_) != SUCCESS) { + MS_LOG(ERROR) << "DVPP init resource failed"; + return FAILED; + } + ModelConverter::RegAllOp(); + + MS_LOG(INFO) << "Init acl success, device id " << device_id_; + init_flag_ = true; + return SUCCESS; +} + +Status AclModel::FinalizeEnv() { + if (!init_flag_) { + return SUCCESS; + } + + dvpp_process_.Finalize(); + aclError ret; + if (stream_ != nullptr) { + ret = aclrtDestroyStream(stream_); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Destroy stream failed"; + } + stream_ = nullptr; + } + MS_LOG(INFO) << "End to destroy stream"; + if (context_ != nullptr) { + ret = aclrtDestroyContext(context_); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Destroy context failed"; + } + context_ = nullptr; + } + MS_LOG(INFO) << "End to destroy context"; + + ret = aclrtResetDevice(device_id_); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Reset devie " << device_id_ << " failed"; + } + MS_LOG(INFO) << "End to reset device " << device_id_; + + init_flag_ = false; + return SUCCESS; +} + +Status AclModel::LoadModel(const Buffer &model_data, ModelType type, + const std::map &options) { + if (load_flag_) { + MS_LOG(ERROR) << "Model has been loaded."; + return FAILED; + } + + options_ = std::make_unique(options); + MS_EXCEPTION_IF_NULL(options_); + + Status ret = InitEnv(); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "InitEnv failed."; + return FAILED; + } + + Buffer om_data; + if (type == ModelType::kMindIR) { + model_converter_.set_options(options_.get()); + om_data = model_converter_.LoadMindIR(model_data); + } else if (type == ModelType::kAIR) { + model_converter_.set_options(options_.get()); + om_data = model_converter_.LoadAscendIR(model_data); + } else if (type == ModelType::kOM) { + om_data = model_data; + } else { + MS_LOG(ERROR) << "Unsupported model type " << type; + return FAILED; + } + + // acl load model + uint32_t acl_model_id; + auto acl_ret = aclmdlLoadFromMem(om_data.Data(), om_data.DataSize(), &acl_model_id); + if (acl_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Call aclmdlLoadFromMem failed."; + return FAILED; + } + + // acl init model resource + model_process_.set_model_id(acl_model_id); + ret = model_process_.PreInitModelResource(); + if (ret != SUCCESS) { + (void)aclmdlUnload(acl_model_id); + MS_LOG(ERROR) << "Pre init model resource failed."; + return FAILED; + } + + // acl init dvpp + ret = dvpp_process_.InitWithJsonConfig(options_->dvpp_cfg_path); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "DVPP config file parse error."; + return FAILED; + } + + load_flag_ = true; + return SUCCESS; +} + +Status AclModel::LoadModel(const std::string &file_name, ModelType type, + const std::map &options) { + Buffer model_data = ModelConverter::ReadFile(file_name); + if (model_data.DataSize() == 0) { + MS_LOG(ERROR) << "Read file " << file_name << " failed."; + return FAILED; + } + + return LoadModel(model_data, type, options); +} + +Status AclModel::UnloadModel() { + if (!load_flag_) { + MS_LOG(WARNING) << "No model is loaded, skip unload."; + return SUCCESS; + } + + aclError rt_ret = aclrtSetCurrentContext(context_); + if (rt_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Set the ascend device context failed"; + return FAILED; + } + + Status ret = model_process_.UnLoad(); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Unload model inner failed."; + return FAILED; + } + + ret = FinalizeEnv(); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "FinalizeEnv failed."; + return FAILED; + } + + MS_LOG(INFO) << "Unload model success."; + load_flag_ = false; + return SUCCESS; +} + +Status AclModel::Train(const DataSet &, std::map *) { + MS_LOG(ERROR) << "Unsupported feature."; + return FAILED; +} + +Status AclModel::Eval(const DataSet &, std::map *) { + MS_LOG(ERROR) << "Unsupported feature."; + return FAILED; +} + +Status AclModel::Predict(const std::map &inputs, std::map *outputs) { + MS_EXCEPTION_IF_NULL(outputs); + if (!load_flag_) { + MS_LOG(ERROR) << "No model is loaded, predict failed."; + return FAILED; + } + + aclError rt_ret = aclrtSetCurrentContext(context_); + if (rt_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Set the ascend device context failed"; + return FAILED; + } + return model_process_.Predict(inputs, outputs); +} + +Status AclModel::GetInputsInfo(std::vector *tensor_list) const { + MS_EXCEPTION_IF_NULL(tensor_list); + return model_process_.GetInputsInfo(tensor_list); +} + +Status AclModel::GetOutputsInfo(std::vector *tensor_list) const { + MS_EXCEPTION_IF_NULL(tensor_list); + return model_process_.GetOutputsInfo(tensor_list); +} + +AclModel::AclEnvGuard::AclEnvGuard(const std::string &cfg_file) { + errno_ = aclInit(common::SafeCStr(cfg_file)); + if (errno_ != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Execute aclInit Failed"; + return; + } + MS_LOG(INFO) << "Acl init success"; +} + +AclModel::AclEnvGuard::~AclEnvGuard() { + errno_ = aclFinalize(); + if (errno_ != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Finalize acl failed"; + } + MS_LOG(INFO) << "Acl finalize success"; +} +} // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model.h b/mindspore/ccsrc/cxx_api/model/acl/acl_model.h new file mode 100644 index 00000000000..6c7cbdf6b74 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model.h @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_CXX_API_ACL_MODEL_H +#define MINDSPORE_CCSRC_CXX_API_ACL_MODEL_H + +#include +#include +#include +#include +#include +#include + +#include "ir/anf.h" +#include "include/api/status.h" +#include "cxx_api/model/model_impl.h" +#include "cxx_api/model/acl/dvpp_process.h" +#include "cxx_api/model/acl/model_process.h" +#include "cxx_api/model/acl/model_converter.h" +#include "cxx_api/model/acl/acl_model_options.h" +#include "ir/tensor.h" + +namespace mindspore::api { +class AclModel : public ModelImpl { + public: + explicit AclModel(uint32_t device_id) + : init_flag_(false), + load_flag_(false), + device_type_("AscendCL"), + device_id_(device_id), + context_(nullptr), + stream_(nullptr), + acl_env_(nullptr), + model_process_(), + dvpp_process_(), + model_converter_(), + options_(nullptr) {} + ~AclModel() = default; + + Status LoadModel(const Buffer &model_data, ModelType type, + const std::map &options) override; + Status LoadModel(const std::string &file_name, ModelType type, + const std::map &options) override; + Status UnloadModel() override; + + Status Train(const DataSet &dataset, std::map *outputs) override; + Status Eval(const DataSet &dataset, std::map *outputs) override; + Status Predict(const std::map &inputs, std::map *outputs) override; + + Status GetInputsInfo(std::vector *tensor_list) const override; + Status GetOutputsInfo(std::vector *tensor_list) const override; + + private: + bool init_flag_; + bool load_flag_; + std::string device_type_; + int32_t device_id_; + aclrtContext context_; + aclrtStream stream_; + + class AclEnvGuard; + std::shared_ptr acl_env_; + static std::weak_ptr global_acl_env_; + static std::mutex global_acl_env_mutex_; + + ModelProcess model_process_; + DvppProcess dvpp_process_; + ModelConverter model_converter_; + std::unique_ptr options_; + + Status InitEnv(); + Status FinalizeEnv(); +}; + +class AclModel::AclEnvGuard { + public: + explicit AclEnvGuard(const std::string &cfg_file); + ~AclEnvGuard(); + aclError GetErrno() const { return errno_; } + + private: + aclError errno_; +}; + +API_REG_MODEL(AscendCL, AclModel); +} // namespace mindspore::api +#endif // MINDSPORE_CCSRC_CXX_API_ACL_MODEL_H diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.cc b/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.cc new file mode 100644 index 00000000000..ac55e8ed7a0 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "cxx_api/model/acl/acl_model_options.h" +#include +#include "external/ge/ge_api_types.h" + +namespace mindspore::api { +static std::string ParseOption(const std::map &options, const std::string &key) { + auto iter = options.find(key); + if (iter != options.end()) { + return iter->second; + } + return ""; +} + +AclModelOptions::AclModelOptions(const std::map &options) { + dump_cfg_path = ParseOption(options, kModelOptionDumpCfgPath); + dvpp_cfg_path = ParseOption(options, kModelOptionDvppCfgPath); + output_node = ParseOption(options, kModelOptionOutputNode); + // to acl + insert_op_cfg_path = ParseOption(options, kModelOptionInsertOpCfgPath); + input_format = ParseOption(options, kModelOptionInputFormat); + input_shape = ParseOption(options, kModelOptionInputShape); + dynamic_batch_size = ParseOption(options, kModelOptionInputShape); + dynamic_image_size = ParseOption(options, kModelOptionInputShape); + dynamic_dims = ParseOption(options, kModelOptionInputShape); + serial_nodes_name = ParseOption(options, kModelOptionSerialInput); + output_type = ParseOption(options, kModelOptionOutputType); +} + +std::map AclModelOptions::GenAclOptions() const { + const std::map acl_options_map = { + {&insert_op_cfg_path, ge::ir_option::INSERT_OP_FILE}, + {&input_format, ge::ir_option::INPUT_FORMAT}, + {&input_shape, ge::ir_option::INPUT_SHAPE}, + {&dynamic_batch_size, ge::ir_option::DYNAMIC_BATCH_SIZE}, + {&dynamic_image_size, ge::ir_option::DYNAMIC_IMAGE_SIZE}, + {&dynamic_dims, ge::ir_option::DYNAMIC_DIMS}, + {&serial_nodes_name, ge::ir_option::INPUT_FP16_NODES}, + {&output_type, ge::ir_option::OUTPUT_TYPE}, + }; + + std::map acl_options; + for (auto [ms_option, acl_option_key] : acl_options_map) { + if (ms_option == nullptr || ms_option->empty()) { + continue; + } + acl_options.emplace(acl_option_key, *ms_option); + } + return acl_options; +} + +} // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.h b/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.h new file mode 100644 index 00000000000..915c7492697 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_OPTION_PARSER_H +#define MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_OPTION_PARSER_H +#include +#include +#include +#include "include/api/types.h" +#include "include/api/status.h" + +namespace mindspore::api { +struct AclModelOptions { + std::string dump_cfg_path; + std::string dvpp_cfg_path; + std::string output_node; // todo: at convert.cc::BuildGraph(), no atc options + // build options + std::string insert_op_cfg_path; + std::string input_format; + std::string input_shape; + std::string dynamic_batch_size; + std::string dynamic_image_size; + std::string dynamic_dims; + std::string serial_nodes_name; + std::string output_type; + + explicit AclModelOptions(const std::map &options); + ~AclModelOptions() = default; + + std::map GenAclOptions() const; +}; +} // namespace mindspore::api + +#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_OPTION_PARSER_H diff --git a/mindspore/ccsrc/cxx_api/model/acl/dvpp_process.cc b/mindspore/ccsrc/cxx_api/model/acl/dvpp_process.cc new file mode 100644 index 00000000000..91ff286a31f --- /dev/null +++ b/mindspore/ccsrc/cxx_api/model/acl/dvpp_process.cc @@ -0,0 +1,1160 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cxx_api/model/acl/dvpp_process.h" +#include +#include +#include +#include +#include +#include "utils/utils.h" +#include "include/api/types.h" +#include "mindspore/core/utils/ms_utils.h" + +namespace mindspore::api { +DvppProcess::DvppProcess() {} + +DvppProcess::~DvppProcess() {} + +static uint32_t ToEven(uint32_t num) { return (num + 1) / 2 * 2; } +static uint32_t ToOdd(uint32_t num) { + if (num == 0) { + return 1; + } + return (num + 1) / 2 * 2 - 1; +} + +class DvppJsonConfigParser { + public: + DvppJsonConfigParser() = default; + ~DvppJsonConfigParser() = default; + + Status InitWithJsonConfig(const std::string &json_config); + DvppDecodePara GetDecodePara() const { return decode_para_; } + DvppResizePara GetResizePara() const { return resize_para_; } + DvppCropPara GetCropPara() const { return crop_para_; } + DvppCropAndPastePara GetCropAndPastePara() const { return crop_and_paste_para_; } + bool HasResizeConfig() const { return resize_flag_; } + bool HasCropConfig() const { return crop_flag_; } + bool HasCropAndPasteConfig() const { return crop_and_paste_flag_; } + + private: + DvppDecodePara decode_para_; + DvppResizePara resize_para_; + DvppCropPara crop_para_; + DvppCropAndPastePara crop_and_paste_para_; + bool resize_flag_ = false; + bool crop_flag_ = false; + bool crop_and_paste_flag_ = false; + + Status GetStringValue(const nlohmann::json &json_item, const std::string &key, std::string *val); + Status GetIntValue(const nlohmann::json &json_item, const std::string &key, uint32_t *val); + Status ParseInputPara(const nlohmann::json &preprocess_item); + Status ParseDecodePara(const nlohmann::json &preprocess_item); + Status ParseResizePara(const nlohmann::json &json_item); + Status ParseCropPara(const nlohmann::json &json_item); + Status ParseCropAndPastePara(const nlohmann::json &json_item); + Status InitWithJsonConfigImp(const std::string &json_config); +}; + +Status DvppProcess::InitResource(aclrtStream stream) { + stream_ = stream; + aclError acl_ret; + dvpp_channel_desc_ = acldvppCreateChannelDesc(); + if (dvpp_channel_desc_ == nullptr) { + MS_LOG(ERROR) << "Call acldvppCreateChannelDesc failed"; + return FAILED; + } + acl_ret = acldvppCreateChannel(dvpp_channel_desc_); + if (acl_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Call acldvppCreateChannel failed, acl return " << acl_ret; + return FAILED; + } + MS_LOG(INFO) << "End init dvpp process resource"; + return SUCCESS; +} + +void DvppProcess::DestroyResource() { + if (dvpp_channel_desc_ != nullptr) { + auto acl_ret = acldvppDestroyChannel(dvpp_channel_desc_); + if (acl_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Call acldvppDestroyChannel failed, acl return " << acl_ret; + } + acl_ret = acldvppDestroyChannelDesc(dvpp_channel_desc_); + if (acl_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Call acldvppDestroyChannelDesc failed, acl return " << acl_ret; + } + dvpp_channel_desc_ = nullptr; + } +} + +void DvppProcess::Finalize() { + DestroyDecodeDesc(); + DestroyVpcOutputDesc(); + DestroyResource(); + if (resize_config_ != nullptr) { + acldvppDestroyResizeConfig(resize_config_); + resize_config_ = nullptr; + } + if (crop_area_ != nullptr) { + acldvppDestroyRoiConfig(crop_area_); + crop_area_ = nullptr; + } + if (paste_area_ != nullptr) { + acldvppDestroyRoiConfig(paste_area_); + paste_area_ = nullptr; + } + if (input_pic_dev_buffer_ != nullptr) { + acldvppFree(input_pic_dev_buffer_); + } + input_pic_buffer_size_ = 0; + MS_LOG(INFO) << "End dvpp process finalize"; +} + +Status DvppProcess::InitJpegDecodePara(const DvppDecodePara &decode_para) { + decode_para_ = decode_para; + MS_LOG(INFO) << "Init decode para, pixel_format " << decode_para_.pixel_format; + return SUCCESS; +} + +Status DvppProcess::InitResizePara(const DvppResizePara &resize_para) { + resize_para_ = resize_para; + MS_LOG(INFO) << "Init resize para, " + << "output_width " << resize_para_.output_width << ", output_height " << resize_para_.output_height; + to_resize_flag_ = true; + to_crop_flag_ = false; + to_crop_and_paste_flag_ = false; + Status ret = InitResizeOutputDesc(); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "InitResizeOutputDesc failed"; + } + return ret; +} + +Status DvppProcess::InitCommonCropPara(uint32_t output_width, uint32_t output_height, DvppCropInfo *crop_info) { + MS_EXCEPTION_IF_NULL(crop_info); + if (crop_info->crop_type == kDvppCropTypeOffset) { + if (CheckAndAdjustRoiArea(&crop_info->crop_area) != SUCCESS) { + MS_LOG(ERROR) << "Check and adjust crop area failed"; + return FAILED; + } + MS_LOG(INFO) << "Init common crop para, crop type offset " + << ", left " << crop_info->crop_area.left << ", right " << crop_info->crop_area.right << ", top " + << crop_info->crop_area.top << ", bottom " << crop_info->crop_area.bottom << ", output_width " + << output_width << ", output_height " << output_height; + } else { + crop_info->crop_width = ToEven(crop_info->crop_width); + crop_info->crop_height = ToEven(crop_info->crop_height); + if (CheckRoiAreaWidthHeight(crop_info->crop_width, crop_info->crop_height) != SUCCESS) { + MS_LOG(ERROR) << "Check crop area width and height failed, actually width " << crop_info->crop_width << " height " + << crop_info->crop_height; + return FAILED; + } + MS_LOG(INFO) << "Init common crop para, crop type centre " + << ", crop_width " << crop_info->crop_width << ", crop_height " << crop_info->crop_height + << ", output_width " << output_width << ", output_height " << output_height; + } + return SUCCESS; +} + +Status DvppProcess::InitCropPara(const DvppCropPara &crop_para) { + crop_para_ = crop_para; + if (InitCommonCropPara(crop_para_.output_width, crop_para_.output_height, &crop_para_.crop_info) != SUCCESS) { + MS_LOG(ERROR) << "Init common crop para failed in InitCropPara"; + return FAILED; + } + to_crop_flag_ = true; + to_resize_flag_ = false; + to_crop_and_paste_flag_ = false; + Status ret = InitCropOutputDesc(); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "InitCropOutputDesc failed"; + } + return ret; +} + +Status DvppProcess::InitCropAndPastePara(const DvppCropAndPastePara &crop_and_paste_para) { + crop_and_paste_para_ = crop_and_paste_para; + if (InitCommonCropPara(crop_and_paste_para_.output_width, crop_and_paste_para_.output_height, + &crop_and_paste_para_.crop_info) != SUCCESS) { + MS_LOG(ERROR) << "Init common crop para failed in InitCropAndPastePara"; + return FAILED; + } + auto &paste_area = crop_and_paste_para_.paste_area; + if (CheckAndAdjustRoiArea(&paste_area) != SUCCESS) { + MS_LOG(ERROR) << "Check and adjust paste area failed"; + return FAILED; + } + MS_LOG(INFO) << "Init crop and paste para, paste info: " + << ", left " << paste_area.left << ", right " << paste_area.right << ", top " << paste_area.top + << ", bottom " << paste_area.bottom; + + to_crop_and_paste_flag_ = true; + to_crop_flag_ = false; + to_resize_flag_ = false; + Status ret = InitCropAndPasteOutputDesc(); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "InitCropAndPasteOutputDesc failed"; + } + return ret; +} + +Status DvppProcess::InputInputBuffer(const void *pic_buffer, size_t pic_buffer_size) { + aclError acl_ret; + if (pic_buffer_size != input_pic_buffer_size_) { + acldvppFree(input_pic_dev_buffer_); + input_pic_buffer_size_ = 0; + acl_ret = acldvppMalloc(&input_pic_dev_buffer_, pic_buffer_size); + if (acl_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Call acldvppMalloc input picture buffer on device failed, buffer size " << pic_buffer_size; + return FAILED; + } + input_pic_buffer_size_ = pic_buffer_size; + } + acl_ret = + aclrtMemcpy(input_pic_dev_buffer_, input_pic_buffer_size_, pic_buffer, pic_buffer_size, ACL_MEMCPY_HOST_TO_DEVICE); + if (acl_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Call aclrtMemcpy input picture buffer to device, buffer size " << pic_buffer_size; + return FAILED; + } + return SUCCESS; +} + +static void JpegErrorExitCustom(j_common_ptr cinfo) { + char jpeg_last_error_msg[JMSG_LENGTH_MAX] = {0}; + if (cinfo != nullptr && cinfo->err != nullptr && cinfo->err->format_message != nullptr) { + (*(cinfo->err->format_message))(cinfo, jpeg_last_error_msg); + } + throw std::runtime_error(jpeg_last_error_msg); +} + +Status DvppProcess::GetJpegWidthHeight(const void *pic_buffer, size_t pic_buffer_size, uint32_t *image_width, + uint32_t *image_height) { + MS_EXCEPTION_IF_NULL(image_width); + MS_EXCEPTION_IF_NULL(image_height); + struct jpeg_decompress_struct jpeg_header; + struct jpeg_error_mgr jpeg_error; + jpeg_header.err = jpeg_std_error(&jpeg_error); + jpeg_error.error_exit = JpegErrorExitCustom; + try { + jpeg_create_decompress(&jpeg_header); + jpeg_mem_src(&jpeg_header, reinterpret_cast(pic_buffer), pic_buffer_size); + (void)jpeg_read_header(&jpeg_header, TRUE); + } catch (std::runtime_error &e) { + jpeg_destroy_decompress(&jpeg_header); + MS_LOG(ERROR) << "JPEG images read failed, " << e.what(); + return INVALID_INPUTS; + } + *image_width = jpeg_header.image_width; + *image_height = jpeg_header.image_height; + + if (jpeg_header.jpeg_color_space != JCS_YCbCr) { + MS_LOG(ERROR) << "Expect color space YUV(YCbCr), current " << jpeg_header.jpeg_color_space; + jpeg_destroy_decompress(&jpeg_header); + return INVALID_INPUTS; + } + if (jpeg_header.dc_huff_tbl_ptrs[0] == nullptr) { + MS_LOG(ERROR) << "Only support Huffman code"; + jpeg_destroy_decompress(&jpeg_header); + return INVALID_INPUTS; + } + jpeg_destroy_decompress(&jpeg_header); + + const uint32_t min_width = 32; + const uint32_t max_width = 8192; + const uint32_t min_height = 32; + const uint32_t max_height = 8192; + if (*image_width < min_width || *image_width > max_width) { + MS_LOG(ERROR) << "Expect image width [" << min_width << ", " << max_width << "], the real image width is " + << *image_width; + return INVALID_INPUTS; + } + if (*image_height < min_height || *image_height > max_height) { + MS_LOG(ERROR) << "Expect image height [" << min_height << ", " << max_height << "], the real image height is " + << *image_height; + return INVALID_INPUTS; + } + return SUCCESS; +} + +Status DvppProcess::Process(const void *pic_buffer, size_t pic_buffer_size, void **output_device_buffer, + size_t *output_size) { + MS_EXCEPTION_IF_NULL(output_device_buffer); + MS_EXCEPTION_IF_NULL(output_size); + if (dvpp_channel_desc_ == nullptr) { + MS_LOG(ERROR) << "Process failed, dvpp not inited"; + return FAILED; + } + uint32_t image_width = 0; + uint32_t image_height = 0; + Status ret = GetJpegWidthHeight(pic_buffer, pic_buffer_size, &image_width, &image_height); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Get jpeg image height and width failed"; + return ret; + } + MS_LOG(INFO) << "Get jpeg width " << image_width << ", height " << image_height; + ret = InitDecodeOutputDesc(image_width, image_height); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "InitDecodeOutputDesc failed"; + return FAILED; + } + ret = UpdateCropArea(image_width, image_height); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Update crop area failed"; + return ret; + } + ret = CheckResizeImageInfo(image_width, image_height); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Check resize para failed"; + return ret; + } + if (InputInputBuffer(pic_buffer, pic_buffer_size) != SUCCESS) { + MS_LOG(ERROR) << "InputInputBuffer failed"; + return FAILED; + } + if (ProcessDecode() != SUCCESS) { + MS_LOG(ERROR) << "Process Decode failed"; + return INVALID_INPUTS; + } + MS_LOG(INFO) << "Process Decode success"; + if (to_resize_flag_) { + if (ProcessResize() != SUCCESS) { + MS_LOG(ERROR) << "Process Resize failed"; + return INVALID_INPUTS; + } + MS_LOG(INFO) << "Process Resize success"; + } else if (to_crop_flag_) { + if (ProcessCrop() != SUCCESS) { + MS_LOG(ERROR) << "Process Crop failed"; + return INVALID_INPUTS; + } + MS_LOG(INFO) << "Process Crop success"; + } else if (to_crop_and_paste_flag_) { + if (ProcessCropAndPaste() != SUCCESS) { + MS_LOG(ERROR) << "Process Crop And Paste failed"; + return INVALID_INPUTS; + } + MS_LOG(INFO) << "Process Crop And Paste success"; + } + if (vpc_output_buffer_dev_ == nullptr) { + *output_device_buffer = decode_output_buffer_dev_; + *output_size = decode_output_buffer_size_; + } else { + *output_device_buffer = vpc_output_buffer_dev_; + *output_size = vpc_output_buffer_size_; + } + MS_LOG(INFO) << "Process dvpp success"; + return SUCCESS; +} + +Status DvppProcess::Process(const std::vector &pic_buffer_list, + const std::vector &pic_buffer_size_list, void **output_device_buffer, + size_t *output_size) { + MS_EXCEPTION_IF_NULL(output_device_buffer); + MS_EXCEPTION_IF_NULL(output_size); + auto batch_size = pic_buffer_list.size(); + if (batch_size == 0 || batch_size != pic_buffer_size_list.size()) { + MS_LOG(ERROR) << "Invalid batch size " << batch_size << ", pic size count" << pic_buffer_size_list.size(); + return FAILED; + } + MS_LOG(INFO) << "Begin dvpp process, batch size " << batch_size; + if (batch_size == 1) { + return Process(pic_buffer_list[0], pic_buffer_size_list[0], output_device_buffer, output_size); + } + size_t total_buffer_size = vpc_output_buffer_size_ * batch_size; + if (batch_size_ != batch_size) { + if (batch_vpc_output_buffer_dev_ != nullptr) { + acldvppFree(batch_vpc_output_buffer_dev_); + batch_vpc_output_buffer_dev_ = nullptr; + } + batch_size_ = batch_size; + auto acl_rt = acldvppMalloc(&batch_vpc_output_buffer_dev_, total_buffer_size); + if (acl_rt != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Call acldvppMalloc failed, buffer size " << total_buffer_size; + return FAILED; + } + } + for (size_t i = 0; i < batch_size; i++) { + const void *pic_buffer = pic_buffer_list[i]; + uint32_t pic_size = pic_buffer_size_list[i]; + if (pic_buffer == nullptr || pic_size == 0) { + MS_LOG(ERROR) << "Get " << 0 << "th images failed"; + return FAILED; + } + void *output_dev_buffer_tmp = nullptr; + size_t output_buffer_size_tmp = 0; + Status ret = Process(pic_buffer, pic_size, &output_dev_buffer_tmp, &output_buffer_size_tmp); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "DVPP process failed"; + return ret; + } + aclrtMemcpy(static_cast(batch_vpc_output_buffer_dev_) + vpc_output_buffer_size_ * i, + total_buffer_size - vpc_output_buffer_size_ * i, output_dev_buffer_tmp, vpc_output_buffer_size_, + ACL_MEMCPY_DEVICE_TO_DEVICE); + + MS_LOG(INFO) << "DVPP process " << i << " th images success, input pic size " << pic_size << " output buffer size " + << output_buffer_size_tmp; + } + *output_device_buffer = batch_vpc_output_buffer_dev_; + *output_size = total_buffer_size; + MS_LOG(INFO) << "End DVPP process, batch size " << batch_size << ", output size " << output_size; + return SUCCESS; +} + +uint32_t DvppProcess::AlignmentHelper(uint32_t org_size, uint32_t alignment) const { + if (alignment == 0) { + return 0; + } + return (org_size + alignment - 1) / alignment * alignment; +} + +uint32_t DvppProcess::GetImageBufferSize(uint32_t stride_width, uint32_t stride_height, + acldvppPixelFormat pixel_format) const { + if (stride_height == 0 || stride_width == 0) { + MS_LOG(ERROR) << "Invalid stride height or width, stride_width " << stride_width << " stride_height " + << stride_height; + return 0; + } + if (UINT32_MAX / 3 < stride_height || UINT32_MAX / (3 * stride_height) < stride_width) { + MS_LOG(ERROR) << "Invalid stride height or width, stride_width " << stride_width << " stride_height " + << stride_height; + return 0; + } + if (pixel_format == PIXEL_FORMAT_YUV_SEMIPLANAR_420 || pixel_format == PIXEL_FORMAT_YVU_SEMIPLANAR_420) { + return stride_width * stride_height * 3 / 2; // 420 + } else if (pixel_format == PIXEL_FORMAT_YUV_SEMIPLANAR_422 || pixel_format == PIXEL_FORMAT_YVU_SEMIPLANAR_422) { + return stride_width * stride_height * 2; // 422 + } else if (pixel_format == PIXEL_FORMAT_YUV_SEMIPLANAR_444 || pixel_format == PIXEL_FORMAT_YVU_SEMIPLANAR_444) { + return stride_width * stride_height * 3; // 444 + } + MS_LOG(ERROR) << "Not support pixel format " << pixel_format; + return 0; +} + +Status DvppProcess::GetPicDescStride(uint32_t width, uint32_t height, uint32_t *stride_width, uint32_t *stride_height) { + MS_EXCEPTION_IF_NULL(stride_width); + MS_EXCEPTION_IF_NULL(stride_height); + const uint32_t width_alignment = 16; + const uint32_t height_alignment = 2; + const uint32_t stride_width_minimum = 32; + const uint32_t stride_width_maximum = 4096; + const uint32_t stride_height_minimum = 6; + const uint32_t stride_height_maximum = 4096; + + *stride_width = AlignmentHelper(width, width_alignment); + *stride_height = AlignmentHelper(height, height_alignment); + if (*stride_width == 0 || *stride_height == 0) { + MS_LOG(ERROR) << "Init VPC output desc failed, get stride width or height failed"; + return FAILED; + } + if (*stride_width < stride_width_minimum || *stride_width > stride_width_maximum) { + MS_LOG(ERROR) << "Expect stride width [" << stride_width_minimum << ", " << stride_width_maximum + << "], current stride width " << stride_width << " given width " << width; + return FAILED; + } + if (*stride_height < stride_height_minimum || *stride_height > stride_height_maximum) { + MS_LOG(ERROR) << "Expect stride height [" << stride_height_minimum << ", " << stride_height_maximum + << "], current stride height " << *stride_height << " given height " << height; + return FAILED; + } + return SUCCESS; +} + +Status DvppProcess::GetPicDescStrideDecode(uint32_t width, uint32_t height, uint32_t *stride_width, + uint32_t *stride_height) { + MS_EXCEPTION_IF_NULL(stride_width); + MS_EXCEPTION_IF_NULL(stride_height); + const uint32_t width_alignment = 128; + const uint32_t height_alignment = 16; + const uint32_t width_minimum = 32; + const uint32_t width_maximum = 4096; // decode support 8192, dvpp(resize/crop/crop&paste) support 4096 + const uint32_t height_minimum = 32; + const uint32_t height_maximum = 4096; // decode support 8192, dvpp(resize/crop/crop&paste) support 4096 + if (width < width_minimum || width > width_maximum) { + MS_LOG(ERROR) << "Expect width [" << width_minimum << ", " << width_maximum << "], current width " << width; + return INVALID_INPUTS; + } + if (height < height_minimum || height > height_maximum) { + MS_LOG(ERROR) << "Expect height [" << height_minimum << ", " << height_maximum << "], current height " << height; + return INVALID_INPUTS; + } + *stride_width = AlignmentHelper(width, width_alignment); + *stride_height = AlignmentHelper(height, height_alignment); + if (*stride_width == 0 || *stride_height == 0) { + MS_LOG(ERROR) << "Init decode output desc failed, get stride width or height failed"; + return FAILED; + } + return SUCCESS; +} + +Status DvppProcess::InitVpcOutputDesc(uint32_t output_width, uint32_t output_height, acldvppPixelFormat pixel_format) { + DestroyVpcOutputDesc(); + uint32_t vpc_stride_width = 0; + uint32_t vpc_stride_height = 0; + if (GetPicDescStride(output_width, output_height, &vpc_stride_width, &vpc_stride_height) != SUCCESS) { + MS_LOG(ERROR) << "Init VPC output desc failed, get VPC output stride width/height failed"; + return FAILED; + } + vpc_output_buffer_size_ = GetImageBufferSize(vpc_stride_width, vpc_stride_height, pixel_format); + if (vpc_output_buffer_size_ == 0) { + MS_LOG(ERROR) << "Init VPC output desc failed, get image buffer size failed"; + return FAILED; + } + auto acl_ret = acldvppMalloc(&vpc_output_buffer_dev_, vpc_output_buffer_size_); + if (acl_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Init VPC output desc failed, malloc dvpp memory failed"; + return FAILED; + } + vpc_output_desc_ = acldvppCreatePicDesc(); + if (vpc_output_desc_ == nullptr) { + MS_LOG(ERROR) << "Init VPC output desc failed, create pic desc failed"; + return FAILED; + } + acldvppSetPicDescData(vpc_output_desc_, vpc_output_buffer_dev_); + acldvppSetPicDescSize(vpc_output_desc_, vpc_output_buffer_size_); + acldvppSetPicDescFormat(vpc_output_desc_, pixel_format); + acldvppSetPicDescWidth(vpc_output_desc_, output_width); + acldvppSetPicDescHeight(vpc_output_desc_, output_height); + acldvppSetPicDescWidthStride(vpc_output_desc_, vpc_stride_width); + acldvppSetPicDescHeightStride(vpc_output_desc_, vpc_stride_height); + MS_LOG(INFO) << "Init VPC output desc success"; + return SUCCESS; +} + +void DvppProcess::DestroyVpcOutputDesc() { + if (vpc_output_desc_ != nullptr) { + acldvppDestroyPicDesc(vpc_output_desc_); + vpc_output_desc_ = nullptr; + } + if (vpc_output_buffer_dev_ != nullptr) { + acldvppFree(vpc_output_buffer_dev_); + vpc_output_buffer_dev_ = nullptr; + } + if (batch_vpc_output_buffer_dev_ != nullptr) { + acldvppFree(batch_vpc_output_buffer_dev_); + batch_vpc_output_buffer_dev_ = nullptr; + } + vpc_output_buffer_size_ = 0; + MS_LOG(INFO) << "End destroy vpc desc"; +} + +Status DvppProcess::InitDecodeOutputDesc(uint32_t image_width, uint32_t image_height) { + if (decode_output_buffer_dev_ != nullptr && image_width == pic_width_ && image_height == pic_height_) { + return SUCCESS; + } + DestroyDecodeDesc(); + + pic_width_ = image_width; + pic_height_ = image_height; + + uint32_t stride_width = 0; + uint32_t stride_height = 0; + Status ret = GetPicDescStrideDecode(pic_width_, pic_height_, &stride_width, &stride_height); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Init VPC output desc failed, get VPC output stride width/height failed"; + return ret; + } + + decode_output_buffer_size_ = GetImageBufferSize(stride_width, stride_height, decode_para_.pixel_format); + if (decode_output_buffer_size_ == 0) { + MS_LOG(ERROR) << "Init decode output desc failed, get image buffer size failed"; + return FAILED; + } + auto acl_ret = acldvppMalloc(&decode_output_buffer_dev_, decode_output_buffer_size_); + if (acl_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Init decode output desc failed, malloc dvpp memory failed"; + return FAILED; + } + decode_output_desc_ = acldvppCreatePicDesc(); + if (decode_output_desc_ == nullptr) { + MS_LOG(ERROR) << "Init decode output desc failed, create pic desc failed"; + return FAILED; + } + acldvppSetPicDescData(decode_output_desc_, decode_output_buffer_dev_); + acldvppSetPicDescSize(decode_output_desc_, decode_output_buffer_size_); + acldvppSetPicDescFormat(decode_output_desc_, decode_para_.pixel_format); + acldvppSetPicDescWidth(decode_output_desc_, pic_width_); + acldvppSetPicDescHeight(decode_output_desc_, pic_height_); + acldvppSetPicDescWidthStride(decode_output_desc_, stride_width); + acldvppSetPicDescHeightStride(decode_output_desc_, stride_height); + MS_LOG(INFO) << "Init decode output desc success"; + return SUCCESS; +} + +Status DvppProcess::CheckRoiAreaWidthHeight(uint32_t width, uint32_t height) { + const uint32_t min_crop_width = 10; + const uint32_t max_crop_width = 4096; + const uint32_t min_crop_height = 6; + const uint32_t max_crop_height = 4096; + + if (width < min_crop_width || width > max_crop_width) { + MS_LOG(ERROR) << "Expect roi area width in [" << min_crop_width << ", " << max_crop_width << "], actually " + << width; + return FAILED; + } + if (height < min_crop_height || height > max_crop_height) { + MS_LOG(ERROR) << "Expect roi area height in [" << min_crop_height << ", " << max_crop_height << "], actually " + << height; + return FAILED; + } + return SUCCESS; +} + +Status DvppProcess::CheckAndAdjustRoiArea(DvppRoiArea *area) { + MS_EXCEPTION_IF_NULL(area); + if (area->right < area->left) { + MS_LOG(ERROR) << "Check roi area failed, left " << area->left << ", right " << area->right; + return FAILED; + } + if (area->bottom < area->top) { + MS_LOG(ERROR) << "Check roi area failed, top " << area->top << ", bottom " << area->bottom; + return FAILED; + } + + area->left = ToEven(area->left); + area->top = ToEven(area->top); + area->right = ToOdd(area->right); + area->bottom = ToOdd(area->bottom); + + auto width = area->right - area->left + 1; + auto height = area->bottom - area->top + 1; + if (CheckRoiAreaWidthHeight(width, height) != SUCCESS) { + MS_LOG(ERROR) << "Check roi area width and height failed," + << " actually width " << width << " left " << area->left << ", right " << area->right + << " actually height " << height << " top " << area->top << ", bottom " << area->bottom; + return FAILED; + } + return SUCCESS; +} + +Status DvppProcess::UpdateCropArea(uint32_t image_width, uint32_t image_height) { + DvppCropInfo *crop_info = nullptr; + if (to_crop_flag_) { + crop_info = &crop_para_.crop_info; + } else if (to_crop_and_paste_flag_) { + crop_info = &crop_and_paste_para_.crop_info; + } else { + return SUCCESS; + } + if (crop_info->crop_type != kDvppCropTypeCentre) { + return SUCCESS; + } + if (image_width < crop_info->crop_width) { + MS_LOG(ERROR) << "Image width " << image_width << "smaller than crop width " << crop_info->crop_width; + return INVALID_INPUTS; + } + if (image_height < crop_info->crop_height) { + MS_LOG(ERROR) << "Image height " << image_height << "smaller than crop height " << crop_info->crop_height; + return INVALID_INPUTS; + } + uint32_t left = ToEven((image_width - crop_info->crop_width) / 2); + uint32_t top = ToEven((image_height - crop_info->crop_height) / 2); + uint32_t right = ToOdd(left + crop_info->crop_width); + uint32_t bottom = ToOdd(top + crop_info->crop_height); + + auto acl_ret = acldvppSetRoiConfig(crop_area_, left, right, top, bottom); + if (acl_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Update Crop Area failed"; + return FAILED; + } + MS_LOG(INFO) << "Update crop area, crop type centre, crop info: " + << ", left " << left << ", right " << right << ", top " << top << ", bottom " << bottom; + return SUCCESS; +} + +Status DvppProcess::CheckResizeImageInfo(uint32_t image_width, uint32_t image_height) const { + if (!to_resize_flag_) { + return SUCCESS; + } + // resize ratio required [1/32, 16] + auto check_resize_ratio = [](uint32_t before_resize, uint32_t after_resize) { + if (before_resize == 0 || after_resize == 0) { + return false; + } + if (before_resize / after_resize > 32) { + return false; + } + if (after_resize / before_resize > 16) { + return false; + } + return true; + }; + if (!check_resize_ratio(image_width, resize_para_.output_width)) { + MS_LOG(ERROR) << "Resize ratio required [1/32, 16], current width resize from " << image_width << " to " + << resize_para_.output_width; + return INVALID_INPUTS; + } + if (!check_resize_ratio(image_height, resize_para_.output_height)) { + MS_LOG(ERROR) << "Resize ratio required [1/32, 16], current height resize from " << image_height << " to " + << resize_para_.output_height; + return INVALID_INPUTS; + } + return SUCCESS; +} + +void DvppProcess::DestroyDecodeDesc() { + if (decode_output_desc_ != nullptr) { + acldvppDestroyPicDesc(decode_output_desc_); + decode_output_desc_ = nullptr; + } + if (decode_output_buffer_dev_ != nullptr) { + acldvppFree(decode_output_buffer_dev_); + decode_output_buffer_dev_ = nullptr; + } + decode_output_buffer_size_ = 0; + MS_LOG(INFO) << "End destroy decode desc"; +} + +Status DvppProcess::InitResizeOutputDesc() { + if (InitVpcOutputDesc(resize_para_.output_width, resize_para_.output_height, decode_para_.pixel_format) != SUCCESS) { + MS_LOG(ERROR) << "Init VPC output desc failed"; + return FAILED; + } + if (resize_config_ == nullptr) { + resize_config_ = acldvppCreateResizeConfig(); + if (resize_config_ == nullptr) { + MS_LOG(ERROR) << "Create Resize config failed"; + return FAILED; + } + } + return SUCCESS; +} + +Status DvppProcess::InitRoiAreaConfig(const DvppRoiArea &init_para, acldvppRoiConfig **roi_area) { + MS_EXCEPTION_IF_NULL(roi_area); + if (*roi_area == nullptr) { + *roi_area = acldvppCreateRoiConfig(init_para.left, init_para.right, init_para.top, init_para.bottom); + if (*roi_area == nullptr) { + MS_LOG(ERROR) << "Create Roi config failed"; + return FAILED; + } + } else { + auto acl_ret = acldvppSetRoiConfig(*roi_area, init_para.left, init_para.right, init_para.top, init_para.bottom); + if (acl_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Set Roi config failed"; + return FAILED; + } + } + return SUCCESS; +} + +Status DvppProcess::InitCropOutputDesc() { + if (InitVpcOutputDesc(crop_para_.output_width, crop_para_.output_height, decode_para_.pixel_format) != SUCCESS) { + MS_LOG(ERROR) << "Init VPC output desc failed"; + return FAILED; + } + if (InitRoiAreaConfig(crop_para_.crop_info.crop_area, &crop_area_) != SUCCESS) { + MS_LOG(ERROR) << "Init crop area failed"; + return FAILED; + } + return SUCCESS; +} + +Status DvppProcess::InitCropAndPasteOutputDesc() { + if (InitVpcOutputDesc(crop_and_paste_para_.output_width, crop_and_paste_para_.output_height, + decode_para_.pixel_format) != SUCCESS) { + MS_LOG(ERROR) << "Init VPC output desc failed"; + return FAILED; + } + if (InitRoiAreaConfig(crop_and_paste_para_.crop_info.crop_area, &crop_area_) != SUCCESS) { + MS_LOG(ERROR) << "Init crop area failed"; + return FAILED; + } + if (InitRoiAreaConfig(crop_and_paste_para_.paste_area, &paste_area_) != SUCCESS) { + MS_LOG(ERROR) << "Init paste area failed"; + return FAILED; + } + return SUCCESS; +} + +Status DvppProcess::ProcessDecode() { + aclError acl_ret; + acl_ret = acldvppJpegDecodeAsync(dvpp_channel_desc_, input_pic_dev_buffer_, input_pic_buffer_size_, + decode_output_desc_, stream_); + if (acl_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Call acldvppJpegDecodeAsync failed, acl return " << acl_ret; + return FAILED; + } + acl_ret = aclrtSynchronizeStream(stream_); + if (acl_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Call aclrtSynchronizeStream failed, acl return " << acl_ret; + return FAILED; + } + return SUCCESS; +} + +Status DvppProcess::ProcessResize() { + aclError acl_ret; + acl_ret = acldvppVpcResizeAsync(dvpp_channel_desc_, decode_output_desc_, vpc_output_desc_, resize_config_, stream_); + if (acl_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Call acldvppVpcResizeAsync failed, acl return " << acl_ret; + return FAILED; + } + acl_ret = aclrtSynchronizeStream(stream_); + if (acl_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Call aclrtSynchronizeStream failed, acl return " << acl_ret; + return FAILED; + } + return SUCCESS; +} + +Status DvppProcess::ProcessCrop() { + aclError acl_ret; + acl_ret = acldvppVpcCropAsync(dvpp_channel_desc_, decode_output_desc_, vpc_output_desc_, crop_area_, stream_); + if (acl_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Call acldvppVpcCropAsync failed, acl return " << acl_ret; + return FAILED; + } + acl_ret = aclrtSynchronizeStream(stream_); + if (acl_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Call aclrtSynchronizeStream failed, acl return " << acl_ret; + return FAILED; + } + return SUCCESS; +} + +Status DvppProcess::ProcessCropAndPaste() { + aclError acl_ret; + acl_ret = acldvppVpcCropAndPasteAsync(dvpp_channel_desc_, decode_output_desc_, vpc_output_desc_, crop_area_, + paste_area_, stream_); + if (acl_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Call acldvppVpcCropAndPasteAsync failed, acl return " << acl_ret; + return FAILED; + } + acl_ret = aclrtSynchronizeStream(stream_); + if (acl_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Call aclrtSynchronizeStream failed, acl return " << acl_ret; + return FAILED; + } + return SUCCESS; +} + +Status DvppJsonConfigParser::GetStringValue(const nlohmann::json &json_item, const std::string &key, std::string *val) { + MS_EXCEPTION_IF_NULL(val); + auto it = json_item.find(key); + if (it == json_item.end()) { + MS_LOG(ERROR) << "Get string item " << key << " failed"; + return FAILED; + } + if (!it->is_string()) { + MS_LOG(ERROR) << "Item " << key << " value is not string type"; + return FAILED; + } + *val = it->get(); + return SUCCESS; +} + +Status DvppJsonConfigParser::GetIntValue(const nlohmann::json &json_item, const std::string &key, uint32_t *val) { + MS_EXCEPTION_IF_NULL(val); + auto it = json_item.find(key); + if (it == json_item.end()) { + MS_LOG(ERROR) << "Get string item " << key << " failed"; + return FAILED; + } + if (!it->is_number_integer()) { + MS_LOG(ERROR) << "Item " << key << " value is not integer type"; + return FAILED; + } + *val = it->get(); + return SUCCESS; +} + +Status DvppJsonConfigParser::ParseInputPara(const nlohmann::json &preprocess_item) { + auto input = preprocess_item.find("input"); + if (input == preprocess_item.end()) { + MS_LOG(ERROR) << "Get input failed"; + return FAILED; + } + if (!input->is_object()) { + MS_LOG(ERROR) << "Input is not object"; + return FAILED; + } + return SUCCESS; +} + +Status DvppJsonConfigParser::ParseDecodePara(const nlohmann::json &preprocess_item) { + auto decode_para = preprocess_item.find("decode_para"); + if (decode_para == preprocess_item.end()) { + MS_LOG(ERROR) << "Get input failed"; + return FAILED; + } + if (!decode_para->is_object()) { + MS_LOG(ERROR) << "Input is not object"; + return FAILED; + } + const std::unordered_map pixel_format_map = { + {"YUV420SP", PIXEL_FORMAT_YUV_SEMIPLANAR_420}, {"YVU420SP", PIXEL_FORMAT_YVU_SEMIPLANAR_420}, + {"YUV422SP", PIXEL_FORMAT_YUV_SEMIPLANAR_422}, {"YVU422SP", PIXEL_FORMAT_YVU_SEMIPLANAR_422}, + {"YUV444SP", PIXEL_FORMAT_YUV_SEMIPLANAR_444}, {"YVU444SP", PIXEL_FORMAT_YVU_SEMIPLANAR_444}, + }; + std::string pixel_format; + if (GetStringValue(*decode_para, "out_pixel_format", &pixel_format) != SUCCESS) { + MS_LOG(ERROR) << "Get op out_pixel_format failed"; + return FAILED; + } + auto format = pixel_format_map.find(pixel_format); + if (format == pixel_format_map.end()) { + MS_LOG(ERROR) << "Unsupported out_pixel_format " << pixel_format; + return FAILED; + } + decode_para_.pixel_format = format->second; + return SUCCESS; +} + +Status DvppJsonConfigParser::ParseResizePara(const nlohmann::json &json_item) { + if (GetIntValue(json_item, "out_width", &resize_para_.output_width) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "out_height", &resize_para_.output_height) != SUCCESS) { + return FAILED; + } + resize_flag_ = true; + return SUCCESS; +} + +Status DvppJsonConfigParser::ParseCropPara(const nlohmann::json &json_item) { + if (GetIntValue(json_item, "out_width", &crop_para_.output_width) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "out_height", &crop_para_.output_height) != SUCCESS) { + return FAILED; + } + auto &crop_info = crop_para_.crop_info; + std::string crop_type = "crop_type"; + if (GetStringValue(json_item, "crop_type", &crop_type) != SUCCESS) { + return FAILED; + } + if (crop_type == "offset") { + MS_LOG(INFO) << "Crop type is 'offset'"; + crop_info.crop_type = kDvppCropTypeOffset; + auto &crop_area = crop_info.crop_area; + if (GetIntValue(json_item, "crop_left", &crop_area.left) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "crop_top", &crop_area.top) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "crop_right", &crop_area.right) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "crop_bottom", &crop_area.bottom) != SUCCESS) { + return FAILED; + } + } else if (crop_type == "centre") { + MS_LOG(INFO) << "Crop type is 'centre'"; + if (GetIntValue(json_item, "crop_width", &crop_info.crop_width) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "crop_height", &crop_info.crop_height) != SUCCESS) { + return FAILED; + } + crop_info.crop_type = kDvppCropTypeCentre; + } else { + MS_LOG(ERROR) << "Invalid crop type " << crop_type << ", expect offset or centre"; + return FAILED; + } + crop_flag_ = true; + return SUCCESS; +} + +Status DvppJsonConfigParser::ParseCropAndPastePara(const nlohmann::json &json_item) { + // crop info + if (GetIntValue(json_item, "out_width", &crop_and_paste_para_.output_width) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "out_height", &crop_and_paste_para_.output_height) != SUCCESS) { + return FAILED; + } + auto &crop_info = crop_and_paste_para_.crop_info; + std::string crop_type = "crop_type"; + if (GetStringValue(json_item, "crop_type", &crop_type) != SUCCESS) { + return FAILED; + } + if (crop_type == "offset") { + MS_LOG(INFO) << "Crop type is 'offset'"; + crop_info.crop_type = kDvppCropTypeOffset; + auto &crop_area = crop_info.crop_area; + if (GetIntValue(json_item, "crop_left", &crop_area.left) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "crop_top", &crop_area.top) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "crop_right", &crop_area.right) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "crop_bottom", &crop_area.bottom) != SUCCESS) { + return FAILED; + } + } else if (crop_type == "centre") { + MS_LOG(INFO) << "Crop type is 'centre'"; + if (GetIntValue(json_item, "crop_width", &crop_info.crop_width) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "crop_height", &crop_info.crop_height) != SUCCESS) { + return FAILED; + } + crop_info.crop_type = kDvppCropTypeCentre; + } else { + MS_LOG(ERROR) << "Invalid crop type " << crop_type << ", expect offset or centre"; + return FAILED; + } + // paste info + auto &paste_area = crop_and_paste_para_.paste_area; + if (GetIntValue(json_item, "paste_left", &paste_area.left) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "paste_top", &paste_area.top) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "paste_right", &paste_area.right) != SUCCESS) { + return FAILED; + } + if (GetIntValue(json_item, "paste_bottom", &paste_area.bottom) != SUCCESS) { + return FAILED; + } + crop_and_paste_flag_ = true; + return SUCCESS; +} + +Status DvppJsonConfigParser::InitWithJsonConfigImp(const std::string &json_config) { + std::ifstream fp(json_config); + if (!fp.is_open()) { + MS_LOG(ERROR) << "Read json config file failed"; + return FAILED; + } + const auto &model_info = nlohmann::json::parse(fp); + auto preprocess_list = model_info.find("preprocess"); + if (preprocess_list == model_info.end()) { + MS_LOG(ERROR) << "Get preprocess failed"; + return FAILED; + } + if (!preprocess_list->is_array()) { + MS_LOG(ERROR) << "Preprocess is not array"; + return FAILED; + } + if (preprocess_list->empty()) { + MS_LOG(ERROR) << "Preprocess size is 0"; + return FAILED; + } + auto &preprocess = preprocess_list->at(0); + // input + if (ParseInputPara(preprocess) != SUCCESS) { + MS_LOG(ERROR) << "Parse input failed"; + return FAILED; + } + // decode para + if (ParseDecodePara(preprocess) != SUCCESS) { + MS_LOG(ERROR) << "Parse decode failed"; + return FAILED; + } + // ops + auto dvpp_process = preprocess.find("dvpp_process"); + if (dvpp_process == preprocess.end()) { + MS_LOG(ERROR) << "Get dvpp_process failed"; + return FAILED; + } + if (!dvpp_process->is_object()) { + MS_LOG(ERROR) << "Obj dvpp_process is not array"; + return FAILED; + } + const auto &item = *dvpp_process; + std::string op_name; + if (GetStringValue(item, "op_name", &op_name) != SUCCESS) { + return FAILED; + } + if (op_name == "resize") { + if (ParseResizePara(item) != SUCCESS) { + MS_LOG(ERROR) << "Parse resize para failed"; + return FAILED; + } + } else if (op_name == "crop") { + if (ParseCropPara(item) != SUCCESS) { + MS_LOG(ERROR) << "Parse crop para failed"; + return FAILED; + } + } else if (op_name == "crop_and_paste") { + if (ParseCropAndPastePara(item) != SUCCESS) { + MS_LOG(ERROR) << "Parse decode para failed"; + return FAILED; + } + } else { + MS_LOG(ERROR) << "Unsupported op name " << op_name << ", expect resize, crop or crop_and_paste"; + return FAILED; + } + return SUCCESS; +} + +Status DvppJsonConfigParser::InitWithJsonConfig(const std::string &json_config) { + try { + auto ret = InitWithJsonConfigImp(json_config); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Init dvpp with json config failed, json config " << json_config; + return FAILED; + } + } catch (nlohmann::json::exception &e) { + MS_LOG(ERROR) << "Init dvpp with json config failed, json config " << json_config << ", error: " << e.what(); + return FAILED; + } + MS_LOG(INFO) << "Init with json config " << json_config << " success"; + return SUCCESS; +} + +Status DvppProcess::InitWithJsonConfig(const std::string &json_config) { + if (json_config.empty()) { + MS_LOG(INFO) << "No dvpp config file path set, skip."; + loaded_flag_ = false; + return SUCCESS; + } + + char real_path[PATH_MAX] = {0}; + if (realpath(common::SafeCStr(json_config), real_path) == nullptr) { + MS_LOG(WARNING) << "Dvpp json file " << json_config << " is not exist."; + loaded_flag_ = false; + return SUCCESS; + } + + DvppJsonConfigParser parser; + if (parser.InitWithJsonConfig(real_path) != SUCCESS) { + MS_LOG(ERROR) << "Init json config failed"; + return FAILED; + } + if (InitJpegDecodePara(parser.GetDecodePara()) != SUCCESS) { + MS_LOG(ERROR) << "Init decode para failed"; + return FAILED; + } + if (parser.HasResizeConfig()) { + if (InitResizePara(parser.GetResizePara()) != SUCCESS) { + MS_LOG(ERROR) << "Init resize para failed"; + return FAILED; + } + } else if (parser.HasCropConfig()) { + if (InitCropPara(parser.GetCropPara()) != SUCCESS) { + MS_LOG(ERROR) << "Init crop para failed"; + return FAILED; + } + } else if (parser.HasCropAndPasteConfig()) { + if (InitCropAndPastePara(parser.GetCropAndPastePara()) != SUCCESS) { + MS_LOG(ERROR) << "Init crop and paste para failed"; + return FAILED; + } + } + + MS_LOG(INFO) << "Dvpp config success"; + loaded_flag_ = true; + return SUCCESS; +} +} // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/model/acl/dvpp_process.h b/mindspore/ccsrc/cxx_api/model/acl/dvpp_process.h new file mode 100644 index 00000000000..105baef1689 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/model/acl/dvpp_process.h @@ -0,0 +1,160 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_DVPP_PROCESS_H +#define MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_DVPP_PROCESS_H +#include +#include +#include +#include "acl/acl.h" +#include "acl/acl_mdl.h" +#include "acl/acl_rt.h" +#include "acl/ops/acl_dvpp.h" +#include "include/api/status.h" + +namespace mindspore::api { +struct DvppDecodePara { + acldvppPixelFormat pixel_format = PIXEL_FORMAT_YUV_SEMIPLANAR_420; +}; + +struct DvppResizePara { + uint32_t output_width = 0; + uint32_t output_height = 0; +}; + +enum DvppCropType { + // crop left,top,right,bottom is given in config + kDvppCropTypeOffset = 0, + // crop left,top,right,bottom is calculated by image width/height and output crop width/height + kDvppCropTypeCentre = 1, +}; + +struct DvppRoiArea { + uint32_t left = 0; + uint32_t top = 0; + uint32_t right = 0; + uint32_t bottom = 0; +}; + +struct DvppCropInfo { + DvppCropType crop_type = kDvppCropTypeOffset; + DvppRoiArea crop_area; // when kDvppCropTypeOffset + uint32_t crop_width = 0; // when kDvppCropTypeCentre + uint32_t crop_height = 0; // when kDvppCropTypeCentre +}; + +struct DvppCropPara { + DvppCropInfo crop_info; + uint32_t output_width = 0; + uint32_t output_height = 0; +}; + +struct DvppCropAndPastePara { + DvppCropInfo crop_info; + DvppRoiArea paste_area; + uint32_t output_width = 0; + uint32_t output_height = 0; +}; + +class DvppProcess { + public: + DvppProcess(); + ~DvppProcess(); + + Status InitResource(aclrtStream stream); + void Finalize(); + Status InitJpegDecodePara(const DvppDecodePara &decode_para); // jpeg decode + (resize | crop) + Status InitResizePara(const DvppResizePara &resize_para); // jpeg decode + resize + Status InitCropPara(const DvppCropPara &crop_para); // jpeg decode + crop + Status InitCropAndPastePara(const DvppCropAndPastePara &crop_and_paste_para); // jpeg decode + crop&paste + + Status InitWithJsonConfig(const std::string &json_config); + + // output device buffer will be destroy by DvppProcess itself. + Status Process(const void *pic_buffer, size_t pic_buffer_size, void **output_device_buffer, size_t *output_size); + Status Process(const std::vector &pic_buffer_list, const std::vector &pic_buffer_size_list, + void **output_device_buffer, size_t *output_size); + bool HasLoaded() const { return loaded_flag_; } + + private: + bool loaded_flag_ = false; + uint32_t pic_width_ = 0; + uint32_t pic_height_ = 0; + + DvppDecodePara decode_para_; + DvppResizePara resize_para_; + DvppCropPara crop_para_; + DvppCropAndPastePara crop_and_paste_para_; + // only one of the resize or crop flag can be true + bool to_resize_flag_ = false; + bool to_crop_flag_ = false; + bool to_crop_and_paste_flag_ = false; + + void *input_pic_dev_buffer_ = nullptr; + uint32_t input_pic_buffer_size_ = 0; + + uint32_t decode_output_buffer_size_ = 0; + void *decode_output_buffer_dev_ = nullptr; + acldvppPicDesc *decode_output_desc_ = nullptr; + + acldvppResizeConfig *resize_config_ = nullptr; + acldvppRoiConfig *crop_area_ = nullptr; + acldvppRoiConfig *paste_area_ = nullptr; + + acldvppPicDesc *vpc_output_desc_ = nullptr; + void *vpc_output_buffer_dev_ = nullptr; // vpc_output_buffer_size_ length + uint32_t vpc_output_buffer_size_ = 0; + + void *batch_vpc_output_buffer_dev_ = nullptr; // batch_size_ * vpc_output_buffer_size_ length + uint32_t batch_size_ = 0; + + aclrtStream stream_ = nullptr; + acldvppChannelDesc *dvpp_channel_desc_ = nullptr; + + uint32_t AlignmentHelper(uint32_t org_size, uint32_t alignment) const; + uint32_t GetImageBufferSize(uint32_t stride_width, uint32_t stride_height, acldvppPixelFormat pixel_format) const; + Status GetPicDescStride(uint32_t width, uint32_t height, uint32_t *stride_width, uint32_t *stride_height); + Status GetPicDescStrideDecode(uint32_t width, uint32_t height, uint32_t *stride_width, uint32_t *stride_height); + Status InputInputBuffer(const void *pic_buffer, size_t pic_buffer_size); + Status InitDecodeOutputDesc(uint32_t image_width, + uint32_t image_height); // decode_output_desc_, decode_output_buffer_dev_ + Status CheckRoiAreaWidthHeight(uint32_t width, uint32_t height); + Status CheckAndAdjustRoiArea(DvppRoiArea *area); + Status UpdateCropArea(uint32_t image_width, uint32_t image_height); + Status CheckResizeImageInfo(uint32_t image_width, uint32_t image_height) const; + void DestroyDecodeDesc(); + + Status InitVpcOutputDesc(uint32_t output_width, uint32_t output_height, + acldvppPixelFormat pixel_format); // vpc_output_desc_, vpc_output_buffer_dev_batch_ + Status InitRoiAreaConfig(const DvppRoiArea &init_para, acldvppRoiConfig **roi_area); + Status InitCommonCropPara(uint32_t out_width, uint32_t out_height, DvppCropInfo *crop_info); + Status InitResizeOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, resize_config + Status InitCropOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, crop_area_ + Status InitCropAndPasteOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, crop_area_, paste_area_ + void DestroyVpcOutputDesc(); + + Status ProcessDecode(); + Status ProcessResize(); + Status ProcessCrop(); + Status ProcessCropAndPaste(); + void DestroyResource(); + + Status GetJpegWidthHeight(const void *pic_buffer, size_t pic_buffer_size, uint32_t *image_width, + uint32_t *image_height); +}; +} // namespace mindspore::api + +#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_DVPP_PROCESS_H diff --git a/mindspore/ccsrc/cxx_api/model/acl/model_converter.cc b/mindspore/ccsrc/cxx_api/model/acl/model_converter.cc new file mode 100644 index 00000000000..bc79310aa21 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/model/acl/model_converter.cc @@ -0,0 +1,285 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cxx_api/model/acl/model_converter.h" +#include +#include "pybind11/pybind11.h" +#include "utils/load_onnx/anf_converter.h" +#include "transform/graph_ir/convert.h" +#include "transform/graph_ir/graph_runner.h" +#include "mindspore/core/utils/ms_context.h" +#include "backend/kernel_compiler/oplib/oplib.h" + +#include "graph/model.h" + +namespace py = pybind11; + +namespace mindspore::api { +namespace { +transform::TensorOrderMap GetParams(const FuncGraphPtr &anf_graph) { + transform::TensorOrderMap res; + for (auto &anf_node : anf_graph->parameters()) { + MS_EXCEPTION_IF_NULL(anf_node); + auto para = anf_node->cast(); + MS_EXCEPTION_IF_NULL(para); + if (para->has_default()) { + auto value = para->default_param(); + MS_EXCEPTION_IF_NULL(value); + auto tensor = value->cast>(); + res.emplace(para->name(), tensor); + MS_LOG(INFO) << "Parameter " << para->name() << " has default value."; + } + } + return res; +} + +bool CreateSessionAndGraphRunner() { + std::shared_ptr sess = transform::DfGraphManager::GetInstance().GetGeSession(); + if (sess == nullptr) { + transform::SessionOptions options; + options["ge.trainFlag"] = "0"; + options["ge.enablePrintOpPass"] = "0"; + sess = transform::GraphRunner::NewSession(options); + if (sess == nullptr) { + MS_LOG(ERROR) << "Init data graph failed, because of create Ge session failed"; + return false; + } else { + transform::DfGraphManager::GetInstance().SetGeSession(sess); + } + } + + transform::GraphRunnerOptions options; + options.sess_ptr = sess; + auto graph_runner = std::make_shared(options); + if (graph_runner == nullptr) { + MS_LOG(ERROR) << "Create new graph runner failed"; + return false; + } else { + transform::DfGraphManager::GetInstance().SetGraphRunner(graph_runner); + } + + return true; +} + +} // namespace + +std::shared_ptr ModelConverter::ConvertMindIrToFuncGraph(const Buffer &model_data) { + try { + auto anf_graph = + lite::AnfConverter::RunAnfConverter(reinterpret_cast(model_data.Data()), model_data.DataSize()); + return anf_graph; + } catch (std::exception &e) { + MS_LOG(ERROR) << "Load MindIR failed."; + return nullptr; + } +} + +transform::DfGraphPtr ModelConverter::ConvertFuncGraphToAIR(const FuncGraphPtr &anf_graph) { + for (auto &anf_node : anf_graph->parameters()) { + MS_EXCEPTION_IF_NULL(anf_node); + auto para = anf_node->cast(); + MS_EXCEPTION_IF_NULL(para); + // normalize name + std::string name = para->name(); + for (auto pos = name.find(':'); pos != std::string::npos; pos = name.find(':')) { + name = name.substr(0, pos) + "_" + name.substr(pos + 1); + MS_LOG(INFO) << name; + } + para->set_name(name); + } + + transform::DfGraphConvertor convertor(anf_graph); + std::string net_id = "0"; + std::string init_graph = "init_subgraph." + net_id; + std::string checkpoint_name = "save." + net_id; + + convertor.set_training(false); + (void)convertor.ConvertAllNode().InitParam(GetParams(anf_graph)).BuildGraph(); + (void)convertor.GenerateCheckpointGraph(); + if (convertor.ErrCode() != 0) { + transform::DfGraphManager::GetInstance().ClearGraph(); + MS_LOG(ERROR) << "Convert df graph failed, err:" << convertor.ErrCode(); + return nullptr; + } + (void)transform::DfGraphManager::GetInstance().AddGraph(anf_graph->ToString(), convertor.GetComputeGraph()); + (void)transform::DfGraphManager::GetInstance().AddGraph(init_graph, convertor.GetInitGraph()); + (void)transform::DfGraphManager::GetInstance().AddGraph(BROADCAST_GRAPH_NAME, convertor.GetBroadcastGraph()); + + transform::Status ret = + transform::DfGraphManager::GetInstance().AddGraph(checkpoint_name, convertor.GetSaveCheckpointGraph()); + if (ret == transform::Status::SUCCESS) { + transform::DfGraphManager::GetInstance().SetAnfGraph(checkpoint_name, anf_graph); + } + + (void)setenv("GE_TRAIN", "0", 1); + + if (!CreateSessionAndGraphRunner()) { + MS_LOG(ERROR) << "Create GE Session or GraphRunner failed."; + return nullptr; + } + + auto wrap_ptr = transform::DfGraphManager::GetInstance().GetGraphByName(anf_graph->ToString()); + if (wrap_ptr == nullptr) { + MS_LOG(ERROR) << "Get graph form DfGraphManager failed!"; + return nullptr; + } + transform::DfGraphPtr &ge_graph = wrap_ptr->graph_ptr_; + if (ge_graph == nullptr) { + MS_LOG(ERROR) << "The export graph is null"; + return nullptr; + } + + return ge_graph; +} + +Buffer ModelConverter::BuildAirModel(const transform::DfGraphPtr &graph, + const std::map &acl_options) { + ge::ModelBufferData model; + auto ge_options = acl_options; + ge_options.emplace(ge::ir_option::SOC_VERSION, "Ascend310"); + auto ret = ge::aclgrphBuildInitialize(ge_options); + if (ret != ge::SUCCESS) { + MS_LOG(ERROR) << "Call aclgrphBuildInitialize fail."; + return Buffer(); + } + + ret = ge::aclgrphBuildModel(*graph, acl_options, model); + if (ret != ge::SUCCESS) { + MS_LOG(ERROR) << "Call aclgrphBuildModel fail."; + return Buffer(); + } + + ge::aclgrphBuildFinalize(); + return Buffer(model.data.get(), model.length); +} + +void ModelConverter::RegAllOp() { + static std::mutex init_mutex; + static bool Initialized = false; + + std::lock_guard lock(init_mutex); + if (Initialized) { + return; + } + Initialized = true; + MsContext::GetInstance()->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); + Py_Initialize(); + auto c_expression = PyImport_ImportModule("mindspore._c_expression"); + MS_EXCEPTION_IF_NULL(c_expression); + PyObject *c_expression_dict = PyModule_GetDict(c_expression); + MS_EXCEPTION_IF_NULL(c_expression_dict); + + PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy"); + MS_EXCEPTION_IF_NULL(op_info_loader_class); + PyObject *op_info_loader = PyInstanceMethod_New(op_info_loader_class); + MS_EXCEPTION_IF_NULL(op_info_loader); + PyObject *op_info_loader_ins = PyObject_CallObject(op_info_loader, nullptr); + MS_EXCEPTION_IF_NULL(op_info_loader_ins); + auto all_ops_info_vector_addr_ul = PyObject_CallMethod(op_info_loader_ins, "get_all_ops_info", nullptr); + MS_EXCEPTION_IF_NULL(all_ops_info_vector_addr_ul); + auto all_ops_info_vector_addr = PyLong_AsVoidPtr(all_ops_info_vector_addr_ul); + auto all_ops_info = static_cast *>(all_ops_info_vector_addr); + for (auto op_info : *all_ops_info) { + kernel::OpLib::RegOpInfo(std::shared_ptr(op_info)); + } + all_ops_info->clear(); + delete all_ops_info; + Py_DECREF(op_info_loader); + Py_DECREF(op_info_loader_class); + Py_DECREF(c_expression_dict); + Py_DECREF(c_expression); +} + +Buffer ModelConverter::ReadFile(const std::string &file) { + Buffer buffer; + if (file.empty()) { + MS_LOG(ERROR) << "Pointer file is nullptr"; + return buffer; + } + std::string realPath = file; + std::ifstream ifs(realPath); + if (!ifs.good()) { + MS_LOG(ERROR) << "File: " << realPath << " is not exist"; + return buffer; + } + + if (!ifs.is_open()) { + MS_LOG(ERROR) << "File: " << realPath << "open failed"; + return buffer; + } + + ifs.seekg(0, std::ios::end); + size_t size = ifs.tellg(); + buffer.ResizeData(size); + if (buffer.DataSize() != size) { + MS_LOG(ERROR) << "Malloc buf failed, file: " << realPath; + ifs.close(); + return buffer; + } + + ifs.seekg(0, std::ios::beg); + ifs.read(reinterpret_cast(buffer.MutableData()), size); + ifs.close(); + + return buffer; +} + +Buffer ModelConverter::LoadMindIR(const Buffer &model_data) { + auto func_graph = ConvertMindIrToFuncGraph(model_data); + if (func_graph == nullptr) { + MS_LOG(ERROR) << "Convert MindIR to FuncGraph failed."; + return Buffer(); + } + + auto df_graph = ConvertFuncGraphToAIR(func_graph); + if (df_graph == nullptr) { + MS_LOG(ERROR) << "Convert FuncGraph to AscendIR failed."; + return Buffer(); + } + + std::map acl_options; + if (options_ != nullptr) { + acl_options = options_->GenAclOptions(); + } + + auto om_data = BuildAirModel(df_graph, acl_options); + return om_data; +} + +Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) { + ge::Model load_model = ge::Model("loadmodel", "version2"); + ge::Status ret = + ge::Model::Load(reinterpret_cast(model_data.Data()), model_data.DataSize(), load_model); + if (ret != ge::GRAPH_SUCCESS) { + MS_LOG(ERROR) << "Load AscendIR failed, ret = " << ret; + return Buffer(); + } + + transform::DfGraphPtr df_graph = std::make_shared(load_model.GetGraph()); + if (df_graph == nullptr) { + MS_LOG(ERROR) << "Convert FuncGraph to AscendIR failed."; + return Buffer(); + } + + std::map acl_options; + if (options_ != nullptr) { + acl_options = options_->GenAclOptions(); + } + + auto om_data = BuildAirModel(df_graph, acl_options); + return om_data; +} +} // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/model/acl/model_converter.h b/mindspore/ccsrc/cxx_api/model/acl/model_converter.h new file mode 100644 index 00000000000..6189ad56def --- /dev/null +++ b/mindspore/ccsrc/cxx_api/model/acl/model_converter.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_CONVERTER_H +#define MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_CONVERTER_H +#include +#include +#include +#include +#include "include/api/types.h" +#include "include/api/status.h" +#include "mindspore/core/ir/func_graph.h" +#include "transform/graph_ir/types.h" +#include "external/ge/ge_ir_build.h" +#include "cxx_api/model/acl/acl_model_options.h" + +namespace mindspore::api { +class ModelConverter { + public: + ModelConverter() : options_(nullptr) {} + + Buffer LoadMindIR(const Buffer &model_data); + Buffer LoadAscendIR(const Buffer &model_data); + + void set_options(AclModelOptions *options) { options_ = options; } + + static Buffer ReadFile(const std::string &file); + static void RegAllOp(); + + private: + std::shared_ptr ConvertMindIrToFuncGraph(const Buffer &model_data); + transform::DfGraphPtr ConvertFuncGraphToAIR(const FuncGraphPtr &anf_graph); + Buffer BuildAirModel(const transform::DfGraphPtr &graph, const std::map &acl_options); + AclModelOptions *options_; +}; +} // namespace mindspore::api + +#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_CONVERTER_H diff --git a/mindspore/ccsrc/cxx_api/model/acl/model_process.cc b/mindspore/ccsrc/cxx_api/model/acl/model_process.cc new file mode 100644 index 00000000000..eb2afc10955 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/model/acl/model_process.cc @@ -0,0 +1,440 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cxx_api/model/acl/model_process.h" +#include +#include +#include "utils/utils.h" + +namespace mindspore::api { +static DataType TransToApiType(aclDataType data_type) { + static const std::map data_type_map = { + {ACL_FLOAT16, api::kMsFloat16}, {ACL_FLOAT, api::kMsFloat32}, {ACL_DOUBLE, api::kMsFloat64}, + {ACL_INT8, api::kMsInt8}, {ACL_INT16, api::kMsInt16}, {ACL_INT32, api::kMsInt32}, + {ACL_INT64, api::kMsInt64}, {ACL_UINT8, api::kMsUint8}, {ACL_UINT16, api::kMsUint16}, + {ACL_UINT32, api::kMsUint32}, {ACL_UINT64, api::kMsUint64}, {ACL_BOOL, api::kMsBool}, + }; + auto it = data_type_map.find(data_type); + if (it == data_type_map.end()) { + return api::kInvalidDataType; + } else { + return it->second; + } +} + +static void ConstructTensorDesc(const std::vector &acl_tensor_list, std::vector *tensor_list) { + MS_EXCEPTION_IF_NULL(tensor_list); + tensor_list->clear(); + + for (size_t i = 0; i < acl_tensor_list.size(); ++i) { + const auto &info = acl_tensor_list[i]; + Tensor tensor_desc; + tensor_desc.SetName(info.name); + tensor_desc.SetDataType(TransToApiType(info.data_type)); + tensor_desc.SetShape(info.dims); + tensor_list->push_back(tensor_desc); + } +} + +Status ModelProcess::PreInitModelResource() { + model_desc_ = aclmdlCreateDesc(); + aclError acl_ret = aclmdlGetDesc(model_desc_, model_id_); + if (acl_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Read model desc failed"; + return FAILED; + } + Status ret = InitInputsBuffer(); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Create input buffer failed"; + return FAILED; + } + ret = InitOutputsBuffer(); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Create output buffer failed"; + return FAILED; + } + return SUCCESS; +} + +Status ModelProcess::LoadModelFromFile(const std::string &file_name, uint32_t *model_id) { + MS_EXCEPTION_IF_NULL(model_id); + aclError acl_ret = aclmdlLoadFromFile(file_name.c_str(), model_id); + if (acl_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Read model file failed, file name is " << file_name; + return FAILED; + } + MS_LOG(INFO) << "Load model success " << file_name; + model_id_ = *model_id; + if (PreInitModelResource() != SUCCESS) { + aclmdlUnload(model_id_); + MS_LOG(ERROR) << "Pre init model resource failed, file name is " << file_name; + return FAILED; + } + return SUCCESS; +} + +Status ModelProcess::InitInputsBuffer() { + aclError ret; + size_t input_size = aclmdlGetNumInputs(model_desc_); + MS_LOG(INFO) << "input_size = " << input_size; + for (size_t i = 0; i < input_size; ++i) { + auto buffer_size = aclmdlGetInputSizeByIndex(model_desc_, i); + void *data_mem_buffer = nullptr; + if (!is_run_on_device_) { // need to copy input/output to/from device + ret = aclrtMalloc(&data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Malloc device input buffer faild , input size " << buffer_size; + return FAILED; + } + } + + aclmdlIODims dims; + ret = aclmdlGetInputDims(model_desc_, i, &dims); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Get input shape failed"; + if (!is_run_on_device_) { + aclrtFree(data_mem_buffer); + } + return FAILED; + } + aclDataType data_type = aclmdlGetInputDataType(model_desc_, i); + std::vector shape(dims.dims, dims.dims + dims.dimCount); + std::string input_name = aclmdlGetInputNameByIndex(model_desc_, i); + if (input_name.empty()) { + MS_LOG(WARNING) << "Get name of input " << i << " failed."; + } + MS_LOG(INFO) << "Name of input " << i << " is " << input_name; + input_infos_.emplace_back(AclTensorInfo{data_mem_buffer, buffer_size, data_type, shape, input_name}); + } + MS_LOG(INFO) << "Create model inputs success"; + return SUCCESS; +} + +Status ModelProcess::CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset) { + MS_EXCEPTION_IF_NULL(data_mem_buffer); + aclError ret; + auto free_data_buffer = [this](void *dataMemBuffer) { + if (!is_run_on_device_) { + aclrtFree(dataMemBuffer); + } else { + aclrtFreeHost(dataMemBuffer); + } + }; + + if (!is_run_on_device_) { + ret = aclrtMalloc(data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Malloc device buffer faild , buffer size " << buffer_size; + return FAILED; + } + } else { + ret = aclrtMallocHost(data_mem_buffer, buffer_size); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Malloc device buffer faild , buffer size " << buffer_size; + return FAILED; + } + } + + auto data_buffer = aclCreateDataBuffer(*data_mem_buffer, buffer_size); + if (data_buffer == nullptr) { + MS_LOG(ERROR) << "Create Data Buffer failed"; + free_data_buffer(*data_mem_buffer); + return FAILED; + } + ret = aclmdlAddDatasetBuffer(dataset, data_buffer); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "add data buffer failed"; + free_data_buffer(*data_mem_buffer); + aclDestroyDataBuffer(data_buffer); + return FAILED; + } + return SUCCESS; +} + +Status ModelProcess::InitOutputsBuffer() { + aclError ret; + outputs_ = aclmdlCreateDataset(); + if (outputs_ == nullptr) { + MS_LOG(ERROR) << "Create input dataset failed"; + return FAILED; + } + size_t output_size = aclmdlGetNumOutputs(model_desc_); + MS_LOG(INFO) << "output_size = " << output_size; + for (size_t i = 0; i < output_size; ++i) { + auto buffer_size = aclmdlGetOutputSizeByIndex(model_desc_, i); + + void *data_mem_buffer = nullptr; + if (CreateDataBuffer(&data_mem_buffer, buffer_size, outputs_) != SUCCESS) { + MS_LOG(ERROR) << "add output data buffer failed, buffer size " << buffer_size; + return FAILED; + } + aclmdlIODims dims; + ret = aclmdlGetOutputDims(model_desc_, i, &dims); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Get input shape failed"; + if (!is_run_on_device_) { + aclrtFree(data_mem_buffer); + } else { + aclrtFreeHost(data_mem_buffer); + } + return FAILED; + } + aclDataType data_type = aclmdlGetOutputDataType(model_desc_, i); + std::vector shape(dims.dims, dims.dims + dims.dimCount); + std::string output_name = aclmdlGetOutputNameByIndex(model_desc_, i); + if (output_name.empty()) { + MS_LOG(WARNING) << "Get name of output " << i << " failed."; + } + MS_LOG(INFO) << "Name of input " << i << " is " << output_name; + output_infos_.emplace_back(AclTensorInfo{data_mem_buffer, buffer_size, data_type, shape, output_name}); + } + MS_LOG(INFO) << "Create model output success"; + return SUCCESS; +} + +void ModelProcess::DestroyInputsDataset() { + if (inputs_ == nullptr) { + return; + } + for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(inputs_); i++) { + auto dataBuffer = aclmdlGetDatasetBuffer(inputs_, i); + aclDestroyDataBuffer(dataBuffer); + } + aclmdlDestroyDataset(inputs_); + inputs_ = nullptr; +} + +void ModelProcess::DestroyInputsDataMem() { + if (!is_run_on_device_) { + for (const auto &item : input_infos_) { + aclrtFree(item.device_data); + } + } + input_infos_.clear(); +} + +void ModelProcess::DestroyInputsBuffer() { + DestroyInputsDataMem(); + DestroyInputsDataset(); +} + +void ModelProcess::DestroyOutputsBuffer() { + for (const auto &item : output_infos_) { + if (!is_run_on_device_) { + aclrtFree(item.device_data); + } else { + aclrtFreeHost(item.device_data); + } + } + output_infos_.clear(); + + if (outputs_ == nullptr) { + return; + } + for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(outputs_); i++) { + auto dataBuffer = aclmdlGetDatasetBuffer(outputs_, i); + aclDestroyDataBuffer(dataBuffer); + } + aclmdlDestroyDataset(outputs_); + outputs_ = nullptr; +} + +Status ModelProcess::UnLoad() { + auto ret = aclmdlUnload(model_id_); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Unload model failed"; + return FAILED; + } + if (model_desc_ != nullptr) { + ret = aclmdlDestroyDesc(model_desc_); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Unload model failed"; + return FAILED; + } + model_desc_ = nullptr; + } + DestroyInputsBuffer(); + DestroyOutputsBuffer(); + MS_LOG(INFO) << "End unload model " << model_id_; + return SUCCESS; +} + +Status ModelProcess::CheckAndInitInput(const std::map &inputs) { + aclError ret; + inputs_ = aclmdlCreateDataset(); + // check inputs + if (inputs.size() != input_infos_.size()) { + MS_LOG(ERROR) << "inputs count not match, required count " << input_infos_.size() << ", given count " + << inputs.size(); + return INVALID_INPUTS; + } + for (size_t i = 0; i < input_infos_.size(); ++i) { + const std::string &input_name = input_infos_[i].name; + auto iter = inputs.find(input_name); + if (iter == inputs.end()) { + MS_LOG(ERROR) << "Model missing input " << input_name; + return INVALID_INPUTS; + } + + if (iter->second.DataSize() != input_infos_[i].buffer_size) { + MS_LOG(ERROR) << "input " << i << " data size not match, required size " << input_infos_[i].buffer_size + << ", given count " << iter->second.DataSize(); + return INVALID_INPUTS; + } + } + // copy inputs + for (size_t i = 0; i < input_infos_.size(); ++i) { + const auto &info = input_infos_[i]; + auto iter = inputs.find(info.name); + if (iter == inputs.end()) { + MS_LOG(ERROR) << "Model missing input " << info.name; + return INVALID_INPUTS; + } + + const auto &input = iter->second; + const void *data = input.Data(); + + void *input_buffer; + if (!is_run_on_device_) { + ret = aclrtMemcpy(info.device_data, info.buffer_size, data, input.DataSize(), ACL_MEMCPY_HOST_TO_DEVICE); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Acl memcpy input " << i << " data to device failed, buffer size " << input.DataSize(); + return FAILED; + } + input_buffer = info.device_data; + } else { + input_buffer = const_cast(data); + } + auto data_buffer = aclCreateDataBuffer(input_buffer, info.buffer_size); + if (data_buffer == nullptr) { + MS_LOG(ERROR) << "Create Data Buffer failed"; + return FAILED; + } + ret = aclmdlAddDatasetBuffer(inputs_, data_buffer); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "add data buffer failed"; + aclDestroyDataBuffer(data_buffer); + return FAILED; + } + } + return SUCCESS; +} + +Status ModelProcess::CheckAndInitDvppInput(const void *dvpp_outputs_buffer_dev, size_t dvpp_outputs_buffer_size, + size_t input_index) { + aclError ret; + inputs_ = aclmdlCreateDataset(); + // check inputs + if (input_index >= input_infos_.size()) { + MS_LOG(ERROR) << "inputs count not match, required count " << input_infos_.size() << ", given index " + << input_index; + return INVALID_INPUTS; + } + if (dvpp_outputs_buffer_dev == nullptr) { + MS_LOG(ERROR) << "input " << 0 << " cannot be null"; + return FAILED; + } + if (dvpp_outputs_buffer_size != input_infos_[input_index].buffer_size) { + MS_LOG(ERROR) << "input " << 0 << " data size not match, required size " << input_infos_[input_index].buffer_size + << ", given count " << dvpp_outputs_buffer_size; + return INVALID_INPUTS; + } + // copy inputs + auto &info = input_infos_[input_index]; + auto data_buffer = aclCreateDataBuffer(const_cast(dvpp_outputs_buffer_dev), info.buffer_size); + if (data_buffer == nullptr) { + MS_LOG(ERROR) << "Create Data Buffer failed"; + return FAILED; + } + ret = aclmdlAddDatasetBuffer(inputs_, data_buffer); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "add data buffer failed"; + aclDestroyDataBuffer(data_buffer); + return FAILED; + } + return SUCCESS; +} + +Status ModelProcess::Predict(const std::map &inputs, std::map *outputs) { + MS_EXCEPTION_IF_NULL(outputs); + aclError acl_ret; + Status ret = CheckAndInitInput(inputs); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "check or init input failed"; + DestroyInputsDataset(); + return ret; // forward status error + } + acl_ret = aclmdlExecute(model_id_, inputs_, outputs_); + DestroyInputsDataset(); + if (acl_ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Execute Model Failed"; + return FAILED; + } + ret = BuildOutputs(outputs); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Build outputs faield"; + return FAILED; + } + MS_LOG(INFO) << "excute model success"; + return SUCCESS; +} + +size_t ModelProcess::GetBatchSize() const { + if (input_infos_.empty()) { + MS_LOG(ERROR) << "Model is not loaded"; + return 0; + } + if (input_infos_[0].dims.empty()) { + return 1; + } + return static_cast(input_infos_[0].dims[0]); +} + +Status ModelProcess::BuildOutputs(std::map *outputs) { + MS_EXCEPTION_IF_NULL(outputs); + aclError ret; + // copy outputs + outputs->clear(); + aclrtMemcpyKind kind = is_run_on_device_ ? ACL_MEMCPY_HOST_TO_HOST : ACL_MEMCPY_DEVICE_TO_HOST; + for (size_t i = 0; i < output_infos_.size(); ++i) { + const auto &info = output_infos_[i]; + // todo + outputs->emplace(info.name, Buffer()); + auto output = outputs->rbegin()->second; + if (!output.ResizeData(info.buffer_size)) { + MS_LOG(ERROR) << "new output data buffer failed, data size " << info.buffer_size; + return FAILED; + } + ret = aclrtMemcpy(output.MutableData(), output.DataSize(), info.device_data, info.buffer_size, kind); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Memcpy output " << i << " from " << (is_run_on_device_ ? "host" : "device") + << " to host failed, memory size " << info.buffer_size; + return FAILED; + } + } + return SUCCESS; +} + +Status ModelProcess::GetInputsInfo(std::vector *tensor_list) const { + ConstructTensorDesc(input_infos_, tensor_list); + return SUCCESS; +} + +Status ModelProcess::GetOutputsInfo(std::vector *tensor_list) const { + ConstructTensorDesc(output_infos_, tensor_list); + return SUCCESS; +} +} // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/model/acl/model_process.h b/mindspore/ccsrc/cxx_api/model/acl/model_process.h new file mode 100644 index 00000000000..24f3f386236 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/model/acl/model_process.h @@ -0,0 +1,93 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_PROCESS_H +#define MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_PROCESS_H +#include +#include +#include +#include "acl/acl.h" +#include "acl/acl_mdl.h" +#include "acl/acl_rt.h" +#include "include/api/status.h" +#include "include/api/types.h" + +namespace mindspore::api { +struct AclTensorInfo { + void *device_data; + size_t buffer_size; + aclDataType data_type; + std::vector dims; + std::string name; +}; + +struct ImagesDvppOutput { + void *buffer_device = nullptr; + size_t buffer_size = 0; + size_t input_index = 0; +}; + +class ModelProcess { + public: + ModelProcess() + : model_id_(0xffffffff), + is_run_on_device_(false), + model_desc_(nullptr), + inputs_(nullptr), + outputs_(nullptr), + input_infos_(), + output_infos_() {} + ~ModelProcess() {} + Status LoadModelFromFile(const std::string &file_name, uint32_t *model_id); + Status UnLoad(); + Status Predict(const std::map &inputs, std::map *outputs); + Status PreInitModelResource(); + Status GetInputsInfo(std::vector *tensor_list) const; + Status GetOutputsInfo(std::vector *tensor_list) const; + + // override this method to avoid request/reply data copy + void SetIsDevice(bool is_device) { is_run_on_device_ = is_device; } + + size_t GetBatchSize() const; + void set_model_id(uint32_t model_id) { model_id_ = model_id; } + uint32_t model_id() const { return model_id_; } + + private: + Status CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset); + Status CheckAndInitInput(const std::map &inputs); + Status CheckAndInitDvppInput(const void *dvpp_outputs_buffer_dev, size_t dvpp_outputs_buffer_size, + size_t input_index); + Status BuildOutputs(std::map *outputs); + Status InitInputsBuffer(); + Status InitOutputsBuffer(); + + void DestroyInputsDataset(); + void DestroyInputsDataMem(); + void DestroyInputsBuffer(); + void DestroyOutputsBuffer(); + + uint32_t model_id_; + // if run one device(AICPU), there is no need to alloc device memory and copy inputs to(/outputs from) device + bool is_run_on_device_; + aclmdlDesc *model_desc_; + aclmdlDataset *inputs_; + aclmdlDataset *outputs_; + std::vector input_infos_; + std::vector output_infos_; +}; +} // namespace mindspore::api + +#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_PROCESS_H diff --git a/mindspore/ccsrc/cxx_api/model/model.cc b/mindspore/ccsrc/cxx_api/model/model.cc new file mode 100644 index 00000000000..27a27f3797b --- /dev/null +++ b/mindspore/ccsrc/cxx_api/model/model.cc @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "include/api/model.h" +#include "cxx_api/model/model_impl.h" +#include "utils/utils.h" + +namespace mindspore::api { +Status Model::LoadModel(const Buffer &model_data, ModelType type, const std::map &options) { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->LoadModel(model_data, type, options); +} + +Status Model::LoadModel(const std::string &file_name, ModelType type, + const std::map &options) { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->LoadModel(file_name, type, options); +} + +Status Model::UnloadModel() { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->UnloadModel(); +} + +Status Model::Train(const DataSet &dataset, std::map *outputs) { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->Train(dataset, outputs); +} + +Status Model::Eval(const DataSet &dataset, std::map *outputs) { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->Eval(dataset, outputs); +} + +Status Model::Predict(const std::map &inputs, std::map *outputs) { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->Predict(inputs, outputs); +} + +Status Model::Predict(const std::vector &inputs, std::map *outputs) { + std::vector tensor_list; + auto ret = GetInputsInfo(&tensor_list); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "GetInputsInfo failed."; + return ret; + } + + if (inputs.size() != tensor_list.size()) { + MS_LOG(ERROR) << "Model need " << tensor_list.size() << " inputs, but given " << inputs.size(); + return FAILED; + } + + std::map inputs_with_map; + for (size_t i = 0; i < tensor_list.size(); ++i) { + inputs_with_map.emplace(tensor_list[i].Name(), inputs[i]); + } + + return Predict(inputs_with_map, outputs); +} + +Status Model::GetInputsInfo(std::vector *tensor_list) const { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->GetInputsInfo(tensor_list); +} + +Status Model::GetOutputsInfo(std::vector *tensor_list) const { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->GetOutputsInfo(tensor_list); +} + +Model::Model(const std::string &device_type, uint32_t device_id) + : impl_(ModelFactory::Instance().Create(device_type, device_id)) { + if (impl_ == nullptr) { + MS_LOG(EXCEPTION) << "Create session type " << device_type << " failed"; + } +} + +Model::Model(NetWork network, const std::string &device_type, uint32_t device_id) { + // todo + if (impl_ == nullptr) { + MS_LOG(EXCEPTION) << "Create session type " << device_type << " failed"; + } +} + +Model::~Model() {} +} // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/model/model_impl.h b/mindspore/ccsrc/cxx_api/model/model_impl.h new file mode 100644 index 00000000000..3ef26a6c3e9 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/model/model_impl.h @@ -0,0 +1,93 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_CXX_API_MODEL_MODEL_IMPL_H +#define MINDSPORE_CCSRC_CXX_API_MODEL_MODEL_IMPL_H +#include +#include +#include +#include +#include +#include +#include "include/api/model.h" +#include "utils/utils.h" + +namespace mindspore::api { +class ModelImpl { + public: + ModelImpl() = default; + virtual ~ModelImpl() = default; + + virtual Status LoadModel(const Buffer &model_data, ModelType type, + const std::map &options) = 0; + virtual Status LoadModel(const std::string &file_name, ModelType type, + const std::map &options) = 0; + virtual Status UnloadModel() = 0; + + virtual Status Train(const DataSet &dataset, std::map *outputs) = 0; + virtual Status Eval(const DataSet &dataset, std::map *outputs) = 0; + virtual Status Predict(const std::map &inputs, std::map *outputs) = 0; + + virtual Status GetInputsInfo(std::vector *tensor_list) const = 0; + virtual Status GetOutputsInfo(std::vector *tensor_list) const = 0; +}; + +using ModelCreator = std::function(uint32_t device_id)>; +class ModelFactory { + public: + ModelFactory(const ModelFactory &) = delete; + void operator=(const ModelFactory &) = delete; + + static ModelFactory &Instance() { + static ModelFactory instance; + return instance; + } + + void Register(const std::string &device_name, ModelCreator &&model_creator) { + if (model_creators_.find(device_name) == model_creators_.end()) { + (void)model_creators_.emplace(device_name, model_creator); + } + } + + std::shared_ptr Create(const std::string &device_name, uint32_t device_id) { + auto iter = model_creators_.find(device_name); + if (model_creators_.end() != iter) { + MS_EXCEPTION_IF_NULL(iter->second); + return (iter->second)(device_id); + } + return nullptr; + } + + private: + ModelFactory() = default; + ~ModelFactory() = default; + std::map model_creators_; +}; + +class ModelRegistrar { + public: + ModelRegistrar(const std::string &device_name, ModelCreator model_creator) { + ModelFactory::Instance().Register(device_name, std::move(model_creator)); + } + ~ModelRegistrar() = default; +}; + +#define API_REG_MODEL(DEVICE_NAME, MODEL_CLASS) \ + static const ModelRegistrar g_api_model_registrar__##DEVICE_NAME##_##_reg( \ + #DEVICE_NAME, [](uint32_t device_id) { return std::make_shared(device_id); }); + +} // namespace mindspore::api + +#endif // MINDSPORE_CCSRC_CXX_API_MODEL_MODEL_IMPL_H diff --git a/mindspore/ccsrc/cxx_api/ops/ops.cc b/mindspore/ccsrc/cxx_api/ops/ops.cc new file mode 100644 index 00000000000..1d028a6d8d6 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/ops/ops.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "include/api/ops/ops.h" + +namespace mindspore::api { +Conv2D::Conv2D(int out_channel, const std::vector &kernel_size, int mode, const std::string &pad_mode, + const std::vector &pad, const std::vector &stride, const std::vector &dilation, int group) + : OpCell("Conv2D"), + out_channel(out_channel), + kernel_size(kernel_size), + mode(mode), + pad_mode(pad_mode), + pad(pad), + stride(stride), + dilation(dilation), + group(group) {} + +Output Conv2D::operator()(const Input &input1, const Input &input2) const { + return CellBase::operator()({input1, input2})[0]; +} + +std::vector Conv2D::Construct(const std::vector &inputs) { + return {Output(shared_from_this(), inputs, 1)}; +} +} // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/serialization.cc b/mindspore/ccsrc/cxx_api/serialization.cc new file mode 100644 index 00000000000..2bd1be56b2a --- /dev/null +++ b/mindspore/ccsrc/cxx_api/serialization.cc @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "include/api/serialization.h" +#include "utils/log_adapter.h" + +namespace mindspore::api { +Status Serialization::LoadCheckPoint(const std::string &ckpt_file, std::map *parameters) { + MS_LOG(ERROR) << "Unsupported feature."; + return FAILED; +} + +Status Serialization::SetParameters(const std::map ¶meters, Model *model) { + MS_LOG(ERROR) << "Unsupported feature."; + return FAILED; +} + +Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data) { + MS_LOG(ERROR) << "Unsupported feature."; + return FAILED; +} + +Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file) { + MS_LOG(ERROR) << "Unsupported feature."; + return FAILED; +} +} // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/types.cc b/mindspore/ccsrc/cxx_api/types.cc new file mode 100644 index 00000000000..03c3aa2dbb5 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/types.cc @@ -0,0 +1,226 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "include/api/types.h" +#include +#include "securec/include/securec.h" +#include "utils/utils.h" + +namespace mindspore::api { +class DataImpl { + public: + DataImpl() : data_() {} + ~DataImpl() = default; + DataImpl(const void *data, size_t data_len) { SetData(data, data_len); } + + const void *Data() const { return data_.data(); } + void *MutableData() { return data_.data(); } + size_t DataSize() const { return data_.size(); } + + bool ResizeData(size_t data_len) { + data_.resize(data_len); + return true; + } + + bool SetData(const void *data, size_t data_len) { + ResizeData(data_len); + if (DataSize() != data_len) { + MS_LOG(ERROR) << "Set data failed, tensor current data size " << DataSize() << " not match data len " << data_len; + return false; + } + + if (data == nullptr) { + return data_len == 0; + } + + if (MutableData() == nullptr) { + MS_LOG(ERROR) << "Set data failed, data len " << data_len; + return false; + } + + auto ret = memcpy_s(MutableData(), DataSize(), data, data_len); + if (ret != 0) { + MS_LOG(ERROR) << "Set data memcpy_s failed, ret = " << ret; + return false; + } + return true; + } + + protected: + std::vector data_; +}; + +class Buffer::Impl : public DataImpl { + public: + Impl() : DataImpl() {} + ~Impl() = default; + Impl(const void *data, size_t data_len) : DataImpl(data, data_len) {} +}; + +class Tensor::Impl : public DataImpl { + public: + Impl() : DataImpl(), name_(), type_(DataType::kMsUnknown), shape_() {} + ~Impl() = default; + Impl(const std::string &name, api::DataType type, const std::vector &shape, const void *data, + size_t data_len) + : DataImpl(data, data_len), name_(name), type_(type), shape_(shape) {} + + const std::string &Name() const { return name_; } + void SetName(const std::string &name) { name_ = name; } + + api::DataType DataType() const { return type_; } + void SetDataType(api::DataType type) { type_ = type; } + + void SetShape(const std::vector &shape) { shape_ = shape; } + const std::vector &Shape() const { return shape_; } + + int64_t ElementNum() const { + std::vector shapex = Shape(); + return std::accumulate(shapex.begin(), shapex.end(), 1LL, std::multiplies()); + } + + static int GetTypeSize(api::DataType type) { + static const std::map type_size_map = { + {kMsBool, sizeof(bool)}, {kMsFloat64, sizeof(double)}, {kMsInt8, sizeof(int8_t)}, + {kMsUint8, sizeof(uint8_t)}, {kMsInt16, sizeof(int16_t)}, {kMsUint16, sizeof(uint16_t)}, + {kMsInt32, sizeof(int32_t)}, {kMsUint32, sizeof(uint32_t)}, {kMsInt64, sizeof(int64_t)}, + {kMsUint64, sizeof(uint64_t)}, {kMsFloat16, sizeof(uint16_t)}, {kMsFloat32, sizeof(float)}, + }; + auto it = type_size_map.find(type); + if (it != type_size_map.end()) { + return it->second; + } + + MS_LOG(WARNING) << "Cannot find data type " << type; + return 0; + } + + private: + std::string name_; + api::DataType type_; + std::vector shape_; +}; + +Tensor::Tensor() : impl_(std::make_shared()) {} +Tensor::Tensor(const std::string &name, api::DataType type, const std::vector &shape, const void *data, + size_t data_len) + : impl_(std::make_shared(name, type, shape, data, data_len)) {} +Tensor::~Tensor() = default; + +Tensor Tensor::Clone() const { + MS_EXCEPTION_IF_NULL(impl_); + Tensor ret; + ret.impl_ = std::make_shared(*impl_); + return ret; +} + +const std::string &Tensor::Name() const { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->Name(); +} + +void Tensor::SetName(const std::string &name) { + MS_EXCEPTION_IF_NULL(impl_); + impl_->SetName(name); +} + +DataType Tensor::DataType() const { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->DataType(); +} + +void Tensor::SetDataType(api::DataType type) { + MS_EXCEPTION_IF_NULL(impl_); + impl_->SetDataType(type); +} + +const std::vector &Tensor::Shape() const { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->Shape(); +} + +void Tensor::SetShape(const std::vector &shape) { + MS_EXCEPTION_IF_NULL(impl_); + impl_->SetShape(shape); +} + +const void *Tensor::Data() const { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->Data(); +} + +void *Tensor::MutableData() { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->MutableData(); +} + +size_t Tensor::DataSize() const { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->DataSize(); +} + +bool Tensor::ResizeData(size_t data_len) { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->ResizeData(data_len); +} + +bool Tensor::SetData(const void *data, size_t data_len) { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->SetData(data, data_len); +} + +int64_t Tensor::ElementNum() const { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->ElementNum(); +} + +int Tensor::GetTypeSize(api::DataType type) { return Impl::GetTypeSize(type); } + +Buffer::Buffer() : impl_(std::make_shared()) {} +Buffer::Buffer(const void *data, size_t data_len) : impl_(std::make_shared(data, data_len)) {} +Buffer::~Buffer() = default; + +Buffer Buffer::Clone() const { + MS_EXCEPTION_IF_NULL(impl_); + Buffer ret; + ret.impl_ = std::make_shared(*impl_); + return ret; +} + +const void *Buffer::Data() const { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->Data(); +} + +void *Buffer::MutableData() { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->MutableData(); +} + +size_t Buffer::DataSize() const { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->DataSize(); +} + +bool Buffer::ResizeData(size_t data_len) { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->ResizeData(data_len); +} + +bool Buffer::SetData(const void *data, size_t data_len) { + MS_EXCEPTION_IF_NULL(impl_); + return impl_->SetData(data, data_len); +} +} // namespace mindspore::api diff --git a/mindspore/ccsrc/transform/graph_ir/CMakeLists.txt b/mindspore/ccsrc/transform/graph_ir/CMakeLists.txt index 3f062609d5d..9279b31d0d5 100644 --- a/mindspore/ccsrc/transform/graph_ir/CMakeLists.txt +++ b/mindspore/ccsrc/transform/graph_ir/CMakeLists.txt @@ -1,5 +1,6 @@ -if (ENABLE_GE OR ENABLE_D) +if (ENABLE_GE OR ENABLE_D OR ENABLE_ACL) file(GLOB_RECURSE _TRANSFORM_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") + list(REMOVE_ITEM _TRANSFORM_SRC_LIST "graph_ir/op_declare/hcom_ops_declare.cc") set_property(SOURCE ${_TRANSFORM_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_GE_ADPT) add_library(_mindspore_transform_graph_ir_obj OBJECT ${_TRANSFORM_SRC_LIST}) diff --git a/mindspore/ccsrc/transform/graph_ir/convert.cc b/mindspore/ccsrc/transform/graph_ir/convert.cc index 1ff54ef2c32..06135c4877b 100644 --- a/mindspore/ccsrc/transform/graph_ir/convert.cc +++ b/mindspore/ccsrc/transform/graph_ir/convert.cc @@ -1579,7 +1579,6 @@ OperatorPtr DfGraphConvertor::ConvertParameter(const AnfNodePtr node) { // build index for parameter using name std::string name = std::static_pointer_cast(node)->name(); params_[name] = node; - std::ostringstream ss; ss << "op" << node.get(); op_draw_name_[node.get()] = ss.str(); diff --git a/serving/CMakeLists.txt b/serving/CMakeLists.txt index 0b9433bcd91..6a9e26eda42 100644 --- a/serving/CMakeLists.txt +++ b/serving/CMakeLists.txt @@ -76,8 +76,6 @@ file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} list(APPEND SERVING_SRC "main.cc" ${hw_proto_srcs} ${hw_grpc_srcs} ${CORE_SRC_LIST}) -option(ENABLE_ACL "enable acl" OFF) - if (ENABLE_ACL) if (DEFINED ENV{ASCEND_CUSTOM_PATH}) set(ASCEND_PATH $ENV{ASCEND_CUSTOM_PATH}) @@ -85,9 +83,11 @@ if (ENABLE_ACL) set(ASCEND_PATH /usr/local/Ascend) endif () set(ACL_LIB_DIR ${ASCEND_PATH}/acllib/) - MESSAGE("acl lib dir " ${ACL_LIB_DIR}) + set(ATLAS_ACL_LIB_DIR ${ASCEND_PATH}/ascend-toolkit/latest/acllib) + MESSAGE("hisi acl lib dir " ${ACL_LIB_DIR} " ,atlas acl lib dir " ${ATLAS_ACL_LIB_DIR}) include_directories(${ACL_LIB_DIR}/include/) + include_directories(${ATLAS_ACL_LIB_DIR}/include/) file(GLOB_RECURSE ACL_SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "acl/*.cc") list(APPEND SERVING_SRC ${ACL_SESSION_SRC_LIST}) endif () @@ -112,10 +112,13 @@ endif () if (ENABLE_ACL) add_compile_definitions(ENABLE_ACL) add_compile_definitions(ENABLE_DVPP_INTERFACE) - set(ALC_LIB_SO ${ACL_LIB_DIR}/lib64/libruntime.so ${ACL_LIB_DIR}/lib64/libascendcl.so - ${ACL_LIB_DIR}/lib64/libacl_retr.so ${ACL_LIB_DIR}/lib64/libacl_cblas.so - ${ACL_LIB_DIR}/lib64/libacl_dvpp.so) - target_link_libraries(ms_serving ${ALC_LIB_SO}) + find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64) + find_library(acl_retr libacl_retr.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64) + find_library(acl_cblas libacl_cblas.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64) + find_library(acl_dvpp libacl_dvpp.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64) + find_library(acl_runtime libruntime.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64) + + target_link_libraries(ms_serving ${acl} ${acl_retr} ${acl_cblas} ${acl_dvpp} ${acl_runtime}) target_link_libraries(ms_serving jpeg_turbo::jpeg securec) else () target_link_libraries(ms_serving inference mindspore_gvar) diff --git a/setup.py b/setup.py index 8e548bf3c52..4d43536c2fe 100644 --- a/setup.py +++ b/setup.py @@ -130,7 +130,11 @@ package_data = { 'lib/*.so*', 'lib/*.a', '.commit_id', - 'ms_serving' + 'ms_serving', + 'include/*', + 'include/*/*', + 'include/*/*/*', + 'include/*/*/*/*' ] } diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index 05092db3c2f..2648edcf6f6 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -56,6 +56,7 @@ if(ENABLE_MINDDATA) ./utils/*.cc ./vm/*.cc ./ps/*.cc + ./cxx_api/*.cc ) if(NOT ENABLE_PYTHON) @@ -176,7 +177,7 @@ if (USE_GLOG) target_link_libraries(ut_tests PRIVATE mindspore::glog) endif() -target_link_libraries(ut_tests PRIVATE mindspore securec graph) +target_link_libraries(ut_tests PRIVATE mindspore mindspore_shared_lib securec graph) # link grpc if (EXISTS ${grpc_ROOT}/lib64) diff --git a/tests/ut/cpp/cxx_api/types_test.cc b/tests/ut/cpp/cxx_api/types_test.cc new file mode 100644 index 00000000000..c222bd5b307 --- /dev/null +++ b/tests/ut/cpp/cxx_api/types_test.cc @@ -0,0 +1,169 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "common/common_test.h" +#include "include/api/types.h" + +namespace mindspore { +class TestCxxApiTypes : public UT::Common { + public: + TestCxxApiTypes() = default; +}; + +TEST_F(TestCxxApiTypes, test_tensor_set_name_SUCCESS) { + std::string tensor_name_before = "TEST1"; + std::string tensor_name_after = "TEST2"; + api::Tensor tensor1(tensor_name_before, api::DataType::kMsFloat32, {}, nullptr, 0); + api::Tensor tensor2 = tensor1; + api::Tensor tensor3 = tensor1.Clone(); + + // name + ASSERT_EQ(tensor1.Name(), tensor_name_before); + ASSERT_EQ(tensor2.Name(), tensor_name_before); + ASSERT_EQ(tensor3.Name(), tensor_name_before); + + tensor1.SetName(tensor_name_after); + ASSERT_EQ(tensor1.Name(), tensor_name_after); + ASSERT_EQ(tensor2.Name(), tensor_name_after); + ASSERT_EQ(tensor3.Name(), tensor_name_before); +} + +TEST_F(TestCxxApiTypes, test_tensor_set_dtype_SUCCESS) { + api::Tensor tensor1("", api::DataType::kMsFloat32, {}, nullptr, 0); + api::Tensor tensor2 = tensor1; + api::Tensor tensor3 = tensor1.Clone(); + + // dtype + ASSERT_EQ(tensor1.DataType(), api::DataType::kMsFloat32); + ASSERT_EQ(tensor2.DataType(), api::DataType::kMsFloat32); + ASSERT_EQ(tensor3.DataType(), api::DataType::kMsFloat32); + + tensor1.SetDataType(api::DataType::kMsUint32); + ASSERT_EQ(tensor1.DataType(), api::DataType::kMsUint32); + ASSERT_EQ(tensor2.DataType(), api::DataType::kMsUint32); + ASSERT_EQ(tensor3.DataType(), api::DataType::kMsFloat32); +} + +TEST_F(TestCxxApiTypes, test_tensor_set_shape_SUCCESS) { + std::vector shape = {3, 4, 5, 6}; + api::Tensor tensor1("", api::DataType::kMsFloat32, {}, nullptr, 0); + api::Tensor tensor2 = tensor1; + api::Tensor tensor3 = tensor1.Clone(); + + // shape + ASSERT_EQ(tensor1.Shape(), std::vector()); + ASSERT_EQ(tensor2.Shape(), std::vector()); + ASSERT_EQ(tensor3.Shape(), std::vector()); + + tensor1.SetShape(shape); + ASSERT_EQ(tensor1.Shape(), shape); + ASSERT_EQ(tensor2.Shape(), shape); + ASSERT_EQ(tensor3.Shape(), std::vector()); +} + + +TEST_F(TestCxxApiTypes, test_tensor_util_SUCCESS) { + std::vector shape = {3, 4, 5, 6}; + std::vector data(3 * 4 * 5 * 6, 123); + api::Tensor tensor1("", api::DataType::kMsFloat32, shape, data.data(), data.size() * sizeof(uint32_t)); + + // data + ASSERT_EQ(api::Tensor::GetTypeSize(api::DataType::kMsUint32), sizeof(uint32_t)); + ASSERT_EQ(tensor1.ElementNum(), 3 * 4 * 5 * 6); +} + +TEST_F(TestCxxApiTypes, test_tensor_data_ref_and_copy_SUCCESS) { + std::vector shape = {3, 4, 5, 6}; + std::vector data(3 * 4 * 5 * 6, 123); + api::Tensor tensor1("", api::DataType::kMsFloat32, shape, data.data(), data.size() * sizeof(uint32_t)); + api::Tensor tensor2 = tensor1; + api::Tensor tensor3 = tensor1.Clone(); + + // data + ASSERT_EQ(tensor1.DataSize(), tensor2.DataSize()); + ASSERT_EQ(tensor1.DataSize(), tensor3.DataSize()); + ASSERT_EQ(tensor1.Data(), tensor2.MutableData()); + ASSERT_NE(tensor1.Data(), tensor3.Data()); +} + +TEST_F(TestCxxApiTypes, test_tensor_resize_data_SUCCESS) { + std::vector shape = {3, 4, 5, 6}; + std::vector data(3 * 4 * 5 * 6, 123); + api::Tensor tensor1("", api::DataType::kMsFloat32, shape, data.data(), data.size() * sizeof(uint32_t)); + + // data + ASSERT_EQ(tensor1.ResizeData(0), true); +} + +TEST_F(TestCxxApiTypes, test_tensor_set_data_wrong_data_size_FAILED) { + std::vector shape = {3, 4, 5, 6}; + std::vector data(3 * 4 * 5 * 6, 123); + api::Tensor tensor1("", api::DataType::kMsFloat32, shape, data.data(), data.size() * sizeof(uint32_t)); + + // data + ASSERT_EQ(tensor1.SetData(nullptr, 1), false); + ASSERT_EQ(tensor1.SetData(data.data(), 0), false); +} + +TEST_F(TestCxxApiTypes, test_tensor_set_data_SUCCESS) { + std::vector shape = {3, 4, 5, 6}; + std::vector data(3 * 4 * 5 * 6, 123); + api::Tensor tensor1("", api::DataType::kMsFloat32, shape, data.data(), data.size() * sizeof(uint32_t)); + + // data + ASSERT_EQ(tensor1.SetData(nullptr, 0), true); + ASSERT_EQ(tensor1.SetData(data.data(), data.size() * sizeof(uint32_t)), true); +} + +TEST_F(TestCxxApiTypes, test_buffer_data_ref_and_copy_SUCCESS) { + std::vector data(3 * 4 * 5 * 6, 123); + api::Buffer buffer1(data.data(), data.size() * sizeof(uint32_t)); + api::Buffer buffer2 = buffer1; + api::Buffer buffer3 = buffer1.Clone(); + + // data + ASSERT_EQ(buffer1.DataSize(), buffer2.DataSize()); + ASSERT_EQ(buffer1.DataSize(), buffer3.DataSize()); + ASSERT_EQ(buffer1.Data(), buffer2.MutableData()); + ASSERT_NE(buffer1.Data(), buffer3.Data()); +} + +TEST_F(TestCxxApiTypes, test_buffer_resize_data_SUCCESS) { + std::vector data(3 * 4 * 5 * 6, 123); + api::Buffer buffer1(data.data(), data.size() * sizeof(uint32_t)); + + // data + ASSERT_EQ(buffer1.ResizeData(0), true); +} + +TEST_F(TestCxxApiTypes, test_buffer_set_data_wrong_data_size_FAILED) { + std::vector data(3 * 4 * 5 * 6, 123); + api::Buffer buffer1(data.data(), data.size() * sizeof(uint32_t)); + + // data + ASSERT_EQ(buffer1.SetData(nullptr, 1), false); + ASSERT_EQ(buffer1.SetData(data.data(), 0), false); +} + +TEST_F(TestCxxApiTypes, test_buffer_set_data_SUCCESS) { + std::vector data(3 * 4 * 5 * 6, 123); + api::Buffer buffer1(data.data(), data.size() * sizeof(uint32_t)); + + // data + ASSERT_EQ(buffer1.SetData(nullptr, 0), true); + ASSERT_EQ(buffer1.SetData(data.data(), data.size() * sizeof(uint32_t)), true); +} +} // namespace mindspore