read pb model
This commit is contained in:
parent
b7bc8e31f6
commit
59c6689512
|
@ -63,7 +63,7 @@ class MS_API MSTensor {
|
|||
// return A pointer points to data in MSTensor.
|
||||
virtual void *MutableData() const = 0;
|
||||
};
|
||||
using MultiTensor = std::vector<std::vector<std::shared_ptr<inference::MSTensor>>>;
|
||||
using MultiTensor = std::vector<std::shared_ptr<inference::MSTensor>>;
|
||||
} // namespace inference
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_MS_TENSOR_H_
|
||||
|
|
|
@ -216,6 +216,7 @@ class CNode : public AnfNode {
|
|||
void set_stop_gradient(bool stop_gradient) { stop_gradient_ = stop_gradient; }
|
||||
|
||||
std::string fullname_with_scope() override;
|
||||
void set_fullname_with_scope(const std::string full_name) { fullname_with_scope_ = full_name; }
|
||||
std::string DebugString(int recursive_level = 1) const override;
|
||||
std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); }
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ if (ENABLE_D)
|
|||
file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"ascend_session.cc"
|
||||
"ascend_control_parser.cc"
|
||||
"ascend_inference_session.cc"
|
||||
)
|
||||
list(APPEND _SESSION_SRC_LIST ${_D_SRC_LIST})
|
||||
endif ()
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "session/ascend_inference_session.h"
|
||||
#include "operator/ops.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/param_value_py.h"
|
||||
#include "device/kernel_runtime.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "common/utils.h"
|
||||
#include "common/trans.h"
|
||||
#include "kernel/tbe/tbe_python_funcs.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "utils/base_ref_extends.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
const std::vector<tensor::TensorPtr> &inputs_const) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
std::vector<tensor::TensorPtr> inputs(inputs_const);
|
||||
auto input_nodes = kernel_graph->inputs();
|
||||
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
size_t no_weight_input = 0;
|
||||
for (size_t i = 0; i < input_nodes.size(); ++i) {
|
||||
tensor::TensorPtr tensor = nullptr;
|
||||
if (!input_nodes[i]->isa<Parameter>()) {
|
||||
MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter";
|
||||
continue;
|
||||
}
|
||||
auto pk_node = input_nodes[i]->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(pk_node);
|
||||
if (AnfAlgo::IsParameterWeight(pk_node)) {
|
||||
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(pk_node->default_param());
|
||||
MS_EXCEPTION_IF_NULL(param_value);
|
||||
auto py_param = param_value->value();
|
||||
MS_EXCEPTION_IF_NULL(py_param);
|
||||
py::array py_array = py_param.cast<py::array>();
|
||||
tensor = std::make_shared<tensor::Tensor>(py_array);
|
||||
} else {
|
||||
tensor = inputs[no_weight_input++];
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
if (AnfAlgo::OutputAddrExist(pk_node, 0)) {
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
|
||||
bool need_sync = false;
|
||||
if (ms_context->enable_pynative_infer()) {
|
||||
if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) {
|
||||
need_sync = true;
|
||||
}
|
||||
} else {
|
||||
if (tensor->is_dirty()) {
|
||||
need_sync = true;
|
||||
} else if (tensor->device_address() != device_address) {
|
||||
(void)tensor->data_sync();
|
||||
need_sync = true;
|
||||
}
|
||||
}
|
||||
if (need_sync) {
|
||||
if (ms_context->execution_mode() == kPynativeMode || AnfAlgo::IsParameterWeight(pk_node)) {
|
||||
tensor->set_device_address(device_address);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
tensor->data_c(false))) {
|
||||
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
|
||||
}
|
||||
}
|
||||
}
|
||||
tensor->set_dirty(false);
|
||||
}
|
||||
}
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H
|
||||
#define MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <stack>
|
||||
#include <map>
|
||||
#include <tuple>
|
||||
#include <set>
|
||||
#include "session/ascend_session.h"
|
||||
#include "session/kernel_graph.h"
|
||||
#include "kernel/kernel.h"
|
||||
#include "session/session_factory.h"
|
||||
#include "session/ascend_control_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
class AscendInferenceSession : public AscendSession {
|
||||
public:
|
||||
AscendInferenceSession() = default;
|
||||
~AscendInferenceSession() = default;
|
||||
void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
const std::vector<tensor::TensorPtr> &inputs_const) const;
|
||||
};
|
||||
MS_REG_SESSION(kDavinciInferenceDevice, AscendInferenceSession);
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H
|
|
@ -124,7 +124,7 @@ MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector<std::shared_p
|
|||
});
|
||||
if (has_error) {
|
||||
MS_LOG(ERROR) << "Init Tensor failed, returning empty result";
|
||||
std::vector<std::vector<std::shared_ptr<inference::MSTensor>>> multiTensor;
|
||||
std::vector<std::shared_ptr<inference::MSTensor>> multiTensor;
|
||||
return multiTensor;
|
||||
}
|
||||
VectorRef outputs;
|
||||
|
@ -135,6 +135,9 @@ MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector<std::shared_p
|
|||
|
||||
int Session::Init(const std::string &device, uint32_t device_id) {
|
||||
RegAllOp();
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
ms_context->set_execution_mode(kGraphMode);
|
||||
ms_context->set_device_target(kAscendDevice);
|
||||
session_impl_ = session::SessionFactory::Get().Create(device);
|
||||
if (session_impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available.";
|
||||
|
|
|
@ -610,6 +610,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
|
|||
auto new_cnode = CreateNewCNode(cnode, graph.get());
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(cnode->abstract());
|
||||
new_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
|
||||
new_cnode->set_scope(cnode->scope());
|
||||
graph->FrontBackendlMapAdd(node, new_cnode);
|
||||
if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) {
|
||||
|
|
|
@ -21,20 +21,27 @@
|
|||
#include "ir/tensor.h"
|
||||
|
||||
namespace mindspore {
|
||||
std::vector<std::shared_ptr<inference::MSTensor>> TransformBaseRefToMSTensor(const BaseRef &base_ref) {
|
||||
std::vector<std::shared_ptr<inference::MSTensor>> msTensors;
|
||||
if (utils::isa<VectorRef>(base_ref)) {
|
||||
auto ref_list = utils::cast<VectorRef>(base_ref);
|
||||
void IterateFindTensor(std::vector<std::shared_ptr<inference::MSTensor>> *msTensors, const VectorRef &ref_list) {
|
||||
for (size_t i = 0; i < ref_list.size(); ++i) {
|
||||
if (utils::isa<tensor::Tensor>(ref_list[i])) {
|
||||
if (utils::isa<tensor::TensorPtr>(ref_list[i])) {
|
||||
auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(ref_list[i]);
|
||||
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||
auto tensor = new inference::Tensor(tensor_ptr);
|
||||
msTensors.emplace_back(std::shared_ptr<inference::MSTensor>(tensor));
|
||||
msTensors->emplace_back(std::shared_ptr<inference::MSTensor>(tensor));
|
||||
} else if (utils::isa<VectorRef>(ref_list[i])) {
|
||||
auto ref_iter = utils::cast<VectorRef>(ref_list[i]);
|
||||
IterateFindTensor(msTensors, ref_iter);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The output is not a tensor!";
|
||||
MS_LOG(EXCEPTION) << "The output is not a tensor";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<inference::MSTensor>> TransformVectorRefToMultiTensor(const VectorRef &base_ref) {
|
||||
std::vector<std::shared_ptr<inference::MSTensor>> msTensors;
|
||||
if (utils::isa<VectorRef>(base_ref)) {
|
||||
auto ref_list = utils::cast<VectorRef>(base_ref);
|
||||
IterateFindTensor(&msTensors, ref_list);
|
||||
} else if (utils::isa<tensor::Tensor>(base_ref)) {
|
||||
auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(base_ref);
|
||||
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||
|
@ -45,14 +52,4 @@ std::vector<std::shared_ptr<inference::MSTensor>> TransformBaseRefToMSTensor(con
|
|||
}
|
||||
return msTensors;
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::shared_ptr<inference::MSTensor>>> TransformVectorRefToMultiTensor(
|
||||
const VectorRef &vector_ref) {
|
||||
std::vector<std::vector<std::shared_ptr<inference::MSTensor>>> multiTensor;
|
||||
for (size_t i = 0; i < vector_ref.size(); ++i) {
|
||||
auto tensors = TransformBaseRefToMSTensor(vector_ref[i]);
|
||||
multiTensor.emplace_back(tensors);
|
||||
}
|
||||
return multiTensor;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,9 +22,6 @@
|
|||
#ifndef MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H
|
||||
#define MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H
|
||||
namespace mindspore {
|
||||
std::vector<std::shared_ptr<inference::MSTensor>> TransformBaseRefToMSTensor(const BaseRef &base_ref);
|
||||
|
||||
std::vector<std::vector<std::shared_ptr<inference::MSTensor>>> TransformVectorRefToMultiTensor(
|
||||
const VectorRef &vector_ref);
|
||||
std::vector<std::shared_ptr<inference::MSTensor>> TransformVectorRefToMultiTensor(const VectorRef &base_ref);
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H
|
||||
|
|
|
@ -41,6 +41,7 @@ const int kPynativeMode = 1;
|
|||
const char kCPUDevice[] = "CPU";
|
||||
const char kGPUDevice[] = "GPU";
|
||||
const char kAscendDevice[] = "Ascend";
|
||||
const char kDavinciInferenceDevice[] = "AscendInference";
|
||||
const char kDavinciDevice[] = "Davinci";
|
||||
const char KNpuLog[] = "_npu_log";
|
||||
const std::set<std::string> kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, kDavinciDevice};
|
||||
|
|
|
@ -96,8 +96,6 @@ std::shared_ptr<FuncGraph> AnfConverter::RunAnfConverter(const std::string &file
|
|||
ReadOnnxFromBinary(modelFile, &model_);
|
||||
MSANFModelParser model_parser;
|
||||
FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_);
|
||||
MS_EXCEPTION_IF_NULL(dstgraph_ptr);
|
||||
TestFuncGraphBuild(dstgraph_ptr);
|
||||
return dstgraph_ptr;
|
||||
}
|
||||
|
||||
|
@ -111,33 +109,7 @@ std::shared_ptr<FuncGraph> AnfConverter::RunAnfConverter(const char *buf, const
|
|||
}
|
||||
MSANFModelParser model_parser;
|
||||
FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_);
|
||||
MS_EXCEPTION_IF_NULL(dstgraph_ptr);
|
||||
TestFuncGraphBuild(dstgraph_ptr);
|
||||
return dstgraph_ptr;
|
||||
}
|
||||
|
||||
int AnfConverter::TestFuncGraphBuild(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto node_return = graph->get_return();
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(node_return);
|
||||
MS_LOG(INFO) << "node_list size is : " << node_list.size();
|
||||
for (auto &node : node_list) {
|
||||
if (node->isa<CNode>()) {
|
||||
auto node_CN = node->cast<CNodePtr>();
|
||||
MS_LOG(INFO) << "CN node: " << node_CN->input(0)->ToString() << ", input size :" << node_CN->size();
|
||||
} else if (node->isa<Parameter>()) {
|
||||
auto node_Para = node->cast<ParameterPtr>();
|
||||
if (node_Para->has_default()) {
|
||||
MS_LOG(INFO) << "Parameter node: " << node_Para->name() << "has default value!";
|
||||
} else {
|
||||
MS_LOG(INFO) << "Parameter node: " << node_Para->name();
|
||||
}
|
||||
} else if (node->isa<ValueNode>()) {
|
||||
auto node_Value = node->cast<ValueNodePtr>();
|
||||
MS_LOG(INFO) << "Value node: " << node_Value->ToString();
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,7 +26,6 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
class AnfConverter {
|
||||
public:
|
||||
static int TestFuncGraphBuild(const FuncGraphPtr &graph);
|
||||
static std::shared_ptr<FuncGraph> RunAnfConverter(const std::string &file_path);
|
||||
static std::shared_ptr<FuncGraph> RunAnfConverter(const char *buf, const size_t buf_size);
|
||||
|
||||
|
|
|
@ -14,16 +14,17 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "utils/load_onnx/anf_model_parser.h"
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "utils/load_onnx/anf_model_parser.h"
|
||||
#include "google/protobuf/io/zero_copy_stream_impl.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "ir/param_value_py.h"
|
||||
#include "operator/ops.h"
|
||||
#include "pipeline/static_analysis/abstract_value.h"
|
||||
#include "proto/onnx.pb.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
|
@ -33,6 +34,8 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
static constexpr char kConstantValueNode[] = "Constant";
|
||||
static constexpr char kCNodeShapeAttr[] = "shape";
|
||||
static constexpr char kCNodeShape1Attr[] = "shape1";
|
||||
static constexpr char kCNodeShape2Attr[] = "shape2";
|
||||
enum ParseForm : int {
|
||||
FORM_PARSE_TYPE = 0,
|
||||
FORM_PARSE_SCALAR = 1,
|
||||
|
@ -56,14 +59,15 @@ static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{
|
|||
void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \
|
||||
const onnx::TensorProto &attr_tensor) { \
|
||||
MS_EXCEPTION_IF_NULL(prim); \
|
||||
std::vector<valuetype> attr_value_vec; \
|
||||
std::vector<ValuePtr> attr_value_vec; \
|
||||
for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \
|
||||
attr_value_vec.push_back(static_cast<valuetype>(attr_tensor.type##_data(i))); \
|
||||
auto value = static_cast<valuetype>(attr_tensor.type##_data(i)); \
|
||||
attr_value_vec.push_back(MakeValue<valuetype>(value)); \
|
||||
} \
|
||||
if (attr_value_vec.size() == 1) { \
|
||||
prim->AddAttr(attr_name, MakeValue<valuetype>(attr_value_vec[0])); \
|
||||
prim->AddAttr(attr_name, attr_value_vec[0]); \
|
||||
} else { \
|
||||
prim->AddAttr(attr_name, MakeValue<std::vector<valuetype>>(attr_value_vec)); \
|
||||
prim->AddAttr(attr_name, std::make_shared<ValueList>(attr_value_vec)); \
|
||||
} \
|
||||
}
|
||||
|
||||
|
@ -247,17 +251,12 @@ bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node
|
|||
const std::string &tensor_buf = attr_tensor.raw_data();
|
||||
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c(true));
|
||||
memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
|
||||
if (attr_tensor_type == onnx::TensorProto_DataType_FLOAT) {
|
||||
auto *data_valuennode = reinterpret_cast<float *>(tensor_info->data_c());
|
||||
MS_EXCEPTION_IF_NULL(data_valuennode);
|
||||
auto new_value_node = std::make_shared<ValueNode>(MakeValue(*data_valuennode));
|
||||
auto new_value_node = NewValueNode(MakeValue(tensor_info));
|
||||
MS_EXCEPTION_IF_NULL(new_value_node);
|
||||
auto tensor_abstract = tensor_info->ToAbstract();
|
||||
MS_EXCEPTION_IF_NULL(tensor_abstract);
|
||||
new_value_node->set_abstract(tensor_abstract);
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
} else {
|
||||
auto *data_valuenode = reinterpret_cast<int32 *>(tensor_info->data_c());
|
||||
MS_EXCEPTION_IF_NULL(data_valuenode);
|
||||
auto new_value_node = std::make_shared<ValueNode>(MakeValue(*data_valuenode));
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -315,7 +314,9 @@ bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_n
|
|||
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
|
||||
return false;
|
||||
}
|
||||
auto new_value_node = std::make_shared<ValueNode>(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
|
||||
auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
|
||||
abstract::AbstractTypePtr abs_type = std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>());
|
||||
new_value_node->set_abstract(abs_type);
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
return true;
|
||||
}
|
||||
|
@ -361,31 +362,45 @@ AbstractBasePtr MSANFModelParser::GetAbstractForCNode(const onnx::AttributeProto
|
|||
tensor::TensorPtr tensor_info =
|
||||
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec);
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
return tensor_info->ToAbstract();
|
||||
auto abstract = tensor_info->ToAbstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
return abstract;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto,
|
||||
const onnx::GraphProto &importProto, const bool &ret_flag) {
|
||||
CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::NodeProto &node_proto) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
if (!node_proto.has_op_type()) {
|
||||
MS_LOG(ERROR) << "Get CNode op_type failed!";
|
||||
return false;
|
||||
return nullptr;
|
||||
}
|
||||
const std::string &node_name = node_proto.output(0);
|
||||
const std::string &fullname_with_scope = node_proto.domain();
|
||||
const std::string &node_type = node_proto.op_type();
|
||||
PrimitivePtr prim = std::make_shared<Primitive>(node_type);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
prim->set_instance_name(node_type);
|
||||
|
||||
AbstractBasePtr abstract;
|
||||
AbstractBasePtr abstract = nullptr;
|
||||
AbstractBasePtr abstract_first = nullptr;
|
||||
AbstractBasePtr abstract_second = nullptr;
|
||||
for (int i = 0; i < node_proto.attribute_size(); ++i) {
|
||||
const onnx::AttributeProto &attr_proto = node_proto.attribute(i);
|
||||
if (attr_proto.name() == kCNodeShapeAttr) {
|
||||
abstract = GetAbstractForCNode(attr_proto);
|
||||
continue;
|
||||
}
|
||||
if (attr_proto.name() == kCNodeShape1Attr) {
|
||||
abstract_first = GetAbstractForCNode(attr_proto);
|
||||
continue;
|
||||
}
|
||||
if (attr_proto.name() == kCNodeShape2Attr) {
|
||||
abstract_second = GetAbstractForCNode(attr_proto);
|
||||
continue;
|
||||
}
|
||||
if (!GetAttrValueForCNode(prim, attr_proto)) {
|
||||
MS_LOG(ERROR) << "Get CNode attr failed!";
|
||||
return false;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -396,16 +411,64 @@ bool MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGrap
|
|||
const std::string &input_name = node_proto.input(i);
|
||||
if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) {
|
||||
MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed";
|
||||
return false;
|
||||
return nullptr;
|
||||
}
|
||||
inputs.push_back(anfnode_build_map_[input_name]);
|
||||
}
|
||||
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
if (node_type == "LayerNorm") {
|
||||
AbstractBasePtrList elem;
|
||||
elem.push_back(abstract);
|
||||
elem.push_back(abstract_first);
|
||||
elem.push_back(abstract_second);
|
||||
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
||||
} else if (node_type == "ArgMaxWithValue") {
|
||||
AbstractBasePtrList elem;
|
||||
elem.push_back(abstract);
|
||||
elem.push_back(abstract_first);
|
||||
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
||||
} else if (nullptr == abstract) {
|
||||
AbstractBasePtrList elem;
|
||||
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
|
||||
elem.push_back(cnode_ptr->input(index)->abstract());
|
||||
}
|
||||
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
||||
} else {
|
||||
cnode_ptr->set_abstract(abstract);
|
||||
if (ret_flag) {
|
||||
}
|
||||
cnode_ptr->set_fullname_with_scope(fullname_with_scope);
|
||||
anfnode_build_map_[node_name] = cnode_ptr;
|
||||
return cnode_ptr;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||
const CNodePtr &cnode_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
if (importProto.output_size() > 1) {
|
||||
inputs.clear();
|
||||
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
AbstractBasePtrList elem;
|
||||
for (int out_size = 0; out_size < importProto.output_size(); ++out_size) {
|
||||
const onnx::ValueInfoProto &output_node = importProto.output(out_size);
|
||||
const std::string &out_tuple = output_node.name();
|
||||
inputs.push_back(anfnode_build_map_[out_tuple]);
|
||||
elem.push_back(anfnode_build_map_[out_tuple]->abstract());
|
||||
}
|
||||
auto maketuple_ptr = outputFuncGraph->NewCNode(inputs);
|
||||
maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
||||
inputs.clear();
|
||||
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
||||
inputs.push_back(maketuple_ptr);
|
||||
auto return_node = outputFuncGraph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
outputFuncGraph->set_return(return_node);
|
||||
MS_LOG(INFO) << "Construct funcgraph finined, all success.";
|
||||
} else {
|
||||
const onnx::ValueInfoProto &output_node = importProto.output(0);
|
||||
const ::onnx::TypeProto &output_typeproto = output_node.type();
|
||||
const onnx::TypeProto &output_typeproto = output_node.type();
|
||||
int output_type = output_typeproto.tensor_type().elem_type();
|
||||
std::vector<int> output_shape;
|
||||
for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) {
|
||||
|
@ -417,20 +480,19 @@ bool MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGrap
|
|||
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
||||
inputs.push_back(cnode_ptr);
|
||||
auto return_node = outputFuncGraph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
return_node->set_abstract(tensor_return->ToAbstract());
|
||||
outputFuncGraph->set_return(return_node);
|
||||
MS_LOG(INFO) << "Construct funcgraph finined, all success!";
|
||||
}
|
||||
anfnode_build_map_[node_name] = cnode_ptr;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
bool return_flag = false;
|
||||
MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size();
|
||||
CNodePtr cnode_ptr = nullptr;
|
||||
for (int i = 0; i < importProto.node_size(); ++i) {
|
||||
return_flag = (i == importProto.node_size() - 1) ? true : return_flag;
|
||||
const onnx::NodeProto &node_proto = importProto.node(i);
|
||||
const std::string &node_type = node_proto.op_type();
|
||||
if (node_type == kConstantValueNode) {
|
||||
|
@ -440,11 +502,14 @@ bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
|
|||
}
|
||||
continue;
|
||||
}
|
||||
if (!BuildCNodeForFuncGraph(outputFuncGraph, node_proto, importProto, return_flag)) {
|
||||
cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto);
|
||||
if (cnode_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -472,12 +537,12 @@ bool MSANFModelParser::MSANFParseModelConfigureInfo(const onnx::ModelProto &mode
|
|||
producer_name_ = model_proto.producer_name();
|
||||
MS_LOG(INFO) << "producer_name :" << producer_name_;
|
||||
|
||||
if (!model_proto.has_producer_version()) {
|
||||
if (!model_proto.has_model_version()) {
|
||||
MS_LOG(ERROR) << "Parse model producer version from pb file failed!";
|
||||
return false;
|
||||
}
|
||||
producer_version_ = model_proto.producer_version();
|
||||
MS_LOG(INFO) << "producer_version : " << producer_version_;
|
||||
model_version_ = model_proto.model_version();
|
||||
MS_LOG(INFO) << "producer_version : " << model_version_;
|
||||
|
||||
if (!model_proto.has_ir_version()) {
|
||||
MS_LOG(ERROR) << "Parse model version from pb file failed!";
|
||||
|
@ -485,14 +550,6 @@ bool MSANFModelParser::MSANFParseModelConfigureInfo(const onnx::ModelProto &mode
|
|||
}
|
||||
ir_version_ = model_proto.ir_version();
|
||||
MS_LOG(INFO) << "ir_version :" << ir_version_;
|
||||
|
||||
const onnx::OperatorSetIdProto &opset_proto = model_proto.opset_import(0);
|
||||
if (!opset_proto.has_version()) {
|
||||
MS_LOG(ERROR) << "Parse opset version from pb file failed!";
|
||||
return false;
|
||||
}
|
||||
opset_version_ = opset_proto.version();
|
||||
MS_LOG(INFO) << "opset_version : " << opset_version_;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -501,7 +558,6 @@ FuncGraphPtr MSANFModelParser::Parse(const onnx::ModelProto &model_proto) {
|
|||
MS_EXCEPTION_IF_NULL(dstGraph);
|
||||
if (!MSANFParseModelConfigureInfo(model_proto)) {
|
||||
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
|
||||
return nullptr;
|
||||
}
|
||||
const onnx::GraphProto &graphBuild = model_proto.graph();
|
||||
if (!BuildFuncGraph(dstGraph, graphBuild)) {
|
||||
|
|
|
@ -29,6 +29,7 @@ namespace lite {
|
|||
using int32 = int32_t;
|
||||
using int64 = int64_t;
|
||||
using uint64 = uint64_t;
|
||||
using float16 = Eigen::half;
|
||||
class MSANFModelParser {
|
||||
public:
|
||||
MSANFModelParser() = default;
|
||||
|
@ -38,17 +39,17 @@ class MSANFModelParser {
|
|||
bool MSANFParseModelConfigureInfo(const onnx::ModelProto &model_proto);
|
||||
|
||||
std::string GetProducerName() { return producer_name_; }
|
||||
std::string GetProducerVersion() { return producer_version_; }
|
||||
int GetProducerVersion() { return model_version_; }
|
||||
int GetIrVersion() { return ir_version_; }
|
||||
int GetOpsetVersion() { return opset_version_; }
|
||||
|
||||
private:
|
||||
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
|
||||
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
|
||||
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
|
||||
bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto);
|
||||
bool BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto,
|
||||
const onnx::GraphProto &importProto, const bool &ret_flag);
|
||||
CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto);
|
||||
bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||
const CNodePtr &cnode_ptr);
|
||||
bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto);
|
||||
bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor);
|
||||
|
@ -63,15 +64,13 @@ class MSANFModelParser {
|
|||
bool GetAttrValueForValueNode(const string &ref_attr_name, const std::string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor);
|
||||
bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
|
||||
AbstractBasePtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
|
||||
|
||||
std::string producer_name_;
|
||||
std::string producer_version_;
|
||||
int ir_version_{};
|
||||
int opset_version_{};
|
||||
int model_version_;
|
||||
int ir_version_;
|
||||
std::unordered_map<std::string, AnfNodePtr> anfnode_build_map_;
|
||||
std::map<std::string, onnx::TensorProto> default_para_map_;
|
||||
|
||||
AbstractBasePtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue