mindspore c++ interface
Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
parent
756594d000
commit
3204ecb7d6
6
build.sh
6
build.sh
|
@ -23,7 +23,7 @@ usage()
|
|||
{
|
||||
echo "Usage:"
|
||||
echo "bash build.sh [-d] [-r] [-v] [-c on|off] [-t on|off] [-g on|off] [-h] [-b ge] [-m infer|train] \\"
|
||||
echo " [-a on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|ascend|cpu|acl] \\"
|
||||
echo " [-a on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|ascend|cpu|ascend310] \\"
|
||||
echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 9.2|10.1] [-I arm64|arm32|x86_64] [-K] \\"
|
||||
echo " [-B on|off] [-E] [-l on|off] [-n full|lite|off] [-T on|off] \\"
|
||||
echo " [-A [cpp|java|object-c] [-C on|off] [-o on|off] [-S on|off] [-k on|off] [-W sse|neon|avx|off] \\"
|
||||
|
@ -45,7 +45,7 @@ usage()
|
|||
echo " -i Enable increment building, default off"
|
||||
echo " -L Enable load ANF-IR as input of 'infer', default off"
|
||||
echo " -j[n] Set the threads when building (Default: -j8)"
|
||||
echo " -e Use cpu, gpu, ascend or acl"
|
||||
echo " -e Use cpu, gpu, ascend or ascend310"
|
||||
echo " -P Enable dump anf graph to file in ProtoBuffer format, default on"
|
||||
echo " -D Enable dumping of function graph ir, default on"
|
||||
echo " -z Compile dataset & mindrecord, default on"
|
||||
|
@ -224,7 +224,7 @@ checkopts()
|
|||
ENABLE_D="on"
|
||||
ENABLE_CPU="on"
|
||||
ENABLE_SERVING="on"
|
||||
elif [[ "X$OPTARG" == "Xacl" ]]; then
|
||||
elif [[ "X$OPTARG" == "Xascend310" ]]; then
|
||||
ENABLE_SERVING="on"
|
||||
ENABLE_ACL="on"
|
||||
elif [[ "X$OPTARG" == "Xcpu" ]]; then
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <memory>
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
|
@ -34,6 +35,7 @@ class MS_API CellBase {
|
|||
virtual ~CellBase() = default;
|
||||
virtual std::vector<Output> Construct(const std::vector<Input> &inputs) { return {}; }
|
||||
virtual std::shared_ptr<CellBase> Clone() const = 0;
|
||||
virtual Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { return SUCCESS; }
|
||||
std::vector<Output> operator()(const std::vector<Input> &inputs) const;
|
||||
};
|
||||
|
||||
|
@ -41,9 +43,7 @@ template <class T>
|
|||
class MS_API Cell : public CellBase {
|
||||
public:
|
||||
virtual ~Cell() = default;
|
||||
std::shared_ptr<CellBase> Clone() const override {
|
||||
return std::make_shared<T>(static_cast<const T&>(*this));
|
||||
}
|
||||
std::shared_ptr<CellBase> Clone() const override { return std::make_shared<T>(static_cast<const T &>(*this)); }
|
||||
};
|
||||
|
||||
class MS_API ParameterCell final : public Cell<ParameterCell> {
|
||||
|
@ -84,9 +84,33 @@ class MS_API OpCell : public OpCellBase, public std::enable_shared_from_this<T>
|
|||
public:
|
||||
explicit OpCell(const std::string &name) : OpCellBase(name) {}
|
||||
~OpCell() override = default;
|
||||
std::shared_ptr<CellBase> Clone() const override {
|
||||
return std::make_shared<T>(static_cast<const T&>(*this));
|
||||
}
|
||||
std::shared_ptr<CellBase> Clone() const override { return std::make_shared<T>(static_cast<const T &>(*this)); }
|
||||
};
|
||||
|
||||
class MS_API GraphCell final : public Cell<GraphCell> {
|
||||
public:
|
||||
class GraphImpl;
|
||||
|
||||
GraphCell() = default;
|
||||
~GraphCell() override = default;
|
||||
|
||||
explicit GraphCell(const Graph &);
|
||||
explicit GraphCell(Graph &&);
|
||||
explicit GraphCell(const std::shared_ptr<Graph> &);
|
||||
|
||||
const std::shared_ptr<Graph> &GetGraph() const { return graph_; }
|
||||
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
|
||||
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
|
||||
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
|
||||
|
||||
private:
|
||||
friend class ModelImpl;
|
||||
Status Load();
|
||||
|
||||
std::shared_ptr<Graph> graph_;
|
||||
std::shared_ptr<GraphImpl> executor_;
|
||||
};
|
||||
|
||||
class MS_API InputAndOutput {
|
||||
|
@ -96,7 +120,7 @@ class MS_API InputAndOutput {
|
|||
|
||||
// no explicit
|
||||
InputAndOutput(const Tensor &); // NOLINT(runtime/explicit)
|
||||
InputAndOutput(Tensor &&); // NOLINT(runtime/explicit)
|
||||
InputAndOutput(Tensor &&); // NOLINT(runtime/explicit)
|
||||
|
||||
InputAndOutput(const std::shared_ptr<CellBase> &, const std::vector<InputAndOutput> &, int32_t index);
|
||||
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
|
||||
#define MINDSPORE_INCLUDE_API_CONTEXT_H
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "include/api/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
class MS_API Context {
|
||||
public:
|
||||
static Context &Instance();
|
||||
const std::string &GetDeviceTarget() const;
|
||||
Context &SetDeviceTarget(const std::string &device_target);
|
||||
uint32_t GetDeviceID() const;
|
||||
Context &SetDeviceID(uint32_t device_id);
|
||||
|
||||
private:
|
||||
Context();
|
||||
~Context();
|
||||
class ContextImpl;
|
||||
std::shared_ptr<ContextImpl> impl_;
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_GRAPH_H
|
||||
#define MINDSPORE_INCLUDE_API_GRAPH_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
class MS_API Graph {
|
||||
public:
|
||||
class GraphData;
|
||||
explicit Graph(const std::shared_ptr<GraphData> &graph_data);
|
||||
explicit Graph(std::shared_ptr<GraphData> &&graph_data);
|
||||
|
||||
enum ModelType ModelType() const;
|
||||
|
||||
private:
|
||||
friend class GraphCell;
|
||||
friend class ModelImpl;
|
||||
std::shared_ptr<GraphData> graph_data_;
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_GRAPH_H
|
|
@ -22,42 +22,39 @@
|
|||
#include <memory>
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/graph.h"
|
||||
#include "include/api/cell.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
class ModelImpl;
|
||||
// todo: minddata c++ interface
|
||||
class DataSet {};
|
||||
class NetWork {};
|
||||
|
||||
class MS_API Model {
|
||||
public:
|
||||
Model(const std::string &device_type, uint32_t device_id);
|
||||
Model(NetWork network, const std::string &device_type, uint32_t device_id);
|
||||
explicit Model(const std::vector<Output> &network);
|
||||
explicit Model(const GraphCell &graph);
|
||||
~Model();
|
||||
Model(const Model &) = delete;
|
||||
void operator=(const Model &) = delete;
|
||||
|
||||
Status LoadModel(const Buffer &model_data, ModelType type, const std::map<std::string, std::string> &options);
|
||||
Status LoadModel(const std::string &file_name, ModelType type, const std::map<std::string, std::string> &options);
|
||||
Status UnloadModel();
|
||||
Status Build(const std::map<std::string, std::string> &options);
|
||||
|
||||
Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs);
|
||||
Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs);
|
||||
Status Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs);
|
||||
Status Predict(const std::vector<Buffer> &inputs, std::map<std::string, Buffer> *outputs);
|
||||
Status Train(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs);
|
||||
Status Eval(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs);
|
||||
Status Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs);
|
||||
|
||||
Status GetInputsInfo(std::vector<Tensor> *tensor_list) const;
|
||||
Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const;
|
||||
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
|
||||
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
|
||||
|
||||
static bool CheckModelSupport(const std::string& device_type, ModelType model_type);
|
||||
static bool CheckModelSupport(const std::string &device_type, ModelType model_type);
|
||||
|
||||
private:
|
||||
std::shared_ptr<ModelImpl> impl_;
|
||||
};
|
||||
|
||||
extern MS_API const char* kDeviceTypeAscendCL;
|
||||
extern MS_API const char* kDeviceTypeAscendMS;
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_MODEL_H
|
||||
|
|
|
@ -23,11 +23,13 @@
|
|||
#include "include/api/status.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/model.h"
|
||||
#include "include/api/graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
class MS_API Serialization {
|
||||
public:
|
||||
static Graph LoadModel(const std::string &file, ModelType model_type);
|
||||
static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters);
|
||||
static Status SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model);
|
||||
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
|
||||
|
|
|
@ -102,6 +102,9 @@ class MS_API Buffer {
|
|||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
extern MS_API const char *kDeviceTypeAscend310;
|
||||
extern MS_API const char *kDeviceTypeAscend910;
|
||||
|
||||
constexpr auto kModelOptionDumpCfgPath = "mindspore.option.dump_config_file_path";
|
||||
constexpr auto kModelOptionDvppCfgPath = "mindspore.option.dvpp_config_file_path";
|
||||
constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file
|
||||
|
|
|
@ -6,22 +6,35 @@ set(LOAD_MINDIR_SRC
|
|||
file(GLOB_RECURSE API_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR} "ops/*.cc")
|
||||
|
||||
if (ENABLE_ACL)
|
||||
file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/acl/*.cc" "model/model_converter_utils/*.cc")
|
||||
elseif (ENABLE_D)
|
||||
file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/ms/*.cc")
|
||||
add_compile_definitions(ENABLE_ACL)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/graphengine/src/ge)
|
||||
include_directories(${CMAKE_BINARY_DIR}/proto/ge)
|
||||
file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"model/acl/*.cc"
|
||||
"model/model_converter_utils/*.cc"
|
||||
"graph/acl/*.cc"
|
||||
)
|
||||
|
||||
endif ()
|
||||
if (ENABLE_D)
|
||||
file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/ms/*.cc" "graph/ms/*.cc")
|
||||
endif ()
|
||||
|
||||
set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cell.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc
|
||||
${API_MS_INFER_SRC}
|
||||
${API_ACL_SRC}
|
||||
${API_OPS_SRC}
|
||||
${LOAD_MINDIR_SRC})
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/context.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cell.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/python_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/graph/graph.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/graph/graph_data.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc
|
||||
${API_MS_INFER_SRC}
|
||||
${API_ACL_SRC}
|
||||
${API_OPS_SRC}
|
||||
${LOAD_MINDIR_SRC})
|
||||
|
||||
add_library(mindspore_shared_lib SHARED ${MSLIB_SRC})
|
||||
set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore PUBLIC_HEADER "${API_INCLUDE}")
|
||||
set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore)
|
||||
|
||||
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
|
||||
-Wl,--whole-archive mindspore -Wl,--no-whole-archive proto_input mindspore_gvar mindspore::protobuf)
|
||||
|
@ -69,5 +82,6 @@ endif ()
|
|||
if (ENABLE_D)
|
||||
find_library(adump_server libadump_server.a ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
target_link_libraries(mindspore_shared_lib PRIVATE ${adump_server})
|
||||
target_link_libraries(mindspore_shared_lib PRIVATE mindspore_core hccl_adapter)
|
||||
endif ()
|
||||
|
||||
|
|
|
@ -14,6 +14,9 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "include/api/cell.h"
|
||||
#include "include/api/context.h"
|
||||
#include "cxx_api/factory.h"
|
||||
#include "cxx_api/graph/graph_impl.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
std::vector<Output> CellBase::operator()(const std::vector<Input> &inputs) const { return Clone()->Construct(inputs); }
|
||||
|
@ -51,6 +54,52 @@ ParameterCell &ParameterCell::operator=(Tensor &&tensor) {
|
|||
return *this;
|
||||
}
|
||||
|
||||
GraphCell::GraphCell(const Graph &graph)
|
||||
: graph_(std::make_shared<Graph>(graph)),
|
||||
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) {
|
||||
MS_EXCEPTION_IF_NULL(graph_);
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
executor_->SetGraph(graph_);
|
||||
}
|
||||
|
||||
GraphCell::GraphCell(const std::shared_ptr<Graph> &graph)
|
||||
: graph_(graph),
|
||||
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) {
|
||||
MS_EXCEPTION_IF_NULL(graph_);
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
executor_->SetGraph(graph_);
|
||||
}
|
||||
|
||||
GraphCell::GraphCell(Graph &&graph)
|
||||
: graph_(std::make_shared<Graph>(graph)),
|
||||
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) {
|
||||
MS_EXCEPTION_IF_NULL(graph_);
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
executor_->SetGraph(graph_);
|
||||
}
|
||||
|
||||
Status GraphCell::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
return executor_->Run(inputs, outputs);
|
||||
}
|
||||
|
||||
Status GraphCell::Load() {
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
return executor_->Load();
|
||||
}
|
||||
|
||||
Status GraphCell::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
return executor_->GetInputsInfo(names, shapes, data_types, mem_sizes);
|
||||
}
|
||||
|
||||
Status GraphCell::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
return executor_->GetOutputsInfo(names, shapes, data_types, mem_sizes);
|
||||
}
|
||||
|
||||
InputAndOutput::InputAndOutput() : cell_(nullptr), prev_(), index_(-1) {}
|
||||
|
||||
InputAndOutput::InputAndOutput(const Tensor &tensor)
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "include/api/context.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
class Context::ContextImpl {
|
||||
public:
|
||||
ContextImpl() : device_target_("NotSet"), device_id_(0) {}
|
||||
const std::string &GetDeviceTarget() const { return device_target_; }
|
||||
void SetDeviceTarget(std::string_view device_target) { device_target_ = device_target; }
|
||||
uint32_t GetDeviceID() const { return device_id_; }
|
||||
void SetDeviceID(uint32_t device_id) { device_id_ = device_id; }
|
||||
|
||||
private:
|
||||
std::string device_target_;
|
||||
uint32_t device_id_;
|
||||
};
|
||||
|
||||
Context &Context::Instance() {
|
||||
static Context context;
|
||||
return context;
|
||||
}
|
||||
|
||||
const std::string &Context::GetDeviceTarget() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->GetDeviceTarget();
|
||||
}
|
||||
|
||||
Context &Context::SetDeviceTarget(const std::string &device_target) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
impl_->SetDeviceTarget(device_target);
|
||||
return *this;
|
||||
}
|
||||
|
||||
uint32_t Context::GetDeviceID() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->GetDeviceID();
|
||||
}
|
||||
|
||||
Context &Context::SetDeviceID(uint32_t device_id) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
impl_->SetDeviceID(device_id);
|
||||
return *this;
|
||||
}
|
||||
|
||||
Context::Context() : impl_(std::make_shared<Context::ContextImpl>()) { MS_EXCEPTION_IF_NULL(impl_); }
|
||||
|
||||
Context::~Context() {}
|
||||
} // namespace mindspore::api
|
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_CXX_API_FACTORY_H
|
||||
#define MINDSPORE_CCSRC_CXX_API_FACTORY_H
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
template <class T>
|
||||
class Factory {
|
||||
using U = std::function<std::shared_ptr<T>()>;
|
||||
|
||||
public:
|
||||
Factory(const Factory &) = delete;
|
||||
void operator=(const Factory &) = delete;
|
||||
|
||||
static Factory &Instance() {
|
||||
static Factory instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void Register(const std::string &device_name, U &&creator) {
|
||||
if (creators_.find(device_name) == creators_.end()) {
|
||||
(void)creators_.emplace(device_name, creator);
|
||||
}
|
||||
}
|
||||
|
||||
bool CheckModelSupport(const std::string &device_name) {
|
||||
return std::any_of(creators_.begin(), creators_.end(),
|
||||
[&device_name](const std::pair<std::string, U> &item) { return item.first == device_name; });
|
||||
}
|
||||
|
||||
std::shared_ptr<T> Create(const std::string &device_name) {
|
||||
auto iter = creators_.find(device_name);
|
||||
if (creators_.end() != iter) {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
return (iter->second)();
|
||||
}
|
||||
|
||||
MS_LOG(ERROR) << "Unsupported device target " << device_name;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
Factory() = default;
|
||||
~Factory() = default;
|
||||
std::map<std::string, U> creators_;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class Registrar {
|
||||
using U = std::function<std::shared_ptr<T>()>;
|
||||
|
||||
public:
|
||||
Registrar(const std::string &device_name, U creator) {
|
||||
Factory<T>::Instance().Register(device_name, std::move(creator));
|
||||
}
|
||||
~Registrar() = default;
|
||||
};
|
||||
|
||||
#define API_FACTORY_REG(BASE_CLASS, DEVICE_NAME, DERIVE_CLASS) \
|
||||
static const Registrar<BASE_CLASS> g_api_##DERIVE_CLASS##_registrar_##DEVICE_NAME##_reg( \
|
||||
#DEVICE_NAME, []() { return std::make_shared<DERIVE_CLASS>(); });
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_FACTORY_H
|
|
@ -0,0 +1,266 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cxx_api/graph/acl/acl_graph_impl.h"
|
||||
#include "include/api/context.h"
|
||||
#include "cxx_api/model/acl/model_converter.h"
|
||||
#include "cxx_api/python_utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
API_FACTORY_REG(GraphCell::GraphImpl, Ascend310, AclGraphImpl);
|
||||
std::weak_ptr<AclGraphImpl::AclEnvGuard> AclGraphImpl::global_acl_env_;
|
||||
std::mutex AclGraphImpl::global_acl_env_mutex_;
|
||||
|
||||
AclGraphImpl::AclGraphImpl()
|
||||
: init_flag_(false),
|
||||
load_flag_(false),
|
||||
device_type_("AscendCL"),
|
||||
device_id_(Context::Instance().GetDeviceID()),
|
||||
context_(nullptr),
|
||||
acl_env_(nullptr) {}
|
||||
|
||||
AclGraphImpl::~AclGraphImpl() { (void)FinalizeEnv(); }
|
||||
|
||||
Status AclGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
Status ret = Load();
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Prepare model resource failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return model_process_.PredictFromHost(inputs, outputs);
|
||||
}
|
||||
|
||||
Status AclGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
|
||||
Status ret = Load();
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Prepare model resource failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return model_process_.GetInputsInfo(names, shapes, data_types, mem_sizes);
|
||||
}
|
||||
|
||||
Status AclGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
|
||||
Status ret = Load();
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Prepare model resource failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return model_process_.GetOutputsInfo(names, shapes, data_types, mem_sizes);
|
||||
}
|
||||
|
||||
Status AclGraphImpl::LoadAclModel(Buffer om_data) {
|
||||
MS_LOG(INFO) << "Start load acl model.";
|
||||
// acl load model
|
||||
uint32_t acl_model_id;
|
||||
auto acl_ret = aclmdlLoadFromMem(om_data.Data(), om_data.DataSize(), &acl_model_id);
|
||||
if (acl_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Call aclmdlLoadFromMem failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// acl init model resource
|
||||
model_process_.set_model_id(acl_model_id);
|
||||
Status ret = model_process_.PreInitModelResource();
|
||||
if (ret != SUCCESS) {
|
||||
(void)aclmdlUnload(acl_model_id);
|
||||
MS_LOG(ERROR) << "Pre init model resource failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Load acl model success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status AclGraphImpl::InitEnv() {
|
||||
if (init_flag_) {
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
aclError ret;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(global_acl_env_mutex_);
|
||||
acl_env_ = global_acl_env_.lock();
|
||||
if (acl_env_ != nullptr) {
|
||||
MS_LOG(INFO) << "Acl has been initialized, skip.";
|
||||
} else {
|
||||
acl_env_ = std::make_shared<AclEnvGuard>("");
|
||||
if (acl_env_->GetErrno() != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Execute aclInit Failed";
|
||||
return FAILED;
|
||||
}
|
||||
global_acl_env_ = acl_env_;
|
||||
MS_LOG(INFO) << "Acl init success";
|
||||
}
|
||||
}
|
||||
|
||||
ret = aclrtSetDevice(device_id_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Acl open device " << device_id_ << " failed";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Open device " << device_id_ << " success";
|
||||
|
||||
ret = aclrtCreateContext(&context_, device_id_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Acl create context failed";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Create context success";
|
||||
|
||||
aclrtRunMode run_mode;
|
||||
ret = aclrtGetRunMode(&run_mode);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Acl get run mode failed";
|
||||
return FAILED;
|
||||
}
|
||||
bool is_device = (run_mode == ACL_DEVICE);
|
||||
model_process_.SetIsDevice(is_device);
|
||||
MS_LOG(INFO) << "Get run mode success is device input/output " << is_device;
|
||||
|
||||
MS_LOG(INFO) << "Init acl success, device id " << device_id_;
|
||||
init_flag_ = true;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status AclGraphImpl::FinalizeEnv() {
|
||||
if (!init_flag_) {
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
aclError rt_ret = aclrtSetCurrentContext(context_);
|
||||
if (rt_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Set the ascend device context failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Status ret = model_process_.UnLoad();
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Unload model inner failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (context_ != nullptr) {
|
||||
rt_ret = aclrtDestroyContext(context_);
|
||||
if (rt_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Destroy context failed";
|
||||
}
|
||||
context_ = nullptr;
|
||||
}
|
||||
MS_LOG(INFO) << "End to destroy context";
|
||||
|
||||
rt_ret = aclrtResetDevice(device_id_);
|
||||
if (rt_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Reset device " << device_id_ << " failed";
|
||||
}
|
||||
MS_LOG(INFO) << "End to reset device " << device_id_;
|
||||
|
||||
init_flag_ = false;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status AclGraphImpl::Load() {
|
||||
// check graph type
|
||||
if (graph_->ModelType() != ModelType::kOM) {
|
||||
Status ret = ConvertToOM();
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Load Failed.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
const auto &graph_data = GraphImpl::MutableGraphData();
|
||||
MS_EXCEPTION_IF_NULL(graph_data);
|
||||
auto om_data = graph_data->GetOMData();
|
||||
|
||||
// init
|
||||
Status ret = InitEnv();
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "InitEnv failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// load model
|
||||
if (!load_flag_) {
|
||||
ret = LoadAclModel(om_data);
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Load acl model failed.";
|
||||
return ret;
|
||||
}
|
||||
load_flag_ = true;
|
||||
}
|
||||
|
||||
aclError rt_ret = aclrtSetCurrentContext(context_);
|
||||
if (rt_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Set the ascend device context failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status AclGraphImpl::ConvertToOM() {
|
||||
MS_LOG(INFO) << "Start convert to om model.";
|
||||
RegAllOpFromPython();
|
||||
if (graph_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid graph_ is null.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
auto &graph_data = GraphImpl::MutableGraphData();
|
||||
MS_EXCEPTION_IF_NULL(graph_data);
|
||||
if (graph_->ModelType() == ModelType::kOM) {
|
||||
MS_LOG(INFO) << "This model has been built, skip.";
|
||||
return SUCCESS;
|
||||
} else if (graph_->ModelType() == ModelType::kMindIR) {
|
||||
auto func_graph = graph_data->GetFuncGraph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
ModelConverter model_converter;
|
||||
Buffer om_data = model_converter.LoadMindIR(func_graph);
|
||||
if (om_data.Data() == nullptr || om_data.DataSize() == 0) {
|
||||
MS_LOG(ERROR) << "Convert MindIR to OM failed.";
|
||||
return FAILED;
|
||||
}
|
||||
graph_data = std::make_shared<Graph::GraphData>(om_data, ModelType::kOM);
|
||||
MS_LOG(INFO) << "Convert MindIR to OM success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
MS_LOG(ERROR) << "Unsupported ModelType " << graph_->ModelType();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
AclGraphImpl::AclEnvGuard::AclEnvGuard(std::string_view cfg_file) {
|
||||
errno_ = aclInit(cfg_file.data());
|
||||
if (errno_ != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Execute aclInit Failed";
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Acl init success";
|
||||
}
|
||||
|
||||
AclGraphImpl::AclEnvGuard::~AclEnvGuard() {
|
||||
errno_ = aclFinalize();
|
||||
if (errno_ != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Finalize acl failed";
|
||||
}
|
||||
MS_LOG(INFO) << "Acl finalize success";
|
||||
}
|
||||
} // namespace mindspore::api
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H
|
||||
#define MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "include/api/graph.h"
|
||||
#include "cxx_api/graph/acl/model_process.h"
|
||||
#include "cxx_api/graph/graph_impl.h"
|
||||
#include "cxx_api/factory.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
class AclGraphImpl : public GraphCell::GraphImpl {
|
||||
public:
|
||||
AclGraphImpl();
|
||||
~AclGraphImpl() override;
|
||||
|
||||
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
|
||||
Status Load() override;
|
||||
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
|
||||
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
|
||||
|
||||
private:
|
||||
class AclEnvGuard;
|
||||
|
||||
Status ConvertToOM();
|
||||
Status InitEnv();
|
||||
Status FinalizeEnv();
|
||||
Status LoadAclModel(Buffer om_data);
|
||||
|
||||
bool init_flag_;
|
||||
bool load_flag_;
|
||||
std::string device_type_;
|
||||
int32_t device_id_;
|
||||
aclrtContext context_;
|
||||
|
||||
std::shared_ptr<AclEnvGuard> acl_env_;
|
||||
static std::weak_ptr<AclEnvGuard> global_acl_env_;
|
||||
static std::mutex global_acl_env_mutex_;
|
||||
|
||||
ModelProcess model_process_;
|
||||
};
|
||||
|
||||
class AclGraphImpl::AclEnvGuard {
|
||||
public:
|
||||
explicit AclEnvGuard(std::string_view cfg_file);
|
||||
~AclEnvGuard();
|
||||
aclError GetErrno() const { return errno_; }
|
||||
|
||||
private:
|
||||
aclError errno_;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cxx_api/model/acl/model_process.h"
|
||||
#include "cxx_api/graph/acl/model_process.h"
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include "utils/utils.h"
|
||||
|
@ -35,17 +35,33 @@ static DataType TransToApiType(aclDataType data_type) {
|
|||
}
|
||||
}
|
||||
|
||||
static void ConstructTensorDesc(const std::vector<AclTensorInfo> &acl_tensor_list, std::vector<Tensor> *tensor_list) {
|
||||
MS_EXCEPTION_IF_NULL(tensor_list);
|
||||
tensor_list->clear();
|
||||
template <class T>
|
||||
inline static void ClearIfNotNull(T *vec) {
|
||||
if (vec != nullptr) {
|
||||
vec->clear();
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, class U = std::vector<T>>
|
||||
inline static void PushbackIfNotNull(U *vec, T &&item) {
|
||||
if (vec != nullptr) {
|
||||
vec->emplace_back(item);
|
||||
}
|
||||
}
|
||||
|
||||
static void ConstructTensorDesc(const std::vector<AclTensorInfo> &acl_tensor_list, std::vector<std::string> *names,
|
||||
std::vector<std::vector<int64_t>> *shapes, std::vector<DataType> *data_types,
|
||||
std::vector<size_t> *mem_sizes) {
|
||||
ClearIfNotNull(names);
|
||||
ClearIfNotNull(shapes);
|
||||
ClearIfNotNull(data_types);
|
||||
ClearIfNotNull(mem_sizes);
|
||||
for (size_t i = 0; i < acl_tensor_list.size(); ++i) {
|
||||
const auto &info = acl_tensor_list[i];
|
||||
Tensor tensor_desc;
|
||||
tensor_desc.SetName(info.name);
|
||||
tensor_desc.SetDataType(TransToApiType(info.data_type));
|
||||
tensor_desc.SetShape(info.dims);
|
||||
tensor_list->push_back(tensor_desc);
|
||||
PushbackIfNotNull(names, info.name);
|
||||
PushbackIfNotNull(shapes, info.dims);
|
||||
PushbackIfNotNull(data_types, TransToApiType(info.data_type));
|
||||
PushbackIfNotNull(mem_sizes, info.buffer_size);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -272,7 +288,7 @@ Status ModelProcess::UnLoad() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ModelProcess::CheckAndInitInput(const std::map<std::string, Buffer> &inputs) {
|
||||
Status ModelProcess::CheckAndInitInput(const std::vector<Buffer> &inputs) {
|
||||
aclError ret;
|
||||
inputs_ = aclmdlCreateDataset();
|
||||
// check inputs
|
||||
|
@ -282,29 +298,16 @@ Status ModelProcess::CheckAndInitInput(const std::map<std::string, Buffer> &inpu
|
|||
return INVALID_INPUTS;
|
||||
}
|
||||
for (size_t i = 0; i < input_infos_.size(); ++i) {
|
||||
const std::string &input_name = input_infos_[i].name;
|
||||
auto iter = inputs.find(input_name);
|
||||
if (iter == inputs.end()) {
|
||||
MS_LOG(ERROR) << "Model missing input " << input_name;
|
||||
return INVALID_INPUTS;
|
||||
}
|
||||
|
||||
if (iter->second.DataSize() != input_infos_[i].buffer_size) {
|
||||
if (inputs[i].DataSize() != input_infos_[i].buffer_size) {
|
||||
MS_LOG(ERROR) << "input " << i << " data size not match, required size " << input_infos_[i].buffer_size
|
||||
<< ", given count " << iter->second.DataSize();
|
||||
<< ", given count " << inputs[i].DataSize();
|
||||
return INVALID_INPUTS;
|
||||
}
|
||||
}
|
||||
// copy inputs
|
||||
for (size_t i = 0; i < input_infos_.size(); ++i) {
|
||||
const auto &info = input_infos_[i];
|
||||
auto iter = inputs.find(info.name);
|
||||
if (iter == inputs.end()) {
|
||||
MS_LOG(ERROR) << "Model missing input " << info.name;
|
||||
return INVALID_INPUTS;
|
||||
}
|
||||
|
||||
const auto &input = iter->second;
|
||||
const auto &input = inputs[i];
|
||||
const void *data = input.Data();
|
||||
|
||||
void *input_buffer = nullptr;
|
||||
|
@ -333,42 +336,7 @@ Status ModelProcess::CheckAndInitInput(const std::map<std::string, Buffer> &inpu
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ModelProcess::CheckAndInitDvppInput(const void *dvpp_outputs_buffer_dev, size_t dvpp_outputs_buffer_size,
|
||||
size_t input_index) {
|
||||
aclError ret;
|
||||
inputs_ = aclmdlCreateDataset();
|
||||
// check inputs
|
||||
if (input_index >= input_infos_.size()) {
|
||||
MS_LOG(ERROR) << "inputs count not match, required count " << input_infos_.size() << ", given index "
|
||||
<< input_index;
|
||||
return INVALID_INPUTS;
|
||||
}
|
||||
if (dvpp_outputs_buffer_dev == nullptr) {
|
||||
MS_LOG(ERROR) << "input " << 0 << " cannot be null";
|
||||
return FAILED;
|
||||
}
|
||||
if (dvpp_outputs_buffer_size != input_infos_[input_index].buffer_size) {
|
||||
MS_LOG(ERROR) << "input " << 0 << " data size not match, required size " << input_infos_[input_index].buffer_size
|
||||
<< ", given count " << dvpp_outputs_buffer_size;
|
||||
return INVALID_INPUTS;
|
||||
}
|
||||
// copy inputs
|
||||
auto &info = input_infos_[input_index];
|
||||
auto data_buffer = aclCreateDataBuffer(const_cast<void *>(dvpp_outputs_buffer_dev), info.buffer_size);
|
||||
if (data_buffer == nullptr) {
|
||||
MS_LOG(ERROR) << "Create Data Buffer failed";
|
||||
return FAILED;
|
||||
}
|
||||
ret = aclmdlAddDatasetBuffer(inputs_, data_buffer);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "add data buffer failed";
|
||||
aclDestroyDataBuffer(data_buffer);
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ModelProcess::Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) {
|
||||
Status ModelProcess::PredictFromHost(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
aclError acl_ret;
|
||||
Status ret = CheckAndInitInput(inputs);
|
||||
|
@ -392,18 +360,7 @@ Status ModelProcess::Predict(const std::map<std::string, Buffer> &inputs, std::m
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
size_t ModelProcess::GetBatchSize() const {
|
||||
if (input_infos_.empty()) {
|
||||
MS_LOG(ERROR) << "Model is not loaded";
|
||||
return 0;
|
||||
}
|
||||
if (input_infos_[0].dims.empty()) {
|
||||
return 1;
|
||||
}
|
||||
return static_cast<size_t>(input_infos_[0].dims[0]);
|
||||
}
|
||||
|
||||
Status ModelProcess::BuildOutputs(std::map<std::string, Buffer> *outputs) {
|
||||
Status ModelProcess::BuildOutputs(std::vector<Buffer> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
aclError ret;
|
||||
// copy outputs
|
||||
|
@ -411,14 +368,13 @@ Status ModelProcess::BuildOutputs(std::map<std::string, Buffer> *outputs) {
|
|||
aclrtMemcpyKind kind = is_run_on_device_ ? ACL_MEMCPY_HOST_TO_HOST : ACL_MEMCPY_DEVICE_TO_HOST;
|
||||
for (size_t i = 0; i < output_infos_.size(); ++i) {
|
||||
const auto &info = output_infos_[i];
|
||||
// todo
|
||||
outputs->emplace(info.name, Buffer());
|
||||
auto output = outputs->rbegin()->second;
|
||||
if (!output.ResizeData(info.buffer_size)) {
|
||||
outputs->emplace_back(Buffer());
|
||||
auto output = outputs->rbegin();
|
||||
if (!output->ResizeData(info.buffer_size)) {
|
||||
MS_LOG(ERROR) << "new output data buffer failed, data size " << info.buffer_size;
|
||||
return FAILED;
|
||||
}
|
||||
ret = aclrtMemcpy(output.MutableData(), output.DataSize(), info.device_data, info.buffer_size, kind);
|
||||
ret = aclrtMemcpy(output->MutableData(), output->DataSize(), info.device_data, info.buffer_size, kind);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Memcpy output " << i << " from " << (is_run_on_device_ ? "host" : "device")
|
||||
<< " to host failed, memory size " << info.buffer_size;
|
||||
|
@ -428,13 +384,15 @@ Status ModelProcess::BuildOutputs(std::map<std::string, Buffer> *outputs) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ModelProcess::GetInputsInfo(std::vector<Tensor> *tensor_list) const {
|
||||
ConstructTensorDesc(input_infos_, tensor_list);
|
||||
Status ModelProcess::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
|
||||
ConstructTensorDesc(input_infos_, names, shapes, data_types, mem_sizes);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ModelProcess::GetOutputsInfo(std::vector<Tensor> *tensor_list) const {
|
||||
ConstructTensorDesc(output_infos_, tensor_list);
|
||||
Status ModelProcess::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
|
||||
ConstructTensorDesc(output_infos_, names, shapes, data_types, mem_sizes);
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace mindspore::api
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_PROCESS_H
|
||||
#define MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_PROCESS_H
|
||||
#ifndef MINDSPORE_CCSRC_CXXAPI_GRAPH_ACL_MODEL_PROCESS_H
|
||||
#define MINDSPORE_CCSRC_CXXAPI_GRAPH_ACL_MODEL_PROCESS_H
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
@ -34,12 +34,6 @@ struct AclTensorInfo {
|
|||
std::string name;
|
||||
};
|
||||
|
||||
struct ImagesDvppOutput {
|
||||
void *buffer_device = nullptr;
|
||||
size_t buffer_size = 0;
|
||||
size_t input_index = 0;
|
||||
};
|
||||
|
||||
class ModelProcess {
|
||||
public:
|
||||
ModelProcess()
|
||||
|
@ -53,24 +47,23 @@ class ModelProcess {
|
|||
~ModelProcess() {}
|
||||
Status LoadModelFromFile(const std::string &file_name, uint32_t *model_id);
|
||||
Status UnLoad();
|
||||
Status Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs);
|
||||
Status PredictFromHost(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs);
|
||||
Status PreInitModelResource();
|
||||
Status GetInputsInfo(std::vector<Tensor> *tensor_list) const;
|
||||
Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const;
|
||||
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
|
||||
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
|
||||
|
||||
// override this method to avoid request/reply data copy
|
||||
void SetIsDevice(bool is_device) { is_run_on_device_ = is_device; }
|
||||
|
||||
size_t GetBatchSize() const;
|
||||
void set_model_id(uint32_t model_id) { model_id_ = model_id; }
|
||||
uint32_t model_id() const { return model_id_; }
|
||||
|
||||
private:
|
||||
Status CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset);
|
||||
Status CheckAndInitInput(const std::map<std::string, Buffer> &inputs);
|
||||
Status CheckAndInitDvppInput(const void *dvpp_outputs_buffer_dev, size_t dvpp_outputs_buffer_size,
|
||||
size_t input_index);
|
||||
Status BuildOutputs(std::map<std::string, Buffer> *outputs);
|
||||
Status CheckAndInitInput(const std::vector<Buffer> &inputs);
|
||||
Status BuildOutputs(std::vector<Buffer> *outputs);
|
||||
Status InitInputsBuffer();
|
||||
Status InitOutputsBuffer();
|
||||
|
||||
|
@ -90,4 +83,4 @@ class ModelProcess {
|
|||
};
|
||||
} // namespace mindspore::api
|
||||
|
||||
#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_PROCESS_H
|
||||
#endif // MINDSPORE_CCSRC_CXXAPI_GRAPH_ACL_MODEL_PROCESS_H
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "include/api/graph.h"
|
||||
#include "cxx_api/graph/graph_data.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
Graph::Graph(const std::shared_ptr<GraphData> &graph_data) : graph_data_(graph_data) {}
|
||||
|
||||
Graph::Graph(std::shared_ptr<GraphData> &&graph_data) : graph_data_(graph_data) {}
|
||||
|
||||
ModelType Graph::ModelType() const {
|
||||
MS_EXCEPTION_IF_NULL(graph_data_);
|
||||
return graph_data_->ModelType();
|
||||
}
|
||||
} // namespace mindspore::api
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cxx_api/graph/graph_data.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#ifdef ENABLE_ACL
|
||||
#include "framework/common/helper/model_helper.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore::api {
|
||||
Graph::GraphData::GraphData(const FuncGraphPtr &func_graph, enum ModelType model_type)
|
||||
: func_graph_(nullptr), om_data_(), model_type_(ModelType::kUnknownType) {
|
||||
if (model_type != ModelType::kMindIR) {
|
||||
MS_LOG(EXCEPTION) << "Invalid ModelType " << model_type;
|
||||
}
|
||||
func_graph_ = func_graph;
|
||||
model_type_ = model_type;
|
||||
}
|
||||
|
||||
Graph::GraphData::GraphData(Buffer om_data, enum ModelType model_type)
|
||||
: func_graph_(nullptr), om_data_(), model_type_(ModelType::kUnknownType) {
|
||||
if (model_type != ModelType::kOM) {
|
||||
MS_LOG(EXCEPTION) << "Invalid ModelType " << model_type;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ACL
|
||||
// check om
|
||||
ge::ModelHelper helper;
|
||||
ge::ModelData model_data;
|
||||
model_data.model_data = om_data.MutableData();
|
||||
model_data.model_len = om_data.DataSize();
|
||||
ge::Status ret = helper.LoadModel(model_data);
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Invalid input data cannot parse to om.";
|
||||
}
|
||||
|
||||
om_data_ = om_data;
|
||||
model_type_ = model_type;
|
||||
#else
|
||||
MS_LOG(EXCEPTION) << "Unsupported ModelType OM.";
|
||||
#endif
|
||||
}
|
||||
|
||||
FuncGraphPtr Graph::GraphData::GetFuncGraph() const {
|
||||
if (model_type_ != ModelType::kMindIR) {
|
||||
MS_LOG(ERROR) << "Invalid ModelType " << model_type_;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return func_graph_;
|
||||
}
|
||||
|
||||
Buffer Graph::GraphData::GetOMData() const {
|
||||
if (model_type_ != ModelType::kOM) {
|
||||
MS_LOG(ERROR) << "Invalid ModelType " << model_type_;
|
||||
return Buffer();
|
||||
}
|
||||
|
||||
return om_data_;
|
||||
}
|
||||
} // namespace mindspore::api
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H
|
||||
#define MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "include/api/graph.h"
|
||||
#include "include/api/types.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
class Graph::GraphData {
|
||||
public:
|
||||
GraphData();
|
||||
|
||||
explicit GraphData(const FuncGraphPtr &func_graph, enum ModelType model_type = kMindIR);
|
||||
|
||||
GraphData(Buffer om_data, enum ModelType model_type);
|
||||
|
||||
enum ModelType ModelType() const { return model_type_; }
|
||||
|
||||
FuncGraphPtr GetFuncGraph() const;
|
||||
|
||||
Buffer GetOMData() const;
|
||||
|
||||
private:
|
||||
FuncGraphPtr func_graph_;
|
||||
Buffer om_data_;
|
||||
enum ModelType model_type_;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H
|
||||
#define MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "include/api/cell.h"
|
||||
#include "include/api/graph.h"
|
||||
#include "cxx_api/graph/graph_data.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
class GraphCell::GraphImpl {
|
||||
public:
|
||||
GraphImpl() = default;
|
||||
virtual ~GraphImpl() = default;
|
||||
|
||||
std::shared_ptr<Graph::GraphData> &MutableGraphData() const { return graph_->graph_data_; }
|
||||
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
|
||||
|
||||
virtual Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) = 0;
|
||||
virtual Status Load() = 0;
|
||||
|
||||
virtual Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) = 0;
|
||||
virtual Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) = 0;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<Graph> graph_;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H
|
|
@ -0,0 +1,334 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cxx_api/graph/ms/ms_graph_impl.h"
|
||||
#include <algorithm>
|
||||
#include "include/api/context.h"
|
||||
#include "cxx_api/factory.h"
|
||||
#include "cxx_api/python_utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/context/context_extends.h"
|
||||
#include "mindspore/core/base/base_ref_utils.h"
|
||||
#include "backend/session/session_factory.h"
|
||||
#include "backend/session/executor_manager.h"
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, MsGraphImpl);
|
||||
|
||||
static DataType TransTypeId2InferDataType(TypeId type_id) {
|
||||
const std::map<TypeId, api::DataType> id2type_map{
|
||||
{TypeId::kNumberTypeBegin, api::kMsUnknown}, {TypeId::kNumberTypeBool, api::kMsBool},
|
||||
{TypeId::kNumberTypeFloat64, api::kMsFloat64}, {TypeId::kNumberTypeInt8, api::kMsInt8},
|
||||
{TypeId::kNumberTypeUInt8, api::kMsUint8}, {TypeId::kNumberTypeInt16, api::kMsInt16},
|
||||
{TypeId::kNumberTypeUInt16, api::kMsUint16}, {TypeId::kNumberTypeInt32, api::kMsInt32},
|
||||
{TypeId::kNumberTypeUInt32, api::kMsUint32}, {TypeId::kNumberTypeInt64, api::kMsInt64},
|
||||
{TypeId::kNumberTypeUInt64, api::kMsUint64}, {TypeId::kNumberTypeFloat16, api::kMsFloat16},
|
||||
{TypeId::kNumberTypeFloat32, api::kMsFloat32},
|
||||
};
|
||||
|
||||
// cppcheck-suppress stlIfFind
|
||||
if (auto it = id2type_map.find(type_id); it != id2type_map.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
MS_LOG(WARNING) << "Unsupported data id " << type_id;
|
||||
return api::kMsUnknown;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
inline static void ClearIfNotNull(T *vec) {
|
||||
if (vec != nullptr) {
|
||||
vec->clear();
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, class U = std::vector<T>>
|
||||
inline static void PushbackIfNotNull(U *vec, T &&item) {
|
||||
if (vec != nullptr) {
|
||||
vec->emplace_back(item);
|
||||
}
|
||||
}
|
||||
|
||||
MsGraphImpl::MsGraphImpl()
|
||||
: session_impl_(nullptr),
|
||||
graph_id_(0),
|
||||
device_type_("Ascend"),
|
||||
device_id_(Context::Instance().GetDeviceID()),
|
||||
context_(nullptr),
|
||||
inputs_(),
|
||||
outputs_(),
|
||||
input_names_(),
|
||||
output_names_(),
|
||||
load_flag_(false) {}
|
||||
|
||||
MsGraphImpl::~MsGraphImpl() { (void)FinalizeEnv(); }
|
||||
|
||||
Status MsGraphImpl::InitEnv() {
|
||||
RegAllOpFromPython();
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
if (ms_context == nullptr) {
|
||||
MS_LOG(ERROR) << "Get Context failed!";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_);
|
||||
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kAscendDevice);
|
||||
if (!context::OpenTsd(ms_context)) {
|
||||
MS_LOG(ERROR) << "Session init OpenTsd failed!";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
session_impl_ = session::SessionFactory::Get().Create(kDavinciInferenceDevice);
|
||||
if (session_impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << kDavinciInferenceDevice
|
||||
<< " is available.";
|
||||
return FAILED;
|
||||
}
|
||||
session_impl_->Init(device_id_);
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MsGraphImpl::FinalizeEnv() {
|
||||
MS_LOG_INFO << "Start finalize env";
|
||||
pybind11::gil_scoped_acquire acquire;
|
||||
session::ExecutorManager::Instance().Clear();
|
||||
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
if (ms_context == nullptr) {
|
||||
MS_LOG(ERROR) << "Get Context failed!";
|
||||
return FAILED;
|
||||
}
|
||||
if (!context::CloseTsd(ms_context)) {
|
||||
MS_LOG(ERROR) << "CloseTsd failed!";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "End finalize env";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MsGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) {
|
||||
MS_ASSERT(session_impl_ != nullptr);
|
||||
try {
|
||||
graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
|
||||
pybind11::gil_scoped_release gil_release;
|
||||
return SUCCESS;
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "CompileGraph failed: " << e.what();
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<tensor::TensorPtr> MsGraphImpl::RunGraph(const std::vector<tensor::TensorPtr> &inputs) {
|
||||
try {
|
||||
VectorRef outputs;
|
||||
session_impl_->RunGraph(graph_id_, inputs, &outputs);
|
||||
return TransformVectorRefToMultiTensor(outputs);
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "RunGraph failed: " << e.what();
|
||||
return std::vector<tensor::TensorPtr>();
|
||||
}
|
||||
}
|
||||
|
||||
Status MsGraphImpl::CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const {
|
||||
MS_ASSERT(session_impl_ != nullptr);
|
||||
std::string error_msg;
|
||||
if (!session_impl_->CheckModelInputs(graph_id_, inputs, &error_msg)) {
|
||||
return Status(INVALID_INPUTS, error_msg);
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MsGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) {
|
||||
MS_EXCEPTION_IF_NULL(reply);
|
||||
if (context_ == nullptr) {
|
||||
MS_LOG(ERROR) << "rtCtx is nullptr";
|
||||
return FAILED;
|
||||
}
|
||||
rtError_t rt_ret = rtCtxSetCurrent(context_);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Set Ascend rtCtx failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
vector<tensor::TensorPtr> inputs;
|
||||
for (size_t i = 0; i < request.size(); i++) {
|
||||
auto &item = request[i];
|
||||
auto input = inputs_[i];
|
||||
if (input->Size() != item.DataSize()) {
|
||||
MS_LOG(ERROR) << "Input " << i << " data size " << item.DataSize() << " not match model input data size "
|
||||
<< input->Size();
|
||||
return FAILED;
|
||||
}
|
||||
auto ret = memcpy_s(input->data_c(), input->Size(), item.Data(), item.DataSize());
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Tensor copy failed";
|
||||
return FAILED;
|
||||
}
|
||||
inputs.push_back(input);
|
||||
}
|
||||
vector<tensor::TensorPtr> outputs = RunGraph(inputs);
|
||||
if (outputs.empty()) {
|
||||
MS_LOG(ERROR) << "Execute Model Failed";
|
||||
return FAILED;
|
||||
}
|
||||
reply->clear();
|
||||
std::transform(outputs.begin(), outputs.end(), std::back_inserter(*reply),
|
||||
[](const tensor::TensorPtr &tensor) { return Buffer(tensor->data_c(), tensor->Size()); });
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MsGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
|
||||
if (!load_flag_) {
|
||||
Status ret = Load();
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "PrepareModel failed.";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
ClearIfNotNull(names);
|
||||
ClearIfNotNull(shapes);
|
||||
ClearIfNotNull(data_types);
|
||||
ClearIfNotNull(mem_sizes);
|
||||
for (size_t i = 0; i < inputs_.size(); i++) {
|
||||
auto &tensor = inputs_[i];
|
||||
PushbackIfNotNull(names, input_names_[i]);
|
||||
PushbackIfNotNull(shapes, tensor->shape());
|
||||
PushbackIfNotNull(data_types, TransTypeId2InferDataType(tensor->data_type()));
|
||||
PushbackIfNotNull(mem_sizes, tensor->DataSize());
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MsGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) {
|
||||
if (!load_flag_) {
|
||||
Status ret = Load();
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "PrepareModel failed.";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
ClearIfNotNull(names);
|
||||
ClearIfNotNull(shapes);
|
||||
ClearIfNotNull(data_types);
|
||||
ClearIfNotNull(mem_sizes);
|
||||
for (size_t i = 0; i < outputs_.size(); i++) {
|
||||
auto &tensor = outputs_[i];
|
||||
PushbackIfNotNull(names, output_names_[i]);
|
||||
PushbackIfNotNull(shapes, tensor->shape());
|
||||
PushbackIfNotNull(data_types, TransTypeId2InferDataType(tensor->data_type()));
|
||||
PushbackIfNotNull(mem_sizes, tensor->DataSize());
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MsGraphImpl::Load() {
|
||||
// check graph type
|
||||
if (graph_->ModelType() != ModelType::kMindIR) {
|
||||
MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType();
|
||||
return INVALID_INPUTS;
|
||||
}
|
||||
|
||||
const auto &graph_data = GraphImpl::MutableGraphData();
|
||||
MS_EXCEPTION_IF_NULL(graph_data);
|
||||
auto func_graph = graph_data->GetFuncGraph();
|
||||
|
||||
// init
|
||||
Status ret = InitEnv();
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "InitEnv failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// load model
|
||||
if (!load_flag_) {
|
||||
ret = CompileGraph(func_graph);
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Compile graph model failed";
|
||||
return FAILED;
|
||||
}
|
||||
session_impl_->GetModelInputsInfo(graph_id_, &inputs_, &input_names_);
|
||||
session_impl_->GetModelOutputsInfo(graph_id_, &outputs_, &output_names_);
|
||||
if (inputs_.empty() || inputs_.size() != input_names_.size()) {
|
||||
MS_LOG_ERROR << "Get model inputs info failed";
|
||||
return FAILED;
|
||||
}
|
||||
if (outputs_.empty() || outputs_.size() != output_names_.size()) {
|
||||
MS_LOG_ERROR << "Get model outputs info failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// save d context
|
||||
rtError_t rt_ret = rtCtxGetCurrent(&context_);
|
||||
if (rt_ret != RT_ERROR_NONE || context_ == nullptr) {
|
||||
MS_LOG(ERROR) << "the ascend device context is null";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Load model success";
|
||||
load_flag_ = true;
|
||||
}
|
||||
|
||||
rtError_t rt_ret = rtCtxSetCurrent(context_);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Set the ascend device context failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MsGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
if (!load_flag_) {
|
||||
Status ret = Load();
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "PrepareModel failed.";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
if (inputs.size() != inputs_.size()) {
|
||||
MS_LOG(ERROR) << "inputs count not match, required count " << inputs_.size() << ", given count " << inputs.size();
|
||||
return INVALID_INPUTS;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < inputs_.size(); ++i) {
|
||||
if (inputs[i].DataSize() != inputs_[i]->Size()) {
|
||||
MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_[i]->Size() << ", given count "
|
||||
<< inputs[i].DataSize();
|
||||
return INVALID_INPUTS;
|
||||
}
|
||||
}
|
||||
if (ExecuteModel(inputs, outputs) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Execute Model Failed";
|
||||
return FAILED;
|
||||
}
|
||||
if (outputs_.size() != outputs->size()) {
|
||||
MS_LOG(ERROR) << "Predict output size " << outputs->size() << " not match output size got from model info "
|
||||
<< outputs_.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace mindspore::api
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H
|
||||
#define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/graph.h"
|
||||
#include "cxx_api/graph/graph_impl.h"
|
||||
#include "backend/session/session_basic.h"
|
||||
#include "ir/anf.h"
|
||||
#include "cxx_api/model/model_impl.h"
|
||||
#include "runtime/context.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
class MsGraphImpl : public GraphCell::GraphImpl {
|
||||
public:
|
||||
MsGraphImpl();
|
||||
~MsGraphImpl() override;
|
||||
|
||||
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
|
||||
Status Load() override;
|
||||
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
|
||||
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
|
||||
|
||||
private:
|
||||
Status InitEnv();
|
||||
Status FinalizeEnv();
|
||||
Status CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr);
|
||||
Status CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const;
|
||||
std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs);
|
||||
Status ExecuteModel(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs);
|
||||
|
||||
std::shared_ptr<session::SessionBasic> session_impl_;
|
||||
uint32_t graph_id_;
|
||||
std::string device_type_;
|
||||
uint32_t device_id_;
|
||||
rtContext_t context_;
|
||||
std::vector<tensor::TensorPtr> inputs_;
|
||||
std::vector<tensor::TensorPtr> outputs_;
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<std::string> output_names_;
|
||||
bool load_flag_;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H
|
|
@ -16,216 +16,57 @@
|
|||
|
||||
#include "cxx_api/model/acl/acl_model.h"
|
||||
#include <memory>
|
||||
#include "utils/context/context_extends.h"
|
||||
#include "cxx_api/factory.h"
|
||||
#include "cxx_api/python_utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
std::weak_ptr<AclModel::AclEnvGuard> AclModel::global_acl_env_;
|
||||
std::mutex AclModel::global_acl_env_mutex_;
|
||||
API_FACTORY_REG(ModelImpl, Ascend310, AclModel);
|
||||
|
||||
Status AclModel::InitEnv() {
|
||||
if (init_flag_) {
|
||||
Status AclModel::Build(const std::map<std::string, std::string> &options_map) {
|
||||
MS_LOG(INFO) << "Start build model.";
|
||||
MS_EXCEPTION_IF_NULL(graph_);
|
||||
RegAllOpFromPython();
|
||||
std::unique_ptr<AclModelOptions> options = std::make_unique<AclModelOptions>(options_map);
|
||||
std::string options_str = GenerateOptionsStr(options_map);
|
||||
MS_EXCEPTION_IF_NULL(options);
|
||||
if (graph_cell_ != nullptr && options_str == options_str_) {
|
||||
MS_LOG(INFO) << "This model has been built, skip.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(options_);
|
||||
aclError ret;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(global_acl_env_mutex_);
|
||||
acl_env_ = global_acl_env_.lock();
|
||||
if (acl_env_ != nullptr) {
|
||||
if (options_->dump_cfg_path.empty()) {
|
||||
MS_LOG(INFO) << "Acl has been initialized, skip.";
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Acl has been initialized, skip, so dump config will be ignored.";
|
||||
}
|
||||
} else {
|
||||
acl_env_ = std::make_shared<AclEnvGuard>(options_->dump_cfg_path);
|
||||
if (acl_env_->GetErrno() != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Execute aclInit Failed";
|
||||
return FAILED;
|
||||
}
|
||||
global_acl_env_ = acl_env_;
|
||||
MS_LOG(INFO) << "Acl init success";
|
||||
if (graph_cell_ == nullptr && graph_->ModelType() == ModelType::kOM) {
|
||||
graph_cell_ = std::make_shared<GraphCell>(graph_);
|
||||
MS_EXCEPTION_IF_NULL(graph_cell_);
|
||||
if (!options_map.empty()) {
|
||||
MS_LOG(WARNING) << "All build options will be ignored.";
|
||||
}
|
||||
}
|
||||
|
||||
ret = aclrtSetDevice(device_id_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Acl open device " << device_id_ << " failed";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Open device " << device_id_ << " success";
|
||||
|
||||
ret = aclrtCreateContext(&context_, device_id_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Acl create context failed";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Create context success";
|
||||
|
||||
ret = aclrtSetCurrentContext(context_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Acl set current context failed";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Set context success";
|
||||
|
||||
ret = aclrtCreateStream(&stream_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Acl create stream failed";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Create stream success";
|
||||
|
||||
aclrtRunMode run_mode;
|
||||
ret = aclrtGetRunMode(&run_mode);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Acl get run mode failed";
|
||||
return FAILED;
|
||||
}
|
||||
bool is_device = (run_mode == ACL_DEVICE);
|
||||
model_process_.SetIsDevice(is_device);
|
||||
MS_LOG(INFO) << "Get run mode success is device input/output " << is_device;
|
||||
|
||||
if (dvpp_process_.InitResource(stream_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "DVPP init resource failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Init acl success, device id " << device_id_;
|
||||
init_flag_ = true;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status AclModel::FinalizeEnv() {
|
||||
if (!init_flag_) {
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
dvpp_process_.Finalize();
|
||||
aclError ret;
|
||||
if (stream_ != nullptr) {
|
||||
ret = aclrtDestroyStream(stream_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Destroy stream failed";
|
||||
}
|
||||
stream_ = nullptr;
|
||||
}
|
||||
MS_LOG(INFO) << "End to destroy stream";
|
||||
if (context_ != nullptr) {
|
||||
ret = aclrtDestroyContext(context_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Destroy context failed";
|
||||
}
|
||||
context_ = nullptr;
|
||||
}
|
||||
MS_LOG(INFO) << "End to destroy context";
|
||||
|
||||
ret = aclrtResetDevice(device_id_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Reset devie " << device_id_ << " failed";
|
||||
}
|
||||
MS_LOG(INFO) << "End to reset device " << device_id_;
|
||||
|
||||
init_flag_ = false;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status AclModel::LoadModel(const Buffer &model_data, ModelType type,
|
||||
const std::map<std::string, std::string> &options) {
|
||||
if (load_flag_) {
|
||||
MS_LOG(ERROR) << "Model has been loaded.";
|
||||
auto func_graph = ModelImpl::GetFuncGraph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
model_converter_.set_options(options.get());
|
||||
auto om_data = model_converter_.LoadMindIR(func_graph);
|
||||
if (om_data.Data() == nullptr || om_data.DataSize() == 0) {
|
||||
MS_LOG(ERROR) << "Load MindIR failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
options_ = std::make_unique<AclModelOptions>(options);
|
||||
MS_EXCEPTION_IF_NULL(options_);
|
||||
|
||||
Status ret = InitEnv();
|
||||
auto graph = std::make_shared<Graph>(std::make_shared<Graph::GraphData>(om_data, ModelType::kOM));
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto graph_cell = std::make_shared<GraphCell>(graph);
|
||||
MS_EXCEPTION_IF_NULL(graph_cell);
|
||||
auto ret = ModelImpl::Load(graph_cell);
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "InitEnv failed.";
|
||||
return FAILED;
|
||||
MS_LOG(ERROR) << "Load failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
Buffer om_data;
|
||||
if (type == ModelType::kMindIR) {
|
||||
model_converter_.set_options(options_.get());
|
||||
om_data = model_converter_.LoadMindIR(model_data);
|
||||
} else if (type == ModelType::kAIR) {
|
||||
model_converter_.set_options(options_.get());
|
||||
om_data = model_converter_.LoadAscendIR(model_data);
|
||||
} else if (type == ModelType::kOM) {
|
||||
om_data = model_data;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported model type " << type;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// acl load model
|
||||
uint32_t acl_model_id;
|
||||
auto acl_ret = aclmdlLoadFromMem(om_data.Data(), om_data.DataSize(), &acl_model_id);
|
||||
if (acl_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Call aclmdlLoadFromMem failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// acl init model resource
|
||||
model_process_.set_model_id(acl_model_id);
|
||||
ret = model_process_.PreInitModelResource();
|
||||
if (ret != SUCCESS) {
|
||||
(void)aclmdlUnload(acl_model_id);
|
||||
MS_LOG(ERROR) << "Pre init model resource failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// acl init dvpp
|
||||
ret = dvpp_process_.InitWithJsonConfig(options_->dvpp_cfg_path);
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "DVPP config file parse error.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
load_flag_ = true;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status AclModel::LoadModel(const std::string &file_name, ModelType type,
|
||||
const std::map<std::string, std::string> &options) {
|
||||
Buffer model_data = ModelConverter::ReadFile(file_name);
|
||||
if (model_data.DataSize() == 0) {
|
||||
MS_LOG(ERROR) << "Read file " << file_name << " failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return LoadModel(model_data, type, options);
|
||||
}
|
||||
|
||||
Status AclModel::UnloadModel() {
|
||||
if (!load_flag_) {
|
||||
MS_LOG(WARNING) << "No model is loaded, skip unload.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
aclError rt_ret = aclrtSetCurrentContext(context_);
|
||||
if (rt_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Set the ascend device context failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Status ret = model_process_.UnLoad();
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Unload model inner failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
ret = FinalizeEnv();
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "FinalizeEnv failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Unload model success.";
|
||||
load_flag_ = false;
|
||||
// save result
|
||||
graph_cell_ = graph_cell;
|
||||
options_ = std::move(options);
|
||||
options_str_ = options_str;
|
||||
MS_LOG(INFO) << "Build model success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -239,45 +80,49 @@ Status AclModel::Eval(const DataSet &, std::map<std::string, Buffer> *) {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
Status AclModel::Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) {
|
||||
Status AclModel::Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
if (!load_flag_) {
|
||||
MS_LOG(ERROR) << "No model is loaded, predict failed.";
|
||||
if (graph_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid data, graph_ is null.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
aclError rt_ret = aclrtSetCurrentContext(context_);
|
||||
if (rt_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Set the ascend device context failed";
|
||||
if (graph_cell_ == nullptr) {
|
||||
MS_LOG(WARNING) << "Model has not been built, it will be built with default options";
|
||||
Status ret = Build({});
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Build model failed.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(graph_cell_);
|
||||
Status ret = graph_cell_->Run(inputs, outputs);
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Run graph failed.";
|
||||
return FAILED;
|
||||
}
|
||||
return model_process_.Predict(inputs, outputs);
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status AclModel::GetInputsInfo(std::vector<Tensor> *tensor_list) const {
|
||||
MS_EXCEPTION_IF_NULL(tensor_list);
|
||||
return model_process_.GetInputsInfo(tensor_list);
|
||||
Status AclModel::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
|
||||
MS_EXCEPTION_IF_NULL(graph_cell_);
|
||||
return graph_cell_->GetInputsInfo(names, shapes, data_types, mem_sizes);
|
||||
}
|
||||
|
||||
Status AclModel::GetOutputsInfo(std::vector<Tensor> *tensor_list) const {
|
||||
MS_EXCEPTION_IF_NULL(tensor_list);
|
||||
return model_process_.GetOutputsInfo(tensor_list);
|
||||
Status AclModel::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
|
||||
MS_EXCEPTION_IF_NULL(graph_cell_);
|
||||
return graph_cell_->GetOutputsInfo(names, shapes, data_types, mem_sizes);
|
||||
}
|
||||
|
||||
AclModel::AclEnvGuard::AclEnvGuard(const std::string &cfg_file) {
|
||||
errno_ = aclInit(common::SafeCStr(cfg_file));
|
||||
if (errno_ != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Execute aclInit Failed";
|
||||
return;
|
||||
std::string AclModel::GenerateOptionsStr(const std::map<std::string, std::string> &options) {
|
||||
std::string ret;
|
||||
for (auto &[key, value] : options) {
|
||||
ret += key + "^" + value + "^^";
|
||||
}
|
||||
MS_LOG(INFO) << "Acl init success";
|
||||
}
|
||||
|
||||
AclModel::AclEnvGuard::~AclEnvGuard() {
|
||||
errno_ = aclFinalize();
|
||||
if (errno_ != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Finalize acl failed";
|
||||
}
|
||||
MS_LOG(INFO) << "Acl finalize success";
|
||||
return ret;
|
||||
}
|
||||
} // namespace mindspore::api
|
||||
|
|
|
@ -23,77 +23,38 @@
|
|||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "include/api/cell.h"
|
||||
#include "include/api/status.h"
|
||||
#include "cxx_api/model/model_impl.h"
|
||||
#include "cxx_api/model/acl/dvpp_process.h"
|
||||
#include "cxx_api/model/acl/model_process.h"
|
||||
#include "cxx_api/model/acl/model_converter.h"
|
||||
#include "cxx_api/model/acl/acl_model_options.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
class AclModel : public ModelImpl {
|
||||
public:
|
||||
explicit AclModel(uint32_t device_id)
|
||||
: init_flag_(false),
|
||||
load_flag_(false),
|
||||
device_type_("AscendCL"),
|
||||
device_id_(device_id),
|
||||
context_(nullptr),
|
||||
stream_(nullptr),
|
||||
acl_env_(nullptr),
|
||||
model_process_(),
|
||||
dvpp_process_(),
|
||||
model_converter_(),
|
||||
options_(nullptr) {}
|
||||
AclModel() : model_converter_(), options_(nullptr), options_str_() {}
|
||||
~AclModel() = default;
|
||||
|
||||
Status LoadModel(const Buffer &model_data, ModelType type,
|
||||
const std::map<std::string, std::string> &options) override;
|
||||
Status LoadModel(const std::string &file_name, ModelType type,
|
||||
const std::map<std::string, std::string> &options) override;
|
||||
Status UnloadModel() override;
|
||||
Status Build(const std::map<std::string, std::string> &options_map) override;
|
||||
|
||||
Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
|
||||
Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
|
||||
Status Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) override;
|
||||
Status Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
|
||||
|
||||
Status GetInputsInfo(std::vector<Tensor> *tensor_list) const override;
|
||||
Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const override;
|
||||
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const override;
|
||||
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const override;
|
||||
|
||||
private:
|
||||
bool init_flag_;
|
||||
bool load_flag_;
|
||||
std::string device_type_;
|
||||
int32_t device_id_;
|
||||
aclrtContext context_;
|
||||
aclrtStream stream_;
|
||||
static std::string GenerateOptionsStr(const std::map<std::string, std::string> &options);
|
||||
|
||||
class AclEnvGuard;
|
||||
std::shared_ptr<AclEnvGuard> acl_env_;
|
||||
static std::weak_ptr<AclEnvGuard> global_acl_env_;
|
||||
static std::mutex global_acl_env_mutex_;
|
||||
|
||||
ModelProcess model_process_;
|
||||
DvppProcess dvpp_process_;
|
||||
std::shared_ptr<GraphCell> graph_cell_;
|
||||
ModelConverter model_converter_;
|
||||
std::unique_ptr<AclModelOptions> options_;
|
||||
|
||||
Status InitEnv();
|
||||
Status FinalizeEnv();
|
||||
std::string options_str_;
|
||||
};
|
||||
|
||||
class AclModel::AclEnvGuard {
|
||||
public:
|
||||
explicit AclEnvGuard(const std::string &cfg_file);
|
||||
~AclEnvGuard();
|
||||
aclError GetErrno() const { return errno_; }
|
||||
|
||||
private:
|
||||
aclError errno_;
|
||||
};
|
||||
|
||||
API_REG_MODEL(AscendCL, AclModel);
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_ACL_MODEL_H
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,160 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_DVPP_PROCESS_H
|
||||
#define MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_DVPP_PROCESS_H
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "acl/acl.h"
|
||||
#include "acl/acl_mdl.h"
|
||||
#include "acl/acl_rt.h"
|
||||
#include "acl/ops/acl_dvpp.h"
|
||||
#include "include/api/status.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
struct DvppDecodePara {
|
||||
acldvppPixelFormat pixel_format = PIXEL_FORMAT_YUV_SEMIPLANAR_420;
|
||||
};
|
||||
|
||||
struct DvppResizePara {
|
||||
uint32_t output_width = 0;
|
||||
uint32_t output_height = 0;
|
||||
};
|
||||
|
||||
enum DvppCropType {
|
||||
// crop left,top,right,bottom is given in config
|
||||
kDvppCropTypeOffset = 0,
|
||||
// crop left,top,right,bottom is calculated by image width/height and output crop width/height
|
||||
kDvppCropTypeCentre = 1,
|
||||
};
|
||||
|
||||
struct DvppRoiArea {
|
||||
uint32_t left = 0;
|
||||
uint32_t top = 0;
|
||||
uint32_t right = 0;
|
||||
uint32_t bottom = 0;
|
||||
};
|
||||
|
||||
struct DvppCropInfo {
|
||||
DvppCropType crop_type = kDvppCropTypeOffset;
|
||||
DvppRoiArea crop_area; // when kDvppCropTypeOffset
|
||||
uint32_t crop_width = 0; // when kDvppCropTypeCentre
|
||||
uint32_t crop_height = 0; // when kDvppCropTypeCentre
|
||||
};
|
||||
|
||||
struct DvppCropPara {
|
||||
DvppCropInfo crop_info;
|
||||
uint32_t output_width = 0;
|
||||
uint32_t output_height = 0;
|
||||
};
|
||||
|
||||
struct DvppCropAndPastePara {
|
||||
DvppCropInfo crop_info;
|
||||
DvppRoiArea paste_area;
|
||||
uint32_t output_width = 0;
|
||||
uint32_t output_height = 0;
|
||||
};
|
||||
|
||||
class DvppProcess {
|
||||
public:
|
||||
DvppProcess();
|
||||
~DvppProcess();
|
||||
|
||||
Status InitResource(aclrtStream stream);
|
||||
void Finalize();
|
||||
Status InitJpegDecodePara(const DvppDecodePara &decode_para); // jpeg decode + (resize | crop)
|
||||
Status InitResizePara(const DvppResizePara &resize_para); // jpeg decode + resize
|
||||
Status InitCropPara(const DvppCropPara &crop_para); // jpeg decode + crop
|
||||
Status InitCropAndPastePara(const DvppCropAndPastePara &crop_and_paste_para); // jpeg decode + crop&paste
|
||||
|
||||
Status InitWithJsonConfig(const std::string &json_config);
|
||||
|
||||
// output device buffer will be destroy by DvppProcess itself.
|
||||
Status Process(const void *pic_buffer, size_t pic_buffer_size, void **output_device_buffer, size_t *output_size);
|
||||
Status Process(const std::vector<const void *> &pic_buffer_list, const std::vector<size_t> &pic_buffer_size_list,
|
||||
void **output_device_buffer, size_t *output_size);
|
||||
bool HasLoaded() const { return loaded_flag_; }
|
||||
|
||||
private:
|
||||
bool loaded_flag_ = false;
|
||||
uint32_t pic_width_ = 0;
|
||||
uint32_t pic_height_ = 0;
|
||||
|
||||
DvppDecodePara decode_para_;
|
||||
DvppResizePara resize_para_;
|
||||
DvppCropPara crop_para_;
|
||||
DvppCropAndPastePara crop_and_paste_para_;
|
||||
// only one of the resize or crop flag can be true
|
||||
bool to_resize_flag_ = false;
|
||||
bool to_crop_flag_ = false;
|
||||
bool to_crop_and_paste_flag_ = false;
|
||||
|
||||
void *input_pic_dev_buffer_ = nullptr;
|
||||
uint32_t input_pic_buffer_size_ = 0;
|
||||
|
||||
uint32_t decode_output_buffer_size_ = 0;
|
||||
void *decode_output_buffer_dev_ = nullptr;
|
||||
acldvppPicDesc *decode_output_desc_ = nullptr;
|
||||
|
||||
acldvppResizeConfig *resize_config_ = nullptr;
|
||||
acldvppRoiConfig *crop_area_ = nullptr;
|
||||
acldvppRoiConfig *paste_area_ = nullptr;
|
||||
|
||||
acldvppPicDesc *vpc_output_desc_ = nullptr;
|
||||
void *vpc_output_buffer_dev_ = nullptr; // vpc_output_buffer_size_ length
|
||||
uint32_t vpc_output_buffer_size_ = 0;
|
||||
|
||||
void *batch_vpc_output_buffer_dev_ = nullptr; // batch_size_ * vpc_output_buffer_size_ length
|
||||
uint32_t batch_size_ = 0;
|
||||
|
||||
aclrtStream stream_ = nullptr;
|
||||
acldvppChannelDesc *dvpp_channel_desc_ = nullptr;
|
||||
|
||||
uint32_t AlignmentHelper(uint32_t org_size, uint32_t alignment) const;
|
||||
uint32_t GetImageBufferSize(uint32_t stride_width, uint32_t stride_height, acldvppPixelFormat pixel_format) const;
|
||||
Status GetPicDescStride(uint32_t width, uint32_t height, uint32_t *stride_width, uint32_t *stride_height);
|
||||
Status GetPicDescStrideDecode(uint32_t width, uint32_t height, uint32_t *stride_width, uint32_t *stride_height);
|
||||
Status InputInputBuffer(const void *pic_buffer, size_t pic_buffer_size);
|
||||
Status InitDecodeOutputDesc(uint32_t image_width,
|
||||
uint32_t image_height); // decode_output_desc_, decode_output_buffer_dev_
|
||||
Status CheckRoiAreaWidthHeight(uint32_t width, uint32_t height);
|
||||
Status CheckAndAdjustRoiArea(DvppRoiArea *area);
|
||||
Status UpdateCropArea(uint32_t image_width, uint32_t image_height);
|
||||
Status CheckResizeImageInfo(uint32_t image_width, uint32_t image_height) const;
|
||||
void DestroyDecodeDesc();
|
||||
|
||||
Status InitVpcOutputDesc(uint32_t output_width, uint32_t output_height,
|
||||
acldvppPixelFormat pixel_format); // vpc_output_desc_, vpc_output_buffer_dev_batch_
|
||||
Status InitRoiAreaConfig(const DvppRoiArea &init_para, acldvppRoiConfig **roi_area);
|
||||
Status InitCommonCropPara(uint32_t out_width, uint32_t out_height, DvppCropInfo *crop_info);
|
||||
Status InitResizeOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, resize_config
|
||||
Status InitCropOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, crop_area_
|
||||
Status InitCropAndPasteOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, crop_area_, paste_area_
|
||||
void DestroyVpcOutputDesc();
|
||||
|
||||
Status ProcessDecode();
|
||||
Status ProcessResize();
|
||||
Status ProcessCrop();
|
||||
Status ProcessCropAndPaste();
|
||||
void DestroyResource();
|
||||
|
||||
Status GetJpegWidthHeight(const void *pic_buffer, size_t pic_buffer_size, uint32_t *image_width,
|
||||
uint32_t *image_height);
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
|
||||
#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_DVPP_PROCESS_H
|
|
@ -16,17 +16,13 @@
|
|||
|
||||
#include "cxx_api/model/acl/model_converter.h"
|
||||
#include <memory>
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "transform/graph_ir/convert.h"
|
||||
#include "transform/graph_ir/graph_runner.h"
|
||||
#include "core/load_mindir/load_model.h"
|
||||
#include "mindspore/core/utils/ms_context.h"
|
||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
|
||||
#include "include/api/serialization.h"
|
||||
#include "graph/model.h"
|
||||
#include "cxx_api/model/model_converter_utils/multi_process.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
#include "cxx_api/python_utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
namespace {
|
||||
|
@ -74,19 +70,8 @@ bool CreateSessionAndGraphRunner() {
|
|||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::shared_ptr<FuncGraph> ModelConverter::ConvertMindIrToFuncGraph(const Buffer &model_data) {
|
||||
try {
|
||||
auto anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(model_data.Data()), model_data.DataSize());
|
||||
return anf_graph;
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Load MindIR failed.";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
transform::DfGraphPtr ModelConverter::ConvertFuncGraphToAIR(const FuncGraphPtr &anf_graph) {
|
||||
for (auto &anf_node : anf_graph->parameters()) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
|
@ -166,88 +151,31 @@ Buffer ModelConverter::BuildAirModel(const transform::DfGraphPtr &graph,
|
|||
return Buffer(model.data.get(), model.length);
|
||||
}
|
||||
|
||||
void ModelConverter::RegAllOp() {
|
||||
static std::mutex init_mutex;
|
||||
static bool Initialized = false;
|
||||
|
||||
std::lock_guard<std::mutex> lock(init_mutex);
|
||||
if (Initialized) {
|
||||
return;
|
||||
}
|
||||
Initialized = true;
|
||||
MsContext::GetInstance()->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
Py_Initialize();
|
||||
auto c_expression = PyImport_ImportModule("mindspore._c_expression");
|
||||
MS_EXCEPTION_IF_NULL(c_expression);
|
||||
PyObject *c_expression_dict = PyModule_GetDict(c_expression);
|
||||
MS_EXCEPTION_IF_NULL(c_expression_dict);
|
||||
|
||||
PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy");
|
||||
MS_EXCEPTION_IF_NULL(op_info_loader_class);
|
||||
PyObject *op_info_loader = PyInstanceMethod_New(op_info_loader_class);
|
||||
MS_EXCEPTION_IF_NULL(op_info_loader);
|
||||
PyObject *op_info_loader_ins = PyObject_CallObject(op_info_loader, nullptr);
|
||||
MS_EXCEPTION_IF_NULL(op_info_loader_ins);
|
||||
auto all_ops_info_vector_addr_ul = PyObject_CallMethod(op_info_loader_ins, "get_all_ops_info", nullptr);
|
||||
MS_EXCEPTION_IF_NULL(all_ops_info_vector_addr_ul);
|
||||
auto all_ops_info_vector_addr = PyLong_AsVoidPtr(all_ops_info_vector_addr_ul);
|
||||
auto all_ops_info = static_cast<std::vector<kernel::OpInfo *> *>(all_ops_info_vector_addr);
|
||||
for (auto op_info : *all_ops_info) {
|
||||
kernel::OpLib::RegOpInfo(std::shared_ptr<kernel::OpInfo>(op_info));
|
||||
}
|
||||
all_ops_info->clear();
|
||||
delete all_ops_info;
|
||||
Py_DECREF(op_info_loader);
|
||||
Py_DECREF(op_info_loader_class);
|
||||
Py_DECREF(c_expression_dict);
|
||||
Py_DECREF(c_expression);
|
||||
}
|
||||
|
||||
Buffer ModelConverter::ReadFile(const std::string &file) {
|
||||
Buffer buffer;
|
||||
if (file.empty()) {
|
||||
MS_LOG(ERROR) << "Pointer file is nullptr";
|
||||
return buffer;
|
||||
}
|
||||
std::string realPath = file;
|
||||
std::ifstream ifs(realPath);
|
||||
if (!ifs.good()) {
|
||||
MS_LOG(ERROR) << "File: " << realPath << " is not exist";
|
||||
return buffer;
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
MS_LOG(ERROR) << "File: " << realPath << "open failed";
|
||||
return buffer;
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
size_t size = ifs.tellg();
|
||||
buffer.ResizeData(size);
|
||||
if (buffer.DataSize() != size) {
|
||||
MS_LOG(ERROR) << "Malloc buf failed, file: " << realPath;
|
||||
ifs.close();
|
||||
return buffer;
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
|
||||
ifs.close();
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
Buffer ModelConverter::LoadMindIR(const Buffer &model_data) {
|
||||
if (Py_IsInitialized() == 0) {
|
||||
Buffer ModelConverter::LoadMindIR(const FuncGraphPtr &func_graph) {
|
||||
if (!PythonIsInited()) {
|
||||
MS_LOG_INFO << "Call LoadMindIRInner directly";
|
||||
return LoadMindIRInner(model_data);
|
||||
return LoadMindIRInner(func_graph);
|
||||
}
|
||||
MultiProcess multi_process;
|
||||
Buffer buffer_ret;
|
||||
auto parent_process = [&model_data, &buffer_ret](MultiProcess *multi_process) -> Status {
|
||||
auto parent_process = [&func_graph, &buffer_ret, this](MultiProcess *multi_process) -> Status {
|
||||
MS_EXCEPTION_IF_NULL(multi_process);
|
||||
auto df_graph = ConvertFuncGraphToAIR(func_graph);
|
||||
if (df_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Convert FuncGraph to AscendIR failed.";
|
||||
return FAILED;
|
||||
}
|
||||
ge::Model model;
|
||||
ge::Buffer model_data;
|
||||
model.SetGraph(*df_graph);
|
||||
auto ge_ret = model.Save(model_data);
|
||||
if (ge_ret != ge::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Save ge model to buffer failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// send original model to child
|
||||
auto status = multi_process->SendMsg(model_data.Data(), model_data.DataSize());
|
||||
auto status = multi_process->SendMsg(model_data.data(), model_data.size());
|
||||
if (!status.IsSuccess()) {
|
||||
MS_LOG_ERROR << "Send original model to child process failed";
|
||||
return FAILED;
|
||||
|
@ -277,7 +205,7 @@ Buffer ModelConverter::LoadMindIR(const Buffer &model_data) {
|
|||
MS_LOG_ERROR << "Receive original model from parent process failed";
|
||||
return FAILED;
|
||||
}
|
||||
Buffer model_result = LoadMindIRInner(model);
|
||||
Buffer model_result = LoadAscendIRInner(model);
|
||||
if (model_result.DataSize() == 0) {
|
||||
MS_LOG_ERROR << "Convert model from MindIR to OM failed";
|
||||
return FAILED;
|
||||
|
@ -300,7 +228,7 @@ Buffer ModelConverter::LoadMindIR(const Buffer &model_data) {
|
|||
}
|
||||
|
||||
Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) {
|
||||
if (Py_IsInitialized() == 0) {
|
||||
if (!PythonIsInited()) {
|
||||
MS_LOG_INFO << "Call LoadAscendIRInner directly";
|
||||
return LoadAscendIRInner(model_data);
|
||||
}
|
||||
|
@ -361,10 +289,8 @@ Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) {
|
|||
return buffer_ret;
|
||||
}
|
||||
|
||||
Buffer ModelConverter::LoadMindIRInner(const Buffer &model_data) {
|
||||
RegAllOp();
|
||||
Py_Initialize();
|
||||
auto func_graph = ConvertMindIrToFuncGraph(model_data);
|
||||
Buffer ModelConverter::LoadMindIRInner(const FuncGraphPtr &func_graph) {
|
||||
RegAllOpFromPython();
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Convert MindIR to FuncGraph failed.";
|
||||
return Buffer();
|
||||
|
@ -386,7 +312,7 @@ Buffer ModelConverter::LoadMindIRInner(const Buffer &model_data) {
|
|||
}
|
||||
|
||||
Buffer ModelConverter::LoadAscendIRInner(const Buffer &model_data) {
|
||||
RegAllOp();
|
||||
RegAllOpFromPython();
|
||||
ge::Model load_model = ge::Model("loadmodel", "version2");
|
||||
ge::Status ret =
|
||||
ge::Model::Load(reinterpret_cast<const uint8_t *>(model_data.Data()), model_data.DataSize(), load_model);
|
||||
|
|
|
@ -32,21 +32,17 @@ class ModelConverter {
|
|||
public:
|
||||
ModelConverter() : options_(nullptr) {}
|
||||
|
||||
Buffer LoadMindIR(const Buffer &model_data);
|
||||
Buffer LoadMindIR(const FuncGraphPtr &func_graph);
|
||||
Buffer LoadAscendIR(const Buffer &model_data);
|
||||
|
||||
void set_options(AclModelOptions *options) { options_ = options; }
|
||||
|
||||
static Buffer ReadFile(const std::string &file);
|
||||
static void RegAllOp();
|
||||
|
||||
private:
|
||||
std::shared_ptr<FuncGraph> ConvertMindIrToFuncGraph(const Buffer &model_data);
|
||||
transform::DfGraphPtr ConvertFuncGraphToAIR(const FuncGraphPtr &anf_graph);
|
||||
Buffer BuildAirModel(const transform::DfGraphPtr &graph, const std::map<std::string, std::string> &acl_options);
|
||||
AclModelOptions *options_;
|
||||
|
||||
Buffer LoadMindIRInner(const Buffer &model_data);
|
||||
Buffer LoadMindIRInner(const FuncGraphPtr &func_graph);
|
||||
Buffer LoadAscendIRInner(const Buffer &model_data);
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
|
|
|
@ -14,93 +14,59 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "include/api/model.h"
|
||||
#include "include/api/context.h"
|
||||
#include "cxx_api/model/model_impl.h"
|
||||
#include "cxx_api/factory.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
const char *kDeviceTypeAscendCL = "AscendCL";
|
||||
const char *kDeviceTypeAscendMS = "AscendMS";
|
||||
|
||||
Status Model::LoadModel(const Buffer &model_data, ModelType type, const std::map<std::string, std::string> &options) {
|
||||
Status Model::Build(const std::map<std::string, std::string> &options) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->LoadModel(model_data, type, options);
|
||||
return impl_->Build(options);
|
||||
}
|
||||
|
||||
Status Model::LoadModel(const std::string &file_name, ModelType type,
|
||||
const std::map<std::string, std::string> &options) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->LoadModel(file_name, type, options);
|
||||
}
|
||||
|
||||
Status Model::UnloadModel() {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->UnloadModel();
|
||||
}
|
||||
|
||||
Status Model::Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) {
|
||||
Status Model::Train(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->Train(dataset, outputs);
|
||||
}
|
||||
|
||||
Status Model::Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) {
|
||||
Status Model::Eval(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->Eval(dataset, outputs);
|
||||
}
|
||||
|
||||
Status Model::Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) {
|
||||
Status Model::Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->Predict(inputs, outputs);
|
||||
}
|
||||
|
||||
Status Model::Predict(const std::vector<Buffer> &inputs, std::map<std::string, Buffer> *outputs) {
|
||||
std::vector<Tensor> tensor_list;
|
||||
auto ret = GetInputsInfo(&tensor_list);
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "GetInputsInfo failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (inputs.size() != tensor_list.size()) {
|
||||
MS_LOG(ERROR) << "Model need " << tensor_list.size() << " inputs, but given " << inputs.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
std::map<std::string, Buffer> inputs_with_map;
|
||||
for (size_t i = 0; i < tensor_list.size(); ++i) {
|
||||
inputs_with_map.emplace(tensor_list[i].Name(), inputs[i]);
|
||||
}
|
||||
|
||||
return Predict(inputs_with_map, outputs);
|
||||
}
|
||||
|
||||
Status Model::GetInputsInfo(std::vector<Tensor> *tensor_list) const {
|
||||
Status Model::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->GetInputsInfo(tensor_list);
|
||||
return impl_->GetInputsInfo(names, shapes, data_types, mem_sizes);
|
||||
}
|
||||
|
||||
Status Model::GetOutputsInfo(std::vector<Tensor> *tensor_list) const {
|
||||
Status Model::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->GetOutputsInfo(tensor_list);
|
||||
return impl_->GetOutputsInfo(names, shapes, data_types, mem_sizes);
|
||||
}
|
||||
|
||||
Model::Model(const std::string &device_type, uint32_t device_id)
|
||||
: impl_(ModelFactory::Instance().Create(device_type, device_id)) {
|
||||
Model::Model(const GraphCell &graph_cell)
|
||||
: impl_(Factory<ModelImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) {
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Create session type " << device_type << " failed";
|
||||
MS_LOG(EXCEPTION) << "Create session type " << Context::Instance().GetDeviceTarget() << " failed";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(graph_cell.GetGraph());
|
||||
impl_->SetGraph(std::make_shared<Graph>(*graph_cell.GetGraph()));
|
||||
}
|
||||
|
||||
Model::Model(NetWork network, const std::string &device_type, uint32_t device_id) {
|
||||
// todo
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Create session type " << device_type << " failed";
|
||||
}
|
||||
}
|
||||
Model::Model(const std::vector<Output> &network) { MS_LOG(EXCEPTION) << "Unsupported feature."; }
|
||||
|
||||
Model::~Model() {}
|
||||
|
||||
bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) {
|
||||
return ModelFactory::Instance().CheckModelSupport(device_type, model_type);
|
||||
bool Model::CheckModelSupport(const std::string &device_type, ModelType) {
|
||||
return Factory<ModelImpl>::Instance().CheckModelSupport(device_type);
|
||||
}
|
||||
|
||||
} // namespace mindspore::api
|
||||
|
|
|
@ -22,7 +22,10 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
#include "include/api/model.h"
|
||||
#include "include/api/graph.h"
|
||||
#include "cxx_api/graph/graph_data.h"
|
||||
#include "utils/utils.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
class ModelImpl {
|
||||
|
@ -30,70 +33,39 @@ class ModelImpl {
|
|||
ModelImpl() = default;
|
||||
virtual ~ModelImpl() = default;
|
||||
|
||||
virtual Status LoadModel(const Buffer &model_data, ModelType type,
|
||||
const std::map<std::string, std::string> &options) = 0;
|
||||
virtual Status LoadModel(const std::string &file_name, ModelType type,
|
||||
const std::map<std::string, std::string> &options) = 0;
|
||||
virtual Status UnloadModel() = 0;
|
||||
virtual Status Build(const std::map<std::string, std::string> &options) = 0;
|
||||
|
||||
virtual Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) = 0;
|
||||
virtual Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) = 0;
|
||||
virtual Status Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) = 0;
|
||||
virtual Status Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) = 0;
|
||||
|
||||
virtual Status GetInputsInfo(std::vector<Tensor> *tensor_list) const = 0;
|
||||
virtual Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const = 0;
|
||||
};
|
||||
virtual Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const = 0;
|
||||
virtual Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const = 0;
|
||||
|
||||
using ModelCreator = std::function<std::shared_ptr<ModelImpl>(uint32_t device_id)>;
|
||||
class ModelFactory {
|
||||
public:
|
||||
ModelFactory(const ModelFactory &) = delete;
|
||||
void operator=(const ModelFactory &) = delete;
|
||||
|
||||
static ModelFactory &Instance() {
|
||||
static ModelFactory instance;
|
||||
return instance;
|
||||
protected:
|
||||
Status Load(const std::shared_ptr<GraphCell> &graph_cell) {
|
||||
MS_EXCEPTION_IF_NULL(graph_cell);
|
||||
return graph_cell->Load();
|
||||
}
|
||||
|
||||
void Register(const std::string &device_name, ModelCreator &&model_creator) {
|
||||
if (model_creators_.find(device_name) == model_creators_.end()) {
|
||||
(void)model_creators_.emplace(device_name, model_creator);
|
||||
FuncGraphPtr GetFuncGraph() const {
|
||||
if (graph_->ModelType() != ModelType::kMindIR) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto graph_data = graph_->graph_data_;
|
||||
MS_EXCEPTION_IF_NULL(graph_data);
|
||||
return graph_data->GetFuncGraph();
|
||||
}
|
||||
|
||||
std::shared_ptr<ModelImpl> Create(const std::string &device_name, uint32_t device_id) {
|
||||
auto iter = model_creators_.find(device_name);
|
||||
if (model_creators_.end() != iter) {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
return (iter->second)(device_id);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool CheckModelSupport(const std::string &device_type, ModelType /*model_type*/) {
|
||||
return std::any_of(
|
||||
model_creators_.begin(), model_creators_.end(),
|
||||
[&device_type](const std::pair<std::string, ModelCreator> &item) { return item.first == device_type; });
|
||||
}
|
||||
std::shared_ptr<Graph> graph_;
|
||||
|
||||
private:
|
||||
ModelFactory() = default;
|
||||
~ModelFactory() = default;
|
||||
std::map<std::string, ModelCreator> model_creators_;
|
||||
friend class Model;
|
||||
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
|
||||
};
|
||||
|
||||
class ModelRegistrar {
|
||||
public:
|
||||
ModelRegistrar(const std::string &device_name, ModelCreator model_creator) {
|
||||
ModelFactory::Instance().Register(device_name, std::move(model_creator));
|
||||
}
|
||||
~ModelRegistrar() = default;
|
||||
};
|
||||
|
||||
#define API_REG_MODEL(DEVICE_NAME, MODEL_CLASS) \
|
||||
static const ModelRegistrar g_api_model_registrar__##DEVICE_NAME##_##_reg( \
|
||||
kDeviceType##DEVICE_NAME, [](uint32_t device_id) { return std::make_shared<MODEL_CLASS>(device_id); });
|
||||
|
||||
} // namespace mindspore::api
|
||||
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_MODEL_MODEL_IMPL_H
|
||||
|
|
|
@ -16,164 +16,33 @@
|
|||
|
||||
#include "cxx_api/model/ms/ms_model.h"
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
|
||||
#include "load_mindir/load_model.h"
|
||||
#include "backend/session/session_basic.h"
|
||||
#include "backend/session/session_factory.h"
|
||||
#include "backend/session/executor_manager.h"
|
||||
#include "base/base_ref_utils.h"
|
||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
#include "utils/context/context_extends.h"
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/embed.h"
|
||||
|
||||
#ifdef ENABLE_D
|
||||
#include "utils/ms_context.h"
|
||||
#endif
|
||||
#include "cxx_api/factory.h"
|
||||
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
MsModel::MsModel(uint32_t device_id) : device_id_(device_id) {}
|
||||
MsModel::~MsModel() = default;
|
||||
API_FACTORY_REG(ModelImpl, Ascend910, MsModel);
|
||||
|
||||
TypeId TransInferDataType2TypeId(DataType data_type) {
|
||||
const std::map<api::DataType, TypeId> type2id_map{
|
||||
{api::kMsUnknown, TypeId::kNumberTypeBegin}, {api::kMsBool, TypeId::kNumberTypeBool},
|
||||
{api::kMsInt8, TypeId::kNumberTypeInt8}, {api::kMsUint8, TypeId::kNumberTypeUInt8},
|
||||
{api::kMsInt16, TypeId::kNumberTypeInt16}, {api::kMsUint16, TypeId::kNumberTypeUInt16},
|
||||
{api::kMsInt32, TypeId::kNumberTypeInt32}, {api::kMsUint32, TypeId::kNumberTypeUInt32},
|
||||
{api::kMsInt64, TypeId::kNumberTypeInt64}, {api::kMsUint64, TypeId::kNumberTypeUInt64},
|
||||
{api::kMsFloat16, TypeId::kNumberTypeFloat16}, {api::kMsFloat32, TypeId::kNumberTypeFloat32},
|
||||
{api::kMsFloat64, TypeId::kNumberTypeFloat64},
|
||||
};
|
||||
auto it = type2id_map.find(data_type);
|
||||
if (it == type2id_map.end()) {
|
||||
MS_LOG_WARNING << "Unsupported MSI data type " << data_type;
|
||||
return TypeId::kNumberTypeBegin;
|
||||
} else {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
Status MsModel::Build(const std::map<std::string, std::string> &) {
|
||||
MS_LOG(INFO) << "Start build model.";
|
||||
MS_EXCEPTION_IF_NULL(graph_);
|
||||
|
||||
DataType TransTypeId2InferDataType(TypeId type_id) {
|
||||
const std::map<TypeId, api::DataType> id2type_map{
|
||||
{TypeId::kNumberTypeBegin, api::kMsUnknown}, {TypeId::kNumberTypeBool, api::kMsBool},
|
||||
{TypeId::kNumberTypeFloat64, api::kMsFloat64}, {TypeId::kNumberTypeInt8, api::kMsInt8},
|
||||
{TypeId::kNumberTypeUInt8, api::kMsUint8}, {TypeId::kNumberTypeInt16, api::kMsInt16},
|
||||
{TypeId::kNumberTypeUInt16, api::kMsUint16}, {TypeId::kNumberTypeInt32, api::kMsInt32},
|
||||
{TypeId::kNumberTypeUInt32, api::kMsUint32}, {TypeId::kNumberTypeInt64, api::kMsInt64},
|
||||
{TypeId::kNumberTypeUInt64, api::kMsUint64}, {TypeId::kNumberTypeFloat16, api::kMsFloat16},
|
||||
{TypeId::kNumberTypeFloat32, api::kMsFloat32},
|
||||
};
|
||||
auto it = id2type_map.find(type_id);
|
||||
if (it == id2type_map.end()) {
|
||||
MS_LOG_WARNING << "Unsupported data id " << type_id;
|
||||
return api::kMsUnknown;
|
||||
} else {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
auto func_graph = ModelImpl::GetFuncGraph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
||||
Buffer MsModel::ReadFile(const std::string &file) {
|
||||
if (file.empty()) {
|
||||
MS_LOG(ERROR) << "file is nullptr";
|
||||
return Buffer();
|
||||
}
|
||||
std::ifstream ifs(file);
|
||||
if (!ifs.good()) {
|
||||
MS_LOG(ERROR) << "file: " << file << " is not exist";
|
||||
return Buffer();
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
MS_LOG(ERROR) << "file: " << file << "open failed";
|
||||
return Buffer();
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
size_t size = ifs.tellg();
|
||||
Buffer buffer;
|
||||
buffer.ResizeData(size);
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(static_cast<char *>(buffer.MutableData()), size);
|
||||
ifs.close();
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
Status MsModel::LoadModel(const Buffer &model_data, ModelType type, const std::map<std::string, std::string> &options) {
|
||||
auto status = InitEnv({});
|
||||
if (status != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Init env failed";
|
||||
return FAILED;
|
||||
}
|
||||
std::shared_ptr<FuncGraph> anf_graph;
|
||||
Py_Initialize();
|
||||
try {
|
||||
anf_graph = ConvertStreamToFuncGraph(static_cast<const char *>(model_data.Data()), model_data.DataSize());
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Inference LoadModel failed";
|
||||
return FAILED;
|
||||
}
|
||||
Status ret = CompileGraph(anf_graph);
|
||||
auto graph = std::make_shared<Graph>(std::make_shared<Graph::GraphData>(func_graph, ModelType::kMindIR));
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto graph_cell = std::make_shared<GraphCell>(graph);
|
||||
MS_EXCEPTION_IF_NULL(graph_cell);
|
||||
auto ret = ModelImpl::Load(graph_cell);
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Compile graph model failed";
|
||||
return FAILED;
|
||||
MS_LOG(ERROR) << "Load failed.";
|
||||
return ret;
|
||||
}
|
||||
session_impl_->GetModelInputsInfo(graph_id_, &inputs_, &input_names_);
|
||||
session_impl_->GetModelOutputsInfo(graph_id_, &outputs_, &output_names_);
|
||||
if (inputs_.empty() || inputs_.size() != input_names_.size()) {
|
||||
MS_LOG_ERROR << "Get model inputs info failed";
|
||||
return FAILED;
|
||||
}
|
||||
if (outputs_.empty() || outputs_.size() != output_names_.size()) {
|
||||
MS_LOG_ERROR << "Get model outputs info failed";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Load model success";
|
||||
|
||||
#ifdef ENABLE_D
|
||||
// set d context
|
||||
rtError_t rt_ret = rtCtxGetCurrent(&context_);
|
||||
if (rt_ret != RT_ERROR_NONE || context_ == nullptr) {
|
||||
MS_LOG(ERROR) << "the ascend device context is null";
|
||||
return FAILED;
|
||||
}
|
||||
#endif
|
||||
load_flag_ = true;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MsModel::LoadModel(const std::string &file_name, ModelType type,
|
||||
const std::map<std::string, std::string> &options) {
|
||||
auto graphBuf = ReadFile(file_name);
|
||||
if (graphBuf.DataSize() == 0) {
|
||||
MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str();
|
||||
return FAILED;
|
||||
}
|
||||
auto status = LoadModel(graphBuf, type, options);
|
||||
if (status != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MsModel::UnloadModel() {
|
||||
if (!load_flag_) {
|
||||
MS_LOG_ERROR << "Model has not been loaded";
|
||||
return FAILED;
|
||||
}
|
||||
FinalizeEnv();
|
||||
load_flag_ = false;
|
||||
// save result
|
||||
graph_cell_ = graph_cell;
|
||||
MS_LOG(INFO) << "Build model success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -187,231 +56,42 @@ Status MsModel::Eval(const DataSet &, std::map<std::string, Buffer> *) {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
Status MsModel::Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) {
|
||||
Status MsModel::Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
if (!load_flag_) {
|
||||
MS_LOG(ERROR) << "No model is loaded, predict failed.";
|
||||
if (graph_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid data, graph_ is null.";
|
||||
return FAILED;
|
||||
}
|
||||
if (inputs.size() != inputs_.size()) {
|
||||
MS_LOG(ERROR) << "inputs count not match, required count " << inputs_.size() << ", given count " << inputs.size();
|
||||
return INVALID_INPUTS;
|
||||
}
|
||||
std::vector<Buffer> request;
|
||||
std::vector<Buffer> reply;
|
||||
for (size_t i = 0; i < inputs_.size(); ++i) {
|
||||
const auto &input_name = input_names_[i];
|
||||
auto iter = inputs.find(input_name);
|
||||
if (iter == inputs.end()) {
|
||||
MS_LOG(ERROR) << "Model missing input " << input_name;
|
||||
return INVALID_INPUTS;
|
||||
}
|
||||
|
||||
if (iter->second.DataSize() != inputs_[i]->Size()) {
|
||||
MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_[i]->Size() << ", given count "
|
||||
<< iter->second.DataSize();
|
||||
return INVALID_INPUTS;
|
||||
}
|
||||
request.push_back(iter->second);
|
||||
}
|
||||
if (ExecuteModel(request, &reply) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Execute Model Failed";
|
||||
return FAILED;
|
||||
}
|
||||
if (outputs_.size() != reply.size()) {
|
||||
MS_LOG(ERROR) << "Predict output size " << reply.size() << " not match output size got from model info "
|
||||
<< outputs_.size();
|
||||
return FAILED;
|
||||
}
|
||||
outputs->clear();
|
||||
for (size_t i = 0; i < reply.size(); i++) {
|
||||
outputs->emplace(output_names_[i], reply[i]);
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MsModel::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) {
|
||||
MS_EXCEPTION_IF_NULL(reply);
|
||||
#ifdef ENABLE_D
|
||||
if (context_ == nullptr) {
|
||||
MS_LOG(ERROR) << "rtCtx is nullptr";
|
||||
return FAILED;
|
||||
}
|
||||
rtError_t rt_ret = rtCtxSetCurrent(context_);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "set Ascend rtCtx failed";
|
||||
return FAILED;
|
||||
}
|
||||
#endif
|
||||
vector<tensor::TensorPtr> inputs;
|
||||
for (size_t i = 0; i < request.size(); i++) {
|
||||
auto &item = request[i];
|
||||
auto input = inputs_[i];
|
||||
if (input->Size() != item.DataSize()) {
|
||||
MS_LOG(ERROR) << "Predict input " << i << " data size " << item.DataSize() << " not match model input data size "
|
||||
<< input->Size();
|
||||
return FAILED;
|
||||
}
|
||||
auto ret = memcpy_s(input->data_c(), input->Size(), item.Data(), item.DataSize());
|
||||
if (graph_cell_ == nullptr) {
|
||||
MS_LOG(INFO) << "Model has not been built, it will be built with default options";
|
||||
Status ret = Build({});
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Tensor copy failed";
|
||||
MS_LOG(ERROR) << "Build model failed.";
|
||||
return FAILED;
|
||||
}
|
||||
inputs.push_back(input);
|
||||
}
|
||||
vector<tensor::TensorPtr> outputs = RunGraph(inputs);
|
||||
if (outputs.empty()) {
|
||||
MS_LOG(ERROR) << "Execute Model Failed";
|
||||
|
||||
MS_EXCEPTION_IF_NULL(graph_cell_);
|
||||
Status ret = graph_cell_->Run(inputs, outputs);
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Run graph failed.";
|
||||
return FAILED;
|
||||
}
|
||||
reply->clear();
|
||||
std::transform(outputs.begin(), outputs.end(), std::back_inserter(*reply),
|
||||
[](const tensor::TensorPtr &tensor) { return Buffer(tensor->data_c(), tensor->Size()); });
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MsModel::FinalizeEnv() {
|
||||
MS_LOG_INFO << "Start finalize env";
|
||||
py::gil_scoped_acquire acquire;
|
||||
session::ExecutorManager::Instance().Clear();
|
||||
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
if (ms_context == nullptr) {
|
||||
MS_LOG(ERROR) << "Get Context failed!";
|
||||
return FAILED;
|
||||
}
|
||||
if (!context::CloseTsd(ms_context)) {
|
||||
MS_LOG(ERROR) << "Inference CloseTsd failed!";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG_INFO << "End finalize env";
|
||||
return SUCCESS;
|
||||
Status MsModel::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
|
||||
MS_EXCEPTION_IF_NULL(graph_cell_);
|
||||
return graph_cell_->GetInputsInfo(names, shapes, data_types, mem_sizes);
|
||||
}
|
||||
|
||||
std::shared_ptr<FuncGraph> MsModel::LoadModel(const char *model_buf, size_t size, const std::string &device) {
|
||||
Py_Initialize();
|
||||
MS_EXCEPTION_IF_NULL(model_buf);
|
||||
try {
|
||||
auto anf_graph = ConvertStreamToFuncGraph(model_buf, size);
|
||||
return anf_graph;
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Inference LoadModel failed: " << e.what();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void MsModel::RegAllOp() {
|
||||
static std::mutex init_mutex;
|
||||
static bool Initialized = false;
|
||||
|
||||
std::lock_guard<std::mutex> lock(init_mutex);
|
||||
if (Initialized) {
|
||||
return;
|
||||
}
|
||||
Initialized = true;
|
||||
auto ms_context_instance = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context_instance);
|
||||
ms_context_instance->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
try {
|
||||
std::shared_ptr<py::scoped_interpreter> guard;
|
||||
if (Py_IsInitialized() == 0) {
|
||||
guard = std::make_shared<py::scoped_interpreter>();
|
||||
}
|
||||
py::module c_expression = py::module::import("mindspore._c_expression");
|
||||
size_t ops_info_long = c_expression.attr("OpInfoLoaderPy")().attr("get_all_ops_info")().cast<size_t>();
|
||||
auto all_ops_info = reinterpret_cast<std::vector<kernel::OpInfo *> *>(static_cast<uintptr_t>(ops_info_long));
|
||||
for (auto op_info : *all_ops_info) {
|
||||
kernel::OpLib::RegOpInfo(std::shared_ptr<kernel::OpInfo>(op_info));
|
||||
}
|
||||
all_ops_info->clear();
|
||||
delete all_ops_info;
|
||||
} catch (const std::runtime_error &ex) {
|
||||
MS_LOG_EXCEPTION << ex.what();
|
||||
}
|
||||
}
|
||||
|
||||
Status MsModel::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) {
|
||||
MS_ASSERT(session_impl_ != nullptr);
|
||||
try {
|
||||
graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
|
||||
py::gil_scoped_release gil_release;
|
||||
return SUCCESS;
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Inference CompileGraph failed: " << e.what();
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<tensor::TensorPtr> MsModel::RunGraph(const std::vector<tensor::TensorPtr> &inputs) {
|
||||
try {
|
||||
VectorRef outputs;
|
||||
session_impl_->RunGraph(graph_id_, inputs, &outputs);
|
||||
return TransformVectorRefToMultiTensor(outputs);
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Inference Rungraph failed: " << e.what();
|
||||
return std::vector<tensor::TensorPtr>();
|
||||
}
|
||||
}
|
||||
|
||||
Status MsModel::InitEnv(const std::unordered_map<std::string, std::string> &other_options) {
|
||||
RegAllOp();
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
if (ms_context == nullptr) {
|
||||
MS_LOG(ERROR) << "Get Context failed!";
|
||||
return FAILED;
|
||||
}
|
||||
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_);
|
||||
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kAscendDevice);
|
||||
if (!context::OpenTsd(ms_context)) {
|
||||
MS_LOG(ERROR) << "Session init OpenTsd failed!";
|
||||
return FAILED;
|
||||
}
|
||||
session_impl_ = session::SessionFactory::Get().Create(kDavinciInferenceDevice);
|
||||
if (session_impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << kDavinciInferenceDevice
|
||||
<< " is available.";
|
||||
return FAILED;
|
||||
}
|
||||
session_impl_->Init(device_id_);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MsModel::CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const {
|
||||
MS_ASSERT(session_impl_ != nullptr);
|
||||
std::string error_msg;
|
||||
if (!session_impl_->CheckModelInputs(graph_id, inputs, &error_msg)) {
|
||||
return Status(INVALID_INPUTS, error_msg);
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MsModel::GetInputsInfo(std::vector<Tensor> *tensor_list) const {
|
||||
MS_EXCEPTION_IF_NULL(tensor_list);
|
||||
tensor_list->clear();
|
||||
for (size_t i = 0; i < inputs_.size(); i++) {
|
||||
auto &tensor = inputs_[i];
|
||||
Tensor infer_tensor;
|
||||
infer_tensor.SetName(input_names_[i]);
|
||||
infer_tensor.SetDataType(TransTypeId2InferDataType(tensor->data_type()));
|
||||
infer_tensor.SetShape(tensor->shape());
|
||||
tensor_list->push_back(infer_tensor);
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MsModel::GetOutputsInfo(std::vector<Tensor> *tensor_list) const {
|
||||
MS_EXCEPTION_IF_NULL(tensor_list);
|
||||
tensor_list->clear();
|
||||
for (size_t i = 0; i < outputs_.size(); i++) {
|
||||
auto &tensor = outputs_[i];
|
||||
Tensor infer_tensor;
|
||||
infer_tensor.SetName(output_names_[i]);
|
||||
infer_tensor.SetDataType(TransTypeId2InferDataType(tensor->data_type()));
|
||||
infer_tensor.SetShape(tensor->shape());
|
||||
tensor_list->push_back(infer_tensor);
|
||||
}
|
||||
return SUCCESS;
|
||||
Status MsModel::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
|
||||
MS_EXCEPTION_IF_NULL(graph_cell_);
|
||||
return graph_cell_->GetOutputsInfo(names, shapes, data_types, mem_sizes);
|
||||
}
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -36,49 +36,23 @@ namespace mindspore {
|
|||
namespace api {
|
||||
class MsModel : public ModelImpl {
|
||||
public:
|
||||
explicit MsModel(uint32_t device_id);
|
||||
~MsModel();
|
||||
MsModel() {}
|
||||
~MsModel() = default;
|
||||
|
||||
Status LoadModel(const Buffer &model_data, ModelType type,
|
||||
const std::map<std::string, std::string> &options) override;
|
||||
Status LoadModel(const std::string &file_name, ModelType type,
|
||||
const std::map<std::string, std::string> &options) override;
|
||||
Status UnloadModel() override;
|
||||
Status Build(const std::map<std::string, std::string> &options_map) override;
|
||||
|
||||
Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
|
||||
Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
|
||||
Status Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) override;
|
||||
Status Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
|
||||
|
||||
Status GetInputsInfo(std::vector<Tensor> *tensor_list) const override;
|
||||
Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const override;
|
||||
|
||||
Status InitEnv(const std::unordered_map<std::string, std::string> &other_options);
|
||||
Status FinalizeEnv();
|
||||
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const override;
|
||||
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<session::SessionBasic> session_impl_ = nullptr;
|
||||
uint32_t graph_id_;
|
||||
std::string device_type_;
|
||||
int32_t device_id_ = 0;
|
||||
#ifdef ENABLE_D
|
||||
rtContext_t context_ = nullptr;
|
||||
#endif
|
||||
std::vector<tensor::TensorPtr> inputs_;
|
||||
std::vector<tensor::TensorPtr> outputs_;
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<std::string> output_names_;
|
||||
bool load_flag_ = false;
|
||||
|
||||
std::shared_ptr<FuncGraph> LoadModel(const char *model_buf, size_t size, const std::string &device);
|
||||
Buffer ReadFile(const std::string &file);
|
||||
static void RegAllOp();
|
||||
Status CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr);
|
||||
Status CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const;
|
||||
std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs);
|
||||
Status ExecuteModel(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs);
|
||||
std::shared_ptr<GraphCell> graph_cell_;
|
||||
};
|
||||
|
||||
API_REG_MODEL(AscendMS, MsModel);
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cxx_api/python_utils.h"
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "mindspore/core/utils/ms_context.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace mindspore::api {
|
||||
void RegAllOpFromPython() {
|
||||
static std::mutex init_mutex;
|
||||
static bool Initialized = false;
|
||||
|
||||
std::lock_guard<std::mutex> lock(init_mutex);
|
||||
if (Initialized) {
|
||||
return;
|
||||
}
|
||||
Initialized = true;
|
||||
MsContext::GetInstance()->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
Py_Initialize();
|
||||
auto c_expression = PyImport_ImportModule("mindspore._c_expression");
|
||||
MS_EXCEPTION_IF_NULL(c_expression);
|
||||
PyObject *c_expression_dict = PyModule_GetDict(c_expression);
|
||||
MS_EXCEPTION_IF_NULL(c_expression_dict);
|
||||
|
||||
PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy");
|
||||
MS_EXCEPTION_IF_NULL(op_info_loader_class);
|
||||
PyObject *op_info_loader = PyInstanceMethod_New(op_info_loader_class);
|
||||
MS_EXCEPTION_IF_NULL(op_info_loader);
|
||||
PyObject *op_info_loader_ins = PyObject_CallObject(op_info_loader, nullptr);
|
||||
MS_EXCEPTION_IF_NULL(op_info_loader_ins);
|
||||
auto all_ops_info_vector_addr_ul = PyObject_CallMethod(op_info_loader_ins, "get_all_ops_info", nullptr);
|
||||
MS_EXCEPTION_IF_NULL(all_ops_info_vector_addr_ul);
|
||||
auto all_ops_info_vector_addr = PyLong_AsVoidPtr(all_ops_info_vector_addr_ul);
|
||||
auto all_ops_info = static_cast<std::vector<kernel::OpInfo *> *>(all_ops_info_vector_addr);
|
||||
for (auto op_info : *all_ops_info) {
|
||||
kernel::OpLib::RegOpInfo(std::shared_ptr<kernel::OpInfo>(op_info));
|
||||
}
|
||||
all_ops_info->clear();
|
||||
delete all_ops_info;
|
||||
Py_DECREF(op_info_loader);
|
||||
Py_DECREF(op_info_loader_class);
|
||||
Py_DECREF(c_expression_dict);
|
||||
Py_DECREF(c_expression);
|
||||
}
|
||||
|
||||
bool PythonIsInited() { return Py_IsInitialized() != 0; }
|
||||
} // namespace mindspore::api
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H
|
||||
#define MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
void RegAllOpFromPython();
|
||||
bool PythonIsInited();
|
||||
} // namespace mindspore::api
|
||||
|
||||
#endif // MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H
|
|
@ -14,9 +14,77 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "include/api/serialization.h"
|
||||
#include <fstream>
|
||||
#include "cxx_api/graph/graph_data.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "mindspore/core/load_mindir/load_model.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
static Buffer ReadFile(const std::string &file) {
|
||||
Buffer buffer;
|
||||
if (file.empty()) {
|
||||
MS_LOG(ERROR) << "Pointer file is nullptr";
|
||||
return buffer;
|
||||
}
|
||||
|
||||
char real_path_mem[PATH_MAX] = {0};
|
||||
char *real_path_ret = nullptr;
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
real_path_ret = _fullpath(real_path_mem, common::SafeCStr(file), PATH_MAX);
|
||||
#else
|
||||
real_path_ret = realpath(common::SafeCStr(file), real_path_mem);
|
||||
#endif
|
||||
|
||||
if (real_path_ret == nullptr) {
|
||||
MS_LOG(ERROR) << "File: " << file << " is not exist.";
|
||||
return buffer;
|
||||
}
|
||||
|
||||
std::string real_path(real_path_mem);
|
||||
std::ifstream ifs(real_path);
|
||||
if (!ifs.good()) {
|
||||
MS_LOG(ERROR) << "File: " << real_path << " is not exist";
|
||||
return buffer;
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
MS_LOG(ERROR) << "File: " << real_path << "open failed";
|
||||
return buffer;
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
size_t size = ifs.tellg();
|
||||
buffer.ResizeData(size);
|
||||
if (buffer.DataSize() != size) {
|
||||
MS_LOG(ERROR) << "Malloc buf failed, file: " << real_path;
|
||||
ifs.close();
|
||||
return buffer;
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
|
||||
ifs.close();
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
|
||||
Buffer data = ReadFile(file);
|
||||
if (model_type == kMindIR) {
|
||||
FuncGraphPtr anf_graph = nullptr;
|
||||
try {
|
||||
anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(data.Data()), data.DataSize());
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Load MindIR failed.";
|
||||
}
|
||||
|
||||
return Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
|
||||
} else if (model_type == kOM) {
|
||||
return Graph(std::make_shared<Graph::GraphData>(data, kOM));
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Unsupported ModelType " << model_type;
|
||||
}
|
||||
|
||||
Status Serialization::LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return FAILED;
|
||||
|
|
|
@ -19,6 +19,9 @@
|
|||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
const char *kDeviceTypeAscend310 = "Ascend310";
|
||||
const char *kDeviceTypeAscend910 = "Ascend910";
|
||||
|
||||
class DataImpl {
|
||||
public:
|
||||
DataImpl() : data_() {}
|
||||
|
|
|
@ -422,7 +422,6 @@ inline ValuePtr MakeValue(S v) {
|
|||
template <typename S, typename U = typename ImmTraits<S>::type>
|
||||
static S GetValue(const ValuePtr &value) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
|
||||
U imm = value->cast<U>();
|
||||
if (imm == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cast failed, original value: " << value->ToString() << ", type: " << value->type_name();
|
||||
|
|
|
@ -1,5 +1,10 @@
|
|||
#add flags
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-compare")
|
||||
|
||||
add_subdirectory("ut")
|
||||
|
||||
if (ENABLE_ACL)
|
||||
add_subdirectory(cxx_st)
|
||||
elseif (ENABLE_GPU OR ENABLE_D OR ENABLE_CPU)
|
||||
message(fatal "No need set -e xxx when compile ut")
|
||||
else ()
|
||||
add_subdirectory(ut)
|
||||
endif()
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
include_directories(${PYTHON_INCLUDE_DIRS})
|
||||
include_directories(${MS_CCSRC_PATH})
|
||||
include_directories(${CMAKE_SOURCE_DIR}/mindspore/core)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/stub/runtime/)
|
||||
include_directories(${CMAKE_BINARY_DIR})
|
||||
include_directories(${CUDA_INCLUDE_DIRS})
|
||||
|
||||
file(GLOB_RECURSE CXX_ST_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} *.cc)
|
||||
add_executable(st_tests ${CXX_ST_SRC})
|
||||
target_link_libraries(st_tests PRIVATE mindspore_shared_lib mindspore::gtest)
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common_test.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
#if __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace ST {
|
||||
|
||||
void Common::SetUpTestCase() {}
|
||||
|
||||
void Common::TearDownTestCase() {}
|
||||
|
||||
void Common::SetUp() {}
|
||||
|
||||
void Common::TearDown() {}
|
||||
|
||||
} // namespace ST
|
||||
|
||||
#ifdef __cplusplus
|
||||
#if __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
|
@ -0,0 +1,76 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef TESTS_CXX_ST_COMMON_COMMON_TEST_H_
|
||||
#define TESTS_CXX_ST_COMMON_COMMON_TEST_H_
|
||||
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include "gtest/gtest.h"
|
||||
namespace ST {
|
||||
class Common : public testing::Test {
|
||||
public:
|
||||
// TestCase only enter once
|
||||
static void SetUpTestCase();
|
||||
static void TearDownTestCase();
|
||||
|
||||
// every TEST_F macro will enter one
|
||||
virtual void SetUp();
|
||||
virtual void TearDown();
|
||||
|
||||
template <typename T>
|
||||
void PrintData(std::string name, T *output_data, int size) {
|
||||
std::cout << "The " << name << " is as follows:" << std::endl;
|
||||
if (typeid(output_data[0]) == typeid(uint8_t) || typeid(output_data[0]) == typeid(int8_t)) {
|
||||
for (size_t i = 0; i < std::min(size, 100); i++) {
|
||||
std::cout << (int)output_data[i] << " ";
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < std::min(size, 100); i++) {
|
||||
std::cout << output_data[i] << " ";
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void CompareOutputData(T *output_data, T *correct_data, int size, float err_bound) {
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
T abs = fabs(output_data[i] - correct_data[i]);
|
||||
ASSERT_LE(abs, err_bound);
|
||||
}
|
||||
}
|
||||
|
||||
void ReadFile(const char *file, size_t *size, char **buf) {
|
||||
ASSERT_NE(nullptr, file);
|
||||
ASSERT_NE(nullptr, size);
|
||||
ASSERT_NE(nullptr, buf);
|
||||
std::string path = std::string(file);
|
||||
std::ifstream ifs(path);
|
||||
ASSERT_EQ(true, ifs.good());
|
||||
ASSERT_EQ(true, ifs.is_open());
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
*size = ifs.tellg();
|
||||
*buf = new char[*size];
|
||||
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(*buf, *size);
|
||||
ifs.close();
|
||||
}
|
||||
};
|
||||
} // namespace ST
|
||||
#endif // TESTS_CXX_ST_COMMON_COMMON_TEST_H_
|
|
@ -0,0 +1,22 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
GTEST_API_ int main(int argc, char** argv) {
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
int ret = RUN_ALL_TESTS();
|
||||
return ret;
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "common/common_test.h"
|
||||
#include "include/api/model.h"
|
||||
#include "include/api/serialization.h"
|
||||
#include "include/api/context.h"
|
||||
|
||||
using namespace mindspore::api;
|
||||
|
||||
static const char tensor_add_file[] = "/home/workspace/mindspore_dataset/tensor_add/tensor_add.mindir";
|
||||
static const std::vector<float> input_data_1 = {1, 2, 3, 4};
|
||||
static const std::vector<float> input_data_2 = {2, 3, 4, 5};
|
||||
|
||||
class TestTensorAdd : public ST::Common {
|
||||
public:
|
||||
TestTensorAdd() {}
|
||||
};
|
||||
|
||||
TEST_F(TestTensorAdd, InferMindIR) {
|
||||
Context::Instance().SetDeviceTarget(kDeviceTypeAscend310).SetDeviceID(1);
|
||||
auto graph = Serialization::LoadModel(tensor_add_file, ModelType::kMindIR);
|
||||
Model tensor_add((GraphCell(graph)));
|
||||
Status ret = tensor_add.Build({});
|
||||
ASSERT_TRUE(ret == SUCCESS);
|
||||
|
||||
// prepare input
|
||||
std::vector<Buffer> outputs;
|
||||
std::vector<Buffer> inputs;
|
||||
inputs.emplace_back(Buffer(input_data_1.data(), sizeof(float) * input_data_1.size()));
|
||||
inputs.emplace_back(Buffer(input_data_2.data(), sizeof(float) * input_data_2.size()));
|
||||
|
||||
// infer
|
||||
ret = tensor_add.Predict(inputs, &outputs);
|
||||
ASSERT_TRUE(ret == SUCCESS);
|
||||
|
||||
// print
|
||||
for (auto &buffer : outputs) {
|
||||
const float *p = reinterpret_cast<const float *>(buffer.Data());
|
||||
for (size_t i = 0; i < buffer.DataSize() / sizeof(float); ++i) {
|
||||
ASSERT_LE(std::abs(p[i] - (input_data_1[i] + input_data_2[i])), 1e-4);
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue