!24083 refactor acl vm && acl vm ignore monad input/output

Merge pull request !24083 from zhoufeng/xiu-ba-ge-2
This commit is contained in:
i-robot 2021-09-26 01:20:04 +00:00 committed by Gitee
commit 46fc24b172
11 changed files with 711 additions and 408 deletions

View File

@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cxx_api/model/acl/acl_model_multi.h"
#include <vector>
#include <utility>
@ -23,16 +22,11 @@
#include <numeric>
#include <deque>
#include <functional>
#include "backend/session/session_basic.h"
#include "backend/session/session_factory.h"
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h"
#include "cxx_api/factory.h"
#include "vm/backend.h"
#include "vm/transform.h"
#include "acl/acl_rt.h"
#include "mindspore/core/load_mindir/infer_mindir.h"
#include "debug/trace.h"
#include "cxx_api/model/acl/acl_vm/ms_tensor_ref.h"
#include "cxx_api/model/acl/acl_vm/acl_vm.h"
namespace mindspore {
API_FACTORY_REG(ModelImpl, Ascend310, AclModelMulti);
@ -46,366 +40,6 @@ std::map<DataType, size_t> kDtypeMap = {
{DataType::kNumberTypeUInt8, sizeof(uint8_t)}, {DataType::kNumberTypeUInt16, sizeof(uint16_t)},
{DataType::kNumberTypeUInt32, sizeof(uint32_t)}, {DataType::kNumberTypeUInt64, sizeof(uint64_t)}};
class MSTensorRef : public BaseRef {
public:
static VectorRef Convert(const std::vector<MSTensor> &tensors) {
VectorRef res;
std::transform(tensors.begin(), tensors.end(), std::back_inserter(res),
[](const MSTensor &t) { return MSTensorRef(t); });
return res;
}
static std::vector<MSTensor> Convert(const BaseRef &args) {
std::vector<MSTensor> res;
if (utils::isa<VectorRef>(args)) {
VectorRef args_vec = utils::cast<VectorRef>(args);
res = ConvertTuple(args_vec);
} else if (utils::isa<MSTensorRef>(args)) {
auto wrapper = utils::cast<MSTensorRef>(args);
res.push_back(wrapper.ms_tensor_);
} else {
MS_LOG(EXCEPTION) << "Invalid BaseRef " << args.ToString() << " must be MSTensorRef or VectorRef{MSTensorRef...}";
}
return res;
}
MS_DECLARE_PARENT(MSTensorRef, BaseRef);
explicit MSTensorRef(const MSTensor &tensor) : ms_tensor_(tensor) {}
~MSTensorRef() override = default;
const MSTensor &GetTensor() const { return ms_tensor_; }
std::shared_ptr<Base> copy() const override {
MSTensor *tensor = ms_tensor_.Clone();
auto res = std::make_shared<MSTensorRef>(static_cast<const MSTensor &>(*tensor));
MSTensor::DestroyTensorPtr(tensor);
return res;
}
uint32_t type() const override { return tid(); }
std::string ToString() const override { return ms_tensor_.Name(); }
bool operator==(const BaseRef &other) const override {
if (!utils::isa<MSTensorRef>(other)) {
return false;
}
auto other_ms_tensor = utils::cast<MSTensorRef>(other).ms_tensor_;
auto this_ms_tensor = ms_tensor_;
return (this_ms_tensor.Name() == other_ms_tensor.Name()) && (this_ms_tensor.Shape() == other_ms_tensor.Shape()) &&
(this_ms_tensor.MutableData() == other_ms_tensor.MutableData()) &&
(this_ms_tensor.DataSize() == other_ms_tensor.DataSize()) &&
(this_ms_tensor.DataType() == other_ms_tensor.DataType());
}
private:
static std::vector<MSTensor> ConvertTuple(const VectorRef &args) {
std::vector<MSTensor> outs;
for (size_t i = 0; i < args.size(); ++i) {
const auto &item = args[i];
if (utils::isa<VectorRef>(item)) {
VectorRef args_vec = utils::cast<VectorRef>(args);
auto ret = ConvertTuple(args_vec);
outs.insert(outs.end(), ret.begin(), ret.end());
} else if (utils::isa<MSTensorRef>(item)) {
auto wrapper = utils::cast<MSTensorRef>(item);
outs.push_back(wrapper.ms_tensor_);
} else {
MS_LOG(EXCEPTION) << "Invalid BaseRef " << args.ToString()
<< " must be MSTensorRef or VectorRef{MSTensorRef...}";
}
}
return outs;
}
MSTensor ms_tensor_;
};
class MultiGraphAclSession : public session::SessionBasic {
public:
MultiGraphAclSession() = default;
~MultiGraphAclSession() override = default;
void Init(uint32_t device_id) override;
GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
void RunGraph(GraphId graph_id, const std::vector<MSTensor> &inputs, VectorRef *outputs);
void SetOptions(const std::shared_ptr<AclModelOptions> &options) { options_ = options; }
private:
VectorRef ConstructOutputRef(GraphId graph_id, std::deque<MSTensor> *out_tensors);
VectorRef ConstructOutputRefByTupleNode(const CNodePtr &tuple_node, std::deque<MSTensor> *out_tensors);
std::map<GraphId, GraphCell> graphs_ = {};
std::map<GraphId, KernelGraphPtr> kernel_graphs_ = {};
std::shared_ptr<AclModelOptions> options_ = nullptr;
};
void MultiGraphAclSession::Init(uint32_t device_id) { InitExecutor(kDavinciMultiGraphInferenceDevice, device_id); }
GraphId MultiGraphAclSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
class FirstGraphModeGuard {
public:
explicit FirstGraphModeGuard(const std::shared_ptr<AclModelOptions> &options) : options_(options) {
if (options_ != nullptr) {
options_->SetFirstGraph(true);
}
}
~FirstGraphModeGuard() {
if (options_ != nullptr) {
options_->SetFirstGraph(false);
}
}
private:
std::shared_ptr<AclModelOptions> options_;
};
MS_LOG(INFO) << "Start MultiGraph Compile.";
// construct kernel graph
auto kernel_graph = SessionBasic::ConstructKernelGraph(lst, outputs, false);
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>("310_multi_graph_pm");
pm->AddPass(std::make_shared<opt::InsertPlaceholderForDynamicRNN>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
// concert to om data
ModelConverter model_converter_;
model_converter_.set_options(options_);
FirstGraphModeGuard guard(options_);
auto om_data = model_converter_.LoadMindIR(kernel_graph);
if (om_data.Data() == nullptr || om_data.DataSize() == 0) {
MS_LOG(ERROR) << "Load MindIR failed.";
return kMCFailed;
}
// load
std::shared_ptr<Graph> graph = std::make_shared<Graph>(std::make_shared<Graph::GraphData>(om_data, ModelType::kOM));
MS_EXCEPTION_IF_NULL(graph);
auto graph_cell = GraphCell(graph);
auto ret = graph_cell.Load(options_->GetDeviceID());
if (ret != kSuccess) {
MS_LOG(EXCEPTION) << "Load failed.";
}
graphs_[kernel_graph->graph_id()] = graph_cell;
kernel_graphs_[kernel_graph->graph_id()] = kernel_graph;
MS_LOG(INFO) << "Mulit graph compile success, graph id " << kernel_graph->graph_id();
return kernel_graph->graph_id();
}
void MultiGraphAclSession::RunGraph(GraphId graph_id, const std::vector<MSTensor> &inputs, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
MS_LOG(INFO) << "Start run graph " << graph_id;
auto iter = graphs_.find(graph_id);
if (iter == graphs_.end()) {
MS_LOG(EXCEPTION) << "Graph id " << graph_id << " not found.";
}
std::vector<MSTensor> out_tensors;
auto ret = iter->second.Run(inputs, &out_tensors);
if (ret != kSuccess) {
MS_LOG(EXCEPTION) << "Graph id " << graph_id << " run failed.";
}
std::deque<MSTensor> out_tensors_deque(out_tensors.begin(), out_tensors.end());
(*outputs) = ConstructOutputRef(graph_id, &out_tensors_deque);
}
VectorRef MultiGraphAclSession::ConstructOutputRef(GraphId graph_id, std::deque<MSTensor> *out_tensors) {
MS_EXCEPTION_IF_NULL(out_tensors);
VectorRef outs;
auto out_nodes = kernel_graphs_[graph_id]->outputs();
for (auto &out : out_nodes) {
if (out_tensors->empty()) {
MS_LOG(EXCEPTION) << "Can not find MSTensor for output node " << out->DebugString();
}
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(out, 0);
auto &anf_node = item_with_index.first;
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) {
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
outs.emplace_back(ConstructOutputRefByTupleNode(cnode, out_tensors));
} else {
outs.emplace_back(MSTensorRef(out_tensors->front()));
out_tensors->pop_front();
}
}
if (!out_tensors->empty()) {
MS_LOG(EXCEPTION) << "Number of output size " << outs.size() << " but " << out_tensors->size()
<< " MSTensor remained.";
}
return outs;
}
VectorRef MultiGraphAclSession::ConstructOutputRefByTupleNode(const CNodePtr &tuple_node,
std::deque<MSTensor> *out_tensors) {
MS_EXCEPTION_IF_NULL(out_tensors);
VectorRef outs;
for (size_t i = 1; i < tuple_node->inputs().size(); ++i) {
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(tuple_node->input(i), 0);
auto &anf_node = item_with_index.first;
if (out_tensors->empty()) {
MS_LOG(EXCEPTION) << "Can not find MSTensor for output node " << anf_node->DebugString();
}
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) {
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
outs.emplace_back(ConstructOutputRefByTupleNode(cnode, out_tensors));
} else {
outs.emplace_back(MSTensorRef(out_tensors->front()));
out_tensors->pop_front();
}
}
return outs;
}
class AclBackend : public compile::MsBackend {
public:
AclBackend(const std::string &name, const std::string &target, const std::shared_ptr<AclModelOptions> &options)
: MsBackend(name, target, options->GetDeviceID()) {
auto session = std::dynamic_pointer_cast<MultiGraphAclSession>(MsBackend::target_sess_);
MS_EXCEPTION_IF_NULL(session);
session->SetOptions(options);
}
~AclBackend() override = default;
VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) override {
std::vector<MSTensor> inputs;
for (const auto &arg : args) {
if (!utils::isa<MSTensorRef>(arg)) {
MS_LOG(EXCEPTION) << "Invalid item " << arg.ToString();
}
auto wrapper = utils::cast<MSTensorRef>(arg);
inputs.emplace_back(wrapper.GetTensor());
}
VectorRef outputs;
MS_EXCEPTION_IF_NULL(target_sess_);
auto exec_sess = std::dynamic_pointer_cast<MultiGraphAclSession>(target_sess_);
MS_EXCEPTION_IF_NULL(exec_sess);
exec_sess->RunGraph(g, inputs, &outputs);
return outputs;
}
bool GetCond(const BaseRef &c, bool *value) override {
MS_EXCEPTION_IF_NULL(value);
if (!utils::isa<MSTensorRef>(c)) {
MS_LOG(ERROR) << "Invalid item " << c.ToString() << " must be a MSTensorRef.";
return false;
}
auto wrapper = utils::cast<MSTensorRef>(c);
if (wrapper.GetTensor().DataType() != DataType::kNumberTypeBool) {
MS_LOG(ERROR) << "Invalid data type " << wrapper.GetTensor().DataType() << " must be bool.";
return false;
}
auto data = wrapper.GetTensor().Data();
if (data == nullptr) {
return false;
}
(*value) = *reinterpret_cast<const bool *>(data.get());
return true;
}
bool GetIndex(const BaseRef &c, int64_t *value) override {
MS_EXCEPTION_IF_NULL(value);
if (!utils::isa<MSTensorRef>(c)) {
MS_LOG(ERROR) << "Invalid item " << c.ToString() << " must be a MSTensorRef.";
return false;
}
auto wrapper = utils::cast<MSTensorRef>(c);
if (wrapper.GetTensor().DataType() == DataType::kNumberTypeInt32) {
auto data = wrapper.GetTensor().Data();
if (data == nullptr) {
return false;
}
auto value_int32 = *reinterpret_cast<const int32_t *>(data.get());
(*value) = static_cast<int64_t>(value_int32);
return true;
} else if (wrapper.GetTensor().DataType() == DataType::kNumberTypeInt64) {
auto data = wrapper.GetTensor().Data();
if (data == nullptr) {
return false;
}
(*value) = *reinterpret_cast<const int64_t *>(data.get());
return true;
} else {
MS_LOG(ERROR) << "Index must be Int type.";
return false;
}
}
};
class AclCompileGraph : public compile::CompileGraph {
public:
explicit AclCompileGraph(const std::shared_ptr<compile::MsBackend> &backend,
const std::vector<PrimitivePtr> &cut_list)
: CompileGraph(backend, cut_list) {}
~AclCompileGraph() override = default;
void AddInst(const compile::Instruction &inst, const MSTensorRef &arg) {
VectorRef args;
args.push_back(arg);
compile::CompileGraph::AddInst(inst, args);
}
int64_t Ref(const AnfNodePtr &node) override {
MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << "Start Ref node " << node->DebugString(true) << " height_: " << height_;
if (slots_.count(node) == 0 && node->isa<ValueNode>()) {
if (IsValueNode<FuncGraph>(node)) {
MS_LOG(DEBUG) << "Push graph.";
compile::CompileGraph::AddInst(compile::Instruction::kGraph, GetValueNode(node));
} else {
MS_LOG(DEBUG) << "Push.";
if (IsValueNode<Primitive>(node)) {
MS_LOG(EXCEPTION) << "must not be primitive in here NodeInfo: " << trace::GetDebugInfo(node->debug_info());
} else if (IsValueNode<tensor::Tensor>(node)) {
auto tensor_node = std::dynamic_pointer_cast<tensor::Tensor>(node->cast<ValueNodePtr>()->value());
MS_EXCEPTION_IF_NULL(tensor_node);
std::string name = "";
std::vector<int64_t> shape = tensor_node->shape_c();
DataType type = static_cast<DataType>(tensor_node->data_type_c());
auto mstensor_node = MSTensor::CreateRefTensor(name, type, shape, tensor_node->data_c(), tensor_node->Size());
MSTensorRef mstensor_ref(*mstensor_node);
AddInst(compile::Instruction::kPush, mstensor_ref);
MSTensor::DestroyTensorPtr(mstensor_node);
} else {
compile::CompileGraph::AddInst(compile::Instruction::kPush, GetValueNode(node));
}
}
Push(node);
}
MS_LOG(DEBUG) << "End Ref node end height_: " << height_ << ", slots: " << slots_[node]
<< ", return: " << slots_[node] - height_;
return slots_[node] - height_;
}
};
class AclCompileGraphs : public compile::CompileGraphs {
public:
explicit AclCompileGraphs(const std::shared_ptr<compile::MsBackend> &backend,
const std::vector<PrimitivePtr> &cut_list)
: CompileGraphs(backend, cut_list) {
MS_EXCEPTION_IF_NULL(backend);
MS_LOG(DEBUG) << "Start vm: " << backend->name();
transform_ = std::make_shared<AclCompileGraph>(backend, cut_list);
Reset();
}
~AclCompileGraphs() override = default;
void Compile(const FuncGraphPtr &graph) override {
MS_LOG(DEBUG) << "Start";
mapping_[graph] = SizeToLong(insts_.size());
if (transform_ != nullptr) {
auto insts = transform_->Run(graph, false);
if (!insts.empty()) {
(void)insts_.insert(insts_.end(), insts.begin(), insts.end());
}
}
MS_LOG(DEBUG) << "End";
}
};
std::shared_ptr<compile::MsBackend> CreateBackend(const std::shared_ptr<AclModelOptions> &options) {
MS_EXCEPTION_IF_NULL(options);
return std::make_shared<AclBackend>(kMsConvert, kDavinciMultiGraphInferenceDevice, options);
@ -610,8 +244,4 @@ std::vector<MSTensor> AclModelMulti::GetOutputs() {
return outputs_;
}
namespace session {
MS_REG_SESSION(kDavinciMultiGraphInferenceDevice, MultiGraphAclSession);
} // namespace session
} // namespace mindspore

