Fix codedex and pclint.
This commit is contained in:
parent
75ab763696
commit
3f3661fedc
|
@ -38,7 +38,7 @@ void GatherActor::Init() {
|
|||
|
||||
auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
|
||||
output_data_.emplace_back(data.get());
|
||||
output_data_by_output_index_[data_arrow->from_output_index_].emplace_back(std::move(data));
|
||||
output_data_by_output_index_[IntToSize(data_arrow->from_output_index_)].emplace_back(std::move(data));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -113,7 +113,7 @@ void GatherActor::SendOutput(OpContext<DeviceTensor> *context) const {
|
|||
// 2.Send output result.
|
||||
for (const auto &result_arrow : output_result_arrows_) {
|
||||
MS_EXCEPTION_IF_NULL(result_arrow);
|
||||
size_t from_index = result_arrow->from_output_index_;
|
||||
size_t from_index = IntToSize(result_arrow->from_output_index_);
|
||||
const auto &front_node = data_nodes_[from_index].first;
|
||||
for (const auto &backend_node : front_to_backend_parameter_.at(front_node)) {
|
||||
if (AnfAlgo::GetMutableOutputAddr(backend_node.first, backend_node.second, false).get() ==
|
||||
|
|
|
@ -85,7 +85,7 @@ class GatherActor : public OpActor<DeviceTensor> {
|
|||
// The device tensors for launch.
|
||||
std::vector<DeviceTensor *> input_device_tensors_;
|
||||
// The branch if for current step.
|
||||
int input_branch_id_;
|
||||
int input_branch_id_{kInvalidBranchID};
|
||||
|
||||
// Input data.
|
||||
std::unordered_map<uuids::uuid *, std::unordered_map<size_t, std::stack<DeviceTensor *>>> input_data_;
|
||||
|
|
|
@ -208,7 +208,7 @@ void SwitchActor::AddInput(const KernelWithIndex node_with_index, const size_t b
|
|||
branch_inputs_pos_[branch].push_back(iter - input_nodes_.begin());
|
||||
return;
|
||||
}
|
||||
device_tensor_store_keys_.push_back({input_nodes_.size(), node.get()});
|
||||
device_tensor_store_keys_.emplace_back(input_nodes_.size(), node.get());
|
||||
branch_inputs_pos_[branch].push_back(input_nodes_.size());
|
||||
input_nodes_.push_back(node_with_index);
|
||||
return;
|
||||
|
@ -245,7 +245,7 @@ void SwitchActor::AddInput(const AnfNodePtr &node, const size_t branch) {
|
|||
} else if (IsCallNode(real_input.first)) {
|
||||
std::vector<AnfNodePtr> call_nodes;
|
||||
const auto call_output_num = FetchOutputSizebyCallNode(real_input.first, &call_nodes);
|
||||
if (call_output_num <= 0) {
|
||||
if (call_output_num == 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid output num for call input:" << AnfAlgo::GetNodeDebugString(real_input.first);
|
||||
}
|
||||
for (size_t i = 0; i < call_output_num; ++i) {
|
||||
|
@ -268,7 +268,7 @@ size_t SwitchActor::GetIndex(OpContext<DeviceTensor> *context) {
|
|||
input_branch_ids_[context->sequential_num_].empty()) {
|
||||
MS_LOG(ERROR) << "Invalid branch id for actor:" + GetAID().Name();
|
||||
}
|
||||
size_t branch_id = input_branch_ids_[context->sequential_num_].top();
|
||||
auto branch_id = input_branch_ids_[context->sequential_num_].top();
|
||||
input_branch_ids_[context->sequential_num_].pop();
|
||||
if (branch_id_to_index_.find(branch_id) == branch_id_to_index_.end()) {
|
||||
MS_LOG(ERROR) << "Invalid branch id for switch actor:" + GetAID().Name() +
|
||||
|
@ -278,9 +278,8 @@ size_t SwitchActor::GetIndex(OpContext<DeviceTensor> *context) {
|
|||
}
|
||||
|
||||
DeviceTensor *device_tensor = input_device_tensors_[0];
|
||||
if (device_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Index of switch actor is empty:" + GetAID().Name();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
|
||||
auto inputs = node_->inputs();
|
||||
TypeId type_id = AnfAlgo::GetOutputInferDataType(inputs[kSwitchCondPos], 0);
|
||||
size_t size = abstract::TypeIdSize(type_id);
|
||||
|
@ -291,7 +290,10 @@ size_t SwitchActor::GetIndex(OpContext<DeviceTensor> *context) {
|
|||
int64_t index = 0;
|
||||
char buf[kMaxSwitchCondSize] = {0};
|
||||
ShapeVector host_shape;
|
||||
device_tensor->SyncDeviceToHost(host_shape, size, type_id, static_cast<void *>(buf));
|
||||
if (!device_tensor->SyncDeviceToHost(host_shape, size, type_id, static_cast<void *>(buf))) {
|
||||
MS_LOG(ERROR) << GetAID().Name() + " get index from device address failed, type id:" + std::to_string(type_id) +
|
||||
", device type:" + std::to_string(static_cast<int>(device_context_->GetDeviceAddressType()));
|
||||
}
|
||||
|
||||
if (type_id == TypeId::kNumberTypeInt32) {
|
||||
index = static_cast<int64_t>((static_cast<int32_t *>(static_cast<void *>(buf)))[0]);
|
||||
|
@ -306,7 +308,7 @@ size_t SwitchActor::GetIndex(OpContext<DeviceTensor> *context) {
|
|||
|
||||
// SwitchLayer node support negative index range [-size, -1].
|
||||
if (index < 0) {
|
||||
index += branch_func_graph_.size();
|
||||
index += SizeToInt(branch_func_graph_.size());
|
||||
}
|
||||
return static_cast<size_t>(index);
|
||||
}
|
||||
|
@ -403,7 +405,7 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
|
|||
" total:" + std::to_string(branch_inputs_pos_[index].size()) + " actor:" + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
size_t from_index = branch_inputs_pos_[index][result_arrow->from_output_index_];
|
||||
size_t from_index = branch_inputs_pos_[index][IntToSize(result_arrow->from_output_index_)];
|
||||
|
||||
MS_LOG(DEBUG) << "Switch actor:" << GetAID() << " send result addr:" << input_device_tensors_[from_index];
|
||||
bool is_send = false;
|
||||
|
|
|
@ -556,7 +556,7 @@ bool IsSubCallNode(const AnfNodePtr &node) {
|
|||
|
||||
std::vector<KernelWithIndex> FetchAllRealInputNodeByParameter(const KernelWithIndex &node) {
|
||||
std::vector<KernelWithIndex> parameters;
|
||||
const auto &real_node_with_index = AnfAlgo::VisitKernelWithReturnType(node.first, node.second);
|
||||
const auto &real_node_with_index = AnfAlgo::VisitKernelWithReturnType(node.first, SizeToInt(node.second));
|
||||
const auto &real_node = real_node_with_index.first;
|
||||
if (real_node->isa<Parameter>()) {
|
||||
if (!HasAbstractRef(real_node) && !HasAbstractMonad(real_node)) {
|
||||
|
@ -749,7 +749,7 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
|
|||
}
|
||||
}
|
||||
|
||||
FetchFrontToBackendParameter(graphs, device_contexts, control_nodes, real_to_formal_front_parameters,
|
||||
FetchFrontToBackendParameter(graphs, device_contexts, real_to_formal_front_parameters,
|
||||
formal_to_real_front_parameters);
|
||||
|
||||
FetchFuncGraphToParameter(control_nodes);
|
||||
|
@ -924,14 +924,14 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
|
|||
<< AnfAlgo::GetNodeDebugString(parameters[i - kCallInputStartPos])
|
||||
<< ", used the default format";
|
||||
CreateDeviceTensorForFrontParameter(inputs[i], device_contexts[0]);
|
||||
front_value_nodes_.push_back({inputs[i], device_contexts[0]});
|
||||
front_value_nodes_.emplace_back(inputs[i], device_contexts[0]);
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &backend_node = front_to_backend_parameters_[parameters[i - kCallInputStartPos]].first;
|
||||
const auto &device_context = front_to_backend_parameters_[parameters[i - kCallInputStartPos]].second;
|
||||
CreateDeviceTensorForValueNode(inputs[i], backend_node, device_context);
|
||||
front_value_nodes_.push_back({inputs[i], device_context});
|
||||
front_value_nodes_.emplace_back(inputs[i], device_context);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -940,18 +940,15 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
|
|||
for (size_t index = 0; index < graphs.size(); ++index) {
|
||||
const auto &graph = graphs[index];
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto execution_order = graph->execution_order();
|
||||
|
||||
for (const auto ¶meter : graph->input_nodes()) {
|
||||
const auto &front_node = graph->GetFrontAnfByBackendAnf(parameter);
|
||||
const auto &internal_node = graph->GetFrontNodeByInternalParameter(parameter);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
|
||||
if (IsInternalParameter(parameter, graph)) {
|
||||
auto front_node_with_index = graph->GetFrontNodeByInternalParameter(parameter);
|
||||
MS_EXCEPTION_IF_NULL(front_node_with_index.first);
|
||||
const auto &front_output_with_index =
|
||||
AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, front_node_with_index.second, false);
|
||||
const auto &front_output_with_index = AnfAlgo::VisitKernelWithReturnType(
|
||||
front_node_with_index.first, SizeToInt(front_node_with_index.second), false);
|
||||
auto front_output_node = front_output_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(front_output_node);
|
||||
if (AnfAlgo::CheckPrimitiveType(front_output_node, prim::kPrimSwitch)) {
|
||||
|
@ -959,7 +956,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
|
|||
FetchValueNodeBySwitchNode(front_output_node, &value_nodes);
|
||||
for (const auto value_node : value_nodes) {
|
||||
CreateDeviceTensorForValueNode(value_node, parameter, device_contexts[index]);
|
||||
front_value_nodes_.push_back({value_node, device_contexts[index]});
|
||||
front_value_nodes_.emplace_back(value_node, device_contexts[index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -975,7 +972,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
|
|||
if (output->isa<ValueNode>() && GetFrontValueNodeDeviceContext(output) == nullptr) {
|
||||
const auto &device_context = call_node_to_backend_parameter.second.second;
|
||||
CreateDeviceTensorForValueNode(output, call_node_to_backend_parameter.second.first, device_context);
|
||||
front_value_nodes_.push_back({output, device_context});
|
||||
front_value_nodes_.emplace_back(output, device_context);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1054,7 +1051,7 @@ void ControlNodeParser::FetchFrontToFrontParameter(
|
|||
if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitch) ||
|
||||
AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) {
|
||||
std::vector<AnfNodePtr> call_inputs;
|
||||
call_inputs.assign(inputs.begin() + kCallInputStartPos, inputs.end());
|
||||
call_inputs.assign(inputs.begin() + SizeToInt(kCallInputStartPos), inputs.end());
|
||||
switch_input_parse(inputs[0], call_inputs);
|
||||
} else if (IsCallNode(inputs[0])) {
|
||||
continue;
|
||||
|
@ -1095,7 +1092,7 @@ std::vector<AnfNodePtr> ControlNodeParser::FetchControlNodeParameter(const std::
|
|||
if (backend_iter == front_to_backend_parameters_.end()) {
|
||||
CreateDeviceTensorForFrontParameter(parameter, device_context);
|
||||
front_to_backend_parameters_[parameter] = {parameter, device_context};
|
||||
front_parameters_.push_back({parameter, device_context});
|
||||
front_parameters_.emplace_back(parameter, device_context);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1215,7 +1212,6 @@ std::vector<KernelWithIndex> FetchParameterbyKernelGraph(const KernelGraphPtr &g
|
|||
|
||||
void ControlNodeParser::FetchFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts,
|
||||
const std::vector<AnfNodePtr> &control_nodes,
|
||||
const RealToFormalNode &real_to_formal_front_parameters,
|
||||
const RealToFormalNode &formal_to_real_front_parameters) {
|
||||
if (graphs.size() != device_contexts.size()) {
|
||||
|
@ -1356,12 +1352,12 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_
|
|||
std::set<AnfNodePtr> *switch_nodes,
|
||||
std::set<KernelWithIndex> *results) {
|
||||
if (front_output->isa<ValueNode>()) {
|
||||
(*results).insert({front_output, 0});
|
||||
(*results).emplace(front_output, 0);
|
||||
|
||||
const auto &iter = formal_to_real_parameters_.find(front_output);
|
||||
if (iter != formal_to_real_parameters_.end()) {
|
||||
for (const auto &node : iter->second) {
|
||||
(*results).insert(node);
|
||||
(*results).emplace(node);
|
||||
}
|
||||
}
|
||||
} else if (front_output->isa<Parameter>()) {
|
||||
|
@ -1370,7 +1366,7 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_
|
|||
|
||||
if (iter != formal_to_real_parameters_.end()) {
|
||||
for (const auto &node : iter->second) {
|
||||
(*results).insert(node);
|
||||
(*results).emplace(node);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Cannot find backend node for front parameter:" << AnfAlgo::GetNodeDebugString(front_output);
|
||||
|
@ -1402,7 +1398,7 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_
|
|||
const auto iter = front_to_backend_kernels_.find(AnfAlgo::VisitKernelWithReturnType(front_output, 0));
|
||||
|
||||
if (iter != front_to_backend_kernels_.end()) {
|
||||
(*results).insert(iter->second.first);
|
||||
(*results).emplace(iter->second.first);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Cannot find backend node for front kernel:" << AnfAlgo::GetNodeDebugString(front_output);
|
||||
}
|
||||
|
@ -1429,7 +1425,7 @@ void ControlNodeParser::FetchBackendInputNodebyFrontNode(
|
|||
MS_LOG(WARNING) << "Cannot find backend node of node:" << AnfAlgo::GetNodeDebugString(node_with_index.first);
|
||||
continue;
|
||||
}
|
||||
formal_to_real_parameters_[formal_parameter].push_back({iter->second.first, 0});
|
||||
formal_to_real_parameters_[formal_parameter].emplace_back(iter->second.first, 0);
|
||||
} else {
|
||||
const auto iter = front_to_backend_kernels_.find(node_with_index);
|
||||
if (iter == front_to_backend_kernels_.end()) {
|
||||
|
@ -1439,7 +1435,7 @@ void ControlNodeParser::FetchBackendInputNodebyFrontNode(
|
|||
}
|
||||
}
|
||||
} else if (real_parameter->isa<ValueNode>()) {
|
||||
formal_to_real_parameters_[formal_parameter].push_back({real_parameter, 0});
|
||||
formal_to_real_parameters_[formal_parameter].emplace_back(real_parameter, 0);
|
||||
} else if (IsCallNode(real_parameter)) {
|
||||
const auto func_graphs = FetchFuncGraphbyCallNode(real_parameter);
|
||||
for (const auto func_graph : func_graphs) {
|
||||
|
@ -1512,7 +1508,7 @@ void ControlNodeParser::FetchBackendInputNode(const std::vector<KernelGraphPtr>
|
|||
auto front_node = graph->GetFrontAnfByBackendAnf(value_node);
|
||||
|
||||
if (front_node != nullptr) {
|
||||
formal_to_real_parameters_[front_node].push_back({value_node, 0});
|
||||
formal_to_real_parameters_[front_node].emplace_back(value_node, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1521,7 +1517,7 @@ void ControlNodeParser::FetchBackendInputNode(const std::vector<KernelGraphPtr>
|
|||
for (const auto &front_weight : host_parameter_to_weight.second) {
|
||||
const auto &iter = front_to_backend_parameters_.find(host_parameter_to_weight.first);
|
||||
if (iter != front_to_backend_parameters_.end()) {
|
||||
formal_to_real_parameters_[front_weight].push_back({iter->second.first, 0});
|
||||
formal_to_real_parameters_[front_weight].emplace_back(iter->second.first, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1552,10 +1548,10 @@ void ControlNodeParser::FetchBackendInputNode(const std::vector<KernelGraphPtr>
|
|||
}
|
||||
}
|
||||
for (const auto parameter_pair : front_to_backend_parameters) {
|
||||
formal_to_real_parameters_[parameter_pair.first].push_back({parameter_pair.second.first, 0});
|
||||
formal_to_real_parameters_[parameter_pair.first].emplace_back(parameter_pair.second.first, 0);
|
||||
}
|
||||
for (const auto parameter_pair : front_to_backend_parameters_) {
|
||||
formal_to_real_parameters_[parameter_pair.first].push_back({parameter_pair.second.first, 0});
|
||||
formal_to_real_parameters_[parameter_pair.first].emplace_back(parameter_pair.second.first, 0);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -150,7 +150,6 @@ class ControlNodeParser {
|
|||
// 2. The parameter from control nodes.
|
||||
void FetchFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts,
|
||||
const std::vector<AnfNodePtr> &control_nodes,
|
||||
const RealToFormalNode &real_to_formal_front_parameters,
|
||||
const RealToFormalNode &formal_to_real_front_parameters);
|
||||
// Get the relationship between the front and backend of the executable kernel in all kernel graphs.
|
||||
|
|
|
@ -1262,7 +1262,7 @@ std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompiler
|
|||
if (HasAbstractMonad(parameter) || HasAbstractRef(parameter)) {
|
||||
continue;
|
||||
}
|
||||
parameters.push_back({parameter, 0});
|
||||
parameters.emplace_back(parameter, 0);
|
||||
}
|
||||
|
||||
const auto branch_id = parser->GetBranchIDByFuncGraph(func_graph);
|
||||
|
@ -1291,7 +1291,7 @@ std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompiler
|
|||
if (HasAbstractMonad(inputs[i]) || (inputs[i]->isa<Parameter>() && HasAbstractRef(inputs[i]))) {
|
||||
continue;
|
||||
}
|
||||
parameters.push_back({inputs[i], 0});
|
||||
parameters.emplace_back(inputs[i], 0);
|
||||
}
|
||||
|
||||
auto func_graph = control_node->func_graph();
|
||||
|
@ -1335,7 +1335,7 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const GraphCompilerInf
|
|||
if (from_kernel->isa<Parameter>() && graph_compiler_info.control_node_parser_->IsCallInputKernelGraph(graph)) {
|
||||
const auto &kernel_with_index = GetFrontNodeByKernelGraph(from_kernel, graph);
|
||||
const auto &real_front_node_with_index =
|
||||
AnfAlgo::VisitKernelWithReturnType(kernel_with_index.first, kernel_with_index.second);
|
||||
AnfAlgo::VisitKernelWithReturnType(kernel_with_index.first, SizeToInt(kernel_with_index.second));
|
||||
if (HasAbstractRef(real_front_node_with_index.first)) {
|
||||
to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second,
|
||||
real_front_node_with_index.first.get());
|
||||
|
@ -1966,6 +1966,7 @@ void GraphScheduler::LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo
|
|||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||
|
||||
// Set branch index into switch actor.
|
||||
size_t branch_index = switch_actor->branch_id_to_index_.size();
|
||||
|
@ -2071,7 +2072,6 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ke
|
|||
void GraphScheduler::PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr> &control_nodes) {
|
||||
for (const auto &node : control_nodes) {
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
const auto &from_func_graph = node->func_graph();
|
||||
auto inputs = cnode->inputs();
|
||||
|
||||
// Before link data arrow, parameters of the call node in switch-call need to be add to the switch actor.
|
||||
|
@ -2079,6 +2079,8 @@ void GraphScheduler::PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr
|
|||
auto actor = FetchActor(inputs[0]->DebugString());
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||
|
||||
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
|
||||
if (HasAbstractMonad(inputs[i])) {
|
||||
continue;
|
||||
|
@ -2107,6 +2109,7 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi
|
|||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto gather_actor = dynamic_cast<GatherActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(gather_actor);
|
||||
|
||||
const auto &func_graph = GetValueNode<FuncGraphPtr>(inputs[0]);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -2123,8 +2126,8 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi
|
|||
continue;
|
||||
}
|
||||
|
||||
gather_actor->device_tensor_store_keys_.push_back(
|
||||
{i - kCallInputStartPos - persist_input_num, inputs[i].get()});
|
||||
gather_actor->device_tensor_store_keys_.emplace_back(i - kCallInputStartPos - persist_input_num,
|
||||
inputs[i].get());
|
||||
gather_actor->device_contexts_[i - kCallInputStartPos - persist_input_num] =
|
||||
graph_compiler_info.control_node_parser_->GetFrontValueNodeDeviceContext(inputs[i]);
|
||||
} else if ((inputs[i]->isa<Parameter>() && HasAbstractRef(inputs[i]->cast<ParameterPtr>())) ||
|
||||
|
@ -2150,7 +2153,9 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto actor = FetchActor(func_graph->get_return()->DebugString());
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
LinkDataArrowForSwitchActor(graph_compiler_info, dynamic_cast<SwitchActor *>(actor));
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||
LinkDataArrowForSwitchActor(graph_compiler_info, switch_actor);
|
||||
}
|
||||
|
||||
// Link arrow for gather actor for call input kernel graph.
|
||||
|
@ -2160,6 +2165,7 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi
|
|||
auto actor = FetchActor(kernel_graph->ToString());
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto gather_actor = dynamic_cast<GatherActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(gather_actor);
|
||||
|
||||
for (size_t i = 0; i < gather_actor->data_nodes_.size(); ++i) {
|
||||
const auto &input_with_index = gather_actor->data_nodes_[i];
|
||||
|
@ -2286,8 +2292,10 @@ void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_c
|
|||
const auto &actor_name = backend_node.first->fullname_with_scope();
|
||||
const auto &actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto op_arrow = std::make_shared<DataArrow>(backend_node.second, to_actor->GetAID(), to_index);
|
||||
auto from_actor = dynamic_cast<KernelActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
|
||||
auto op_arrow = std::make_shared<DataArrow>(backend_node.second, to_actor->GetAID(), to_index);
|
||||
from_actor->output_data_arrows_.emplace_back(op_arrow);
|
||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->kernel_, backend_node.second, false);
|
||||
UpdateRefCount(device_tensor.get(), true);
|
||||
|
@ -2416,6 +2424,8 @@ void GraphScheduler::LinkControlArrowForGatherActor(std::vector<KernelActorPtr>
|
|||
actor = FetchActor(func_graph->get_return()->DebugString());
|
||||
if (actor != nullptr) {
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||
|
||||
kernel_actor->output_control_arrows_.emplace_back(switch_actor->GetAID());
|
||||
switch_actor->input_controls_num_++;
|
||||
}
|
||||
|
@ -2458,6 +2468,7 @@ void GraphScheduler::LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr>
|
|||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto gather_actor = dynamic_cast<GatherActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(gather_actor);
|
||||
gather_actor->output_control_arrows_.emplace_back(switch_actor->GetAID());
|
||||
switch_actor->input_controls_num_++;
|
||||
}
|
||||
|
@ -2470,6 +2481,7 @@ void GraphScheduler::LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr>
|
|||
const auto &actor = FetchActor(actor_name);
|
||||
if (actor != nullptr) {
|
||||
const auto &gather_actor = dynamic_cast<GatherActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(gather_actor);
|
||||
switch_actor->output_branch_control_arrows_[i].emplace_back(gather_actor->GetAID());
|
||||
gather_actor->input_controls_num_++;
|
||||
}
|
||||
|
@ -2495,6 +2507,7 @@ void GraphScheduler::LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr>
|
|||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||
|
||||
size_t branch_index = switch_actor->branch_id_to_index_.size();
|
||||
if (switch_actor->branch_id_to_index_.find(kMainBranchID) != switch_actor->branch_id_to_index_.end()) {
|
||||
|
@ -2517,6 +2530,8 @@ void GraphScheduler::LinkBranchArrowForSwitchActor(const GraphCompilerInfo &grap
|
|||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||
|
||||
for (size_t i = 0; i < switch_actor->branch_func_graph_.size(); ++i) {
|
||||
const auto &func_graph = switch_actor->branch_func_graph_[i];
|
||||
if (func_graph == nullptr) {
|
||||
|
@ -2556,6 +2571,7 @@ void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &grap
|
|||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto gather_actor = dynamic_cast<GatherActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(gather_actor);
|
||||
gather_actor->output_branch_arrows_.emplace_back(gather_actor->switch_aid_);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue