forked from mindspore-Ecosystem/mindspore
!30229 Commit code in 1.6.
Merge pull request !30229 from gaoyong10/runtime_second12
This commit is contained in:
commit
3215493944
|
@ -119,6 +119,20 @@ bool CPUDeviceAddress::SyncHostToDevice(const ShapeVector &, size_t size, TypeId
|
|||
MS_LOG(WARNING) << "Please check whether need sync data, host size: " << size << ", device size: " << size_;
|
||||
return true;
|
||||
}
|
||||
|
||||
// If the value of host is a scalar type, then the host addr is a temporary address, which will be released after
|
||||
// the sync ends. Therefore, if the value is less than 16, it needs to be copied.
|
||||
const size_t kCopySize = 16;
|
||||
if (size <= kCopySize) {
|
||||
auto ret = memcpy_s(ptr_, size, host_ptr, size);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "Failed to copy tensor!";
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Use the tensor host ptr to set the device ptr.
|
||||
if (from_mem_pool_) {
|
||||
CPUMemoryPool::GetInstance().FreeTensorMem(ptr_);
|
||||
|
|
|
@ -80,7 +80,7 @@ bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph,
|
|||
|
||||
// In control flow, only the parameters of the root funcgraph are in the host data source.
|
||||
const auto &front_node = graph->GetFrontAnfByBackendAnf(node);
|
||||
bool is_host = ((front_node == nullptr) || host_parameters.empty() ||
|
||||
bool is_host = ((front_node == nullptr) ||
|
||||
find(host_parameters.begin(), host_parameters.end(), front_node) != host_parameters.end());
|
||||
|
||||
// Judge whether node is internal parameter.
|
||||
|
|
|
@ -129,6 +129,28 @@ std::vector<ControlActorPtr> CollectActors(const ControlActorSetPtr &control_act
|
|||
|
||||
return actors;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> FetchAllMonadNodeByNode(const AnfNodePtr &node) {
|
||||
if (!node->isa<CNode>()) {
|
||||
return {};
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState) ||
|
||||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad)) {
|
||||
return {node};
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> results;
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
for (const auto &input : cnode->inputs()) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
const auto &result = FetchAllMonadNodeByNode(input);
|
||||
results.insert(results.end(), result.begin(), result.end());
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ControlActorSetPtr ControlNodeScheduler::Build(const GraphCompilerInfo &graph_compiler_info,
|
||||
|
@ -557,7 +579,7 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
|
|||
if (!parser->IsNeedStackControlNode(switch_actor->node_)) {
|
||||
for (size_t i = 0; i < switch_actor->formal_parameters_.size(); ++i) {
|
||||
LinkArrowbyFormalParameter(switch_actor.get(), switch_actor->formal_parameters_[i], {switch_actor->node_, i},
|
||||
parser);
|
||||
graph_compiler_info);
|
||||
}
|
||||
} else {
|
||||
// If the control actor has a corresponding stack actor, the input should be linked to the stack actor.
|
||||
|
@ -566,7 +588,7 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
|
|||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto stack_actor = dynamic_cast<StackActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(stack_actor);
|
||||
LinkArrowFromStackActor(stack_actor, switch_actor.get(), parser);
|
||||
LinkArrowFromStackActor(stack_actor, switch_actor.get(), graph_compiler_info);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -575,7 +597,7 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
|
|||
if (!parser->IsNeedStackControlNode(gather_actor->node_)) {
|
||||
for (size_t i = 0; i < gather_actor->formal_parameters_.size(); ++i) {
|
||||
LinkArrowbyFormalParameter(gather_actor.get(), gather_actor->formal_parameters_[i], {gather_actor->node_, i},
|
||||
parser);
|
||||
graph_compiler_info);
|
||||
}
|
||||
} else {
|
||||
// If the control actor has a corresponding stack actor, the input should be linked to the stack actor.
|
||||
|
@ -584,13 +606,13 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
|
|||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto stack_actor = dynamic_cast<StackActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(stack_actor);
|
||||
LinkArrowFromStackActor(stack_actor, gather_actor.get(), parser);
|
||||
LinkArrowFromStackActor(stack_actor, gather_actor.get(), graph_compiler_info);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &entrance_actor : control_actor_set->entrance_actors_) {
|
||||
for (const auto &call_node : entrance_actor->call_nodes_) {
|
||||
LinkArrowbyFormalParameter(entrance_actor.get(), call_node, {entrance_actor->node_, 0}, parser);
|
||||
LinkArrowbyFormalParameter(entrance_actor.get(), call_node, {entrance_actor->node_, 0}, graph_compiler_info);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -602,28 +624,31 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
|
|||
auto actor = FetchActor(stack_actor_name);
|
||||
if (actor == nullptr) {
|
||||
for (size_t i = 0; i < exit_actor->formal_parameters_.size(); ++i) {
|
||||
LinkArrowbyFormalParameter(exit_actor.get(), exit_actor->formal_parameters_[i], {exit_actor->node_, i}, parser);
|
||||
LinkArrowbyFormalParameter(exit_actor.get(), exit_actor->formal_parameters_[i], {exit_actor->node_, i},
|
||||
graph_compiler_info);
|
||||
}
|
||||
} else {
|
||||
// If the control actor has a corresponding stack actor, the input should be linked to the stack actor.
|
||||
auto stack_actor = dynamic_cast<StackActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(stack_actor);
|
||||
LinkArrowFromStackActor(stack_actor, exit_actor.get(), parser);
|
||||
LinkArrowFromStackActor(stack_actor, exit_actor.get(), graph_compiler_info);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &stack_actor : control_actor_set->stack_actors_) {
|
||||
for (size_t i = 0; i < stack_actor->formal_parameters_.size(); ++i) {
|
||||
LinkArrowbyFormalParameter(stack_actor.get(), stack_actor->formal_parameters_[i], {stack_actor->node_, i},
|
||||
parser);
|
||||
graph_compiler_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::LinkArrowFromStackActor(StackActor *const stack_actor, ControlActor *const to_actor,
|
||||
const ControlNodeParserPtr &parser) {
|
||||
const GraphCompilerInfo &graph_compiler_info) {
|
||||
MS_EXCEPTION_IF_NULL(stack_actor);
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
|
||||
const auto &parser = graph_compiler_info.control_node_parser_;
|
||||
|
||||
for (size_t to_index = 0; to_index < to_actor->formal_parameters_.size(); ++to_index) {
|
||||
const auto &formal_parameter = to_actor->formal_parameters_[to_index];
|
||||
|
@ -638,7 +663,7 @@ void ControlNodeScheduler::LinkArrowFromStackActor(StackActor *const stack_actor
|
|||
(!AnfAlgo::IsCallNode(from_node)) && (!AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimPartial)) &&
|
||||
to_actor->GetAID().Name().find(
|
||||
parser->FetchGroupNameByKernelGraph(parser->FetchKernelGraphByFrontNode(from_node))) != std::string::npos) {
|
||||
LinkArrowByKernel(from_node, to_actor, formal_parameter, {to_actor->node_, to_index}, parser);
|
||||
LinkArrowByKernel(from_node, to_actor, formal_parameter, {to_actor->node_, to_index}, graph_compiler_info);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -660,16 +685,18 @@ void ControlNodeScheduler::LinkArrowFromStackActor(StackActor *const stack_actor
|
|||
void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_actor,
|
||||
const KernelWithIndex &from_node_with_index,
|
||||
const KernelWithIndex &to_node_with_index,
|
||||
const ControlNodeParserPtr &parser) {
|
||||
const GraphCompilerInfo &graph_compiler_info) {
|
||||
const auto &from_node = from_node_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(from_node);
|
||||
if (from_node->isa<ValueNode>()) {
|
||||
LinkArrowByValueNode(from_node, to_actor, from_node_with_index.second, to_node_with_index.second);
|
||||
} else if (from_node->isa<Parameter>()) {
|
||||
LinkArrowByParameter(from_node, to_actor, from_node_with_index, to_node_with_index, parser);
|
||||
LinkArrowByParameter(from_node, to_actor, from_node_with_index, to_node_with_index,
|
||||
graph_compiler_info.control_node_parser_);
|
||||
} else if (AnfAlgo::IsCallNode(from_node)) {
|
||||
// Link arrow by call node.
|
||||
LinkArrowByCallNode(from_node, to_actor, from_node_with_index, to_node_with_index, parser);
|
||||
LinkArrowByCallNode(from_node, to_actor, from_node_with_index, to_node_with_index,
|
||||
graph_compiler_info.control_node_parser_);
|
||||
} else if (AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitch) ||
|
||||
AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitchLayer)) {
|
||||
// Link arrow from switch actor.
|
||||
|
@ -678,14 +705,7 @@ void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_act
|
|||
MS_EXCEPTION_IF_NULL(actor);
|
||||
const auto &switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||
|
||||
const auto &abstract = from_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
if (abstract->isa<abstract::AbstractFunction>()) {
|
||||
LinkPartialArrow(switch_actor, to_actor, from_node_with_index.second, to_node_with_index.second);
|
||||
} else {
|
||||
LinkDataArrow(switch_actor, to_actor, from_node_with_index.second, to_node_with_index.second);
|
||||
}
|
||||
} else if (AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimPartial)) {
|
||||
// Link arrow from gather actor
|
||||
const auto &actor_name = GetActorName(from_node);
|
||||
|
@ -699,7 +719,7 @@ void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_act
|
|||
LinkPartialArrow(gather_actor, to_actor, from_node_with_index.second, to_node_with_index.second);
|
||||
} else if (from_node->isa<CNode>()) {
|
||||
// Link arrow by kernel.
|
||||
LinkArrowByKernel(from_node, to_actor, from_node_with_index, to_node_with_index, parser);
|
||||
LinkArrowByKernel(from_node, to_actor, from_node_with_index, to_node_with_index, graph_compiler_info);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -828,12 +848,13 @@ void ControlNodeScheduler::LinkArrowByCallNode(const AnfNodePtr &call_node, Cont
|
|||
void ControlNodeScheduler::LinkArrowByKernel(const AnfNodePtr &kernel, ControlActor *const to_actor,
|
||||
const KernelWithIndex &from_node_with_index,
|
||||
const KernelWithIndex &to_node_with_index,
|
||||
const ControlNodeParserPtr &parser) {
|
||||
MS_EXCEPTION_IF_NULL(parser);
|
||||
const GraphCompilerInfo &graph_compiler_info) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
const auto &from_node = from_node_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(from_node);
|
||||
MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
|
||||
const auto &parser = graph_compiler_info.control_node_parser_;
|
||||
const auto &graph = parser->FetchKernelGraphByFrontNode(from_node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
const auto &group_name = parser->FetchGroupNameByKernelGraph(graph);
|
||||
|
@ -844,7 +865,7 @@ void ControlNodeScheduler::LinkArrowByKernel(const AnfNodePtr &kernel, ControlAc
|
|||
const auto &kernel_with_index = parser->FetchBackendNodeByFrontNode(from_node_with_index);
|
||||
MS_EXCEPTION_IF_NULL(kernel_with_index.first);
|
||||
auto type = FetchKernelTransformType(kernel_with_index.first, graph, {});
|
||||
auto from_actor = FetchActor(type, "", kernel_with_index.first, graph);
|
||||
auto from_actor = FetchActor(type, graph_compiler_info.name_, kernel_with_index.first, graph);
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
if (!AnfAlgo::OutputAddrExist(kernel_with_index.first, kernel_with_index.second, false)) {
|
||||
MS_LOG(EXCEPTION) << "Invalid output index:" << kernel_with_index.second
|
||||
|
@ -926,10 +947,10 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor
|
|||
const auto &inputs = cnode->inputs();
|
||||
for (const auto &input : inputs) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimUpdateState) ||
|
||||
AnfAlgo::CheckPrimitiveType(input, prim::kPrimDepend) ||
|
||||
AnfAlgo::CheckPrimitiveType(input, prim::kPrimLoad)) {
|
||||
LinkControlArrowByAutoMonad(from_actor, input, parser);
|
||||
std::vector<AnfNodePtr> monad_nodes = FetchAllMonadNodeByNode(input);
|
||||
for (const auto &monad_node : monad_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(monad_node);
|
||||
LinkControlArrowByAutoMonad(from_actor, monad_node, parser);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -63,20 +63,22 @@ class ControlNodeScheduler {
|
|||
// Link all arrows between control actors.
|
||||
void LinkArrowForControlActor(ControlActorSet *const control_actor_set, const GraphCompilerInfo &graph_compiler_info);
|
||||
void LinkArrowbyFormalParameter(ControlActor *const to_actor, const KernelWithIndex &from_node_with_index,
|
||||
const KernelWithIndex &to_node_with_index, const ControlNodeParserPtr &parser);
|
||||
const KernelWithIndex &to_node_with_index,
|
||||
const GraphCompilerInfo &graph_compiler_info);
|
||||
void LinkArrowByCallNode(const AnfNodePtr &call_node, ControlActor *const to_actor,
|
||||
const KernelWithIndex &from_node_with_index, const KernelWithIndex &to_node_with_index,
|
||||
const ControlNodeParserPtr &parser);
|
||||
void LinkArrowByKernel(const AnfNodePtr &kernel, ControlActor *const to_actor,
|
||||
const KernelWithIndex &from_node_with_index, const KernelWithIndex &to_node_with_index,
|
||||
const ControlNodeParserPtr &parser);
|
||||
const GraphCompilerInfo &graph_compiler_info);
|
||||
void LinkArrowByParameter(const AnfNodePtr ¶meter, ControlActor *const to_actor,
|
||||
const KernelWithIndex &from_node_with_index, const KernelWithIndex &to_node_with_index,
|
||||
const ControlNodeParserPtr &parser);
|
||||
void LinkArrowByValueNode(const AnfNodePtr &value_node, ControlActor *const to_actor, size_t from_index,
|
||||
size_t to_index);
|
||||
// Link arrow from stack actor to control actor.
|
||||
void LinkArrowFromStackActor(StackActor *stack_actor, ControlActor *to_actor, const ControlNodeParserPtr &parser);
|
||||
void LinkArrowFromStackActor(StackActor *stack_actor, ControlActor *to_actor,
|
||||
const GraphCompilerInfo &graph_compiler_info);
|
||||
|
||||
// Link data arrow between control actor and actor in frame, including kernel actor, output actor, data source actor.
|
||||
void LinkDataArrowForKernelActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
|
|
Loading…
Reference in New Issue