View File

@ -0,0 +1,155 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cxx_api/model/acl/acl_vm/acl_multi_graph_session.h"
#include <memory>
#include <deque>
#include <vector>
#include "backend/session/session_factory.h"
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h"
#include "cxx_api/model/acl/model_converter.h"
#include "cxx_api/model/acl/acl_model_options.h"
#include "cxx_api/model/acl/acl_vm/ms_tensor_ref.h"
#include "cxx_api/graph/graph_data.h"
namespace mindspore::session {
void MultiGraphAclSession::Init(uint32_t device_id) { InitExecutor(kDavinciMultiGraphInferenceDevice, device_id); }
GraphId MultiGraphAclSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
class FirstGraphModeGuard {
public:
explicit FirstGraphModeGuard(const std::shared_ptr<AclModelOptions> &options) : options_(options) {
if (options_ != nullptr) {
options_->SetFirstGraph(true);
}
}
~FirstGraphModeGuard() {
if (options_ != nullptr) {
options_->SetFirstGraph(false);
}
}
private:
std::shared_ptr<AclModelOptions> options_;
};
MS_LOG(INFO) << "Start MultiGraph Compile.";
// construct kernel graph
auto kernel_graph = SessionBasic::ConstructKernelGraph(lst, outputs, false);
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>("310_multi_graph_pm");
pm->AddPass(std::make_shared<opt::InsertPlaceholderForDynamicRNN>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
// concert to om data
ModelConverter model_converter_;
model_converter_.set_options(options_);
FirstGraphModeGuard guard(options_);
auto om_data = model_converter_.LoadMindIR(kernel_graph);
if (om_data.Data() == nullptr || om_data.DataSize() == 0) {
MS_LOG(ERROR) << "Load MindIR failed.";
return kMCFailed;
}
// load
std::shared_ptr<Graph> graph = std::make_shared<Graph>(std::make_shared<Graph::GraphData>(om_data, ModelType::kOM));
MS_EXCEPTION_IF_NULL(graph);
auto graph_cell = GraphCell(graph);
auto ret = graph_cell.Load(options_->GetDeviceID());
if (ret != kSuccess) {
MS_LOG(EXCEPTION) << "Load failed.";
}
graphs_[kernel_graph->graph_id()] = graph_cell;
kernel_graphs_[kernel_graph->graph_id()] = kernel_graph;
MS_LOG(INFO) << "Multi graph compile success, graph id " << kernel_graph->graph_id();
return kernel_graph->graph_id();
}
void MultiGraphAclSession::RunGraph(GraphId graph_id, const std::vector<MSTensor> &inputs, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
MS_LOG(INFO) << "Start run graph " << graph_id;
auto iter = graphs_.find(graph_id);
if (iter == graphs_.end()) {
MS_LOG(EXCEPTION) << "Graph id " << graph_id << " not found.";
}
std::vector<MSTensor> out_tensors;
auto ret = iter->second.Run(inputs, &out_tensors);
if (ret != kSuccess) {
MS_LOG(EXCEPTION) << "Graph id " << graph_id << " run failed.";
}
std::deque<MSTensor> out_tensors_deque(out_tensors.begin(), out_tensors.end());
(*outputs) = ConstructOutputRef(graph_id, &out_tensors_deque);
}
VectorRef MultiGraphAclSession::ConstructOutputRef(GraphId graph_id, std::deque<MSTensor> *out_tensors) {
MS_EXCEPTION_IF_NULL(out_tensors);
VectorRef outs;
auto out_nodes = kernel_graphs_[graph_id]->outputs();
for (auto &out : out_nodes) {
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(
out, 0, false, std::vector<PrimitivePtr>{prim::kPrimMakeTuple, prim::kPrimUpdateState, prim::kPrimStateSetItem});
auto &anf_node = item_with_index.first;
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) {
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
outs.emplace_back(ConstructOutputRefByTupleNode(cnode, out_tensors));
} else if (AnfAlgo::IsRealKernel(anf_node)) {
if (out_tensors->empty()) {
MS_LOG(EXCEPTION) << "Can not find MSTensor for output node " << out->DebugString()
<< ", visited: " << anf_node->DebugString();
}
outs.emplace_back(MSTensorRef(out_tensors->front()));
out_tensors->pop_front();
}
}
if (!out_tensors->empty()) {
MS_LOG(EXCEPTION) << "Number of output size " << outs.size() << " but " << out_tensors->size()
<< " MSTensor remained.";
}
return outs;
}
VectorRef MultiGraphAclSession::ConstructOutputRefByTupleNode(const CNodePtr &tuple_node,
std::deque<MSTensor> *out_tensors) {
MS_EXCEPTION_IF_NULL(out_tensors);
VectorRef outs;
for (size_t i = 1; i < tuple_node->inputs().size(); ++i) {
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(
tuple_node->input(i), 0, false,
std::vector<PrimitivePtr>{prim::kPrimMakeTuple, prim::kPrimUpdateState, prim::kPrimStateSetItem});
auto &anf_node = item_with_index.first;
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) {
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
outs.emplace_back(ConstructOutputRefByTupleNode(cnode, out_tensors));
} else if (AnfAlgo::IsRealKernel(anf_node)) {
if (out_tensors->empty()) {
MS_LOG(EXCEPTION) << "Can not find MSTensor for output node " << tuple_node->input(i)->DebugString()
<< ", visited: " << anf_node->DebugString();
}
outs.emplace_back(MSTensorRef(out_tensors->front()));
out_tensors->pop_front();
}
}
return outs;
}
MS_REG_SESSION(kDavinciMultiGraphInferenceDevice, MultiGraphAclSession);
} // namespace mindspore::session

