!24083 refactor acl vm && acl vm ignore monad input/output
Merge pull request !24083 from zhoufeng/xiu-ba-ge-2
This commit is contained in:
commit
46fc24b172
|
@ -13,7 +13,6 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cxx_api/model/acl/acl_model_multi.h"
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
@ -23,16 +22,11 @@
|
|||
#include <numeric>
|
||||
#include <deque>
|
||||
#include <functional>
|
||||
#include "backend/session/session_basic.h"
|
||||
#include "backend/session/session_factory.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h"
|
||||
#include "cxx_api/factory.h"
|
||||
#include "vm/backend.h"
|
||||
#include "vm/transform.h"
|
||||
#include "acl/acl_rt.h"
|
||||
#include "mindspore/core/load_mindir/infer_mindir.h"
|
||||
#include "debug/trace.h"
|
||||
#include "cxx_api/model/acl/acl_vm/ms_tensor_ref.h"
|
||||
#include "cxx_api/model/acl/acl_vm/acl_vm.h"
|
||||
|
||||
namespace mindspore {
|
||||
API_FACTORY_REG(ModelImpl, Ascend310, AclModelMulti);
|
||||
|
@ -46,366 +40,6 @@ std::map<DataType, size_t> kDtypeMap = {
|
|||
{DataType::kNumberTypeUInt8, sizeof(uint8_t)}, {DataType::kNumberTypeUInt16, sizeof(uint16_t)},
|
||||
{DataType::kNumberTypeUInt32, sizeof(uint32_t)}, {DataType::kNumberTypeUInt64, sizeof(uint64_t)}};
|
||||
|
||||
class MSTensorRef : public BaseRef {
|
||||
public:
|
||||
static VectorRef Convert(const std::vector<MSTensor> &tensors) {
|
||||
VectorRef res;
|
||||
std::transform(tensors.begin(), tensors.end(), std::back_inserter(res),
|
||||
[](const MSTensor &t) { return MSTensorRef(t); });
|
||||
return res;
|
||||
}
|
||||
|
||||
static std::vector<MSTensor> Convert(const BaseRef &args) {
|
||||
std::vector<MSTensor> res;
|
||||
if (utils::isa<VectorRef>(args)) {
|
||||
VectorRef args_vec = utils::cast<VectorRef>(args);
|
||||
res = ConvertTuple(args_vec);
|
||||
} else if (utils::isa<MSTensorRef>(args)) {
|
||||
auto wrapper = utils::cast<MSTensorRef>(args);
|
||||
res.push_back(wrapper.ms_tensor_);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid BaseRef " << args.ToString() << " must be MSTensorRef or VectorRef{MSTensorRef...}";
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
MS_DECLARE_PARENT(MSTensorRef, BaseRef);
|
||||
explicit MSTensorRef(const MSTensor &tensor) : ms_tensor_(tensor) {}
|
||||
~MSTensorRef() override = default;
|
||||
|
||||
const MSTensor &GetTensor() const { return ms_tensor_; }
|
||||
std::shared_ptr<Base> copy() const override {
|
||||
MSTensor *tensor = ms_tensor_.Clone();
|
||||
auto res = std::make_shared<MSTensorRef>(static_cast<const MSTensor &>(*tensor));
|
||||
MSTensor::DestroyTensorPtr(tensor);
|
||||
return res;
|
||||
}
|
||||
|
||||
uint32_t type() const override { return tid(); }
|
||||
std::string ToString() const override { return ms_tensor_.Name(); }
|
||||
bool operator==(const BaseRef &other) const override {
|
||||
if (!utils::isa<MSTensorRef>(other)) {
|
||||
return false;
|
||||
}
|
||||
auto other_ms_tensor = utils::cast<MSTensorRef>(other).ms_tensor_;
|
||||
auto this_ms_tensor = ms_tensor_;
|
||||
return (this_ms_tensor.Name() == other_ms_tensor.Name()) && (this_ms_tensor.Shape() == other_ms_tensor.Shape()) &&
|
||||
(this_ms_tensor.MutableData() == other_ms_tensor.MutableData()) &&
|
||||
(this_ms_tensor.DataSize() == other_ms_tensor.DataSize()) &&
|
||||
(this_ms_tensor.DataType() == other_ms_tensor.DataType());
|
||||
}
|
||||
|
||||
private:
|
||||
static std::vector<MSTensor> ConvertTuple(const VectorRef &args) {
|
||||
std::vector<MSTensor> outs;
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
const auto &item = args[i];
|
||||
if (utils::isa<VectorRef>(item)) {
|
||||
VectorRef args_vec = utils::cast<VectorRef>(args);
|
||||
auto ret = ConvertTuple(args_vec);
|
||||
outs.insert(outs.end(), ret.begin(), ret.end());
|
||||
} else if (utils::isa<MSTensorRef>(item)) {
|
||||
auto wrapper = utils::cast<MSTensorRef>(item);
|
||||
outs.push_back(wrapper.ms_tensor_);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid BaseRef " << args.ToString()
|
||||
<< " must be MSTensorRef or VectorRef{MSTensorRef...}";
|
||||
}
|
||||
}
|
||||
return outs;
|
||||
}
|
||||
|
||||
MSTensor ms_tensor_;
|
||||
};
|
||||
|
||||
class MultiGraphAclSession : public session::SessionBasic {
|
||||
public:
|
||||
MultiGraphAclSession() = default;
|
||||
~MultiGraphAclSession() override = default;
|
||||
void Init(uint32_t device_id) override;
|
||||
GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
|
||||
void RunGraph(GraphId graph_id, const std::vector<MSTensor> &inputs, VectorRef *outputs);
|
||||
void SetOptions(const std::shared_ptr<AclModelOptions> &options) { options_ = options; }
|
||||
|
||||
private:
|
||||
VectorRef ConstructOutputRef(GraphId graph_id, std::deque<MSTensor> *out_tensors);
|
||||
VectorRef ConstructOutputRefByTupleNode(const CNodePtr &tuple_node, std::deque<MSTensor> *out_tensors);
|
||||
|
||||
std::map<GraphId, GraphCell> graphs_ = {};
|
||||
std::map<GraphId, KernelGraphPtr> kernel_graphs_ = {};
|
||||
std::shared_ptr<AclModelOptions> options_ = nullptr;
|
||||
};
|
||||
|
||||
void MultiGraphAclSession::Init(uint32_t device_id) { InitExecutor(kDavinciMultiGraphInferenceDevice, device_id); }
|
||||
|
||||
GraphId MultiGraphAclSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
||||
class FirstGraphModeGuard {
|
||||
public:
|
||||
explicit FirstGraphModeGuard(const std::shared_ptr<AclModelOptions> &options) : options_(options) {
|
||||
if (options_ != nullptr) {
|
||||
options_->SetFirstGraph(true);
|
||||
}
|
||||
}
|
||||
~FirstGraphModeGuard() {
|
||||
if (options_ != nullptr) {
|
||||
options_->SetFirstGraph(false);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<AclModelOptions> options_;
|
||||
};
|
||||
MS_LOG(INFO) << "Start MultiGraph Compile.";
|
||||
// construct kernel graph
|
||||
auto kernel_graph = SessionBasic::ConstructKernelGraph(lst, outputs, false);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>("310_multi_graph_pm");
|
||||
pm->AddPass(std::make_shared<opt::InsertPlaceholderForDynamicRNN>());
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
// concert to om data
|
||||
ModelConverter model_converter_;
|
||||
model_converter_.set_options(options_);
|
||||
FirstGraphModeGuard guard(options_);
|
||||
auto om_data = model_converter_.LoadMindIR(kernel_graph);
|
||||
if (om_data.Data() == nullptr || om_data.DataSize() == 0) {
|
||||
MS_LOG(ERROR) << "Load MindIR failed.";
|
||||
return kMCFailed;
|
||||
}
|
||||
// load
|
||||
std::shared_ptr<Graph> graph = std::make_shared<Graph>(std::make_shared<Graph::GraphData>(om_data, ModelType::kOM));
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto graph_cell = GraphCell(graph);
|
||||
auto ret = graph_cell.Load(options_->GetDeviceID());
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(EXCEPTION) << "Load failed.";
|
||||
}
|
||||
graphs_[kernel_graph->graph_id()] = graph_cell;
|
||||
kernel_graphs_[kernel_graph->graph_id()] = kernel_graph;
|
||||
MS_LOG(INFO) << "Mulit graph compile success, graph id " << kernel_graph->graph_id();
|
||||
return kernel_graph->graph_id();
|
||||
}
|
||||
|
||||
void MultiGraphAclSession::RunGraph(GraphId graph_id, const std::vector<MSTensor> &inputs, VectorRef *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
MS_LOG(INFO) << "Start run graph " << graph_id;
|
||||
auto iter = graphs_.find(graph_id);
|
||||
if (iter == graphs_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Graph id " << graph_id << " not found.";
|
||||
}
|
||||
std::vector<MSTensor> out_tensors;
|
||||
auto ret = iter->second.Run(inputs, &out_tensors);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(EXCEPTION) << "Graph id " << graph_id << " run failed.";
|
||||
}
|
||||
|
||||
std::deque<MSTensor> out_tensors_deque(out_tensors.begin(), out_tensors.end());
|
||||
(*outputs) = ConstructOutputRef(graph_id, &out_tensors_deque);
|
||||
}
|
||||
|
||||
VectorRef MultiGraphAclSession::ConstructOutputRef(GraphId graph_id, std::deque<MSTensor> *out_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(out_tensors);
|
||||
VectorRef outs;
|
||||
auto out_nodes = kernel_graphs_[graph_id]->outputs();
|
||||
for (auto &out : out_nodes) {
|
||||
if (out_tensors->empty()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find MSTensor for output node " << out->DebugString();
|
||||
}
|
||||
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(out, 0);
|
||||
auto &anf_node = item_with_index.first;
|
||||
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) {
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
outs.emplace_back(ConstructOutputRefByTupleNode(cnode, out_tensors));
|
||||
} else {
|
||||
outs.emplace_back(MSTensorRef(out_tensors->front()));
|
||||
out_tensors->pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
if (!out_tensors->empty()) {
|
||||
MS_LOG(EXCEPTION) << "Number of output size " << outs.size() << " but " << out_tensors->size()
|
||||
<< " MSTensor remained.";
|
||||
}
|
||||
|
||||
return outs;
|
||||
}
|
||||
|
||||
VectorRef MultiGraphAclSession::ConstructOutputRefByTupleNode(const CNodePtr &tuple_node,
|
||||
std::deque<MSTensor> *out_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(out_tensors);
|
||||
VectorRef outs;
|
||||
for (size_t i = 1; i < tuple_node->inputs().size(); ++i) {
|
||||
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(tuple_node->input(i), 0);
|
||||
auto &anf_node = item_with_index.first;
|
||||
if (out_tensors->empty()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find MSTensor for output node " << anf_node->DebugString();
|
||||
}
|
||||
|
||||
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) {
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
outs.emplace_back(ConstructOutputRefByTupleNode(cnode, out_tensors));
|
||||
} else {
|
||||
outs.emplace_back(MSTensorRef(out_tensors->front()));
|
||||
out_tensors->pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
return outs;
|
||||
}
|
||||
|
||||
class AclBackend : public compile::MsBackend {
|
||||
public:
|
||||
AclBackend(const std::string &name, const std::string &target, const std::shared_ptr<AclModelOptions> &options)
|
||||
: MsBackend(name, target, options->GetDeviceID()) {
|
||||
auto session = std::dynamic_pointer_cast<MultiGraphAclSession>(MsBackend::target_sess_);
|
||||
MS_EXCEPTION_IF_NULL(session);
|
||||
session->SetOptions(options);
|
||||
}
|
||||
|
||||
~AclBackend() override = default;
|
||||
|
||||
VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) override {
|
||||
std::vector<MSTensor> inputs;
|
||||
for (const auto &arg : args) {
|
||||
if (!utils::isa<MSTensorRef>(arg)) {
|
||||
MS_LOG(EXCEPTION) << "Invalid item " << arg.ToString();
|
||||
}
|
||||
auto wrapper = utils::cast<MSTensorRef>(arg);
|
||||
inputs.emplace_back(wrapper.GetTensor());
|
||||
}
|
||||
|
||||
VectorRef outputs;
|
||||
MS_EXCEPTION_IF_NULL(target_sess_);
|
||||
auto exec_sess = std::dynamic_pointer_cast<MultiGraphAclSession>(target_sess_);
|
||||
MS_EXCEPTION_IF_NULL(exec_sess);
|
||||
exec_sess->RunGraph(g, inputs, &outputs);
|
||||
return outputs;
|
||||
}
|
||||
|
||||
bool GetCond(const BaseRef &c, bool *value) override {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (!utils::isa<MSTensorRef>(c)) {
|
||||
MS_LOG(ERROR) << "Invalid item " << c.ToString() << " must be a MSTensorRef.";
|
||||
return false;
|
||||
}
|
||||
auto wrapper = utils::cast<MSTensorRef>(c);
|
||||
if (wrapper.GetTensor().DataType() != DataType::kNumberTypeBool) {
|
||||
MS_LOG(ERROR) << "Invalid data type " << wrapper.GetTensor().DataType() << " must be bool.";
|
||||
return false;
|
||||
}
|
||||
auto data = wrapper.GetTensor().Data();
|
||||
if (data == nullptr) {
|
||||
return false;
|
||||
}
|
||||
(*value) = *reinterpret_cast<const bool *>(data.get());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GetIndex(const BaseRef &c, int64_t *value) override {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (!utils::isa<MSTensorRef>(c)) {
|
||||
MS_LOG(ERROR) << "Invalid item " << c.ToString() << " must be a MSTensorRef.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto wrapper = utils::cast<MSTensorRef>(c);
|
||||
if (wrapper.GetTensor().DataType() == DataType::kNumberTypeInt32) {
|
||||
auto data = wrapper.GetTensor().Data();
|
||||
if (data == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto value_int32 = *reinterpret_cast<const int32_t *>(data.get());
|
||||
(*value) = static_cast<int64_t>(value_int32);
|
||||
return true;
|
||||
} else if (wrapper.GetTensor().DataType() == DataType::kNumberTypeInt64) {
|
||||
auto data = wrapper.GetTensor().Data();
|
||||
if (data == nullptr) {
|
||||
return false;
|
||||
}
|
||||
(*value) = *reinterpret_cast<const int64_t *>(data.get());
|
||||
return true;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Index must be Int type.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class AclCompileGraph : public compile::CompileGraph {
|
||||
public:
|
||||
explicit AclCompileGraph(const std::shared_ptr<compile::MsBackend> &backend,
|
||||
const std::vector<PrimitivePtr> &cut_list)
|
||||
: CompileGraph(backend, cut_list) {}
|
||||
~AclCompileGraph() override = default;
|
||||
|
||||
void AddInst(const compile::Instruction &inst, const MSTensorRef &arg) {
|
||||
VectorRef args;
|
||||
args.push_back(arg);
|
||||
compile::CompileGraph::AddInst(inst, args);
|
||||
}
|
||||
|
||||
int64_t Ref(const AnfNodePtr &node) override {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(DEBUG) << "Start Ref node " << node->DebugString(true) << " height_: " << height_;
|
||||
if (slots_.count(node) == 0 && node->isa<ValueNode>()) {
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
MS_LOG(DEBUG) << "Push graph.";
|
||||
compile::CompileGraph::AddInst(compile::Instruction::kGraph, GetValueNode(node));
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Push.";
|
||||
if (IsValueNode<Primitive>(node)) {
|
||||
MS_LOG(EXCEPTION) << "must not be primitive in here NodeInfo: " << trace::GetDebugInfo(node->debug_info());
|
||||
} else if (IsValueNode<tensor::Tensor>(node)) {
|
||||
auto tensor_node = std::dynamic_pointer_cast<tensor::Tensor>(node->cast<ValueNodePtr>()->value());
|
||||
MS_EXCEPTION_IF_NULL(tensor_node);
|
||||
std::string name = "";
|
||||
std::vector<int64_t> shape = tensor_node->shape_c();
|
||||
DataType type = static_cast<DataType>(tensor_node->data_type_c());
|
||||
auto mstensor_node = MSTensor::CreateRefTensor(name, type, shape, tensor_node->data_c(), tensor_node->Size());
|
||||
MSTensorRef mstensor_ref(*mstensor_node);
|
||||
AddInst(compile::Instruction::kPush, mstensor_ref);
|
||||
MSTensor::DestroyTensorPtr(mstensor_node);
|
||||
} else {
|
||||
compile::CompileGraph::AddInst(compile::Instruction::kPush, GetValueNode(node));
|
||||
}
|
||||
}
|
||||
Push(node);
|
||||
}
|
||||
MS_LOG(DEBUG) << "End Ref node end height_: " << height_ << ", slots: " << slots_[node]
|
||||
<< ", return: " << slots_[node] - height_;
|
||||
return slots_[node] - height_;
|
||||
}
|
||||
};
|
||||
|
||||
class AclCompileGraphs : public compile::CompileGraphs {
|
||||
public:
|
||||
explicit AclCompileGraphs(const std::shared_ptr<compile::MsBackend> &backend,
|
||||
const std::vector<PrimitivePtr> &cut_list)
|
||||
: CompileGraphs(backend, cut_list) {
|
||||
MS_EXCEPTION_IF_NULL(backend);
|
||||
MS_LOG(DEBUG) << "Start vm: " << backend->name();
|
||||
transform_ = std::make_shared<AclCompileGraph>(backend, cut_list);
|
||||
Reset();
|
||||
}
|
||||
~AclCompileGraphs() override = default;
|
||||
void Compile(const FuncGraphPtr &graph) override {
|
||||
MS_LOG(DEBUG) << "Start";
|
||||
mapping_[graph] = SizeToLong(insts_.size());
|
||||
if (transform_ != nullptr) {
|
||||
auto insts = transform_->Run(graph, false);
|
||||
if (!insts.empty()) {
|
||||
(void)insts_.insert(insts_.end(), insts.begin(), insts.end());
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "End";
|
||||
}
|
||||
};
|
||||
|
||||
std::shared_ptr<compile::MsBackend> CreateBackend(const std::shared_ptr<AclModelOptions> &options) {
|
||||
MS_EXCEPTION_IF_NULL(options);
|
||||
return std::make_shared<AclBackend>(kMsConvert, kDavinciMultiGraphInferenceDevice, options);
|
||||
|
@ -610,8 +244,4 @@ std::vector<MSTensor> AclModelMulti::GetOutputs() {
|
|||
|
||||
return outputs_;
|
||||
}
|
||||
|
||||
namespace session {
|
||||
MS_REG_SESSION(kDavinciMultiGraphInferenceDevice, MultiGraphAclSession);
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,155 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cxx_api/model/acl/acl_vm/acl_multi_graph_session.h"
|
||||
#include <memory>
|
||||
#include <deque>
|
||||
#include <vector>
|
||||
#include "backend/session/session_factory.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h"
|
||||
#include "cxx_api/model/acl/model_converter.h"
|
||||
#include "cxx_api/model/acl/acl_model_options.h"
|
||||
#include "cxx_api/model/acl/acl_vm/ms_tensor_ref.h"
|
||||
#include "cxx_api/graph/graph_data.h"
|
||||
|
||||
namespace mindspore::session {
|
||||
|
||||
void MultiGraphAclSession::Init(uint32_t device_id) { InitExecutor(kDavinciMultiGraphInferenceDevice, device_id); }
|
||||
|
||||
GraphId MultiGraphAclSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
||||
class FirstGraphModeGuard {
|
||||
public:
|
||||
explicit FirstGraphModeGuard(const std::shared_ptr<AclModelOptions> &options) : options_(options) {
|
||||
if (options_ != nullptr) {
|
||||
options_->SetFirstGraph(true);
|
||||
}
|
||||
}
|
||||
~FirstGraphModeGuard() {
|
||||
if (options_ != nullptr) {
|
||||
options_->SetFirstGraph(false);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<AclModelOptions> options_;
|
||||
};
|
||||
MS_LOG(INFO) << "Start MultiGraph Compile.";
|
||||
// construct kernel graph
|
||||
auto kernel_graph = SessionBasic::ConstructKernelGraph(lst, outputs, false);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>("310_multi_graph_pm");
|
||||
pm->AddPass(std::make_shared<opt::InsertPlaceholderForDynamicRNN>());
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
// concert to om data
|
||||
ModelConverter model_converter_;
|
||||
model_converter_.set_options(options_);
|
||||
FirstGraphModeGuard guard(options_);
|
||||
auto om_data = model_converter_.LoadMindIR(kernel_graph);
|
||||
if (om_data.Data() == nullptr || om_data.DataSize() == 0) {
|
||||
MS_LOG(ERROR) << "Load MindIR failed.";
|
||||
return kMCFailed;
|
||||
}
|
||||
// load
|
||||
std::shared_ptr<Graph> graph = std::make_shared<Graph>(std::make_shared<Graph::GraphData>(om_data, ModelType::kOM));
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto graph_cell = GraphCell(graph);
|
||||
auto ret = graph_cell.Load(options_->GetDeviceID());
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(EXCEPTION) << "Load failed.";
|
||||
}
|
||||
graphs_[kernel_graph->graph_id()] = graph_cell;
|
||||
kernel_graphs_[kernel_graph->graph_id()] = kernel_graph;
|
||||
MS_LOG(INFO) << "Multi graph compile success, graph id " << kernel_graph->graph_id();
|
||||
return kernel_graph->graph_id();
|
||||
}
|
||||
|
||||
void MultiGraphAclSession::RunGraph(GraphId graph_id, const std::vector<MSTensor> &inputs, VectorRef *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
MS_LOG(INFO) << "Start run graph " << graph_id;
|
||||
auto iter = graphs_.find(graph_id);
|
||||
if (iter == graphs_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Graph id " << graph_id << " not found.";
|
||||
}
|
||||
std::vector<MSTensor> out_tensors;
|
||||
auto ret = iter->second.Run(inputs, &out_tensors);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(EXCEPTION) << "Graph id " << graph_id << " run failed.";
|
||||
}
|
||||
|
||||
std::deque<MSTensor> out_tensors_deque(out_tensors.begin(), out_tensors.end());
|
||||
(*outputs) = ConstructOutputRef(graph_id, &out_tensors_deque);
|
||||
}
|
||||
|
||||
VectorRef MultiGraphAclSession::ConstructOutputRef(GraphId graph_id, std::deque<MSTensor> *out_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(out_tensors);
|
||||
VectorRef outs;
|
||||
auto out_nodes = kernel_graphs_[graph_id]->outputs();
|
||||
for (auto &out : out_nodes) {
|
||||
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(
|
||||
out, 0, false, std::vector<PrimitivePtr>{prim::kPrimMakeTuple, prim::kPrimUpdateState, prim::kPrimStateSetItem});
|
||||
auto &anf_node = item_with_index.first;
|
||||
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) {
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
outs.emplace_back(ConstructOutputRefByTupleNode(cnode, out_tensors));
|
||||
} else if (AnfAlgo::IsRealKernel(anf_node)) {
|
||||
if (out_tensors->empty()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find MSTensor for output node " << out->DebugString()
|
||||
<< ", visited: " << anf_node->DebugString();
|
||||
}
|
||||
outs.emplace_back(MSTensorRef(out_tensors->front()));
|
||||
out_tensors->pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
if (!out_tensors->empty()) {
|
||||
MS_LOG(EXCEPTION) << "Number of output size " << outs.size() << " but " << out_tensors->size()
|
||||
<< " MSTensor remained.";
|
||||
}
|
||||
|
||||
return outs;
|
||||
}
|
||||
|
||||
VectorRef MultiGraphAclSession::ConstructOutputRefByTupleNode(const CNodePtr &tuple_node,
|
||||
std::deque<MSTensor> *out_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(out_tensors);
|
||||
VectorRef outs;
|
||||
for (size_t i = 1; i < tuple_node->inputs().size(); ++i) {
|
||||
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(
|
||||
tuple_node->input(i), 0, false,
|
||||
std::vector<PrimitivePtr>{prim::kPrimMakeTuple, prim::kPrimUpdateState, prim::kPrimStateSetItem});
|
||||
auto &anf_node = item_with_index.first;
|
||||
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) {
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
outs.emplace_back(ConstructOutputRefByTupleNode(cnode, out_tensors));
|
||||
} else if (AnfAlgo::IsRealKernel(anf_node)) {
|
||||
if (out_tensors->empty()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find MSTensor for output node " << tuple_node->input(i)->DebugString()
|
||||
<< ", visited: " << anf_node->DebugString();
|
||||
}
|
||||
outs.emplace_back(MSTensorRef(out_tensors->front()));
|
||||
out_tensors->pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
return outs;
|
||||
}
|
||||
MS_REG_SESSION(kDavinciMultiGraphInferenceDevice, MultiGraphAclSession);
|
||||
} // namespace mindspore::session
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_CXX_API_ACL_VM_ACL_MULTI_GRAPH_SESSION_H
|
||||
#define MINDSPORE_CCSRC_CXX_API_ACL_VM_ACL_MULTI_GRAPH_SESSION_H
|
||||
|
||||
#include <deque>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/cell.h"
|
||||
#include "backend/session/session_basic.h"
|
||||
|
||||
namespace mindspore {
|
||||
class AclModelOptions;
|
||||
namespace session {
|
||||
class MultiGraphAclSession : public session::SessionBasic {
|
||||
public:
|
||||
MultiGraphAclSession() = default;
|
||||
~MultiGraphAclSession() override = default;
|
||||
void Init(uint32_t device_id) override;
|
||||
GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
|
||||
void RunGraph(GraphId graph_id, const std::vector<MSTensor> &inputs, VectorRef *outputs);
|
||||
void SetOptions(const std::shared_ptr<AclModelOptions> &options) { options_ = options; }
|
||||
|
||||
private:
|
||||
VectorRef ConstructOutputRef(GraphId graph_id, std::deque<MSTensor> *out_tensors);
|
||||
VectorRef ConstructOutputRefByTupleNode(const CNodePtr &tuple_node, std::deque<MSTensor> *out_tensors);
|
||||
|
||||
std::map<GraphId, GraphCell> graphs_ = {};
|
||||
std::map<GraphId, KernelGraphPtr> kernel_graphs_ = {};
|
||||
std::shared_ptr<AclModelOptions> options_ = nullptr;
|
||||
};
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_ACL_VM_ACL_MULTI_GRAPH_SESSION_H
|
|
@ -0,0 +1,295 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cxx_api/model/acl/acl_vm/acl_vm.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "cxx_api/model/acl/acl_model_options.h"
|
||||
#include "cxx_api/model/acl/acl_vm/acl_multi_graph_session.h"
|
||||
#include "debug/trace.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
inline bool IsMonadNode(const AnfNodePtr &node) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimStateSetItem) || IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (HasAbstractMonad(node)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
AclBackend::AclBackend(const std::string &name, const std::string &target,
|
||||
const std::shared_ptr<AclModelOptions> &options)
|
||||
: MsBackend(name, target, options->GetDeviceID()) {
|
||||
auto session = std::dynamic_pointer_cast<session::MultiGraphAclSession>(MsBackend::target_sess_);
|
||||
MS_EXCEPTION_IF_NULL(session);
|
||||
session->SetOptions(options);
|
||||
}
|
||||
|
||||
VectorRef AclBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
|
||||
std::vector<MSTensor> inputs;
|
||||
for (const auto &arg : args) {
|
||||
if (!utils::isa<MSTensorRef>(arg)) {
|
||||
MS_LOG(EXCEPTION) << "Invalid item " << arg.ToString();
|
||||
}
|
||||
auto wrapper = utils::cast<MSTensorRef>(arg);
|
||||
inputs.emplace_back(wrapper.GetTensor());
|
||||
}
|
||||
|
||||
VectorRef outputs;
|
||||
MS_EXCEPTION_IF_NULL(target_sess_);
|
||||
auto exec_sess = std::dynamic_pointer_cast<session::MultiGraphAclSession>(target_sess_);
|
||||
MS_EXCEPTION_IF_NULL(exec_sess);
|
||||
exec_sess->RunGraph(g, inputs, &outputs);
|
||||
return outputs;
|
||||
}
|
||||
|
||||
bool AclBackend::GetCond(const BaseRef &c, bool *value) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (!utils::isa<MSTensorRef>(c)) {
|
||||
MS_LOG(ERROR) << "Invalid item " << c.ToString() << " must be a MSTensorRef.";
|
||||
return false;
|
||||
}
|
||||
auto wrapper = utils::cast<MSTensorRef>(c);
|
||||
if (wrapper.GetTensor().DataType() != DataType::kNumberTypeBool) {
|
||||
MS_LOG(ERROR) << "Invalid data type " << wrapper.GetTensor().DataType() << " must be bool.";
|
||||
return false;
|
||||
}
|
||||
auto data = wrapper.GetTensor().Data();
|
||||
if (data == nullptr) {
|
||||
return false;
|
||||
}
|
||||
(*value) = *reinterpret_cast<const bool *>(data.get());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AclBackend::GetIndex(const BaseRef &c, int64_t *value) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (!utils::isa<MSTensorRef>(c)) {
|
||||
MS_LOG(ERROR) << "Invalid item " << c.ToString() << " must be a MSTensorRef.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto wrapper = utils::cast<MSTensorRef>(c);
|
||||
if (wrapper.GetTensor().DataType() == DataType::kNumberTypeInt32) {
|
||||
auto data = wrapper.GetTensor().Data();
|
||||
if (data == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto value_int32 = *reinterpret_cast<const int32_t *>(data.get());
|
||||
(*value) = static_cast<int64_t>(value_int32);
|
||||
return true;
|
||||
} else if (wrapper.GetTensor().DataType() == DataType::kNumberTypeInt64) {
|
||||
auto data = wrapper.GetTensor().Data();
|
||||
if (data == nullptr) {
|
||||
return false;
|
||||
}
|
||||
(*value) = *reinterpret_cast<const int64_t *>(data.get());
|
||||
return true;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Index must be Int type.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
AclCompileGraph::AclCompileGraph(const std::shared_ptr<compile::MsBackend> &backend,
|
||||
const std::vector<PrimitivePtr> &cut_list)
|
||||
: CompileGraph(backend, cut_list) {}
|
||||
|
||||
void AclCompileGraph::AddInst(const compile::Instruction &inst, const MSTensorRef &arg) {
|
||||
VectorRef args;
|
||||
args.push_back(arg);
|
||||
compile::CompileGraph::AddInst(inst, args);
|
||||
}
|
||||
|
||||
int64_t AclCompileGraph::Ref(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(DEBUG) << "Start Ref node " << node->DebugString(true) << " height_: " << height_;
|
||||
if (slots_.count(node) == 0 && node->isa<ValueNode>()) {
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
MS_LOG(DEBUG) << "Push graph.";
|
||||
compile::CompileGraph::AddInst(compile::Instruction::kGraph, GetValueNode(node));
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Push.";
|
||||
if (IsValueNode<Primitive>(node)) {
|
||||
MS_LOG(EXCEPTION) << "must not be primitive in here NodeInfo: " << trace::GetDebugInfo(node->debug_info());
|
||||
} else if (IsValueNode<tensor::Tensor>(node)) {
|
||||
auto tensor_node = std::dynamic_pointer_cast<tensor::Tensor>(node->cast<ValueNodePtr>()->value());
|
||||
MS_EXCEPTION_IF_NULL(tensor_node);
|
||||
std::string name = "";
|
||||
std::vector<int64_t> shape = tensor_node->shape_c();
|
||||
DataType type = static_cast<DataType>(tensor_node->data_type_c());
|
||||
auto mstensor_node = MSTensor::CreateRefTensor(name, type, shape, tensor_node->data_c(), tensor_node->Size());
|
||||
MSTensorRef mstensor_ref(*mstensor_node);
|
||||
AddInst(compile::Instruction::kPush, mstensor_ref);
|
||||
MSTensor::DestroyTensorPtr(mstensor_node);
|
||||
} else {
|
||||
compile::CompileGraph::AddInst(compile::Instruction::kPush, GetValueNode(node));
|
||||
}
|
||||
}
|
||||
Push(node);
|
||||
} else if (auto const_parameter = dyn_cast<Parameter>(node);
|
||||
slots_.count(node) == 0 && const_parameter != nullptr && const_parameter->has_default()) {
|
||||
auto value = const_parameter->default_param();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
auto tensor_node = std::dynamic_pointer_cast<tensor::Tensor>(value);
|
||||
MS_EXCEPTION_IF_NULL(tensor_node);
|
||||
std::vector<int64_t> shape = tensor_node->shape_c();
|
||||
DataType type = static_cast<DataType>(tensor_node->data_type_c());
|
||||
auto mstensor_node =
|
||||
MSTensor::CreateRefTensor(const_parameter->name(), type, shape, tensor_node->data_c(), tensor_node->Size());
|
||||
MSTensorRef mstensor_ref(*mstensor_node);
|
||||
AddInst(compile::Instruction::kPush, mstensor_ref);
|
||||
MSTensor::DestroyTensorPtr(mstensor_node);
|
||||
} else {
|
||||
compile::CompileGraph::AddInst(compile::Instruction::kPush, value);
|
||||
}
|
||||
Push(node);
|
||||
}
|
||||
MS_LOG(DEBUG) << "End Ref node end height_: " << height_ << ", slots: " << slots_[node]
|
||||
<< ", return: " << slots_[node] - height_;
|
||||
return slots_[node] - height_;
|
||||
}
|
||||
|
||||
void AclCompileGraph::AddExternal(const compile::LinConvertResult &result) {
|
||||
VectorRef args;
|
||||
args.push_back(result.run);
|
||||
args.push_back(result.simu_run);
|
||||
size_t size = result.inputs.size();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
const auto &input = result.inputs[i];
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (auto parameter = dyn_cast<Parameter>(input); parameter != nullptr && parameter->has_default()) {
|
||||
MS_LOG(DEBUG) << parameter->DebugString() << " has default value, will not be pushed as inputs.";
|
||||
continue;
|
||||
}
|
||||
if (IsMonadNode(input)) {
|
||||
MS_LOG(DEBUG) << input->DebugString() << " is monad node, will not be pushed as inputs.";
|
||||
continue;
|
||||
}
|
||||
args.emplace_back(Ref(input));
|
||||
}
|
||||
compile::CompileGraph::AddInst(compile::Instruction::kExternal, args);
|
||||
size_t out_count = 0;
|
||||
for (auto &out : result.outputs) {
|
||||
if (IsMonadNode(out)) {
|
||||
continue;
|
||||
}
|
||||
++out_count;
|
||||
Push(out);
|
||||
}
|
||||
MS_LOG(DEBUG) << "Args size " << args.size() << " out size " << out_count;
|
||||
}
|
||||
|
||||
void AclCompileGraph::AddInput(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (IsMonadNode(node)) {
|
||||
return;
|
||||
}
|
||||
if (slots_.count(node) == 0) {
|
||||
MS_LOG(DEBUG) << "Input node is null " << node->DebugString(true);
|
||||
(void)Ref(node);
|
||||
return;
|
||||
}
|
||||
compile::CompileGraph::AddInst(compile::Instruction::kInput, Ref(node));
|
||||
set_height(height_ + 1);
|
||||
}
|
||||
|
||||
void AclCompileGraph::AddPartial(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto inputs = node->inputs();
|
||||
VectorRef args;
|
||||
if (inputs.size() <= 1) {
|
||||
MS_LOG(EXCEPTION) << "The node:" << node->DebugString() << "do not have two input.";
|
||||
}
|
||||
auto fn = inputs[1];
|
||||
if (!IsValueNode<FuncGraph>(fn)) {
|
||||
MS_LOG(EXCEPTION) << "The type of 1st input of node must be FuncGraph";
|
||||
}
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
if (IsMonadNode(inputs[i])) {
|
||||
continue;
|
||||
}
|
||||
args.emplace_back(Ref(inputs[i]));
|
||||
}
|
||||
compile::CompileGraph::AddInst(compile::Instruction::kPartial, args);
|
||||
}
|
||||
|
||||
int64_t AclCompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto inputs = node->inputs();
|
||||
AnfNodePtr fn = inputs[0];
|
||||
(void)Ref(fn);
|
||||
size_t size = inputs.size();
|
||||
size_t non_monad_size = size;
|
||||
for (size_t i = size - 1; i > 0; --i) {
|
||||
if (IsMonadNode(inputs[i])) {
|
||||
--non_monad_size;
|
||||
continue;
|
||||
}
|
||||
AddInput(inputs[i]);
|
||||
}
|
||||
if (node == graph->output()) {
|
||||
AddTailCall(fn, non_monad_size);
|
||||
return RET_BREAK;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Call:" << Ref(fn) << ", " << height_ << ", " << (non_monad_size - 1);
|
||||
compile::CompileGraph::AddInst(compile::Instruction::kCall, Ref(fn));
|
||||
Ret(static_cast<int64_t>(non_monad_size - 1));
|
||||
for (size_t i = size - 1; i > 0; i--) {
|
||||
const auto iter = slots_.find(inputs[i]);
|
||||
if (iter != slots_.end() && iter->second >= height_) {
|
||||
slots_.erase(inputs[i]);
|
||||
}
|
||||
}
|
||||
return RET_SUCCESS;
|
||||
}
|
||||
|
||||
void AclCompileGraph::PushParameters(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> parameters = func_graph->parameters();
|
||||
for (size_t i = parameters.size(); i != 0; i--) {
|
||||
MS_EXCEPTION_IF_NULL(parameters[i - 1]);
|
||||
auto param = parameters[i - 1]->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
if (param->has_default()) {
|
||||
MS_LOG(DEBUG) << "Parameter " << (i - 1) << ": " << param->DebugString() << " has default value, skip.";
|
||||
continue;
|
||||
}
|
||||
if (IsMonadNode(param)) {
|
||||
MS_LOG(DEBUG) << "Parameter " << (i - 1) << ": " << param->DebugString() << " has monad type, skip.";
|
||||
continue;
|
||||
}
|
||||
Push(param);
|
||||
MS_LOG(DEBUG) << "Push parameter " << (i - 1) << ": " << param->DebugString(true);
|
||||
}
|
||||
}
|
||||
|
||||
AclCompileGraphs::AclCompileGraphs(const std::shared_ptr<compile::MsBackend> &backend,
|
||||
const std::vector<PrimitivePtr> &cut_list)
|
||||
: CompileGraphs(backend, cut_list) {
|
||||
MS_EXCEPTION_IF_NULL(backend);
|
||||
MS_LOG(DEBUG) << "Start vm: " << backend->name();
|
||||
transform_ = std::make_shared<AclCompileGraph>(backend, cut_list);
|
||||
Reset();
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,62 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_CXX_API_ACL_VM_ACL_VM_H
|
||||
#define MINDSPORE_CCSRC_CXX_API_ACL_VM_ACL_VM_H
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "vm/transform.h"
|
||||
#include "vm/backend.h"
|
||||
#include "cxx_api/model/acl/acl_vm/ms_tensor_ref.h"
|
||||
|
||||
namespace mindspore {
|
||||
class AclModelOptions;
|
||||
class AclBackend : public compile::MsBackend {
|
||||
public:
|
||||
AclBackend(const std::string &name, const std::string &target, const std::shared_ptr<AclModelOptions> &options);
|
||||
~AclBackend() override = default;
|
||||
|
||||
VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) override;
|
||||
bool GetCond(const BaseRef &c, bool *value) override;
|
||||
bool GetIndex(const BaseRef &c, int64_t *value) override;
|
||||
};
|
||||
|
||||
class AclCompileGraph : public compile::CompileGraph {
|
||||
public:
|
||||
explicit AclCompileGraph(const std::shared_ptr<compile::MsBackend> &backend,
|
||||
const std::vector<PrimitivePtr> &cut_list);
|
||||
~AclCompileGraph() override = default;
|
||||
|
||||
int64_t Ref(const AnfNodePtr &node) override;
|
||||
void AddExternal(const compile::LinConvertResult &result) override;
|
||||
void AddInput(const AnfNodePtr &node) override;
|
||||
void AddPartial(const CNodePtr &node) override;
|
||||
int64_t AddCall(const FuncGraphPtr &graph, const CNodePtr &node) override;
|
||||
void PushParameters(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
void AddInst(const compile::Instruction &inst, const MSTensorRef &arg);
|
||||
};
|
||||
|
||||
class AclCompileGraphs : public compile::CompileGraphs {
|
||||
public:
|
||||
explicit AclCompileGraphs(const std::shared_ptr<compile::MsBackend> &backend,
|
||||
const std::vector<PrimitivePtr> &cut_list);
|
||||
~AclCompileGraphs() override = default;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_ACL_VM_ACL_VM_H
|
|
@ -0,0 +1,78 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cxx_api/model/acl/acl_vm/ms_tensor_ref.h"
|
||||
#include <algorithm>
|
||||
|
||||
namespace mindspore {
|
||||
VectorRef MSTensorRef::Convert(const std::vector<MSTensor> &tensors) {
|
||||
VectorRef res;
|
||||
std::transform(tensors.begin(), tensors.end(), std::back_inserter(res),
|
||||
[](const MSTensor &t) { return MSTensorRef(t); });
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<MSTensor> MSTensorRef::Convert(const BaseRef &args) {
|
||||
std::vector<MSTensor> res;
|
||||
if (utils::isa<VectorRef>(args)) {
|
||||
VectorRef args_vec = utils::cast<VectorRef>(args);
|
||||
res = ConvertTuple(args_vec);
|
||||
} else if (utils::isa<MSTensorRef>(args)) {
|
||||
auto wrapper = utils::cast<MSTensorRef>(args);
|
||||
res.push_back(wrapper.ms_tensor_);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid BaseRef " << args.ToString() << " must be MSTensorRef or VectorRef{MSTensorRef...}";
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
std::shared_ptr<Base> MSTensorRef::copy() const {
|
||||
MSTensor *tensor = ms_tensor_.Clone();
|
||||
auto res = std::make_shared<MSTensorRef>(static_cast<const MSTensor &>(*tensor));
|
||||
MSTensor::DestroyTensorPtr(tensor);
|
||||
return res;
|
||||
}
|
||||
|
||||
bool MSTensorRef::operator==(const BaseRef &other) const {
|
||||
if (!utils::isa<MSTensorRef>(other)) {
|
||||
return false;
|
||||
}
|
||||
auto other_ms_tensor = utils::cast<MSTensorRef>(other).ms_tensor_;
|
||||
auto this_ms_tensor = ms_tensor_;
|
||||
return (this_ms_tensor.Name() == other_ms_tensor.Name()) && (this_ms_tensor.Shape() == other_ms_tensor.Shape()) &&
|
||||
(this_ms_tensor.MutableData() == other_ms_tensor.MutableData()) &&
|
||||
(this_ms_tensor.DataSize() == other_ms_tensor.DataSize()) &&
|
||||
(this_ms_tensor.DataType() == other_ms_tensor.DataType());
|
||||
}
|
||||
|
||||
std::vector<MSTensor> MSTensorRef::ConvertTuple(const VectorRef &args) {
|
||||
std::vector<MSTensor> outs;
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
const auto &item = args[i];
|
||||
if (utils::isa<VectorRef>(item)) {
|
||||
VectorRef args_vec = utils::cast<VectorRef>(args);
|
||||
auto ret = ConvertTuple(args_vec);
|
||||
outs.insert(outs.end(), ret.begin(), ret.end());
|
||||
} else if (utils::isa<MSTensorRef>(item)) {
|
||||
auto wrapper = utils::cast<MSTensorRef>(item);
|
||||
outs.push_back(wrapper.ms_tensor_);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid BaseRef " << args.ToString() << " must be MSTensorRef or VectorRef{MSTensorRef...}";
|
||||
}
|
||||
}
|
||||
return outs;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_CXX_API_ACL_VM_MS_TENSOR_REF_H
|
||||
#define MINDSPORE_CCSRC_CXX_API_ACL_VM_MS_TENSOR_REF_H
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "include/api/types.h"
|
||||
#include "mindspore/core/base/base_ref.h"
|
||||
|
||||
namespace mindspore {
|
||||
class MSTensorRef : public BaseRef {
|
||||
public:
|
||||
MS_DECLARE_PARENT(MSTensorRef, BaseRef);
|
||||
|
||||
static VectorRef Convert(const std::vector<MSTensor> &tensors);
|
||||
static std::vector<MSTensor> Convert(const BaseRef &args);
|
||||
|
||||
explicit MSTensorRef(const MSTensor &tensor) : ms_tensor_(tensor) {}
|
||||
~MSTensorRef() override = default;
|
||||
|
||||
const MSTensor &GetTensor() const { return ms_tensor_; }
|
||||
std::shared_ptr<Base> copy() const override;
|
||||
|
||||
uint32_t type() const override { return tid(); }
|
||||
std::string ToString() const override { return ms_tensor_.Name(); }
|
||||
bool operator==(const BaseRef &other) const override;
|
||||
|
||||
private:
|
||||
static std::vector<MSTensor> ConvertTuple(const VectorRef &args);
|
||||
|
||||
MSTensor ms_tensor_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_ACL_VM_MS_TENSOR_REF_H
|
|
@ -900,7 +900,10 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
|
|||
return *this;
|
||||
}
|
||||
UpdateDataOpDesc(it, op);
|
||||
|
||||
if (HasAbstractMonad(it)) {
|
||||
MS_LOG(INFO) << it->DebugString() << " is a monad parameter, skip.";
|
||||
continue;
|
||||
}
|
||||
MS_LOG(INFO) << "add input " << it->ToString() << ", index " << index;
|
||||
(void)std::static_pointer_cast<Data>(op)->set_attr_index(index++);
|
||||
inputs.push_back(*op);
|
||||
|
@ -945,13 +948,17 @@ void DfGraphConvertor::UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr
|
|||
MS_LOG(ERROR) << "Update data op descriptor failed! Invalid node.";
|
||||
return;
|
||||
}
|
||||
auto normal_shape_ptr = dyn_cast<abstract::Shape>(node->Shape());
|
||||
|
||||
std::vector<int64_t> shape;
|
||||
if (normal_shape_ptr == nullptr) {
|
||||
if (auto normal_shape_ptr = dyn_cast<abstract::Shape>(node->Shape()); normal_shape_ptr != nullptr) {
|
||||
shape = normal_shape_ptr->shape();
|
||||
} else if (auto no_shape_ptr = dyn_cast<abstract::NoShape>(node->Shape()); no_shape_ptr != nullptr) {
|
||||
shape = {};
|
||||
} else {
|
||||
MS_LOG(INFO) << "Invalid shape to update data op descriptor.";
|
||||
return;
|
||||
}
|
||||
shape = normal_shape_ptr->shape();
|
||||
|
||||
if (node->Type() == nullptr) {
|
||||
MS_LOG(INFO) << "Invalid type to update data op descriptor.";
|
||||
return;
|
||||
|
@ -1887,7 +1894,6 @@ OperatorPtr DfGraphConvertor::ConvertValueNode(const ValueNodePtr node) {
|
|||
auto ge_tensor = const_op->get_attr_value();
|
||||
auto ge_desc = ge_tensor.GetTensorDesc();
|
||||
(void)const_op->update_output_desc_y(ge_desc);
|
||||
|
||||
op_cache_[node.get()] = op;
|
||||
return op_cache_[node.get()];
|
||||
}
|
||||
|
|
|
@ -145,22 +145,6 @@ void CompileGraph::PushParameters(const FuncGraphPtr &graph) {
|
|||
}
|
||||
}
|
||||
|
||||
void CompileGraph::PushInputs(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> parameters = graph->parameters();
|
||||
for (size_t i = parameters.size(); i != 0; i--) {
|
||||
MS_EXCEPTION_IF_NULL(parameters[i - 1]);
|
||||
auto param = parameters[i - 1]->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
if (param->has_default()) {
|
||||
MS_LOG(DEBUG) << "Parameter " << (i - 1) << ": " << param->DebugString() << " has default value, skip.";
|
||||
continue;
|
||||
}
|
||||
Push(param);
|
||||
MS_LOG(DEBUG) << "Push parameter " << (i - 1) << ": " << param->DebugString(true);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t CompileGraph::LinConvert(const FuncGraphPtr &graph, const GraphSegmentPtr &segment, const std::string &target) {
|
||||
MS_EXCEPTION_IF_NULL(segment);
|
||||
MS_LOG(DEBUG) << "LinConvert start";
|
||||
|
@ -273,15 +257,11 @@ bool CompileGraph::Compile(const FuncGraphPtr &graph) {
|
|||
return true;
|
||||
}
|
||||
|
||||
InstSet CompileGraph::Run(const FuncGraphPtr &graph, bool push_weight) {
|
||||
InstSet CompileGraph::Run(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
||||
Reset();
|
||||
if (push_weight) {
|
||||
PushParameters(graph);
|
||||
} else {
|
||||
PushInputs(graph);
|
||||
}
|
||||
PushParameters(graph);
|
||||
|
||||
int64_t param_height = height_;
|
||||
MS_EXCEPTION_IF_NULL(graph->get_return());
|
||||
|
|
|
@ -55,7 +55,7 @@ class CompileGraph {
|
|||
|
||||
virtual ~CompileGraph() = default;
|
||||
|
||||
InstSet Run(const FuncGraphPtr &func_graph, bool push_weight = true);
|
||||
InstSet Run(const FuncGraphPtr &func_graph);
|
||||
bool IsCut(const AnfNodePtr &node);
|
||||
void Push(const AnfNodePtr &node);
|
||||
void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; }
|
||||
|
@ -77,22 +77,21 @@ class CompileGraph {
|
|||
}
|
||||
|
||||
protected:
|
||||
void PushParameters(const FuncGraphPtr &func_graph);
|
||||
void PushInputs(const FuncGraphPtr &graph);
|
||||
virtual void PushParameters(const FuncGraphPtr &func_graph);
|
||||
bool Compile(const FuncGraphPtr &func_graph);
|
||||
int64_t LinConvert(const FuncGraphPtr &func_graph, const GraphSegmentPtr &segment, const std::string &target = "");
|
||||
int64_t InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node);
|
||||
int64_t AddCall(const FuncGraphPtr &graph, const CNodePtr &node);
|
||||
virtual int64_t AddCall(const FuncGraphPtr &graph, const CNodePtr &node);
|
||||
void AddPadStack(int64_t param_height);
|
||||
void AddTailCall(const AnfNodePtr &fn, size_t size);
|
||||
void AddPartial(const CNodePtr &node);
|
||||
virtual void AddPartial(const CNodePtr &node);
|
||||
void AddMakeTuple(const CNodePtr &node);
|
||||
void AddSwitch(const CNodePtr &node);
|
||||
void AddSwitchLayer(const CNodePtr &node);
|
||||
void AddReturn(const CNodePtr &node);
|
||||
void AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim);
|
||||
void AddInput(const AnfNodePtr &node);
|
||||
void AddExternal(const LinConvertResult &result);
|
||||
virtual void AddInput(const AnfNodePtr &node);
|
||||
virtual void AddExternal(const LinConvertResult &result);
|
||||
void AddInst(const Instruction &inst, const int64_t &arg);
|
||||
void AddInst(const Instruction &inst, const ValuePtr &arg);
|
||||
void AddInst(const Instruction &inst, const VectorRef &args);
|
||||
|
@ -122,7 +121,7 @@ class CompileGraphs {
|
|||
mapping_.clear();
|
||||
}
|
||||
|
||||
virtual void Compile(const FuncGraphPtr &func_graph);
|
||||
void Compile(const FuncGraphPtr &func_graph);
|
||||
FinalVMPtr Link();
|
||||
FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph);
|
||||
|
||||
|
|
|
@ -89,7 +89,7 @@ void FinalVM::Pop(int64_t n) {
|
|||
|
||||
void FinalVM::MoveStack(int64_t nitems, int64_t height) {
|
||||
if (nitems > height || height > sp_) {
|
||||
MS_LOG(EXCEPTION) << "MoveStack arg error: nitems=" << nitems << " height=" << height;
|
||||
MS_LOG(EXCEPTION) << "MoveStack arg error: nitems=" << nitems << " height=" << height << " sp=" << sp_;
|
||||
}
|
||||
int64_t n = height - nitems;
|
||||
int64_t src = sp_ - height;
|
||||
|
|
Loading…
Reference in New Issue