!31050 Dynamic shape for control flow.

Merge pull request !31050 from gaoyong10/dynamic_shape_01
This commit is contained in:
i-robot 2022-03-10 13:25:05 +00:00 committed by Gitee
commit 687f40a32f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 196 additions and 10 deletions

View File

@ -120,8 +120,10 @@ void CPUDeviceContext::FreeMemory(void *const ptr) const {
DeviceAddressPtr CPUDeviceContext::CreateDeviceAddress(void *const device_ptr, size_t device_size, const string &format,
TypeId type_id, const ShapeVector &shape) const {
return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id, device_context_key_.device_name_,
device_context_key_.device_id_);
auto device_address = std::make_shared<CPUDeviceAddress>(
device_ptr, device_size, format, type_id, device_context_key_.device_name_, device_context_key_.device_id_);
device_address->set_host_shape(shape);
return device_address;
}
void CPUDeviceContext::OptimizeGraph(const KernelGraphPtr &graph) const {

View File

@ -234,8 +234,10 @@ void GPUDeviceContext::FreeMemory(void *const ptr) const {
DeviceAddressPtr GPUDeviceContext::CreateDeviceAddress(void *const device_ptr, size_t device_size, const string &format,
TypeId type_id, const ShapeVector &shape) const {
return std::make_shared<GPUDeviceAddress>(device_ptr, device_size, format, type_id, device_context_key_.device_name_,
device_context_key_.device_id_);
auto device_address = std::make_shared<GPUDeviceAddress>(
device_ptr, device_size, format, type_id, device_context_key_.device_name_, device_context_key_.device_id_);
device_address->set_host_shape(shape);
return device_address;
}
void GPUDeviceContext::OptimizeGraph(const KernelGraphPtr &graph) const {

View File

@ -26,6 +26,7 @@ ControlActor::ControlActor(const std::string &name, KernelTransformType type, co
(void)input_partials_.emplace_back(std::make_shared<OpPartial>());
}
input_device_tensors_.resize(parameters.size());
backend_parameters_.resize(parameters.size());
}
void ControlActor::Init() {
@ -111,6 +112,7 @@ void ControlActor::Run(OpContext<DeviceTensor> *const context) {
SendMemoryFreeReq(context);
EraseInput(context);
UpdateDynamicShapeInParameter();
SendOutput(context);
} catch (const std::exception &e) {
MsException::Instance().SetException();
@ -435,5 +437,21 @@ void ControlActor::SendOutput(OpContext<DeviceTensor> *const context) {
IntToSize(partial_arrow->to_input_index_), context);
}
}
void ControlActor::UpdateDynamicShapeInParameter() {
for (size_t i = 0; i < backend_parameters_.size(); ++i) {
if (backend_parameters_[i].empty() || input_device_tensors_[i] == nullptr) {
continue;
}
auto shape = input_device_tensors_[i]->host_shape();
std::vector<size_t> shape_tmp;
std::transform(shape.begin(), shape.end(), std::back_inserter(shape_tmp), IntToSize);
for (const auto &parameter : backend_parameters_[i]) {
common::AnfAlgo::SetOutputInferTypeAndShape({input_device_tensors_[i]->type_id()}, {shape_tmp}, parameter.get());
}
}
}
} // namespace runtime
} // namespace mindspore

View File

@ -99,6 +99,9 @@ class ControlActor : public MemoryAwareActor {
bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override;
void UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrowPtr &data_arrow,
const AnfNodePtr &output_node, OpContext<DeviceTensor> *const context) override;
// Update the dynamic shape in backend parameters.
void UpdateDynamicShapeInParameter();
void SendOutput(OpContext<DeviceTensor> *const context) override;
void EraseInput(const OpContext<DeviceTensor> *context) override;
@ -159,6 +162,10 @@ class ControlActor : public MemoryAwareActor {
std::map<size_t, std::set<DeviceTensorPtr>> ref_formal_parameter_device_tensors_;
std::map<size_t, std::set<DeviceTensorPtr>> ref_node_formal_parameter_device_tensors_;
// Backend parameters in the kernel graph.In the dynamic shape, when parameters are passed between the kernel
// graphs, the shape in the backend parameters needs to be updated.
std::vector<std::vector<AnfNodePtr>> backend_parameters_;
// local node for control actor, such as return node for exit actor, switch node for switch actor.
AnfNodePtr node_;
};

View File

@ -78,6 +78,7 @@ void EntranceActor::Run(OpContext<DeviceTensor> *const context) {
SendMemoryFreeReq(context);
EraseInput(context);
UpdateDynamicShapeInParameter();
SendOutput(context);
}

View File

@ -171,14 +171,23 @@ void ExitActor::CopyDeviceAddress(OpContext<DeviceTensor> *const context) {
continue;
}
MS_EXCEPTION_IF_NULL(device_contexts_[i]);
auto host_shape = input_device_tensor->host_shape();
if (common::AnfAlgo::IsDynamicShape(node_with_index.first)) {
// If there is a dynamic shape, the shape in the kernel should be used.
MS_LOG(DEBUG) << "Update dynamic shape in kernel output:" << node_with_index.first->DebugString()
<< " for actor:" << GetAID();
auto shape_tmp = common::AnfAlgo::GetOutputInferShape(node_with_index.first, node_with_index.second);
host_shape.clear();
std::transform(shape_tmp.begin(), shape_tmp.end(), std::back_inserter(host_shape), IntToSize);
}
// Create the new device tensor to take over the input_device_tensors which are the outputs of kernel graphs.
auto new_device_tensor =
device_contexts_[i]->CreateDeviceAddress(nullptr, input_device_tensor->GetSize(), input_device_tensor->format(),
input_device_tensor->type_id(), input_device_tensor->host_shape());
input_device_tensor->type_id(), host_shape);
MS_EXCEPTION_IF_NULL(new_device_tensor);
(void)created_device_tensors_.emplace_back(new_device_tensor);
(void)new_device_tensors.emplace_back(new_device_tensor.get());
new_device_tensor->SetNodeIndex(node_with_index.first, node_with_index.second);
new_device_tensor->set_from_persistent_mem(input_device_tensor->from_persistent_mem());
// The device address which is created by actor uses the dynamic ref count.

View File

@ -76,10 +76,6 @@ class StackActor : public ControlActor {
size_t input_stack_data_num_{0};
size_t input_stack_partials_num_{0};
size_t input_stack_controls_num_{0};
// The backend parameter is used to save the backend node corresponding to the device tensor in the stack.
// When these device tensors are used as output, they need to be placed in the node of the result arrow,
// so these nodes need to be saved.
std::vector<KernelWithIndex> backend_parameters_;
};
using StackActorPtr = std::shared_ptr<StackActor>;

View File