View File

@ -0,0 +1,49 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CXX_API_ACL_VM_ACL_MULTI_GRAPH_SESSION_H
#define MINDSPORE_CCSRC_CXX_API_ACL_VM_ACL_MULTI_GRAPH_SESSION_H
#include <deque>
#include <vector>
#include <map>
#include <memory>
#include "include/api/types.h"
#include "include/api/cell.h"
#include "backend/session/session_basic.h"
namespace mindspore {
class AclModelOptions;
namespace session {
class MultiGraphAclSession : public session::SessionBasic {
public:
MultiGraphAclSession() = default;
~MultiGraphAclSession() override = default;
void Init(uint32_t device_id) override;
GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
void RunGraph(GraphId graph_id, const std::vector<MSTensor> &inputs, VectorRef *outputs);
void SetOptions(const std::shared_ptr<AclModelOptions> &options) { options_ = options; }
private:
VectorRef ConstructOutputRef(GraphId graph_id, std::deque<MSTensor> *out_tensors);
VectorRef ConstructOutputRefByTupleNode(const CNodePtr &tuple_node, std::deque<MSTensor> *out_tensors);
std::map<GraphId, GraphCell> graphs_ = {};
std::map<GraphId, KernelGraphPtr> kernel_graphs_ = {};
std::shared_ptr<AclModelOptions> options_ = nullptr;
};
} // namespace session
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_ACL_VM_ACL_MULTI_GRAPH_SESSION_H

View File

@ -0,0 +1,295 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cxx_api/model/acl/acl_vm/acl_vm.h"
#include <memory>
#include <string>
#include <vector>
#include "cxx_api/model/acl/acl_model_options.h"
#include "cxx_api/model/acl/acl_vm/acl_multi_graph_session.h"
#include "debug/trace.h"
namespace mindspore {
namespace {
inline bool IsMonadNode(const AnfNodePtr &node) {
if (IsPrimitiveCNode(node, prim::kPrimStateSetItem) || IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
return true;
}
if (HasAbstractMonad(node)) {
return true;
}
return false;
}
} // namespace
AclBackend::AclBackend(const std::string &name, const std::string &target,
const std::shared_ptr<AclModelOptions> &options)
: MsBackend(name, target, options->GetDeviceID()) {
auto session = std::dynamic_pointer_cast<session::MultiGraphAclSession>(MsBackend::target_sess_);
MS_EXCEPTION_IF_NULL(session);
session->SetOptions(options);
}
VectorRef AclBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
std::vector<MSTensor> inputs;
for (const auto &arg : args) {
if (!utils::isa<MSTensorRef>(arg)) {
MS_LOG(EXCEPTION) << "Invalid item " << arg.ToString();
}
auto wrapper = utils::cast<MSTensorRef>(arg);
inputs.emplace_back(wrapper.GetTensor());
}
VectorRef outputs;
MS_EXCEPTION_IF_NULL(target_sess_);
auto exec_sess = std::dynamic_pointer_cast<session::MultiGraphAclSession>(target_sess_);
MS_EXCEPTION_IF_NULL(exec_sess);
exec_sess->RunGraph(g, inputs, &outputs);
return outputs;
}
bool AclBackend::GetCond(const BaseRef &c, bool *value) {
MS_EXCEPTION_IF_NULL(value);
if (!utils::isa<MSTensorRef>(c)) {
MS_LOG(ERROR) << "Invalid item " << c.ToString() << " must be a MSTensorRef.";
return false;
}
auto wrapper = utils::cast<MSTensorRef>(c);
if (wrapper.GetTensor().DataType() != DataType::kNumberTypeBool) {
MS_LOG(ERROR) << "Invalid data type " << wrapper.GetTensor().DataType() << " must be bool.";
return false;
}
auto data = wrapper.GetTensor().Data();
if (data == nullptr) {
return false;
}
(*value) = *reinterpret_cast<const bool *>(data.get());
return true;
}
bool AclBackend::GetIndex(const BaseRef &c, int64_t *value) {
MS_EXCEPTION_IF_NULL(value);
if (!utils::isa<MSTensorRef>(c)) {
MS_LOG(ERROR) << "Invalid item " << c.ToString() << " must be a MSTensorRef.";
return false;
}
auto wrapper = utils::cast<MSTensorRef>(c);
if (wrapper.GetTensor().DataType() == DataType::kNumberTypeInt32) {
auto data = wrapper.GetTensor().Data();
if (data == nullptr) {
return false;
}
auto value_int32 = *reinterpret_cast<const int32_t *>(data.get());
(*value) = static_cast<int64_t>(value_int32);
return true;
} else if (wrapper.GetTensor().DataType() == DataType::kNumberTypeInt64) {
auto data = wrapper.GetTensor().Data();
if (data == nullptr) {
return false;
}
(*value) = *reinterpret_cast<const int64_t *>(data.get());
return true;
} else {
MS_LOG(ERROR) << "Index must be Int type.";
return false;
}
}
AclCompileGraph::AclCompileGraph(const std::shared_ptr<compile::MsBackend> &backend,
const std::vector<PrimitivePtr> &cut_list)
: CompileGraph(backend, cut_list) {}
void AclCompileGraph::AddInst(const compile::Instruction &inst, const MSTensorRef &arg) {
VectorRef args;
args.push_back(arg);
compile::CompileGraph::AddInst(inst, args);
}
int64_t AclCompileGraph::Ref(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << "Start Ref node " << node->DebugString(true) << " height_: " << height_;
if (slots_.count(node) == 0 && node->isa<ValueNode>()) {
if (IsValueNode<FuncGraph>(node)) {
MS_LOG(DEBUG) << "Push graph.";
compile::CompileGraph::AddInst(compile::Instruction::kGraph, GetValueNode(node));
} else {
MS_LOG(DEBUG) << "Push.";
if (IsValueNode<Primitive>(node)) {
MS_LOG(EXCEPTION) << "must not be primitive in here NodeInfo: " << trace::GetDebugInfo(node->debug_info());
} else if (IsValueNode<tensor::Tensor>(node)) {
auto tensor_node = std::dynamic_pointer_cast<tensor::Tensor>(node->cast<ValueNodePtr>()->value());
MS_EXCEPTION_IF_NULL(tensor_node);
std::string name = "";
std::vector<int64_t> shape = tensor_node->shape_c();
DataType type = static_cast<DataType>(tensor_node->data_type_c());
auto mstensor_node = MSTensor::CreateRefTensor(name, type, shape, tensor_node->data_c(), tensor_node->Size());
MSTensorRef mstensor_ref(*mstensor_node);
AddInst(compile::Instruction::kPush, mstensor_ref);
MSTensor::DestroyTensorPtr(mstensor_node);
} else {
compile::CompileGraph::AddInst(compile::Instruction::kPush, GetValueNode(node));
}
}
Push(node);
} else if (auto const_parameter = dyn_cast<Parameter>(node);
slots_.count(node) == 0 && const_parameter != nullptr && const_parameter->has_default()) {
auto value = const_parameter->default_param();
MS_EXCEPTION_IF_NULL(value);
if (value->isa<tensor::Tensor>()) {
auto tensor_node = std::dynamic_pointer_cast<tensor::Tensor>(value);
MS_EXCEPTION_IF_NULL(tensor_node);
std::vector<int64_t> shape = tensor_node->shape_c();
DataType type = static_cast<DataType>(tensor_node->data_type_c());
auto mstensor_node =
MSTensor::CreateRefTensor(const_parameter->name(), type, shape, tensor_node->data_c(), tensor_node->Size());
MSTensorRef mstensor_ref(*mstensor_node);
AddInst(compile::Instruction::kPush, mstensor_ref);
MSTensor::DestroyTensorPtr(mstensor_node);
} else {
compile::CompileGraph::AddInst(compile::Instruction::kPush, value);
}
Push(node);
}
MS_LOG(DEBUG) << "End Ref node end height_: " << height_ << ", slots: " << slots_[node]
<< ", return: " << slots_[node] - height_;
return slots_[node] - height_;
}
void AclCompileGraph::AddExternal(const compile::LinConvertResult &result) {
VectorRef args;
args.push_back(result.run);
args.push_back(result.simu_run);
size_t size = result.inputs.size();
for (size_t i = 0; i < size; ++i) {
const auto &input = result.inputs[i];
MS_EXCEPTION_IF_NULL(input);
if (auto parameter = dyn_cast<Parameter>(input); parameter != nullptr && parameter->has_default()) {
MS_LOG(DEBUG) << parameter->DebugString() << " has default value, will not be pushed as inputs.";
continue;
}
if (IsMonadNode(input)) {
MS_LOG(DEBUG) << input->DebugString() << " is monad node, will not be pushed as inputs.";
continue;
}
args.emplace_back(Ref(input));
}
compile::CompileGraph::AddInst(compile::Instruction::kExternal, args);
size_t out_count = 0;
for (auto &out : result.outputs) {
if (IsMonadNode(out)) {
continue;
}
++out_count;
Push(out);
}
MS_LOG(DEBUG) << "Args size " << args.size() << " out size " << out_count;
}
void AclCompileGraph::AddInput(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (IsMonadNode(node)) {
return;
}
if (slots_.count(node) == 0) {
MS_LOG(DEBUG) << "Input node is null " << node->DebugString(true);
(void)Ref(node);
return;
}
compile::CompileGraph::AddInst(compile::Instruction::kInput, Ref(node));
set_height(height_ + 1);
}
void AclCompileGraph::AddPartial(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto inputs = node->inputs();
VectorRef args;
if (inputs.size() <= 1) {
MS_LOG(EXCEPTION) << "The node:" << node->DebugString() << "do not have two input.";
}
auto fn = inputs[1];
if (!IsValueNode<FuncGraph>(fn)) {
MS_LOG(EXCEPTION) << "The type of 1st input of node must be FuncGraph";
}
for (size_t i = 1; i < inputs.size(); i++) {
if (IsMonadNode(inputs[i])) {
continue;
}
args.emplace_back(Ref(inputs[i]));
}
compile::CompileGraph::AddInst(compile::Instruction::kPartial, args);
}
int64_t AclCompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto inputs = node->inputs();
AnfNodePtr fn = inputs[0];
(void)Ref(fn);
size_t size = inputs.size();
size_t non_monad_size = size;
for (size_t i = size - 1; i > 0; --i) {
if (IsMonadNode(inputs[i])) {
--non_monad_size;
continue;
}
AddInput(inputs[i]);
}
if (node == graph->output()) {
AddTailCall(fn, non_monad_size);
return RET_BREAK;
}
MS_LOG(DEBUG) << "Call:" << Ref(fn) << ", " << height_ << ", " << (non_monad_size - 1);
compile::CompileGraph::AddInst(compile::Instruction::kCall, Ref(fn));
Ret(static_cast<int64_t>(non_monad_size - 1));
for (size_t i = size - 1; i > 0; i--) {
const auto iter = slots_.find(inputs[i]);
if (iter != slots_.end() && iter->second >= height_) {
slots_.erase(inputs[i]);
}
}
return RET_SUCCESS;
}
void AclCompileGraph::PushParameters(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> parameters = func_graph->parameters();
for (size_t i = parameters.size(); i != 0; i--) {
MS_EXCEPTION_IF_NULL(parameters[i - 1]);
auto param = parameters[i - 1]->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
if (param->has_default()) {
MS_LOG(DEBUG) << "Parameter " << (i - 1) << ": " << param->DebugString() << " has default value, skip.";
continue;
}
if (IsMonadNode(param)) {
MS_LOG(DEBUG) << "Parameter " << (i - 1) << ": " << param->DebugString() << " has monad type, skip.";
continue;
}
Push(param);
MS_LOG(DEBUG) << "Push parameter " << (i - 1) << ": " << param->DebugString(true);
}
}
AclCompileGraphs::AclCompileGraphs(const std::shared_ptr<compile::MsBackend> &backend,
const std::vector<PrimitivePtr> &cut_list)
: CompileGraphs(backend, cut_list) {
MS_EXCEPTION_IF_NULL(backend);
MS_LOG(DEBUG) << "Start vm: " << backend->name();
transform_ = std::make_shared<AclCompileGraph>(backend, cut_list);
Reset();
}
} // namespace mindspore

