diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_kernel_executor.cc b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_kernel_executor.cc index 12ced52c276..4209e8b1973 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_kernel_executor.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_kernel_executor.cc @@ -230,18 +230,25 @@ void AscendKernelExecutor::PreprocessBeforeRunSingleOpGraph(const KernelGraphPtr static const std::set place_holder_nodes = {kDynamicRNNOpName, kDynamicGRUV2OpName}; auto iter = place_holder_nodes.find(op_name); if (iter != place_holder_nodes.end()) { - auto none_index = common::AnfAlgo::GetNodeAttr>(node, kAttrPlaceHolderIndex); - // Remove seq_length - auto input_num = common::AnfAlgo::GetInputTensorNum(node); - std::vector new_inputs = {common::AnfAlgo::GetCNodePrimitiveNode(node)}; - for (size_t i = 0; i < input_num; ++i) { - auto item = std::find(none_index.begin(), none_index.end(), i); - if (item == none_index.end()) { - auto input_node = common::AnfAlgo::GetInputNode(node, i); - new_inputs.emplace_back(input_node); + // keep placeholder for acl_kernel + auto is_acl_kernel = AnfAlgo::GetKernelType(node) == KernelType::ACL_KERNEL; + if (!is_acl_kernel) { + auto none_index = common::AnfAlgo::GetNodeAttr>(node, kAttrPlaceHolderIndex); + // Remove seq_length + auto input_num = common::AnfAlgo::GetInputTensorNum(node); + std::vector new_inputs = {common::AnfAlgo::GetCNodePrimitiveNode(node)}; + for (size_t i = 0; i < input_num; ++i) { + auto item = std::find(none_index.begin(), none_index.end(), i); + if (item == none_index.end()) { + auto input_node = common::AnfAlgo::GetInputNode(node, i); + new_inputs.emplace_back(input_node); + } } + (void)node->set_inputs(new_inputs); + // update attr + common::AnfAlgo::EraseNodeAttr(kAttrPlaceHolderIndex, node); + MS_LOG(DEBUG) << "Remove placeholder input and kAttrPlaceHolderIndex for " << op_name; } - (void)node->set_inputs(new_inputs); } // Save the nop_op that needs to be memcpy diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/acl/acl_kernel_mod.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/acl/acl_kernel_mod.cc index df5e014d37b..6f1103a846d 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/acl/acl_kernel_mod.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/acl/acl_kernel_mod.cc @@ -269,6 +269,16 @@ bool AclKernelMod::Launch(const std::vector &inputs, const std::vect } MS_LOG(DEBUG) << "Start aclopCompileAndExecute of node: " << node->fullname_with_scope() << " op_type_:" << op_type_; + if (op_desc_ptr->input_tensor_desc().size() != op_desc_ptr->input_tensor_data().size()) { + MS_LOG(ERROR) << "For input, the size of tensor_desc and tensor_data is inconsistent! node: " + << node->fullname_with_scope(); + return false; + } + if (op_desc_ptr->output_tensor_desc().size() != op_desc_ptr->output_tensor_data().size()) { + MS_LOG(ERROR) << "For output, the size of tensor_desc and tensor_data is inconsistent! node: " + << node->fullname_with_scope(); + return false; + } bool ret = aclopCompileAndExecute(const_cast(op_type_.c_str()), op_desc_ptr->input_tensor_desc().size(), op_desc_ptr->input_tensor_desc().data(), op_desc_ptr->input_tensor_data().data(), op_desc_ptr->output_tensor_desc().size(), op_desc_ptr->output_tensor_desc().data(), diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/acl/acl_kernel_utils.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/acl/acl_kernel_utils.cc index 1cbc5e85541..dd27a4998ad 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/acl/acl_kernel_utils.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/acl/acl_kernel_utils.cc @@ -183,7 +183,7 @@ void AclOpDesc::AddDataBuf(const std::vector &inputs, const std::vec MS_EXCEPTION_IF_NULL(node); const auto &input_names = AclUtils::GetOpInputAnchorNames(node); input_tensor_data_.clear(); - input_tensor_data_.resize(input_names.size(), aclCreateDataBuffer(nullptr, 0)); + input_tensor_data_.resize(input_names.size(), nullptr); for (size_t i = 0; i < inputs.size(); i++) { auto idx = AclUtils::GetInputKernelIdxByGraphIdx(node, i); if (idx < 0) { @@ -194,7 +194,9 @@ void AclOpDesc::AddDataBuf(const std::vector &inputs, const std::vec << ", node:" << node->fullname_with_scope(); } if (input_size_list[idx] == kSizeMax) { - CreateNullAclTensor(idx, true); + if (input_tensor_desc_[idx] != nullptr || common::AnfAlgo::IsNoneInput(node, i)) { + CreateNullAclTensor(idx, true); + } continue; } input_tensor_data_[idx] = CreateDataBuf(inputs[i], input_size_list[idx]); diff --git a/mindspore/ccsrc/runtime/pynative/run_op_helper.cc b/mindspore/ccsrc/runtime/pynative/run_op_helper.cc index 634f412ed3b..5a1e871ed4e 100644 --- a/mindspore/ccsrc/runtime/pynative/run_op_helper.cc +++ b/mindspore/ccsrc/runtime/pynative/run_op_helper.cc @@ -318,6 +318,10 @@ bool MallocForKernelInput(const std::shared_ptr &runtime_info, MS_EXCEPTION_IF_NULL(device_context->device_res_manager_); auto input_size = runtime_info->GetInputSize(); for (size_t i = 0; i < input_size; ++i) { + if (common::AnfAlgo::IsNoneInput(node, i)) { + MS_LOG(DEBUG) << "Input [" << i << "] of " << node->fullname_with_scope() << " is None."; + continue; + } auto input_address = runtime_info->GetInputDeviceAddress(i); kernel_mod->set_input_user_data(input_address->user_data().get(), i); MS_EXCEPTION_IF_NULL(input_address); @@ -372,11 +376,18 @@ bool MallocForKernelOutput(const std::shared_ptr &runtime_info, c return true; } -kernel::AddressPtrList CreateKernelInputAddress(const std::shared_ptr &runtime_info) { +kernel::AddressPtrList CreateKernelInputAddress(const std::shared_ptr &runtime_info, + const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(runtime_info); auto input_size = runtime_info->GetInputSize(); kernel::AddressPtrList inputs; for (size_t i = 0; i < input_size; ++i) { + if (common::AnfAlgo::IsNoneInput(node, i)) { + (void)inputs.emplace_back(std::make_shared()); + MS_LOG(DEBUG) << "Input[" << i << "]:" + << " is None Input"; + continue; + } auto device_address = runtime_info->GetInputDeviceAddress(i); MS_EXCEPTION_IF_NULL(device_address); (void)inputs.emplace_back( @@ -561,7 +572,7 @@ void LaunchKernelsDynamic(const KernelGraphPtr &graph, const device::DeviceConte if (!MallocForKernelInput(runtime_info, device_context, node)) { MS_LOG(EXCEPTION) << "Malloc for kernel input failed, Memory isn't enough, node:" << node->fullname_with_scope(); } - auto inputs = CreateKernelInputAddress(runtime_info); + auto inputs = CreateKernelInputAddress(runtime_info, node); InferNodeRealShape(node); @@ -610,7 +621,7 @@ void LaunchKernels(const KernelGraphPtr &graph, const device::DeviceContext *dev if (!MallocForKernelInput(runtime_info, device_context, node)) { MS_LOG(EXCEPTION) << "Malloc for kernel input failed, Memory isn't enough, node:" << node->fullname_with_scope(); } - auto inputs = CreateKernelInputAddress(runtime_info); + auto inputs = CreateKernelInputAddress(runtime_info, node); if (is_dynamic_shape) { InferNodeRealShape(node); ResizeNodeInput(node); diff --git a/mindspore/ccsrc/utils/anfalgo.cc b/mindspore/ccsrc/utils/anfalgo.cc index d46b1bbaf17..3588725cf24 100644 --- a/mindspore/ccsrc/utils/anfalgo.cc +++ b/mindspore/ccsrc/utils/anfalgo.cc @@ -1562,17 +1562,17 @@ bool AnfAlgo::IsNonTaskOp(const CNodePtr &node) { } bool AnfAlgo::IsNoneInput(const AnfNodePtr &node, size_t index) { - auto op_name = GetCNodeName(node); - constexpr auto none_placeholder_index = 3; - if (op_name == kDynamicRNNOpName && index == none_placeholder_index) { - return true; + MS_EXCEPTION_IF_NULL(node); + static std::set node_set = {kDynamicRNNOpName, kDynamicGRUV2OpName}; + auto cnode_name = common::AnfAlgo::GetCNodeName(node); + if (node_set.find(cnode_name) == node_set.end()) { + return false; } - if (op_name == kDynamicGRUV2OpName) { - auto none_index = AnfAlgo::GetNodeAttr>(node, kAttrPlaceHolderIndex); - auto item = std::find(none_index.begin(), none_index.end(), index); - if (item != none_index.end()) { - return true; - } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (common::AnfAlgo::HasNodeAttr(kAttrPlaceHolderIndex, cnode)) { + auto none_index = common::AnfAlgo::GetNodeAttr>(node, kAttrPlaceHolderIndex); + return find(none_index.begin(), none_index.end(), index) != none_index.end(); } return false; } diff --git a/mindspore/core/ops/dynamic_rnn.cc b/mindspore/core/ops/dynamic_rnn.cc index 6449c860cf7..16626ea0bb3 100644 --- a/mindspore/core/ops/dynamic_rnn.cc +++ b/mindspore/core/ops/dynamic_rnn.cc @@ -57,6 +57,7 @@ constexpr int64_t kDynamicRnnShapeB = 1; constexpr int64_t kDynamicRnnShapeH = 3; constexpr int64_t kDynamicRnnShapeC = 3; constexpr int64_t kDynRnnNum4 = 4; +constexpr int64_t kDynRnnInputNum = 6; abstract::TupleShapePtr DynamicRNNInferDynamicShape(const std::vector &input_args) { const int64_t y_shape_num = 3; @@ -121,6 +122,7 @@ void DynamicRNNShapeCheck(const PrimitivePtr &primitive, const std::vector &input_args) { + CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kDynRnnInputNum, primitive->name()); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx0]->BuildShape())[kShape]; auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx1]->BuildShape())[kShape]; std::vector placeholder_index = {MakeValue((int64_t)3)}; @@ -153,6 +155,7 @@ abstract::TupleShapePtr DynamicRNNInferShape(const PrimitivePtr &primitive, } TuplePtr DynamicRNNInferType(const PrimitivePtr &primitive, const std::vector &input_args) { + CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kDynRnnInputNum, primitive->name()); auto op_name = primitive->name(); auto x_dtype = input_args[kDynRnnIdx0]->BuildType(); auto w_dtype = input_args[kDynRnnIdx1]->BuildType(); @@ -186,8 +189,6 @@ MIND_API_OPERATOR_IMPL(DynamicRNN, BaseOperator); AbstractBasePtr DynamicRNNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - const int64_t input_num = 6; - CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name()); auto type = DynamicRNNInferType(primitive, input_args); auto shape = DynamicRNNInferShape(primitive, input_args); return abstract::MakeAbstract(shape, type); diff --git a/tests/st/ops/ascend/test_acl_ops/test_dynamic_rnn.py b/tests/st/ops/ascend/test_acl_ops/test_dynamic_rnn.py new file mode 100644 index 00000000000..063d8b89e32 --- /dev/null +++ b/tests/st/ops/ascend/test_acl_ops/test_dynamic_rnn.py @@ -0,0 +1,57 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore +from mindspore import context +from mindspore.common.tensor import Tensor +from mindspore.nn import Cell +from mindspore.ops import operations as P + +context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + + +class Net(Cell): + "DynamicRNN network." + + def __init__(self): + super(Net, self).__init__() + self.op = P.DynamicRNN() + + def construct(self, x, w, b, init_h, init_c): + x = self.op(x, w, b, None, init_h, init_c) + return x + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_batchmatmul_acl_dynamic_shape(): + """ + Feature: Test acl call with pynative mode and dynamic shape. + Description: The first input is dynamic. + Expectation: print output x. + """ + np.random.seed(1024) + x = Tensor(np.random.rand(2, 16, 64).astype(np.float16)) + w = Tensor(np.random.rand(96, 128).astype(np.float16)) + b = Tensor(np.random.rand(128).astype(np.float16)) + init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) + init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) + dynamic_rnn = Net() + dynamic_rnn.set_inputs(Tensor(shape=[None, 16, 64], dtype=mindspore.float16), w, b, init_h, init_c) + output = dynamic_rnn(x, w, b, init_h, init_c) + print(output)