mindspore cxx api for 310 inference
Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
parent
fd16535017
commit
183742009f
|
@ -69,19 +69,9 @@ include_directories(${PYTHON_INCLUDE_DIRS})
|
|||
set(MS_CCSRC_PATH ${CMAKE_SOURCE_DIR}/mindspore/ccsrc)
|
||||
set(MS_CCSRC_BUILD_PATH ${BUILD_PATH}/mindspore/mindspore/ccsrc)
|
||||
|
||||
if (ENABLE_GE)
|
||||
link_directories(${CMAKE_SOURCE_DIR}/third_party/ge/lib)
|
||||
elseif(ENABLE_D OR ENABLE_TESTCASES)
|
||||
if (ENABLE_D OR ENABLE_ACL OR ENABLE_TESTCASES)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/dependency_graphengine.cmake)
|
||||
endif()
|
||||
|
||||
if (ENABLE_GE OR ENABLE_D OR ENABLE_TESTCASES)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/inc)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/inc/external)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/inc/framework)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc/toolchain)
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden")
|
||||
add_subdirectory(mindspore/ccsrc)
|
||||
|
|
10
build.sh
10
build.sh
|
@ -23,9 +23,9 @@ usage()
|
|||
{
|
||||
echo "Usage:"
|
||||
echo "bash build.sh [-d] [-r] [-v] [-c on|off] [-t on|off] [-g on|off] [-h] [-b ge] [-m infer|train] \\"
|
||||
echo " [-a on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|d|cpu] \\"
|
||||
echo " [-a on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|ascend|cpu|acl] \\"
|
||||
echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 9.2|10.1] [-I arm64|arm32|x86_64] [-K] \\"
|
||||
echo " [-B on|off] [-w on|off] [-E] [-l on|off] [-n full|lite|off] [-T on|off] \\"
|
||||
echo " [-B on|off] [-E] [-l on|off] [-n full|lite|off] [-T on|off] \\"
|
||||
echo " [-A [cpp|java|object-c] [-C on|off] [-o on|off] [-S on|off] [-k on|off] [-W sse|neon|avx|off] \\"
|
||||
echo ""
|
||||
echo "Options:"
|
||||
|
@ -45,7 +45,7 @@ usage()
|
|||
echo " -i Enable increment building, default off"
|
||||
echo " -L Enable load ANF-IR as input of 'infer', default off"
|
||||
echo " -j[n] Set the threads when building (Default: -j8)"
|
||||
echo " -e Use gpu, d or cpu"
|
||||
echo " -e Use cpu, gpu, ascend or acl"
|
||||
echo " -P Enable dump anf graph to file in ProtoBuffer format, default on"
|
||||
echo " -D Enable dumping of function graph ir, default on"
|
||||
echo " -z Compile dataset & mindrecord, default on"
|
||||
|
@ -55,7 +55,6 @@ usage()
|
|||
echo " -I Enable compiling mindspore lite for arm64, arm32 or x86_64, default disable mindspore lite compilation"
|
||||
echo " -K Compile with AKG, default on"
|
||||
echo " -s Enable serving module, default off"
|
||||
echo " -w Enable acl module, default off"
|
||||
echo " -B Enable debugger, default on"
|
||||
echo " -E Enable IBVERBS for parameter server, default off"
|
||||
echo " -l Compile with python dependency, default on"
|
||||
|
@ -225,6 +224,9 @@ checkopts()
|
|||
ENABLE_D="on"
|
||||
ENABLE_CPU="on"
|
||||
ENABLE_SERVING="on"
|
||||
elif [[ "X$OPTARG" == "Xacl" ]]; then
|
||||
ENABLE_SERVING="on"
|
||||
ENABLE_ACL="on"
|
||||
elif [[ "X$OPTARG" == "Xcpu" ]]; then
|
||||
ENABLE_CPU="on"
|
||||
else
|
||||
|
|
|
@ -11,7 +11,7 @@ include(${GE_SOURCE_DIR}/cmake/external_libs/onnx.cmake)
|
|||
include(${GE_SOURCE_DIR}/cmake/external_libs/securec.cmake)
|
||||
|
||||
# for UT, find slog and error_manager from local prebuild
|
||||
if (NOT ENABLE_D)
|
||||
if (NOT ENABLE_D AND NOT ENABLE_ACL)
|
||||
set(GE_PREBUILD_PATH ${GE_SOURCE_DIR}/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR})
|
||||
find_library(slog libslog.so ${GE_PREBUILD_PATH})
|
||||
find_library(error_manager liberror_manager.so ${GE_PREBUILD_PATH})
|
||||
|
@ -28,6 +28,7 @@ elseif (DEFINED ENV{D_LINK_PATH})
|
|||
message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated")
|
||||
endif()
|
||||
set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH})
|
||||
find_library(c_sec libc_sec.so ${GE_LIB_PATH})
|
||||
find_library(slog libslog.so ${GE_LIB_PATH})
|
||||
find_library(mmpa libmmpa.so ${GE_LIB_PATH})
|
||||
find_library(runtime libruntime.so ${GE_LIB_PATH})
|
||||
|
@ -44,8 +45,8 @@ else()
|
|||
else()
|
||||
set(ASCEND_PATH /usr/local/Ascend)
|
||||
endif()
|
||||
set(ASCEND_DRIVER_PATH ${ASCEND_PATH}/driver/lib64/common)
|
||||
set(ASCEND_RUNTIME_PATH ${ASCEND_PATH}/fwkacllib/lib64)
|
||||
set(ASCEND_DRIVER_PATH ${ASCEND_PATH}/driver/lib64/common ${ASCEND_PATH}/driver/lib64)
|
||||
set(ASCEND_RUNTIME_PATH ${ASCEND_PATH}/fwkacllib/lib64 ${ASCEND_PATH}/acllib/lib64 ${ASCEND_PATH}/atc/lib64)
|
||||
find_library(c_sec libc_sec.so ${ASCEND_DRIVER_PATH})
|
||||
find_library(slog libslog.so ${ASCEND_DRIVER_PATH})
|
||||
find_library(mmpa libmmpa.so ${ASCEND_DRIVER_PATH})
|
||||
|
@ -76,9 +77,11 @@ string(REPLACE " -Werror" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
|||
# force __FILE__ to show relative path of file, from source directory
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__FILE__='\"$(subst $(realpath ${CMAKE_SOURCE_DIR})/,,$(abspath $<))\"' -Wno-builtin-macro-redefined")
|
||||
add_subdirectory(${GE_SOURCE_DIR}/src/common/graph)
|
||||
if(ENABLE_D)
|
||||
if (ENABLE_ACL OR ENABLE_D)
|
||||
add_subdirectory(${GE_SOURCE_DIR}/src/ge/common)
|
||||
add_subdirectory(${GE_SOURCE_DIR}/src/ge/ge_runtime)
|
||||
endif()
|
||||
if (ENABLE_D)
|
||||
add_subdirectory(${GE_SOURCE_DIR}/src/ge/ge_runtime)
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
set(CMAKE_CXX_FLAGS ${_ge_tmp_CMAKE_CXX_FLAGS})
|
||||
|
|
|
@ -58,13 +58,22 @@ if (ENABLE_GE)
|
|||
include_directories(${CMAKE_SOURCE_DIR}/third_party/ge/include)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/third_party/ge/include/external)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/third_party/ge/include/external/graph)
|
||||
elseif(ENABLE_D OR ENABLE_TESTCASES)
|
||||
link_directories(${CMAKE_SOURCE_DIR}/third_party/ge/lib)
|
||||
elseif(ENABLE_D OR ENABLE_ACL OR ENABLE_TESTCASES)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc/ops)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc/external)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc/external/graph)
|
||||
endif()
|
||||
|
||||
if (ENABLE_GE OR ENABLE_D OR ENABLE_ACL OR ENABLE_TESTCASES)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc/external)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc/framework)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc/toolchain)
|
||||
endif()
|
||||
|
||||
if (ENABLE_MINDDATA)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/icu4c.cmake)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/libtiff.cmake)
|
||||
|
|
|
@ -19,6 +19,7 @@ option(ENABLE_AKG "enable akg" OFF)
|
|||
option(ENABLE_DEBUGGER "enable debugger" OFF)
|
||||
option(ENABLE_IBVERBS "enable IBVERBS for parameter server" OFF)
|
||||
option(ENABLE_PYTHON "Enable python" ON)
|
||||
option(ENABLE_ACL "enable acl" OFF)
|
||||
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
|
||||
if (WIN32)
|
||||
|
|
|
@ -58,6 +58,12 @@ install(
|
|||
COMPONENT mindspore
|
||||
)
|
||||
|
||||
install(
|
||||
TARGETS mindspore_shared_lib
|
||||
LIBRARY DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
|
||||
install(
|
||||
TARGETS mindspore_gvar
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
|
@ -194,7 +200,7 @@ if (ENABLE_SERVING OR ENABLE_TESTCASES)
|
|||
endif ()
|
||||
|
||||
if (NOT ENABLE_GE)
|
||||
if (ENABLE_D)
|
||||
if (ENABLE_D OR ENABLE_ACL)
|
||||
if (DEFINED ENV{ASCEND_CUSTOM_PATH})
|
||||
set(ASCEND_PATH $ENV{ASCEND_CUSTOM_PATH})
|
||||
else ()
|
||||
|
@ -203,19 +209,26 @@ if (NOT ENABLE_GE)
|
|||
set(ASCEND_DRIVER_PATH ${ASCEND_PATH}/driver/lib64/common)
|
||||
|
||||
install(
|
||||
FILES
|
||||
${CMAKE_BINARY_DIR}/graphengine/src/common/graph/libgraph.so
|
||||
${CMAKE_BINARY_DIR}/graphengine/src/ge/common/libge_common.so
|
||||
${CMAKE_BINARY_DIR}/graphengine/src/ge/ge_runtime/libge_runtime.so
|
||||
${CMAKE_SOURCE_DIR}/build/graphengine/libc_sec.so
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
install(
|
||||
TARGETS ms_profile
|
||||
FILES ${CMAKE_SOURCE_DIR}/build/graphengine/libc_sec.so
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
|
||||
if (ENABLE_D)
|
||||
install(
|
||||
TARGETS ms_profile
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
install(
|
||||
FILES
|
||||
${CMAKE_BINARY_DIR}/graphengine/src/common/graph/libgraph.so
|
||||
${CMAKE_BINARY_DIR}/graphengine/src/ge/common/libge_common.so
|
||||
${CMAKE_BINARY_DIR}/graphengine/src/ge/ge_runtime/libge_runtime.so
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
endif ()
|
||||
elseif (ENABLE_TESTCASES)
|
||||
install(
|
||||
FILES
|
||||
|
@ -287,6 +300,13 @@ if (EXISTS ${CMAKE_SOURCE_DIR}/mindspore/dataset)
|
|||
)
|
||||
endif ()
|
||||
|
||||
## Public header files
|
||||
install(
|
||||
DIRECTORY ${CMAKE_SOURCE_DIR}/include
|
||||
DESTINATION ${INSTALL_BASE_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
|
||||
if (ENABLE_SERVING)
|
||||
install(
|
||||
TARGETS ms_serving
|
||||
|
@ -308,8 +328,8 @@ if (ENABLE_SERVING)
|
|||
)
|
||||
|
||||
install(
|
||||
FILES ${LIBEVENT_LIB_LIST}
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
FILES ${LIBEVENT_LIB_LIST}
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
endif ()
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 423c0228e8c421f2b095e40d14e9fb3b563f63aa
|
||||
Subproject commit 42d217fb8cec74b1c73685b8abe94d5f1520e9fe
|
|
@ -0,0 +1,113 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_CELL_H
|
||||
#define MINDSPORE_INCLUDE_API_CELL_H
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
class InputAndOutput;
|
||||
using Input = InputAndOutput;
|
||||
using Output = InputAndOutput;
|
||||
|
||||
class MS_API CellBase {
|
||||
public:
|
||||
CellBase() = default;
|
||||
virtual ~CellBase() = default;
|
||||
virtual std::vector<Output> Construct(const std::vector<Input> &inputs) { return {}; }
|
||||
virtual std::shared_ptr<CellBase> Clone() const = 0;
|
||||
std::vector<Output> operator()(const std::vector<Input> &inputs) const;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class MS_API Cell : public CellBase {
|
||||
public:
|
||||
virtual ~Cell() = default;
|
||||
std::shared_ptr<CellBase> Clone() const override {
|
||||
return std::make_shared<T>(static_cast<const T&>(*this));
|
||||
}
|
||||
};
|
||||
|
||||
class MS_API ParameterCell final : public Cell<ParameterCell> {
|
||||
public:
|
||||
ParameterCell() = default;
|
||||
~ParameterCell() override = default;
|
||||
|
||||
ParameterCell(const ParameterCell &);
|
||||
ParameterCell &operator=(const ParameterCell &);
|
||||
|
||||
ParameterCell(ParameterCell &&);
|
||||
ParameterCell &operator=(ParameterCell &&);
|
||||
|
||||
explicit ParameterCell(const Tensor &);
|
||||
ParameterCell &operator=(const Tensor &);
|
||||
|
||||
explicit ParameterCell(Tensor &&);
|
||||
ParameterCell &operator=(Tensor &&);
|
||||
|
||||
Tensor GetTensor() const { return tensor_; }
|
||||
|
||||
private:
|
||||
Tensor tensor_;
|
||||
};
|
||||
|
||||
class MS_API OpCellBase : public CellBase {
|
||||
public:
|
||||
explicit OpCellBase(const std::string &name) : name_(name) {}
|
||||
~OpCellBase() override = default;
|
||||
const std::string &GetOpType() const { return name_; }
|
||||
|
||||
protected:
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class MS_API OpCell : public OpCellBase, public std::enable_shared_from_this<T> {
|
||||
public:
|
||||
explicit OpCell(const std::string &name) : OpCellBase(name) {}
|
||||
~OpCell() override = default;
|
||||
std::shared_ptr<CellBase> Clone() const override {
|
||||
return std::make_shared<T>(static_cast<const T&>(*this));
|
||||
}
|
||||
};
|
||||
|
||||
class MS_API InputAndOutput {
|
||||
public:
|
||||
InputAndOutput();
|
||||
~InputAndOutput() = default;
|
||||
|
||||
// no explicit
|
||||
InputAndOutput(const Tensor &); // NOLINT(runtime/explicit)
|
||||
InputAndOutput(Tensor &&); // NOLINT(runtime/explicit)
|
||||
|
||||
InputAndOutput(const std::shared_ptr<CellBase> &, const std::vector<InputAndOutput> &, int32_t index);
|
||||
|
||||
int32_t GetIndex() const { return index_; }
|
||||
void SetIndex(int32_t index) { index_ = index; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<CellBase> cell_;
|
||||
std::vector<InputAndOutput> prev_;
|
||||
int32_t index_;
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_CELL_H
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_MODEL_H
|
||||
#define MINDSPORE_INCLUDE_API_MODEL_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
class ModelImpl;
|
||||
// todo: minddata c++ interface
|
||||
class DataSet {};
|
||||
class NetWork {};
|
||||
|
||||
class MS_API Model {
|
||||
public:
|
||||
Model(const std::string &device_type, uint32_t device_id);
|
||||
Model(NetWork network, const std::string &device_type, uint32_t device_id);
|
||||
~Model();
|
||||
Model(const Model &) = delete;
|
||||
void operator=(const Model &) = delete;
|
||||
|
||||
Status LoadModel(const Buffer &model_data, ModelType type, const std::map<std::string, std::string> &options);
|
||||
Status LoadModel(const std::string &file_name, ModelType type, const std::map<std::string, std::string> &options);
|
||||
Status UnloadModel();
|
||||
|
||||
Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs);
|
||||
Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs);
|
||||
Status Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs);
|
||||
Status Predict(const std::vector<Buffer> &inputs, std::map<std::string, Buffer> *outputs);
|
||||
|
||||
Status GetInputsInfo(std::vector<Tensor> *tensor_list) const;
|
||||
Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<ModelImpl> impl_;
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_MODEL_H
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_OPS_OPS_H
|
||||
#define MINDSPORE_INCLUDE_API_OPS_OPS_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/cell.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
struct MS_API Conv2D : public OpCell<Conv2D> {
|
||||
Conv2D() : OpCell("Conv2D") {}
|
||||
~Conv2D() override = default;
|
||||
std::vector<Output> Construct(const std::vector<Input> &inputs) override;
|
||||
Conv2D(int out_channel, const std::vector<int> &kernel_size, int mode = 1, const std::string &pad_mode = "valid",
|
||||
const std::vector<int> &pad = {0, 0, 0, 0}, const std::vector<int> &stride = {1, 1, 1, 1},
|
||||
const std::vector<int> &dilation = {1, 1, 1, 1}, int group = 1);
|
||||
|
||||
Output operator()(const Input &, const Input &) const;
|
||||
|
||||
int out_channel;
|
||||
std::vector<int> kernel_size;
|
||||
int mode = 1;
|
||||
std::string pad_mode = "valid";
|
||||
std::vector<int> pad = {0, 0, 0, 0};
|
||||
std::vector<int> stride = {1, 1, 1, 1};
|
||||
std::vector<int> dilation = {1, 1, 1, 1};
|
||||
int group = 1;
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_OPS_OPS_H
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_SERIALIZATION_H
|
||||
#define MINDSPORE_INCLUDE_API_SERIALIZATION_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/model.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
class MS_API Serialization {
|
||||
public:
|
||||
static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters);
|
||||
static Status SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model);
|
||||
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
|
||||
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file);
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_STATUS_H
|
||||
#define MINDSPORE_INCLUDE_API_STATUS_H
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
enum StatusCode {
|
||||
SUCCESS = 0,
|
||||
FAILED,
|
||||
INVALID_INPUTS,
|
||||
// insert new status code here
|
||||
UNKNOWN = 0xFFFFFFFF
|
||||
};
|
||||
|
||||
class Status {
|
||||
public:
|
||||
Status() : status_code_(FAILED) {}
|
||||
Status(enum StatusCode status_code, const std::string &status_msg = "") // NOLINT(runtime/explicit)
|
||||
: status_code_(status_code), status_msg_(status_msg) {}
|
||||
~Status() = default;
|
||||
|
||||
bool IsSuccess() const { return status_code_ == SUCCESS; }
|
||||
enum StatusCode StatusCode() const { return status_code_; }
|
||||
std::string StatusMessage() const { return status_msg_; }
|
||||
bool operator==(const Status &other) const { return status_code_ == other.status_code_; }
|
||||
bool operator==(enum StatusCode other_code) const { return status_code_ == other_code; }
|
||||
bool operator!=(const Status &other) const { return status_code_ != other.status_code_; }
|
||||
bool operator!=(enum StatusCode other_code) const { return status_code_ != other_code; }
|
||||
operator bool() const = delete;
|
||||
|
||||
private:
|
||||
enum StatusCode status_code_;
|
||||
std::string status_msg_;
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_STATUS_H
|
|
@ -0,0 +1,119 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_TYPES_H
|
||||
#define MINDSPORE_INCLUDE_API_TYPES_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#define MS_API __attribute__((visibility("default")))
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
enum ModelType {
|
||||
kMindIR = 0,
|
||||
kAIR = 1,
|
||||
kOM = 2,
|
||||
kONNX = 3,
|
||||
// insert new data type here
|
||||
kUnknownType = 0xFFFFFFFF
|
||||
};
|
||||
|
||||
enum DataType {
|
||||
kMsUnknown = 0,
|
||||
kMsBool = 1,
|
||||
kMsInt8 = 2,
|
||||
kMsInt16 = 3,
|
||||
kMsInt32 = 4,
|
||||
kMsInt64 = 5,
|
||||
kMsUint8 = 6,
|
||||
kMsUint16 = 7,
|
||||
kMsUint32 = 8,
|
||||
kMsUint64 = 9,
|
||||
kMsFloat16 = 10,
|
||||
kMsFloat32 = 11,
|
||||
kMsFloat64 = 12,
|
||||
// insert new data type here
|
||||
kInvalidDataType = 0xFFFFFFFF
|
||||
};
|
||||
|
||||
class MS_API Tensor {
|
||||
public:
|
||||
Tensor();
|
||||
Tensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, const void *data, size_t data_len);
|
||||
~Tensor();
|
||||
|
||||
const std::string &Name() const;
|
||||
void SetName(const std::string &name);
|
||||
|
||||
api::DataType DataType() const;
|
||||
void SetDataType(api::DataType type);
|
||||
|
||||
const std::vector<int64_t> &Shape() const;
|
||||
void SetShape(const std::vector<int64_t> &shape);
|
||||
|
||||
const void *Data() const;
|
||||
void *MutableData();
|
||||
size_t DataSize() const;
|
||||
|
||||
bool ResizeData(size_t data_len);
|
||||
bool SetData(const void *data, size_t data_len);
|
||||
|
||||
int64_t ElementNum() const;
|
||||
static int GetTypeSize(api::DataType type);
|
||||
Tensor Clone() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
class MS_API Buffer {
|
||||
public:
|
||||
Buffer();
|
||||
Buffer(const void *data, size_t data_len);
|
||||
~Buffer();
|
||||
|
||||
const void *Data() const;
|
||||
void *MutableData();
|
||||
size_t DataSize() const;
|
||||
|
||||
bool ResizeData(size_t data_len);
|
||||
bool SetData(const void *data, size_t data_len);
|
||||
|
||||
Buffer Clone() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
constexpr auto kModelOptionDumpCfgPath = "mindspore.option.dump_config_file_path";
|
||||
constexpr auto kModelOptionDvppCfgPath = "mindspore.option.dvpp_config_file_path";
|
||||
constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file
|
||||
constexpr auto kModelOptionInputFormat = "mindspore.option.input_format"; // nchw or nhwc
|
||||
// Mandatory while dynamic batch: e.g. "input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1"
|
||||
constexpr auto kModelOptionInputShape = "mindspore.option.input_shape";
|
||||
constexpr auto kModelOptionDynamicBatchSize = "mindspore.option.dynamic_batch_size";
|
||||
constexpr auto kModelOptionDynamicImageSize = "mindspore.option.dynamic_image_size";
|
||||
constexpr auto kModelOptionDynamicDims = "mindspore.option.dynamic_dims";
|
||||
constexpr auto kModelOptionSerialInput = "mindspore.option.serial_inputs_name"; // separated by ';'
|
||||
constexpr auto kModelOptionOutputNode = "mindspore.option.output_node"; // e.g. "node_name1:0;node_name2:1"
|
||||
constexpr auto kModelOptionOutputType = "mindspore.option.output_type"; // "FP32", "UINT8" or "FP16", default as "FP32"
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_TYPES_H
|
|
@ -295,6 +295,9 @@ else ()
|
|||
target_link_libraries(mindspore ibverbs rdmacm)
|
||||
endif()
|
||||
endif()
|
||||
if (ENABLE_ACL)
|
||||
target_link_libraries(_c_expression PRIVATE graph)
|
||||
endif ()
|
||||
target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore proto_input -Wl,--no-whole-archive)
|
||||
target_link_libraries(_c_expression PRIVATE mindspore::pybind11_module)
|
||||
target_link_libraries(_c_expression PRIVATE mindspore_gvar)
|
||||
|
@ -359,3 +362,5 @@ if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
|||
elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||
set_target_properties(inference PROPERTIES MACOSX_RPATH ON)
|
||||
endif ()
|
||||
|
||||
add_subdirectory(cxx_api)
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
# build mindspore_shared_lib
|
||||
set(LOAD_ONNX_SRC
|
||||
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/utils/load_onnx/anf_converter.cc
|
||||
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc
|
||||
)
|
||||
file(GLOB_RECURSE API_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR} "ops/*.cc")
|
||||
|
||||
if (ENABLE_ACL)
|
||||
file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/acl/*.cc")
|
||||
endif ()
|
||||
|
||||
set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cell.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc
|
||||
${API_ACL_SRC}
|
||||
${API_OPS_SRC}
|
||||
${LOAD_ONNX_SRC})
|
||||
|
||||
add_library(mindspore_shared_lib SHARED ${MSLIB_SRC})
|
||||
set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore PUBLIC_HEADER "${API_INCLUDE}")
|
||||
|
||||
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
|
||||
-Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_gvar mindspore::protobuf)
|
||||
|
||||
if (ENABLE_CPU)
|
||||
target_link_libraries(mindspore_shared_lib PRIVATE mindspore::dnnl mindspore::mkldnn)
|
||||
endif ()
|
||||
|
||||
if (USE_GLOG)
|
||||
target_link_libraries(mindspore_shared_lib PRIVATE mindspore::glog)
|
||||
endif ()
|
||||
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
target_link_options(mindspore_shared_lib PRIVATE -Wl,-init,common_log_init)
|
||||
endif ()
|
||||
|
||||
if (ENABLE_ACL)
|
||||
if (DEFINED ENV{ASCEND_CUSTOM_PATH})
|
||||
set(ASCEND_PATH $ENV{ASCEND_CUSTOM_PATH})
|
||||
else ()
|
||||
set(ASCEND_PATH /usr/local/Ascend)
|
||||
endif ()
|
||||
set(ACL_LIB_DIR ${ASCEND_PATH}/acllib/)
|
||||
set(ATLAS_ACL_LIB_DIR ${ASCEND_PATH}/ascend-toolkit/latest/acllib)
|
||||
set(ATC_DIR ${ASCEND_PATH}/atc/)
|
||||
set(ATLAS_ATC_DIR ${ASCEND_PATH}/ascend-toolkit/latest/atc)
|
||||
MESSAGE("acl lib dir " ${ACL_LIB_DIR} ", atc dir " ${ATC_DIR})
|
||||
MESSAGE("atlas acl lib dir " ${ATLAS_ACL_LIB_DIR} ", atc dir " ${ATLAS_ATC_DIR})
|
||||
|
||||
include_directories(${ACL_LIB_DIR}/include/)
|
||||
include_directories(${ATLAS_ACL_LIB_DIR}/include/)
|
||||
add_compile_definitions(ENABLE_DVPP_INTERFACE)
|
||||
find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
||||
find_library(acl_retr libacl_retr.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
||||
find_library(acl_cblas libacl_cblas.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
||||
find_library(acl_dvpp libacl_dvpp.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
||||
find_library(acl_runtime libruntime.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
||||
find_library(ge_compiler libge_compiler.so ${ATC_DIR}/lib64 ${ATLAS_ATC_DIR}/lib64)
|
||||
target_link_libraries(mindspore_shared_lib PRIVATE ${acl} ${acl_retr} ${acl_cblas} ${acl_dvpp} ${acl_runtime}
|
||||
${ge_compiler} mindspore::jpeg_turbo)
|
||||
endif ()
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "include/api/cell.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
std::vector<Output> CellBase::operator()(const std::vector<Input> &inputs) const { return Clone()->Construct(inputs); }
|
||||
|
||||
ParameterCell::ParameterCell(const ParameterCell &cell) : tensor_(cell.tensor_.Clone()) {}
|
||||
ParameterCell &ParameterCell::operator=(const ParameterCell &cell) {
|
||||
if (&cell == this) {
|
||||
return *this;
|
||||
}
|
||||
tensor_ = cell.tensor_.Clone();
|
||||
return *this;
|
||||
}
|
||||
|
||||
ParameterCell::ParameterCell(ParameterCell &&cell) : tensor_(cell.tensor_) {}
|
||||
|
||||
ParameterCell &ParameterCell::operator=(ParameterCell &&cell) {
|
||||
if (&cell == this) {
|
||||
return *this;
|
||||
}
|
||||
tensor_ = cell.tensor_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
ParameterCell::ParameterCell(const Tensor &tensor) : tensor_(tensor.Clone()) {}
|
||||
|
||||
ParameterCell &ParameterCell::operator=(const Tensor &tensor) {
|
||||
tensor_ = tensor.Clone();
|
||||
return *this;
|
||||
}
|
||||
|
||||
ParameterCell::ParameterCell(Tensor &&tensor) : tensor_(tensor) {}
|
||||
|
||||
ParameterCell &ParameterCell::operator=(Tensor &&tensor) {
|
||||
tensor_ = tensor;
|
||||
return *this;
|
||||
}
|
||||
|
||||
InputAndOutput::InputAndOutput() : cell_(nullptr), prev_(), index_(-1) {}
|
||||
|
||||
InputAndOutput::InputAndOutput(const Tensor &tensor)
|
||||
: cell_(std::make_shared<ParameterCell>(tensor.Clone())), prev_(), index_(-1) {}
|
||||
InputAndOutput::InputAndOutput(Tensor &&tensor) : cell_(std::make_shared<ParameterCell>(tensor)), prev_(), index_(-1) {}
|
||||
|
||||
InputAndOutput::InputAndOutput(const std::shared_ptr<CellBase> &cell, const std::vector<InputAndOutput> &prev,
|
||||
int32_t index)
|
||||
: cell_(cell), prev_(prev), index_(index) {}
|
||||
} // namespace mindspore::api
|
|
@ -0,0 +1,284 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cxx_api/model/acl/acl_model.h"
|
||||
#include <memory>
|
||||
#include "utils/context/context_extends.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
std::weak_ptr<AclModel::AclEnvGuard> AclModel::global_acl_env_;
|
||||
std::mutex AclModel::global_acl_env_mutex_;
|
||||
|
||||
Status AclModel::InitEnv() {
|
||||
if (init_flag_) {
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(options_);
|
||||
aclError ret;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(global_acl_env_mutex_);
|
||||
acl_env_ = global_acl_env_.lock();
|
||||
if (acl_env_ != nullptr) {
|
||||
if (options_->dump_cfg_path.empty()) {
|
||||
MS_LOG(INFO) << "Acl has been initialized, skip.";
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Acl has been initialized, skip, so dump config will be ignored.";
|
||||
}
|
||||
} else {
|
||||
acl_env_ = std::make_shared<AclEnvGuard>(options_->dump_cfg_path);
|
||||
if (acl_env_->GetErrno() != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Execute aclInit Failed";
|
||||
return FAILED;
|
||||
}
|
||||
global_acl_env_ = acl_env_;
|
||||
MS_LOG(INFO) << "Acl init success";
|
||||
}
|
||||
}
|
||||
|
||||
ret = aclrtSetDevice(device_id_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Acl open device " << device_id_ << " failed";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Open device " << device_id_ << " success";
|
||||
|
||||
ret = aclrtCreateContext(&context_, device_id_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Acl create context failed";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Create context success";
|
||||
|
||||
ret = aclrtSetCurrentContext(context_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Acl set current context failed";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Set context success";
|
||||
|
||||
ret = aclrtCreateStream(&stream_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Acl create stream failed";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Create stream success";
|
||||
|
||||
aclrtRunMode run_mode;
|
||||
ret = aclrtGetRunMode(&run_mode);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Acl get run mode failed";
|
||||
return FAILED;
|
||||
}
|
||||
bool is_device = (run_mode == ACL_DEVICE);
|
||||
model_process_.SetIsDevice(is_device);
|
||||
MS_LOG(INFO) << "Get run mode success is device input/output " << is_device;
|
||||
|
||||
if (dvpp_process_.InitResource(stream_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "DVPP init resource failed";
|
||||
return FAILED;
|
||||
}
|
||||
ModelConverter::RegAllOp();
|
||||
|
||||
MS_LOG(INFO) << "Init acl success, device id " << device_id_;
|
||||
init_flag_ = true;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status AclModel::FinalizeEnv() {
|
||||
if (!init_flag_) {
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
dvpp_process_.Finalize();
|
||||
aclError ret;
|
||||
if (stream_ != nullptr) {
|
||||
ret = aclrtDestroyStream(stream_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Destroy stream failed";
|
||||
}
|
||||
stream_ = nullptr;
|
||||
}
|
||||
MS_LOG(INFO) << "End to destroy stream";
|
||||
if (context_ != nullptr) {
|
||||
ret = aclrtDestroyContext(context_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Destroy context failed";
|
||||
}
|
||||
context_ = nullptr;
|
||||
}
|
||||
MS_LOG(INFO) << "End to destroy context";
|
||||
|
||||
ret = aclrtResetDevice(device_id_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Reset devie " << device_id_ << " failed";
|
||||
}
|
||||
MS_LOG(INFO) << "End to reset device " << device_id_;
|
||||
|
||||
init_flag_ = false;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status AclModel::LoadModel(const Buffer &model_data, ModelType type,
|
||||
const std::map<std::string, std::string> &options) {
|
||||
if (load_flag_) {
|
||||
MS_LOG(ERROR) << "Model has been loaded.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
options_ = std::make_unique<AclModelOptions>(options);
|
||||
MS_EXCEPTION_IF_NULL(options_);
|
||||
|
||||
Status ret = InitEnv();
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "InitEnv failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Buffer om_data;
|
||||
if (type == ModelType::kMindIR) {
|
||||
model_converter_.set_options(options_.get());
|
||||
om_data = model_converter_.LoadMindIR(model_data);
|
||||
} else if (type == ModelType::kAIR) {
|
||||
model_converter_.set_options(options_.get());
|
||||
om_data = model_converter_.LoadAscendIR(model_data);
|
||||
} else if (type == ModelType::kOM) {
|
||||
om_data = model_data;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported model type " << type;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// acl load model
|
||||
uint32_t acl_model_id;
|
||||
auto acl_ret = aclmdlLoadFromMem(om_data.Data(), om_data.DataSize(), &acl_model_id);
|
||||
if (acl_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Call aclmdlLoadFromMem failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// acl init model resource
|
||||
model_process_.set_model_id(acl_model_id);
|
||||
ret = model_process_.PreInitModelResource();
|
||||
if (ret != SUCCESS) {
|
||||
(void)aclmdlUnload(acl_model_id);
|
||||
MS_LOG(ERROR) << "Pre init model resource failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// acl init dvpp
|
||||
ret = dvpp_process_.InitWithJsonConfig(options_->dvpp_cfg_path);
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "DVPP config file parse error.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
load_flag_ = true;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status AclModel::LoadModel(const std::string &file_name, ModelType type,
|
||||
const std::map<std::string, std::string> &options) {
|
||||
Buffer model_data = ModelConverter::ReadFile(file_name);
|
||||
if (model_data.DataSize() == 0) {
|
||||
MS_LOG(ERROR) << "Read file " << file_name << " failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return LoadModel(model_data, type, options);
|
||||
}
|
||||
|
||||
Status AclModel::UnloadModel() {
|
||||
if (!load_flag_) {
|
||||
MS_LOG(WARNING) << "No model is loaded, skip unload.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
aclError rt_ret = aclrtSetCurrentContext(context_);
|
||||
if (rt_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Set the ascend device context failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Status ret = model_process_.UnLoad();
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Unload model inner failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
ret = FinalizeEnv();
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "FinalizeEnv failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Unload model success.";
|
||||
load_flag_ = false;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status AclModel::Train(const DataSet &, std::map<std::string, Buffer> *) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Status AclModel::Eval(const DataSet &, std::map<std::string, Buffer> *) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Status AclModel::Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
if (!load_flag_) {
|
||||
MS_LOG(ERROR) << "No model is loaded, predict failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
aclError rt_ret = aclrtSetCurrentContext(context_);
|
||||
if (rt_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Set the ascend device context failed";
|
||||
return FAILED;
|
||||
}
|
||||
return model_process_.Predict(inputs, outputs);
|
||||
}
|
||||
|
||||
Status AclModel::GetInputsInfo(std::vector<Tensor> *tensor_list) const {
|
||||
MS_EXCEPTION_IF_NULL(tensor_list);
|
||||
return model_process_.GetInputsInfo(tensor_list);
|
||||
}
|
||||
|
||||
Status AclModel::GetOutputsInfo(std::vector<Tensor> *tensor_list) const {
|
||||
MS_EXCEPTION_IF_NULL(tensor_list);
|
||||
return model_process_.GetOutputsInfo(tensor_list);
|
||||
}
|
||||
|
||||
AclModel::AclEnvGuard::AclEnvGuard(const std::string &cfg_file) {
|
||||
errno_ = aclInit(common::SafeCStr(cfg_file));
|
||||
if (errno_ != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Execute aclInit Failed";
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Acl init success";
|
||||
}
|
||||
|
||||
AclModel::AclEnvGuard::~AclEnvGuard() {
|
||||
errno_ = aclFinalize();
|
||||
if (errno_ != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Finalize acl failed";
|
||||
}
|
||||
MS_LOG(INFO) << "Acl finalize success";
|
||||
}
|
||||
} // namespace mindspore::api
|
|
@ -0,0 +1,99 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_CXX_API_ACL_MODEL_H
|
||||
#define MINDSPORE_CCSRC_CXX_API_ACL_MODEL_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "include/api/status.h"
|
||||
#include "cxx_api/model/model_impl.h"
|
||||
#include "cxx_api/model/acl/dvpp_process.h"
|
||||
#include "cxx_api/model/acl/model_process.h"
|
||||
#include "cxx_api/model/acl/model_converter.h"
|
||||
#include "cxx_api/model/acl/acl_model_options.h"
|
||||
#include "ir/tensor.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
class AclModel : public ModelImpl {
|
||||
public:
|
||||
explicit AclModel(uint32_t device_id)
|
||||
: init_flag_(false),
|
||||
load_flag_(false),
|
||||
device_type_("AscendCL"),
|
||||
device_id_(device_id),
|
||||
context_(nullptr),
|
||||
stream_(nullptr),
|
||||
acl_env_(nullptr),
|
||||
model_process_(),
|
||||
dvpp_process_(),
|
||||
model_converter_(),
|
||||
options_(nullptr) {}
|
||||
~AclModel() = default;
|
||||
|
||||
Status LoadModel(const Buffer &model_data, ModelType type,
|
||||
const std::map<std::string, std::string> &options) override;
|
||||
Status LoadModel(const std::string &file_name, ModelType type,
|
||||
const std::map<std::string, std::string> &options) override;
|
||||
Status UnloadModel() override;
|
||||
|
||||
Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
|
||||
Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
|
||||
Status Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) override;
|
||||
|
||||
Status GetInputsInfo(std::vector<Tensor> *tensor_list) const override;
|
||||
Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const override;
|
||||
|
||||
private:
|
||||
bool init_flag_;
|
||||
bool load_flag_;
|
||||
std::string device_type_;
|
||||
int32_t device_id_;
|
||||
aclrtContext context_;
|
||||
aclrtStream stream_;
|
||||
|
||||
class AclEnvGuard;
|
||||
std::shared_ptr<AclEnvGuard> acl_env_;
|
||||
static std::weak_ptr<AclEnvGuard> global_acl_env_;
|
||||
static std::mutex global_acl_env_mutex_;
|
||||
|
||||
ModelProcess model_process_;
|
||||
DvppProcess dvpp_process_;
|
||||
ModelConverter model_converter_;
|
||||
std::unique_ptr<AclModelOptions> options_;
|
||||
|
||||
Status InitEnv();
|
||||
Status FinalizeEnv();
|
||||
};
|
||||
|
||||
class AclModel::AclEnvGuard {
|
||||
public:
|
||||
explicit AclEnvGuard(const std::string &cfg_file);
|
||||
~AclEnvGuard();
|
||||
aclError GetErrno() const { return errno_; }
|
||||
|
||||
private:
|
||||
aclError errno_;
|
||||
};
|
||||
|
||||
API_REG_MODEL(AscendCL, AclModel);
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_ACL_MODEL_H
|
|
@ -0,0 +1,66 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cxx_api/model/acl/acl_model_options.h"
|
||||
#include <memory>
|
||||
#include "external/ge/ge_api_types.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
static std::string ParseOption(const std::map<std::string, std::string> &options, const std::string &key) {
|
||||
auto iter = options.find(key);
|
||||
if (iter != options.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
AclModelOptions::AclModelOptions(const std::map<std::string, std::string> &options) {
|
||||
dump_cfg_path = ParseOption(options, kModelOptionDumpCfgPath);
|
||||
dvpp_cfg_path = ParseOption(options, kModelOptionDvppCfgPath);
|
||||
output_node = ParseOption(options, kModelOptionOutputNode);
|
||||
// to acl
|
||||
insert_op_cfg_path = ParseOption(options, kModelOptionInsertOpCfgPath);
|
||||
input_format = ParseOption(options, kModelOptionInputFormat);
|
||||
input_shape = ParseOption(options, kModelOptionInputShape);
|
||||
dynamic_batch_size = ParseOption(options, kModelOptionInputShape);
|
||||
dynamic_image_size = ParseOption(options, kModelOptionInputShape);
|
||||
dynamic_dims = ParseOption(options, kModelOptionInputShape);
|
||||
serial_nodes_name = ParseOption(options, kModelOptionSerialInput);
|
||||
output_type = ParseOption(options, kModelOptionOutputType);
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> AclModelOptions::GenAclOptions() const {
|
||||
const std::map<std::string const *, std::string> acl_options_map = {
|
||||
{&insert_op_cfg_path, ge::ir_option::INSERT_OP_FILE},
|
||||
{&input_format, ge::ir_option::INPUT_FORMAT},
|
||||
{&input_shape, ge::ir_option::INPUT_SHAPE},
|
||||
{&dynamic_batch_size, ge::ir_option::DYNAMIC_BATCH_SIZE},
|
||||
{&dynamic_image_size, ge::ir_option::DYNAMIC_IMAGE_SIZE},
|
||||
{&dynamic_dims, ge::ir_option::DYNAMIC_DIMS},
|
||||
{&serial_nodes_name, ge::ir_option::INPUT_FP16_NODES},
|
||||
{&output_type, ge::ir_option::OUTPUT_TYPE},
|
||||
};
|
||||
|
||||
std::map<std::string, std::string> acl_options;
|
||||
for (auto [ms_option, acl_option_key] : acl_options_map) {
|
||||
if (ms_option == nullptr || ms_option->empty()) {
|
||||
continue;
|
||||
}
|
||||
acl_options.emplace(acl_option_key, *ms_option);
|
||||
}
|
||||
return acl_options;
|
||||
}
|
||||
|
||||
} // namespace mindspore::api
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_OPTION_PARSER_H
|
||||
#define MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_OPTION_PARSER_H
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/status.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
struct AclModelOptions {
|
||||
std::string dump_cfg_path;
|
||||
std::string dvpp_cfg_path;
|
||||
std::string output_node; // todo: at convert.cc::BuildGraph(), no atc options
|
||||
// build options
|
||||
std::string insert_op_cfg_path;
|
||||
std::string input_format;
|
||||
std::string input_shape;
|
||||
std::string dynamic_batch_size;
|
||||
std::string dynamic_image_size;
|
||||
std::string dynamic_dims;
|
||||
std::string serial_nodes_name;
|
||||
std::string output_type;
|
||||
|
||||
explicit AclModelOptions(const std::map<std::string, std::string> &options);
|
||||
~AclModelOptions() = default;
|
||||
|
||||
std::map<std::string, std::string> GenAclOptions() const;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
|
||||
#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_OPTION_PARSER_H
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,160 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_DVPP_PROCESS_H
|
||||
#define MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_DVPP_PROCESS_H
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "acl/acl.h"
|
||||
#include "acl/acl_mdl.h"
|
||||
#include "acl/acl_rt.h"
|
||||
#include "acl/ops/acl_dvpp.h"
|
||||
#include "include/api/status.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
struct DvppDecodePara {
|
||||
acldvppPixelFormat pixel_format = PIXEL_FORMAT_YUV_SEMIPLANAR_420;
|
||||
};
|
||||
|
||||
struct DvppResizePara {
|
||||
uint32_t output_width = 0;
|
||||
uint32_t output_height = 0;
|
||||
};
|
||||
|
||||
enum DvppCropType {
|
||||
// crop left,top,right,bottom is given in config
|
||||
kDvppCropTypeOffset = 0,
|
||||
// crop left,top,right,bottom is calculated by image width/height and output crop width/height
|
||||
kDvppCropTypeCentre = 1,
|
||||
};
|
||||
|
||||
struct DvppRoiArea {
|
||||
uint32_t left = 0;
|
||||
uint32_t top = 0;
|
||||
uint32_t right = 0;
|
||||
uint32_t bottom = 0;
|
||||
};
|
||||
|
||||
struct DvppCropInfo {
|
||||
DvppCropType crop_type = kDvppCropTypeOffset;
|
||||
DvppRoiArea crop_area; // when kDvppCropTypeOffset
|
||||
uint32_t crop_width = 0; // when kDvppCropTypeCentre
|
||||
uint32_t crop_height = 0; // when kDvppCropTypeCentre
|
||||
};
|
||||
|
||||
struct DvppCropPara {
|
||||
DvppCropInfo crop_info;
|
||||
uint32_t output_width = 0;
|
||||
uint32_t output_height = 0;
|
||||
};
|
||||
|
||||
struct DvppCropAndPastePara {
|
||||
DvppCropInfo crop_info;
|
||||
DvppRoiArea paste_area;
|
||||
uint32_t output_width = 0;
|
||||
uint32_t output_height = 0;
|
||||
};
|
||||
|
||||
class DvppProcess {
|
||||
public:
|
||||
DvppProcess();
|
||||
~DvppProcess();
|
||||
|
||||
Status InitResource(aclrtStream stream);
|
||||
void Finalize();
|
||||
Status InitJpegDecodePara(const DvppDecodePara &decode_para); // jpeg decode + (resize | crop)
|
||||
Status InitResizePara(const DvppResizePara &resize_para); // jpeg decode + resize
|
||||
Status InitCropPara(const DvppCropPara &crop_para); // jpeg decode + crop
|
||||
Status InitCropAndPastePara(const DvppCropAndPastePara &crop_and_paste_para); // jpeg decode + crop&paste
|
||||
|
||||
Status InitWithJsonConfig(const std::string &json_config);
|
||||
|
||||
// output device buffer will be destroy by DvppProcess itself.
|
||||
Status Process(const void *pic_buffer, size_t pic_buffer_size, void **output_device_buffer, size_t *output_size);
|
||||
Status Process(const std::vector<const void *> &pic_buffer_list, const std::vector<size_t> &pic_buffer_size_list,
|
||||
void **output_device_buffer, size_t *output_size);
|
||||
bool HasLoaded() const { return loaded_flag_; }
|
||||
|
||||
private:
|
||||
bool loaded_flag_ = false;
|
||||
uint32_t pic_width_ = 0;
|
||||
uint32_t pic_height_ = 0;
|
||||
|
||||
DvppDecodePara decode_para_;
|
||||
DvppResizePara resize_para_;
|
||||
DvppCropPara crop_para_;
|
||||
DvppCropAndPastePara crop_and_paste_para_;
|
||||
// only one of the resize or crop flag can be true
|
||||
bool to_resize_flag_ = false;
|
||||
bool to_crop_flag_ = false;
|
||||
bool to_crop_and_paste_flag_ = false;
|
||||
|
||||
void *input_pic_dev_buffer_ = nullptr;
|
||||
uint32_t input_pic_buffer_size_ = 0;
|
||||
|
||||
uint32_t decode_output_buffer_size_ = 0;
|
||||
void *decode_output_buffer_dev_ = nullptr;
|
||||
acldvppPicDesc *decode_output_desc_ = nullptr;
|
||||
|
||||
acldvppResizeConfig *resize_config_ = nullptr;
|
||||
acldvppRoiConfig *crop_area_ = nullptr;
|
||||
acldvppRoiConfig *paste_area_ = nullptr;
|
||||
|
||||
acldvppPicDesc *vpc_output_desc_ = nullptr;
|
||||
void *vpc_output_buffer_dev_ = nullptr; // vpc_output_buffer_size_ length
|
||||
uint32_t vpc_output_buffer_size_ = 0;
|
||||
|
||||
void *batch_vpc_output_buffer_dev_ = nullptr; // batch_size_ * vpc_output_buffer_size_ length
|
||||
uint32_t batch_size_ = 0;
|
||||
|
||||
aclrtStream stream_ = nullptr;
|
||||
acldvppChannelDesc *dvpp_channel_desc_ = nullptr;
|
||||
|
||||
uint32_t AlignmentHelper(uint32_t org_size, uint32_t alignment) const;
|
||||
uint32_t GetImageBufferSize(uint32_t stride_width, uint32_t stride_height, acldvppPixelFormat pixel_format) const;
|
||||
Status GetPicDescStride(uint32_t width, uint32_t height, uint32_t *stride_width, uint32_t *stride_height);
|
||||
Status GetPicDescStrideDecode(uint32_t width, uint32_t height, uint32_t *stride_width, uint32_t *stride_height);
|
||||
Status InputInputBuffer(const void *pic_buffer, size_t pic_buffer_size);
|
||||
Status InitDecodeOutputDesc(uint32_t image_width,
|
||||
uint32_t image_height); // decode_output_desc_, decode_output_buffer_dev_
|
||||
Status CheckRoiAreaWidthHeight(uint32_t width, uint32_t height);
|
||||
Status CheckAndAdjustRoiArea(DvppRoiArea *area);
|
||||
Status UpdateCropArea(uint32_t image_width, uint32_t image_height);
|
||||
Status CheckResizeImageInfo(uint32_t image_width, uint32_t image_height) const;
|
||||
void DestroyDecodeDesc();
|
||||
|
||||
Status InitVpcOutputDesc(uint32_t output_width, uint32_t output_height,
|
||||
acldvppPixelFormat pixel_format); // vpc_output_desc_, vpc_output_buffer_dev_batch_
|
||||
Status InitRoiAreaConfig(const DvppRoiArea &init_para, acldvppRoiConfig **roi_area);
|
||||
Status InitCommonCropPara(uint32_t out_width, uint32_t out_height, DvppCropInfo *crop_info);
|
||||
Status InitResizeOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, resize_config
|
||||
Status InitCropOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, crop_area_
|
||||
Status InitCropAndPasteOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, crop_area_, paste_area_
|
||||
void DestroyVpcOutputDesc();
|
||||
|
||||
Status ProcessDecode();
|
||||
Status ProcessResize();
|
||||
Status ProcessCrop();
|
||||
Status ProcessCropAndPaste();
|
||||
void DestroyResource();
|
||||
|
||||
Status GetJpegWidthHeight(const void *pic_buffer, size_t pic_buffer_size, uint32_t *image_width,
|
||||
uint32_t *image_height);
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
|
||||
#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_DVPP_PROCESS_H
|
|
@ -0,0 +1,285 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cxx_api/model/acl/model_converter.h"
|
||||
#include <memory>
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "utils/load_onnx/anf_converter.h"
|
||||
#include "transform/graph_ir/convert.h"
|
||||
#include "transform/graph_ir/graph_runner.h"
|
||||
#include "mindspore/core/utils/ms_context.h"
|
||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
|
||||
#include "graph/model.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace {
|
||||
transform::TensorOrderMap GetParams(const FuncGraphPtr &anf_graph) {
|
||||
transform::TensorOrderMap res;
|
||||
for (auto &anf_node : anf_graph->parameters()) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
auto para = anf_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(para);
|
||||
if (para->has_default()) {
|
||||
auto value = para->default_param();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
auto tensor = value->cast<std::shared_ptr<tensor::Tensor>>();
|
||||
res.emplace(para->name(), tensor);
|
||||
MS_LOG(INFO) << "Parameter " << para->name() << " has default value.";
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
bool CreateSessionAndGraphRunner() {
|
||||
std::shared_ptr<ge::Session> sess = transform::DfGraphManager::GetInstance().GetGeSession();
|
||||
if (sess == nullptr) {
|
||||
transform::SessionOptions options;
|
||||
options["ge.trainFlag"] = "0";
|
||||
options["ge.enablePrintOpPass"] = "0";
|
||||
sess = transform::GraphRunner::NewSession(options);
|
||||
if (sess == nullptr) {
|
||||
MS_LOG(ERROR) << "Init data graph failed, because of create Ge session failed";
|
||||
return false;
|
||||
} else {
|
||||
transform::DfGraphManager::GetInstance().SetGeSession(sess);
|
||||
}
|
||||
}
|
||||
|
||||
transform::GraphRunnerOptions options;
|
||||
options.sess_ptr = sess;
|
||||
auto graph_runner = std::make_shared<transform::GraphRunner>(options);
|
||||
if (graph_runner == nullptr) {
|
||||
MS_LOG(ERROR) << "Create new graph runner failed";
|
||||
return false;
|
||||
} else {
|
||||
transform::DfGraphManager::GetInstance().SetGraphRunner(graph_runner);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::shared_ptr<FuncGraph> ModelConverter::ConvertMindIrToFuncGraph(const Buffer &model_data) {
|
||||
try {
|
||||
auto anf_graph =
|
||||
lite::AnfConverter::RunAnfConverter(reinterpret_cast<const char *>(model_data.Data()), model_data.DataSize());
|
||||
return anf_graph;
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Load MindIR failed.";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
transform::DfGraphPtr ModelConverter::ConvertFuncGraphToAIR(const FuncGraphPtr &anf_graph) {
|
||||
for (auto &anf_node : anf_graph->parameters()) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
auto para = anf_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(para);
|
||||
// normalize name
|
||||
std::string name = para->name();
|
||||
for (auto pos = name.find(':'); pos != std::string::npos; pos = name.find(':')) {
|
||||
name = name.substr(0, pos) + "_" + name.substr(pos + 1);
|
||||
MS_LOG(INFO) << name;
|
||||
}
|
||||
para->set_name(name);
|
||||
}
|
||||
|
||||
transform::DfGraphConvertor convertor(anf_graph);
|
||||
std::string net_id = "0";
|
||||
std::string init_graph = "init_subgraph." + net_id;
|
||||
std::string checkpoint_name = "save." + net_id;
|
||||
|
||||
convertor.set_training(false);
|
||||
(void)convertor.ConvertAllNode().InitParam(GetParams(anf_graph)).BuildGraph();
|
||||
(void)convertor.GenerateCheckpointGraph();
|
||||
if (convertor.ErrCode() != 0) {
|
||||
transform::DfGraphManager::GetInstance().ClearGraph();
|
||||
MS_LOG(ERROR) << "Convert df graph failed, err:" << convertor.ErrCode();
|
||||
return nullptr;
|
||||
}
|
||||
(void)transform::DfGraphManager::GetInstance().AddGraph(anf_graph->ToString(), convertor.GetComputeGraph());
|
||||
(void)transform::DfGraphManager::GetInstance().AddGraph(init_graph, convertor.GetInitGraph());
|
||||
(void)transform::DfGraphManager::GetInstance().AddGraph(BROADCAST_GRAPH_NAME, convertor.GetBroadcastGraph());
|
||||
|
||||
transform::Status ret =
|
||||
transform::DfGraphManager::GetInstance().AddGraph(checkpoint_name, convertor.GetSaveCheckpointGraph());
|
||||
if (ret == transform::Status::SUCCESS) {
|
||||
transform::DfGraphManager::GetInstance().SetAnfGraph(checkpoint_name, anf_graph);
|
||||
}
|
||||
|
||||
(void)setenv("GE_TRAIN", "0", 1);
|
||||
|
||||
if (!CreateSessionAndGraphRunner()) {
|
||||
MS_LOG(ERROR) << "Create GE Session or GraphRunner failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto wrap_ptr = transform::DfGraphManager::GetInstance().GetGraphByName(anf_graph->ToString());
|
||||
if (wrap_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Get graph form DfGraphManager failed!";
|
||||
return nullptr;
|
||||
}
|
||||
transform::DfGraphPtr &ge_graph = wrap_ptr->graph_ptr_;
|
||||
if (ge_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "The export graph is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return ge_graph;
|
||||
}
|
||||
|
||||
Buffer ModelConverter::BuildAirModel(const transform::DfGraphPtr &graph,
|
||||
const std::map<std::string, std::string> &acl_options) {
|
||||
ge::ModelBufferData model;
|
||||
auto ge_options = acl_options;
|
||||
ge_options.emplace(ge::ir_option::SOC_VERSION, "Ascend310");
|
||||
auto ret = ge::aclgrphBuildInitialize(ge_options);
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Call aclgrphBuildInitialize fail.";
|
||||
return Buffer();
|
||||
}
|
||||
|
||||
ret = ge::aclgrphBuildModel(*graph, acl_options, model);
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Call aclgrphBuildModel fail.";
|
||||
return Buffer();
|
||||
}
|
||||
|
||||
ge::aclgrphBuildFinalize();
|
||||
return Buffer(model.data.get(), model.length);
|
||||
}
|
||||
|
||||
void ModelConverter::RegAllOp() {
|
||||
static std::mutex init_mutex;
|
||||
static bool Initialized = false;
|
||||
|
||||
std::lock_guard<std::mutex> lock(init_mutex);
|
||||
if (Initialized) {
|
||||
return;
|
||||
}
|
||||
Initialized = true;
|
||||
MsContext::GetInstance()->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
Py_Initialize();
|
||||
auto c_expression = PyImport_ImportModule("mindspore._c_expression");
|
||||
MS_EXCEPTION_IF_NULL(c_expression);
|
||||
PyObject *c_expression_dict = PyModule_GetDict(c_expression);
|
||||
MS_EXCEPTION_IF_NULL(c_expression_dict);
|
||||
|
||||
PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy");
|
||||
MS_EXCEPTION_IF_NULL(op_info_loader_class);
|
||||
PyObject *op_info_loader = PyInstanceMethod_New(op_info_loader_class);
|
||||
MS_EXCEPTION_IF_NULL(op_info_loader);
|
||||
PyObject *op_info_loader_ins = PyObject_CallObject(op_info_loader, nullptr);
|
||||
MS_EXCEPTION_IF_NULL(op_info_loader_ins);
|
||||
auto all_ops_info_vector_addr_ul = PyObject_CallMethod(op_info_loader_ins, "get_all_ops_info", nullptr);
|
||||
MS_EXCEPTION_IF_NULL(all_ops_info_vector_addr_ul);
|
||||
auto all_ops_info_vector_addr = PyLong_AsVoidPtr(all_ops_info_vector_addr_ul);
|
||||
auto all_ops_info = static_cast<std::vector<kernel::OpInfo *> *>(all_ops_info_vector_addr);
|
||||
for (auto op_info : *all_ops_info) {
|
||||
kernel::OpLib::RegOpInfo(std::shared_ptr<kernel::OpInfo>(op_info));
|
||||
}
|
||||
all_ops_info->clear();
|
||||
delete all_ops_info;
|
||||
Py_DECREF(op_info_loader);
|
||||
Py_DECREF(op_info_loader_class);
|
||||
Py_DECREF(c_expression_dict);
|
||||
Py_DECREF(c_expression);
|
||||
}
|
||||
|
||||
Buffer ModelConverter::ReadFile(const std::string &file) {
|
||||
Buffer buffer;
|
||||
if (file.empty()) {
|
||||
MS_LOG(ERROR) << "Pointer file is nullptr";
|
||||
return buffer;
|
||||
}
|
||||
std::string realPath = file;
|
||||
std::ifstream ifs(realPath);
|
||||
if (!ifs.good()) {
|
||||
MS_LOG(ERROR) << "File: " << realPath << " is not exist";
|
||||
return buffer;
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
MS_LOG(ERROR) << "File: " << realPath << "open failed";
|
||||
return buffer;
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
size_t size = ifs.tellg();
|
||||
buffer.ResizeData(size);
|
||||
if (buffer.DataSize() != size) {
|
||||
MS_LOG(ERROR) << "Malloc buf failed, file: " << realPath;
|
||||
ifs.close();
|
||||
return buffer;
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
|
||||
ifs.close();
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
Buffer ModelConverter::LoadMindIR(const Buffer &model_data) {
|
||||
auto func_graph = ConvertMindIrToFuncGraph(model_data);
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Convert MindIR to FuncGraph failed.";
|
||||
return Buffer();
|
||||
}
|
||||
|
||||
auto df_graph = ConvertFuncGraphToAIR(func_graph);
|
||||
if (df_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Convert FuncGraph to AscendIR failed.";
|
||||
return Buffer();
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> acl_options;
|
||||
if (options_ != nullptr) {
|
||||
acl_options = options_->GenAclOptions();
|
||||
}
|
||||
|
||||
auto om_data = BuildAirModel(df_graph, acl_options);
|
||||
return om_data;
|
||||
}
|
||||
|
||||
Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) {
|
||||
ge::Model load_model = ge::Model("loadmodel", "version2");
|
||||
ge::Status ret =
|
||||
ge::Model::Load(reinterpret_cast<const uint8_t *>(model_data.Data()), model_data.DataSize(), load_model);
|
||||
if (ret != ge::GRAPH_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Load AscendIR failed, ret = " << ret;
|
||||
return Buffer();
|
||||
}
|
||||
|
||||
transform::DfGraphPtr df_graph = std::make_shared<transform::DfGraph>(load_model.GetGraph());
|
||||
if (df_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Convert FuncGraph to AscendIR failed.";
|
||||
return Buffer();
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> acl_options;
|
||||
if (options_ != nullptr) {
|
||||
acl_options = options_->GenAclOptions();
|
||||
}
|
||||
|
||||
auto om_data = BuildAirModel(df_graph, acl_options);
|
||||
return om_data;
|
||||
}
|
||||
} // namespace mindspore::api
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_CONVERTER_H
|
||||
#define MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_CONVERTER_H
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/status.h"
|
||||
#include "mindspore/core/ir/func_graph.h"
|
||||
#include "transform/graph_ir/types.h"
|
||||
#include "external/ge/ge_ir_build.h"
|
||||
#include "cxx_api/model/acl/acl_model_options.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
class ModelConverter {
|
||||
public:
|
||||
ModelConverter() : options_(nullptr) {}
|
||||
|
||||
Buffer LoadMindIR(const Buffer &model_data);
|
||||
Buffer LoadAscendIR(const Buffer &model_data);
|
||||
|
||||
void set_options(AclModelOptions *options) { options_ = options; }
|
||||
|
||||
static Buffer ReadFile(const std::string &file);
|
||||
static void RegAllOp();
|
||||
|
||||
private:
|
||||
std::shared_ptr<FuncGraph> ConvertMindIrToFuncGraph(const Buffer &model_data);
|
||||
transform::DfGraphPtr ConvertFuncGraphToAIR(const FuncGraphPtr &anf_graph);
|
||||
Buffer BuildAirModel(const transform::DfGraphPtr &graph, const std::map<std::string, std::string> &acl_options);
|
||||
AclModelOptions *options_;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
|
||||
#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_CONVERTER_H
|
|
@ -0,0 +1,440 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cxx_api/model/acl/model_process.h"
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
static DataType TransToApiType(aclDataType data_type) {
|
||||
static const std::map<aclDataType, api::DataType> data_type_map = {
|
||||
{ACL_FLOAT16, api::kMsFloat16}, {ACL_FLOAT, api::kMsFloat32}, {ACL_DOUBLE, api::kMsFloat64},
|
||||
{ACL_INT8, api::kMsInt8}, {ACL_INT16, api::kMsInt16}, {ACL_INT32, api::kMsInt32},
|
||||
{ACL_INT64, api::kMsInt64}, {ACL_UINT8, api::kMsUint8}, {ACL_UINT16, api::kMsUint16},
|
||||
{ACL_UINT32, api::kMsUint32}, {ACL_UINT64, api::kMsUint64}, {ACL_BOOL, api::kMsBool},
|
||||
};
|
||||
auto it = data_type_map.find(data_type);
|
||||
if (it == data_type_map.end()) {
|
||||
return api::kInvalidDataType;
|
||||
} else {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
static void ConstructTensorDesc(const std::vector<AclTensorInfo> &acl_tensor_list, std::vector<Tensor> *tensor_list) {
|
||||
MS_EXCEPTION_IF_NULL(tensor_list);
|
||||
tensor_list->clear();
|
||||
|
||||
for (size_t i = 0; i < acl_tensor_list.size(); ++i) {
|
||||
const auto &info = acl_tensor_list[i];
|
||||
Tensor tensor_desc;
|
||||
tensor_desc.SetName(info.name);
|
||||
tensor_desc.SetDataType(TransToApiType(info.data_type));
|
||||
tensor_desc.SetShape(info.dims);
|
||||
tensor_list->push_back(tensor_desc);
|
||||
}
|
||||
}
|
||||
|
||||
Status ModelProcess::PreInitModelResource() {
|
||||
model_desc_ = aclmdlCreateDesc();
|
||||
aclError acl_ret = aclmdlGetDesc(model_desc_, model_id_);
|
||||
if (acl_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Read model desc failed";
|
||||
return FAILED;
|
||||
}
|
||||
Status ret = InitInputsBuffer();
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Create input buffer failed";
|
||||
return FAILED;
|
||||
}
|
||||
ret = InitOutputsBuffer();
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Create output buffer failed";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ModelProcess::LoadModelFromFile(const std::string &file_name, uint32_t *model_id) {
|
||||
MS_EXCEPTION_IF_NULL(model_id);
|
||||
aclError acl_ret = aclmdlLoadFromFile(file_name.c_str(), model_id);
|
||||
if (acl_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Read model file failed, file name is " << file_name;
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Load model success " << file_name;
|
||||
model_id_ = *model_id;
|
||||
if (PreInitModelResource() != SUCCESS) {
|
||||
aclmdlUnload(model_id_);
|
||||
MS_LOG(ERROR) << "Pre init model resource failed, file name is " << file_name;
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ModelProcess::InitInputsBuffer() {
|
||||
aclError ret;
|
||||
size_t input_size = aclmdlGetNumInputs(model_desc_);
|
||||
MS_LOG(INFO) << "input_size = " << input_size;
|
||||
for (size_t i = 0; i < input_size; ++i) {
|
||||
auto buffer_size = aclmdlGetInputSizeByIndex(model_desc_, i);
|
||||
void *data_mem_buffer = nullptr;
|
||||
if (!is_run_on_device_) { // need to copy input/output to/from device
|
||||
ret = aclrtMalloc(&data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Malloc device input buffer faild , input size " << buffer_size;
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
aclmdlIODims dims;
|
||||
ret = aclmdlGetInputDims(model_desc_, i, &dims);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Get input shape failed";
|
||||
if (!is_run_on_device_) {
|
||||
aclrtFree(data_mem_buffer);
|
||||
}
|
||||
return FAILED;
|
||||
}
|
||||
aclDataType data_type = aclmdlGetInputDataType(model_desc_, i);
|
||||
std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount);
|
||||
std::string input_name = aclmdlGetInputNameByIndex(model_desc_, i);
|
||||
if (input_name.empty()) {
|
||||
MS_LOG(WARNING) << "Get name of input " << i << " failed.";
|
||||
}
|
||||
MS_LOG(INFO) << "Name of input " << i << " is " << input_name;
|
||||
input_infos_.emplace_back(AclTensorInfo{data_mem_buffer, buffer_size, data_type, shape, input_name});
|
||||
}
|
||||
MS_LOG(INFO) << "Create model inputs success";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ModelProcess::CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset) {
|
||||
MS_EXCEPTION_IF_NULL(data_mem_buffer);
|
||||
aclError ret;
|
||||
auto free_data_buffer = [this](void *dataMemBuffer) {
|
||||
if (!is_run_on_device_) {
|
||||
aclrtFree(dataMemBuffer);
|
||||
} else {
|
||||
aclrtFreeHost(dataMemBuffer);
|
||||
}
|
||||
};
|
||||
|
||||
if (!is_run_on_device_) {
|
||||
ret = aclrtMalloc(data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Malloc device buffer faild , buffer size " << buffer_size;
|
||||
return FAILED;
|
||||
}
|
||||
} else {
|
||||
ret = aclrtMallocHost(data_mem_buffer, buffer_size);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Malloc device buffer faild , buffer size " << buffer_size;
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
auto data_buffer = aclCreateDataBuffer(*data_mem_buffer, buffer_size);
|
||||
if (data_buffer == nullptr) {
|
||||
MS_LOG(ERROR) << "Create Data Buffer failed";
|
||||
free_data_buffer(*data_mem_buffer);
|
||||
return FAILED;
|
||||
}
|
||||
ret = aclmdlAddDatasetBuffer(dataset, data_buffer);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "add data buffer failed";
|
||||
free_data_buffer(*data_mem_buffer);
|
||||
aclDestroyDataBuffer(data_buffer);
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ModelProcess::InitOutputsBuffer() {
|
||||
aclError ret;
|
||||
outputs_ = aclmdlCreateDataset();
|
||||
if (outputs_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Create input dataset failed";
|
||||
return FAILED;
|
||||
}
|
||||
size_t output_size = aclmdlGetNumOutputs(model_desc_);
|
||||
MS_LOG(INFO) << "output_size = " << output_size;
|
||||
for (size_t i = 0; i < output_size; ++i) {
|
||||
auto buffer_size = aclmdlGetOutputSizeByIndex(model_desc_, i);
|
||||
|
||||
void *data_mem_buffer = nullptr;
|
||||
if (CreateDataBuffer(&data_mem_buffer, buffer_size, outputs_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "add output data buffer failed, buffer size " << buffer_size;
|
||||
return FAILED;
|
||||
}
|
||||
aclmdlIODims dims;
|
||||
ret = aclmdlGetOutputDims(model_desc_, i, &dims);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Get input shape failed";
|
||||
if (!is_run_on_device_) {
|
||||
aclrtFree(data_mem_buffer);
|
||||
} else {
|
||||
aclrtFreeHost(data_mem_buffer);
|
||||
}
|
||||
return FAILED;
|
||||
}
|
||||
aclDataType data_type = aclmdlGetOutputDataType(model_desc_, i);
|
||||
std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount);
|
||||
std::string output_name = aclmdlGetOutputNameByIndex(model_desc_, i);
|
||||
if (output_name.empty()) {
|
||||
MS_LOG(WARNING) << "Get name of output " << i << " failed.";
|
||||
}
|
||||
MS_LOG(INFO) << "Name of input " << i << " is " << output_name;
|
||||
output_infos_.emplace_back(AclTensorInfo{data_mem_buffer, buffer_size, data_type, shape, output_name});
|
||||
}
|
||||
MS_LOG(INFO) << "Create model output success";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void ModelProcess::DestroyInputsDataset() {
|
||||
if (inputs_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(inputs_); i++) {
|
||||
auto dataBuffer = aclmdlGetDatasetBuffer(inputs_, i);
|
||||
aclDestroyDataBuffer(dataBuffer);
|
||||
}
|
||||
aclmdlDestroyDataset(inputs_);
|
||||
inputs_ = nullptr;
|
||||
}
|
||||
|
||||
void ModelProcess::DestroyInputsDataMem() {
|
||||
if (!is_run_on_device_) {
|
||||
for (const auto &item : input_infos_) {
|
||||
aclrtFree(item.device_data);
|
||||
}
|
||||
}
|
||||
input_infos_.clear();
|
||||
}
|
||||
|
||||
void ModelProcess::DestroyInputsBuffer() {
|
||||
DestroyInputsDataMem();
|
||||
DestroyInputsDataset();
|
||||
}
|
||||
|
||||
void ModelProcess::DestroyOutputsBuffer() {
|
||||
for (const auto &item : output_infos_) {
|
||||
if (!is_run_on_device_) {
|
||||
aclrtFree(item.device_data);
|
||||
} else {
|
||||
aclrtFreeHost(item.device_data);
|
||||
}
|
||||
}
|
||||
output_infos_.clear();
|
||||
|
||||
if (outputs_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(outputs_); i++) {
|
||||
auto dataBuffer = aclmdlGetDatasetBuffer(outputs_, i);
|
||||
aclDestroyDataBuffer(dataBuffer);
|
||||
}
|
||||
aclmdlDestroyDataset(outputs_);
|
||||
outputs_ = nullptr;
|
||||
}
|
||||
|
||||
Status ModelProcess::UnLoad() {
|
||||
auto ret = aclmdlUnload(model_id_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Unload model failed";
|
||||
return FAILED;
|
||||
}
|
||||
if (model_desc_ != nullptr) {
|
||||
ret = aclmdlDestroyDesc(model_desc_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Unload model failed";
|
||||
return FAILED;
|
||||
}
|
||||
model_desc_ = nullptr;
|
||||
}
|
||||
DestroyInputsBuffer();
|
||||
DestroyOutputsBuffer();
|
||||
MS_LOG(INFO) << "End unload model " << model_id_;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ModelProcess::CheckAndInitInput(const std::map<std::string, Buffer> &inputs) {
|
||||
aclError ret;
|
||||
inputs_ = aclmdlCreateDataset();
|
||||
// check inputs
|
||||
if (inputs.size() != input_infos_.size()) {
|
||||
MS_LOG(ERROR) << "inputs count not match, required count " << input_infos_.size() << ", given count "
|
||||
<< inputs.size();
|
||||
return INVALID_INPUTS;
|
||||
}
|
||||
for (size_t i = 0; i < input_infos_.size(); ++i) {
|
||||
const std::string &input_name = input_infos_[i].name;
|
||||
auto iter = inputs.find(input_name);
|
||||
if (iter == inputs.end()) {
|
||||
MS_LOG(ERROR) << "Model missing input " << input_name;
|
||||
return INVALID_INPUTS;
|
||||
}
|
||||
|
||||
if (iter->second.DataSize() != input_infos_[i].buffer_size) {
|
||||
MS_LOG(ERROR) << "input " << i << " data size not match, required size " << input_infos_[i].buffer_size
|
||||
<< ", given count " << iter->second.DataSize();
|
||||
return INVALID_INPUTS;
|
||||
}
|
||||
}
|
||||
// copy inputs
|
||||
for (size_t i = 0; i < input_infos_.size(); ++i) {
|
||||
const auto &info = input_infos_[i];
|
||||
auto iter = inputs.find(info.name);
|
||||
if (iter == inputs.end()) {
|
||||
MS_LOG(ERROR) << "Model missing input " << info.name;
|
||||
return INVALID_INPUTS;
|
||||
}
|
||||
|
||||
const auto &input = iter->second;
|
||||
const void *data = input.Data();
|
||||
|
||||
void *input_buffer;
|
||||
if (!is_run_on_device_) {
|
||||
ret = aclrtMemcpy(info.device_data, info.buffer_size, data, input.DataSize(), ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Acl memcpy input " << i << " data to device failed, buffer size " << input.DataSize();
|
||||
return FAILED;
|
||||
}
|
||||
input_buffer = info.device_data;
|
||||
} else {
|
||||
input_buffer = const_cast<void *>(data);
|
||||
}
|
||||
auto data_buffer = aclCreateDataBuffer(input_buffer, info.buffer_size);
|
||||
if (data_buffer == nullptr) {
|
||||
MS_LOG(ERROR) << "Create Data Buffer failed";
|
||||
return FAILED;
|
||||
}
|
||||
ret = aclmdlAddDatasetBuffer(inputs_, data_buffer);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "add data buffer failed";
|
||||
aclDestroyDataBuffer(data_buffer);
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ModelProcess::CheckAndInitDvppInput(const void *dvpp_outputs_buffer_dev, size_t dvpp_outputs_buffer_size,
|
||||
size_t input_index) {
|
||||
aclError ret;
|
||||
inputs_ = aclmdlCreateDataset();
|
||||
// check inputs
|
||||
if (input_index >= input_infos_.size()) {
|
||||
MS_LOG(ERROR) << "inputs count not match, required count " << input_infos_.size() << ", given index "
|
||||
<< input_index;
|
||||
return INVALID_INPUTS;
|
||||
}
|
||||
if (dvpp_outputs_buffer_dev == nullptr) {
|
||||
MS_LOG(ERROR) << "input " << 0 << " cannot be null";
|
||||
return FAILED;
|
||||
}
|
||||
if (dvpp_outputs_buffer_size != input_infos_[input_index].buffer_size) {
|
||||
MS_LOG(ERROR) << "input " << 0 << " data size not match, required size " << input_infos_[input_index].buffer_size
|
||||
<< ", given count " << dvpp_outputs_buffer_size;
|
||||
return INVALID_INPUTS;
|
||||
}
|
||||
// copy inputs
|
||||
auto &info = input_infos_[input_index];
|
||||
auto data_buffer = aclCreateDataBuffer(const_cast<void *>(dvpp_outputs_buffer_dev), info.buffer_size);
|
||||
if (data_buffer == nullptr) {
|
||||
MS_LOG(ERROR) << "Create Data Buffer failed";
|
||||
return FAILED;
|
||||
}
|
||||
ret = aclmdlAddDatasetBuffer(inputs_, data_buffer);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "add data buffer failed";
|
||||
aclDestroyDataBuffer(data_buffer);
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ModelProcess::Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
aclError acl_ret;
|
||||
Status ret = CheckAndInitInput(inputs);
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "check or init input failed";
|
||||
DestroyInputsDataset();
|
||||
return ret; // forward status error
|
||||
}
|
||||
acl_ret = aclmdlExecute(model_id_, inputs_, outputs_);
|
||||
DestroyInputsDataset();
|
||||
if (acl_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Execute Model Failed";
|
||||
return FAILED;
|
||||
}
|
||||
ret = BuildOutputs(outputs);
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Build outputs faield";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "excute model success";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
size_t ModelProcess::GetBatchSize() const {
|
||||
if (input_infos_.empty()) {
|
||||
MS_LOG(ERROR) << "Model is not loaded";
|
||||
return 0;
|
||||
}
|
||||
if (input_infos_[0].dims.empty()) {
|
||||
return 1;
|
||||
}
|
||||
return static_cast<size_t>(input_infos_[0].dims[0]);
|
||||
}
|
||||
|
||||
Status ModelProcess::BuildOutputs(std::map<std::string, Buffer> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
aclError ret;
|
||||
// copy outputs
|
||||
outputs->clear();
|
||||
aclrtMemcpyKind kind = is_run_on_device_ ? ACL_MEMCPY_HOST_TO_HOST : ACL_MEMCPY_DEVICE_TO_HOST;
|
||||
for (size_t i = 0; i < output_infos_.size(); ++i) {
|
||||
const auto &info = output_infos_[i];
|
||||
// todo
|
||||
outputs->emplace(info.name, Buffer());
|
||||
auto output = outputs->rbegin()->second;
|
||||
if (!output.ResizeData(info.buffer_size)) {
|
||||
MS_LOG(ERROR) << "new output data buffer failed, data size " << info.buffer_size;
|
||||
return FAILED;
|
||||
}
|
||||
ret = aclrtMemcpy(output.MutableData(), output.DataSize(), info.device_data, info.buffer_size, kind);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Memcpy output " << i << " from " << (is_run_on_device_ ? "host" : "device")
|
||||
<< " to host failed, memory size " << info.buffer_size;
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ModelProcess::GetInputsInfo(std::vector<Tensor> *tensor_list) const {
|
||||
ConstructTensorDesc(input_infos_, tensor_list);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ModelProcess::GetOutputsInfo(std::vector<Tensor> *tensor_list) const {
|
||||
ConstructTensorDesc(output_infos_, tensor_list);
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace mindspore::api
|
|
@ -0,0 +1,93 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_PROCESS_H
|
||||
#define MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_PROCESS_H
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "acl/acl.h"
|
||||
#include "acl/acl_mdl.h"
|
||||
#include "acl/acl_rt.h"
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/types.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
struct AclTensorInfo {
|
||||
void *device_data;
|
||||
size_t buffer_size;
|
||||
aclDataType data_type;
|
||||
std::vector<int64_t> dims;
|
||||
std::string name;
|
||||
};
|
||||
|
||||
struct ImagesDvppOutput {
|
||||
void *buffer_device = nullptr;
|
||||
size_t buffer_size = 0;
|
||||
size_t input_index = 0;
|
||||
};
|
||||
|
||||
class ModelProcess {
|
||||
public:
|
||||
ModelProcess()
|
||||
: model_id_(0xffffffff),
|
||||
is_run_on_device_(false),
|
||||
model_desc_(nullptr),
|
||||
inputs_(nullptr),
|
||||
outputs_(nullptr),
|
||||
input_infos_(),
|
||||
output_infos_() {}
|
||||
~ModelProcess() {}
|
||||
Status LoadModelFromFile(const std::string &file_name, uint32_t *model_id);
|
||||
Status UnLoad();
|
||||
Status Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs);
|
||||
Status PreInitModelResource();
|
||||
Status GetInputsInfo(std::vector<Tensor> *tensor_list) const;
|
||||
Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const;
|
||||
|
||||
// override this method to avoid request/reply data copy
|
||||
void SetIsDevice(bool is_device) { is_run_on_device_ = is_device; }
|
||||
|
||||
size_t GetBatchSize() const;
|
||||
void set_model_id(uint32_t model_id) { model_id_ = model_id; }
|
||||
uint32_t model_id() const { return model_id_; }
|
||||
|
||||
private:
|
||||
Status CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset);
|
||||
Status CheckAndInitInput(const std::map<std::string, Buffer> &inputs);
|
||||
Status CheckAndInitDvppInput(const void *dvpp_outputs_buffer_dev, size_t dvpp_outputs_buffer_size,
|
||||
size_t input_index);
|
||||
Status BuildOutputs(std::map<std::string, Buffer> *outputs);
|
||||
Status InitInputsBuffer();
|
||||
Status InitOutputsBuffer();
|
||||
|
||||
void DestroyInputsDataset();
|
||||
void DestroyInputsDataMem();
|
||||
void DestroyInputsBuffer();
|
||||
void DestroyOutputsBuffer();
|
||||
|
||||
uint32_t model_id_;
|
||||
// if run one device(AICPU), there is no need to alloc device memory and copy inputs to(/outputs from) device
|
||||
bool is_run_on_device_;
|
||||
aclmdlDesc *model_desc_;
|
||||
aclmdlDataset *inputs_;
|
||||
aclmdlDataset *outputs_;
|
||||
std::vector<AclTensorInfo> input_infos_;
|
||||
std::vector<AclTensorInfo> output_infos_;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
|
||||
#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_PROCESS_H
|
|
@ -0,0 +1,98 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "include/api/model.h"
|
||||
#include "cxx_api/model/model_impl.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
Status Model::LoadModel(const Buffer &model_data, ModelType type, const std::map<std::string, std::string> &options) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->LoadModel(model_data, type, options);
|
||||
}
|
||||
|
||||
Status Model::LoadModel(const std::string &file_name, ModelType type,
|
||||
const std::map<std::string, std::string> &options) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->LoadModel(file_name, type, options);
|
||||
}
|
||||
|
||||
Status Model::UnloadModel() {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->UnloadModel();
|
||||
}
|
||||
|
||||
Status Model::Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->Train(dataset, outputs);
|
||||
}
|
||||
|
||||
Status Model::Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->Eval(dataset, outputs);
|
||||
}
|
||||
|
||||
Status Model::Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->Predict(inputs, outputs);
|
||||
}
|
||||
|
||||
Status Model::Predict(const std::vector<Buffer> &inputs, std::map<std::string, Buffer> *outputs) {
|
||||
std::vector<Tensor> tensor_list;
|
||||
auto ret = GetInputsInfo(&tensor_list);
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "GetInputsInfo failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (inputs.size() != tensor_list.size()) {
|
||||
MS_LOG(ERROR) << "Model need " << tensor_list.size() << " inputs, but given " << inputs.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
std::map<std::string, Buffer> inputs_with_map;
|
||||
for (size_t i = 0; i < tensor_list.size(); ++i) {
|
||||
inputs_with_map.emplace(tensor_list[i].Name(), inputs[i]);
|
||||
}
|
||||
|
||||
return Predict(inputs_with_map, outputs);
|
||||
}
|
||||
|
||||
Status Model::GetInputsInfo(std::vector<Tensor> *tensor_list) const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->GetInputsInfo(tensor_list);
|
||||
}
|
||||
|
||||
Status Model::GetOutputsInfo(std::vector<Tensor> *tensor_list) const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->GetOutputsInfo(tensor_list);
|
||||
}
|
||||
|
||||
Model::Model(const std::string &device_type, uint32_t device_id)
|
||||
: impl_(ModelFactory::Instance().Create(device_type, device_id)) {
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Create session type " << device_type << " failed";
|
||||
}
|
||||
}
|
||||
|
||||
Model::Model(NetWork network, const std::string &device_type, uint32_t device_id) {
|
||||
// todo
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Create session type " << device_type << " failed";
|
||||
}
|
||||
}
|
||||
|
||||
Model::~Model() {}
|
||||
} // namespace mindspore::api
|
|
@ -0,0 +1,93 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_CXX_API_MODEL_MODEL_IMPL_H
|
||||
#define MINDSPORE_CCSRC_CXX_API_MODEL_MODEL_IMPL_H
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "include/api/model.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
class ModelImpl {
|
||||
public:
|
||||
ModelImpl() = default;
|
||||
virtual ~ModelImpl() = default;
|
||||
|
||||
virtual Status LoadModel(const Buffer &model_data, ModelType type,
|
||||
const std::map<std::string, std::string> &options) = 0;
|
||||
virtual Status LoadModel(const std::string &file_name, ModelType type,
|
||||
const std::map<std::string, std::string> &options) = 0;
|
||||
virtual Status UnloadModel() = 0;
|
||||
|
||||
virtual Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) = 0;
|
||||
virtual Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) = 0;
|
||||
virtual Status Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) = 0;
|
||||
|
||||
virtual Status GetInputsInfo(std::vector<Tensor> *tensor_list) const = 0;
|
||||
virtual Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const = 0;
|
||||
};
|
||||
|
||||
using ModelCreator = std::function<std::shared_ptr<ModelImpl>(uint32_t device_id)>;
|
||||
class ModelFactory {
|
||||
public:
|
||||
ModelFactory(const ModelFactory &) = delete;
|
||||
void operator=(const ModelFactory &) = delete;
|
||||
|
||||
static ModelFactory &Instance() {
|
||||
static ModelFactory instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void Register(const std::string &device_name, ModelCreator &&model_creator) {
|
||||
if (model_creators_.find(device_name) == model_creators_.end()) {
|
||||
(void)model_creators_.emplace(device_name, model_creator);
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ModelImpl> Create(const std::string &device_name, uint32_t device_id) {
|
||||
auto iter = model_creators_.find(device_name);
|
||||
if (model_creators_.end() != iter) {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
return (iter->second)(device_id);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
ModelFactory() = default;
|
||||
~ModelFactory() = default;
|
||||
std::map<std::string, ModelCreator> model_creators_;
|
||||
};
|
||||
|
||||
class ModelRegistrar {
|
||||
public:
|
||||
ModelRegistrar(const std::string &device_name, ModelCreator model_creator) {
|
||||
ModelFactory::Instance().Register(device_name, std::move(model_creator));
|
||||
}
|
||||
~ModelRegistrar() = default;
|
||||
};
|
||||
|
||||
#define API_REG_MODEL(DEVICE_NAME, MODEL_CLASS) \
|
||||
static const ModelRegistrar g_api_model_registrar__##DEVICE_NAME##_##_reg( \
|
||||
#DEVICE_NAME, [](uint32_t device_id) { return std::make_shared<MODEL_CLASS>(device_id); });
|
||||
|
||||
} // namespace mindspore::api
|
||||
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_MODEL_MODEL_IMPL_H
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "include/api/ops/ops.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
Conv2D::Conv2D(int out_channel, const std::vector<int> &kernel_size, int mode, const std::string &pad_mode,
|
||||
const std::vector<int> &pad, const std::vector<int> &stride, const std::vector<int> &dilation, int group)
|
||||
: OpCell("Conv2D"),
|
||||
out_channel(out_channel),
|
||||
kernel_size(kernel_size),
|
||||
mode(mode),
|
||||
pad_mode(pad_mode),
|
||||
pad(pad),
|
||||
stride(stride),
|
||||
dilation(dilation),
|
||||
group(group) {}
|
||||
|
||||
Output Conv2D::operator()(const Input &input1, const Input &input2) const {
|
||||
return CellBase::operator()({input1, input2})[0];
|
||||
}
|
||||
|
||||
std::vector<Output> Conv2D::Construct(const std::vector<Input> &inputs) {
|
||||
return {Output(shared_from_this(), inputs, 1)};
|
||||
}
|
||||
} // namespace mindspore::api
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "include/api/serialization.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
Status Serialization::LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Status Serialization::SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return FAILED;
|
||||
}
|
||||
} // namespace mindspore::api
|
|
@ -0,0 +1,226 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "include/api/types.h"
|
||||
#include <numeric>
|
||||
#include "securec/include/securec.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
class DataImpl {
|
||||
public:
|
||||
DataImpl() : data_() {}
|
||||
~DataImpl() = default;
|
||||
DataImpl(const void *data, size_t data_len) { SetData(data, data_len); }
|
||||
|
||||
const void *Data() const { return data_.data(); }
|
||||
void *MutableData() { return data_.data(); }
|
||||
size_t DataSize() const { return data_.size(); }
|
||||
|
||||
bool ResizeData(size_t data_len) {
|
||||
data_.resize(data_len);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SetData(const void *data, size_t data_len) {
|
||||
ResizeData(data_len);
|
||||
if (DataSize() != data_len) {
|
||||
MS_LOG(ERROR) << "Set data failed, tensor current data size " << DataSize() << " not match data len " << data_len;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (data == nullptr) {
|
||||
return data_len == 0;
|
||||
}
|
||||
|
||||
if (MutableData() == nullptr) {
|
||||
MS_LOG(ERROR) << "Set data failed, data len " << data_len;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto ret = memcpy_s(MutableData(), DataSize(), data, data_len);
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "Set data memcpy_s failed, ret = " << ret;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<uint8_t> data_;
|
||||
};
|
||||
|
||||
class Buffer::Impl : public DataImpl {
|
||||
public:
|
||||
Impl() : DataImpl() {}
|
||||
~Impl() = default;
|
||||
Impl(const void *data, size_t data_len) : DataImpl(data, data_len) {}
|
||||
};
|
||||
|
||||
class Tensor::Impl : public DataImpl {
|
||||
public:
|
||||
Impl() : DataImpl(), name_(), type_(DataType::kMsUnknown), shape_() {}
|
||||
~Impl() = default;
|
||||
Impl(const std::string &name, api::DataType type, const std::vector<int64_t> &shape, const void *data,
|
||||
size_t data_len)
|
||||
: DataImpl(data, data_len), name_(name), type_(type), shape_(shape) {}
|
||||
|
||||
const std::string &Name() const { return name_; }
|
||||
void SetName(const std::string &name) { name_ = name; }
|
||||
|
||||
api::DataType DataType() const { return type_; }
|
||||
void SetDataType(api::DataType type) { type_ = type; }
|
||||
|
||||
void SetShape(const std::vector<int64_t> &shape) { shape_ = shape; }
|
||||
const std::vector<int64_t> &Shape() const { return shape_; }
|
||||
|
||||
int64_t ElementNum() const {
|
||||
std::vector<int64_t> shapex = Shape();
|
||||
return std::accumulate(shapex.begin(), shapex.end(), 1LL, std::multiplies<int64_t>());
|
||||
}
|
||||
|
||||
static int GetTypeSize(api::DataType type) {
|
||||
static const std::map<api::DataType, size_t> type_size_map = {
|
||||
{kMsBool, sizeof(bool)}, {kMsFloat64, sizeof(double)}, {kMsInt8, sizeof(int8_t)},
|
||||
{kMsUint8, sizeof(uint8_t)}, {kMsInt16, sizeof(int16_t)}, {kMsUint16, sizeof(uint16_t)},
|
||||
{kMsInt32, sizeof(int32_t)}, {kMsUint32, sizeof(uint32_t)}, {kMsInt64, sizeof(int64_t)},
|
||||
{kMsUint64, sizeof(uint64_t)}, {kMsFloat16, sizeof(uint16_t)}, {kMsFloat32, sizeof(float)},
|
||||
};
|
||||
auto it = type_size_map.find(type);
|
||||
if (it != type_size_map.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
MS_LOG(WARNING) << "Cannot find data type " << type;
|
||||
return 0;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
api::DataType type_;
|
||||
std::vector<int64_t> shape_;
|
||||
};
|
||||
|
||||
Tensor::Tensor() : impl_(std::make_shared<Impl>()) {}
|
||||
Tensor::Tensor(const std::string &name, api::DataType type, const std::vector<int64_t> &shape, const void *data,
|
||||
size_t data_len)
|
||||
: impl_(std::make_shared<Impl>(name, type, shape, data, data_len)) {}
|
||||
Tensor::~Tensor() = default;
|
||||
|
||||
Tensor Tensor::Clone() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
Tensor ret;
|
||||
ret.impl_ = std::make_shared<Impl>(*impl_);
|
||||
return ret;
|
||||
}
|
||||
|
||||
const std::string &Tensor::Name() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->Name();
|
||||
}
|
||||
|
||||
void Tensor::SetName(const std::string &name) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
impl_->SetName(name);
|
||||
}
|
||||
|
||||
DataType Tensor::DataType() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->DataType();
|
||||
}
|
||||
|
||||
void Tensor::SetDataType(api::DataType type) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
impl_->SetDataType(type);
|
||||
}
|
||||
|
||||
const std::vector<int64_t> &Tensor::Shape() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->Shape();
|
||||
}
|
||||
|
||||
void Tensor::SetShape(const std::vector<int64_t> &shape) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
impl_->SetShape(shape);
|
||||
}
|
||||
|
||||
const void *Tensor::Data() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->Data();
|
||||
}
|
||||
|
||||
void *Tensor::MutableData() {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->MutableData();
|
||||
}
|
||||
|
||||
size_t Tensor::DataSize() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->DataSize();
|
||||
}
|
||||
|
||||
bool Tensor::ResizeData(size_t data_len) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->ResizeData(data_len);
|
||||
}
|
||||
|
||||
bool Tensor::SetData(const void *data, size_t data_len) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->SetData(data, data_len);
|
||||
}
|
||||
|
||||
int64_t Tensor::ElementNum() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->ElementNum();
|
||||
}
|
||||
|
||||
int Tensor::GetTypeSize(api::DataType type) { return Impl::GetTypeSize(type); }
|
||||
|
||||
Buffer::Buffer() : impl_(std::make_shared<Impl>()) {}
|
||||
Buffer::Buffer(const void *data, size_t data_len) : impl_(std::make_shared<Impl>(data, data_len)) {}
|
||||
Buffer::~Buffer() = default;
|
||||
|
||||
Buffer Buffer::Clone() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
Buffer ret;
|
||||
ret.impl_ = std::make_shared<Impl>(*impl_);
|
||||
return ret;
|
||||
}
|
||||
|
||||
const void *Buffer::Data() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->Data();
|
||||
}
|
||||
|
||||
void *Buffer::MutableData() {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->MutableData();
|
||||
}
|
||||
|
||||
size_t Buffer::DataSize() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->DataSize();
|
||||
}
|
||||
|
||||
bool Buffer::ResizeData(size_t data_len) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->ResizeData(data_len);
|
||||
}
|
||||
|
||||
bool Buffer::SetData(const void *data, size_t data_len) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->SetData(data, data_len);
|
||||
}
|
||||
} // namespace mindspore::api
|
|
@ -1,5 +1,6 @@
|
|||
if (ENABLE_GE OR ENABLE_D)
|
||||
if (ENABLE_GE OR ENABLE_D OR ENABLE_ACL)
|
||||
file(GLOB_RECURSE _TRANSFORM_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
list(REMOVE_ITEM _TRANSFORM_SRC_LIST "graph_ir/op_declare/hcom_ops_declare.cc")
|
||||
set_property(SOURCE ${_TRANSFORM_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_GE_ADPT)
|
||||
add_library(_mindspore_transform_graph_ir_obj OBJECT ${_TRANSFORM_SRC_LIST})
|
||||
|
||||
|
|
|
@ -1579,7 +1579,6 @@ OperatorPtr DfGraphConvertor::ConvertParameter(const AnfNodePtr node) {
|
|||
// build index for parameter using name
|
||||
std::string name = std::static_pointer_cast<Parameter>(node)->name();
|
||||
params_[name] = node;
|
||||
|
||||
std::ostringstream ss;
|
||||
ss << "op" << node.get();
|
||||
op_draw_name_[node.get()] = ss.str();
|
||||
|
|
|
@ -76,8 +76,6 @@ file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
|
||||
list(APPEND SERVING_SRC "main.cc" ${hw_proto_srcs} ${hw_grpc_srcs} ${CORE_SRC_LIST})
|
||||
|
||||
option(ENABLE_ACL "enable acl" OFF)
|
||||
|
||||
if (ENABLE_ACL)
|
||||
if (DEFINED ENV{ASCEND_CUSTOM_PATH})
|
||||
set(ASCEND_PATH $ENV{ASCEND_CUSTOM_PATH})
|
||||
|
@ -85,9 +83,11 @@ if (ENABLE_ACL)
|
|||
set(ASCEND_PATH /usr/local/Ascend)
|
||||
endif ()
|
||||
set(ACL_LIB_DIR ${ASCEND_PATH}/acllib/)
|
||||
MESSAGE("acl lib dir " ${ACL_LIB_DIR})
|
||||
set(ATLAS_ACL_LIB_DIR ${ASCEND_PATH}/ascend-toolkit/latest/acllib)
|
||||
MESSAGE("hisi acl lib dir " ${ACL_LIB_DIR} " ,atlas acl lib dir " ${ATLAS_ACL_LIB_DIR})
|
||||
|
||||
include_directories(${ACL_LIB_DIR}/include/)
|
||||
include_directories(${ATLAS_ACL_LIB_DIR}/include/)
|
||||
file(GLOB_RECURSE ACL_SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "acl/*.cc")
|
||||
list(APPEND SERVING_SRC ${ACL_SESSION_SRC_LIST})
|
||||
endif ()
|
||||
|
@ -112,10 +112,13 @@ endif ()
|
|||
if (ENABLE_ACL)
|
||||
add_compile_definitions(ENABLE_ACL)
|
||||
add_compile_definitions(ENABLE_DVPP_INTERFACE)
|
||||
set(ALC_LIB_SO ${ACL_LIB_DIR}/lib64/libruntime.so ${ACL_LIB_DIR}/lib64/libascendcl.so
|
||||
${ACL_LIB_DIR}/lib64/libacl_retr.so ${ACL_LIB_DIR}/lib64/libacl_cblas.so
|
||||
${ACL_LIB_DIR}/lib64/libacl_dvpp.so)
|
||||
target_link_libraries(ms_serving ${ALC_LIB_SO})
|
||||
find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
||||
find_library(acl_retr libacl_retr.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
||||
find_library(acl_cblas libacl_cblas.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
||||
find_library(acl_dvpp libacl_dvpp.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
||||
find_library(acl_runtime libruntime.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
||||
|
||||
target_link_libraries(ms_serving ${acl} ${acl_retr} ${acl_cblas} ${acl_dvpp} ${acl_runtime})
|
||||
target_link_libraries(ms_serving jpeg_turbo::jpeg securec)
|
||||
else ()
|
||||
target_link_libraries(ms_serving inference mindspore_gvar)
|
||||
|
|
6
setup.py
6
setup.py
|
@ -130,7 +130,11 @@ package_data = {
|
|||
'lib/*.so*',
|
||||
'lib/*.a',
|
||||
'.commit_id',
|
||||
'ms_serving'
|
||||
'ms_serving',
|
||||
'include/*',
|
||||
'include/*/*',
|
||||
'include/*/*/*',
|
||||
'include/*/*/*/*'
|
||||
]
|
||||
}
|
||||
|
||||
|
|
|
@ -56,6 +56,7 @@ if(ENABLE_MINDDATA)
|
|||
./utils/*.cc
|
||||
./vm/*.cc
|
||||
./ps/*.cc
|
||||
./cxx_api/*.cc
|
||||
)
|
||||
|
||||
if(NOT ENABLE_PYTHON)
|
||||
|
@ -176,7 +177,7 @@ if (USE_GLOG)
|
|||
target_link_libraries(ut_tests PRIVATE mindspore::glog)
|
||||
endif()
|
||||
|
||||
target_link_libraries(ut_tests PRIVATE mindspore securec graph)
|
||||
target_link_libraries(ut_tests PRIVATE mindspore mindspore_shared_lib securec graph)
|
||||
|
||||
# link grpc
|
||||
if (EXISTS ${grpc_ROOT}/lib64)
|
||||
|
|
|
@ -0,0 +1,169 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <memory>
|
||||
#include "common/common_test.h"
|
||||
#include "include/api/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestCxxApiTypes : public UT::Common {
|
||||
public:
|
||||
TestCxxApiTypes() = default;
|
||||
};
|
||||
|
||||
TEST_F(TestCxxApiTypes, test_tensor_set_name_SUCCESS) {
|
||||
std::string tensor_name_before = "TEST1";
|
||||
std::string tensor_name_after = "TEST2";
|
||||
api::Tensor tensor1(tensor_name_before, api::DataType::kMsFloat32, {}, nullptr, 0);
|
||||
api::Tensor tensor2 = tensor1;
|
||||
api::Tensor tensor3 = tensor1.Clone();
|
||||
|
||||
// name
|
||||
ASSERT_EQ(tensor1.Name(), tensor_name_before);
|
||||
ASSERT_EQ(tensor2.Name(), tensor_name_before);
|
||||
ASSERT_EQ(tensor3.Name(), tensor_name_before);
|
||||
|
||||
tensor1.SetName(tensor_name_after);
|
||||
ASSERT_EQ(tensor1.Name(), tensor_name_after);
|
||||
ASSERT_EQ(tensor2.Name(), tensor_name_after);
|
||||
ASSERT_EQ(tensor3.Name(), tensor_name_before);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiTypes, test_tensor_set_dtype_SUCCESS) {
|
||||
api::Tensor tensor1("", api::DataType::kMsFloat32, {}, nullptr, 0);
|
||||
api::Tensor tensor2 = tensor1;
|
||||
api::Tensor tensor3 = tensor1.Clone();
|
||||
|
||||
// dtype
|
||||
ASSERT_EQ(tensor1.DataType(), api::DataType::kMsFloat32);
|
||||
ASSERT_EQ(tensor2.DataType(), api::DataType::kMsFloat32);
|
||||
ASSERT_EQ(tensor3.DataType(), api::DataType::kMsFloat32);
|
||||
|
||||
tensor1.SetDataType(api::DataType::kMsUint32);
|
||||
ASSERT_EQ(tensor1.DataType(), api::DataType::kMsUint32);
|
||||
ASSERT_EQ(tensor2.DataType(), api::DataType::kMsUint32);
|
||||
ASSERT_EQ(tensor3.DataType(), api::DataType::kMsFloat32);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiTypes, test_tensor_set_shape_SUCCESS) {
|
||||
std::vector<int64_t> shape = {3, 4, 5, 6};
|
||||
api::Tensor tensor1("", api::DataType::kMsFloat32, {}, nullptr, 0);
|
||||
api::Tensor tensor2 = tensor1;
|
||||
api::Tensor tensor3 = tensor1.Clone();
|
||||
|
||||
// shape
|
||||
ASSERT_EQ(tensor1.Shape(), std::vector<int64_t>());
|
||||
ASSERT_EQ(tensor2.Shape(), std::vector<int64_t>());
|
||||
ASSERT_EQ(tensor3.Shape(), std::vector<int64_t>());
|
||||
|
||||
tensor1.SetShape(shape);
|
||||
ASSERT_EQ(tensor1.Shape(), shape);
|
||||
ASSERT_EQ(tensor2.Shape(), shape);
|
||||
ASSERT_EQ(tensor3.Shape(), std::vector<int64_t>());
|
||||
}
|
||||
|
||||
|
||||
TEST_F(TestCxxApiTypes, test_tensor_util_SUCCESS) {
|
||||
std::vector<int64_t> shape = {3, 4, 5, 6};
|
||||
std::vector<uint32_t> data(3 * 4 * 5 * 6, 123);
|
||||
api::Tensor tensor1("", api::DataType::kMsFloat32, shape, data.data(), data.size() * sizeof(uint32_t));
|
||||
|
||||
// data
|
||||
ASSERT_EQ(api::Tensor::GetTypeSize(api::DataType::kMsUint32), sizeof(uint32_t));
|
||||
ASSERT_EQ(tensor1.ElementNum(), 3 * 4 * 5 * 6);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiTypes, test_tensor_data_ref_and_copy_SUCCESS) {
|
||||
std::vector<int64_t> shape = {3, 4, 5, 6};
|
||||
std::vector<uint32_t> data(3 * 4 * 5 * 6, 123);
|
||||
api::Tensor tensor1("", api::DataType::kMsFloat32, shape, data.data(), data.size() * sizeof(uint32_t));
|
||||
api::Tensor tensor2 = tensor1;
|
||||
api::Tensor tensor3 = tensor1.Clone();
|
||||
|
||||
// data
|
||||
ASSERT_EQ(tensor1.DataSize(), tensor2.DataSize());
|
||||
ASSERT_EQ(tensor1.DataSize(), tensor3.DataSize());
|
||||
ASSERT_EQ(tensor1.Data(), tensor2.MutableData());
|
||||
ASSERT_NE(tensor1.Data(), tensor3.Data());
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiTypes, test_tensor_resize_data_SUCCESS) {
|
||||
std::vector<int64_t> shape = {3, 4, 5, 6};
|
||||
std::vector<uint32_t> data(3 * 4 * 5 * 6, 123);
|
||||
api::Tensor tensor1("", api::DataType::kMsFloat32, shape, data.data(), data.size() * sizeof(uint32_t));
|
||||
|
||||
// data
|
||||
ASSERT_EQ(tensor1.ResizeData(0), true);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiTypes, test_tensor_set_data_wrong_data_size_FAILED) {
|
||||
std::vector<int64_t> shape = {3, 4, 5, 6};
|
||||
std::vector<uint32_t> data(3 * 4 * 5 * 6, 123);
|
||||
api::Tensor tensor1("", api::DataType::kMsFloat32, shape, data.data(), data.size() * sizeof(uint32_t));
|
||||
|
||||
// data
|
||||
ASSERT_EQ(tensor1.SetData(nullptr, 1), false);
|
||||
ASSERT_EQ(tensor1.SetData(data.data(), 0), false);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiTypes, test_tensor_set_data_SUCCESS) {
|
||||
std::vector<int64_t> shape = {3, 4, 5, 6};
|
||||
std::vector<uint32_t> data(3 * 4 * 5 * 6, 123);
|
||||
api::Tensor tensor1("", api::DataType::kMsFloat32, shape, data.data(), data.size() * sizeof(uint32_t));
|
||||
|
||||
// data
|
||||
ASSERT_EQ(tensor1.SetData(nullptr, 0), true);
|
||||
ASSERT_EQ(tensor1.SetData(data.data(), data.size() * sizeof(uint32_t)), true);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiTypes, test_buffer_data_ref_and_copy_SUCCESS) {
|
||||
std::vector<uint32_t> data(3 * 4 * 5 * 6, 123);
|
||||
api::Buffer buffer1(data.data(), data.size() * sizeof(uint32_t));
|
||||
api::Buffer buffer2 = buffer1;
|
||||
api::Buffer buffer3 = buffer1.Clone();
|
||||
|
||||
// data
|
||||
ASSERT_EQ(buffer1.DataSize(), buffer2.DataSize());
|
||||
ASSERT_EQ(buffer1.DataSize(), buffer3.DataSize());
|
||||
ASSERT_EQ(buffer1.Data(), buffer2.MutableData());
|
||||
ASSERT_NE(buffer1.Data(), buffer3.Data());
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiTypes, test_buffer_resize_data_SUCCESS) {
|
||||
std::vector<uint32_t> data(3 * 4 * 5 * 6, 123);
|
||||
api::Buffer buffer1(data.data(), data.size() * sizeof(uint32_t));
|
||||
|
||||
// data
|
||||
ASSERT_EQ(buffer1.ResizeData(0), true);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiTypes, test_buffer_set_data_wrong_data_size_FAILED) {
|
||||
std::vector<uint32_t> data(3 * 4 * 5 * 6, 123);
|
||||
api::Buffer buffer1(data.data(), data.size() * sizeof(uint32_t));
|
||||
|
||||
// data
|
||||
ASSERT_EQ(buffer1.SetData(nullptr, 1), false);
|
||||
ASSERT_EQ(buffer1.SetData(data.data(), 0), false);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiTypes, test_buffer_set_data_SUCCESS) {
|
||||
std::vector<uint32_t> data(3 * 4 * 5 * 6, 123);
|
||||
api::Buffer buffer1(data.data(), data.size() * sizeof(uint32_t));
|
||||
|
||||
// data
|
||||
ASSERT_EQ(buffer1.SetData(nullptr, 0), true);
|
||||
ASSERT_EQ(buffer1.SetData(data.data(), data.size() * sizeof(uint32_t)), true);
|
||||
}
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue