forked from mindspore-Ecosystem/mindspore
!31050 Dynamic shape for control flow.
Merge pull request !31050 from gaoyong10/dynamic_shape_01
This commit is contained in:
commit
687f40a32f
|
@ -120,8 +120,10 @@ void CPUDeviceContext::FreeMemory(void *const ptr) const {
|
|||
|
||||
DeviceAddressPtr CPUDeviceContext::CreateDeviceAddress(void *const device_ptr, size_t device_size, const string &format,
|
||||
TypeId type_id, const ShapeVector &shape) const {
|
||||
return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id, device_context_key_.device_name_,
|
||||
device_context_key_.device_id_);
|
||||
auto device_address = std::make_shared<CPUDeviceAddress>(
|
||||
device_ptr, device_size, format, type_id, device_context_key_.device_name_, device_context_key_.device_id_);
|
||||
device_address->set_host_shape(shape);
|
||||
return device_address;
|
||||
}
|
||||
|
||||
void CPUDeviceContext::OptimizeGraph(const KernelGraphPtr &graph) const {
|
||||
|
|
|
@ -234,8 +234,10 @@ void GPUDeviceContext::FreeMemory(void *const ptr) const {
|
|||
|
||||
DeviceAddressPtr GPUDeviceContext::CreateDeviceAddress(void *const device_ptr, size_t device_size, const string &format,
|
||||
TypeId type_id, const ShapeVector &shape) const {
|
||||
return std::make_shared<GPUDeviceAddress>(device_ptr, device_size, format, type_id, device_context_key_.device_name_,
|
||||
device_context_key_.device_id_);
|
||||
auto device_address = std::make_shared<GPUDeviceAddress>(
|
||||
device_ptr, device_size, format, type_id, device_context_key_.device_name_, device_context_key_.device_id_);
|
||||
device_address->set_host_shape(shape);
|
||||
return device_address;
|
||||
}
|
||||
|
||||
void GPUDeviceContext::OptimizeGraph(const KernelGraphPtr &graph) const {
|
||||
|
|
|
@ -26,6 +26,7 @@ ControlActor::ControlActor(const std::string &name, KernelTransformType type, co
|
|||
(void)input_partials_.emplace_back(std::make_shared<OpPartial>());
|
||||
}
|
||||
input_device_tensors_.resize(parameters.size());
|
||||
backend_parameters_.resize(parameters.size());
|
||||
}
|
||||
|
||||
void ControlActor::Init() {
|
||||
|
@ -111,6 +112,7 @@ void ControlActor::Run(OpContext<DeviceTensor> *const context) {
|
|||
SendMemoryFreeReq(context);
|
||||
|
||||
EraseInput(context);
|
||||
UpdateDynamicShapeInParameter();
|
||||
SendOutput(context);
|
||||
} catch (const std::exception &e) {
|
||||
MsException::Instance().SetException();
|
||||
|
@ -435,5 +437,21 @@ void ControlActor::SendOutput(OpContext<DeviceTensor> *const context) {
|
|||
IntToSize(partial_arrow->to_input_index_), context);
|
||||
}
|
||||
}
|
||||
|
||||
void ControlActor::UpdateDynamicShapeInParameter() {
|
||||
for (size_t i = 0; i < backend_parameters_.size(); ++i) {
|
||||
if (backend_parameters_[i].empty() || input_device_tensors_[i] == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto shape = input_device_tensors_[i]->host_shape();
|
||||
std::vector<size_t> shape_tmp;
|
||||
std::transform(shape.begin(), shape.end(), std::back_inserter(shape_tmp), IntToSize);
|
||||
|
||||
for (const auto ¶meter : backend_parameters_[i]) {
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({input_device_tensors_[i]->type_id()}, {shape_tmp}, parameter.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -99,6 +99,9 @@ class ControlActor : public MemoryAwareActor {
|
|||
bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override;
|
||||
void UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrowPtr &data_arrow,
|
||||
const AnfNodePtr &output_node, OpContext<DeviceTensor> *const context) override;
|
||||
|
||||
// Update the dynamic shape in backend parameters.
|
||||
void UpdateDynamicShapeInParameter();
|
||||
void SendOutput(OpContext<DeviceTensor> *const context) override;
|
||||
void EraseInput(const OpContext<DeviceTensor> *context) override;
|
||||
|
||||
|
@ -159,6 +162,10 @@ class ControlActor : public MemoryAwareActor {
|
|||
std::map<size_t, std::set<DeviceTensorPtr>> ref_formal_parameter_device_tensors_;
|
||||
std::map<size_t, std::set<DeviceTensorPtr>> ref_node_formal_parameter_device_tensors_;
|
||||
|
||||
// Backend parameters in the kernel graph.In the dynamic shape, when parameters are passed between the kernel
|
||||
// graphs, the shape in the backend parameters needs to be updated.
|
||||
std::vector<std::vector<AnfNodePtr>> backend_parameters_;
|
||||
|
||||
// local node for control actor, such as return node for exit actor, switch node for switch actor.
|
||||
AnfNodePtr node_;
|
||||
};
|
||||
|
|
|
@ -78,6 +78,7 @@ void EntranceActor::Run(OpContext<DeviceTensor> *const context) {
|
|||
SendMemoryFreeReq(context);
|
||||
|
||||
EraseInput(context);
|
||||
UpdateDynamicShapeInParameter();
|
||||
SendOutput(context);
|
||||
}
|
||||
|
||||
|
|
|
@ -171,14 +171,23 @@ void ExitActor::CopyDeviceAddress(OpContext<DeviceTensor> *const context) {
|
|||
continue;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(device_contexts_[i]);
|
||||
|
||||
auto host_shape = input_device_tensor->host_shape();
|
||||
if (common::AnfAlgo::IsDynamicShape(node_with_index.first)) {
|
||||
// If there is a dynamic shape, the shape in the kernel should be used.
|
||||
MS_LOG(DEBUG) << "Update dynamic shape in kernel output:" << node_with_index.first->DebugString()
|
||||
<< " for actor:" << GetAID();
|
||||
auto shape_tmp = common::AnfAlgo::GetOutputInferShape(node_with_index.first, node_with_index.second);
|
||||
host_shape.clear();
|
||||
std::transform(shape_tmp.begin(), shape_tmp.end(), std::back_inserter(host_shape), IntToSize);
|
||||
}
|
||||
// Create the new device tensor to take over the input_device_tensors which are the outputs of kernel graphs.
|
||||
auto new_device_tensor =
|
||||
device_contexts_[i]->CreateDeviceAddress(nullptr, input_device_tensor->GetSize(), input_device_tensor->format(),
|
||||
input_device_tensor->type_id(), input_device_tensor->host_shape());
|
||||
input_device_tensor->type_id(), host_shape);
|
||||
MS_EXCEPTION_IF_NULL(new_device_tensor);
|
||||
(void)created_device_tensors_.emplace_back(new_device_tensor);
|
||||
(void)new_device_tensors.emplace_back(new_device_tensor.get());
|
||||
|
||||
new_device_tensor->SetNodeIndex(node_with_index.first, node_with_index.second);
|
||||
new_device_tensor->set_from_persistent_mem(input_device_tensor->from_persistent_mem());
|
||||
// The device address which is created by actor uses the dynamic ref count.
|
||||
|
|
|
@ -76,10 +76,6 @@ class StackActor : public ControlActor {
|
|||
size_t input_stack_data_num_{0};
|
||||
size_t input_stack_partials_num_{0};
|
||||
size_t input_stack_controls_num_{0};
|
||||
// The backend parameter is used to save the backend node corresponding to the device tensor in the stack.
|
||||
// When these device tensors are used as output, they need to be placed in the node of the result arrow,
|
||||
// so these nodes need to be saved.
|
||||
std::vector<KernelWithIndex> backend_parameters_;
|
||||
};
|
||||
|
||||
using StackActorPtr = std::shared_ptr<StackActor>;
|
||||
|
|
|
@ -647,6 +647,45 @@ void ControlNodeScheduler::Link(ActorSet *const actor_set, const GraphCompilerIn
|
|||
LinkControlArrowForKernelActor(actor_set, graph_compiler_info);
|
||||
|
||||
LinkControlArrowForLoopCountActor(actor_set, graph_compiler_info);
|
||||
|
||||
LinkControlArrowForCustomActor(actor_set, graph_compiler_info);
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::LinkControlArrowForCustomActor(ActorSet *const actor_set,
|
||||
const GraphCompilerInfo &graph_compiler_info) {
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
const auto &parser = graph_compiler_info.control_node_parser_;
|
||||
MS_EXCEPTION_IF_NULL(parser);
|
||||
|
||||
for (auto &custom_actor : actor_set->custom_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(custom_actor);
|
||||
const auto &kernel = custom_actor->kernel().lock();
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
const auto &graph = kernel->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (custom_actor->output_data_arrows().empty() && custom_actor->output_control_arrows().empty()) {
|
||||
const auto &actor_name = graph->ToString() + kExitActorNameSuffix;
|
||||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
LinkControlArrow(custom_actor.get(), actor);
|
||||
}
|
||||
if (custom_actor->input_control_arrow_aids().empty() && custom_actor->input_data_arrow_aids().empty()) {
|
||||
const auto &kernel_graph = dynamic_cast<KernelGraph *>(graph.get());
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
AbstractActor *from_actor = nullptr;
|
||||
if (parser->IsCallInputKernelGraph(kernel_graph)) {
|
||||
const auto &actor_name = kernel_graph->ToString() + kStackActorNameSuffix;
|
||||
from_actor = FetchActor(actor_name);
|
||||
} else {
|
||||
const auto &func_graph = parser->FetchFuncGraphByKernelGraph(kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
const auto &actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
|
||||
from_actor = FetchActor(actor_name);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
LinkControlArrow(from_actor, custom_actor.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::ClearActorData(const ControlActorSet *control_actor_set) {
|
||||
|
@ -1555,6 +1594,21 @@ void ControlNodeScheduler::AddFormalParameterDeviceTensor(ControlActor *const fr
|
|||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
||||
// Collect backend parameters with dynamic shapes.
|
||||
auto base_shape = input_node->Shape();
|
||||
if (input_node->isa<Parameter>() && base_shape != nullptr && base_shape->isa<abstract::Shape>()) {
|
||||
if (AnfUtils::IsShapeDynamic(base_shape->cast<abstract::ShapePtr>())) {
|
||||
if (from_index >= from_actor->backend_parameters_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid from index:" << from_index << " for actor:" << from_actor->GetAID()
|
||||
<< " vector size:" << from_actor->backend_parameters_.size();
|
||||
}
|
||||
MS_LOG(INFO) << "Add dynamic shape backend parameter:" << input_node->DebugString() << " index:" << from_index
|
||||
<< " for actor:" << from_actor->GetAID();
|
||||
from_actor->backend_parameters_[from_index].emplace_back(input_node);
|
||||
}
|
||||
}
|
||||
|
||||
if (!HasAbstractRef(input_node)) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -95,6 +95,7 @@ class ControlNodeScheduler {
|
|||
void LinkControlArrowForLoopCountActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info);
|
||||
void LinkDataArrowForOutputActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
|
||||
void LinkControlArrowForKernelActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
|
||||
void LinkControlArrowForCustomActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
|
||||
void LinkControlArrowByAutoMonad(ControlActor *to_actor, const AnfNodePtr &from_node,
|
||||
const ControlNodeParserPtr &parser);
|
||||
|
||||
|
|
|
@ -1585,6 +1585,14 @@ void GraphScheduler::LinkControlArrowForCustomActor(ActorSet *const actor_set,
|
|||
no_depend_custom_actors.erase(std::dynamic_pointer_cast<CustomActor>(to_iter->second));
|
||||
}
|
||||
}
|
||||
|
||||
// In control flow, no input actors should be linked to entrance actors.
|
||||
const auto &parser = graph_compiler_info.control_node_parser_;
|
||||
MS_EXCEPTION_IF_NULL(parser);
|
||||
if (parser->IsInited()) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const auto &custom_actor : no_depend_custom_actors) {
|
||||
auto kernel = custom_actor->kernel().lock();
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class UniqueIf(nn.Cell):
|
||||
def __init__(self):
|
||||
super(UniqueIf, self).__init__()
|
||||
self.unique = P.Unique()
|
||||
self.shape = P.DynamicShape()
|
||||
|
||||
def construct(self, x, index):
|
||||
x_unique = self.unique(x)[0]
|
||||
if index > 3:
|
||||
x_unique = x_unique + 2
|
||||
else:
|
||||
x_unique = x_unique - 3
|
||||
return self.shape(x_unique)
|
||||
|
||||
|
||||
class UniqueWhile(nn.Cell):
|
||||
def __init__(self):
|
||||
super(UniqueWhile, self).__init__()
|
||||
self.unique = P.Unique()
|
||||
self.shape = P.DynamicShape()
|
||||
self.mod = P.Mod()
|
||||
|
||||
def construct(self, x, y, index):
|
||||
while index < 3:
|
||||
x = self.mod(x, y[index])
|
||||
x = self.unique(x)[0]
|
||||
index = index + 1
|
||||
return self.shape(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_unique_if():
|
||||
"""
|
||||
Feature: Dynamic shape for control flow.
|
||||
Description: If scene.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([4, 5, 1, 2, 3, 3, 4, 5]).astype(np.int32))
|
||||
index = Tensor([0], mstype.int32)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = UniqueIf()
|
||||
x_shape = net(x, index)
|
||||
assert x_shape == 5
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_unique_while():
|
||||
"""
|
||||
Feature: Dynamic shape for control flow.
|
||||
Description: While scene.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([12406268, 4962722, 720966, 75948, 6776, 960, 67, 8]).astype(np.int32))
|
||||
y = Tensor(np.array([957, 67, 7]).astype(np.int32))
|
||||
index = Tensor([0], mstype.int32)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = UniqueWhile()
|
||||
x_shape = net(x, y, index)
|
||||
assert x_shape == 3
|
Loading…
Reference in New Issue