!26212 Control flow support untail call.

Merge pull request !26212 from gaoyong10/runtime_second8
This commit is contained in:
i-robot 2021-11-15 10:53:39 +00:00 committed by Gitee
commit 390b3c2efa
15 changed files with 430 additions and 92 deletions

View File

@ -64,7 +64,9 @@ void DumpAbstractActor(const AbstractActor *actor, std::ofstream &ofs) {
}
if (actor->output_data_arrows().size() != actor->output_data_nodes().size()) {
MS_LOG(EXCEPTION) << "The size of output data arrows is not equal to the output nodes.";
MS_LOG(EXCEPTION) << "The size of output data arrows is not equal to the output nodes, arrow num:"
<< actor->output_data_arrows().size() << " node num:" << actor->output_data_nodes().size()
<< " for actor:" << actor->GetAID();
}
if (actor->output_data_arrows().size() > 0) {
ofs << "\t\toutput_data_arrows:" << actor->output_data_arrows().size() << "\n ";
@ -178,22 +180,30 @@ void DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) {
}
void DumpControlActor(const ControlActor *actor, std::ofstream &ofs) {
const auto &output_data_arrows = actor->output_data_arrows();
if (output_data_arrows.size() > 0) {
ofs << "\t\t\toutput_data_arrows:" << output_data_arrows.size() << "\n ";
for (const auto &data_arrow : output_data_arrows) {
MS_EXCEPTION_IF_NULL(data_arrow);
ofs << "\t\t\t\tfrom_output_index:" << data_arrow->from_output_index_
<< "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_
<< "\n";
MS_EXCEPTION_IF_NULL(actor);
DumpAbstractActor(actor, ofs);
const auto &local_partials = actor->local_partials();
if (local_partials.size() > 0) {
ofs << "\t\t\tlocal partial num:" << local_partials.size() << "\n ";
for (const auto &local_partial : local_partials) {
MS_EXCEPTION_IF_NULL(local_partial.second.first);
ofs << "\t\t\t\tlocal partial index:" << local_partial.first
<< "\tgraph:" << local_partial.second.first->ToString()
<< "\tparameter num:" << local_partial.second.second.size() << "\n";
}
}
const auto &output_control_arrows = actor->output_control_arrows();
if (output_control_arrows.size() > 0) {
ofs << "\t\t\toutput_control_arrows:" << output_control_arrows.size() << "\n ";
for (const auto &aid : output_control_arrows) {
ofs << "\t\t\t\tto_actor_name:" << aid.Name() << "\n";
if (actor->input_partial_arrow_aids().size() > 0) {
ofs << "\t\tinput_partial_arrow_actor:" << actor->input_partial_arrow_aids().size() << "\n ";
for (const auto &input_partial_arrow_aid : actor->input_partial_arrow_aids()) {
ofs << "\t\t\tfrom_actor_name:" << input_partial_arrow_aid.Name() << "\n";
}
}
if (actor->input_branch_id_arrow_aids().size() > 0) {
ofs << "\t\tinput_branch_id_arrow_actor:" << actor->input_branch_id_arrow_aids().size() << "\n ";
for (const auto &input_branch_id_arrow_aid : actor->input_branch_id_arrow_aids()) {
ofs << "\t\t\tfrom_actor_name:" << input_branch_id_arrow_aid.Name() << "\n";
}
}
@ -219,13 +229,13 @@ void DumpControlActor(const ControlActor *actor, std::ofstream &ofs) {
void DumpSwitchActor(const SwitchActor *actor, std::ofstream &ofs) {
MS_EXCEPTION_IF_NULL(actor);
ofs << "\t\ttactor_name:" << actor->GetAID().Name() << '\n';
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
DumpControlActor(actor, ofs);
}
void DumpGatherActor(const GatherActor *actor, std::ofstream &ofs) {
MS_EXCEPTION_IF_NULL(actor);
ofs << "\t\tactor_name:" << actor->GetAID().Name() << '\n';
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
DumpControlActor(actor, ofs);
const auto &output_data_with_branch_id_arrows = actor->output_data_with_branch_id_arrows();
@ -242,13 +252,13 @@ void DumpGatherActor(const GatherActor *actor, std::ofstream &ofs) {
void DumpEntranceActor(const EntranceActor *actor, std::ofstream &ofs) {
MS_EXCEPTION_IF_NULL(actor);
ofs << "\t\tactor_name:" << actor->GetAID().Name() << '\n';
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
DumpControlActor(actor, ofs);
}
void DumpExitActor(const ExitActor *actor, std::ofstream &ofs) {
MS_EXCEPTION_IF_NULL(actor);
ofs << "\t\tactor_name:" << actor->GetAID().Name() << '\n';
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
DumpControlActor(actor, ofs);
const auto &output_branch_data_arrows = actor->output_branch_data_arrows();
@ -291,40 +301,40 @@ void DumpExitActor(const ExitActor *actor, std::ofstream &ofs) {
void DumpStackActor(const StackActor *actor, std::ofstream &ofs) {
MS_EXCEPTION_IF_NULL(actor);
ofs << "\t\tactor_name:" << actor->GetAID().Name() << '\n';
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
DumpControlActor(actor, ofs);
}
void DumpSwitchActors(const std::vector<SwitchActorPtr> &actors, std::ofstream &ofs) {
ofs << "\n\n\t[Switch actors:" << actors.size() << "]\n";
ofs << "\n\n[Switch actors:" << actors.size() << "]\n";
for (const auto &switch_actor : actors) {
DumpSwitchActor(switch_actor.get(), ofs);
}
}
void DumpGatherActors(const std::vector<GatherActorPtr> &actors, std::ofstream &ofs) {
ofs << "\n\n\t[Gather actors:" << actors.size() << "]\n";
ofs << "\n\n[Gather actors:" << actors.size() << "]\n";
for (const auto &gather_actor : actors) {
DumpGatherActor(gather_actor.get(), ofs);
}
}
void DumpEntranceActors(const std::vector<EntranceActorPtr> &actors, std::ofstream &ofs) {
ofs << "\n\n\t[Entrance actors:" << actors.size() << "]\n";
ofs << "\n\n[Entrance actors:" << actors.size() << "]\n";
for (const auto &entrance_actor : actors) {
DumpEntranceActor(entrance_actor.get(), ofs);
}
}
void DumpExitActors(const std::vector<ExitActorPtr> &actors, std::ofstream &ofs) {
ofs << "\n\n\t[Exit actors:" << actors.size() << "]\n";
ofs << "\n\n[Exit actors:" << actors.size() << "]\n";
for (const auto &exit_actor : actors) {
DumpExitActor(exit_actor.get(), ofs);
}
}
void DumpStackActors(const std::vector<StackActorPtr> &actors, std::ofstream &ofs) {
ofs << "\n\n\t[Stack actors:" << actors.size() << "]\n";
ofs << "\n\n[Stack actors:" << actors.size() << "]\n";
for (const auto &stack_actor : actors) {
DumpStackActor(stack_actor.get(), ofs);
}

View File

@ -42,7 +42,8 @@ void ControlActor::Init() {
size_t ControlActor::FetchNodePosition(const KernelWithIndex &node) const {
const auto &iter = find(formal_parameters_.begin(), formal_parameters_.end(), node);
if (iter == formal_parameters_.end()) {
MS_LOG(EXCEPTION) << "Invalid formal parameter:" << node.first->DebugString() << " for actor:" << GetAID();
MS_LOG(EXCEPTION) << "Invalid formal parameter:" << node.first->DebugString() << " index:" << node.second
<< " for actor:" << GetAID();
}
return iter - formal_parameters_.begin();
}
@ -112,18 +113,13 @@ void ControlActor::FetchInput(OpContext<DeviceTensor> *const context) {
}
// Fetch input device tensor from device store.
for (auto &device_tensor_store_key : device_tensor_store_keys_) {
auto device_tensor = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get(),
device_contexts_[0]->GetDeviceAddressType());
if (device_tensor == nullptr) {
MS_LOG(ERROR) << GetAID() << " get device tensor store failed: " << device_tensor_store_key.second->DebugString();
for (auto &local_device_tensor : local_device_tensors_) {
MS_EXCEPTION_IF_NULL(local_device_tensor.second);
if (local_device_tensor.first >= input_device_tensors_.size()) {
MS_LOG(ERROR) << "Invalid local index:" << local_device_tensor.first
<< " current:" << local_device_tensors_.size() << " for actor:" << GetAID();
}
if (device_tensor_store_key.first >= input_device_tensors_.size()) {
MS_LOG(ERROR) << "The input index is out of range, need:" << device_tensor_store_key.first
<< " current:" << input_device_tensors_.size() << " for actor:" << GetAID();
}
input_device_tensors_[device_tensor_store_key.first] = device_tensor;
input_device_tensors_[local_device_tensor.first] = local_device_tensor.second;
}
for (size_t i = 0; i < output_data_by_output_index_.size(); ++i) {

View File

@ -54,6 +54,9 @@ class ControlActor : public AbstractActor {
const std::vector<DataArrowPtr> &output_partial_arrows() const { return output_partial_arrows_; }
const std::vector<AID> &output_branch_id_arrows() const { return output_branch_id_arrows_; }
const std::unordered_map<size_t, OpPartial> &local_partials() const { return local_partials_; }
const std::vector<AID> &input_partial_arrow_aids() const { return input_partial_arrow_aids_; }
const std::vector<AID> &input_branch_id_arrow_aids() const { return input_branch_id_arrow_aids_; }
protected:
friend class ControlNodeScheduler;
@ -87,6 +90,10 @@ class ControlActor : public AbstractActor {
// Input num.
size_t input_partials_num_{0};
// The dependent input actors.
std::vector<AID> input_partial_arrow_aids_;
std::vector<AID> input_branch_id_arrow_aids_;
// Output Arrows.
std::vector<DataArrowPtr> output_partial_arrows_;
OpPartial output_partial_;
@ -99,6 +106,8 @@ class ControlActor : public AbstractActor {
// Partial data in local. When partial is only funcgraph without real parameter, it is stored inside the actor.
std::unordered_map<size_t, OpPartial> local_partials_;
// Device tensor in control node, but not in kernel graph.
std::unordered_map<size_t, DeviceTensor *> local_device_tensors_;
// Cache output data by output index to modify the output data effectively.
std::vector<std::vector<OpData<DeviceTensor> *>> output_data_by_output_index_;

View File

@ -85,10 +85,11 @@ void ExitActor::CopyDeviceAddress() {
MS_EXCEPTION_IF_NULL(input_device_tensor);
const KernelWithIndex &node_with_index = input_device_tensor->GetNodeIndex();
MS_EXCEPTION_IF_NULL(node_with_index.first);
if (!node_with_index.first->isa<CNode>()) {
if (device_contexts_[i] == nullptr) {
// If device context is empty, it means that the input is from a parameter, need not to copy a new device tensor.
new_device_tensors.emplace_back(input_device_tensor);
continue;
}
MS_EXCEPTION_IF_NULL(device_contexts_[i]);
auto new_device_tensor =
device_contexts_[i]->CreateDeviceAddress(nullptr, input_device_tensors_[i]->GetSize(),

View File

@ -21,7 +21,9 @@ namespace mindspore {
namespace runtime {
GatherActor::GatherActor(const std::string &name, const std::vector<KernelWithIndex> &parameters,
const AnfNodePtr &node)
: ControlActor(name, KernelTransformType::kGatherActor, parameters, node) {}
: ControlActor(name, KernelTransformType::kGatherActor, parameters, node) {
device_contexts_.resize(parameters.size());
}
void GatherActor::FetchInput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);

View File

@ -25,13 +25,96 @@ StackActor::StackActor(const std::string &name, const std::vector<KernelWithInde
input_device_tensors_.resize(parameters.size());
}
bool StackActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
MS_EXCEPTION_IF_NULL(context);
return false;
void StackActor::Init() {
ControlActor::Init();
for (const auto &formal_parameter : formal_parameters_) {
if (AnfAlgo::IsCallNode(formal_parameter.first)) {
break;
}
++input_parameter_data_num_;
}
input_datas_num_ = formal_parameters_.size() - input_parameter_data_num_;
}
void StackActor::FetchInput(OpContext<DeviceTensor> *const context) { MS_EXCEPTION_IF_NULL(context); }
void StackActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(input_data->data_);
auto &sequential_num = context->sequential_num_;
// The parameters from the inside of the subgraph need to be put into the stack.
if (IntToSize(input_data->index_) < input_parameter_data_num_) {
input_parameter_data_[sequential_num][input_data->index_].push(input_data->data_);
} else {
// The outputs of call nodes are placed directly in the input data.
input_op_datas_[sequential_num].emplace_back(input_data);
}
if (CheckRunningCondition(context)) {
Run(context);
}
}
void StackActor::EraseInput(const OpContext<DeviceTensor> *const context) { MS_EXCEPTION_IF_NULL(context); }
bool StackActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
MS_EXCEPTION_IF_NULL(context);
if (!ControlActor::CheckRunningCondition(context)) {
return false;
}
if (input_parameter_data_num_ != 0) {
const auto &data_iter = input_parameter_data_.find(context->sequential_num_);
if (data_iter == input_parameter_data_.end()) {
return false;
}
if (data_iter->second.size() != input_parameter_data_num_) {
return false;
}
auto iter = input_branch_ids_.find(context->sequential_num_);
if (iter == input_branch_ids_.end() || iter->second.empty()) {
MS_LOG(ERROR) << "There is no branch id for actor:" << GetAID();
}
size_t branch_id_size = iter->second.size();
if (std::any_of(data_iter->second.begin(), data_iter->second.end(),
[branch_id_size](const auto &one_stack) { return one_stack.second.size() != branch_id_size; })) {
return false;
}
}
return true;
}
void StackActor::FetchInput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
if (input_parameter_data_num_ != 0) {
const auto &data_iter = input_parameter_data_.find(context->sequential_num_);
if (data_iter == input_parameter_data_.end()) {
MS_LOG(ERROR) << "Invalid input for actor:" << GetAID();
}
for (const auto &one_stack : data_iter->second) {
if (one_stack.first >= input_parameter_data_num_) {
MS_LOG(ERROR) << "Invalid input index:" << one_stack.first << " need:" << input_parameter_data_num_
<< " for actor:" << GetAID();
}
input_device_tensors_[one_stack.first] = one_stack.second.top();
}
}
ControlActor::FetchInput(context);
}
void StackActor::EraseInput(const OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
ControlActor::EraseInput(context);
if (input_parameter_data_num_ != 0) {
const auto &data_iter = input_parameter_data_.find(context->sequential_num_);
if (data_iter == input_parameter_data_.end()) {
MS_LOG(ERROR) << "Invalid input for actor:" << GetAID();
}
for (auto &one_stack : data_iter->second) {
if (one_stack.second.empty()) {
MS_LOG(ERROR) << "Input index:" << one_stack.first << " is null in actor:" << GetAID();
}
one_stack.second.pop();
}
}
}
} // namespace runtime
} // namespace mindspore

View File

@ -38,6 +38,11 @@ class StackActor : public ControlActor {
StackActor(const std::string &name, const std::vector<KernelWithIndex> &parameters);
~StackActor() override = default;
void Init() override;
// The input data of the stack actor needs to be pushed into the stack according to the input index, so it is
// implemented separately.
void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override;
protected:
void FetchInput(OpContext<DeviceTensor> *const context) override;
bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override;
@ -46,6 +51,12 @@ class StackActor : public ControlActor {
private:
friend class ControlNodeScheduler;
// The input data records that the stack actor is copied from the input nodes and needs to be stored in the
// device tensor in the stack.
std::unordered_map<int, std::unordered_map<size_t, std::stack<DeviceTensor *>>> input_parameter_data_;
// Input parameter data num represents the number of actor's input come from funcgraph itself, these inputs
// will be ranked at the front of input.
size_t input_parameter_data_num_{0};
// The backend parameter is used to save the backend node corresponding to the device tensor in the stack.
// When these device tensors are used as output, they need to be placed in the node of the result arrow,
// so these nodes need to be saved.

View File

@ -347,6 +347,60 @@ void DataPrepareActor::PrepareDataForValueNodeTensor(const ValueNodePtr &node, c
}
}
void DataPrepareActor::PrepareDataForControlValueNode(const KernelWithIndex &node_with_index,
const DeviceContext *device_context,
OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(node_with_index.first);
if (!node_with_index.first->isa<ValueNode>()) {
return;
}
const auto &node = node_with_index.first->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(node);
const auto &node_value = node->value();
MS_EXCEPTION_IF_NULL(node_value);
if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
std::vector<TensorPtr> tensors;
// Fetch all of tensors in value node.
TensorValueToTensor(node_value, &tensors);
for (size_t i = 0; i < tensors.size(); i++) {
const auto &tensor = tensors[i];
if (tensor == nullptr) {
MS_LOG(WARNING) << "Tensor is null";
return;
}
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, i, false);
MS_EXCEPTION_IF_NULL(device_tensor);
if (device_tensor->GetPtr() != nullptr) {
return;
}
MS_LOG(INFO) << "Prepare device data for control value node: " << node->DebugString() << ", output index: " << i;
tensor->set_device_address(device_tensor);
UpdateRefCount(device_tensor.get(), true);
if (!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize())) {
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(strategy_, *context, *device_context, node->fullname_with_scope(),
device_tensor->GetSize());
}
auto host_tensor_size = LongToSize(tensor->data().nbytes());
auto host_tensor_type = tensor->data_type();
auto shape = tensor->shape();
if (!device_tensor->SyncHostToDevice(shape, host_tensor_size, host_tensor_type, tensor->data_c(),
tensor->device_info().host_format_)) {
std::string error_info = "Sync host to device failed for node:" + node->DebugString();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
}
}
}
// Prepare the device data for persistent device tensor of value node.
void DataPrepareActor::PrepareDataForValueNode(const ValueNodePtr &node, const DeviceContext *device_context,
OpContext<DeviceTensor> *const context) {
@ -504,9 +558,8 @@ void DataPrepareActor::PrepareDeviceTensorStoreForControlNode(const ControlNodeP
OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(control_node_parser);
for (const auto &value_node_with_context : control_node_parser->front_value_nodes()) {
if (AnfAlgo::OutputAddrExist(value_node_with_context.first, 0)) {
PrepareDataForValueNode(value_node_with_context.first->cast<ValueNodePtr>(), value_node_with_context.second,
context);
if (AnfAlgo::OutputAddrExist(value_node_with_context.first.first, 0)) {
PrepareDataForControlValueNode(value_node_with_context.first, value_node_with_context.second, context);
}
}
@ -517,7 +570,7 @@ void DataPrepareActor::PrepareDeviceTensorStoreForControlNode(const ControlNodeP
MS_EXCEPTION_IF_NULL(input_node);
if (IsPersistentDeviceTensor(input_node)) {
const auto &front_to_backend_parameters = control_node_parser->front_to_backend_parameters();
const auto &iter = front_to_backend_parameters.find(input_node);
const auto &iter = front_to_backend_parameters.find({input_node, 0});
if (iter == front_to_backend_parameters.end() || iter->second.empty()) {
MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:"
<< AnfAlgo::GetNodeDebugString(input_node);

View File

@ -97,6 +97,8 @@ class DataPrepareActor : public DebugAwareActor {
const DeviceContext *device_context,
const HostParameterToWeight &host_parameter_to_weights,
OpContext<DeviceTensor> *const context);
void PrepareDataForControlValueNode(const KernelWithIndex &node_with_index, const DeviceContext *device_context,
OpContext<DeviceTensor> *const context);
const GraphCompilerInfo *graph_compiler_info_;
GraphExecutionStrategy strategy_;

View File

@ -235,7 +235,24 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *cons
for (size_t i = 0; i < host_tensors.size(); ++i) {
auto &host_tensor = host_tensors[i];
auto &device_tensor = device_tensors[i];
MS_EXCEPTION_IF_NULL(host_tensor);
if (host_tensor == nullptr) {
// In the control flow, the weight device tensor needs to be sent by the data source actor, and the input of
// the data prepare actor is host tensor, the device tensor of the weight cannot be obtained, the input will
// be empty, here to check whether it is weight, if it is, get the device tensor from the device tensor store.
if (IsPersistentDeviceTensor(data_nodes_[i])) {
MS_EXCEPTION_IF_NULL(device_contexts_[i]);
auto device_store_tensor =
DeviceTensorStore::GetInstance().Fetch(data_nodes_[i].get(), device_contexts_[i]->GetDeviceAddressType());
if (device_store_tensor == nullptr) {
std::string error_info = GetAID().Name() + " failed get device tensor for: " + data_nodes_[i]->DebugString();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
device_tensor = device_store_tensor;
continue;
} else {
MS_LOG(EXCEPTION) << "Invalid host tensor for index:" << i << " node:" << data_nodes_[i]->DebugString();
}
}
MS_EXCEPTION_IF_NULL(device_tensor);
auto tensor_device_address = std::dynamic_pointer_cast<DeviceTensor>(host_tensor->device_address());
// Sync data from host_tensor_device_address to device_tensor.

View File

@ -24,7 +24,10 @@ namespace {
// Get all the real parameters corresponding to node.
void FetchRealParameterByNode(const KernelWithIndex &node, std::set<KernelWithIndex> *real_parameters,
std::set<KernelWithIndex> *invalid_call_nodes) {
const auto &node_with_index = AnfAlgo::VisitKernelWithReturnType(node.first, node.second);
auto node_with_index = node;
if (!node.first->isa<ValueNode>()) {
node_with_index = AnfAlgo::VisitKernelWithReturnType(node.first, node.second);
}
if (node_with_index.first->isa<ValueNode>() || node_with_index.first->isa<Parameter>()) {
// If node is a valuenode or parameter, the real parameter is itself.
real_parameters->emplace(node_with_index);
@ -156,6 +159,20 @@ std::vector<KernelWithIndex> FetchAllOutputWithIndex(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
std::vector<KernelWithIndex> result;
if (node->isa<ValueNode>() && IsValueNode<ValueTuple>(node)) {
const auto &value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
const auto &value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
const auto &value_tuple = value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
const auto tuple_value = value_tuple->value();
for (size_t i = 0; i < tuple_value.size(); ++i) {
result.emplace_back(node, i);
}
return result;
}
const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0);
if (AnfAlgo::CheckPrimitiveType(node_with_index.first, prim::kPrimMakeTuple)) {
const auto &cnode = node_with_index.first->cast<CNodePtr>();
@ -178,6 +195,73 @@ std::vector<KernelWithIndex> FetchAllOutputWithIndex(const AnfNodePtr &node) {
return result;
}
// Create a device tensor for the front node.
// Get the output format and select kernel build info from the backend node corresponding to the front node to
// create the device address.
void CreateDeviceTensorForValueNode(const KernelWithIndex &front_node_with_index, const AnfNodePtr &backend_node,
const DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(backend_node);
MS_EXCEPTION_IF_NULL(device_context);
const auto &front_node = front_node_with_index.first;
MS_EXCEPTION_IF_NULL(front_node);
const auto &node_value = front_node->cast<ValueNodePtr>()->value();
if ((!node_value->isa<tensor::Tensor>()) && (!node_value->isa<ValueTuple>())) {
return;
}
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(backend_node, 0);
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(backend_node, 0);
if (output_type_id == kTypeUnknown) {
output_type_id = AnfAlgo::GetOutputInferDataType(backend_node, 0);
}
if (front_node->kernel_info() == nullptr) {
front_node->set_kernel_info(std::make_shared<device::KernelInfo>());
}
// Get the select kernel build info.
auto kernel_info = static_cast<device::KernelInfo *>(backend_node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->GetMutableSelectKernelBuildInfo();
MS_EXCEPTION_IF_NULL(build_info);
AnfAlgo::SetSelectKernelBuildInfo(build_info, front_node.get());
// Create device tensor.
std::string output_format = AnfAlgo::GetOutputFormat(backend_node, 0);
device::DeviceAddressPtr address =
device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id);
MS_EXCEPTION_IF_NULL(address);
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(front_node) << " addr:" << address;
AnfAlgo::SetOutputAddr(address, front_node_with_index.second, front_node.get());
}
// Create a device tensor for front node.
// When the condition input of the switch and switchlayer or the output of a subgraph is a parameter or value node,
// there is no corresponding backend node for this parameter, so a device tensor needs to be created for it.
void CreateDeviceTensorForFrontNode(const KernelWithIndex &front_node_with_index, const DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(device_context);
const auto &node = front_node_with_index.first;
MS_EXCEPTION_IF_NULL(device_context);
TypeId type_id = AnfAlgo::GetOutputInferDataType(node, 0);
if (node->kernel_info() == nullptr) {
auto kernel_info = std::make_shared<device::KernelInfo>();
std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
builder->SetOutputsFormat({kOpFormat_DEFAULT});
builder->SetOutputsDeviceType({type_id});
kernel_info->set_select_kernel_build_info(builder->Build());
node->set_kernel_info(kernel_info);
}
size_t size = AnfAlgo::GetOutputTensorMemSize(node, 0);
// Create device tensor.
device::DeviceAddressPtr address = device_context->CreateDeviceAddress(nullptr, size, kOpFormat_DEFAULT, type_id);
MS_EXCEPTION_IF_NULL(address);
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(node) << " addr:" << address;
AnfAlgo::SetOutputAddr(address, front_node_with_index.second, node.get());
}
} // namespace
bool HasAbstractRef(const AnfNodePtr &node) {
@ -209,7 +293,7 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
<< " device context num:" << device_contexts.size();
}
if (control_nodes.size() <= 1) {
if (control_nodes.size() <= 1 || device_contexts.empty()) {
return;
}
@ -235,7 +319,7 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
FetchCallInputKernelGraph(graphs, device_contexts);
FetchFrontValueNode();
FetchFrontValueNode(device_contexts[0]);
FetchFrontToBackendKernel(graphs, device_contexts);
@ -250,7 +334,7 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
bool ControlNodeParser::IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(graph);
if (!IsInited()) {
if ((!IsInited()) || (!node->isa<Parameter>())) {
return false;
}
@ -385,7 +469,8 @@ void ControlNodeParser::ParseDeviceContextForControlNode(const DeviceContext *de
// will be thrown.
if (call_device_contexts.empty() || call_device_contexts.size() <= output_node.second) {
MS_LOG(EXCEPTION) << "Cannot find device context for call node:" << output_node.first->DebugString()
<< " device contexts size:" << call_device_contexts.size();
<< " device contexts size:" << call_device_contexts.size()
<< " index:" << output_node.second;
}
return_device_contexts.emplace_back(call_device_contexts[output_node.second]);
} else if (output_node.first->isa<CNode>()) {
@ -444,7 +529,9 @@ KernelWithIndex ControlNodeParser::FetchBackendNodeByFrontNode(const KernelWithI
return {};
}
void ControlNodeParser::FetchFrontValueNode() {
void ControlNodeParser::FetchFrontValueNode(DeviceContext *default_context) {
MS_EXCEPTION_IF_NULL(default_context);
for (const auto &formal_to_real_parameter : formal_to_real_parameters_) {
for (const auto &real_parameter_with_index : formal_to_real_parameter.second) {
const auto &real_parameter = real_parameter_with_index.first;
@ -452,12 +539,30 @@ void ControlNodeParser::FetchFrontValueNode() {
continue;
}
const auto &iter = front_to_backend_parameters_.find({real_parameter, 0});
const auto &iter = front_to_backend_parameters_.find(real_parameter_with_index);
if (iter != front_to_backend_parameters_.end() && (!iter->second.empty())) {
front_value_nodes_.emplace(real_parameter, iter->second.begin()->second);
front_value_nodes_.emplace(real_parameter_with_index, iter->second.begin()->second);
CreateDeviceTensorForValueNode(real_parameter_with_index, iter->second.begin()->first,
iter->second.begin()->second);
} else {
front_value_nodes_.emplace(real_parameter_with_index, default_context);
CreateDeviceTensorForFrontNode(real_parameter_with_index, default_context);
}
}
}
// If the output of funcgraph is a value node, it will eventually be sent to the kernel as a real parameter.
// These the value nodes also need to create a device address.
for (const auto &front_to_backend_parameters : front_to_backend_parameters_) {
const auto &front_node = front_to_backend_parameters.first.first;
MS_EXCEPTION_IF_NULL(front_node);
if (front_node->isa<ValueNode>() && (!front_to_backend_parameters.second.empty())) {
const auto &backend_parameter = front_to_backend_parameters.second.begin()->first;
const auto &device_context = front_to_backend_parameters.second.begin()->second;
CreateDeviceTensorForValueNode(front_to_backend_parameters.first, backend_parameter, device_context);
front_value_nodes_.emplace(front_to_backend_parameters.first, device_context);
}
}
}
void ControlNodeParser::ParseFormalToRealParameter(const std::vector<AnfNodePtr> &control_nodes) {
@ -657,11 +762,11 @@ void ControlNodeParser::ParseFrontToBackendParameter(const std::vector<KernelGra
FetchRealParameterByNode(front_node_with_index, &real_parameters, &invalid_call_nodes);
for (const auto real_parameter : real_parameters) {
if (real_parameter.first->isa<Parameter>() || real_parameter.first->isa<ValueNode>()) {
front_to_backend_parameters_[real_parameter.first].emplace(parameter, device_context);
front_to_backend_parameters_[real_parameter].emplace(parameter, device_context);
}
}
} else {
front_to_backend_parameters_[front_node].emplace(parameter, device_context);
front_to_backend_parameters_[{front_node, 0}].emplace(parameter, device_context);
}
}
}
@ -671,12 +776,13 @@ void ControlNodeParser::ParseFrontToBackendParameter(const std::vector<KernelGra
for (const auto &front_to_backend_parameters : front_to_backend_parameters_) {
const auto &front_parameter = front_to_backend_parameters.first;
const auto &backend_parameters = front_to_backend_parameters.second;
const auto &iter = formal_to_real_parameters_.find(front_parameter);
const auto &iter = formal_to_real_parameters_.find(front_parameter.first);
if (iter != formal_to_real_parameters_.end()) {
for (const auto &real_parameter_with_index : iter->second) {
const auto &real_parameter = real_parameter_with_index.first;
if (real_parameter->isa<Parameter>()) {
front_to_backend_parameters_[real_parameter].insert(backend_parameters.begin(), backend_parameters.end());
front_to_backend_parameters_[real_parameter_with_index].insert(backend_parameters.begin(),
backend_parameters.end());
}
}
}

View File

@ -57,11 +57,11 @@ constexpr size_t kSingleControlNode = 1;
const char kEntranceActorNameSuffix[] = "_EntranceActor";
const char kStackActorNameSuffix[] = "_StackActor";
using FrontToBackendNodeWithContext = std::unordered_map<AnfNodePtr, std::set<std::pair<AnfNodePtr, DeviceContext *>>>;
using FrontToBackendNodeWithContext = std::map<KernelWithIndex, std::set<std::pair<AnfNodePtr, DeviceContext *>>>;
using FrontToBackendKernelWithContext = std::map<KernelWithIndex, std::pair<KernelWithIndex, DeviceContext *>>;
using FuncGraphToKernelGraph = std::unordered_map<FuncGraphPtr, std::vector<KernelGraphPtr>>;
using HostParameterToWeight = std::unordered_map<AnfNodePtr, std::set<AnfNodePtr>>;
using NodeWithDeviceContext = std::set<std::pair<AnfNodePtr, DeviceContext *>>;
using NodeWithDeviceContext = std::set<std::pair<KernelWithIndex, DeviceContext *>>;
using RealToFormalNode = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>;
using FormalToRealParameter = std::unordered_map<AnfNodePtr, std::set<KernelWithIndex>>;
using RealToFormalParameter = std::unordered_map<AnfNodePtr, std::set<AnfNodePtr>>;
@ -114,7 +114,7 @@ class ControlNodeParser {
// value nodes will not enter the kernel graph, so these nodes need to be saved separately, and space is allocated for
// them separately during initialization.
// The interface is initialized by finding the backend node in the kernel graph that the front node finally sends to.
void FetchFrontValueNode();
void FetchFrontValueNode(DeviceContext *default_context);
// Create branch id for all call node in the control flow.
void CreateBranchIDForCallNode(const std::vector<AnfNodePtr> &control_nodes);

View File

@ -69,6 +69,10 @@ std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node) {
<< " for make tuple node:" << make_tuple_cnode->DebugString();
}
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(make_tuple_inputs[j + kMakeTupleInputStartPos], 0));
} else if (node_with_index.first->isa<ValueNode>()) {
// When the value node is a value tuple, the value node will have multiple outputs, which need to be directly
// put into the vector, and the output cannot be obtained through the VisitKernelWithReturnType interface.
results.emplace_back(node_with_index.first, j);
} else {
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(node_with_index.first, j));
}
@ -202,6 +206,11 @@ std::vector<ExitActorPtr> ControlNodeScheduler::BuildExitActor(const GraphCompil
for (const auto func_graph_to_kernel_graphs : parser->func_graph_to_kernel_graphs_) {
for (const auto &kernel_graph : func_graph_to_kernel_graphs.second) {
MS_EXCEPTION_IF_NULL(kernel_graph);
// If the graph does not have kernel, it means there is no internal calculation in it, the output is parameter,
// and no exit actor is needed.
if (kernel_graph->execution_order().empty()) {
continue;
}
std::vector<KernelWithIndex> formal_parameters;
const auto &graph_outputs = kernel_graph->graph_output_map();
std::vector<const DeviceContext *> device_contexts;
@ -232,6 +241,10 @@ std::vector<ExitActorPtr> ControlNodeScheduler::BuildExitActor(const GraphCompil
const auto &actor_name = kernel_graph->ToString();
const auto &exit_actor = std::make_shared<ExitActor>(actor_name, formal_parameters, nullptr);
exit_actors.emplace_back(exit_actor);
if (exit_actor->device_contexts_.size() != device_contexts.size()) {
MS_LOG(EXCEPTION) << "Invalid device context size:" << device_contexts.size()
<< " need:" << exit_actor->device_contexts_.size() << " for actor:" << exit_actor->GetAID();
}
exit_actor->device_contexts_.swap(device_contexts);
InsertActor(exit_actor.get());
}
@ -251,13 +264,23 @@ std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphComp
const auto &device_context = graph_with_context.second;
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(device_context);
// If the graph does not have kernel, it means there is no internal calculation in it, the output is parameter,
// and no stack actor is needed.
if (graph->execution_order().empty()) {
continue;
}
const auto &real_parameters = graph->input_nodes();
// Collect inputs of stack actor.
for (const auto &parameter : real_parameters) {
const auto &front_node_with_index = GetFrontNodeByKernelGraph(parameter, graph);
MS_EXCEPTION_IF_NULL(front_node_with_index.first);
formal_parameters.emplace_back(front_node_with_index);
// If the input comes from inside funcgraph, put it at the front of the vector, otherwise put it at the end.
if (AnfAlgo::IsCallNode(front_node_with_index.first)) {
formal_parameters.emplace_back(front_node_with_index);
} else {
formal_parameters.insert(formal_parameters.begin(), front_node_with_index);
}
}
const auto &actor_name = graph->ToString() + kStackActorNameSuffix;
@ -347,7 +370,13 @@ void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_act
to_actor->local_partials_[to_node_with_index.second] = OpPartial(func_graph.get(), {});
} else {
// Link device store value node.
to_actor->device_tensor_store_keys_.emplace_back(to_node_with_index.second, from_node.get());
if (!AnfAlgo::OutputAddrExist(from_node, from_node_with_index.second)) {
MS_LOG(EXCEPTION) << "Invalid output address index:" << from_node_with_index.second
<< " for value node:" << from_node->DebugString();
}
to_actor->local_device_tensors_[to_node_with_index.second] =
AnfAlgo::GetMutableOutputAddr(from_node, from_node_with_index.second, false).get();
to_actor->local_device_tensors_[to_node_with_index.second]->SetNodeIndex(from_node, from_node_with_index.second);
}
} else if (from_node->isa<Parameter>()) {
// Link arrow from entrance actor.
@ -451,8 +480,6 @@ void ControlNodeScheduler::LinkArrowByKernel(const AnfNodePtr &kernel, ControlAc
auto device_tensor = AnfAlgo::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, false);
UpdateRefCount(device_tensor.get(), true);
device_tensor->SetNodeIndex(kernel_with_index.first, kernel_with_index.second);
kernel_actor->output_data_nodes_.emplace_back(kernel_with_index.first);
LinkDataArrow(kernel_actor, to_actor, kernel_with_index.second, to_node_with_index.second);
} else {
// Link arrow from exit actor.
@ -535,7 +562,25 @@ void ControlNodeScheduler::LinkBranchIDArrowForControlActor(ControlActorSet *con
MS_EXCEPTION_IF_NULL(actor);
auto entrance_actor = dynamic_cast<EntranceActor *>(actor);
MS_EXCEPTION_IF_NULL(entrance_actor);
entrance_actor->output_branch_id_arrows_.emplace_back(exit_actor->GetAID());
LinkBranchIDArrow(entrance_actor, exit_actor.get());
}
// Connect the branch id arrows from the entrance actor to the stack actor.
for (auto stack_actor : control_actor_set->stack_actors_) {
MS_EXCEPTION_IF_NULL(stack_actor);
if (stack_actor->formal_parameters_.empty()) {
MS_LOG(ERROR) << "Invalid stack actor:" << stack_actor->GetAID();
}
const auto &node = stack_actor->formal_parameters_.back().first;
MS_EXCEPTION_IF_NULL(node);
const auto &func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
const auto &actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto entrance_actor = dynamic_cast<EntranceActor *>(actor);
MS_EXCEPTION_IF_NULL(entrance_actor);
LinkBranchIDArrow(entrance_actor, stack_actor.get());
}
}
@ -587,7 +632,7 @@ void ControlNodeScheduler::LinkDataArrowByKernelGraph(const KernelGraphPtr &grap
// If there is a call node in the input of the graph, the parameter of the graph needs to be sent by the
// corresponding stack actor, otherwise it is sent by the entrance actor.
if (is_call_input_graph) {
auto actor = FetchActor(graph->ToString());
auto actor = FetchActor(graph->ToString() + kStackActorNameSuffix);
MS_EXCEPTION_IF_NULL(actor);
from_actor = dynamic_cast<ControlActor *>(actor);
MS_EXCEPTION_IF_NULL(from_actor);
@ -623,7 +668,8 @@ void ControlNodeScheduler::LinkDataArrowForOutputActor(ActorSet *const actor_set
auto exit_actor = dynamic_cast<ExitActor *>(actor);
MS_EXCEPTION_IF_NULL(exit_actor);
for (size_t i = 0; i < exit_actor->formal_parameters_.size(); ++i) {
LinkDataArrow(exit_actor, to_actor.get(), i, i);
LinkDataArrowForExitActor(exit_actor, to_actor.get(), i, i, 0);
to_actor->input_datas_num_++;
}
auto control_node_to_device_contexts = parser->control_node_to_device_contexts_;
@ -659,10 +705,7 @@ void ControlNodeScheduler::LinkDataArrowForHostDSActor(const GraphCompilerInfo &
const auto &iter = host_ds_actor->data_node_position_map_.find(formal_parameter.first);
if (iter != host_ds_actor->data_node_position_map_.end()) {
const auto &parameter = host_ds_actor->data_nodes()[iter->second];
LinkDataArrow(host_ds_actor, to_actor, iter->second, i);
// Set the source node to the device address.
host_ds_actor->output_data_nodes_.emplace_back(parameter);
LinkDataArrow(host_ds_actor, to_actor, iter->second, i, parameter);
if (!AnfAlgo::OutputAddrExist(parameter, 0, false)) {
MS_LOG(EXCEPTION) << "Invalid output index:" << 0 << " for parameter:" << parameter->DebugString();
}
@ -674,12 +717,13 @@ void ControlNodeScheduler::LinkDataArrowForHostDSActor(const GraphCompilerInfo &
}
void ControlNodeScheduler::LinkDataArrow(AbstractActor *const from_actor, AbstractActor *const to_actor,
size_t from_index, size_t to_index) {
size_t from_index, size_t to_index, const AnfNodePtr &from_kernel) {
MS_EXCEPTION_IF_NULL(from_actor);
MS_EXCEPTION_IF_NULL(to_actor);
auto data_arrow = std::make_shared<DataArrow>(from_index, to_actor->GetAID(), to_index);
(void)from_actor->output_data_arrows_.emplace_back(data_arrow);
(void)from_actor->output_data_nodes_.emplace_back(from_kernel);
to_actor->input_datas_num_++;
(void)to_actor->input_data_arrow_aids_.emplace_back(from_actor->GetAID());
}
@ -692,7 +736,7 @@ void ControlNodeScheduler::LinkControlArrow(AbstractActor *from_actor, AbstractA
(void)to_actor->input_control_arrow_aids_.emplace_back(from_actor->GetAID());
}
void ControlNodeScheduler::LinkDataArrowForExitActor(ExitActor *const exit_actor, ControlActor *const to_actor,
void ControlNodeScheduler::LinkDataArrowForExitActor(ExitActor *const exit_actor, AbstractActor *const to_actor,
size_t from_index, size_t to_index, int branch_id) {
MS_EXCEPTION_IF_NULL(exit_actor);
MS_EXCEPTION_IF_NULL(to_actor);
@ -726,6 +770,14 @@ void ControlNodeScheduler::LinkPartialArrow(ControlActor *const from_actor, Cont
auto op_arrow = std::make_shared<DataArrow>(from_index, to_actor->GetAID(), to_index);
from_actor->output_partial_arrows_.emplace_back(op_arrow);
to_actor->input_partials_num_++;
to_actor->input_partial_arrow_aids_.emplace_back(from_actor->GetAID());
}
void ControlNodeScheduler::LinkBranchIDArrow(ControlActor *const from_actor, ControlActor *const to_actor) {
MS_EXCEPTION_IF_NULL(from_actor);
MS_EXCEPTION_IF_NULL(to_actor);
from_actor->output_branch_id_arrows_.emplace_back(to_actor->GetAID());
to_actor->input_branch_id_arrow_aids_.emplace_back(from_actor->GetAID());
}
bool ControlNodeScheduler::CheckActorValid(const ControlActorSetPtr &control_actor_set) {

View File

@ -78,12 +78,13 @@ class ControlNodeScheduler {
const FuncGraphPtr &func_graph);
void LinkPartialArrow(ControlActor *const from_actor, ControlActor *const to_actor, size_t from_index,
size_t to_index);
void LinkDataArrow(AbstractActor *const from_actor, AbstractActor *const to_actor, size_t from_index,
size_t to_index);
void LinkDataArrow(AbstractActor *const from_actor, AbstractActor *const to_actor, size_t from_index, size_t to_index,
const AnfNodePtr &from_kernel = nullptr);
void LinkBranchIDArrow(ControlActor *const from_actor, ControlActor *const to_actor);
// Since the output of exit actor has branches, it needs to be based on a dedicated interface.
void LinkControlArrowForExitActor(ExitActor *from_actor, AbstractActor *to_actor, int branch_id);
void LinkDataArrowForExitActor(ExitActor *const exit_actor, ControlActor *const to_actor, size_t from_index,
void LinkDataArrowForExitActor(ExitActor *const exit_actor, AbstractActor *const to_actor, size_t from_index,
size_t to_index, int branch_id);
};
} // namespace runtime

View File

@ -625,10 +625,7 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
// the corresponding backend parameter from the map, and insert it into the host data source actor
const auto &control_node_parameters = graph_compiler_info.control_node_parser_->control_node_parameters();
for (const auto &parameter : control_node_parameters) {
if (IsPersistentDeviceTensor(parameter)) {
continue;
}
auto backend_iter = front_to_backend_parameter.find(parameter);
auto backend_iter = front_to_backend_parameter.find({parameter, 0});
if (backend_iter == front_to_backend_parameter.end() || backend_iter->second.empty()) {
MS_LOG(EXCEPTION) << "Cannot find backend node for front node:" << AnfAlgo::GetNodeDebugString(parameter);
}
@ -1730,6 +1727,9 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler
MS_EXCEPTION_IF_NULL(device_tensor);
if (IsPersistentDeviceTensor(input_node)) {
AddDeviceTensorStore(front_node.get(), device_tensor);
// In the control flow, the device tensor of the weight needs to be obtained according to the backend node,
// so insert the relationship between the backend node and the device tensor.
AddDeviceTensorStore(input_node.get(), device_tensor);
}
// Share the weight in the host and device, then input_node is internal parameter and front_node is weight.
@ -1743,17 +1743,12 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler
auto other_type_device_tensor = device_context->CreateDeviceAddress(
nullptr, device_tensor->GetSize(), device_tensor->format(), device_tensor->type_id());
AddDeviceTensorStore(front_node.get(), other_type_device_tensor);
// In the control flow, the device tensor of the weight needs to be obtained according to the backend node,
// so insert the relationship between the backend node and the device tensor.
AddDeviceTensorStore(input_node.get(), other_type_device_tensor);
}
}
}
// In control flow, there may be some value nodes that is not in the kernel graph and needs to be placed
// in the tensor store separately.
for (const auto &value_node : graph_compiler_info.control_node_parser_->front_value_nodes_) {
MS_EXCEPTION_IF_NULL(value_node.first);
auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node.first, 0, false);
AddDeviceTensorStore(value_node.first.get(), device_tensor);
}
}
void GraphScheduler::FetchKernelTransformTypeAndName(const AnfNodePtr &node, const KernelGraphPtr &graph,