fix the link bug from copy actor to exit actor

This commit is contained in:
limingqi107 2022-06-20 17:28:02 +08:00
parent 3146f49feb
commit 1d9398ac81
3 changed files with 19 additions and 12 deletions

View File

@ -36,8 +36,9 @@ using mindspore::device::DeviceContext;
// -> OnMemoryAllocFinish -> Copy -> SendMemoryFreeReq -> SendOutput.
class CopyActor : public MemoryAwareActor {
public:
CopyActor(const std::string &name, const AID &memory_manager_aid)
CopyActor(const std::string &name, AnfNode *from_kernel, const AID &memory_manager_aid)
: MemoryAwareActor(name, KernelTransformType::kCopyActor, nullptr, memory_manager_aid),
from_kernel_(from_kernel),
output_(nullptr),
is_need_update_output_size_(false) {}
~CopyActor() override = default;
@ -64,6 +65,9 @@ class CopyActor : public MemoryAwareActor {
// Fetch the device tensor for copy.
void FetchDeviceTensor(OpContext<DeviceTensor> *const context);
// The copy source.
AnfNode *from_kernel_;
// The input device tensor is saved from the input data or fetched by device_tensor_store_keys_.
std::vector<DeviceTensor *> input_device_tensor_;
// The output device tensor is saved from the output or fetched by device_tensor_store_keys_.

View File

@ -1102,19 +1102,15 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor
}
}
// Link copy actor to exit actor.
for (const auto &copy_actor : actor_set->copy_actors_) {
if (copy_actor->output_data_arrows_.size() != 0 || copy_actor->output_control_arrows_.size() != 0) {
if ((!copy_actor->output_data_arrows_.empty()) || (!copy_actor->output_control_arrows_.empty())) {
continue;
}
if (copy_actor->input_control_arrow_aids_.empty()) {
MS_LOG(EXCEPTION) << "Invalid copy actor:" << copy_actor->GetAID();
if (copy_actor->from_kernel_ == nullptr) {
MS_LOG(EXCEPTION) << "Invalid copy actor:" << copy_actor->GetAID().Name();
}
auto from_actor = FetchActor(copy_actor->input_control_arrow_aids_[0].first.Name());
MS_EXCEPTION_IF_NULL(from_actor);
auto kernel_actor = dynamic_cast<KernelActor *>(from_actor);
MS_EXCEPTION_IF_NULL(kernel_actor);
MS_EXCEPTION_IF_NULL(kernel_actor->kernel_);
auto graph = kernel_actor->kernel_->func_graph();
auto graph = copy_actor->from_kernel_->func_graph();
MS_EXCEPTION_IF_NULL(graph);
auto kernel_graph = std::dynamic_pointer_cast<KernelGraph>(graph);
MS_EXCEPTION_IF_NULL(kernel_graph);
@ -1123,6 +1119,7 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor
MS_EXCEPTION_IF_NULL(exit_actor);
SchedulerHelper::AddControlArrow(copy_actor.get(), exit_actor);
}
LinkControlArrowByKernelGraphGroup(graph_compiler_info);
}

View File

@ -1443,7 +1443,7 @@ void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor,
// Link between from actor and copy actor.
if (copy_actor == nullptr) {
// Create the copy actor.
auto copy_actor_shared_ptr = std::make_shared<CopyActor>(name, memory_manager_aid_);
auto copy_actor_shared_ptr = std::make_shared<CopyActor>(name, from_kernel.get(), memory_manager_aid_);
(void)copy_actors_.emplace_back(copy_actor_shared_ptr);
copy_actor = copy_actor_shared_ptr.get();
MS_EXCEPTION_IF_NULL(copy_actor);
@ -2065,7 +2065,13 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ab
if (FetchActor(name) != nullptr) {
continue;
}
auto copy_actor = std::make_shared<CopyActor>(name, memory_manager_aid_);
AnfNode *from_kernel = nullptr;
if (auto_monad_actor->type() == KernelTransformType::kKernelActor) {
auto kernel_actor = dynamic_cast<KernelActor *>(auto_monad_actor);
MS_EXCEPTION_IF_NULL(kernel_actor);
from_kernel = kernel_actor->kernel().get();
}
auto copy_actor = std::make_shared<CopyActor>(name, from_kernel, memory_manager_aid_);
MS_EXCEPTION_IF_NULL(copy_actor);
(void)copy_actors_.emplace_back(copy_actor);
InsertActor(copy_actor.get());