@ -647,6 +647,45 @@ void ControlNodeScheduler::Link(ActorSet *const actor_set, const GraphCompilerIn
LinkControlArrowForKernelActor(actor_set, graph_compiler_info);
LinkControlArrowForLoopCountActor(actor_set, graph_compiler_info);
LinkControlArrowForCustomActor(actor_set, graph_compiler_info);
}
void ControlNodeScheduler::LinkControlArrowForCustomActor(ActorSet *const actor_set,
const GraphCompilerInfo &graph_compiler_info) {
MS_EXCEPTION_IF_NULL(actor_set);
const auto &parser = graph_compiler_info.control_node_parser_;
MS_EXCEPTION_IF_NULL(parser);
for (auto &custom_actor : actor_set->custom_actors_) {
MS_EXCEPTION_IF_NULL(custom_actor);
const auto &kernel = custom_actor->kernel().lock();
MS_EXCEPTION_IF_NULL(kernel);
const auto &graph = kernel->func_graph();
MS_EXCEPTION_IF_NULL(graph);
if (custom_actor->output_data_arrows().empty() && custom_actor->output_control_arrows().empty()) {
const auto &actor_name = graph->ToString() + kExitActorNameSuffix;
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
LinkControlArrow(custom_actor.get(), actor);
}
if (custom_actor->input_control_arrow_aids().empty() && custom_actor->input_data_arrow_aids().empty()) {
const auto &kernel_graph = dynamic_cast<KernelGraph *>(graph.get());
MS_EXCEPTION_IF_NULL(kernel_graph);
AbstractActor *from_actor = nullptr;
if (parser->IsCallInputKernelGraph(kernel_graph)) {
const auto &actor_name = kernel_graph->ToString() + kStackActorNameSuffix;
from_actor = FetchActor(actor_name);
} else {
const auto &func_graph = parser->FetchFuncGraphByKernelGraph(kernel_graph);
MS_EXCEPTION_IF_NULL(func_graph);
const auto &actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
from_actor = FetchActor(actor_name);
}
MS_EXCEPTION_IF_NULL(from_actor);
LinkControlArrow(from_actor, custom_actor.get());
}
}
}
void ControlNodeScheduler::ClearActorData(const ControlActorSet *control_actor_set) {
@ -1555,6 +1594,21 @@ void ControlNodeScheduler::AddFormalParameterDeviceTensor(ControlActor *const fr
MS_EXCEPTION_IF_NULL(from_actor);
MS_EXCEPTION_IF_NULL(input_node);
MS_EXCEPTION_IF_NULL(graph);
// Collect backend parameters with dynamic shapes.
auto base_shape = input_node->Shape();
if (input_node->isa<Parameter>() && base_shape != nullptr && base_shape->isa<abstract::Shape>()) {
if (AnfUtils::IsShapeDynamic(base_shape->cast<abstract::ShapePtr>())) {
if (from_index >= from_actor->backend_parameters_.size()) {
MS_LOG(EXCEPTION) << "Invalid from index:" << from_index << " for actor:" << from_actor->GetAID()
<< " vector size:" << from_actor->backend_parameters_.size();
}
MS_LOG(INFO) << "Add dynamic shape backend parameter:" << input_node->DebugString() << " index:" << from_index
<< " for actor:" << from_actor->GetAID();
from_actor->backend_parameters_[from_index].emplace_back(input_node);
}
}
if (!HasAbstractRef(input_node)) {
return;
}

View File

@ -95,6 +95,7 @@ class ControlNodeScheduler {
void LinkControlArrowForLoopCountActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info);
void LinkDataArrowForOutputActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
void LinkControlArrowForKernelActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
void LinkControlArrowForCustomActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
void LinkControlArrowByAutoMonad(ControlActor *to_actor, const AnfNodePtr &from_node,
const ControlNodeParserPtr &parser);

View File

@ -1585,6 +1585,14 @@ void GraphScheduler::LinkControlArrowForCustomActor(ActorSet *const actor_set,
no_depend_custom_actors.erase(std::dynamic_pointer_cast<CustomActor>(to_iter->second));
}
}
// In control flow, no input actors should be linked to entrance actors.
const auto &parser = graph_compiler_info.control_node_parser_;
MS_EXCEPTION_IF_NULL(parser);
if (parser->IsInited()) {
return;
}
for (const auto &custom_actor : no_depend_custom_actors) {
auto kernel = custom_actor->kernel().lock();
MS_EXCEPTION_IF_NULL(kernel);

View File

@ -0,0 +1,88 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
class UniqueIf(nn.Cell):
def __init__(self):
super(UniqueIf, self).__init__()
self.unique = P.Unique()
self.shape = P.DynamicShape()
def construct(self, x, index):
x_unique = self.unique(x)[0]
if index > 3:
x_unique = x_unique + 2
else:
x_unique = x_unique - 3
return self.shape(x_unique)
class UniqueWhile(nn.Cell):
def __init__(self):
super(UniqueWhile, self).__init__()
self.unique = P.Unique()
self.shape = P.DynamicShape()
self.mod = P.Mod()
def construct(self, x, y, index):
while index < 3:
x = self.mod(x, y[index])
x = self.unique(x)[0]
index = index + 1
return self.shape(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_unique_if():
"""
Feature: Dynamic shape for control flow.
Description: If scene.
Expectation: No exception.
"""
x = Tensor(np.array([4, 5, 1, 2, 3, 3, 4, 5]).astype(np.int32))
index = Tensor([0], mstype.int32)
context.set_context(mode=context.GRAPH_MODE)
net = UniqueIf()
x_shape = net(x, index)
assert x_shape == 5
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_unique_while():
"""
Feature: Dynamic shape for control flow.
Description: While scene.
Expectation: No exception.
"""
x = Tensor(np.array([12406268, 4962722, 720966, 75948, 6776, 960, 67, 8]).astype(np.int32))
y = Tensor(np.array([957, 67, 7]).astype(np.int32))
index = Tensor([0], mstype.int32)
context.set_context(mode=context.GRAPH_MODE)
net = UniqueWhile()
x_shape = net(x, y, index)
assert x_shape == 3