View File

@ -0,0 +1,62 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CXX_API_ACL_VM_ACL_VM_H
#define MINDSPORE_CCSRC_CXX_API_ACL_VM_ACL_VM_H
#include <vector>
#include <memory>
#include <string>
#include "vm/transform.h"
#include "vm/backend.h"
#include "cxx_api/model/acl/acl_vm/ms_tensor_ref.h"
namespace mindspore {
class AclModelOptions;
class AclBackend : public compile::MsBackend {
public:
AclBackend(const std::string &name, const std::string &target, const std::shared_ptr<AclModelOptions> &options);
~AclBackend() override = default;
VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) override;
bool GetCond(const BaseRef &c, bool *value) override;
bool GetIndex(const BaseRef &c, int64_t *value) override;
};
class AclCompileGraph : public compile::CompileGraph {
public:
explicit AclCompileGraph(const std::shared_ptr<compile::MsBackend> &backend,
const std::vector<PrimitivePtr> &cut_list);
~AclCompileGraph() override = default;
int64_t Ref(const AnfNodePtr &node) override;
void AddExternal(const compile::LinConvertResult &result) override;
void AddInput(const AnfNodePtr &node) override;
void AddPartial(const CNodePtr &node) override;
int64_t AddCall(const FuncGraphPtr &graph, const CNodePtr &node) override;
void PushParameters(const FuncGraphPtr &func_graph) override;
private:
void AddInst(const compile::Instruction &inst, const MSTensorRef &arg);
};
class AclCompileGraphs : public compile::CompileGraphs {
public:
explicit AclCompileGraphs(const std::shared_ptr<compile::MsBackend> &backend,
const std::vector<PrimitivePtr> &cut_list);
~AclCompileGraphs() override = default;
};
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_ACL_VM_ACL_VM_H

