diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index c2db7c398aa..6f8b135f59e 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -411,6 +411,11 @@ void AscendBackendUBFusionOptimization(const std::shared_ptris_dynamic_shape()) { + MS_LOG(WARNING) << "Dynamic shape skip fusion"; + return; + } bool save_graphs = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG); if (save_graphs) { std::string file_name = "hwopt_d_ub_fusion_before_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc index f63b91d02a6..1f5594502e6 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc @@ -54,7 +54,11 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo auto imm = std::make_shared(output_idx); idx->set_abstract(std::make_shared(imm)); auto getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx}); - AnfAlgo::SetOutputTypeAndDetailShape({origin_type}, {origin_shape}, getitem.get()); + auto abs = cnode->abstract()->cast(); + MS_EXCEPTION_IF_NULL(abs); + auto abs_i = abs->elements()[output_idx]; + MS_EXCEPTION_IF_NULL(abs_i); + getitem->set_abstract(abs_i); const auto dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx); const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx); if (origin_type != device_type) { diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc index 4968ea8c6e6..7cb9b77eb0d 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.cc +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -492,9 +492,11 @@ CNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr CNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); MS_EXCEPTION_IF_NULL(tuple_getitem); tuple_getitem->set_scope(node->scope()); - auto origin_shape = AnfAlgo::GetOutputDetailShape(node, output_idx); - TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx); - AnfAlgo::SetOutputTypeAndDetailShape({origin_type}, {origin_shape}, tuple_getitem.get()); + auto abs = node->abstract()->cast(); + MS_EXCEPTION_IF_NULL(abs); + auto abs_i = abs->elements()[output_idx]; + MS_EXCEPTION_IF_NULL(abs_i); + tuple_getitem->set_abstract(abs_i); return tuple_getitem; } diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.cc b/mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.cc index fb1f0f00a51..0d1e3bd7494 100644 --- a/mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.cc +++ b/mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.cc @@ -165,10 +165,6 @@ bool AiCpuDynamicKernel::UpdateExtInfo() { } bool AiCpuDynamicKernel::UpdateOutputShapeFromExtInfo() { - if (input_num_ == 0) { - MS_LOG(WARNING) << "input num is 0"; - return true; - } MS_LOG(INFO) << "UpdateOutputShapeFromExtInfo start"; auto ret = rtMemcpy(ext_info_handler_->GetExtInfo(), ext_info_handler_->GetExtInfoLen(), ext_info_addr_dev_, ext_info_size_, RT_MEMCPY_DEVICE_TO_HOST); diff --git a/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc index 8a94a7d0760..16f8f848d67 100644 --- a/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc +++ b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc @@ -73,9 +73,6 @@ void DynamicKernel::RebuildDependTensor() { } void DynamicKernel::InferShape() { - if (!is_input_dynamic_shape_ && is_output_dynamic_shape_ && !have_depends()) { - return; - } auto cnode = cnode_ptr_.lock(); MS_EXCEPTION_IF_NULL(cnode); MS_LOG(INFO) << "InferShape start, node:" << cnode->fullname_with_scope(); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 564691dddad..4d6838a0327 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -602,7 +602,7 @@ const std::set kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat const std::set kComputeDepend = {kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName, kPadAndShiftOpName, kCTCGreedyDecoderOpName, kDropoutGenMaskOpName, - kMaskedSelectOpName, kDynamicStitchOpName}; + kMaskedSelectOpName, kDynamicStitchOpName, kGetNextOpName}; const std::set k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D, kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC}; diff --git a/mindspore/core/abstract/utils.cc b/mindspore/core/abstract/utils.cc index 5cb9f381a83..d04cbda2864 100644 --- a/mindspore/core/abstract/utils.cc +++ b/mindspore/core/abstract/utils.cc @@ -361,13 +361,11 @@ AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type) { ShapeVector min_shape_vec; ShapeVector max_shape_vec; - if (shape->IsDynamic()) { - if (!shape->min_shape().empty()) { - min_shape_vec = shape->min_shape(); - } - if (!shape->max_shape().empty()) { - max_shape_vec = shape->max_shape(); - } + if (!shape->min_shape().empty()) { + min_shape_vec = shape->min_shape(); + } + if (!shape->max_shape().empty()) { + max_shape_vec = shape->max_shape(); } auto ret_shape = std::make_shared(ret_vec, min_shape_vec, max_shape_vec); diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index fad1f174d34..9dfd38cc9ee 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -32,6 +32,7 @@ inline const std::unordered_map kSideEffectPropagate = { {mindspore::GRAPH_FLAG_SIDE_EFFECT_PROPAGATE, kValueOne}, }; +constexpr auto kGetNext = "GetNext"; constexpr auto kGather = "Gather"; // Arithmetic constexpr auto kScalarAdd = "ScalarAdd"; @@ -92,6 +93,9 @@ constexpr auto kDropoutGrad = "DropoutGrad"; constexpr auto kConv2DTranspose = "Conv2DTranspose"; // Here list all primitives used in backend or some special primitives used by core. +// GetNext +inline const PrimitivePtr kPrimGetNext = std::make_shared(kGetNext); + // Arithmetic inline const PrimitivePtr kPrimScalarAdd = std::make_shared(kScalarAdd); inline const PrimitivePtr kPrimScalarSub = std::make_shared(kScalarSub); diff --git a/mindspore/core/ops/getnext.cc b/mindspore/core/ops/getnext.cc new file mode 100644 index 00000000000..09364888bd3 --- /dev/null +++ b/mindspore/core/ops/getnext.cc @@ -0,0 +1,96 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "ops/getnext.h" +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" +#include "abstract/primitive_infer_map.h" + +namespace mindspore { +namespace ops { +namespace { + +void GetShapeVector(const ValuePtr &shape_attr, std::vector> *shape_vec) { + if (shape_attr == nullptr) { + return; + } + std::vector shape = shape_attr->isa() ? shape_attr->cast()->value() + : shape_attr->cast()->value(); + for (ValuePtr shape_elements : shape) { + std::vector shape_elements_list = shape_elements->isa() + ? shape_elements->cast()->value() + : shape_elements->cast()->value(); + std::vector shape_vec_item; + (void)std::transform(std::begin(shape_elements_list), std::end(shape_elements_list), + std::back_inserter(shape_vec_item), + [](const ValuePtr &e) -> int64_t { return GetValue(e); }); + shape_vec->push_back(shape_vec_item); + } +} + +bool IsDynamic(const std::vector &shape) { + for (auto shape_vec : shape) { + if (std::find(shape_vec.begin(), shape_vec.end(), -1) != shape_vec.end()) { + return true; + } + } + return false; +} + +abstract::AbstractBasePtr GetnextInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto types = GetValue>(primitive->GetAttr("types")); + ValuePtr shape_attr = primitive->GetAttr("shapes"); + ValuePtr min_shape_attr = primitive->GetAttr("min_shapes"); + ValuePtr max_shape_attr = primitive->GetAttr("max_shapes"); + + std::vector shape; + std::vector min_shape; + std::vector max_shape; + + GetShapeVector(shape_attr, &shape); + GetShapeVector(min_shape_attr, &min_shape); + GetShapeVector(max_shape_attr, &max_shape); + + bool is_dynamic = IsDynamic(shape); + + AbstractBasePtrList output; + for (size_t i = 0; i < shape.size(); ++i) { + auto ret_shape = !min_shape.empty() && !max_shape.empty() && is_dynamic + ? std::make_shared(shape[i], min_shape[i], max_shape[i]) + : std::make_shared(shape[i]); + auto element = std::make_shared(kAnyValue, types[i]); + auto tensor = std::make_shared(element, ret_shape); + output.push_back(tensor); + } + return std::make_shared(output); +} +} // namespace + +AbstractBasePtr GetNextInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + return GetnextInferShape(primitive, input_args); +} +REGISTER_PRIMITIVE_EVAL_IMPL(GetNext, prim::kPrimGetNext, GetNextInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/getnext.h b/mindspore/core/ops/getnext.h new file mode 100644 index 00000000000..78acd30f76f --- /dev/null +++ b/mindspore/core/ops/getnext.h @@ -0,0 +1,43 @@ + +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_GETNEXT_H_ +#define MINDSPORE_CORE_OPS_GETNEXT_H_ +#include +#include +#include +#include +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameGetNext = prim::kGetNext; +class GetNext : public PrimitiveC { + public: + GetNext() : PrimitiveC(prim::kPrimGetNext->name()) {} + ~GetNext() = default; + MS_DECLARE_PARENT(GetNext, PrimitiveC); + void Init() {} +}; +AbstractBasePtr GetNextInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimMulPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_GETNEXT_H_ diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 6be71d0e038..53ac3b027f5 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -3498,7 +3498,7 @@ class FastGeLU(PrimitiveWithInfer): return input_x -class GetNext(PrimitiveWithInfer): +class GetNext(Primitive): """ Returns the next element in the dataset queue. @@ -3545,12 +3545,6 @@ class GetNext(PrimitiveWithInfer): validator.check("types length", len(types), "shapes length", len(shapes), Rel.EQ, self.name) validator.check_value_type("output_num", output_num, [int], self.name) - def infer_shape(self): - return tuple(self.shapes) - - def infer_dtype(self): - return tuple(self.types) - class PReLU(PrimitiveWithInfer): r""" diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index 39df6c954b4..8dcddefd8a0 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -60,13 +60,18 @@ class _DataWrapper(nn.Cell): dataset channel 'queue_name' and performs the forward computation. """ - def __init__(self, network, dataset_types, dataset_shapes, queue_name): + def __init__(self, network, dataset_types, dataset_shapes, queue_name, min_shapes=None, max_shapes=None): super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags()) # Also copy the flag in `network` construct flags = getattr(network.__class__.construct, "_mindspore_flags", {}) self.info = (dataset_types, dataset_shapes) self.add_flags(**flags) self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name) + if min_shapes is not None and max_shapes is not None: + Validator.check_value_type("min_shapes", min_shapes, [list, tuple]) + Validator.check_value_type("max_shapes", max_shapes, [list, tuple]) + self.get_next.add_prim_attr("min_shapes", min_shapes) + self.get_next.add_prim_attr("max_shapes", max_shapes) self.network = network def construct(self): @@ -74,9 +79,24 @@ class _DataWrapper(nn.Cell): return self.network(*outputs) -def _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name): +def _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name, + min_shapes=None, max_shapes=None): if not isinstance(network, _DataWrapper): - network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name) + network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name, min_shapes, max_shapes) + return network + +def has_dynamic_shape(dataset_shapes): + for shape in dataset_shapes: + if -1 in shape: + return True + return False + +def _generate_network_with_dataset(network, dataset_helper, queue_name): + dataset_types, dataset_shapes = dataset_helper.types_shapes() + (min_shapes, max_shapes) = (None, None) if not has_dynamic_shape(dataset_shapes) \ + else dataset_helper.dynamic_min_max_shapes() + network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, + queue_name, min_shapes, max_shapes) return network @@ -151,8 +171,7 @@ def connect_network_with_dataset(network, dataset_helper): not context.get_context("enable_ge") and \ context.get_context("device_target") in ("Ascend", "GPU"): dataset.__me_inited__ = True - dataset_types, dataset_shapes = dataset_helper.types_shapes() - network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name) + network = _generate_network_with_dataset(network, dataset_helper, queue_name) if hasattr(dataset_iter, "sink_size") and \ dataset_iter.sink_size == 1 and \ @@ -286,7 +305,7 @@ class _DatasetIter: self.release = dataset.__transfer_dataset__.release self.continue_send = dataset.__transfer_dataset__.continue_send self.get_data_info = dataset.__transfer_dataset__.get_data_info - self.dynamic_min_max_shapes = dataset.__transfer_dataset__.dynamic_min_max_shapes + self.dynamic_min_max_shapes = dataset.dynamic_min_max_shapes self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) def __iter__(self): diff --git a/tests/st/dynamic_shape/test_getnext_dynamic_pipeline.py b/tests/st/dynamic_shape/test_getnext_dynamic_pipeline.py new file mode 100644 index 00000000000..9552b039382 --- /dev/null +++ b/tests/st/dynamic_shape/test_getnext_dynamic_pipeline.py @@ -0,0 +1,81 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np +import pytest +from mindspore import nn, context +from mindspore import ops as P +from mindspore.train import DatasetHelper, connect_network_with_dataset +import mindspore.dataset as ds +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +def _exec_preprocess(network, is_train, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1, dataset_helper=None): + if dataset_sink_mode and not is_train: + dataset.__loop_size__ = 1 + + if dataset_helper is None: + dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num) + + if dataset_sink_mode: + network = connect_network_with_dataset(network, dataset_helper) + + network.set_train(is_train) + + return dataset_helper, network + + +def _eval_dataset_sink_process(network, valid_dataset): + dataset_helper, eval_network = _exec_preprocess(network, is_train=False, dataset=valid_dataset, + dataset_sink_mode=True) + for inputs1, inputs2 in zip(dataset_helper, valid_dataset.create_dict_iterator()): + outputs = eval_network(*inputs1) + for elem1, (_, elem2) in zip(outputs, inputs2.items()): + assert elem1.shape == elem2.shape + +def dataset_generator(): + for i in range(1, 10): + yield ( + np.ones((32, i), dtype=np.float32), np.zeros((32, i, i, 3), dtype=np.int32), + np.ones((32,), dtype=np.float32), + np.ones((32, i, 8), dtype=np.float32), np.ones((32, 8, 8), dtype=np.float32)) + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.relu = P.ReLU() + + def construct(self, x1, x2, x3, x4, x5): + x1 = self.relu(x1) + x1 = self.relu(x1) + + x2 = self.relu(x2) + + x3 = self.relu(x3) + x3 = self.relu(x3) + + x4 = self.relu(x4) + + x5 = self.relu(x5) + return x1, x2, x3, x4, x5 + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_getnext_dynamic_pipeline(): + network = Net() + dataset = ds.GeneratorDataset(dataset_generator, ["data1", "data2", "data3", "data4", "data5"]) + dataset.set_dynamic_columns(columns={"data1": [32, None], "data2": [32, None, None, 3], + "data3": [32], "data4": [32, None, 8], "data5": [32, 8, 8]}) + _eval_dataset_sink_process(network, dataset)