diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.cc index b61263600fe..1bbae1ddfd7 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.cc @@ -48,7 +48,8 @@ void KernelActor::Init() { const auto &input_device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel_, i, false); MS_EXCEPTION_IF_NULL(input_device_tensor); (void)real_input_data_infos_.emplace_back( - std::make_pair(input_device_tensor->format(), input_device_tensor->host_shape())); + std::make_shared(input_device_tensor->format(), input_device_tensor->host_shape(), + input_device_tensor->GetSize(), input_device_tensor->type_id())); } // Init the device tensors and kernel launch info. @@ -311,8 +312,9 @@ void KernelActor::CopyInputDeviceTensor(const OpData *input_data, SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, "The input index is of range."); } auto &real_input_info = real_input_data_infos_[input_data->index_]; + MS_EXCEPTION_IF_NULL(real_input_info); if ((input_data->data_->DeviceType() == device_contexts_[0]->GetDeviceAddressType()) && - (input_data->data_->format() == real_input_info.first)) { + (input_data->data_->format() == real_input_info->format_)) { return; } @@ -320,14 +322,15 @@ void KernelActor::CopyInputDeviceTensor(const OpData *input_data, SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, "The input index is of range."); } if (copy_input_device_tensors_[input_data->index_] == nullptr) { - copy_input_device_tensors_[input_data->index_] = - device_contexts_[0]->CreateDeviceAddress(nullptr, input_data->data_->GetSize(), real_input_info.first, - input_data->data_->type_id(), real_input_info.second); + copy_input_device_tensors_[input_data->index_] = device_contexts_[0]->CreateDeviceAddress( + nullptr, real_input_info->size_, real_input_info->format_, real_input_info->type_id_, real_input_info->shape_); } auto &new_device_tensor = copy_input_device_tensors_[input_data->index_]; MS_EXCEPTION_IF_NULL(new_device_tensor); // Dynamic shape need update size. - new_device_tensor->SetSize(input_data->data_->GetSize()); + if (AnfUtils::IsShapeDynamic(real_input_info->shape_)) { + new_device_tensor->SetSize(input_data->data_->GetSize()); + } // Update the input device tensor. input_device_tensors_[input_data->index_] = new_device_tensor.get(); diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.h b/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.h index d4be291896e..e6132b70215 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.h @@ -39,6 +39,15 @@ using mindspore::kernel::Address; using mindspore::kernel::KernelLaunchInfo; using mindspore::tensor::TensorPtr; +struct InputDataInfo { + InputDataInfo(const std::string &format, const ShapeVector &shape, size_t size, TypeId type_id) + : format_(format), shape_(shape), size_(size), type_id_(type_id) {} + std::string format_; + ShapeVector shape_; + size_t size_; + TypeId type_id_; +}; + // The kernel actor is used to receive the device tensors and control info to luanch kernel. // The processing flow is RunOpData/RunOpControl -> CheckRunningCondition -> SendMemoryAllocReq // -> OnMemoryAllocFinish -> LaunchKernel -> SendMemoryFreeReq -> SendOutput. @@ -129,8 +138,8 @@ class KernelActor : public DebugAwareActor { // The received input device type and format may be different from the formal parameter in the control flow scenarios, // so it needs to be copied from the input data to real data that kernel launch needs. std::vector copy_input_device_tensors_; - // Real data info that kernel launch needs, used to check the consistency of received input data. - std::vector> real_input_data_infos_; + // Real data info that kernel launch needs, used to check the consistency of received input data. + std::vector> real_input_data_infos_; // The device tensors for memory alloc and free. // output + workspace diff --git a/mindspore/core/utils/anf_utils.cc b/mindspore/core/utils/anf_utils.cc index e2991ab00e2..04feb04b9e5 100644 --- a/mindspore/core/utils/anf_utils.cc +++ b/mindspore/core/utils/anf_utils.cc @@ -142,6 +142,10 @@ bool AnfUtils::IsShapeDynamic(const std::vector &shape) { return std::any_of(shape.begin(), shape.end(), [](int64_t s) { return s < 0; }); } +bool AnfUtils::IsShapeDynamic(const std::vector &shape) { + return std::any_of(shape.begin(), shape.end(), [](int64_t s) { return s < 0; }); +} + bool AnfUtils::IsNodeOutputDynamicShape(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto base_shape = node->Shape(); diff --git a/mindspore/core/utils/anf_utils.h b/mindspore/core/utils/anf_utils.h index 4a57a26bcfe..f8d75fbb0b9 100644 --- a/mindspore/core/utils/anf_utils.h +++ b/mindspore/core/utils/anf_utils.h @@ -65,6 +65,7 @@ class MS_CORE_API AnfUtils { static bool IsDimUnknown(const abstract::ShapePtr &shape); static bool IsShapeDynamic(const abstract::ShapePtr &shape); static bool IsShapeDynamic(const std::vector &shape); + static bool IsShapeDynamic(const std::vector &shape); static bool IsNodeOutputDynamicShape(const CNodePtr &node); static bool IsDimUnknown(const AnfNodePtr &node); // check whether the anf node is a real kernel that can run on device,parameter and constant is real kernel too