View File

@ -0,0 +1,78 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cxx_api/model/acl/acl_vm/ms_tensor_ref.h"
#include <algorithm>
namespace mindspore {
VectorRef MSTensorRef::Convert(const std::vector<MSTensor> &tensors) {
VectorRef res;
std::transform(tensors.begin(), tensors.end(), std::back_inserter(res),
[](const MSTensor &t) { return MSTensorRef(t); });
return res;
}
std::vector<MSTensor> MSTensorRef::Convert(const BaseRef &args) {
std::vector<MSTensor> res;
if (utils::isa<VectorRef>(args)) {
VectorRef args_vec = utils::cast<VectorRef>(args);
res = ConvertTuple(args_vec);
} else if (utils::isa<MSTensorRef>(args)) {
auto wrapper = utils::cast<MSTensorRef>(args);
res.push_back(wrapper.ms_tensor_);
} else {
MS_LOG(EXCEPTION) << "Invalid BaseRef " << args.ToString() << " must be MSTensorRef or VectorRef{MSTensorRef...}";
}
return res;
}
std::shared_ptr<Base> MSTensorRef::copy() const {
MSTensor *tensor = ms_tensor_.Clone();
auto res = std::make_shared<MSTensorRef>(static_cast<const MSTensor &>(*tensor));
MSTensor::DestroyTensorPtr(tensor);
return res;
}
bool MSTensorRef::operator==(const BaseRef &other) const {
if (!utils::isa<MSTensorRef>(other)) {
return false;
}
auto other_ms_tensor = utils::cast<MSTensorRef>(other).ms_tensor_;
auto this_ms_tensor = ms_tensor_;
return (this_ms_tensor.Name() == other_ms_tensor.Name()) && (this_ms_tensor.Shape() == other_ms_tensor.Shape()) &&
(this_ms_tensor.MutableData() == other_ms_tensor.MutableData()) &&
(this_ms_tensor.DataSize() == other_ms_tensor.DataSize()) &&
(this_ms_tensor.DataType() == other_ms_tensor.DataType());
}
std::vector<MSTensor> MSTensorRef::ConvertTuple(const VectorRef &args) {
std::vector<MSTensor> outs;
for (size_t i = 0; i < args.size(); ++i) {
const auto &item = args[i];
if (utils::isa<VectorRef>(item)) {
VectorRef args_vec = utils::cast<VectorRef>(args);
auto ret = ConvertTuple(args_vec);
outs.insert(outs.end(), ret.begin(), ret.end());
} else if (utils::isa<MSTensorRef>(item)) {
auto wrapper = utils::cast<MSTensorRef>(item);
outs.push_back(wrapper.ms_tensor_);
} else {
MS_LOG(EXCEPTION) << "Invalid BaseRef " << args.ToString() << " must be MSTensorRef or VectorRef{MSTensorRef...}";
}
}
return outs;
}
} // namespace mindspore

