!49130 support DynamicRNN op with the PyNative dynamic shape on ascend device

Merge pull request !49130 from hanhuifeng/acl_dynamic_rnn
This commit is contained in:
i-robot 2023-02-22 06:34:22 +00:00 committed by Gitee
commit 9382f35fa4
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 115 additions and 27 deletions

View File

@ -230,18 +230,25 @@ void AscendKernelExecutor::PreprocessBeforeRunSingleOpGraph(const KernelGraphPtr
static const std::set<std::string> place_holder_nodes = {kDynamicRNNOpName, kDynamicGRUV2OpName};
auto iter = place_holder_nodes.find(op_name);
if (iter != place_holder_nodes.end()) {
auto none_index = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, kAttrPlaceHolderIndex);
// Remove seq_length
auto input_num = common::AnfAlgo::GetInputTensorNum(node);
std::vector<AnfNodePtr> new_inputs = {common::AnfAlgo::GetCNodePrimitiveNode(node)};
for (size_t i = 0; i < input_num; ++i) {
auto item = std::find(none_index.begin(), none_index.end(), i);
if (item == none_index.end()) {
auto input_node = common::AnfAlgo::GetInputNode(node, i);
new_inputs.emplace_back(input_node);
// keep placeholder for acl_kernel
auto is_acl_kernel = AnfAlgo::GetKernelType(node) == KernelType::ACL_KERNEL;
if (!is_acl_kernel) {
auto none_index = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, kAttrPlaceHolderIndex);
// Remove seq_length
auto input_num = common::AnfAlgo::GetInputTensorNum(node);
std::vector<AnfNodePtr> new_inputs = {common::AnfAlgo::GetCNodePrimitiveNode(node)};
for (size_t i = 0; i < input_num; ++i) {
auto item = std::find(none_index.begin(), none_index.end(), i);
if (item == none_index.end()) {
auto input_node = common::AnfAlgo::GetInputNode(node, i);
new_inputs.emplace_back(input_node);
}
}
(void)node->set_inputs(new_inputs);
// update attr
common::AnfAlgo::EraseNodeAttr(kAttrPlaceHolderIndex, node);
MS_LOG(DEBUG) << "Remove placeholder input and kAttrPlaceHolderIndex for " << op_name;
}
(void)node->set_inputs(new_inputs);
}
// Save the nop_op that needs to be memcpy

View File

@ -269,6 +269,16 @@ bool AclKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vect
}
MS_LOG(DEBUG) << "Start aclopCompileAndExecute of node: " << node->fullname_with_scope() << " op_type_:" << op_type_;
if (op_desc_ptr->input_tensor_desc().size() != op_desc_ptr->input_tensor_data().size()) {
MS_LOG(ERROR) << "For input, the size of tensor_desc and tensor_data is inconsistent! node: "
<< node->fullname_with_scope();
return false;
}
if (op_desc_ptr->output_tensor_desc().size() != op_desc_ptr->output_tensor_data().size()) {
MS_LOG(ERROR) << "For output, the size of tensor_desc and tensor_data is inconsistent! node: "
<< node->fullname_with_scope();
return false;
}
bool ret = aclopCompileAndExecute(const_cast<char *>(op_type_.c_str()), op_desc_ptr->input_tensor_desc().size(),
op_desc_ptr->input_tensor_desc().data(), op_desc_ptr->input_tensor_data().data(),
op_desc_ptr->output_tensor_desc().size(), op_desc_ptr->output_tensor_desc().data(),

View File

@ -183,7 +183,7 @@ void AclOpDesc::AddDataBuf(const std::vector<AddressPtr> &inputs, const std::vec
MS_EXCEPTION_IF_NULL(node);
const auto &input_names = AclUtils::GetOpInputAnchorNames(node);
input_tensor_data_.clear();
input_tensor_data_.resize(input_names.size(), aclCreateDataBuffer(nullptr, 0));
input_tensor_data_.resize(input_names.size(), nullptr);
for (size_t i = 0; i < inputs.size(); i++) {
auto idx = AclUtils::GetInputKernelIdxByGraphIdx(node, i);
if (idx < 0) {
@ -194,7 +194,9 @@ void AclOpDesc::AddDataBuf(const std::vector<AddressPtr> &inputs, const std::vec
<< ", node:" << node->fullname_with_scope();
}
if (input_size_list[idx] == kSizeMax) {
CreateNullAclTensor(idx, true);
if (input_tensor_desc_[idx] != nullptr || common::AnfAlgo::IsNoneInput(node, i)) {
CreateNullAclTensor(idx, true);
}
continue;
}
input_tensor_data_[idx] = CreateDataBuf(inputs[i], input_size_list[idx]);

View File

@ -318,6 +318,10 @@ bool MallocForKernelInput(const std::shared_ptr<OpRuntimeInfo> &runtime_info,
MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
auto input_size = runtime_info->GetInputSize();
for (size_t i = 0; i < input_size; ++i) {
if (common::AnfAlgo::IsNoneInput(node, i)) {
MS_LOG(DEBUG) << "Input [" << i << "] of " << node->fullname_with_scope() << " is None.";
continue;
}
auto input_address = runtime_info->GetInputDeviceAddress(i);
kernel_mod->set_input_user_data(input_address->user_data().get(), i);
MS_EXCEPTION_IF_NULL(input_address);
@ -372,11 +376,18 @@ bool MallocForKernelOutput(const std::shared_ptr<OpRuntimeInfo> &runtime_info, c
return true;
}
kernel::AddressPtrList CreateKernelInputAddress(const std::shared_ptr<OpRuntimeInfo> &runtime_info) {
kernel::AddressPtrList CreateKernelInputAddress(const std::shared_ptr<OpRuntimeInfo> &runtime_info,
const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(runtime_info);
auto input_size = runtime_info->GetInputSize();
kernel::AddressPtrList inputs;
for (size_t i = 0; i < input_size; ++i) {
if (common::AnfAlgo::IsNoneInput(node, i)) {
(void)inputs.emplace_back(std::make_shared<kernel::Address>());
MS_LOG(DEBUG) << "Input[" << i << "]:"
<< " is None Input";
continue;
}
auto device_address = runtime_info->GetInputDeviceAddress(i);
MS_EXCEPTION_IF_NULL(device_address);
(void)inputs.emplace_back(
@ -561,7 +572,7 @@ void LaunchKernelsDynamic(const KernelGraphPtr &graph, const device::DeviceConte
if (!MallocForKernelInput(runtime_info, device_context, node)) {
MS_LOG(EXCEPTION) << "Malloc for kernel input failed, Memory isn't enough, node:" << node->fullname_with_scope();
}
auto inputs = CreateKernelInputAddress(runtime_info);
auto inputs = CreateKernelInputAddress(runtime_info, node);
InferNodeRealShape(node);
@ -610,7 +621,7 @@ void LaunchKernels(const KernelGraphPtr &graph, const device::DeviceContext *dev
if (!MallocForKernelInput(runtime_info, device_context, node)) {
MS_LOG(EXCEPTION) << "Malloc for kernel input failed, Memory isn't enough, node:" << node->fullname_with_scope();
}
auto inputs = CreateKernelInputAddress(runtime_info);
auto inputs = CreateKernelInputAddress(runtime_info, node);
if (is_dynamic_shape) {
InferNodeRealShape(node);
ResizeNodeInput(node);

View File

@ -1562,17 +1562,17 @@ bool AnfAlgo::IsNonTaskOp(const CNodePtr &node) {
}
bool AnfAlgo::IsNoneInput(const AnfNodePtr &node, size_t index) {
auto op_name = GetCNodeName(node);
constexpr auto none_placeholder_index = 3;
if (op_name == kDynamicRNNOpName && index == none_placeholder_index) {
return true;
MS_EXCEPTION_IF_NULL(node);
static std::set<std::string> node_set = {kDynamicRNNOpName, kDynamicGRUV2OpName};
auto cnode_name = common::AnfAlgo::GetCNodeName(node);
if (node_set.find(cnode_name) == node_set.end()) {
return false;
}
if (op_name == kDynamicGRUV2OpName) {
auto none_index = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, kAttrPlaceHolderIndex);
auto item = std::find(none_index.begin(), none_index.end(), index);
if (item != none_index.end()) {
return true;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (common::AnfAlgo::HasNodeAttr(kAttrPlaceHolderIndex, cnode)) {
auto none_index = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, kAttrPlaceHolderIndex);
return find(none_index.begin(), none_index.end(), index) != none_index.end();
}
return false;
}

View File

@ -57,6 +57,7 @@ constexpr int64_t kDynamicRnnShapeB = 1;
constexpr int64_t kDynamicRnnShapeH = 3;
constexpr int64_t kDynamicRnnShapeC = 3;
constexpr int64_t kDynRnnNum4 = 4;
constexpr int64_t kDynRnnInputNum = 6;
abstract::TupleShapePtr DynamicRNNInferDynamicShape(const std::vector<AbstractBasePtr> &input_args) {
const int64_t y_shape_num = 3;
@ -121,6 +122,7 @@ void DynamicRNNShapeCheck(const PrimitivePtr &primitive, const std::vector<Abstr
abstract::TupleShapePtr DynamicRNNInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kDynRnnInputNum, primitive->name());
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx0]->BuildShape())[kShape];
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx1]->BuildShape())[kShape];
std::vector<ValuePtr> placeholder_index = {MakeValue((int64_t)3)};
@ -153,6 +155,7 @@ abstract::TupleShapePtr DynamicRNNInferShape(const PrimitivePtr &primitive,
}
TuplePtr DynamicRNNInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kDynRnnInputNum, primitive->name());
auto op_name = primitive->name();
auto x_dtype = input_args[kDynRnnIdx0]->BuildType();
auto w_dtype = input_args[kDynRnnIdx1]->BuildType();
@ -186,8 +189,6 @@ MIND_API_OPERATOR_IMPL(DynamicRNN, BaseOperator);
AbstractBasePtr DynamicRNNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 6;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
auto type = DynamicRNNInferType(primitive, input_args);
auto shape = DynamicRNNInferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);

View File

@ -0,0 +1,57 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.nn import Cell
from mindspore.ops import operations as P
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
class Net(Cell):
"DynamicRNN network."
def __init__(self):
super(Net, self).__init__()
self.op = P.DynamicRNN()
def construct(self, x, w, b, init_h, init_c):
x = self.op(x, w, b, None, init_h, init_c)
return x
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_batchmatmul_acl_dynamic_shape():
"""
Feature: Test acl call with pynative mode and dynamic shape.
Description: The first input is dynamic.
Expectation: print output x.
"""
np.random.seed(1024)
x = Tensor(np.random.rand(2, 16, 64).astype(np.float16))
w = Tensor(np.random.rand(96, 128).astype(np.float16))
b = Tensor(np.random.rand(128).astype(np.float16))
init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
dynamic_rnn = Net()
dynamic_rnn.set_inputs(Tensor(shape=[None, 16, 64], dtype=mindspore.float16), w, b, init_h, init_c)
output = dynamic_rnn(x, w, b, init_h, init_c)
print(output)