fix the link bug from copy actor to exit actor
This commit is contained in:
parent
3146f49feb
commit
1d9398ac81
|
@ -36,8 +36,9 @@ using mindspore::device::DeviceContext;
|
|||
// -> OnMemoryAllocFinish -> Copy -> SendMemoryFreeReq -> SendOutput.
|
||||
class CopyActor : public MemoryAwareActor {
|
||||
public:
|
||||
CopyActor(const std::string &name, const AID &memory_manager_aid)
|
||||
CopyActor(const std::string &name, AnfNode *from_kernel, const AID &memory_manager_aid)
|
||||
: MemoryAwareActor(name, KernelTransformType::kCopyActor, nullptr, memory_manager_aid),
|
||||
from_kernel_(from_kernel),
|
||||
output_(nullptr),
|
||||
is_need_update_output_size_(false) {}
|
||||
~CopyActor() override = default;
|
||||
|
@ -64,6 +65,9 @@ class CopyActor : public MemoryAwareActor {
|
|||
// Fetch the device tensor for copy.
|
||||
void FetchDeviceTensor(OpContext<DeviceTensor> *const context);
|
||||
|
||||
// The copy source.
|
||||
AnfNode *from_kernel_;
|
||||
|
||||
// The input device tensor is saved from the input data or fetched by device_tensor_store_keys_.
|
||||
std::vector<DeviceTensor *> input_device_tensor_;
|
||||
// The output device tensor is saved from the output or fetched by device_tensor_store_keys_.
|
||||
|
|
|
@ -1102,19 +1102,15 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor
|
|||
}
|
||||
}
|
||||
|
||||
// Link copy actor to exit actor.
|
||||
for (const auto ©_actor : actor_set->copy_actors_) {
|
||||
if (copy_actor->output_data_arrows_.size() != 0 || copy_actor->output_control_arrows_.size() != 0) {
|
||||
if ((!copy_actor->output_data_arrows_.empty()) || (!copy_actor->output_control_arrows_.empty())) {
|
||||
continue;
|
||||
}
|
||||
if (copy_actor->input_control_arrow_aids_.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid copy actor:" << copy_actor->GetAID();
|
||||
if (copy_actor->from_kernel_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Invalid copy actor:" << copy_actor->GetAID().Name();
|
||||
}
|
||||
auto from_actor = FetchActor(copy_actor->input_control_arrow_aids_[0].first.Name());
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
auto kernel_actor = dynamic_cast<KernelActor *>(from_actor);
|
||||
MS_EXCEPTION_IF_NULL(kernel_actor);
|
||||
MS_EXCEPTION_IF_NULL(kernel_actor->kernel_);
|
||||
auto graph = kernel_actor->kernel_->func_graph();
|
||||
auto graph = copy_actor->from_kernel_->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto kernel_graph = std::dynamic_pointer_cast<KernelGraph>(graph);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
@ -1123,6 +1119,7 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor
|
|||
MS_EXCEPTION_IF_NULL(exit_actor);
|
||||
SchedulerHelper::AddControlArrow(copy_actor.get(), exit_actor);
|
||||
}
|
||||
|
||||
LinkControlArrowByKernelGraphGroup(graph_compiler_info);
|
||||
}
|
||||
|
||||
|
|
|
@ -1443,7 +1443,7 @@ void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor,
|
|||
// Link between from actor and copy actor.
|
||||
if (copy_actor == nullptr) {
|
||||
// Create the copy actor.
|
||||
auto copy_actor_shared_ptr = std::make_shared<CopyActor>(name, memory_manager_aid_);
|
||||
auto copy_actor_shared_ptr = std::make_shared<CopyActor>(name, from_kernel.get(), memory_manager_aid_);
|
||||
(void)copy_actors_.emplace_back(copy_actor_shared_ptr);
|
||||
copy_actor = copy_actor_shared_ptr.get();
|
||||
MS_EXCEPTION_IF_NULL(copy_actor);
|
||||
|
@ -2065,7 +2065,13 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ab
|
|||
if (FetchActor(name) != nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto copy_actor = std::make_shared<CopyActor>(name, memory_manager_aid_);
|
||||
AnfNode *from_kernel = nullptr;
|
||||
if (auto_monad_actor->type() == KernelTransformType::kKernelActor) {
|
||||
auto kernel_actor = dynamic_cast<KernelActor *>(auto_monad_actor);
|
||||
MS_EXCEPTION_IF_NULL(kernel_actor);
|
||||
from_kernel = kernel_actor->kernel().get();
|
||||
}
|
||||
auto copy_actor = std::make_shared<CopyActor>(name, from_kernel, memory_manager_aid_);
|
||||
MS_EXCEPTION_IF_NULL(copy_actor);
|
||||
(void)copy_actors_.emplace_back(copy_actor);
|
||||
InsertActor(copy_actor.get());
|
||||
|
|
Loading…
Reference in New Issue