View File

@ -0,0 +1,49 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CXX_API_ACL_VM_MS_TENSOR_REF_H
#define MINDSPORE_CCSRC_CXX_API_ACL_VM_MS_TENSOR_REF_H
#include <memory>
#include <string>
#include <vector>
#include "include/api/types.h"
#include "mindspore/core/base/base_ref.h"
namespace mindspore {
class MSTensorRef : public BaseRef {
public:
MS_DECLARE_PARENT(MSTensorRef, BaseRef);
static VectorRef Convert(const std::vector<MSTensor> &tensors);
static std::vector<MSTensor> Convert(const BaseRef &args);
explicit MSTensorRef(const MSTensor &tensor) : ms_tensor_(tensor) {}
~MSTensorRef() override = default;
const MSTensor &GetTensor() const { return ms_tensor_; }
std::shared_ptr<Base> copy() const override;
uint32_t type() const override { return tid(); }
std::string ToString() const override { return ms_tensor_.Name(); }
bool operator==(const BaseRef &other) const override;
private:
static std::vector<MSTensor> ConvertTuple(const VectorRef &args);
MSTensor ms_tensor_;
};
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_ACL_VM_MS_TENSOR_REF_H

View File

@ -900,7 +900,10 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
return *this;
}
UpdateDataOpDesc(it, op);
if (HasAbstractMonad(it)) {
MS_LOG(INFO) << it->DebugString() << " is a monad parameter, skip.";
continue;
}
MS_LOG(INFO) << "add input " << it->ToString() << ", index " << index;
(void)std::static_pointer_cast<Data>(op)->set_attr_index(index++);
inputs.push_back(*op);
@ -945,13 +948,17 @@ void DfGraphConvertor::UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr
MS_LOG(ERROR) << "Update data op descriptor failed! Invalid node.";
return;
}
auto normal_shape_ptr = dyn_cast<abstract::Shape>(node->Shape());
std::vector<int64_t> shape;
if (normal_shape_ptr == nullptr) {
if (auto normal_shape_ptr = dyn_cast<abstract::Shape>(node->Shape()); normal_shape_ptr != nullptr) {
shape = normal_shape_ptr->shape();
} else if (auto no_shape_ptr = dyn_cast<abstract::NoShape>(node->Shape()); no_shape_ptr != nullptr) {
shape = {};
} else {
MS_LOG(INFO) << "Invalid shape to update data op descriptor.";
return;
}
shape = normal_shape_ptr->shape();
if (node->Type() == nullptr) {
MS_LOG(INFO) << "Invalid type to update data op descriptor.";
return;
@ -1887,7 +1894,6 @@ OperatorPtr DfGraphConvertor::ConvertValueNode(const ValueNodePtr node) {
auto ge_tensor = const_op->get_attr_value();
auto ge_desc = ge_tensor.GetTensorDesc();
(void)const_op->update_output_desc_y(ge_desc);
op_cache_[node.get()] = op;
return op_cache_[node.get()];
}

View File

@ -145,22 +145,6 @@ void CompileGraph::PushParameters(const FuncGraphPtr &graph) {
}
}
void CompileGraph::PushInputs(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> parameters = graph->parameters();
for (size_t i = parameters.size(); i != 0; i--) {
MS_EXCEPTION_IF_NULL(parameters[i - 1]);
auto param = parameters[i - 1]->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
if (param->has_default()) {
MS_LOG(DEBUG) << "Parameter " << (i - 1) << ": " << param->DebugString() << " has default value, skip.";
continue;
}
Push(param);
MS_LOG(DEBUG) << "Push parameter " << (i - 1) << ": " << param->DebugString(true);
}
}
int64_t CompileGraph::LinConvert(const FuncGraphPtr &graph, const GraphSegmentPtr &segment, const std::string &target) {
MS_EXCEPTION_IF_NULL(segment);
MS_LOG(DEBUG) << "LinConvert start";
@ -273,15 +257,11 @@ bool CompileGraph::Compile(const FuncGraphPtr &graph) {
return true;
}
InstSet CompileGraph::Run(const FuncGraphPtr &graph, bool push_weight) {
InstSet CompileGraph::Run(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
Reset();
if (push_weight) {
PushParameters(graph);
} else {
PushInputs(graph);
}
int64_t param_height = height_;
MS_EXCEPTION_IF_NULL(graph->get_return());

View File

@ -55,7 +55,7 @@ class CompileGraph {
virtual ~CompileGraph() = default;
InstSet Run(const FuncGraphPtr &func_graph, bool push_weight = true);
InstSet Run(const FuncGraphPtr &func_graph);
bool IsCut(const AnfNodePtr &node);
void Push(const AnfNodePtr &node);
void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; }
@ -77,22 +77,21 @@ class CompileGraph {
}
protected:
void PushParameters(const FuncGraphPtr &func_graph);
void PushInputs(const FuncGraphPtr &graph);
virtual void PushParameters(const FuncGraphPtr &func_graph);
bool Compile(const FuncGraphPtr &func_graph);
int64_t LinConvert(const FuncGraphPtr &func_graph, const GraphSegmentPtr &segment, const std::string &target = "");
int64_t InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node);
int64_t AddCall(const FuncGraphPtr &graph, const CNodePtr &node);
virtual int64_t AddCall(const FuncGraphPtr &graph, const CNodePtr &node);
void AddPadStack(int64_t param_height);
void AddTailCall(const AnfNodePtr &fn, size_t size);
void AddPartial(const CNodePtr &node);
virtual void AddPartial(const CNodePtr &node);
void AddMakeTuple(const CNodePtr &node);
void AddSwitch(const CNodePtr &node);
void AddSwitchLayer(const CNodePtr &node);
void AddReturn(const CNodePtr &node);
void AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim);
void AddInput(const AnfNodePtr &node);
void AddExternal(const LinConvertResult &result);
virtual void AddInput(const AnfNodePtr &node);
virtual void AddExternal(const LinConvertResult &result);
void AddInst(const Instruction &inst, const int64_t &arg);
void AddInst(const Instruction &inst, const ValuePtr &arg);
void AddInst(const Instruction &inst, const VectorRef &args);
@ -122,7 +121,7 @@ class CompileGraphs {
mapping_.clear();
}
virtual void Compile(const FuncGraphPtr &func_graph);
void Compile(const FuncGraphPtr &func_graph);
FinalVMPtr Link();
FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph);

View File

@ -89,7 +89,7 @@ void FinalVM::Pop(int64_t n) {
void FinalVM::MoveStack(int64_t nitems, int64_t height) {
if (nitems > height || height > sp_) {
MS_LOG(EXCEPTION) << "MoveStack arg error: nitems=" << nitems << " height=" << height;
MS_LOG(EXCEPTION) << "MoveStack arg error: nitems=" << nitems << " height=" << height << " sp=" << sp_;
}
int64_t n = height - nitems;
int64_t src = sp_ - height;