From ab248e923cd41aa29fc154915d478208183ed57d Mon Sep 17 00:00:00 2001 From: ckey_Dou Date: Wed, 24 Nov 2021 16:42:08 +0800 Subject: [PATCH] On Ascend the update of dynamic input for graph is Done in 'LoadInputData', which is replaced by DataPrepareActor::PrepareData on GPU(MindRT). This PR do the update of dynamic input in PrepareData and will work for all platforms when they switch to MindRT. --- .../framework/actor/data_prepare_actor.cc | 29 +++++++++++++++++++ .../framework/actor/data_prepare_actor.h | 2 ++ 2 files changed, 31 insertions(+) diff --git a/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc b/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc index 4b2611bac7f..4c4e9c20e46 100644 --- a/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc @@ -14,6 +14,8 @@ * limitations under the License. */ +#include + #include "runtime/framework/actor/data_prepare_actor.h" #include "runtime/framework/actor/memory_manager_actor.h" #include "runtime/framework/actor/kernel_actor.h" @@ -120,6 +122,28 @@ void DataPrepareActor::Init() { } } +void DataPrepareActor::UpdateDynamicShape(const AnfNodePtr &input_node, const TensorPtr &input_tensor) { + MS_EXCEPTION_IF_NULL(input_node); + if (input_tensor == nullptr) { + return; + } + + if (!input_node->isa()) { + return; + } + + auto input_param = input_node->cast(); + MS_EXCEPTION_IF_NULL(input_param); + if (!input_param->has_dynamic_shape()) { + return; + } + + auto shape = input_tensor->shape(); + std::vector shape_tmp; + std::transform(shape.begin(), shape.end(), std::back_inserter(shape_tmp), IntToSize); + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {shape_tmp}, input_node.get()); +} + void DataPrepareActor::PrepareData(const std::vector> &input_tensors, OpContext *const context) { MS_EXCEPTION_IF_NULL(context); @@ -229,6 +253,9 @@ void DataPrepareActor::PrepareDataForHostTensorQueue(const std::vectororigin_parameters_order_, strategy_)) { continue; } + + UpdateDynamicShape(input_node, input_tensor); + auto tensor_position = host_data_source_actor_->FetchNodePosition(input_node); if (tensor_position >= host_tensors.size()) { std::string error_info = "The position of tensor is out of range: " + std::to_string(tensor_position); @@ -277,6 +304,8 @@ void DataPrepareActor::PrepareDataForStepMode(const std::vectorFetchNodePosition(input_node); if (tensor_position >= host_tensors.size()) { diff --git a/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.h b/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.h index 56b253af0c3..897c16e92c1 100644 --- a/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.h +++ b/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.h @@ -68,6 +68,8 @@ class DataPrepareActor : public DebugAwareActor { private: friend class GraphScheduler; + void UpdateDynamicShape(const AnfNodePtr &input_node, const TensorPtr &input_tensor); + void PrepareDataForDeviceTensorStore(const std::vector> &input_tensors, OpContext *const context); void PrepareDataForHostTensorQueue(const std::vector> &input_tensors,