forked from mindspore-Ecosystem/mindspore
!26212 Control flow support untail call.
Merge pull request !26212 from gaoyong10/runtime_second8
This commit is contained in:
commit
390b3c2efa
|
@ -64,7 +64,9 @@ void DumpAbstractActor(const AbstractActor *actor, std::ofstream &ofs) {
|
|||
}
|
||||
|
||||
if (actor->output_data_arrows().size() != actor->output_data_nodes().size()) {
|
||||
MS_LOG(EXCEPTION) << "The size of output data arrows is not equal to the output nodes.";
|
||||
MS_LOG(EXCEPTION) << "The size of output data arrows is not equal to the output nodes, arrow num:"
|
||||
<< actor->output_data_arrows().size() << " node num:" << actor->output_data_nodes().size()
|
||||
<< " for actor:" << actor->GetAID();
|
||||
}
|
||||
if (actor->output_data_arrows().size() > 0) {
|
||||
ofs << "\t\toutput_data_arrows:" << actor->output_data_arrows().size() << "\n ";
|
||||
|
@ -178,22 +180,30 @@ void DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) {
|
|||
}
|
||||
|
||||
void DumpControlActor(const ControlActor *actor, std::ofstream &ofs) {
|
||||
const auto &output_data_arrows = actor->output_data_arrows();
|
||||
if (output_data_arrows.size() > 0) {
|
||||
ofs << "\t\t\toutput_data_arrows:" << output_data_arrows.size() << "\n ";
|
||||
for (const auto &data_arrow : output_data_arrows) {
|
||||
MS_EXCEPTION_IF_NULL(data_arrow);
|
||||
ofs << "\t\t\t\tfrom_output_index:" << data_arrow->from_output_index_
|
||||
<< "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_
|
||||
<< "\n";
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
DumpAbstractActor(actor, ofs);
|
||||
const auto &local_partials = actor->local_partials();
|
||||
if (local_partials.size() > 0) {
|
||||
ofs << "\t\t\tlocal partial num:" << local_partials.size() << "\n ";
|
||||
for (const auto &local_partial : local_partials) {
|
||||
MS_EXCEPTION_IF_NULL(local_partial.second.first);
|
||||
ofs << "\t\t\t\tlocal partial index:" << local_partial.first
|
||||
<< "\tgraph:" << local_partial.second.first->ToString()
|
||||
<< "\tparameter num:" << local_partial.second.second.size() << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
const auto &output_control_arrows = actor->output_control_arrows();
|
||||
if (output_control_arrows.size() > 0) {
|
||||
ofs << "\t\t\toutput_control_arrows:" << output_control_arrows.size() << "\n ";
|
||||
for (const auto &aid : output_control_arrows) {
|
||||
ofs << "\t\t\t\tto_actor_name:" << aid.Name() << "\n";
|
||||
if (actor->input_partial_arrow_aids().size() > 0) {
|
||||
ofs << "\t\tinput_partial_arrow_actor:" << actor->input_partial_arrow_aids().size() << "\n ";
|
||||
for (const auto &input_partial_arrow_aid : actor->input_partial_arrow_aids()) {
|
||||
ofs << "\t\t\tfrom_actor_name:" << input_partial_arrow_aid.Name() << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
if (actor->input_branch_id_arrow_aids().size() > 0) {
|
||||
ofs << "\t\tinput_branch_id_arrow_actor:" << actor->input_branch_id_arrow_aids().size() << "\n ";
|
||||
for (const auto &input_branch_id_arrow_aid : actor->input_branch_id_arrow_aids()) {
|
||||
ofs << "\t\t\tfrom_actor_name:" << input_branch_id_arrow_aid.Name() << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -219,13 +229,13 @@ void DumpControlActor(const ControlActor *actor, std::ofstream &ofs) {
|
|||
|
||||
void DumpSwitchActor(const SwitchActor *actor, std::ofstream &ofs) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
ofs << "\t\ttactor_name:" << actor->GetAID().Name() << '\n';
|
||||
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
|
||||
DumpControlActor(actor, ofs);
|
||||
}
|
||||
|
||||
void DumpGatherActor(const GatherActor *actor, std::ofstream &ofs) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
ofs << "\t\tactor_name:" << actor->GetAID().Name() << '\n';
|
||||
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
|
||||
DumpControlActor(actor, ofs);
|
||||
|
||||
const auto &output_data_with_branch_id_arrows = actor->output_data_with_branch_id_arrows();
|
||||
|
@ -242,13 +252,13 @@ void DumpGatherActor(const GatherActor *actor, std::ofstream &ofs) {
|
|||
|
||||
void DumpEntranceActor(const EntranceActor *actor, std::ofstream &ofs) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
ofs << "\t\tactor_name:" << actor->GetAID().Name() << '\n';
|
||||
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
|
||||
DumpControlActor(actor, ofs);
|
||||
}
|
||||
|
||||
void DumpExitActor(const ExitActor *actor, std::ofstream &ofs) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
ofs << "\t\tactor_name:" << actor->GetAID().Name() << '\n';
|
||||
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
|
||||
DumpControlActor(actor, ofs);
|
||||
|
||||
const auto &output_branch_data_arrows = actor->output_branch_data_arrows();
|
||||
|
@ -291,40 +301,40 @@ void DumpExitActor(const ExitActor *actor, std::ofstream &ofs) {
|
|||
|
||||
void DumpStackActor(const StackActor *actor, std::ofstream &ofs) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
ofs << "\t\tactor_name:" << actor->GetAID().Name() << '\n';
|
||||
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
|
||||
DumpControlActor(actor, ofs);
|
||||
}
|
||||
|
||||
void DumpSwitchActors(const std::vector<SwitchActorPtr> &actors, std::ofstream &ofs) {
|
||||
ofs << "\n\n\t[Switch actors:" << actors.size() << "]\n";
|
||||
ofs << "\n\n[Switch actors:" << actors.size() << "]\n";
|
||||
for (const auto &switch_actor : actors) {
|
||||
DumpSwitchActor(switch_actor.get(), ofs);
|
||||
}
|
||||
}
|
||||
|
||||
void DumpGatherActors(const std::vector<GatherActorPtr> &actors, std::ofstream &ofs) {
|
||||
ofs << "\n\n\t[Gather actors:" << actors.size() << "]\n";
|
||||
ofs << "\n\n[Gather actors:" << actors.size() << "]\n";
|
||||
for (const auto &gather_actor : actors) {
|
||||
DumpGatherActor(gather_actor.get(), ofs);
|
||||
}
|
||||
}
|
||||
|
||||
void DumpEntranceActors(const std::vector<EntranceActorPtr> &actors, std::ofstream &ofs) {
|
||||
ofs << "\n\n\t[Entrance actors:" << actors.size() << "]\n";
|
||||
ofs << "\n\n[Entrance actors:" << actors.size() << "]\n";
|
||||
for (const auto &entrance_actor : actors) {
|
||||
DumpEntranceActor(entrance_actor.get(), ofs);
|
||||
}
|
||||
}
|
||||
|
||||
void DumpExitActors(const std::vector<ExitActorPtr> &actors, std::ofstream &ofs) {
|
||||
ofs << "\n\n\t[Exit actors:" << actors.size() << "]\n";
|
||||
ofs << "\n\n[Exit actors:" << actors.size() << "]\n";
|
||||
for (const auto &exit_actor : actors) {
|
||||
DumpExitActor(exit_actor.get(), ofs);
|
||||
}
|
||||
}
|
||||
|
||||
void DumpStackActors(const std::vector<StackActorPtr> &actors, std::ofstream &ofs) {
|
||||
ofs << "\n\n\t[Stack actors:" << actors.size() << "]\n";
|
||||
ofs << "\n\n[Stack actors:" << actors.size() << "]\n";
|
||||
for (const auto &stack_actor : actors) {
|
||||
DumpStackActor(stack_actor.get(), ofs);
|
||||
}
|
||||
|
|
|
@ -42,7 +42,8 @@ void ControlActor::Init() {
|
|||
size_t ControlActor::FetchNodePosition(const KernelWithIndex &node) const {
|
||||
const auto &iter = find(formal_parameters_.begin(), formal_parameters_.end(), node);
|
||||
if (iter == formal_parameters_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid formal parameter:" << node.first->DebugString() << " for actor:" << GetAID();
|
||||
MS_LOG(EXCEPTION) << "Invalid formal parameter:" << node.first->DebugString() << " index:" << node.second
|
||||
<< " for actor:" << GetAID();
|
||||
}
|
||||
return iter - formal_parameters_.begin();
|
||||
}
|
||||
|
@ -112,18 +113,13 @@ void ControlActor::FetchInput(OpContext<DeviceTensor> *const context) {
|
|||
}
|
||||
|
||||
// Fetch input device tensor from device store.
|
||||
for (auto &device_tensor_store_key : device_tensor_store_keys_) {
|
||||
auto device_tensor = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get(),
|
||||
device_contexts_[0]->GetDeviceAddressType());
|
||||
if (device_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << GetAID() << " get device tensor store failed: " << device_tensor_store_key.second->DebugString();
|
||||
for (auto &local_device_tensor : local_device_tensors_) {
|
||||
MS_EXCEPTION_IF_NULL(local_device_tensor.second);
|
||||
if (local_device_tensor.first >= input_device_tensors_.size()) {
|
||||
MS_LOG(ERROR) << "Invalid local index:" << local_device_tensor.first
|
||||
<< " current:" << local_device_tensors_.size() << " for actor:" << GetAID();
|
||||
}
|
||||
|
||||
if (device_tensor_store_key.first >= input_device_tensors_.size()) {
|
||||
MS_LOG(ERROR) << "The input index is out of range, need:" << device_tensor_store_key.first
|
||||
<< " current:" << input_device_tensors_.size() << " for actor:" << GetAID();
|
||||
}
|
||||
input_device_tensors_[device_tensor_store_key.first] = device_tensor;
|
||||
input_device_tensors_[local_device_tensor.first] = local_device_tensor.second;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < output_data_by_output_index_.size(); ++i) {
|
||||
|
|
|
@ -54,6 +54,9 @@ class ControlActor : public AbstractActor {
|
|||
|
||||
const std::vector<DataArrowPtr> &output_partial_arrows() const { return output_partial_arrows_; }
|
||||
const std::vector<AID> &output_branch_id_arrows() const { return output_branch_id_arrows_; }
|
||||
const std::unordered_map<size_t, OpPartial> &local_partials() const { return local_partials_; }
|
||||
const std::vector<AID> &input_partial_arrow_aids() const { return input_partial_arrow_aids_; }
|
||||
const std::vector<AID> &input_branch_id_arrow_aids() const { return input_branch_id_arrow_aids_; }
|
||||
|
||||
protected:
|
||||
friend class ControlNodeScheduler;
|
||||
|
@ -87,6 +90,10 @@ class ControlActor : public AbstractActor {
|
|||
// Input num.
|
||||
size_t input_partials_num_{0};
|
||||
|
||||
// The dependent input actors.
|
||||
std::vector<AID> input_partial_arrow_aids_;
|
||||
std::vector<AID> input_branch_id_arrow_aids_;
|
||||
|
||||
// Output Arrows.
|
||||
std::vector<DataArrowPtr> output_partial_arrows_;
|
||||
OpPartial output_partial_;
|
||||
|
@ -99,6 +106,8 @@ class ControlActor : public AbstractActor {
|
|||
|
||||
// Partial data in local. When partial is only funcgraph without real parameter, it is stored inside the actor.
|
||||
std::unordered_map<size_t, OpPartial> local_partials_;
|
||||
// Device tensor in control node, but not in kernel graph.
|
||||
std::unordered_map<size_t, DeviceTensor *> local_device_tensors_;
|
||||
|
||||
// Cache output data by output index to modify the output data effectively.
|
||||
std::vector<std::vector<OpData<DeviceTensor> *>> output_data_by_output_index_;
|
||||
|
|
|
@ -85,10 +85,11 @@ void ExitActor::CopyDeviceAddress() {
|
|||
MS_EXCEPTION_IF_NULL(input_device_tensor);
|
||||
const KernelWithIndex &node_with_index = input_device_tensor->GetNodeIndex();
|
||||
MS_EXCEPTION_IF_NULL(node_with_index.first);
|
||||
if (!node_with_index.first->isa<CNode>()) {
|
||||
if (device_contexts_[i] == nullptr) {
|
||||
// If device context is empty, it means that the input is from a parameter, need not to copy a new device tensor.
|
||||
new_device_tensors.emplace_back(input_device_tensor);
|
||||
continue;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(device_contexts_[i]);
|
||||
auto new_device_tensor =
|
||||
device_contexts_[i]->CreateDeviceAddress(nullptr, input_device_tensors_[i]->GetSize(),
|
||||
|
|
|
@ -21,7 +21,9 @@ namespace mindspore {
|
|||
namespace runtime {
|
||||
GatherActor::GatherActor(const std::string &name, const std::vector<KernelWithIndex> ¶meters,
|
||||
const AnfNodePtr &node)
|
||||
: ControlActor(name, KernelTransformType::kGatherActor, parameters, node) {}
|
||||
: ControlActor(name, KernelTransformType::kGatherActor, parameters, node) {
|
||||
device_contexts_.resize(parameters.size());
|
||||
}
|
||||
|
||||
void GatherActor::FetchInput(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
|
|
|
@ -25,13 +25,96 @@ StackActor::StackActor(const std::string &name, const std::vector<KernelWithInde
|
|||
input_device_tensors_.resize(parameters.size());
|
||||
}
|
||||
|
||||
bool StackActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
return false;
|
||||
void StackActor::Init() {
|
||||
ControlActor::Init();
|
||||
for (const auto &formal_parameter : formal_parameters_) {
|
||||
if (AnfAlgo::IsCallNode(formal_parameter.first)) {
|
||||
break;
|
||||
}
|
||||
++input_parameter_data_num_;
|
||||
}
|
||||
input_datas_num_ = formal_parameters_.size() - input_parameter_data_num_;
|
||||
}
|
||||
|
||||
void StackActor::FetchInput(OpContext<DeviceTensor> *const context) { MS_EXCEPTION_IF_NULL(context); }
|
||||
void StackActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
MS_EXCEPTION_IF_NULL(input_data->data_);
|
||||
auto &sequential_num = context->sequential_num_;
|
||||
// The parameters from the inside of the subgraph need to be put into the stack.
|
||||
if (IntToSize(input_data->index_) < input_parameter_data_num_) {
|
||||
input_parameter_data_[sequential_num][input_data->index_].push(input_data->data_);
|
||||
} else {
|
||||
// The outputs of call nodes are placed directly in the input data.
|
||||
input_op_datas_[sequential_num].emplace_back(input_data);
|
||||
}
|
||||
if (CheckRunningCondition(context)) {
|
||||
Run(context);
|
||||
}
|
||||
}
|
||||
|
||||
void StackActor::EraseInput(const OpContext<DeviceTensor> *const context) { MS_EXCEPTION_IF_NULL(context); }
|
||||
bool StackActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (!ControlActor::CheckRunningCondition(context)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input_parameter_data_num_ != 0) {
|
||||
const auto &data_iter = input_parameter_data_.find(context->sequential_num_);
|
||||
if (data_iter == input_parameter_data_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (data_iter->second.size() != input_parameter_data_num_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto iter = input_branch_ids_.find(context->sequential_num_);
|
||||
if (iter == input_branch_ids_.end() || iter->second.empty()) {
|
||||
MS_LOG(ERROR) << "There is no branch id for actor:" << GetAID();
|
||||
}
|
||||
size_t branch_id_size = iter->second.size();
|
||||
if (std::any_of(data_iter->second.begin(), data_iter->second.end(),
|
||||
[branch_id_size](const auto &one_stack) { return one_stack.second.size() != branch_id_size; })) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void StackActor::FetchInput(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (input_parameter_data_num_ != 0) {
|
||||
const auto &data_iter = input_parameter_data_.find(context->sequential_num_);
|
||||
if (data_iter == input_parameter_data_.end()) {
|
||||
MS_LOG(ERROR) << "Invalid input for actor:" << GetAID();
|
||||
}
|
||||
for (const auto &one_stack : data_iter->second) {
|
||||
if (one_stack.first >= input_parameter_data_num_) {
|
||||
MS_LOG(ERROR) << "Invalid input index:" << one_stack.first << " need:" << input_parameter_data_num_
|
||||
<< " for actor:" << GetAID();
|
||||
}
|
||||
input_device_tensors_[one_stack.first] = one_stack.second.top();
|
||||
}
|
||||
}
|
||||
ControlActor::FetchInput(context);
|
||||
}
|
||||
|
||||
void StackActor::EraseInput(const OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
ControlActor::EraseInput(context);
|
||||
|
||||
if (input_parameter_data_num_ != 0) {
|
||||
const auto &data_iter = input_parameter_data_.find(context->sequential_num_);
|
||||
if (data_iter == input_parameter_data_.end()) {
|
||||
MS_LOG(ERROR) << "Invalid input for actor:" << GetAID();
|
||||
}
|
||||
|
||||
for (auto &one_stack : data_iter->second) {
|
||||
if (one_stack.second.empty()) {
|
||||
MS_LOG(ERROR) << "Input index:" << one_stack.first << " is null in actor:" << GetAID();
|
||||
}
|
||||
one_stack.second.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,6 +38,11 @@ class StackActor : public ControlActor {
|
|||
StackActor(const std::string &name, const std::vector<KernelWithIndex> ¶meters);
|
||||
~StackActor() override = default;
|
||||
|
||||
void Init() override;
|
||||
// The input data of the stack actor needs to be pushed into the stack according to the input index, so it is
|
||||
// implemented separately.
|
||||
void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override;
|
||||
|
||||
protected:
|
||||
void FetchInput(OpContext<DeviceTensor> *const context) override;
|
||||
bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override;
|
||||
|
@ -46,6 +51,12 @@ class StackActor : public ControlActor {
|
|||
private:
|
||||
friend class ControlNodeScheduler;
|
||||
|
||||
// The input data records that the stack actor is copied from the input nodes and needs to be stored in the
|
||||
// device tensor in the stack.
|
||||
std::unordered_map<int, std::unordered_map<size_t, std::stack<DeviceTensor *>>> input_parameter_data_;
|
||||
// Input parameter data num represents the number of actor's input come from funcgraph itself, these inputs
|
||||
// will be ranked at the front of input.
|
||||
size_t input_parameter_data_num_{0};
|
||||
// The backend parameter is used to save the backend node corresponding to the device tensor in the stack.
|
||||
// When these device tensors are used as output, they need to be placed in the node of the result arrow,
|
||||
// so these nodes need to be saved.
|
||||
|
|
|
@ -347,6 +347,60 @@ void DataPrepareActor::PrepareDataForValueNodeTensor(const ValueNodePtr &node, c
|
|||
}
|
||||
}
|
||||
|
||||
void DataPrepareActor::PrepareDataForControlValueNode(const KernelWithIndex &node_with_index,
|
||||
const DeviceContext *device_context,
|
||||
OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
MS_EXCEPTION_IF_NULL(node_with_index.first);
|
||||
if (!node_with_index.first->isa<ValueNode>()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const auto &node = node_with_index.first->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
const auto &node_value = node->value();
|
||||
MS_EXCEPTION_IF_NULL(node_value);
|
||||
|
||||
if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
|
||||
std::vector<TensorPtr> tensors;
|
||||
// Fetch all of tensors in value node.
|
||||
TensorValueToTensor(node_value, &tensors);
|
||||
|
||||
for (size_t i = 0; i < tensors.size(); i++) {
|
||||
const auto &tensor = tensors[i];
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(WARNING) << "Tensor is null";
|
||||
return;
|
||||
}
|
||||
|
||||
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, i, false);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
if (device_tensor->GetPtr() != nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Prepare device data for control value node: " << node->DebugString() << ", output index: " << i;
|
||||
tensor->set_device_address(device_tensor);
|
||||
UpdateRefCount(device_tensor.get(), true);
|
||||
|
||||
if (!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize())) {
|
||||
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(strategy_, *context, *device_context, node->fullname_with_scope(),
|
||||
device_tensor->GetSize());
|
||||
}
|
||||
|
||||
auto host_tensor_size = LongToSize(tensor->data().nbytes());
|
||||
auto host_tensor_type = tensor->data_type();
|
||||
auto shape = tensor->shape();
|
||||
if (!device_tensor->SyncHostToDevice(shape, host_tensor_size, host_tensor_type, tensor->data_c(),
|
||||
tensor->device_info().host_format_)) {
|
||||
std::string error_info = "Sync host to device failed for node:" + node->DebugString();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare the device data for persistent device tensor of value node.
|
||||
void DataPrepareActor::PrepareDataForValueNode(const ValueNodePtr &node, const DeviceContext *device_context,
|
||||
OpContext<DeviceTensor> *const context) {
|
||||
|
@ -504,9 +558,8 @@ void DataPrepareActor::PrepareDeviceTensorStoreForControlNode(const ControlNodeP
|
|||
OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(control_node_parser);
|
||||
for (const auto &value_node_with_context : control_node_parser->front_value_nodes()) {
|
||||
if (AnfAlgo::OutputAddrExist(value_node_with_context.first, 0)) {
|
||||
PrepareDataForValueNode(value_node_with_context.first->cast<ValueNodePtr>(), value_node_with_context.second,
|
||||
context);
|
||||
if (AnfAlgo::OutputAddrExist(value_node_with_context.first.first, 0)) {
|
||||
PrepareDataForControlValueNode(value_node_with_context.first, value_node_with_context.second, context);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -517,7 +570,7 @@ void DataPrepareActor::PrepareDeviceTensorStoreForControlNode(const ControlNodeP
|
|||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (IsPersistentDeviceTensor(input_node)) {
|
||||
const auto &front_to_backend_parameters = control_node_parser->front_to_backend_parameters();
|
||||
const auto &iter = front_to_backend_parameters.find(input_node);
|
||||
const auto &iter = front_to_backend_parameters.find({input_node, 0});
|
||||
if (iter == front_to_backend_parameters.end() || iter->second.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:"
|
||||
<< AnfAlgo::GetNodeDebugString(input_node);
|
||||
|
|
|
@ -97,6 +97,8 @@ class DataPrepareActor : public DebugAwareActor {
|
|||
const DeviceContext *device_context,
|
||||
const HostParameterToWeight &host_parameter_to_weights,
|
||||
OpContext<DeviceTensor> *const context);
|
||||
void PrepareDataForControlValueNode(const KernelWithIndex &node_with_index, const DeviceContext *device_context,
|
||||
OpContext<DeviceTensor> *const context);
|
||||
|
||||
const GraphCompilerInfo *graph_compiler_info_;
|
||||
GraphExecutionStrategy strategy_;
|
||||
|
|
|
@ -235,7 +235,24 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *cons
|
|||
for (size_t i = 0; i < host_tensors.size(); ++i) {
|
||||
auto &host_tensor = host_tensors[i];
|
||||
auto &device_tensor = device_tensors[i];
|
||||
MS_EXCEPTION_IF_NULL(host_tensor);
|
||||
if (host_tensor == nullptr) {
|
||||
// In the control flow, the weight device tensor needs to be sent by the data source actor, and the input of
|
||||
// the data prepare actor is host tensor, the device tensor of the weight cannot be obtained, the input will
|
||||
// be empty, here to check whether it is weight, if it is, get the device tensor from the device tensor store.
|
||||
if (IsPersistentDeviceTensor(data_nodes_[i])) {
|
||||
MS_EXCEPTION_IF_NULL(device_contexts_[i]);
|
||||
auto device_store_tensor =
|
||||
DeviceTensorStore::GetInstance().Fetch(data_nodes_[i].get(), device_contexts_[i]->GetDeviceAddressType());
|
||||
if (device_store_tensor == nullptr) {
|
||||
std::string error_info = GetAID().Name() + " failed get device tensor for: " + data_nodes_[i]->DebugString();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
device_tensor = device_store_tensor;
|
||||
continue;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid host tensor for index:" << i << " node:" << data_nodes_[i]->DebugString();
|
||||
}
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
auto tensor_device_address = std::dynamic_pointer_cast<DeviceTensor>(host_tensor->device_address());
|
||||
// Sync data from host_tensor_device_address to device_tensor.
|
||||
|
|
|
@ -24,7 +24,10 @@ namespace {
|
|||
// Get all the real parameters corresponding to node.
|
||||
void FetchRealParameterByNode(const KernelWithIndex &node, std::set<KernelWithIndex> *real_parameters,
|
||||
std::set<KernelWithIndex> *invalid_call_nodes) {
|
||||
const auto &node_with_index = AnfAlgo::VisitKernelWithReturnType(node.first, node.second);
|
||||
auto node_with_index = node;
|
||||
if (!node.first->isa<ValueNode>()) {
|
||||
node_with_index = AnfAlgo::VisitKernelWithReturnType(node.first, node.second);
|
||||
}
|
||||
if (node_with_index.first->isa<ValueNode>() || node_with_index.first->isa<Parameter>()) {
|
||||
// If node is a valuenode or parameter, the real parameter is itself.
|
||||
real_parameters->emplace(node_with_index);
|
||||
|
@ -156,6 +159,20 @@ std::vector<KernelWithIndex> FetchAllOutputWithIndex(const AnfNodePtr &node) {
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
std::vector<KernelWithIndex> result;
|
||||
|
||||
if (node->isa<ValueNode>() && IsValueNode<ValueTuple>(node)) {
|
||||
const auto &value_node = node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
const auto &value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
const auto &value_tuple = value->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_tuple);
|
||||
const auto tuple_value = value_tuple->value();
|
||||
for (size_t i = 0; i < tuple_value.size(); ++i) {
|
||||
result.emplace_back(node, i);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0);
|
||||
if (AnfAlgo::CheckPrimitiveType(node_with_index.first, prim::kPrimMakeTuple)) {
|
||||
const auto &cnode = node_with_index.first->cast<CNodePtr>();
|
||||
|
@ -178,6 +195,73 @@ std::vector<KernelWithIndex> FetchAllOutputWithIndex(const AnfNodePtr &node) {
|
|||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Create a device tensor for the front node.
|
||||
// Get the output format and select kernel build info from the backend node corresponding to the front node to
|
||||
// create the device address.
|
||||
void CreateDeviceTensorForValueNode(const KernelWithIndex &front_node_with_index, const AnfNodePtr &backend_node,
|
||||
const DeviceContext *device_context) {
|
||||
MS_EXCEPTION_IF_NULL(backend_node);
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
const auto &front_node = front_node_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(front_node);
|
||||
|
||||
const auto &node_value = front_node->cast<ValueNodePtr>()->value();
|
||||
if ((!node_value->isa<tensor::Tensor>()) && (!node_value->isa<ValueTuple>())) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(backend_node, 0);
|
||||
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(backend_node, 0);
|
||||
if (output_type_id == kTypeUnknown) {
|
||||
output_type_id = AnfAlgo::GetOutputInferDataType(backend_node, 0);
|
||||
}
|
||||
|
||||
if (front_node->kernel_info() == nullptr) {
|
||||
front_node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
}
|
||||
|
||||
// Get the select kernel build info.
|
||||
auto kernel_info = static_cast<device::KernelInfo *>(backend_node->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
auto build_info = kernel_info->GetMutableSelectKernelBuildInfo();
|
||||
MS_EXCEPTION_IF_NULL(build_info);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(build_info, front_node.get());
|
||||
|
||||
// Create device tensor.
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(backend_node, 0);
|
||||
device::DeviceAddressPtr address =
|
||||
device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id);
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(front_node) << " addr:" << address;
|
||||
AnfAlgo::SetOutputAddr(address, front_node_with_index.second, front_node.get());
|
||||
}
|
||||
|
||||
// Create a device tensor for front node.
|
||||
// When the condition input of the switch and switchlayer or the output of a subgraph is a parameter or value node,
|
||||
// there is no corresponding backend node for this parameter, so a device tensor needs to be created for it.
|
||||
void CreateDeviceTensorForFrontNode(const KernelWithIndex &front_node_with_index, const DeviceContext *device_context) {
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
const auto &node = front_node_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
|
||||
TypeId type_id = AnfAlgo::GetOutputInferDataType(node, 0);
|
||||
if (node->kernel_info() == nullptr) {
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
|
||||
builder->SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
builder->SetOutputsDeviceType({type_id});
|
||||
kernel_info->set_select_kernel_build_info(builder->Build());
|
||||
node->set_kernel_info(kernel_info);
|
||||
}
|
||||
size_t size = AnfAlgo::GetOutputTensorMemSize(node, 0);
|
||||
|
||||
// Create device tensor.
|
||||
device::DeviceAddressPtr address = device_context->CreateDeviceAddress(nullptr, size, kOpFormat_DEFAULT, type_id);
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(node) << " addr:" << address;
|
||||
AnfAlgo::SetOutputAddr(address, front_node_with_index.second, node.get());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool HasAbstractRef(const AnfNodePtr &node) {
|
||||
|
@ -209,7 +293,7 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
|
|||
<< " device context num:" << device_contexts.size();
|
||||
}
|
||||
|
||||
if (control_nodes.size() <= 1) {
|
||||
if (control_nodes.size() <= 1 || device_contexts.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -235,7 +319,7 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
|
|||
|
||||
FetchCallInputKernelGraph(graphs, device_contexts);
|
||||
|
||||
FetchFrontValueNode();
|
||||
FetchFrontValueNode(device_contexts[0]);
|
||||
|
||||
FetchFrontToBackendKernel(graphs, device_contexts);
|
||||
|
||||
|
@ -250,7 +334,7 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
|
|||
|
||||
bool ControlNodeParser::IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (!IsInited()) {
|
||||
if ((!IsInited()) || (!node->isa<Parameter>())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -385,7 +469,8 @@ void ControlNodeParser::ParseDeviceContextForControlNode(const DeviceContext *de
|
|||
// will be thrown.
|
||||
if (call_device_contexts.empty() || call_device_contexts.size() <= output_node.second) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find device context for call node:" << output_node.first->DebugString()
|
||||
<< " device contexts size:" << call_device_contexts.size();
|
||||
<< " device contexts size:" << call_device_contexts.size()
|
||||
<< " index:" << output_node.second;
|
||||
}
|
||||
return_device_contexts.emplace_back(call_device_contexts[output_node.second]);
|
||||
} else if (output_node.first->isa<CNode>()) {
|
||||
|
@ -444,7 +529,9 @@ KernelWithIndex ControlNodeParser::FetchBackendNodeByFrontNode(const KernelWithI
|
|||
return {};
|
||||
}
|
||||
|
||||
void ControlNodeParser::FetchFrontValueNode() {
|
||||
void ControlNodeParser::FetchFrontValueNode(DeviceContext *default_context) {
|
||||
MS_EXCEPTION_IF_NULL(default_context);
|
||||
|
||||
for (const auto &formal_to_real_parameter : formal_to_real_parameters_) {
|
||||
for (const auto &real_parameter_with_index : formal_to_real_parameter.second) {
|
||||
const auto &real_parameter = real_parameter_with_index.first;
|
||||
|
@ -452,12 +539,30 @@ void ControlNodeParser::FetchFrontValueNode() {
|
|||
continue;
|
||||
}
|
||||
|
||||
const auto &iter = front_to_backend_parameters_.find({real_parameter, 0});
|
||||
const auto &iter = front_to_backend_parameters_.find(real_parameter_with_index);
|
||||
if (iter != front_to_backend_parameters_.end() && (!iter->second.empty())) {
|
||||
front_value_nodes_.emplace(real_parameter, iter->second.begin()->second);
|
||||
front_value_nodes_.emplace(real_parameter_with_index, iter->second.begin()->second);
|
||||
CreateDeviceTensorForValueNode(real_parameter_with_index, iter->second.begin()->first,
|
||||
iter->second.begin()->second);
|
||||
} else {
|
||||
front_value_nodes_.emplace(real_parameter_with_index, default_context);
|
||||
CreateDeviceTensorForFrontNode(real_parameter_with_index, default_context);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the output of funcgraph is a value node, it will eventually be sent to the kernel as a real parameter.
|
||||
// These the value nodes also need to create a device address.
|
||||
for (const auto &front_to_backend_parameters : front_to_backend_parameters_) {
|
||||
const auto &front_node = front_to_backend_parameters.first.first;
|
||||
MS_EXCEPTION_IF_NULL(front_node);
|
||||
if (front_node->isa<ValueNode>() && (!front_to_backend_parameters.second.empty())) {
|
||||
const auto &backend_parameter = front_to_backend_parameters.second.begin()->first;
|
||||
const auto &device_context = front_to_backend_parameters.second.begin()->second;
|
||||
CreateDeviceTensorForValueNode(front_to_backend_parameters.first, backend_parameter, device_context);
|
||||
front_value_nodes_.emplace(front_to_backend_parameters.first, device_context);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ControlNodeParser::ParseFormalToRealParameter(const std::vector<AnfNodePtr> &control_nodes) {
|
||||
|
@ -657,11 +762,11 @@ void ControlNodeParser::ParseFrontToBackendParameter(const std::vector<KernelGra
|
|||
FetchRealParameterByNode(front_node_with_index, &real_parameters, &invalid_call_nodes);
|
||||
for (const auto real_parameter : real_parameters) {
|
||||
if (real_parameter.first->isa<Parameter>() || real_parameter.first->isa<ValueNode>()) {
|
||||
front_to_backend_parameters_[real_parameter.first].emplace(parameter, device_context);
|
||||
front_to_backend_parameters_[real_parameter].emplace(parameter, device_context);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
front_to_backend_parameters_[front_node].emplace(parameter, device_context);
|
||||
front_to_backend_parameters_[{front_node, 0}].emplace(parameter, device_context);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -671,12 +776,13 @@ void ControlNodeParser::ParseFrontToBackendParameter(const std::vector<KernelGra
|
|||
for (const auto &front_to_backend_parameters : front_to_backend_parameters_) {
|
||||
const auto &front_parameter = front_to_backend_parameters.first;
|
||||
const auto &backend_parameters = front_to_backend_parameters.second;
|
||||
const auto &iter = formal_to_real_parameters_.find(front_parameter);
|
||||
const auto &iter = formal_to_real_parameters_.find(front_parameter.first);
|
||||
if (iter != formal_to_real_parameters_.end()) {
|
||||
for (const auto &real_parameter_with_index : iter->second) {
|
||||
const auto &real_parameter = real_parameter_with_index.first;
|
||||
if (real_parameter->isa<Parameter>()) {
|
||||
front_to_backend_parameters_[real_parameter].insert(backend_parameters.begin(), backend_parameters.end());
|
||||
front_to_backend_parameters_[real_parameter_with_index].insert(backend_parameters.begin(),
|
||||
backend_parameters.end());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -57,11 +57,11 @@ constexpr size_t kSingleControlNode = 1;
|
|||
const char kEntranceActorNameSuffix[] = "_EntranceActor";
|
||||
const char kStackActorNameSuffix[] = "_StackActor";
|
||||
|
||||
using FrontToBackendNodeWithContext = std::unordered_map<AnfNodePtr, std::set<std::pair<AnfNodePtr, DeviceContext *>>>;
|
||||
using FrontToBackendNodeWithContext = std::map<KernelWithIndex, std::set<std::pair<AnfNodePtr, DeviceContext *>>>;
|
||||
using FrontToBackendKernelWithContext = std::map<KernelWithIndex, std::pair<KernelWithIndex, DeviceContext *>>;
|
||||
using FuncGraphToKernelGraph = std::unordered_map<FuncGraphPtr, std::vector<KernelGraphPtr>>;
|
||||
using HostParameterToWeight = std::unordered_map<AnfNodePtr, std::set<AnfNodePtr>>;
|
||||
using NodeWithDeviceContext = std::set<std::pair<AnfNodePtr, DeviceContext *>>;
|
||||
using NodeWithDeviceContext = std::set<std::pair<KernelWithIndex, DeviceContext *>>;
|
||||
using RealToFormalNode = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>;
|
||||
using FormalToRealParameter = std::unordered_map<AnfNodePtr, std::set<KernelWithIndex>>;
|
||||
using RealToFormalParameter = std::unordered_map<AnfNodePtr, std::set<AnfNodePtr>>;
|
||||
|
@ -114,7 +114,7 @@ class ControlNodeParser {
|
|||
// value nodes will not enter the kernel graph, so these nodes need to be saved separately, and space is allocated for
|
||||
// them separately during initialization.
|
||||
// The interface is initialized by finding the backend node in the kernel graph that the front node finally sends to.
|
||||
void FetchFrontValueNode();
|
||||
void FetchFrontValueNode(DeviceContext *default_context);
|
||||
// Create branch id for all call node in the control flow.
|
||||
void CreateBranchIDForCallNode(const std::vector<AnfNodePtr> &control_nodes);
|
||||
|
||||
|
|
|
@ -69,6 +69,10 @@ std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node) {
|
|||
<< " for make tuple node:" << make_tuple_cnode->DebugString();
|
||||
}
|
||||
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(make_tuple_inputs[j + kMakeTupleInputStartPos], 0));
|
||||
} else if (node_with_index.first->isa<ValueNode>()) {
|
||||
// When the value node is a value tuple, the value node will have multiple outputs, which need to be directly
|
||||
// put into the vector, and the output cannot be obtained through the VisitKernelWithReturnType interface.
|
||||
results.emplace_back(node_with_index.first, j);
|
||||
} else {
|
||||
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(node_with_index.first, j));
|
||||
}
|
||||
|
@ -202,6 +206,11 @@ std::vector<ExitActorPtr> ControlNodeScheduler::BuildExitActor(const GraphCompil
|
|||
for (const auto func_graph_to_kernel_graphs : parser->func_graph_to_kernel_graphs_) {
|
||||
for (const auto &kernel_graph : func_graph_to_kernel_graphs.second) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
// If the graph does not have kernel, it means there is no internal calculation in it, the output is parameter,
|
||||
// and no exit actor is needed.
|
||||
if (kernel_graph->execution_order().empty()) {
|
||||
continue;
|
||||
}
|
||||
std::vector<KernelWithIndex> formal_parameters;
|
||||
const auto &graph_outputs = kernel_graph->graph_output_map();
|
||||
std::vector<const DeviceContext *> device_contexts;
|
||||
|
@ -232,6 +241,10 @@ std::vector<ExitActorPtr> ControlNodeScheduler::BuildExitActor(const GraphCompil
|
|||
const auto &actor_name = kernel_graph->ToString();
|
||||
const auto &exit_actor = std::make_shared<ExitActor>(actor_name, formal_parameters, nullptr);
|
||||
exit_actors.emplace_back(exit_actor);
|
||||
if (exit_actor->device_contexts_.size() != device_contexts.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid device context size:" << device_contexts.size()
|
||||
<< " need:" << exit_actor->device_contexts_.size() << " for actor:" << exit_actor->GetAID();
|
||||
}
|
||||
exit_actor->device_contexts_.swap(device_contexts);
|
||||
InsertActor(exit_actor.get());
|
||||
}
|
||||
|
@ -251,13 +264,23 @@ std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphComp
|
|||
const auto &device_context = graph_with_context.second;
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
// If the graph does not have kernel, it means there is no internal calculation in it, the output is parameter,
|
||||
// and no stack actor is needed.
|
||||
if (graph->execution_order().empty()) {
|
||||
continue;
|
||||
}
|
||||
const auto &real_parameters = graph->input_nodes();
|
||||
|
||||
// Collect inputs of stack actor.
|
||||
for (const auto ¶meter : real_parameters) {
|
||||
const auto &front_node_with_index = GetFrontNodeByKernelGraph(parameter, graph);
|
||||
MS_EXCEPTION_IF_NULL(front_node_with_index.first);
|
||||
formal_parameters.emplace_back(front_node_with_index);
|
||||
// If the input comes from inside funcgraph, put it at the front of the vector, otherwise put it at the end.
|
||||
if (AnfAlgo::IsCallNode(front_node_with_index.first)) {
|
||||
formal_parameters.emplace_back(front_node_with_index);
|
||||
} else {
|
||||
formal_parameters.insert(formal_parameters.begin(), front_node_with_index);
|
||||
}
|
||||
}
|
||||
|
||||
const auto &actor_name = graph->ToString() + kStackActorNameSuffix;
|
||||
|
@ -347,7 +370,13 @@ void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_act
|
|||
to_actor->local_partials_[to_node_with_index.second] = OpPartial(func_graph.get(), {});
|
||||
} else {
|
||||
// Link device store value node.
|
||||
to_actor->device_tensor_store_keys_.emplace_back(to_node_with_index.second, from_node.get());
|
||||
if (!AnfAlgo::OutputAddrExist(from_node, from_node_with_index.second)) {
|
||||
MS_LOG(EXCEPTION) << "Invalid output address index:" << from_node_with_index.second
|
||||
<< " for value node:" << from_node->DebugString();
|
||||
}
|
||||
to_actor->local_device_tensors_[to_node_with_index.second] =
|
||||
AnfAlgo::GetMutableOutputAddr(from_node, from_node_with_index.second, false).get();
|
||||
to_actor->local_device_tensors_[to_node_with_index.second]->SetNodeIndex(from_node, from_node_with_index.second);
|
||||
}
|
||||
} else if (from_node->isa<Parameter>()) {
|
||||
// Link arrow from entrance actor.
|
||||
|
@ -451,8 +480,6 @@ void ControlNodeScheduler::LinkArrowByKernel(const AnfNodePtr &kernel, ControlAc
|
|||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, false);
|
||||
UpdateRefCount(device_tensor.get(), true);
|
||||
device_tensor->SetNodeIndex(kernel_with_index.first, kernel_with_index.second);
|
||||
|
||||
kernel_actor->output_data_nodes_.emplace_back(kernel_with_index.first);
|
||||
LinkDataArrow(kernel_actor, to_actor, kernel_with_index.second, to_node_with_index.second);
|
||||
} else {
|
||||
// Link arrow from exit actor.
|
||||
|
@ -535,7 +562,25 @@ void ControlNodeScheduler::LinkBranchIDArrowForControlActor(ControlActorSet *con
|
|||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto entrance_actor = dynamic_cast<EntranceActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(entrance_actor);
|
||||
entrance_actor->output_branch_id_arrows_.emplace_back(exit_actor->GetAID());
|
||||
LinkBranchIDArrow(entrance_actor, exit_actor.get());
|
||||
}
|
||||
|
||||
// Connect the branch id arrows from the entrance actor to the stack actor.
|
||||
for (auto stack_actor : control_actor_set->stack_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(stack_actor);
|
||||
if (stack_actor->formal_parameters_.empty()) {
|
||||
MS_LOG(ERROR) << "Invalid stack actor:" << stack_actor->GetAID();
|
||||
}
|
||||
const auto &node = stack_actor->formal_parameters_.back().first;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
const auto &func_graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
const auto &actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
|
||||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto entrance_actor = dynamic_cast<EntranceActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(entrance_actor);
|
||||
LinkBranchIDArrow(entrance_actor, stack_actor.get());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -587,7 +632,7 @@ void ControlNodeScheduler::LinkDataArrowByKernelGraph(const KernelGraphPtr &grap
|
|||
// If there is a call node in the input of the graph, the parameter of the graph needs to be sent by the
|
||||
// corresponding stack actor, otherwise it is sent by the entrance actor.
|
||||
if (is_call_input_graph) {
|
||||
auto actor = FetchActor(graph->ToString());
|
||||
auto actor = FetchActor(graph->ToString() + kStackActorNameSuffix);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
from_actor = dynamic_cast<ControlActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
|
@ -623,7 +668,8 @@ void ControlNodeScheduler::LinkDataArrowForOutputActor(ActorSet *const actor_set
|
|||
auto exit_actor = dynamic_cast<ExitActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(exit_actor);
|
||||
for (size_t i = 0; i < exit_actor->formal_parameters_.size(); ++i) {
|
||||
LinkDataArrow(exit_actor, to_actor.get(), i, i);
|
||||
LinkDataArrowForExitActor(exit_actor, to_actor.get(), i, i, 0);
|
||||
to_actor->input_datas_num_++;
|
||||
}
|
||||
|
||||
auto control_node_to_device_contexts = parser->control_node_to_device_contexts_;
|
||||
|
@ -659,10 +705,7 @@ void ControlNodeScheduler::LinkDataArrowForHostDSActor(const GraphCompilerInfo &
|
|||
const auto &iter = host_ds_actor->data_node_position_map_.find(formal_parameter.first);
|
||||
if (iter != host_ds_actor->data_node_position_map_.end()) {
|
||||
const auto ¶meter = host_ds_actor->data_nodes()[iter->second];
|
||||
LinkDataArrow(host_ds_actor, to_actor, iter->second, i);
|
||||
|
||||
// Set the source node to the device address.
|
||||
host_ds_actor->output_data_nodes_.emplace_back(parameter);
|
||||
LinkDataArrow(host_ds_actor, to_actor, iter->second, i, parameter);
|
||||
if (!AnfAlgo::OutputAddrExist(parameter, 0, false)) {
|
||||
MS_LOG(EXCEPTION) << "Invalid output index:" << 0 << " for parameter:" << parameter->DebugString();
|
||||
}
|
||||
|
@ -674,12 +717,13 @@ void ControlNodeScheduler::LinkDataArrowForHostDSActor(const GraphCompilerInfo &
|
|||
}
|
||||
|
||||
void ControlNodeScheduler::LinkDataArrow(AbstractActor *const from_actor, AbstractActor *const to_actor,
|
||||
size_t from_index, size_t to_index) {
|
||||
size_t from_index, size_t to_index, const AnfNodePtr &from_kernel) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
|
||||
auto data_arrow = std::make_shared<DataArrow>(from_index, to_actor->GetAID(), to_index);
|
||||
(void)from_actor->output_data_arrows_.emplace_back(data_arrow);
|
||||
(void)from_actor->output_data_nodes_.emplace_back(from_kernel);
|
||||
to_actor->input_datas_num_++;
|
||||
(void)to_actor->input_data_arrow_aids_.emplace_back(from_actor->GetAID());
|
||||
}
|
||||
|
@ -692,7 +736,7 @@ void ControlNodeScheduler::LinkControlArrow(AbstractActor *from_actor, AbstractA
|
|||
(void)to_actor->input_control_arrow_aids_.emplace_back(from_actor->GetAID());
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::LinkDataArrowForExitActor(ExitActor *const exit_actor, ControlActor *const to_actor,
|
||||
void ControlNodeScheduler::LinkDataArrowForExitActor(ExitActor *const exit_actor, AbstractActor *const to_actor,
|
||||
size_t from_index, size_t to_index, int branch_id) {
|
||||
MS_EXCEPTION_IF_NULL(exit_actor);
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
|
@ -726,6 +770,14 @@ void ControlNodeScheduler::LinkPartialArrow(ControlActor *const from_actor, Cont
|
|||
auto op_arrow = std::make_shared<DataArrow>(from_index, to_actor->GetAID(), to_index);
|
||||
from_actor->output_partial_arrows_.emplace_back(op_arrow);
|
||||
to_actor->input_partials_num_++;
|
||||
to_actor->input_partial_arrow_aids_.emplace_back(from_actor->GetAID());
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::LinkBranchIDArrow(ControlActor *const from_actor, ControlActor *const to_actor) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
from_actor->output_branch_id_arrows_.emplace_back(to_actor->GetAID());
|
||||
to_actor->input_branch_id_arrow_aids_.emplace_back(from_actor->GetAID());
|
||||
}
|
||||
|
||||
bool ControlNodeScheduler::CheckActorValid(const ControlActorSetPtr &control_actor_set) {
|
||||
|
|
|
@ -78,12 +78,13 @@ class ControlNodeScheduler {
|
|||
const FuncGraphPtr &func_graph);
|
||||
void LinkPartialArrow(ControlActor *const from_actor, ControlActor *const to_actor, size_t from_index,
|
||||
size_t to_index);
|
||||
void LinkDataArrow(AbstractActor *const from_actor, AbstractActor *const to_actor, size_t from_index,
|
||||
size_t to_index);
|
||||
void LinkDataArrow(AbstractActor *const from_actor, AbstractActor *const to_actor, size_t from_index, size_t to_index,
|
||||
const AnfNodePtr &from_kernel = nullptr);
|
||||
void LinkBranchIDArrow(ControlActor *const from_actor, ControlActor *const to_actor);
|
||||
|
||||
// Since the output of exit actor has branches, it needs to be based on a dedicated interface.
|
||||
void LinkControlArrowForExitActor(ExitActor *from_actor, AbstractActor *to_actor, int branch_id);
|
||||
void LinkDataArrowForExitActor(ExitActor *const exit_actor, ControlActor *const to_actor, size_t from_index,
|
||||
void LinkDataArrowForExitActor(ExitActor *const exit_actor, AbstractActor *const to_actor, size_t from_index,
|
||||
size_t to_index, int branch_id);
|
||||
};
|
||||
} // namespace runtime
|
||||
|
|
|
@ -625,10 +625,7 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
|
|||
// the corresponding backend parameter from the map, and insert it into the host data source actor
|
||||
const auto &control_node_parameters = graph_compiler_info.control_node_parser_->control_node_parameters();
|
||||
for (const auto ¶meter : control_node_parameters) {
|
||||
if (IsPersistentDeviceTensor(parameter)) {
|
||||
continue;
|
||||
}
|
||||
auto backend_iter = front_to_backend_parameter.find(parameter);
|
||||
auto backend_iter = front_to_backend_parameter.find({parameter, 0});
|
||||
if (backend_iter == front_to_backend_parameter.end() || backend_iter->second.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find backend node for front node:" << AnfAlgo::GetNodeDebugString(parameter);
|
||||
}
|
||||
|
@ -1730,6 +1727,9 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler
|
|||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
if (IsPersistentDeviceTensor(input_node)) {
|
||||
AddDeviceTensorStore(front_node.get(), device_tensor);
|
||||
// In the control flow, the device tensor of the weight needs to be obtained according to the backend node,
|
||||
// so insert the relationship between the backend node and the device tensor.
|
||||
AddDeviceTensorStore(input_node.get(), device_tensor);
|
||||
}
|
||||
|
||||
// Share the weight in the host and device, then input_node is internal parameter and front_node is weight.
|
||||
|
@ -1743,17 +1743,12 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler
|
|||
auto other_type_device_tensor = device_context->CreateDeviceAddress(
|
||||
nullptr, device_tensor->GetSize(), device_tensor->format(), device_tensor->type_id());
|
||||
AddDeviceTensorStore(front_node.get(), other_type_device_tensor);
|
||||
// In the control flow, the device tensor of the weight needs to be obtained according to the backend node,
|
||||
// so insert the relationship between the backend node and the device tensor.
|
||||
AddDeviceTensorStore(input_node.get(), other_type_device_tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// In control flow, there may be some value nodes that is not in the kernel graph and needs to be placed
|
||||
// in the tensor store separately.
|
||||
for (const auto &value_node : graph_compiler_info.control_node_parser_->front_value_nodes_) {
|
||||
MS_EXCEPTION_IF_NULL(value_node.first);
|
||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node.first, 0, false);
|
||||
AddDeviceTensorStore(value_node.first.get(), device_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::FetchKernelTransformTypeAndName(const AnfNodePtr &node, const KernelGraphPtr &graph,
|
||||
|
|
Loading…
Reference in New Issue