!18362 Link value node for control flow and organize control node parser.
Merge pull request !18362 from gaoyong10/new_runtime5
This commit is contained in:
commit
b324a49078
|
@ -94,8 +94,8 @@ void GatherActor::FetchBackendInputNode(const FuncGraphPtr &func_graph,
|
|||
if (parameters[i]->isa<Parameter>()) {
|
||||
// Input node is a parameter from host data source actor.
|
||||
std::vector<AnfNodePtr> invalid_inputs;
|
||||
std::vector<AnfNodePtr> front_inputs = ControlNodeParser::FetchInputNodeByParameter(
|
||||
parameters[i], origin_parameters_order, &invalid_inputs, func_graph_to_parameters);
|
||||
std::vector<AnfNodePtr> front_inputs =
|
||||
FetchInputNodeByParameter(parameters[i], origin_parameters_order, &invalid_inputs, func_graph_to_parameters);
|
||||
|
||||
for (const auto &front_input : front_inputs) {
|
||||
const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(front_input, 0);
|
||||
|
|
|
@ -34,6 +34,8 @@
|
|||
namespace mindspore {
|
||||
namespace runtime {
|
||||
|
||||
constexpr size_t kReturnInputPos = 1;
|
||||
|
||||
// Gather actor is the entrance of sub funcgraph. Graph input is sent to it and sent to other actors by gather actor.
|
||||
class GatherActor : public OpActor<DeviceTensor> {
|
||||
public:
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "runtime/framework/actor/switch_actor.h"
|
||||
#include "runtime/framework/actor/output_actor.h"
|
||||
#include "runtime/framework/actor/memory_manager_actor.h"
|
||||
#include "mindrt/include/async/async.h"
|
||||
#include "abstract/utils.h"
|
||||
|
@ -57,7 +58,6 @@ void SwitchActor::Initialize() {
|
|||
} else {
|
||||
InitSwitchLayer();
|
||||
}
|
||||
input_datas_num_ = input_nodes_.size();
|
||||
}
|
||||
|
||||
void SwitchActor::InitPartial(const AnfNodePtr &node, const size_t branch_id) {
|
||||
|
@ -68,8 +68,18 @@ void SwitchActor::InitPartial(const AnfNodePtr &node, const size_t branch_id) {
|
|||
// [0] ValueNode<Primitive> kPartial.
|
||||
// [1] ValueNode<FuncGraphPtr>.
|
||||
// [2..] Inputs.
|
||||
auto node_inputs = cnode->inputs();
|
||||
branch_func_graph_[branch_id] = GetValueNode<FuncGraphPtr>(node_inputs[kPartialFuncGraphPos]);
|
||||
const auto &node_inputs = cnode->inputs();
|
||||
if (node_inputs.size() <= kPartialFuncGraphPos) {
|
||||
MS_LOG(EXCEPTION) << "Invalid Partial node:" << AnfAlgo::GetNodeDebugString(cnode);
|
||||
}
|
||||
|
||||
const auto &func_graph = GetValueNode<FuncGraphPtr>(node_inputs[kPartialFuncGraphPos]);
|
||||
if (func_graph->output()->isa<ValueNode>()) {
|
||||
AddInput(func_graph->output(), branch_id);
|
||||
return;
|
||||
}
|
||||
|
||||
branch_func_graph_[branch_id] = func_graph;
|
||||
for (size_t j = kPartialInputStartPos; j < node_inputs.size(); ++j) {
|
||||
AddInput(node_inputs[j], branch_id);
|
||||
}
|
||||
|
@ -93,8 +103,11 @@ void SwitchActor::InitSwitch() {
|
|||
branch_inputs_pos_.resize(kSwitchPartialNum);
|
||||
branch_func_graph_.resize(kSwitchPartialNum);
|
||||
output_branch_arrows_.resize(kSwitchPartialNum);
|
||||
input_nodes_.push_back(inputs[kSwitchCondPos]);
|
||||
output_branch_result_arrows_.resize(kSwitchPartialNum);
|
||||
output_branch_control_arrows_.resize(kSwitchPartialNum);
|
||||
|
||||
input_nodes_.push_back(inputs[kSwitchCondPos]);
|
||||
input_datas_num_++;
|
||||
// Init the two branches of switch node.
|
||||
InitPartial(inputs[kSwitchFalseBranchPos], static_cast<size_t>(false));
|
||||
InitPartial(inputs[kSwitchTrueBranchPos], static_cast<size_t>(true));
|
||||
|
@ -118,6 +131,8 @@ void SwitchActor::InitSwitchLayer() {
|
|||
branch_inputs_pos_.resize(branch_nodes.size() - 1);
|
||||
branch_func_graph_.resize(branch_nodes.size() - 1);
|
||||
output_branch_arrows_.resize(branch_nodes.size() - 1);
|
||||
output_branch_result_arrows_.resize(branch_nodes.size() - 1);
|
||||
output_branch_control_arrows_.resize(branch_nodes.size() - 1);
|
||||
|
||||
// Parse all branches.
|
||||
for (size_t i = 1; i < branch_nodes.size(); ++i) {
|
||||
|
@ -147,14 +162,23 @@ size_t SwitchActor::FetchDataNodePosition(const AnfNodePtr &data_node) const {
|
|||
void SwitchActor::AddInput(const AnfNodePtr &node, const size_t branch) {
|
||||
branch_total_inputs_[branch].push_back(node);
|
||||
|
||||
if (node->isa<ValueNode>() && (!HasAbstractMonad(node))) {
|
||||
device_tensor_store_keys_.push_back({input_nodes_.size(), node.get()});
|
||||
branch_inputs_pos_[branch].push_back(input_nodes_.size());
|
||||
input_nodes_.push_back(node);
|
||||
return;
|
||||
}
|
||||
|
||||
// Switch actor only receives parameter, updatestate node output is U, need to be skipped.
|
||||
if (IsPersistentDeviceTensor(node) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto iter = find(input_nodes_.begin(), input_nodes_.end(), node);
|
||||
if (iter == input_nodes_.end()) {
|
||||
branch_inputs_pos_[branch].push_back(input_nodes_.size());
|
||||
input_nodes_.push_back(node);
|
||||
++input_datas_num_;
|
||||
} else {
|
||||
branch_inputs_pos_[branch].push_back(iter - input_nodes_.begin());
|
||||
}
|
||||
|
@ -208,8 +232,7 @@ bool SwitchActor::CheckLaunchCondition(OpContext<DeviceTensor> *context) const {
|
|||
|
||||
void SwitchActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
|
||||
input_device_tensors_.resize(input_datas_num_);
|
||||
input_device_tensors_.resize(input_nodes_.size());
|
||||
auto data_iter = input_op_datas_.find(context->sequential_num_);
|
||||
if (data_iter != input_op_datas_.end()) {
|
||||
for (auto &input_data : data_iter->second) {
|
||||
|
@ -218,6 +241,18 @@ void SwitchActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
|
|||
}
|
||||
}
|
||||
data_iter->second.clear();
|
||||
|
||||
for (const auto &device_tensor_store_key : device_tensor_store_keys_) {
|
||||
auto device_tensor =
|
||||
DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second, device_context_->GetDeviceAddressType());
|
||||
if (device_tensor == nullptr) {
|
||||
std::string error_info =
|
||||
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->DebugString() +
|
||||
", device type:" + std::to_string(static_cast<int>(device_context_->GetDeviceAddressType()));
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
input_device_tensors_[device_tensor_store_key.first] = device_tensor;
|
||||
}
|
||||
}
|
||||
|
||||
void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
|
||||
|
@ -237,6 +272,28 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
|
|||
data->data_ = input_device_tensors_[data_arrow->from_output_index_];
|
||||
Async(data_arrow->to_op_id_, &OpActor::RunOpData, data.get(), context);
|
||||
}
|
||||
|
||||
// Send result.
|
||||
auto &output_branch_result_arrow = output_branch_result_arrows_[index];
|
||||
for (size_t i = 0; i < output_branch_result_arrow.size(); ++i) {
|
||||
auto &result_arrow = output_branch_result_arrow[i];
|
||||
MS_EXCEPTION_IF_NULL(result_arrow);
|
||||
size_t from_index = result_arrow->from_output_index_;
|
||||
for (const auto &backend_node : front_to_backend_parameter_[from_index]) {
|
||||
if (AnfAlgo::GetMutableOutputAddr(backend_node.first, backend_node.second).get() ==
|
||||
input_device_tensors_[from_index]) {
|
||||
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, backend_node.first, backend_node.second,
|
||||
result_arrow->to_input_index_, context);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send output control.
|
||||
auto source_aid = const_cast<AID *>(&GetAID());
|
||||
for (auto &output_control : output_branch_control_arrows_[index]) {
|
||||
Async(output_control, &OpActor::RunOpControl, source_aid, context);
|
||||
}
|
||||
}
|
||||
|
||||
void SwitchActor::EraseInput(OpContext<DeviceTensor> *context) {
|
||||
|
@ -262,5 +319,37 @@ void SwitchActor::SendMemoryFreeReq(OpContext<DeviceTensor> *context) {
|
|||
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensors_, device_context_, context);
|
||||
}
|
||||
|
||||
void SwitchActor::FetchInputNode(const std::vector<AnfNodePtr> &origin_parameters_order,
|
||||
const FrontToBackendNodeWithContext &front_to_backend_parameters,
|
||||
const std::unordered_map<AnfNodePtr, AnfNodePtr> &front_to_backend_kernel) {
|
||||
front_to_backend_parameter_.resize(input_nodes_.size());
|
||||
|
||||
for (size_t i = 0; i < input_nodes_.size(); ++i) {
|
||||
const auto &input_node = input_nodes_[i];
|
||||
if (input_node->isa<ValueNode>()) {
|
||||
front_to_backend_parameter_[i].push_back({input_node, 0});
|
||||
} else if (input_node->isa<Parameter>()) {
|
||||
if (front_to_backend_parameters.find(input_node) != front_to_backend_parameters.end()) {
|
||||
const auto backend_node = front_to_backend_parameters.at(input_node).first;
|
||||
front_to_backend_parameter_[i].push_back({backend_node, 0});
|
||||
}
|
||||
} else if (input_node->isa<CNode>()) {
|
||||
if (IsCallNode(input_node)) {
|
||||
const auto func_graphs = FetchFuncGraphbyCallNode(input_node->cast<CNodePtr>());
|
||||
for (const auto func_graph : func_graphs) {
|
||||
if (func_graph->output()->isa<ValueNode>()) {
|
||||
front_to_backend_parameter_[i].push_back({func_graph->output(), 0});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const auto &kernel_with_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0);
|
||||
if (front_to_backend_kernel.find(input_node) != front_to_backend_kernel.end()) {
|
||||
front_to_backend_parameter_[i].emplace_back(kernel_with_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include <unordered_map>
|
||||
#include "runtime/framework/actor/actor_common.h"
|
||||
#include "runtime/framework/device_tensor_store.h"
|
||||
#include "runtime/framework/control_node_parser.h"
|
||||
#include "mindrt/include/actor/switch_actor.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
|
||||
|
@ -56,7 +57,8 @@ constexpr size_t kMakeTupleInputStartPos = 1;
|
|||
// 5. Free Memory
|
||||
class SwitchActor : public SwitchActorBase<DeviceTensor> {
|
||||
public:
|
||||
SwitchActor(const std::string &name, const CNodePtr &node) : SwitchActorBase(name), node_(node) {}
|
||||
SwitchActor(const std::string &name, DeviceContext *device_context, const CNodePtr &node)
|
||||
: SwitchActorBase(name), device_context_(device_context), node_(node) {}
|
||||
~SwitchActor() override = default;
|
||||
|
||||
void Init() override;
|
||||
|
@ -91,19 +93,35 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
|
|||
void EraseInput(OpContext<DeviceTensor> *context);
|
||||
void SendMemoryFreeReq(OpContext<DeviceTensor> *context);
|
||||
|
||||
// Collect all the backend inputs of switch actor.
|
||||
void FetchInputNode(const std::vector<AnfNodePtr> &origin_parameters_order,
|
||||
const FrontToBackendNodeWithContext &front_to_backend_parameters,
|
||||
const std::unordered_map<AnfNodePtr, AnfNodePtr> &front_to_backend_kernel);
|
||||
// All inputs of the switch actor, excluding weight and tensor.
|
||||
// Used to receive input data, the first input is the condition of switch.
|
||||
std::vector<AnfNodePtr> input_nodes_;
|
||||
// The position of the branch output in the input_nodes_.
|
||||
std::vector<std::vector<size_t>> branch_inputs_pos_;
|
||||
|
||||
// Control arrows of different branches.
|
||||
std::vector<std::vector<AID>> output_branch_control_arrows_;
|
||||
// Result arrows of different branches.
|
||||
std::vector<std::vector<DataArrowPtr>> output_branch_result_arrows_;
|
||||
|
||||
// When the output is a value node from switch actor, the actor needs to send the anfnode to the output actor,
|
||||
// so all the nodes that may send the device tensor to switch actor are recorded.
|
||||
std::vector<std::vector<KernelWithIndex>> front_to_backend_parameter_;
|
||||
|
||||
std::vector<std::vector<AnfNodePtr>> branch_total_inputs_;
|
||||
std::vector<FuncGraphPtr> branch_func_graph_;
|
||||
|
||||
// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
|
||||
std::vector<std::pair<size_t, AnfNode *>> device_tensor_store_keys_;
|
||||
|
||||
std::vector<DeviceTensor *> input_device_tensors_;
|
||||
|
||||
// Save the DeviceContext of input_nodes_, which is used to release the DeviceTensor.
|
||||
DeviceContext *device_context_;
|
||||
const DeviceContext *device_context_;
|
||||
|
||||
// The id of memory manager actor. Send message to it for alloc and free memory.
|
||||
const AID memory_manager_aid_;
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "runtime/framework/control_node_parser.h"
|
||||
#include "runtime/framework/actor/switch_actor.h"
|
||||
#include "ir/tensor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
|
@ -43,6 +44,12 @@ bool CheckValidFuncGraphInput(const AnfNodePtr &node) {
|
|||
return (!IsPersistentDeviceTensor(node)) && (!HasAbstractMonad(node));
|
||||
}
|
||||
|
||||
// Get the funcgraph in partial node.
|
||||
FuncGraphPtr GetFuncGraphFromPartial(const AnfNodePtr &node) {
|
||||
const auto &partial_inputs = node->cast<CNodePtr>()->inputs();
|
||||
return GetValueNode<FuncGraphPtr>(partial_inputs[1]);
|
||||
}
|
||||
|
||||
// Get the corresponding relationship between funcgraph and parameters in the switch node.
|
||||
void FetchParameterBySwitchNode(const AnfNodePtr &switch_node, FuncGraphToParameter *graph_to_real_parameters) {
|
||||
const auto &switch_cnode = switch_node->cast<CNodePtr>();
|
||||
|
@ -53,7 +60,7 @@ void FetchParameterBySwitchNode(const AnfNodePtr &switch_node, FuncGraphToParame
|
|||
|
||||
for (size_t i = kSwitchTrueBranchPos; i < kSwitchInputNum; ++i) {
|
||||
const auto &partial_node = switch_inputs[i];
|
||||
const auto &func_graph = ControlNodeParser::GetFuncGraphFromPartial(partial_node);
|
||||
const auto &func_graph = GetFuncGraphFromPartial(partial_node);
|
||||
std::vector<AnfNodePtr> parameters;
|
||||
const auto &partial_inputs = partial_node->cast<CNodePtr>()->inputs();
|
||||
for (size_t j = kPartialInputStartPos; j < partial_inputs.size(); ++j) {
|
||||
|
@ -78,7 +85,7 @@ void FetchParameterBySwitchLayerNode(const AnfNodePtr &switch_layer_node, const
|
|||
auto tuple_inputs = switch_layer_inputs[kSwitchLayerBranchPos]->cast<CNodePtr>()->inputs();
|
||||
for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) {
|
||||
if (AnfAlgo::CheckPrimitiveType(tuple_inputs[i], prim::kPrimPartial)) {
|
||||
const auto &func_graph = ControlNodeParser::GetFuncGraphFromPartial(tuple_inputs[i]);
|
||||
const auto &func_graph = GetFuncGraphFromPartial(tuple_inputs[i]);
|
||||
std::vector<AnfNodePtr> parameters;
|
||||
const auto &partial_inputs = tuple_inputs[i]->cast<CNodePtr>()->inputs();
|
||||
for (size_t j = kPartialInputStartPos; j < partial_inputs.size(); ++j) {
|
||||
|
@ -104,9 +111,46 @@ void FetchParameterBySwitchLayerNode(const AnfNodePtr &switch_layer_node, const
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a device tensor for the front node.
|
||||
// Get the output format and select kernel build info from the backend node corresponding to the front node to
|
||||
// create the device address.
|
||||
void CreateDeviceTensorForValueNode(const AnfNodePtr &front_node, const AnfNodePtr &backend_node,
|
||||
const DeviceContext *device_context) {
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
|
||||
const auto &node_value = front_node->cast<ValueNodePtr>()->value();
|
||||
if (!node_value->isa<tensor::Tensor>()) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(backend_node, 0);
|
||||
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(backend_node, 0);
|
||||
if (output_type_id == kTypeUnknown) {
|
||||
output_type_id = AnfAlgo::GetOutputInferDataType(backend_node, 0);
|
||||
}
|
||||
|
||||
if (front_node->kernel_info() == nullptr) {
|
||||
front_node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
}
|
||||
|
||||
// Get the select kernel build info.
|
||||
auto kernel_info = static_cast<device::KernelInfo *>(backend_node->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
auto build_info = kernel_info->GetMutableSelectKernelBuildInfo();
|
||||
MS_EXCEPTION_IF_NULL(build_info);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(build_info, front_node.get());
|
||||
|
||||
// Create device tensor.
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(backend_node, 0);
|
||||
device::DeviceAddressPtr address =
|
||||
device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id);
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
AnfAlgo::SetOutputAddr(address, 0, front_node.get());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool ControlNodeParser::IsCallNode(const AnfNodePtr &node) {
|
||||
bool IsCallNode(const AnfNodePtr &node) {
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
|
@ -115,12 +159,7 @@ bool ControlNodeParser::IsCallNode(const AnfNodePtr &node) {
|
|||
return inputs[0]->isa<CNode>() || (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0]));
|
||||
}
|
||||
|
||||
FuncGraphPtr ControlNodeParser::GetFuncGraphFromPartial(const AnfNodePtr &node) {
|
||||
const auto &partial_inputs = node->cast<CNodePtr>()->inputs();
|
||||
return GetValueNode<FuncGraphPtr>(partial_inputs[1]);
|
||||
}
|
||||
|
||||
std::vector<FuncGraphPtr> ControlNodeParser::FetchFuncGraphbyCallNode(const CNodePtr &node) {
|
||||
std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const CNodePtr &node) {
|
||||
std::vector<FuncGraphPtr> func_graphs;
|
||||
const auto &call_inputs = node->inputs();
|
||||
|
||||
|
@ -155,6 +194,182 @@ std::vector<FuncGraphPtr> ControlNodeParser::FetchFuncGraphbyCallNode(const CNod
|
|||
return func_graphs;
|
||||
}
|
||||
|
||||
FuncGraphPtr FetchFuncGraphByNode(const AnfNodePtr &node) {
|
||||
auto front_node = GetFrontNodeByBackendNode(node);
|
||||
|
||||
// If the front node is nullptr, we can check its inputs.
|
||||
if (front_node == nullptr) {
|
||||
if (node->isa<CNode>()) {
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
const auto &func_graph = FetchFuncGraphByNode(inputs[i]);
|
||||
if (func_graph != nullptr) {
|
||||
return func_graph;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
const auto &func_graph = front_node->func_graph();
|
||||
return func_graph;
|
||||
}
|
||||
|
||||
AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node) {
|
||||
if (backend_node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto kernel_graph = dynamic_cast<KernelGraph *>(backend_node->func_graph().get());
|
||||
if (kernel_graph == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return kernel_graph->GetFrontAnfByBackendAnf(backend_node);
|
||||
}
|
||||
|
||||
FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node) {
|
||||
auto front_node = GetFrontNodeByBackendNode(backend_node);
|
||||
if (front_node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return front_node->func_graph();
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> FetchInputNodeByParameter(const AnfNodePtr ¶meter,
|
||||
const std::vector<AnfNodePtr> &host_ds_parameters,
|
||||
std::vector<AnfNodePtr> *invalid_inputs,
|
||||
const FuncGraphToParameter &graph_to_real_parameters) {
|
||||
std::vector<AnfNodePtr> input_nodes;
|
||||
|
||||
// If the node has been collected, skip it.
|
||||
if (find((*invalid_inputs).begin(), (*invalid_inputs).end(), parameter) != (*invalid_inputs).end()) {
|
||||
return input_nodes;
|
||||
}
|
||||
|
||||
// Record the node which has been collected.
|
||||
(*invalid_inputs).emplace_back(parameter);
|
||||
|
||||
// If the parameter node is a parameter of host data source actor, return it.
|
||||
if (find(host_ds_parameters.begin(), host_ds_parameters.end(), parameter) != host_ds_parameters.end()) {
|
||||
input_nodes.emplace_back(parameter);
|
||||
return input_nodes;
|
||||
}
|
||||
|
||||
// Check the parameter which send to its funcgraph.
|
||||
const auto &func_graph = parameter->func_graph();
|
||||
if (graph_to_real_parameters.find(func_graph) == graph_to_real_parameters.end()) {
|
||||
return input_nodes;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> self_inputs;
|
||||
for (const auto &input : func_graph->get_inputs()) {
|
||||
// Monad input need not send to funcgraph.
|
||||
if (!HasAbstractMonad(input)) {
|
||||
self_inputs.emplace_back(input);
|
||||
}
|
||||
}
|
||||
|
||||
size_t pos = find(self_inputs.begin(), self_inputs.end(), parameter) - self_inputs.begin();
|
||||
for (const auto parameters : graph_to_real_parameters.at(func_graph)) {
|
||||
const auto input = parameters[pos];
|
||||
if (input->isa<CNode>()) {
|
||||
input_nodes.emplace_back(input);
|
||||
} else if (input->isa<Parameter>()) {
|
||||
// If input is a parameter, you need to find its input recursively.
|
||||
auto inputs = FetchInputNodeByParameter(input, host_ds_parameters, invalid_inputs, graph_to_real_parameters);
|
||||
input_nodes.insert(input_nodes.end(), inputs.begin(), inputs.end());
|
||||
}
|
||||
}
|
||||
return input_nodes;
|
||||
}
|
||||
|
||||
void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph) {
|
||||
FetchFrontToBackendParameterMap(graphs, device_contexts, control_nodes);
|
||||
|
||||
FetchFuncGraphToParameterMap(control_nodes);
|
||||
|
||||
FetchHostParameterToWeightMap(control_nodes);
|
||||
|
||||
FetchFrontValueNode(graphs, device_contexts);
|
||||
|
||||
// Get inputs of control node which come from the host actor.
|
||||
control_node_parameters_ = FetchControlNodeParameter(control_nodes);
|
||||
|
||||
front_output_nodes_ = FetchAllBranchOutputs(root_graph);
|
||||
}
|
||||
|
||||
void ControlNodeParser::FetchValueNodeInSwitchNode(const AnfNodePtr &switch_node,
|
||||
std::vector<AnfNodePtr> *value_nodes) {
|
||||
const auto &cnode = switch_node->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (inputs.size() != kSwitchInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Invalid switch node input num:" << inputs.size();
|
||||
}
|
||||
|
||||
for (const auto &input : inputs) {
|
||||
if (input->isa<ValueNode>()) {
|
||||
const auto &node_value = input->cast<ValueNodePtr>()->value();
|
||||
if (node_value->isa<tensor::Tensor>()) {
|
||||
(*value_nodes).emplace_back(input);
|
||||
}
|
||||
} else if (IsCallNode(input)) {
|
||||
// If input is a call not, should check the switch node in its input.
|
||||
const auto &call_node = input->cast<CNodePtr>();
|
||||
const auto &call_inputs = call_node->inputs();
|
||||
if (call_inputs.empty() || (!AnfAlgo::CheckPrimitiveType(call_inputs[0], prim::kPrimSwitch))) {
|
||||
continue;
|
||||
}
|
||||
FetchValueNodeInSwitchNode(call_inputs[0], value_nodes);
|
||||
} else if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimPartial)) {
|
||||
const auto &partial_node = input->cast<CNodePtr>();
|
||||
const auto &partial_inputs = partial_node->inputs();
|
||||
if (partial_inputs.size() <= kPartialFuncGraphPos) {
|
||||
MS_LOG(EXCEPTION) << "Invalid partial node input num:" << partial_inputs.size();
|
||||
}
|
||||
|
||||
// if input is a partial node, get the value node in its funcgraph.
|
||||
const auto &func_graph = GetValueNode<FuncGraphPtr>(partial_inputs[kPartialFuncGraphPos]);
|
||||
if (func_graph->output()->isa<ValueNode>()) {
|
||||
(*value_nodes).emplace_back(func_graph->output());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ControlNodeParser::FetchFrontValueNode(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts) {
|
||||
for (size_t index = 0; index < graphs.size(); ++index) {
|
||||
const auto &graph = graphs[index];
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto execution_order = graph->execution_order();
|
||||
|
||||
for (const auto ¶meter : graph->input_nodes()) {
|
||||
const auto &front_node = graph->GetFrontAnfByBackendAnf(parameter);
|
||||
const auto &internal_node = graph->GetFrontNodeByInternalParameter(parameter);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
if (IsInternalParameter(parameter, graph)) {
|
||||
auto front_node_with_index = graph->GetFrontNodeByInternalParameter(parameter);
|
||||
MS_EXCEPTION_IF_NULL(front_node_with_index.first);
|
||||
const auto &front_output_with_index =
|
||||
AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, front_node_with_index.second, false);
|
||||
auto front_output_node = front_output_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(front_output_node);
|
||||
if (AnfAlgo::CheckPrimitiveType(front_output_node, prim::kPrimSwitch)) {
|
||||
std::vector<AnfNodePtr> value_nodes;
|
||||
FetchValueNodeInSwitchNode(front_output_node, &value_nodes);
|
||||
for (const auto value_node : value_nodes) {
|
||||
CreateDeviceTensorForValueNode(value_node, parameter, device_contexts[index]);
|
||||
front_value_nodes_.push_back({value_node, device_contexts[index]});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr ControlNodeParser::GetCallNodeInputByPos(const AnfNodePtr &call_node, const FuncGraphPtr &func_graph,
|
||||
const size_t pos) {
|
||||
const auto &cnode = call_node->cast<CNodePtr>();
|
||||
|
@ -401,8 +616,7 @@ std::vector<AnfNodePtr> ControlNodeParser::FetchAllBranchOutputs(const FuncGraph
|
|||
|
||||
void ControlNodeParser::FetchFrontToBackendParameterMap(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts,
|
||||
const std::vector<AnfNodePtr> &control_nodes,
|
||||
FrontToBackendNodeWithContext *front_to_backend_parameter) {
|
||||
const std::vector<AnfNodePtr> &control_nodes) {
|
||||
if (graphs.size() != device_contexts.size()) {
|
||||
MS_LOG(EXCEPTION) << "Graph num is not equal to device context num.";
|
||||
}
|
||||
|
@ -414,8 +628,8 @@ void ControlNodeParser::FetchFrontToBackendParameterMap(const std::vector<Kernel
|
|||
for (const auto ¶meter : graph->parameters()) {
|
||||
auto front_node = graph->GetFrontAnfByBackendAnf(parameter);
|
||||
if (front_node != nullptr && front_node->isa<Parameter>() &&
|
||||
(*front_to_backend_parameter).find(front_node) == (*front_to_backend_parameter).end()) {
|
||||
(*front_to_backend_parameter)[front_node] = {parameter, device_context};
|
||||
front_to_backend_parameters_.find(front_node) == front_to_backend_parameters_.end()) {
|
||||
front_to_backend_parameters_[front_node] = {parameter, device_context};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -431,15 +645,14 @@ void ControlNodeParser::FetchFrontToBackendParameterMap(const std::vector<Kernel
|
|||
for (const auto &front_pair : front_to_front_parameter) {
|
||||
std::set<AnfNodePtr> invalid_node;
|
||||
const auto &backend_node = FetchBackendNodeByFrontNode(front_pair.first, front_to_front_parameter,
|
||||
*front_to_backend_parameter, &invalid_node);
|
||||
front_to_backend_parameters_, &invalid_node);
|
||||
if (backend_node.first != nullptr) {
|
||||
(*front_to_backend_parameter)[front_pair.first] = backend_node;
|
||||
front_to_backend_parameters_[front_pair.first] = backend_node;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ControlNodeParser::FetchHostParameterToWeightMap(const std::vector<AnfNodePtr> &control_nodes,
|
||||
HostParameterToWeight *host_parameter_to_weights) {
|
||||
void ControlNodeParser::FetchHostParameterToWeightMap(const std::vector<AnfNodePtr> &control_nodes) {
|
||||
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> front_to_front_parameter;
|
||||
|
||||
FetchFrontToFrontParameterMap(control_nodes, &front_to_front_parameter);
|
||||
|
@ -447,54 +660,11 @@ void ControlNodeParser::FetchHostParameterToWeightMap(const std::vector<AnfNodeP
|
|||
for (const auto &pair : front_to_front_parameter) {
|
||||
std::vector<AnfNodePtr> dest_nodes;
|
||||
FetchWeightbyHostParameter(pair.first, &dest_nodes, front_to_front_parameter);
|
||||
(*host_parameter_to_weights)[pair.first] = dest_nodes;
|
||||
host_parameter_to_weights_[pair.first] = dest_nodes;
|
||||
}
|
||||
}
|
||||
|
||||
FuncGraphPtr ControlNodeParser::FetchFuncGraphByNode(const AnfNodePtr &node) {
|
||||
auto front_node = GetFrontNodeByBackendNode(node);
|
||||
|
||||
// If the front node is nullptr, we can check its inputs.
|
||||
if (front_node == nullptr) {
|
||||
if (node->isa<CNode>()) {
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
const auto &func_graph = FetchFuncGraphByNode(inputs[i]);
|
||||
if (func_graph != nullptr) {
|
||||
return func_graph;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
const auto &func_graph = front_node->func_graph();
|
||||
return func_graph;
|
||||
}
|
||||
|
||||
AnfNodePtr ControlNodeParser::GetFrontNodeByBackendNode(const AnfNodePtr &backend_node) {
|
||||
if (backend_node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto kernel_graph = dynamic_cast<KernelGraph *>(backend_node->func_graph().get());
|
||||
if (kernel_graph == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return kernel_graph->GetFrontAnfByBackendAnf(backend_node);
|
||||
}
|
||||
|
||||
FuncGraphPtr ControlNodeParser::GetFuncgraphByBackendNode(const AnfNodePtr &backend_node) {
|
||||
auto front_node = GetFrontNodeByBackendNode(backend_node);
|
||||
if (front_node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return front_node->func_graph();
|
||||
}
|
||||
|
||||
void ControlNodeParser::FetchFuncGraphToParameterMap(const std::vector<AnfNodePtr> &control_nodes,
|
||||
FuncGraphToParameter *graph_to_real_parameters) {
|
||||
void ControlNodeParser::FetchFuncGraphToParameterMap(const std::vector<AnfNodePtr> &control_nodes) {
|
||||
for (const auto &control_node : control_nodes) {
|
||||
const auto &cnode = control_node->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
|
@ -508,10 +678,10 @@ void ControlNodeParser::FetchFuncGraphToParameterMap(const std::vector<AnfNodePt
|
|||
|
||||
if (AnfAlgo::CheckPrimitiveType(switch_cnode, prim::kPrimSwitch)) {
|
||||
// Switch node.
|
||||
FetchParameterBySwitchNode(inputs[0], graph_to_real_parameters);
|
||||
FetchParameterBySwitchNode(inputs[0], &func_graph_to_parameters_);
|
||||
} else if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) {
|
||||
// Switchlayer node.
|
||||
FetchParameterBySwitchLayerNode(inputs[0], inputs, graph_to_real_parameters);
|
||||
FetchParameterBySwitchLayerNode(inputs[0], inputs, &func_graph_to_parameters_);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unable to identify call node" << switch_cnode->DebugString();
|
||||
}
|
||||
|
@ -524,57 +694,9 @@ void ControlNodeParser::FetchFuncGraphToParameterMap(const std::vector<AnfNodePt
|
|||
parameters.emplace_back(inputs[i]);
|
||||
}
|
||||
}
|
||||
(*graph_to_real_parameters)[func_graph].emplace_back(parameters);
|
||||
func_graph_to_parameters_[func_graph].emplace_back(parameters);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> ControlNodeParser::FetchInputNodeByParameter(
|
||||
const AnfNodePtr ¶meter, const std::vector<AnfNodePtr> &host_ds_parameters,
|
||||
std::vector<AnfNodePtr> *invalid_inputs,
|
||||
const std::unordered_map<FuncGraphPtr, std::vector<std::vector<AnfNodePtr>>> &graph_to_real_parameters) {
|
||||
std::vector<AnfNodePtr> input_nodes;
|
||||
|
||||
// If the node has been collected, skip it.
|
||||
if (find((*invalid_inputs).begin(), (*invalid_inputs).end(), parameter) != (*invalid_inputs).end()) {
|
||||
return input_nodes;
|
||||
}
|
||||
|
||||
// Record the node which has been collected.
|
||||
(*invalid_inputs).emplace_back(parameter);
|
||||
|
||||
// If the parameter node is a parameter of host data source actor, return it.
|
||||
if (find(host_ds_parameters.begin(), host_ds_parameters.end(), parameter) != host_ds_parameters.end()) {
|
||||
input_nodes.emplace_back(parameter);
|
||||
return input_nodes;
|
||||
}
|
||||
|
||||
// Check the parameter which send to its funcgraph.
|
||||
const auto &func_graph = parameter->func_graph();
|
||||
if (graph_to_real_parameters.find(func_graph) == graph_to_real_parameters.end()) {
|
||||
return input_nodes;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> self_inputs;
|
||||
for (const auto &input : func_graph->get_inputs()) {
|
||||
// Monad input need not send to funcgraph.
|
||||
if (!HasAbstractMonad(input)) {
|
||||
self_inputs.emplace_back(input);
|
||||
}
|
||||
}
|
||||
|
||||
size_t pos = find(self_inputs.begin(), self_inputs.end(), parameter) - self_inputs.begin();
|
||||
for (const auto parameters : graph_to_real_parameters.at(func_graph)) {
|
||||
const auto input = parameters[pos];
|
||||
if (input->isa<CNode>()) {
|
||||
input_nodes.emplace_back(input);
|
||||
} else if (input->isa<Parameter>()) {
|
||||
// If input is a parameter, you need to find its input recursively.
|
||||
auto inputs = FetchInputNodeByParameter(input, host_ds_parameters, invalid_inputs, graph_to_real_parameters);
|
||||
input_nodes.insert(input_nodes.end(), inputs.begin(), inputs.end());
|
||||
}
|
||||
}
|
||||
return input_nodes;
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -40,71 +40,86 @@ constexpr int kSubBranchStartID = 1;
|
|||
using FrontToBackendNodeWithContext = std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>>;
|
||||
using FuncGraphToParameter = std::unordered_map<FuncGraphPtr, std::vector<std::vector<AnfNodePtr>>>;
|
||||
using HostParameterToWeight = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>;
|
||||
// ControlNodeParser is a series of tool functions used to parse control nodes.
|
||||
using NodeWithDeviceContext = std::vector<std::pair<AnfNodePtr, DeviceContext *>>;
|
||||
|
||||
// Check whether node is a call node, there are two types of call nodes:
|
||||
// 1. First input of node is a cnode.
|
||||
// 2. First input of node is a funcgraph value node.
|
||||
bool IsCallNode(const AnfNodePtr &node);
|
||||
|
||||
FuncGraphPtr FetchFuncGraphByNode(const AnfNodePtr &node);
|
||||
|
||||
// Get front node by backend node.
|
||||
AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node);
|
||||
|
||||
// Get the funcgraph to which the node belongs.
|
||||
FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node);
|
||||
|
||||
// Find all funcgraphs that the call node will call.
|
||||
std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const CNodePtr &node);
|
||||
|
||||
// Fetch all backend input nodes by parameter for gather actor.
|
||||
std::vector<AnfNodePtr> FetchInputNodeByParameter(const AnfNodePtr ¶meter,
|
||||
const std::vector<AnfNodePtr> &host_ds_parameters,
|
||||
std::vector<AnfNodePtr> *invalid_inputs,
|
||||
const FuncGraphToParameter &graph_to_real_parameters);
|
||||
|
||||
// ControlNodeParser is used to parse control nodes, and get the edges between nodes.
|
||||
class ControlNodeParser {
|
||||
public:
|
||||
// Check whether node is a call node, there are two types of call nodes:
|
||||
// 1. First input of node is a cnode.
|
||||
// 2. First input of node is a funcgraph value node.
|
||||
static bool IsCallNode(const AnfNodePtr &node);
|
||||
// Parse the control node and put the results of the parsing into member variables.
|
||||
void Parse(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph);
|
||||
|
||||
std::vector<AnfNodePtr> GetControlNodeParameter() { return control_node_parameters_; }
|
||||
|
||||
// Get the output of funcgraph, usually there is only one output node, In the control flow, there are
|
||||
// multiple branch outputs, there will be multiple output nodes.
|
||||
std::vector<AnfNodePtr> FetchAllBranchOutputs(const FuncGraphPtr &func_graph);
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
|
||||
// Collect all front value nodes. In the control flow, when the input of the switch actor is the value node, these
|
||||
// value nodes will not enter the kernel graph, so these nodes need to be saved separately, and space is allocated for
|
||||
// them separately during initialization.
|
||||
// The interface is initialized by finding the backend node in the kernel graph that the front node finally sends to.
|
||||
void FetchFrontValueNode(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts);
|
||||
|
||||
// Find all value nodes in the switch recursively.
|
||||
void FetchValueNodeInSwitchNode(const AnfNodePtr &switch_node, std::vector<AnfNodePtr> *value_nodes);
|
||||
|
||||
// Fetch all the relationships between front parameters and backend parameters.The front parameters
|
||||
// include two parts:
|
||||
// 1. The parameter from kernel graph.
|
||||
// 2. The parameter from control nodes.
|
||||
static void FetchFrontToBackendParameterMap(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts,
|
||||
const std::vector<AnfNodePtr> &control_nodes,
|
||||
FrontToBackendNodeWithContext *front_to_backend_parameter);
|
||||
void FetchFrontToBackendParameterMap(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts,
|
||||
const std::vector<AnfNodePtr> &control_nodes);
|
||||
|
||||
// Get inputs of control node which come from the host actor. These inputs generally come from the partial
|
||||
// nodes and call nodes of the root funcgraph.
|
||||
static std::vector<AnfNodePtr> FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes);
|
||||
|
||||
// Get the output of funcgraph, usually there is only one output node, In the control flow, there are
|
||||
// multiple branch outputs, there will be multiple output nodes.
|
||||
static std::vector<AnfNodePtr> FetchAllBranchOutputs(const FuncGraphPtr &func_graph);
|
||||
|
||||
// Find all funcgraphs that the call node will call.
|
||||
static std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const CNodePtr &node);
|
||||
|
||||
// Get the funcgraph to which the node belongs.
|
||||
static FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node);
|
||||
static FuncGraphPtr FetchFuncGraphByNode(const AnfNodePtr &node);
|
||||
|
||||
// Get front node by backend node.
|
||||
static AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node);
|
||||
// Get the funcgraph in partial node.
|
||||
static FuncGraphPtr GetFuncGraphFromPartial(const AnfNodePtr &node);
|
||||
std::vector<AnfNodePtr> FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes);
|
||||
|
||||
// Get all the input parameters of funcgraph. The call of funcgraph is realized through the call node,
|
||||
// and the input of the call node is the input parameter of the corresponding funcgraph.
|
||||
static void FetchFuncGraphToParameterMap(const std::vector<AnfNodePtr> &control_nodes,
|
||||
FuncGraphToParameter *graph_to_real_parameters);
|
||||
void FetchFuncGraphToParameterMap(const std::vector<AnfNodePtr> &control_nodes);
|
||||
|
||||
// Get all the front weight parameters related to the weight in the host parameter.
|
||||
static void FetchHostParameterToWeightMap(const std::vector<AnfNodePtr> &control_nodes,
|
||||
HostParameterToWeight *host_parameter_to_weights);
|
||||
void FetchHostParameterToWeightMap(const std::vector<AnfNodePtr> &control_nodes);
|
||||
|
||||
// Fetch all backend input nodes by parameter for gather actor.
|
||||
static std::vector<AnfNodePtr> FetchInputNodeByParameter(const AnfNodePtr ¶meter,
|
||||
const std::vector<AnfNodePtr> &host_ds_parameters,
|
||||
std::vector<AnfNodePtr> *invalid_inputs,
|
||||
const FuncGraphToParameter &graph_to_real_parameters);
|
||||
|
||||
private:
|
||||
// Get the pos input of call node to funcgraph.
|
||||
AnfNodePtr GetCallNodeInputByPos(const AnfNodePtr &call_node, const FuncGraphPtr &func_graph, const size_t pos);
|
||||
|
||||
// Find the output of the funcgraph, if the output is a call node, return the output of the funcgraph
|
||||
// called by the call node.
|
||||
static std::vector<AnfNodePtr> FetchFuncGraphOutput(const FuncGraphPtr &func_graph,
|
||||
std::vector<AnfNodePtr> *call_nodes);
|
||||
std::vector<AnfNodePtr> FetchFuncGraphOutput(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *call_nodes);
|
||||
|
||||
// Find the corresponding backend parameter for the front_node. If the front_node does not have the corresponding
|
||||
// backend parameter, then recursively find the backend parameters of other front parameters corresponding to the
|
||||
// front_node.
|
||||
static std::pair<AnfNodePtr, DeviceContext *> FetchBackendNodeByFrontNode(
|
||||
std::pair<AnfNodePtr, DeviceContext *> FetchBackendNodeByFrontNode(
|
||||
const AnfNodePtr &front_node,
|
||||
const std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> &front_to_front_parameter,
|
||||
const std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>> &front_to_backend_parameter,
|
||||
|
@ -115,10 +130,27 @@ class ControlNodeParser {
|
|||
// 1. The parameter is used as the input of the call node,
|
||||
// 2. The parameter is used as the input of the partial and will be input to the funcgraph of the partial in the
|
||||
// subsequent call node.
|
||||
static void FetchFrontToFrontParameterMap(
|
||||
const std::vector<AnfNodePtr> &control_nodes,
|
||||
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> *front_to_front_parameter);
|
||||
void FetchFrontToFrontParameterMap(const std::vector<AnfNodePtr> &control_nodes,
|
||||
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> *front_to_front_parameter);
|
||||
|
||||
// The front to backend parameters is used to build and link the host data source actor in the control flow scenario.
|
||||
FrontToBackendNodeWithContext front_to_backend_parameters_;
|
||||
// The funcgraph to parameters map records the input parameters of funcgraph and is used to initialize
|
||||
// the input node of gather.
|
||||
FuncGraphToParameter func_graph_to_parameters_;
|
||||
// host parameter to weights records the weights in the subgraph corresponding to the node in the root funcgraph.
|
||||
// When initializing the weights, all related weights need to be recorded as the same device tensor.
|
||||
HostParameterToWeight host_parameter_to_weights_;
|
||||
// The front value node saves all value nodes that are not in the kernel graph. These nodes are generally the
|
||||
// input of the control node.
|
||||
NodeWithDeviceContext front_value_nodes_;
|
||||
// The front output_node is used to link the output actor in multi-branch output scenario.
|
||||
std::vector<AnfNodePtr> front_output_nodes_;
|
||||
// Parameters of control node which come from the host actor.
|
||||
std::vector<AnfNodePtr> control_node_parameters_;
|
||||
};
|
||||
|
||||
using ControlNodeParserPtr = std::shared_ptr<ControlNodeParser>;
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -496,30 +496,9 @@ void GraphScheduler::PrepareRun(const ActorSet *actor_set, const GraphCompilerIn
|
|||
}
|
||||
}
|
||||
|
||||
// 3.Fill host tensors for non weighted parameters which belongs to control node.
|
||||
std::vector<AnfNodePtr> control_node_parameters =
|
||||
ControlNodeParser::FetchControlNodeParameter(graph_compiler_info.control_nodes_);
|
||||
const auto &tensors = input_tensors.back();
|
||||
const auto ¶meters = graph_compiler_info.origin_parameters_order_;
|
||||
for (size_t j = 0; j < control_node_parameters.size(); ++j) {
|
||||
const auto &input_node = control_node_parameters[j];
|
||||
const auto &input_tensor = tensors[j];
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (IsPersistentDeviceTensor(input_node)) {
|
||||
const auto &iter = graph_compiler_info.front_to_backend_parameters_.find(input_node);
|
||||
if (iter == graph_compiler_info.front_to_backend_parameters_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:"
|
||||
<< AnfAlgo::GetNodeDebugString(input_node);
|
||||
}
|
||||
const auto &node_with_context = iter->second;
|
||||
PrepareDataForControlWeightNode(node_with_context.first, input_node, input_tensor, node_with_context.second,
|
||||
graph_compiler_info.host_parameter_to_weights_);
|
||||
} else if (find(parameters.begin(), parameters.end(), input_node) != parameters.end()) {
|
||||
MS_EXCEPTION_IF_NULL(host_data_source_actor);
|
||||
PrepareDataForHostDataSourceActor(host_data_source_actor->data_node_position_map_, input_node, input_tensor,
|
||||
&host_tensors);
|
||||
}
|
||||
}
|
||||
// 3.Prepare the data which belongs to control node.
|
||||
PrepareDataForControlNode(graph_compiler_info.control_node_parser_, graph_compiler_info.origin_parameters_order_,
|
||||
input_tensors.back(), host_data_source_actor->data_node_position_map_, &host_tensors);
|
||||
|
||||
// 4.Prepare the data of host tensor queue(non weighted parameters of graph).
|
||||
if (host_data_source_actor != nullptr) {
|
||||
|
@ -529,6 +508,39 @@ void GraphScheduler::PrepareRun(const ActorSet *actor_set, const GraphCompilerIn
|
|||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::PrepareDataForControlNode(const ControlNodeParserPtr &control_node_parser,
|
||||
const std::vector<AnfNodePtr> &origin_parameters,
|
||||
const std::vector<TensorPtr> &tensors,
|
||||
const std::unordered_map<AnfNodePtr, size_t> &data_node_position_map,
|
||||
std::vector<TensorPtr> *host_tensors) {
|
||||
const auto &control_node_parameters = control_node_parser->GetControlNodeParameter();
|
||||
|
||||
for (size_t j = 0; j < control_node_parameters.size(); ++j) {
|
||||
const auto &input_node = control_node_parameters[j];
|
||||
const auto &input_tensor = tensors[j];
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (IsPersistentDeviceTensor(input_node)) {
|
||||
const auto &front_to_backend_parameters = control_node_parser->front_to_backend_parameters_;
|
||||
const auto &iter = front_to_backend_parameters.find(input_node);
|
||||
if (iter == front_to_backend_parameters.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:"
|
||||
<< AnfAlgo::GetNodeDebugString(input_node);
|
||||
}
|
||||
const auto &node_with_context = iter->second;
|
||||
PrepareDataForControlWeightNode(node_with_context.first, input_node, input_tensor, node_with_context.second,
|
||||
control_node_parser->host_parameter_to_weights_);
|
||||
} else if (find(origin_parameters.begin(), origin_parameters.end(), input_node) != origin_parameters.end()) {
|
||||
PrepareDataForHostDataSourceActor(data_node_position_map, input_node, input_tensor, host_tensors);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &value_node_with_context : control_node_parser->front_value_nodes_) {
|
||||
if (AnfAlgo::OutputAddrExist(value_node_with_context.first, 0)) {
|
||||
PrepareDataForValueNode(value_node_with_context.first->cast<ValueNodePtr>(), value_node_with_context.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool GraphScheduler::Run(const ActorSet *actor_set, GraphExecutionStrategy strategy,
|
||||
const std::vector<TensorPtr> *input_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
|
@ -773,12 +785,11 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
|
|||
}
|
||||
}
|
||||
|
||||
const auto &front_to_backend_parameter = graph_compiler_info.front_to_backend_parameters_;
|
||||
const auto &front_to_backend_parameter = graph_compiler_info.control_node_parser_->front_to_backend_parameters_;
|
||||
|
||||
// Initialize the parameter in the control node, first get all the front parameters in the control node, then find
|
||||
// the corresponding backend parameter from the map, and insert it into the host data source actor
|
||||
std::vector<AnfNodePtr> control_node_parameters =
|
||||
ControlNodeParser::FetchControlNodeParameter(graph_compiler_info.control_nodes_);
|
||||
std::vector<AnfNodePtr> control_node_parameters = graph_compiler_info.control_node_parser_->GetControlNodeParameter();
|
||||
for (const auto parameter : control_node_parameters) {
|
||||
if (IsPersistentDeviceTensor(parameter)) {
|
||||
continue;
|
||||
|
@ -926,13 +937,23 @@ std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const GraphC
|
|||
|
||||
std::vector<SwitchActorPtr> GraphScheduler::BuildSwitchActor(const GraphCompilerInfo &graph_compiler_info) {
|
||||
std::vector<SwitchActorPtr> switch_actors;
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_backend_kernel;
|
||||
for (const auto &pair : front_node_to_actor_) {
|
||||
front_to_backend_kernel[pair.first] = pair.second->kernel_;
|
||||
}
|
||||
|
||||
for (const auto &control_node : graph_compiler_info.control_nodes_) {
|
||||
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
|
||||
AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
|
||||
auto actor_name = control_node->fullname_with_scope();
|
||||
auto switch_actor = std::make_shared<SwitchActor>(actor_name, control_node->cast<CNodePtr>());
|
||||
auto switch_actor = std::make_shared<SwitchActor>(actor_name, graph_compiler_info.device_contexts_[0],
|
||||
control_node->cast<CNodePtr>());
|
||||
switch_actor->Initialize();
|
||||
|
||||
// Fetch all the input nodes of switch actor.
|
||||
switch_actor->FetchInputNode(graph_compiler_info.origin_parameters_order_,
|
||||
graph_compiler_info.control_node_parser_->front_to_backend_parameters_,
|
||||
front_to_backend_kernel);
|
||||
InsertActor(switch_actor.get());
|
||||
switch_actors.emplace_back(switch_actor);
|
||||
}
|
||||
|
@ -958,6 +979,13 @@ std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompiler
|
|||
continue;
|
||||
}
|
||||
|
||||
const auto &cnode = control_node->cast<CNodePtr>();
|
||||
const auto inputs = cnode->inputs();
|
||||
// If the output of funcgraph is a value node, no need to create gather actor.
|
||||
if (inputs[kReturnInputPos]->isa<ValueNode>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto func_graph = control_node->func_graph();
|
||||
auto actor_name = func_graph->ToString();
|
||||
std::vector<AnfNodePtr> parameters;
|
||||
|
@ -977,8 +1005,9 @@ std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompiler
|
|||
auto gather_actor =
|
||||
std::make_shared<GatherActor>(actor_name, parameters, loop_count_actor->GetAID(), output_actor->GetAID());
|
||||
gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.origin_parameters_order_,
|
||||
graph_compiler_info.front_to_backend_parameters_,
|
||||
graph_compiler_info.func_graph_to_parameters_, front_to_backend_kernel);
|
||||
graph_compiler_info.control_node_parser_->front_to_backend_parameters_,
|
||||
graph_compiler_info.control_node_parser_->func_graph_to_parameters_,
|
||||
front_to_backend_kernel);
|
||||
InsertActor(gather_actor.get());
|
||||
gather_actors.emplace_back(gather_actor);
|
||||
}
|
||||
|
@ -994,7 +1023,7 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const GraphCompilerInf
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
||||
auto from_kernel = from_kernel_with_output_idx.first;
|
||||
auto front_node = ControlNodeParser::GetFrontNodeByBackendNode(from_kernel);
|
||||
auto front_node = GetFrontNodeByBackendNode(from_kernel);
|
||||
if (IsDeviceQueueDSActor(from_kernel)) {
|
||||
// Link the data arrows of device queue data source actor.
|
||||
std::string actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
|
||||
|
@ -1002,7 +1031,7 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const GraphCompilerInf
|
|||
LinkDataArrowForDeviceDSActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
} else if (front_node != nullptr && IsGatherActor(front_node, actor_name_to_actor_)) {
|
||||
// Link the data arrows of gather actor.
|
||||
auto func_graph = ControlNodeParser::GetFuncgraphByBackendNode(from_kernel);
|
||||
auto func_graph = GetFuncgraphByBackendNode(from_kernel);
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find funcgraph of node:" << AnfAlgo::GetNodeDebugString(from_kernel);
|
||||
}
|
||||
|
@ -1064,8 +1093,11 @@ void GraphScheduler::LinkDataArrowForInternalParameter(const AnfNodePtr &interna
|
|||
auto from_kernel_with_output_idx = KernelWithIndex(from_actor->data_kernel_, actor_pair.second);
|
||||
LinkDataArrowForDeviceDSActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
} else if (IsSwitchActor(front_output_node)) {
|
||||
const auto &from_actor = dynamic_cast<SwitchActor *>(FetchActor(front_output_node->fullname_with_scope()));
|
||||
MS_LOG(ERROR) << "Need link to switch actor:" << from_actor->GetAID();
|
||||
const auto &actor_name = front_output_node->fullname_with_scope();
|
||||
const auto &actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
LinkDataArrowForSwitchActor(switch_actor, to_actor, to_kernel_with_input_idx.second);
|
||||
} else if (IsKernelActor(front_output_node)) {
|
||||
auto from_actor = dynamic_cast<KernelActor *>(actor_pair.first);
|
||||
auto from_kernel_with_output_idx = KernelWithIndex(from_actor->kernel_, actor_pair.second);
|
||||
|
@ -1512,6 +1544,46 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
|
|||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info,
|
||||
const ActorSet *actor_set) {
|
||||
const auto &to_actor = actor_set->output_actor_;
|
||||
const auto &loop_count_actor = actor_set->loop_count_actor_;
|
||||
const auto &switch_actors = actor_set->switch_actors_;
|
||||
if (to_actor == nullptr || loop_count_actor == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const auto &from_actor : switch_actors) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
auto origin_output_with_index = KernelWithIndex(from_actor->node_, 0);
|
||||
const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index);
|
||||
if (iter == graph_compiler_info.origin_outputs_order_.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If the switch actor is in the output list, the output of switch actor should be sent to the output actor.
|
||||
// And need to link a control arrow to the loop count actor.
|
||||
for (const auto pos : iter->second.second) {
|
||||
to_actor->device_contexts_[pos] = from_actor->device_context_;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < from_actor->branch_inputs_pos_.size(); ++i) {
|
||||
const auto &input_pos = from_actor->branch_inputs_pos_[i];
|
||||
if (input_pos.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid input num in switch actor:" << from_actor->GetAID();
|
||||
}
|
||||
|
||||
for (const auto pos : iter->second.second) {
|
||||
auto op_arrow = std::make_shared<DataArrow>(input_pos[0], to_actor->GetAID(), pos);
|
||||
from_actor->output_branch_result_arrows_[i].emplace_back(op_arrow);
|
||||
}
|
||||
|
||||
from_actor->output_branch_control_arrows_[i].emplace_back(loop_count_actor->GetAID());
|
||||
}
|
||||
loop_count_actor->branch_id_to_input_controls_num_[kMainBranchID]++;
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor *> &auto_monad_actors) {
|
||||
const size_t kNeedUpdateDeviceTensorStoreNum = 2;
|
||||
for (auto &kernel_actor : auto_monad_actors) {
|
||||
|
@ -1606,9 +1678,12 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi
|
|||
|
||||
LinkBranchArrowForGatherActor(graph_compiler_info, actor_set);
|
||||
|
||||
LinkControlArrowForGatherActor(graph_compiler_info, actor_set);
|
||||
LinkControlArrowForGatherActor(&(actor_set->gather_actors_), actor_set->loop_count_actor_.get(),
|
||||
graph_compiler_info.graphs_);
|
||||
|
||||
LinkOutputResultArrowForGatherActor(graph_compiler_info, actor_set);
|
||||
|
||||
LinkOutputResultArrowForSwitchActor(graph_compiler_info, actor_set);
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkDataArrowForGatherActor(GatherActor *from_actor, KernelActor *to_actor,
|
||||
|
@ -1621,7 +1696,7 @@ void GraphScheduler::LinkDataArrowForGatherActor(GatherActor *from_actor, Kernel
|
|||
MS_EXCEPTION_IF_NULL(from_kernel);
|
||||
auto to_input_index = to_kernel_with_input_idx.second;
|
||||
|
||||
auto front_node = ControlNodeParser::GetFrontNodeByBackendNode(from_kernel);
|
||||
auto front_node = GetFrontNodeByBackendNode(from_kernel);
|
||||
if (front_node == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find front node of node:" << AnfAlgo::GetNodeDebugString(from_kernel);
|
||||
}
|
||||
|
@ -1637,7 +1712,8 @@ void GraphScheduler::LinkDataArrowForGatherActor(GatherActor *from_actor, Kernel
|
|||
void GraphScheduler::LinkDataArrowByCallInput(const GraphCompilerInfo &graph_compiler_info, const AnfNodePtr &call_node,
|
||||
OpActor<DeviceTensor> *to_actor, const size_t to_index) {
|
||||
// Fetch all the funcgraph that call node would call.
|
||||
std::vector<FuncGraphPtr> func_graphs = ControlNodeParser::FetchFuncGraphbyCallNode(call_node->cast<CNodePtr>());
|
||||
const auto cnode = call_node->cast<CNodePtr>();
|
||||
std::vector<FuncGraphPtr> func_graphs = FetchFuncGraphbyCallNode(cnode);
|
||||
|
||||
// Collect the output of each funcgraph.
|
||||
for (const auto &func_graph : func_graphs) {
|
||||
|
@ -1677,19 +1753,58 @@ void GraphScheduler::LinkDataArrowByCallInput(const GraphCompilerInfo &graph_com
|
|||
const size_t pos = iter - gather_actor->data_nodes_.begin();
|
||||
auto op_arrow = std::make_shared<DataArrow>(pos, to_actor->GetAID(), to_index);
|
||||
gather_actor->output_data_arrows_.emplace_back(op_arrow);
|
||||
} else if (output_with_index.first->isa<ValueNode>()) {
|
||||
// If the output is a value node, then the value node needs to be sent by the switch actor.
|
||||
const auto &call_inputs = cnode->inputs();
|
||||
if (AnfAlgo::CheckPrimitiveType(call_inputs[0], prim::kPrimSwitch)) {
|
||||
const auto &actor_name = call_inputs[0]->fullname_with_scope();
|
||||
const auto &actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||
|
||||
// Add output for each branch of switch.
|
||||
for (size_t i = 0; i < switch_actor->branch_inputs_pos_.size(); ++i) {
|
||||
if (switch_actor->branch_inputs_pos_[i].empty()) {
|
||||
MS_LOG(EXCEPTION) << "No input for switch actor:" << actor_name << " branch:" << i;
|
||||
}
|
||||
|
||||
const auto from_index = switch_actor->branch_inputs_pos_[i][0];
|
||||
auto op_arrow = std::make_shared<DataArrow>(from_index, to_actor->GetAID(), to_index);
|
||||
switch_actor->output_branch_arrows_[i].emplace_back(op_arrow);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid input for call node:" << AnfAlgo::GetNodeDebugString(call_node);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Output of func graph is not a parameter or kernel, func graph:" << func_graph->ToString();
|
||||
MS_LOG(EXCEPTION) << "Output of func graph is not a parameter or kernel, func graph:" << func_graph->ToString()
|
||||
<< " output node:" << AnfAlgo::GetNodeDebugString(output_with_index.first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkDataArrowForSwitchActor(SwitchActor *from_actor, KernelActor *to_actor,
|
||||
const size_t to_index) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
|
||||
for (size_t i = 0; i < from_actor->output_branch_arrows_.size(); ++i) {
|
||||
if (from_actor->branch_inputs_pos_[i].empty()) {
|
||||
MS_LOG(EXCEPTION) << "No input for switch actor:" << from_actor->GetAID() << " branch:" << i;
|
||||
}
|
||||
const auto from_index = from_actor->branch_inputs_pos_[i][0];
|
||||
auto op_arrow = std::make_shared<DataArrow>(from_index, to_actor->GetAID(), to_index);
|
||||
from_actor->output_branch_arrows_[i].emplace_back(op_arrow);
|
||||
}
|
||||
to_actor->input_datas_num_++;
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_compiler_info,
|
||||
const AnfNodePtr &input_node, OpActor<DeviceTensor> *to_actor,
|
||||
const size_t to_index) {
|
||||
const auto ¶meters = graph_compiler_info.origin_parameters_order_;
|
||||
const auto &front_to_backend_parameter = graph_compiler_info.front_to_backend_parameters_;
|
||||
const auto &front_to_backend_parameter = graph_compiler_info.control_node_parser_->front_to_backend_parameters_;
|
||||
|
||||
if (ControlNodeParser::IsCallNode(input_node)) {
|
||||
if (IsCallNode(input_node)) {
|
||||
// The actor input is a call node.
|
||||
LinkDataArrowByCallInput(graph_compiler_info, input_node, to_actor, to_index);
|
||||
} else if (IsGatherActor(input_node, actor_name_to_actor_)) {
|
||||
|
@ -1745,20 +1860,19 @@ void GraphScheduler::LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_
|
|||
const auto &inputs = actor->input_nodes_;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
auto input = inputs[i];
|
||||
if (input->isa<ValueNode>()) {
|
||||
continue;
|
||||
}
|
||||
LinkDataArrowByControlNode(graph_compiler_info, input, actor, i);
|
||||
}
|
||||
|
||||
// Link switch output.
|
||||
for (size_t i = 0; i < actor->branch_func_graph_.size(); ++i) {
|
||||
auto func_graph = actor->branch_func_graph_[i];
|
||||
if (func_graph == nullptr) {
|
||||
if (func_graph == nullptr || func_graph->output()->isa<ValueNode>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (func_graph->output()->isa<ValueNode>()) {
|
||||
actor->AddInput(func_graph->output(), 0);
|
||||
}
|
||||
|
||||
auto gather_name = func_graph->ToString();
|
||||
if (actor_name_to_actor_.find(gather_name) == actor_name_to_actor_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find gather actor for funcgraph:" << gather_name
|
||||
|
@ -1773,13 +1887,15 @@ void GraphScheduler::LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_
|
|||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkControlArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info,
|
||||
const ActorSet *actor_set) {
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
void GraphScheduler::LinkControlArrowForGatherActor(std::vector<GatherActorPtr> *from_actors, LoopCountActor *to_actor,
|
||||
const std::vector<KernelGraphPtr> &graphs) {
|
||||
if (from_actors == nullptr || to_actor == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Link control arrow to kernel actor.
|
||||
for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
|
||||
const auto &kernel_graph = graph_compiler_info.graphs_[i];
|
||||
for (size_t i = 0; i < graphs.size(); ++i) {
|
||||
const auto &kernel_graph = graphs[i];
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
const auto &func_graph = kernel_graph->GetFuncGraph();
|
||||
if (func_graph == nullptr) {
|
||||
|
@ -1807,12 +1923,11 @@ void GraphScheduler::LinkControlArrowForGatherActor(const GraphCompilerInfo &gra
|
|||
}
|
||||
|
||||
// link control arrow to loop count actor.
|
||||
for (auto &from_actor : actor_set->gather_actors_) {
|
||||
for (auto &from_actor : *from_actors) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
|
||||
// If the gather actor has no output, then adds the output control to loop count actor.
|
||||
if (from_actor->output_data_arrows_.size() == 0 && from_actor->output_control_arrows_.size() == 0) {
|
||||
const auto &to_actor = actor_set->loop_count_actor_;
|
||||
auto to_aid = to_actor->GetAID();
|
||||
from_actor->output_control_arrows_.emplace_back(to_aid);
|
||||
to_actor->branch_id_to_input_controls_num_[kMainBranchID]++;
|
||||
|
@ -1831,7 +1946,7 @@ void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &grap
|
|||
const auto &output_actor = actor_set->output_actor_.get();
|
||||
|
||||
// If there is only one branch output, set the branch id of the loop count to 0, no need to send the branch id.
|
||||
auto outputs = graph_compiler_info.front_output_nodes_;
|
||||
auto outputs = graph_compiler_info.control_node_parser_->front_output_nodes_;
|
||||
if (outputs.size() == 1) {
|
||||
return;
|
||||
}
|
||||
|
@ -1868,7 +1983,7 @@ void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &grap
|
|||
MS_EXCEPTION_IF_NULL(kernel_actor);
|
||||
if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0)) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_actor->kernel_);
|
||||
const auto &sub_func_graph = ControlNodeParser::FetchFuncGraphByNode(kernel_actor->kernel_);
|
||||
const auto &sub_func_graph = FetchFuncGraphByNode(kernel_actor->kernel_);
|
||||
if (sub_func_graph == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cannot get funcgraph from kernel:" << kernel_actor->kernel_->fullname_with_scope();
|
||||
}
|
||||
|
@ -1896,6 +2011,18 @@ void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &grap
|
|||
gather_actor->branch_id_ = branch_id;
|
||||
loop_count_actor->branch_id_to_input_controls_num_[branch_id] = graph_to_control_num[sub_func_graph];
|
||||
}
|
||||
|
||||
// If the switch actor is linked to the output actor, it will link a control arrow to the loop count actor,
|
||||
// and this should be recorded.
|
||||
for (const auto &from_actor : actor_set->switch_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
auto origin_output_with_index = KernelWithIndex(from_actor->node_, 0);
|
||||
const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index);
|
||||
if (iter == graph_compiler_info.origin_outputs_order_.end()) {
|
||||
continue;
|
||||
}
|
||||
loop_count_actor->branch_id_to_input_controls_num_[iter->second.first]++;
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkOutputResultArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info,
|
||||
|
@ -2061,6 +2188,15 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// In control flow, there may be some value nodes that is not in the kernel graph and needs to be placed
|
||||
// in the tensor store separately.
|
||||
for (const auto &value_node : graph_compiler_info.control_node_parser_->front_value_nodes_) {
|
||||
MS_EXCEPTION_IF_NULL(value_node.first);
|
||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node.first, 0, false);
|
||||
DeviceTensorStore::GetInstance().Insert(value_node.first.get(), device_tensor);
|
||||
UpdateRefCount(device_tensor.get(), true);
|
||||
}
|
||||
}
|
||||
|
||||
HostTensorQueue *GraphScheduler::FetchHostQueue(const ActorInfo &actor_info) const {
|
||||
|
|
|
@ -67,33 +67,24 @@ enum class GraphExecutionStrategy {
|
|||
// The tensors mask is used to distinguish input tensor's type.
|
||||
// The input tensor is used to link graphs in the dynamic build scenario.
|
||||
// The control node is used to link graphs in the control flow scenario.
|
||||
// The control node parser is used to parse the edge info in control nodes.
|
||||
// The origin parameters order is used to correspond to the input args.
|
||||
// The origin outputs order is used to correspond to the output args.
|
||||
// The front to backend parameters is used to build and link the host data source actor in the control flow scenario.
|
||||
// The funcgraph to parameters map records the input parameters of funcgraph and is used to initialize
|
||||
// the input node of gather,
|
||||
// host parameter to weights records the weights in the subgraph corresponding to the node in the root funcgraph.
|
||||
// When initializing the weights, all related weights need to be recorded as the same device tensor.
|
||||
// The front output_node is used to link the output actor in multi-branch output scenario.
|
||||
struct GraphCompilerInfo {
|
||||
GraphCompilerInfo(
|
||||
const std::vector<KernelGraphPtr> &graphs, const std::vector<DeviceContext *> &device_contexts,
|
||||
const std::vector<std::vector<int64_t> *> &tensors_mask, const std::vector<std::vector<TensorPtr> *> &input_tensors,
|
||||
const std::vector<AnfNodePtr> &control_nodes, const std::vector<AnfNodePtr> &origin_parameters_order,
|
||||
const KernelMapPosition &origin_outputs_order, const FrontToBackendNodeWithContext &front_to_backend_parameters,
|
||||
const FuncGraphToParameter &func_graph_to_parameters, const HostParameterToWeight &host_parameter_to_weights,
|
||||
const std::vector<AnfNodePtr> &front_output_nodes, const size_t outputs_num, const std::string &name)
|
||||
GraphCompilerInfo(const std::vector<KernelGraphPtr> &graphs, const std::vector<DeviceContext *> &device_contexts,
|
||||
const std::vector<std::vector<int64_t> *> &tensors_mask,
|
||||
const std::vector<std::vector<TensorPtr> *> &input_tensors,
|
||||
const std::vector<AnfNodePtr> &control_nodes,
|
||||
const std::vector<AnfNodePtr> &origin_parameters_order, const ControlNodeParserPtr &parser,
|
||||
const KernelMapPosition &origin_outputs_order, const size_t outputs_num, const std::string &name)
|
||||
: graphs_(graphs),
|
||||
device_contexts_(device_contexts),
|
||||
tensors_mask_(tensors_mask),
|
||||
input_tensors_(input_tensors),
|
||||
control_nodes_(control_nodes),
|
||||
control_node_parser_(parser),
|
||||
origin_parameters_order_(origin_parameters_order),
|
||||
origin_outputs_order_(origin_outputs_order),
|
||||
front_to_backend_parameters_(front_to_backend_parameters),
|
||||
func_graph_to_parameters_(func_graph_to_parameters),
|
||||
host_parameter_to_weights_(host_parameter_to_weights),
|
||||
front_output_nodes_(front_output_nodes),
|
||||
outputs_num_(outputs_num),
|
||||
name_(name) {}
|
||||
std::vector<KernelGraphPtr> graphs_;
|
||||
|
@ -101,12 +92,9 @@ struct GraphCompilerInfo {
|
|||
std::vector<std::vector<int64_t> *> tensors_mask_;
|
||||
std::vector<std::vector<TensorPtr> *> input_tensors_;
|
||||
std::vector<AnfNodePtr> control_nodes_;
|
||||
ControlNodeParserPtr control_node_parser_;
|
||||
std::vector<AnfNodePtr> origin_parameters_order_;
|
||||
KernelMapPosition origin_outputs_order_;
|
||||
FrontToBackendNodeWithContext front_to_backend_parameters_;
|
||||
FuncGraphToParameter func_graph_to_parameters_;
|
||||
HostParameterToWeight host_parameter_to_weights_;
|
||||
std::vector<AnfNodePtr> front_output_nodes_;
|
||||
size_t outputs_num_;
|
||||
std::string name_;
|
||||
};
|
||||
|
@ -245,17 +233,27 @@ class GraphScheduler {
|
|||
// connected.
|
||||
void LinkDataArrowByCallInput(const GraphCompilerInfo &graph_compiler_info, const AnfNodePtr &call_node,
|
||||
OpActor<DeviceTensor> *to_actor, const size_t to_index);
|
||||
void LinkControlArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
|
||||
void LinkDataArrowForSwitchActor(SwitchActor *from_actor, KernelActor *to_actor, const size_t to_index);
|
||||
void LinkControlArrowForGatherActor(std::vector<GatherActorPtr> *from_actors, LoopCountActor *to_actor,
|
||||
const std::vector<KernelGraphPtr> &graphs);
|
||||
// In control flow, there are scenarios where there are multi-branch outputs, and the gather actor needs to
|
||||
// send the branch id to the loop count actor.
|
||||
void LinkBranchArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
|
||||
void LinkOutputResultArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
|
||||
void LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
|
||||
|
||||
// The processing of actors link dynamically.
|
||||
// Analyze necessary input data of current actor, generate and cache op arrow
|
||||
// between current actor and prev actor, the method executes before calling Schedule.
|
||||
void PrepareForDynamiclyLink(ActorSet *actor_set, const CNodePtr &kernel, const AID &aid,
|
||||
const std::vector<TensorPtr> *input_tensors);
|
||||
|
||||
void PrepareDataForControlNode(const ControlNodeParserPtr &control_node_parser,
|
||||
const std::vector<AnfNodePtr> &origin_parameters,
|
||||
const std::vector<TensorPtr> &tensors,
|
||||
const std::unordered_map<AnfNodePtr, size_t> &data_node_position_map,
|
||||
std::vector<TensorPtr> *host_tensors);
|
||||
|
||||
// Link to prev actor dynamically, and send message to prev actor to add the
|
||||
// new DataArrow and send output data back, the method must execute after calling Schedule.
|
||||
void LinkDataArrowForKernelActorDynamicly(const ActorSet *actor_set);
|
||||
|
|
|
@ -454,8 +454,7 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
|
|||
std::vector<tensor::TensorPtr> input_tensor;
|
||||
|
||||
// Get inputs of control node which come from the host actor.
|
||||
std::vector<AnfNodePtr> control_node_parameters =
|
||||
ControlNodeParser::FetchControlNodeParameter(graph_compiler_info.control_nodes_);
|
||||
const auto &control_node_parameters = graph_compiler_info.control_node_parser_->GetControlNodeParameter();
|
||||
for (const auto ¶meter : control_node_parameters) {
|
||||
PushTensor(args, origin_parameters, parameter, &input_tensor);
|
||||
}
|
||||
|
@ -555,10 +554,13 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
|
|||
name.append("_").append(std::to_string(graph_id_to_context.first));
|
||||
}
|
||||
|
||||
auto parser = std::make_shared<ControlNodeParser>();
|
||||
parser->Parse(control_nodes_, graphs, device_contexts, root_graph);
|
||||
|
||||
// Get all the outputs. In control flow, there may be multiple branch output.
|
||||
runtime::KernelMapPosition outputs_order;
|
||||
size_t outputs_num = 0;
|
||||
const auto &all_branch_output = ControlNodeParser::FetchAllBranchOutputs(root_graph);
|
||||
const auto &all_branch_output = parser->FetchAllBranchOutputs(root_graph);
|
||||
for (int j = 0; j < SizeToInt(all_branch_output.size()); ++j) {
|
||||
// In general, there is only one output branch, and the branch id is 0 at this time. In the control flow,
|
||||
// there are multi-branch output scenarios. Different branches may have different weight nodes. When output
|
||||
|
@ -578,28 +580,10 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
|
|||
}
|
||||
}
|
||||
|
||||
// Fetch all the relationships between front parameters and backend parameters which will be used to
|
||||
// build and link actors.
|
||||
FrontToBackendNodeWithContext front_to_backend_parameter;
|
||||
ControlNodeParser::FetchFrontToBackendParameterMap(graphs, device_contexts, control_nodes_,
|
||||
&front_to_backend_parameter);
|
||||
|
||||
// The funcgraph to parameters map records the input parameters of funcgraph and is used to initialize
|
||||
// the input node of gather,
|
||||
FuncGraphToParameter func_graph_to_parameters;
|
||||
ControlNodeParser::FetchFuncGraphToParameterMap(control_nodes_, &func_graph_to_parameters);
|
||||
|
||||
// host parameter to weights records the weights in the subgraph corresponding to the node in the root funcgraph.
|
||||
// When initializing the weights, all related weights need to be recorded as the same device tensor.
|
||||
HostParameterToWeight host_parameter_to_weights;
|
||||
ControlNodeParser::FetchHostParameterToWeightMap(control_nodes_, &host_parameter_to_weights);
|
||||
|
||||
std::vector<std::vector<int64_t> *> tensors_mask;
|
||||
std::vector<std::vector<tensor::TensorPtr> *> input_tensors;
|
||||
return std::make_unique<GraphCompilerInfo>(graphs, device_contexts, tensors_mask, input_tensors, control_nodes_,
|
||||
root_graph->parameters(), outputs_order, front_to_backend_parameter,
|
||||
func_graph_to_parameters, host_parameter_to_weights, all_branch_output,
|
||||
outputs_num, name);
|
||||
root_graph->parameters(), parser, outputs_order, outputs_num, name);
|
||||
}
|
||||
|
||||
std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(
|
||||
|
@ -629,10 +613,10 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(
|
|||
std::vector<std::vector<int64_t> *> tensors_mask_list(1, const_cast<std::vector<int64_t> *>(tensors_mask));
|
||||
std::vector<std::vector<TensorPtr> *> input_tensors_list(1,
|
||||
const_cast<std::vector<tensor::TensorPtr> *>(input_tensors));
|
||||
return std::make_unique<GraphCompilerInfo>(
|
||||
graphs, device_contexts, tensors_mask_list, input_tensors_list, std::vector<AnfNodePtr>(),
|
||||
std::vector<AnfNodePtr>(), outputs_order, FrontToBackendNodeWithContext(), FuncGraphToParameter(),
|
||||
HostParameterToWeight(), std::vector<AnfNodePtr>(), outputs_order.size(), actor_info);
|
||||
auto parser = std::make_shared<ControlNodeParser>();
|
||||
return std::make_unique<GraphCompilerInfo>(graphs, device_contexts, tensors_mask_list, input_tensors_list,
|
||||
std::vector<AnfNodePtr>(), std::vector<AnfNodePtr>(), parser,
|
||||
outputs_order, outputs_order.size(), actor_info);
|
||||
}
|
||||
|
||||
void MindRTBackend::RunGraph(const ActorInfo &actor_info, OpRunInfo *op_run_info,
|
||||
|
|
|
@ -41,9 +41,7 @@ using ActorInfo = runtime::ActorInfo;
|
|||
using GraphCompiler = runtime::GraphCompiler;
|
||||
using GraphCompilerInfo = runtime::GraphCompilerInfo;
|
||||
using ControlNodeParser = runtime::ControlNodeParser;
|
||||
using FrontToBackendNodeWithContext = runtime::FrontToBackendNodeWithContext;
|
||||
using FuncGraphToParameter = runtime::FuncGraphToParameter;
|
||||
using HostParameterToWeight = runtime::HostParameterToWeight;
|
||||
using ControlNodeParserPtr = runtime::ControlNodeParserPtr;
|
||||
|
||||
enum SwitchCondStatus {
|
||||
kCondOk = 0,
|
||||
|
|
Loading…
Reference in New Issue