diff --git a/mindspore/ccsrc/plugin/device/cpu/hal/hardware/cpu_device_context.cc b/mindspore/ccsrc/plugin/device/cpu/hal/hardware/cpu_device_context.cc index 5204882321d..8526205ff8b 100644 --- a/mindspore/ccsrc/plugin/device/cpu/hal/hardware/cpu_device_context.cc +++ b/mindspore/ccsrc/plugin/device/cpu/hal/hardware/cpu_device_context.cc @@ -120,8 +120,10 @@ void CPUDeviceContext::FreeMemory(void *const ptr) const { DeviceAddressPtr CPUDeviceContext::CreateDeviceAddress(void *const device_ptr, size_t device_size, const string &format, TypeId type_id, const ShapeVector &shape) const { - return std::make_shared(device_ptr, device_size, format, type_id, device_context_key_.device_name_, - device_context_key_.device_id_); + auto device_address = std::make_shared( + device_ptr, device_size, format, type_id, device_context_key_.device_name_, device_context_key_.device_id_); + device_address->set_host_shape(shape); + return device_address; } void CPUDeviceContext::OptimizeGraph(const KernelGraphPtr &graph) const { diff --git a/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.cc b/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.cc index d2ba4736ab9..a40efcd5134 100644 --- a/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.cc +++ b/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.cc @@ -234,8 +234,10 @@ void GPUDeviceContext::FreeMemory(void *const ptr) const { DeviceAddressPtr GPUDeviceContext::CreateDeviceAddress(void *const device_ptr, size_t device_size, const string &format, TypeId type_id, const ShapeVector &shape) const { - return std::make_shared(device_ptr, device_size, format, type_id, device_context_key_.device_name_, - device_context_key_.device_id_); + auto device_address = std::make_shared( + device_ptr, device_size, format, type_id, device_context_key_.device_name_, device_context_key_.device_id_); + device_address->set_host_shape(shape); + return device_address; } void GPUDeviceContext::OptimizeGraph(const KernelGraphPtr &graph) const { diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/control_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/control_actor.cc index 4f94682dab9..989d4a9a0d2 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/control_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/control_actor.cc @@ -26,6 +26,7 @@ ControlActor::ControlActor(const std::string &name, KernelTransformType type, co (void)input_partials_.emplace_back(std::make_shared()); } input_device_tensors_.resize(parameters.size()); + backend_parameters_.resize(parameters.size()); } void ControlActor::Init() { @@ -111,6 +112,7 @@ void ControlActor::Run(OpContext *const context) { SendMemoryFreeReq(context); EraseInput(context); + UpdateDynamicShapeInParameter(); SendOutput(context); } catch (const std::exception &e) { MsException::Instance().SetException(); @@ -435,5 +437,21 @@ void ControlActor::SendOutput(OpContext *const context) { IntToSize(partial_arrow->to_input_index_), context); } } + +void ControlActor::UpdateDynamicShapeInParameter() { + for (size_t i = 0; i < backend_parameters_.size(); ++i) { + if (backend_parameters_[i].empty() || input_device_tensors_[i] == nullptr) { + continue; + } + + auto shape = input_device_tensors_[i]->host_shape(); + std::vector shape_tmp; + std::transform(shape.begin(), shape.end(), std::back_inserter(shape_tmp), IntToSize); + + for (const auto ¶meter : backend_parameters_[i]) { + common::AnfAlgo::SetOutputInferTypeAndShape({input_device_tensors_[i]->type_id()}, {shape_tmp}, parameter.get()); + } + } +} } // namespace runtime } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/control_actor.h b/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/control_actor.h index 059e882e5d2..065500d980c 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/control_actor.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/control_actor.h @@ -99,6 +99,9 @@ class ControlActor : public MemoryAwareActor { bool CheckRunningCondition(const OpContext *context) const override; void UpdateOutputData(OpData *const output_data, const DataArrowPtr &data_arrow, const AnfNodePtr &output_node, OpContext *const context) override; + + // Update the dynamic shape in backend parameters. + void UpdateDynamicShapeInParameter(); void SendOutput(OpContext *const context) override; void EraseInput(const OpContext *context) override; @@ -159,6 +162,10 @@ class ControlActor : public MemoryAwareActor { std::map> ref_formal_parameter_device_tensors_; std::map> ref_node_formal_parameter_device_tensors_; + // Backend parameters in the kernel graph.In the dynamic shape, when parameters are passed between the kernel + // graphs, the shape in the backend parameters needs to be updated. + std::vector> backend_parameters_; + // local node for control actor, such as return node for exit actor, switch node for switch actor. AnfNodePtr node_; }; diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/entrance_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/entrance_actor.cc index 87962a6964d..91180444bf5 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/entrance_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/entrance_actor.cc @@ -78,6 +78,7 @@ void EntranceActor::Run(OpContext *const context) { SendMemoryFreeReq(context); EraseInput(context); + UpdateDynamicShapeInParameter(); SendOutput(context); } diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/exit_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/exit_actor.cc index a5c4f6c5d57..bbd242f6716 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/exit_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/exit_actor.cc @@ -171,14 +171,23 @@ void ExitActor::CopyDeviceAddress(OpContext *const context) { continue; } MS_EXCEPTION_IF_NULL(device_contexts_[i]); + + auto host_shape = input_device_tensor->host_shape(); + if (common::AnfAlgo::IsDynamicShape(node_with_index.first)) { + // If there is a dynamic shape, the shape in the kernel should be used. + MS_LOG(DEBUG) << "Update dynamic shape in kernel output:" << node_with_index.first->DebugString() + << " for actor:" << GetAID(); + auto shape_tmp = common::AnfAlgo::GetOutputInferShape(node_with_index.first, node_with_index.second); + host_shape.clear(); + std::transform(shape_tmp.begin(), shape_tmp.end(), std::back_inserter(host_shape), IntToSize); + } // Create the new device tensor to take over the input_device_tensors which are the outputs of kernel graphs. auto new_device_tensor = device_contexts_[i]->CreateDeviceAddress(nullptr, input_device_tensor->GetSize(), input_device_tensor->format(), - input_device_tensor->type_id(), input_device_tensor->host_shape()); + input_device_tensor->type_id(), host_shape); MS_EXCEPTION_IF_NULL(new_device_tensor); (void)created_device_tensors_.emplace_back(new_device_tensor); (void)new_device_tensors.emplace_back(new_device_tensor.get()); - new_device_tensor->SetNodeIndex(node_with_index.first, node_with_index.second); new_device_tensor->set_from_persistent_mem(input_device_tensor->from_persistent_mem()); // The device address which is created by actor uses the dynamic ref count. diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/stack_actor.h b/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/stack_actor.h index 533b08db442..7260ee584cd 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/stack_actor.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/stack_actor.h @@ -76,10 +76,6 @@ class StackActor : public ControlActor { size_t input_stack_data_num_{0}; size_t input_stack_partials_num_{0}; size_t input_stack_controls_num_{0}; - // The backend parameter is used to save the backend node corresponding to the device tensor in the stack. - // When these device tensors are used as output, they need to be placed in the node of the result arrow, - // so these nodes need to be saved. - std::vector backend_parameters_; }; using StackActorPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/runtime/graph_scheduler/control_node_scheduler.cc b/mindspore/ccsrc/runtime/graph_scheduler/control_node_scheduler.cc index c18abece49e..c71e0d4e0b7 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/control_node_scheduler.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/control_node_scheduler.cc @@ -647,6 +647,45 @@ void ControlNodeScheduler::Link(ActorSet *const actor_set, const GraphCompilerIn LinkControlArrowForKernelActor(actor_set, graph_compiler_info); LinkControlArrowForLoopCountActor(actor_set, graph_compiler_info); + + LinkControlArrowForCustomActor(actor_set, graph_compiler_info); +} + +void ControlNodeScheduler::LinkControlArrowForCustomActor(ActorSet *const actor_set, + const GraphCompilerInfo &graph_compiler_info) { + MS_EXCEPTION_IF_NULL(actor_set); + const auto &parser = graph_compiler_info.control_node_parser_; + MS_EXCEPTION_IF_NULL(parser); + + for (auto &custom_actor : actor_set->custom_actors_) { + MS_EXCEPTION_IF_NULL(custom_actor); + const auto &kernel = custom_actor->kernel().lock(); + MS_EXCEPTION_IF_NULL(kernel); + const auto &graph = kernel->func_graph(); + MS_EXCEPTION_IF_NULL(graph); + if (custom_actor->output_data_arrows().empty() && custom_actor->output_control_arrows().empty()) { + const auto &actor_name = graph->ToString() + kExitActorNameSuffix; + auto actor = FetchActor(actor_name); + MS_EXCEPTION_IF_NULL(actor); + LinkControlArrow(custom_actor.get(), actor); + } + if (custom_actor->input_control_arrow_aids().empty() && custom_actor->input_data_arrow_aids().empty()) { + const auto &kernel_graph = dynamic_cast(graph.get()); + MS_EXCEPTION_IF_NULL(kernel_graph); + AbstractActor *from_actor = nullptr; + if (parser->IsCallInputKernelGraph(kernel_graph)) { + const auto &actor_name = kernel_graph->ToString() + kStackActorNameSuffix; + from_actor = FetchActor(actor_name); + } else { + const auto &func_graph = parser->FetchFuncGraphByKernelGraph(kernel_graph); + MS_EXCEPTION_IF_NULL(func_graph); + const auto &actor_name = func_graph->ToString() + kEntranceActorNameSuffix; + from_actor = FetchActor(actor_name); + } + MS_EXCEPTION_IF_NULL(from_actor); + LinkControlArrow(from_actor, custom_actor.get()); + } + } } void ControlNodeScheduler::ClearActorData(const ControlActorSet *control_actor_set) { @@ -1555,6 +1594,21 @@ void ControlNodeScheduler::AddFormalParameterDeviceTensor(ControlActor *const fr MS_EXCEPTION_IF_NULL(from_actor); MS_EXCEPTION_IF_NULL(input_node); MS_EXCEPTION_IF_NULL(graph); + + // Collect backend parameters with dynamic shapes. + auto base_shape = input_node->Shape(); + if (input_node->isa() && base_shape != nullptr && base_shape->isa()) { + if (AnfUtils::IsShapeDynamic(base_shape->cast())) { + if (from_index >= from_actor->backend_parameters_.size()) { + MS_LOG(EXCEPTION) << "Invalid from index:" << from_index << " for actor:" << from_actor->GetAID() + << " vector size:" << from_actor->backend_parameters_.size(); + } + MS_LOG(INFO) << "Add dynamic shape backend parameter:" << input_node->DebugString() << " index:" << from_index + << " for actor:" << from_actor->GetAID(); + from_actor->backend_parameters_[from_index].emplace_back(input_node); + } + } + if (!HasAbstractRef(input_node)) { return; } diff --git a/mindspore/ccsrc/runtime/graph_scheduler/control_node_scheduler.h b/mindspore/ccsrc/runtime/graph_scheduler/control_node_scheduler.h index 20e75831027..482c29c13b2 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/control_node_scheduler.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/control_node_scheduler.h @@ -95,6 +95,7 @@ class ControlNodeScheduler { void LinkControlArrowForLoopCountActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info); void LinkDataArrowForOutputActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info); void LinkControlArrowForKernelActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info); + void LinkControlArrowForCustomActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info); void LinkControlArrowByAutoMonad(ControlActor *to_actor, const AnfNodePtr &from_node, const ControlNodeParserPtr &parser); diff --git a/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc b/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc index fb284c05451..8289dbfff5c 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc @@ -1585,6 +1585,14 @@ void GraphScheduler::LinkControlArrowForCustomActor(ActorSet *const actor_set, no_depend_custom_actors.erase(std::dynamic_pointer_cast(to_iter->second)); } } + + // In control flow, no input actors should be linked to entrance actors. + const auto &parser = graph_compiler_info.control_node_parser_; + MS_EXCEPTION_IF_NULL(parser); + if (parser->IsInited()) { + return; + } + for (const auto &custom_actor : no_depend_custom_actors) { auto kernel = custom_actor->kernel().lock(); MS_EXCEPTION_IF_NULL(kernel); diff --git a/tests/st/dynamic_shape/test_dynamic_shape_with_control_flow.py b/tests/st/dynamic_shape/test_dynamic_shape_with_control_flow.py new file mode 100644 index 00000000000..15cfa4fd0a2 --- /dev/null +++ b/tests/st/dynamic_shape/test_dynamic_shape_with_control_flow.py @@ -0,0 +1,88 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype + + +class UniqueIf(nn.Cell): + def __init__(self): + super(UniqueIf, self).__init__() + self.unique = P.Unique() + self.shape = P.DynamicShape() + + def construct(self, x, index): + x_unique = self.unique(x)[0] + if index > 3: + x_unique = x_unique + 2 + else: + x_unique = x_unique - 3 + return self.shape(x_unique) + + +class UniqueWhile(nn.Cell): + def __init__(self): + super(UniqueWhile, self).__init__() + self.unique = P.Unique() + self.shape = P.DynamicShape() + self.mod = P.Mod() + + def construct(self, x, y, index): + while index < 3: + x = self.mod(x, y[index]) + x = self.unique(x)[0] + index = index + 1 + return self.shape(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_unique_if(): + """ + Feature: Dynamic shape for control flow. + Description: If scene. + Expectation: No exception. + """ + x = Tensor(np.array([4, 5, 1, 2, 3, 3, 4, 5]).astype(np.int32)) + index = Tensor([0], mstype.int32) + context.set_context(mode=context.GRAPH_MODE) + net = UniqueIf() + x_shape = net(x, index) + assert x_shape == 5 + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_unique_while(): + """ + Feature: Dynamic shape for control flow. + Description: While scene. + Expectation: No exception. + """ + x = Tensor(np.array([12406268, 4962722, 720966, 75948, 6776, 960, 67, 8]).astype(np.int32)) + y = Tensor(np.array([957, 67, 7]).astype(np.int32)) + index = Tensor([0], mstype.int32) + context.set_context(mode=context.GRAPH_MODE) + net = UniqueWhile() + x_shape = net(x, y, index) + assert x_shape == 3