!30229 Commit code in 1.6.

Merge pull request !30229 from gaoyong10/runtime_second12
This commit is contained in:
i-robot 2022-02-19 08:14:41 +00:00 committed by Gitee
commit 3215493944
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 70 additions and 33 deletions

View File

@ -119,6 +119,20 @@ bool CPUDeviceAddress::SyncHostToDevice(const ShapeVector &, size_t size, TypeId
MS_LOG(WARNING) << "Please check whether need sync data, host size: " << size << ", device size: " << size_;
return true;
}
// If the value of host is a scalar type, then the host addr is a temporary address, which will be released after
// the sync ends. Therefore, if the value is less than 16, it needs to be copied.
const size_t kCopySize = 16;
if (size <= kCopySize) {
auto ret = memcpy_s(ptr_, size, host_ptr, size);
if (ret != EOK) {
MS_LOG(ERROR) << "Failed to copy tensor!";
return false;
} else {
return true;
}
}
// Use the tensor host ptr to set the device ptr.
if (from_mem_pool_) {
CPUMemoryPool::GetInstance().FreeTensorMem(ptr_);

View File

@ -80,7 +80,7 @@ bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph,
// In control flow, only the parameters of the root funcgraph are in the host data source.
const auto &front_node = graph->GetFrontAnfByBackendAnf(node);
bool is_host = ((front_node == nullptr) || host_parameters.empty() ||
bool is_host = ((front_node == nullptr) ||
find(host_parameters.begin(), host_parameters.end(), front_node) != host_parameters.end());
// Judge whether node is internal parameter.

View File

@ -129,6 +129,28 @@ std::vector<ControlActorPtr> CollectActors(const ControlActorSetPtr &control_act
return actors;
}
std::vector<AnfNodePtr> FetchAllMonadNodeByNode(const AnfNodePtr &node) {
if (!node->isa<CNode>()) {
return {};
}
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad)) {
return {node};
}
std::vector<AnfNodePtr> results;
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
const auto &cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
for (const auto &input : cnode->inputs()) {
MS_EXCEPTION_IF_NULL(input);
const auto &result = FetchAllMonadNodeByNode(input);
results.insert(results.end(), result.begin(), result.end());
}
}
return results;
}
} // namespace
ControlActorSetPtr ControlNodeScheduler::Build(const GraphCompilerInfo &graph_compiler_info,
@ -557,7 +579,7 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
if (!parser->IsNeedStackControlNode(switch_actor->node_)) {
for (size_t i = 0; i < switch_actor->formal_parameters_.size(); ++i) {
LinkArrowbyFormalParameter(switch_actor.get(), switch_actor->formal_parameters_[i], {switch_actor->node_, i},
parser);
graph_compiler_info);
}
} else {
// If the control actor has a corresponding stack actor, the input should be linked to the stack actor.
@ -566,7 +588,7 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
MS_EXCEPTION_IF_NULL(actor);
auto stack_actor = dynamic_cast<StackActor *>(actor);
MS_EXCEPTION_IF_NULL(stack_actor);
LinkArrowFromStackActor(stack_actor, switch_actor.get(), parser);
LinkArrowFromStackActor(stack_actor, switch_actor.get(), graph_compiler_info);
}
}
@ -575,7 +597,7 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
if (!parser->IsNeedStackControlNode(gather_actor->node_)) {
for (size_t i = 0; i < gather_actor->formal_parameters_.size(); ++i) {
LinkArrowbyFormalParameter(gather_actor.get(), gather_actor->formal_parameters_[i], {gather_actor->node_, i},
parser);
graph_compiler_info);
}
} else {
// If the control actor has a corresponding stack actor, the input should be linked to the stack actor.
@ -584,13 +606,13 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
MS_EXCEPTION_IF_NULL(actor);
auto stack_actor = dynamic_cast<StackActor *>(actor);
MS_EXCEPTION_IF_NULL(stack_actor);
LinkArrowFromStackActor(stack_actor, gather_actor.get(), parser);
LinkArrowFromStackActor(stack_actor, gather_actor.get(), graph_compiler_info);
}
}
for (auto &entrance_actor : control_actor_set->entrance_actors_) {
for (const auto &call_node : entrance_actor->call_nodes_) {
LinkArrowbyFormalParameter(entrance_actor.get(), call_node, {entrance_actor->node_, 0}, parser);
LinkArrowbyFormalParameter(entrance_actor.get(), call_node, {entrance_actor->node_, 0}, graph_compiler_info);
}
}
@ -602,28 +624,31 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
auto actor = FetchActor(stack_actor_name);
if (actor == nullptr) {
for (size_t i = 0; i < exit_actor->formal_parameters_.size(); ++i) {
LinkArrowbyFormalParameter(exit_actor.get(), exit_actor->formal_parameters_[i], {exit_actor->node_, i}, parser);
LinkArrowbyFormalParameter(exit_actor.get(), exit_actor->formal_parameters_[i], {exit_actor->node_, i},
graph_compiler_info);
}
} else {
// If the control actor has a corresponding stack actor, the input should be linked to the stack actor.
auto stack_actor = dynamic_cast<StackActor *>(actor);
MS_EXCEPTION_IF_NULL(stack_actor);
LinkArrowFromStackActor(stack_actor, exit_actor.get(), parser);
LinkArrowFromStackActor(stack_actor, exit_actor.get(), graph_compiler_info);
}
}
for (auto &stack_actor : control_actor_set->stack_actors_) {
for (size_t i = 0; i < stack_actor->formal_parameters_.size(); ++i) {
LinkArrowbyFormalParameter(stack_actor.get(), stack_actor->formal_parameters_[i], {stack_actor->node_, i},
parser);
graph_compiler_info);
}
}
}
void ControlNodeScheduler::LinkArrowFromStackActor(StackActor *const stack_actor, ControlActor *const to_actor,
const ControlNodeParserPtr &parser) {
const GraphCompilerInfo &graph_compiler_info) {
MS_EXCEPTION_IF_NULL(stack_actor);
MS_EXCEPTION_IF_NULL(to_actor);
MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
const auto &parser = graph_compiler_info.control_node_parser_;
for (size_t to_index = 0; to_index < to_actor->formal_parameters_.size(); ++to_index) {
const auto &formal_parameter = to_actor->formal_parameters_[to_index];
@ -638,7 +663,7 @@ void ControlNodeScheduler::LinkArrowFromStackActor(StackActor *const stack_actor
(!AnfAlgo::IsCallNode(from_node)) && (!AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimPartial)) &&
to_actor->GetAID().Name().find(
parser->FetchGroupNameByKernelGraph(parser->FetchKernelGraphByFrontNode(from_node))) != std::string::npos) {
LinkArrowByKernel(from_node, to_actor, formal_parameter, {to_actor->node_, to_index}, parser);
LinkArrowByKernel(from_node, to_actor, formal_parameter, {to_actor->node_, to_index}, graph_compiler_info);
continue;
}
@ -660,16 +685,18 @@ void ControlNodeScheduler::LinkArrowFromStackActor(StackActor *const stack_actor
void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_actor,
const KernelWithIndex &from_node_with_index,
const KernelWithIndex &to_node_with_index,
const ControlNodeParserPtr &parser) {
const GraphCompilerInfo &graph_compiler_info) {
const auto &from_node = from_node_with_index.first;
MS_EXCEPTION_IF_NULL(from_node);
if (from_node->isa<ValueNode>()) {
LinkArrowByValueNode(from_node, to_actor, from_node_with_index.second, to_node_with_index.second);
} else if (from_node->isa<Parameter>()) {
LinkArrowByParameter(from_node, to_actor, from_node_with_index, to_node_with_index, parser);
LinkArrowByParameter(from_node, to_actor, from_node_with_index, to_node_with_index,
graph_compiler_info.control_node_parser_);
} else if (AnfAlgo::IsCallNode(from_node)) {
// Link arrow by call node.
LinkArrowByCallNode(from_node, to_actor, from_node_with_index, to_node_with_index, parser);
LinkArrowByCallNode(from_node, to_actor, from_node_with_index, to_node_with_index,
graph_compiler_info.control_node_parser_);
} else if (AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitchLayer)) {
// Link arrow from switch actor.
@ -678,14 +705,7 @@ void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_act
MS_EXCEPTION_IF_NULL(actor);
const auto &switch_actor = dynamic_cast<SwitchActor *>(actor);
MS_EXCEPTION_IF_NULL(switch_actor);
const auto &abstract = from_node->abstract();
MS_EXCEPTION_IF_NULL(abstract);
if (abstract->isa<abstract::AbstractFunction>()) {
LinkPartialArrow(switch_actor, to_actor, from_node_with_index.second, to_node_with_index.second);
} else {
LinkDataArrow(switch_actor, to_actor, from_node_with_index.second, to_node_with_index.second);
}
} else if (AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimPartial)) {
// Link arrow from gather actor
const auto &actor_name = GetActorName(from_node);
@ -699,7 +719,7 @@ void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_act
LinkPartialArrow(gather_actor, to_actor, from_node_with_index.second, to_node_with_index.second);
} else if (from_node->isa<CNode>()) {
// Link arrow by kernel.
LinkArrowByKernel(from_node, to_actor, from_node_with_index, to_node_with_index, parser);
LinkArrowByKernel(from_node, to_actor, from_node_with_index, to_node_with_index, graph_compiler_info);
}
}
@ -828,12 +848,13 @@ void ControlNodeScheduler::LinkArrowByCallNode(const AnfNodePtr &call_node, Cont
void ControlNodeScheduler::LinkArrowByKernel(const AnfNodePtr &kernel, ControlActor *const to_actor,
const KernelWithIndex &from_node_with_index,
const KernelWithIndex &to_node_with_index,
const ControlNodeParserPtr &parser) {
MS_EXCEPTION_IF_NULL(parser);
const GraphCompilerInfo &graph_compiler_info) {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(to_actor);
const auto &from_node = from_node_with_index.first;
MS_EXCEPTION_IF_NULL(from_node);
MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
const auto &parser = graph_compiler_info.control_node_parser_;
const auto &graph = parser->FetchKernelGraphByFrontNode(from_node);
MS_EXCEPTION_IF_NULL(graph);
const auto &group_name = parser->FetchGroupNameByKernelGraph(graph);
@ -844,7 +865,7 @@ void ControlNodeScheduler::LinkArrowByKernel(const AnfNodePtr &kernel, ControlAc
const auto &kernel_with_index = parser->FetchBackendNodeByFrontNode(from_node_with_index);
MS_EXCEPTION_IF_NULL(kernel_with_index.first);
auto type = FetchKernelTransformType(kernel_with_index.first, graph, {});
auto from_actor = FetchActor(type, "", kernel_with_index.first, graph);
auto from_actor = FetchActor(type, graph_compiler_info.name_, kernel_with_index.first, graph);
MS_EXCEPTION_IF_NULL(from_actor);
if (!AnfAlgo::OutputAddrExist(kernel_with_index.first, kernel_with_index.second, false)) {
MS_LOG(EXCEPTION) << "Invalid output index:" << kernel_with_index.second
@ -926,10 +947,10 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor
const auto &inputs = cnode->inputs();
for (const auto &input : inputs) {
MS_EXCEPTION_IF_NULL(input);
if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimUpdateState) ||
AnfAlgo::CheckPrimitiveType(input, prim::kPrimDepend) ||
AnfAlgo::CheckPrimitiveType(input, prim::kPrimLoad)) {
LinkControlArrowByAutoMonad(from_actor, input, parser);
std::vector<AnfNodePtr> monad_nodes = FetchAllMonadNodeByNode(input);
for (const auto &monad_node : monad_nodes) {
MS_EXCEPTION_IF_NULL(monad_node);
LinkControlArrowByAutoMonad(from_actor, monad_node, parser);
}
}
}

View File

@ -63,20 +63,22 @@ class ControlNodeScheduler {
// Link all arrows between control actors.
void LinkArrowForControlActor(ControlActorSet *const control_actor_set, const GraphCompilerInfo &graph_compiler_info);
void LinkArrowbyFormalParameter(ControlActor *const to_actor, const KernelWithIndex &from_node_with_index,
const KernelWithIndex &to_node_with_index, const ControlNodeParserPtr &parser);
const KernelWithIndex &to_node_with_index,
const GraphCompilerInfo &graph_compiler_info);
void LinkArrowByCallNode(const AnfNodePtr &call_node, ControlActor *const to_actor,
const KernelWithIndex &from_node_with_index, const KernelWithIndex &to_node_with_index,
const ControlNodeParserPtr &parser);
void LinkArrowByKernel(const AnfNodePtr &kernel, ControlActor *const to_actor,
const KernelWithIndex &from_node_with_index, const KernelWithIndex &to_node_with_index,
const ControlNodeParserPtr &parser);
const GraphCompilerInfo &graph_compiler_info);
void LinkArrowByParameter(const AnfNodePtr &parameter, ControlActor *const to_actor,
const KernelWithIndex &from_node_with_index, const KernelWithIndex &to_node_with_index,
const ControlNodeParserPtr &parser);
void LinkArrowByValueNode(const AnfNodePtr &value_node, ControlActor *const to_actor, size_t from_index,
size_t to_index);
// Link arrow from stack actor to control actor.
void LinkArrowFromStackActor(StackActor *stack_actor, ControlActor *to_actor, const ControlNodeParserPtr &parser);
void LinkArrowFromStackActor(StackActor *stack_actor, ControlActor *to_actor,
const GraphCompilerInfo &graph_compiler_info);
// Link data arrow between control actor and actor in frame, including kernel actor, output actor, data source actor.
void LinkDataArrowForKernelActor(const GraphCompilerInfo &graph_compiler_info);