diff --git a/mindspore/ccsrc/backend/common/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/common/session/anf_runtime_algorithm.cc index 22c30d39577..ca7e9296a4d 100644 --- a/mindspore/ccsrc/backend/common/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/common/session/anf_runtime_algorithm.cc @@ -551,7 +551,7 @@ std::vector AnfRuntimeAlgorithm::GetOutputDeviceShapeForTbeBuild(const } // if format is default_format or NC1KHKWHWC0,device shape = original shape - if (trans::IsNeedPadding(format, infer_shape.size())) { + if (trans::IsNeedPadding(format, infer_shape)) { infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx), node); } auto dtype = GetOutputDeviceDataType(node, output_idx); @@ -570,7 +570,7 @@ ShapeVector AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, si } // if format is default_format or NC1KHKWHWC0,device shape = original shape - if (trans::IsNeedPadding(format, infer_shape.size())) { + if (trans::IsNeedPadding(format, infer_shape)) { infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx), node); } auto dtype = GetOutputDeviceDataType(node, output_idx); @@ -591,7 +591,7 @@ std::vector AnfRuntimeAlgorithm::GetInputDeviceShapeForTbeBuild(const A } // if format is default_format or NC1KHKWHWC0,device shape = original shape - if (trans::IsNeedPadding(format, infer_shape.size())) { + if (trans::IsNeedPadding(format, infer_shape)) { infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx), node); } auto dtype = GetInputDeviceDataType(node, input_idx); @@ -605,7 +605,7 @@ std::vector AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr & return infer_shape; } // if format is default_format or NC1KHKWHWC0,device shape = original shape - if (trans::IsNeedPadding(format, infer_shape.size())) { + if (trans::IsNeedPadding(format, infer_shape)) { infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx), node); } auto dtype = GetInputDeviceDataType(node, input_idx); diff --git a/mindspore/ccsrc/include/common/utils/anfalgo.h b/mindspore/ccsrc/include/common/utils/anfalgo.h index f3d057d8bb5..87808481675 100644 --- a/mindspore/ccsrc/include/common/utils/anfalgo.h +++ b/mindspore/ccsrc/include/common/utils/anfalgo.h @@ -177,6 +177,11 @@ class COMMON_EXPORT AnfAlgo { static bool IsNodeInputDynamicShape(const CNodePtr &anf_node_ptr); static bool IsNodeOutputDynamicShape(const AnfNodePtr &node); static bool IsDynamicShape(const AnfNodePtr &node); + static bool IsDynamicRankNode(const AnfNodePtr &node); + static bool IsNodeInputDynamicRank(const CNodePtr &anf_node_ptr); + static bool IsNodeOutputDynamicRank(const AnfNodePtr &node); + static bool IsInputAnchorDynamicRank(const AnfNodePtr &node, size_t idx); + static bool IsOutputAnchorDynamicRank(const AnfNodePtr &node, size_t idx); static bool HasDynamicShapeFlag(const PrimitivePtr &prim); static bool IsCondControlKernel(const CNodePtr &node); static bool GetBooleanAttr(const AnfNodePtr &node, const std::string &attr); diff --git a/mindspore/ccsrc/include/common/utils/utils.h b/mindspore/ccsrc/include/common/utils/utils.h index ec4a7617c5f..d6c51b86885 100644 --- a/mindspore/ccsrc/include/common/utils/utils.h +++ b/mindspore/ccsrc/include/common/utils/utils.h @@ -1031,6 +1031,9 @@ constexpr auto kAttrKsizes = "ksizes"; constexpr auto kAttrIsKernelDynamicImpl = "is_kernel_dynamic_impl"; constexpr auto kAttrIsKernelDynamicShape = "is_kernel_dynamic_shape"; constexpr auto kAttrIsDynamicShape = "is_dynamic_shape"; +constexpr auto kAttrIsDynamicRank = "is_dynamic_rank"; +constexpr auto kAttrInputIsDynamicRank = "input_is_dynamic_rank"; +constexpr auto kAttrOutputIsDynamicRank = "output_is_dynamic_rank"; constexpr auto kAttrInputIsDynamicShape = "input_is_dynamic_shape"; constexpr auto kAttrOutputIsDynamicShape = "output_is_dynamic_shape"; constexpr auto kAttrPynativeNextOpName = "next_op"; @@ -1323,6 +1326,7 @@ COMMON_EXPORT bool IsOneOfComputeDepend(const std::string &name); COMMON_EXPORT bool IsOneOfHWSpecialFormat(const std::string &format); COMMON_EXPORT bool IsOneOfFormat(const std::string &format); COMMON_EXPORT bool IsOneOfServerFormatC04(const std::string &format); +COMMON_EXPORT bool IsOneOfDynRankNeedPadShape(const std::string &format); // The map between kernel's output and input ref relationship. // Key is the output index while the value is input index which will be used as the reference of output. diff --git a/mindspore/ccsrc/kernel/kernel_build_info.cc b/mindspore/ccsrc/kernel/kernel_build_info.cc index 0e28b3fc6d2..64aa943781d 100644 --- a/mindspore/ccsrc/kernel/kernel_build_info.cc +++ b/mindspore/ccsrc/kernel/kernel_build_info.cc @@ -178,11 +178,20 @@ void KernelBuildInfo::SetInputsDeviceType(const std::vector &inputs_devi void KernelBuildInfo::SetOutputFormat(const std::string &format, size_t index) { if (index >= outputs_format_.size()) { - MS_LOG(EXCEPTION) << "The index [" << index << "] is exceed the number of output"; + MS_LOG(EXCEPTION) << "The index [" << index + << "] is exceed the length of output formats list, total size:" << outputs_format_.size(); } outputs_format_[index] = format; } +void KernelBuildInfo::SetInputFormat(const std::string &format, size_t index) { + if (index >= inputs_format_.size()) { + MS_LOG(EXCEPTION) << "The index [" << index + << "] is exceed the length of input formats list, total size:" << inputs_format_.size(); + } + inputs_format_[index] = format; +} + void KernelBuildInfo::SetOutputsFormat(const std::vector &outputs_format) { outputs_format_ = outputs_format; } diff --git a/mindspore/ccsrc/kernel/kernel_build_info.h b/mindspore/ccsrc/kernel/kernel_build_info.h index 392e66ebd04..7f1ccd737ec 100644 --- a/mindspore/ccsrc/kernel/kernel_build_info.h +++ b/mindspore/ccsrc/kernel/kernel_build_info.h @@ -118,6 +118,8 @@ class BACKEND_EXPORT KernelBuildInfo { void SetOutputFormat(const std::string &format, size_t index); + void SetInputFormat(const std::string &format, size_t index); + void SetOutputDeviceType(const TypeId &output_device_type, size_t index); void SetInputsFormat(const std::vector &inputs_format); diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/device/kernel_select_ascend.cc b/mindspore/ccsrc/plugin/device/ascend/hal/device/kernel_select_ascend.cc index 275441f9e47..4b98f2b9941 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/device/kernel_select_ascend.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/device/kernel_select_ascend.cc @@ -183,6 +183,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) { bool is_init = false; bool need_change_nd = false; bool is_5d_input = false; + bool is_dyn_rank = common::AnfAlgo::IsDynamicRankNode(cnode); size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode); for (size_t index = 0; index < input_num; ++index) { auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index); @@ -194,7 +195,11 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) { if (priority_matched_format != pre_output_format && pre_output_format != kOpFormat_DEFAULT) { priority_matched_format = kOpFormat_DEFAULT; } - auto input_shape_size = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size(); + const auto &prev_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); + if (IsDynamicRank(prev_shape)) { + is_dyn_rank = true; + } + auto input_shape_size = prev_shape.size(); if (input_shape_size == k5dSize) { is_5d_input = true; } @@ -206,6 +211,9 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) { if (is_5d_input && priority_matched_format != kOpFormat_FRAC_NZ) { priority_matched_format = kOpFormat_NDC1HWC0; } + if (is_dyn_rank) { + priority_matched_format = kOpFormat_ND; + } common::AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode); return priority_matched_format; } diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.cc index 4d9dbbc4e7c..76e7738279b 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include "utils/ms_context.h" #include "backend/common/optimizer/helper.h" #include "include/backend/anf_runtime_algorithm.h" @@ -93,8 +94,36 @@ inline void GetRangeByShape(const AnfNodePtr &anf_node, const ShapeVector &shape } } +ShapeVector TbeDynamicShapeUtil::UpdateShape(const AnfNodePtr &node, const std::string &format, + const ShapeVector &shape, size_t index, bool is_input, + bool *is_change_nd) { + MS_EXCEPTION_IF_NULL(node); + const std::set op_names = {kTransDataOpName}; + if (!node->isa() || op_names.find(common::AnfAlgo::GetCNodeName(node)) == op_names.end()) { + return shape; + } + std::string sp_format = format; + auto kernel_info = dynamic_cast(node->kernel_info()); + if (kernel_info->select_kernel_build_info() != nullptr) { + auto in_format = AnfAlgo::GetInputFormat(node, 0); + auto out_format = AnfAlgo::GetOutputFormat(node, 0); + sp_format = IsOneOfHWSpecialFormat(in_format) ? in_format : out_format; + } + + const auto &pad_idx = + is_input ? AnfAlgo::GetInputReshapeType(node, index) : AnfAlgo::GetOutputReshapeType(node, index); + if (format == kOpFormat_NCHW && shape.size() < kDim4 && IsOneOfDynRankNeedPadShape(sp_format)) { + if (is_change_nd) { + *is_change_nd = true; + } + return trans::PaddingShape(shape, sp_format, pad_idx); + } + return shape; +} + RangePair TbeDynamicShapeUtil::GetInputDynamicRange(const AnfNodePtr &anf_node, size_t index, - const std::string &def_format, const TypeId &type) { + const std::string &def_format, const std::string &ori_format, + const TypeId &type) { MS_EXCEPTION_IF_NULL(anf_node); auto kernel_info = dynamic_cast(anf_node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); @@ -110,13 +139,17 @@ RangePair TbeDynamicShapeUtil::GetInputDynamicRange(const AnfNodePtr &anf_node, auto prev_node = common::AnfAlgo::GetPrevNodeOutput(anf_node, index); MS_EXCEPTION_IF_NULL(prev_node.first); auto shape = common::AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second); + if (anf_node->isa()) { + shape = UpdateShape(anf_node, ori_format, shape, index, true); + } GetRangeByShape(anf_node, shape, &ret); return shapeRangeTransfer.GetRealRange(ret, format, data_type, reshape_type); } RangePair TbeDynamicShapeUtil::GetOutputDynamicRange(const AnfNodePtr &anf_node, size_t index, - const std::string &def_format, const TypeId &type) { + const std::string &def_format, const std::string &ori_format, + const TypeId &type) { MS_EXCEPTION_IF_NULL(anf_node); auto kernel_info = dynamic_cast(anf_node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); @@ -130,6 +163,9 @@ RangePair TbeDynamicShapeUtil::GetOutputDynamicRange(const AnfNodePtr &anf_node, RangePair ret; auto shape = common::AnfAlgo::GetOutputInferShape(anf_node, index); + if (anf_node->isa()) { + shape = UpdateShape(anf_node, ori_format, shape, index, false); + } GetRangeByShape(anf_node, shape, &ret); return shapeRangeTransfer.GetRealRange(ret, format, data_type, reshape_type); diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.h b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.h index 51151aa7944..4e377dd4a01 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.h +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.h @@ -32,15 +32,17 @@ class TbeDynamicShapeUtil { public: TbeDynamicShapeUtil() = default; ~TbeDynamicShapeUtil() = default; + static ShapeVector UpdateShape(const AnfNodePtr &node, const std::string &format, const ShapeVector &shape, + size_t index, bool is_input, bool *is_change_nd = nullptr); static bool GetDynamicShapeAttr(const CNodePtr &cnode); static bool GetDynamicShapeAttr(const AnfNodePtr &anf_node); static std::shared_ptr FindOp(const std::string &op_name, const AnfNodePtr &anf_node); static std::shared_ptr FindOp(const std::string &op_name, const CNodePtr &cnode); static std::shared_ptr FindOp(const CNodePtr &cnode); static RangePair GetInputDynamicRange(const AnfNodePtr &anf_node, size_t index, const std::string &def_format, - const TypeId &type); + const std::string &ori_format, const TypeId &type); static RangePair GetOutputDynamicRange(const AnfNodePtr &anf_node, size_t index, const std::string &def_format, - const TypeId &type); + const std::string &ori_format, const TypeId &type); }; } // namespace tbe } // namespace kernel diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_json/single_tbe_json_creator.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_json/single_tbe_json_creator.cc index c7da1242a34..a1632c688ed 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_json/single_tbe_json_creator.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_json/single_tbe_json_creator.cc @@ -221,10 +221,12 @@ void SingleTbeJsonCreator::GenInputDescJson(const AnfNodePtr &anf_node, size_t r infer_shape = trans::PaddingShape(infer_shape, format, AnfAlgo::GetInputReshapeType(anf_node, real_input_index)); (*input_desc)[kJCValue] = infer_shape[1]; } + shape = tbe::TbeDynamicShapeUtil::UpdateShape(anf_node, (*input_desc)[kJOriFormat], shape, real_input_index, true); (*input_desc)[kJShape] = shape; (*input_desc)[kJFormat] = format; (*input_desc)[kJValid] = true; - (*input_desc)[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index, format, d_type); + (*input_desc)[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index, format, + (*input_desc)[kJOriFormat], d_type); GenInputConstValue(anf_node, real_input_index, input_desc); } @@ -236,8 +238,8 @@ void SingleTbeJsonCreator::GenOutputDescJson(const AnfNodePtr &anf_node, size_t auto type_str = GetJsonValue(*output_desc, kJDtype); auto d_type = tbe::DtypeToTypeId(type_str); (*output_desc)[kJValid] = true; - (*output_desc)[kJRange] = - tbe::TbeDynamicShapeUtil::GetOutputDynamicRange(anf_node, node_out_idx, (*output_desc)[kJFormat], d_type); + (*output_desc)[kJRange] = tbe::TbeDynamicShapeUtil::GetOutputDynamicRange( + anf_node, node_out_idx, (*output_desc)[kJFormat], (*output_desc)[kJOriFormat], d_type); } bool SingleTbeJsonCreator::AssignInputsJson(const AnfNodePtr &anf_node, const std::vector &inputs_desc, @@ -398,6 +400,7 @@ void SelectTbeJsonCreator::GenDescJson(const AnfNodePtr &anf_node, size_t node_o (*output_desc)[kJFormat] = format; (*output_desc)[kJOriFormat] = def_format; (*output_desc)[kJOriShape] = ori_shape; + shape = tbe::TbeDynamicShapeUtil::UpdateShape(anf_node, (*output_desc)[kJOriFormat], shape, node_out_idx, false); (*output_desc)[kJShape] = shape; (*output_desc)[kJOutputIndex] = desc_output_idx; } @@ -422,13 +425,15 @@ void SelectTbeJsonCreator::GenInputDescJson(const AnfNodePtr &anf_node, size_t r if (common::AnfAlgo::GetCNodeName(anf_node) == kMaxPool3DGradGradDOpName) { (*input_desc)[kJOriFormat] = kOpFormat_NDHWC; } + shape = tbe::TbeDynamicShapeUtil::UpdateShape(anf_node, (*input_desc)[kJOriFormat], shape, real_input_index, true); (*input_desc)[kJShape] = shape; (*input_desc)[kJFormat] = format; (*input_desc)[kJValid] = true; auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(anf_node, real_input_index); MS_EXCEPTION_IF_NULL(input_node_with_index.first); if (!input_node_with_index.first->isa()) { - (*input_desc)[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index, format, d_type); + (*input_desc)[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index, format, + (*input_desc)[kJOriFormat], d_type); } } bool SelectTbeJsonCreator::AttrsJsonPostProcessing(const AnfNodePtr &anf_node, const OpInfoPtr &op_info_ptr, diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_json/tbe_json_creator.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_json/tbe_json_creator.cc index 9140844870a..016eed9c275 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_json/tbe_json_creator.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_json/tbe_json_creator.cc @@ -487,6 +487,7 @@ void TbeJsonCreator::GenDescJson(const AnfNodePtr &anf_node, size_t node_out_idx (*output_desc)[kJFormat] = format; (*output_desc)[kJOriFormat] = def_format; (*output_desc)[kJOriShape] = ori_shape; + shape = tbe::TbeDynamicShapeUtil::UpdateShape(anf_node, (*output_desc)[kJOriFormat], shape, node_out_idx, false); (*output_desc)[kJShape] = shape; (*output_desc)[kJName] = output_desc_name; // !! Note: output_index, only node's output use it diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.cc index e317538ac6f..17fc6c9ae74 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.cc @@ -83,7 +83,7 @@ std::vector HostCheck::GetFinalInferShape(const AnfNodePtr &node, size_ return infer_shape; } - if (trans::IsNeedPadding(format, infer_shape.size())) { + if (trans::IsNeedPadding(format, infer_shape)) { auto reshape_type = is_output ? AnfAlgo::GetOutputReshapeType(node, index) : AnfAlgo::GetInputReshapeType(node, index); infer_shape = trans::PaddingShape(infer_shape, format, reshape_type, node); diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_helper.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_helper.cc index 94c9cd63c04..60bc96ca606 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_helper.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_helper.cc @@ -36,6 +36,23 @@ namespace mindspore { namespace opt { using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; namespace { +struct CreateNodeArgs { + FuncGraphPtr func_graph{nullptr}; + AnfNodePtr node{nullptr}; + AnfNodePtr input_node{nullptr}; + AnfNodePtr orig_node{nullptr}; + KernelSelectPtr kernel_select{nullptr}; + std::string trans_opname; + std::string input_format; + std::string dst_format; + std::string spec_format; + std::string reshape_type; + TypeId type_id; + ShapeVector out_shape; + bool is_dynamic_shape; + bool need_padding; +}; + std::string GetTransOpName(const std::string &spec_format) { std::string trans_opname = (spec_format == kOpFormat_FRACTAL_ZN_RNN || spec_format == kOpFormat_ND_RNN_BIAS) ? prim::kPrimTransDataRNN->name() @@ -83,6 +100,61 @@ CNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &inp return reshape; } +AnfNodePtr CreateTransDataWithOutReshape(const CreateNodeArgs &args) { + // don't need padding insert TransData only + auto trans_data = NewTransOpNode(args.func_graph, args.input_node, args.orig_node, args.kernel_select, + args.need_padding, args.trans_opname); + RefreshKernelBuildInfo(args.kernel_select, args.input_format, args.dst_format, trans_data, args.reshape_type, + args.type_id); + return trans_data; +} + +AnfNodePtr CreateTransDataWithReshape(const CreateNodeArgs &args) { + AnfNodePtr trans_node = nullptr; + CNodePtr trans_data = nullptr; + if (!args.need_padding) { + // don't need padding insert transdata only + trans_data = NewTransOpNode(args.func_graph, args.input_node, args.orig_node, args.kernel_select, args.need_padding, + args.trans_opname); + trans_node = trans_data; + RefreshKernelBuildInfo(args.kernel_select, args.input_format, args.dst_format, trans_data, args.reshape_type, + args.type_id); + } else if (args.spec_format == args.dst_format) { + // if need padding & default to special format + // ori_shape -> reshape[padding shape] -> transdata[device shape] + auto padding_shape = trans::PaddingShape(args.out_shape, args.dst_format, args.reshape_type, args.node); + std::vector padding_axis; + if (std::count(padding_shape.begin(), padding_shape.end(), -1) > 1) { + padding_axis = trans::StringToAxisVector(args.out_shape, args.dst_format, args.reshape_type, args.node); + } + abstract::ShapePtr pad_shape_ptr = std::make_shared(padding_shape); + auto reshape_node = CreateReshapeNode(args.func_graph, args.input_node, args.orig_node, args.kernel_select, + pad_shape_ptr, args.is_dynamic_shape, padding_axis); + trans_data = NewTransOpNode(args.func_graph, reshape_node, args.orig_node, args.kernel_select, args.need_padding, + args.trans_opname); + trans_node = trans_data; + trans_data->set_abstract(args.input_node->abstract()); + RefreshKernelBuildInfo(args.kernel_select, args.input_format, args.dst_format, trans_data, args.reshape_type, + args.type_id); + } else { + // if need padding & special to default format + // device shape -> transdata[padding shape] -> reshape[ori_shape] + trans_data = NewTransOpNode(args.func_graph, args.input_node, args.orig_node, args.kernel_select, args.need_padding, + args.trans_opname); + RefreshKernelBuildInfo(args.kernel_select, args.input_format, args.dst_format, trans_data, args.reshape_type, + args.type_id); + abstract::ShapePtr pad_shape_ptr = std::make_shared(args.out_shape); + std::vector padding_axis; + if (std::count(args.out_shape.begin(), args.out_shape.end(), -1) > 1) { + padding_axis = trans::StringToAxisVector(args.out_shape, args.dst_format, args.reshape_type, args.node); + } + auto reshape_node = CreateReshapeNode(args.func_graph, trans_data, args.orig_node, args.kernel_select, + pad_shape_ptr, args.is_dynamic_shape, padding_axis); + trans_node = reshape_node; + } + return trans_node; +} + void ReFreshInferShape(const AnfNodePtr &trans_node, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(trans_node); MS_EXCEPTION_IF_NULL(node); @@ -136,6 +208,9 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr & MS_EXCEPTION_IF_NULL(node_with_index.first); auto real_input = node_with_index.first; if (real_input->isa() || real_input->isa()) { + MS_LOG(DEBUG) + << "ValueNode or Parameter has no inputs, try to insert for ValueNode or Parameter at out anchor, node: " + << node->fullname_with_scope(); input_node = InsertTransOpForOutput(func_graph, input_node, input_node, kernel_select); MS_EXCEPTION_IF_NULL(input_node); common::AnfAlgo::SetNodeInput(node, input_node, index); @@ -143,8 +218,8 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr & ShapeVector origin_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, index); std::string dest_format = AnfAlgo::GetInputFormat(node, index); if (NeedInsertTransData(origin_shape, dest_format)) { - MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index) - << " To DefaultFormat , index: " << index; + MS_LOG(DEBUG) << "Need insert TransData change format from [" << dest_format + << "] to [DefaultFormat], input index:" << index << ", node: " << node->fullname_with_scope(); auto transdata = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true); if (real_input->isa()) { SetGroupAttr(real_input->cast(), input_node, transdata, dest_format); @@ -165,7 +240,8 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An << node->DebugString() << trace::DumpSourceLines(node); } if (NeedInsertTransData(origin_shape, output_format)) { - MS_LOG(DEBUG) << "Inserted transdata " << output_format << " to default , index :0"; + MS_LOG(DEBUG) << "Inserted TransData change format from [" << output_format + << "] to [DefaultFormat], single output index :0"; return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false); } return node; @@ -243,61 +319,32 @@ AnfNodePtr AddTransOpNodeToGraphWithFormat(const FuncGraphPtr &func_graph, const auto out_shape_ptr = input_node_out_shape->cast(); MS_EXCEPTION_IF_NULL(out_shape_ptr); ShapeVector out_shape = out_shape_ptr->shape(); - + auto is_dyn_rank = out_shape_ptr->IsDimUnknown(); auto is_dynamic_shape = out_shape_ptr->IsDynamic(); - bool need_padding = trans::IsNeedPadding(spec_format, out_shape.size()); + bool need_padding = trans::IsNeedPadding(spec_format, out_shape); std::string trans_opname = GetTransOpName(spec_format); bool is_insert_output = node == input_node; auto orig_node = GetOriginNode(is_insert_output, node); - - AnfNodePtr trans_node = nullptr; - CNodePtr trans_data = nullptr; - if (!need_padding) { - // don't need padding insert transdata only - trans_data = NewTransOpNode(func_graph, input_node, orig_node, kernel_select, need_padding, trans_opname); - trans_node = trans_data; - RefreshKernelBuildInfo(kernel_select, input_format, dst_format, trans_data, reshape_type, type_id); - } else if (spec_format == dst_format) { - // if need padding & default to special format - // ori_shape -> reshape[padding shape] -> transdata[device shape] - - auto padding_shape = trans::PaddingShape(out_shape, dst_format, reshape_type, node); - std::vector padding_axis; - if (std::count(padding_shape.begin(), padding_shape.end(), -1) > 1) { - padding_axis = trans::StringToAxisVector(out_shape, dst_format, reshape_type, node); - } - abstract::ShapePtr pad_shape_ptr = std::make_shared(padding_shape); - auto reshape_node = CreateReshapeNode(func_graph, input_node, orig_node, kernel_select, pad_shape_ptr, - is_dynamic_shape, padding_axis); - trans_data = NewTransOpNode(func_graph, reshape_node, orig_node, kernel_select, need_padding, trans_opname); - trans_node = trans_data; - trans_data->set_abstract(input_node->abstract()); - RefreshKernelBuildInfo(kernel_select, input_format, dst_format, trans_data, reshape_type, type_id); + AnfNodePtr trans_data = nullptr; + CreateNodeArgs args = {func_graph, node, input_node, orig_node, kernel_select, + trans_opname, input_format, dst_format, spec_format, reshape_type, + type_id, out_shape, is_dynamic_shape, need_padding}; + if (is_dyn_rank) { + trans_data = CreateTransDataWithOutReshape(args); } else { - // if need padding & special to default format - // device shape -> transdata[padding shape] -> reshape[ori_shape] - trans_data = NewTransOpNode(func_graph, input_node, orig_node, kernel_select, need_padding, trans_opname); - RefreshKernelBuildInfo(kernel_select, input_format, dst_format, trans_data, reshape_type, type_id); - abstract::ShapePtr pad_shape_ptr = std::make_shared(out_shape); - std::vector padding_axis; - if (std::count(out_shape.begin(), out_shape.end(), -1) > 1) { - padding_axis = trans::StringToAxisVector(out_shape, dst_format, reshape_type, node); - } - auto reshape_node = CreateReshapeNode(func_graph, trans_data, orig_node, kernel_select, pad_shape_ptr, - is_dynamic_shape, padding_axis); - - trans_node = reshape_node; + trans_data = CreateTransDataWithReshape(args); } + if (spec_format == kOpFormat_FRAC_Z && groups != 1 && !common::AnfAlgo::HasNodeAttr(kAttrFracZGroup, trans_data->cast())) { common::AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(groups), trans_data); common::AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups), trans_data); } if (is_insert_output) { - ReFreshInferShape(trans_node, node); + ReFreshInferShape(trans_data, node); } - return trans_node; + return trans_data; } void RefreshKernelBuildInfo(const KernelSelectPtr &kernel_select, const std::string &input_format, @@ -368,15 +415,17 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, auto out_shape_base = AnfAlgo::GetOutputDetailShape(input, 0); MS_EXCEPTION_IF_NULL(out_shape_base); ShapeVector out_shape; + bool is_dyn_rank = false; bool is_dynamic_shape = false; if (out_shape_base->isa()) { auto out_shape_ptr = out_shape_base->cast(); MS_EXCEPTION_IF_NULL(out_shape_ptr); out_shape = out_shape_ptr->shape(); is_dynamic_shape = out_shape_ptr->IsDynamic(); + is_dyn_rank = out_shape_ptr->IsDimUnknown(); } - if (need_padding) { + if (need_padding && !is_dyn_rank) { // if need padding we should set the transdata node's shape to the padding shape auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); @@ -495,6 +544,7 @@ AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePt MS_EXCEPTION_IF_NULL(cnode); std::vector new_inputs = {common::AnfAlgo::GetCNodePrimitiveNode(cnode)}; size_t in_num = common::AnfAlgo::GetInputNum(cnode); // include monads. + MS_LOG(DEBUG) << "Try to insert TransData at input anchor for node: " << cnode->fullname_with_scope(); for (size_t input_index = 0; input_index < in_num; ++input_index) { // Monad inputs keep unchanged from GetTransInputNodePtr(). AnfNodePtr input_node = GetTransInputNodePtr(func_graph, cnode, input_index, kernel_select); diff --git a/mindspore/ccsrc/runtime/device/ms_device_shape_transfer.cc b/mindspore/ccsrc/runtime/device/ms_device_shape_transfer.cc index bf76197ce64..9bbf542bb31 100644 --- a/mindspore/ccsrc/runtime/device/ms_device_shape_transfer.cc +++ b/mindspore/ccsrc/runtime/device/ms_device_shape_transfer.cc @@ -247,13 +247,16 @@ void StringToAxisVector5D(const std::string &reshape_type_str, std::vector PaddingShape(const std::vector &shape, const std::string &form MS_LOG(DEBUG) << "Start padding shape for node: [" << node->fullname_with_scope() << "], format: " << format << ", detail info: " << node->DebugString(); } - std::vector host_shape; + if (IsOneOf3DFormat(format)) { if (shape.size() >= kDim5) { return shape; } - host_shape = PaddingShapeTo5d(shape, pad_index); - } else { - host_shape = PaddingShapeTo4d(shape, pad_index); + if (shape.size() == 1 && shape[0] == abstract::Shape::kShapeRankAny) { + return {-1, -1, -1, -1, -1}; + } + return PaddingShapeTo5d(shape, pad_index); } - return host_shape; + + if (shape.size() == 1 && shape[0] == abstract::Shape::kShapeRankAny) { + return {-1, -1, -1, -1}; + } + return PaddingShapeTo4d(shape, pad_index); } /** diff --git a/mindspore/ccsrc/utils/anfalgo.cc b/mindspore/ccsrc/utils/anfalgo.cc index 75003e44a03..a1e34ac10b2 100644 --- a/mindspore/ccsrc/utils/anfalgo.cc +++ b/mindspore/ccsrc/utils/anfalgo.cc @@ -1382,6 +1382,93 @@ bool AnfAlgo::HasDynamicShapeFlag(const PrimitivePtr &prim) { return get_bool_attr(prim, kAttrInputIsDynamicShape) || get_bool_attr(prim, kAttrOutputIsDynamicShape); } +bool IsNodeDynamicRank(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + MS_LOG(DEBUG) << "Node is not a cnode"; + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto in_dyn_rank = AnfAlgo::IsNodeInputDynamicRank(cnode); + auto out_dyn_rank = AnfAlgo::IsNodeOutputDynamicRank(cnode); + if (in_dyn_rank && !AnfAlgo::HasNodeAttr(kAttrInputIsDynamicRank, cnode)) { + AnfAlgo::SetNodeAttrSafely(kAttrInputIsDynamicRank, MakeValue(true), cnode); + MS_LOG(DEBUG) << "Set input dynamic rank attr for node:" << cnode->fullname_with_scope(); + } + if (out_dyn_rank && !AnfAlgo::HasNodeAttr(kAttrOutputIsDynamicRank, cnode)) { + AnfAlgo::SetNodeAttrSafely(kAttrOutputIsDynamicRank, MakeValue(true), cnode); + MS_LOG(DEBUG) << "Set output dynamic rank attr for node:" << cnode->fullname_with_scope(); + } + return in_dyn_rank || out_dyn_rank; +} + +bool AnfAlgo::IsDynamicRankNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + return IsOutputAnchorDynamicRank(node, 0); + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if ((!HasNodeAttr(kAttrInputIsDynamicRank, cnode)) && (!HasNodeAttr(kAttrOutputIsDynamicRank, cnode))) { + auto ret = IsNodeDynamicRank(node); + MS_LOG(DEBUG) << "The Node:" << node->fullname_with_scope() << " is dynamic rank: [" << ret << "]"; + return ret; + } + return GetBooleanAttr(node, kAttrInputIsDynamicRank) || GetBooleanAttr(node, kAttrOutputIsDynamicRank) || + GetBooleanAttr(node, kAttrIsDynamicRank); +} + +bool AnfAlgo::IsInputAnchorDynamicRank(const AnfNodePtr &node, size_t idx) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + MS_LOG(EXCEPTION) << "Only cnode has inputs, node: " << node->fullname_with_scope(); + } + const auto &in_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, idx); + if (mindspore::IsDynamicRank(in_shape)) { + return true; + } + return false; +} + +bool AnfAlgo::IsOutputAnchorDynamicRank(const AnfNodePtr &node, size_t idx) { + MS_EXCEPTION_IF_NULL(node); + const auto &out_shape = common::AnfAlgo::GetOutputInferShape(node, idx); + if (mindspore::IsDynamicRank(out_shape)) { + return true; + } + return false; +} + +bool AnfAlgo::IsNodeInputDynamicRank(const CNodePtr &anf_node_ptr) { + MS_EXCEPTION_IF_NULL(anf_node_ptr); + const auto &inputs = anf_node_ptr->inputs(); + for (size_t i = 1; i < inputs.size(); ++i) { + const auto &input = inputs[i]; + MS_EXCEPTION_IF_NULL(input); + if (IsNodeOutputDynamicRank(input)) { + return true; + } + } + return false; +} + +bool AnfAlgo::IsNodeOutputDynamicRank(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto base_shape = node->Shape(); + if (base_shape == nullptr) { + MS_LOG(INFO) << "Invalid base shape, node: " << node->fullname_with_scope(); + return false; + } + if (base_shape->isa()) { + auto b_ptr = base_shape->cast(); + if (b_ptr->IsDimUnknown()) { + return true; + } + } + return base_shape->IsDimUnknown(); +} + bool AnfAlgo::IsDynamicShape(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { diff --git a/mindspore/ccsrc/utils/utils.cc b/mindspore/ccsrc/utils/utils.cc index 20cbd01b687..d31dac4e636 100644 --- a/mindspore/ccsrc/utils/utils.cc +++ b/mindspore/ccsrc/utils/utils.cc @@ -211,4 +211,11 @@ bool IsOneOfServerFormatC04(const std::string &format) { static const std::set kServerFormatC04List = {kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; return kServerFormatC04List.find(format) != kServerFormatC04List.end(); } + +bool IsOneOfDynRankNeedPadShape(const std::string &format) { + const std::set kOpFormats = {kOpFormat_NC1HWC0, kOpFormat_NDC1HWC0, kOpFormat_FRAC_Z, + kOpFormat_NDC1HWC0, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, + kOpFormat_FRACTAL_Z_3D, kOpFormat_FRACTAL_Z_C04, kOpFormat_NCDHW}; + return kOpFormats.find(format) != kOpFormats.end(); +} } // namespace mindspore diff --git a/mindspore/core/abstract/dshape.h b/mindspore/core/abstract/dshape.h index 96e72429b08..d86f22205fc 100644 --- a/mindspore/core/abstract/dshape.h +++ b/mindspore/core/abstract/dshape.h @@ -244,6 +244,7 @@ class MS_CORE_API DynamicSequenceShape : public BaseShape { // element's shape BaseShapePtr element_shape_{nullptr}; }; +using DynamicSequenceShapePtr = std::shared_ptr; GVAR_DEF(std::shared_ptr, kDynamicSequenceShape, std::make_shared()); /// \brief SequequeShape defines base class of multiple-shape classes. diff --git a/tests/st/ops/ascend/test_dynamic_rank/test_batch_norm_dynamic_rank.py b/tests/st/ops/ascend/test_dynamic_rank/test_batch_norm_dynamic_rank.py new file mode 100644 index 00000000000..22c5c9f9d95 --- /dev/null +++ b/tests/st/ops/ascend/test_dynamic_rank/test_batch_norm_dynamic_rank.py @@ -0,0 +1,111 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.ops.operations as ops +from mindspore.nn import Cell +from mindspore import Tensor +from mindspore import context + +np.random.seed(3) + + +class MSDynRankNet(Cell): + def __init__(self, is_training=True): + super(MSDynRankNet, self).__init__() + self.is_training = is_training + self.batch_norm = ops.BatchNorm(is_training=self.is_training) + self.reduce_mean = ops.ReduceMean(keep_dims=False) + self.relu = ops.ReLU() + + def construct(self, input_x, scale, offset, mean, variance, indices): + unique_indices = self.relu(indices) + reduced_in = self.reduce_mean(input_x, unique_indices) + reduced_scale = self.reduce_mean(scale, unique_indices) + reduced_offset = self.reduce_mean(offset, unique_indices) + reduced_mean = self.reduce_mean(mean, unique_indices) + reduced_variance = self.reduce_mean(variance, unique_indices) + out, _, _, _, _ = self.batch_norm(reduced_in, reduced_scale, reduced_offset, reduced_mean, reduced_variance) + return out + + +class NetFactory: + def __init__(self, x, scale, offset, mean, variance, indices, dtype=np.float32, is_training=False): + super(NetFactory, self).__init__() + self.x = x + self.scale = scale + self.offset = offset + self.mean = mean + self.variance = variance + self.indices = indices + self.dtype = dtype + self.is_training = is_training + self.nh2nc = [0, 3, 1, 2] + self.nc2nh = [0, 2, 3, 1] + + def mindspore_case(self): + ms_x = Tensor(self.x) + ms_indices = Tensor(self.indices) + ms_scale = Tensor(self.scale) + ms_offset = Tensor(self.offset) + ms_mean = Tensor(self.mean) + ms_variance = Tensor(self.variance) + + ms_dyn_x = Tensor(shape=[None for _ in ms_x.shape], dtype=ms_x.dtype) + ms_dyn_scale = Tensor(shape=[None for _ in ms_scale.shape], dtype=ms_scale.dtype) + ms_dyn_offset = Tensor(shape=[None for _ in ms_offset.shape], dtype=ms_offset.dtype) + ms_dyn_mean = Tensor(shape=[None for _ in ms_mean.shape], dtype=ms_mean.dtype) + ms_dyn_variance = Tensor(shape=[None for _ in ms_variance.shape], dtype=ms_variance.dtype) + ms_dyn_indices = Tensor(shape=[None], dtype=ms_indices.dtype) + + ms_net = MSDynRankNet(is_training=self.is_training) + ms_net.set_inputs(ms_dyn_x, ms_dyn_scale, ms_dyn_offset, ms_dyn_mean, ms_dyn_variance, ms_dyn_indices) + + ms_out = ms_net(ms_x, ms_scale, ms_offset, ms_mean, ms_variance, ms_indices) + return ms_out.asnumpy() + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_batch_norm_dynamic_rank(): + """ + Feature: test batch norm dynamic rank + Description: test batch norm dynamic rank with input tensor's type float32 + Expectation: none. + """ + input_x = np.random.randn(3, 3, 4, 3, 3).astype(np.float32) + scale_ = np.ones((4, 4)).astype(np.float32) + offset_ = np.ones((4, 4)).astype(np.float32) + mean_ = np.ones((4, 4)).astype(np.float32) + variance_ = np.ones((4, 4)).astype(np.float32) + indices_ = np.unique(np.random.randint(1, 2, (1,)).astype(np.int32)) + + # graph mode + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + graph_mode_net = NetFactory(input_x, scale=scale_, offset=offset_, mean=mean_, variance=variance_, indices=indices_, + dtype=np.float32) + graph_mode_out = graph_mode_net.mindspore_case() + + # pynative mode + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + pynative_mode_net = NetFactory(input_x, scale=scale_, offset=offset_, mean=mean_, variance=variance_, + indices=indices_, + dtype=np.float32) + pynative_mode_out = pynative_mode_net.mindspore_case() + + assert np.allclose(pynative_mode_out, graph_mode_out, 1e-4, 1e-4) diff --git a/tests/st/ops/ascend/test_dynamic_rank/test_biasadd_dynamic_rank.py b/tests/st/ops/ascend/test_dynamic_rank/test_biasadd_dynamic_rank.py new file mode 100644 index 00000000000..0d103ba0fd2 --- /dev/null +++ b/tests/st/ops/ascend/test_dynamic_rank/test_biasadd_dynamic_rank.py @@ -0,0 +1,99 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import pytest +import numpy as np +import torch +from mindspore import Tensor +from mindspore import context +from mindspore.nn import Cell +from mindspore.ops import operations as P + +np.random.seed(3) +context.set_context(mode=context.GRAPH_MODE) + + +class MSBiasAddDynRankNet(Cell): + def __init__(self, data_format="NCHW"): + super(MSBiasAddDynRankNet, self).__init__() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.bias_add = P.BiasAdd(data_format=data_format) + self.relu = P.ReLU() + + def construct(self, input_a, input_b, indices): + relu_indices = self.relu(indices) + reduce_a = self.reduce_sum(input_a, relu_indices) + out = self.bias_add(reduce_a, input_b) + return out + + +class TorchAddNet(torch.nn.Module): + def __init__(self): + super(TorchAddNet, self).__init__() + self.keep_dims = False + + def forward(self, input_a, input_b, indices): + relu_indices = torch.relu(indices) + reduce_a = torch.sum(input_a, relu_indices.tolist(), keepdim=self.keep_dims) + out = torch.add(reduce_a, input_b) + return out + + +class BiasAddOpFactory: + def __init__(self, in_shape, indices, dtype=np.float32, data_format="NCHW"): + super(BiasAddOpFactory, self).__init__() + self.dtype = dtype + self.input_x = np.random.randn(*in_shape).astype(self.dtype) + self.data_format = data_format + self.indices = indices + self.input_b = np.random.randn(in_shape[-1]).astype(self.dtype) + self.loss = 1e-4 + + def ms_biass_add_forward(self): + a = Tensor(self.input_x) + b = Tensor(self.input_b) + indices = Tensor(self.indices) + + dyn_a = Tensor(shape=[None for _ in a.shape], dtype=a.dtype) + dyn_b = Tensor(shape=[None for _ in b.shape], dtype=b.dtype) + dyn_indices = Tensor(shape=[None], dtype=indices.dtype) + + ms_net = MSBiasAddDynRankNet(data_format=self.data_format) + ms_net.set_inputs(dyn_a, dyn_b, dyn_indices) + out = ms_net(a, b, indices) + return out.asnumpy() + + def torch_bias_add_forward(self): + torch_net = TorchAddNet() + out = torch_net(torch.from_numpy(self.input_x), torch.from_numpy(self.input_b), torch.from_numpy(self.indices)) + return out.detach().numpy() + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_bias_add_dyn_rank(): + """ + Feature: test bias add dynamic rank + Description: test bias add dynamic rank with input tensor's type float32 + Expectation: none. + """ + in_shape = (16, 16, 16, 16, 16) + indices_np = np.unique(np.random.randint(0, 2, size=3).astype(np.int32)) + factory = BiasAddOpFactory(in_shape=in_shape, indices=indices_np, dtype=np.float32, data_format="NCHW") + ms_out = factory.ms_biass_add_forward() + torch_out = factory.torch_bias_add_forward() + + np.allclose(ms_out, torch_out, factory.loss, factory.loss) diff --git a/tests/st/ops/ascend/test_dynamic_rank/test_reduce_op_dynamic_rank.py b/tests/st/ops/ascend/test_dynamic_rank/test_reduce_op_dynamic_rank.py new file mode 100644 index 00000000000..3ca9da6b7f3 --- /dev/null +++ b/tests/st/ops/ascend/test_dynamic_rank/test_reduce_op_dynamic_rank.py @@ -0,0 +1,98 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import pytest +import numpy as np +import torch +from mindspore import Tensor +from mindspore import context +from mindspore.nn import Cell +from mindspore.ops import operations as P + +np.random.seed(3) +context.set_context(mode=context.GRAPH_MODE) + + +class MSReduceSumNet(Cell): + def __init__(self, keep_dims=False): + super(MSReduceSumNet, self).__init__() + self.reduce_sum = P.ReduceSum(keep_dims=keep_dims) + self.reduce = P.ReduceSum(keep_dims=False) + + def construct(self, x, indices, axis): + x = self.reduce(x, axis) + return self.reduce_sum(x, indices) + + +class TorchReduceSumNet(torch.nn.Module): + def __init__(self, keep_dims=False): + super(TorchReduceSumNet, self).__init__() + self.keep_dims = keep_dims + + def forward(self, input_x, indices, axis): + x = torch.sum(input_x, axis.tolist(), False) + out = torch.sum(x, indices.tolist(), self.keep_dims) + return out + + +class ReduceOpFactory: + def __init__(self, input_x, indices, axis, keep_dims, dtype=np.float32, loos=1e-4): + super(ReduceOpFactory, self).__init__() + self.out_grad = None + self.input_x = input_x + self.indices = indices + self.axis = axis + self.keep_dims = keep_dims + self.dtype = dtype + self.loss = loos + + def ms_reduce_sum_forward(self): + net = MSReduceSumNet(self.keep_dims) + in_tensor = Tensor(self.input_x) + in_indices = Tensor(self.indices) + in_axis = Tensor(self.axis) + + in_tensor_dyn = Tensor(shape=[None for _ in in_tensor.shape], dtype=in_tensor.dtype) + in_indices_dyn = Tensor(shape=[None for _ in in_indices.shape], dtype=in_indices.dtype) + in_axis_dyn = Tensor(shape=[None for _ in in_axis.shape], dtype=in_axis.dtype) + + net.set_inputs(in_tensor_dyn, in_indices_dyn, in_axis_dyn) + out = net(in_tensor, in_indices, in_axis) + return out.asnumpy() + + def torch_reduce_sum_forward(self): + net = TorchReduceSumNet(self.keep_dims) + out = net(torch.from_numpy(self.input_x.astype(self.dtype)), self.indices, self.axis) + return out.detach().numpy() + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_reduce_sum_dyn_rank(): + """ + Feature: test reduce sum dynamic rank + Description: test reduce sum dynamic rank with input tensor's type float32 + Expectation: none. + """ + dtype = np.float32 + x = np.random.randn(22, 20, 28, 36, 24, 23).astype(dtype) + indices = np.array([0, -1]) + axis = np.unique(np.random.randint(0, 2, size=5).astype(np.int32)) + factory = ReduceOpFactory(x, indices, axis, keep_dims=True, dtype=dtype, loos=1e-4) + + ms_data = factory.ms_reduce_sum_forward() + torch_data = factory.torch_reduce_sum_forward() + np.allclose(torch_data, ms_data, factory.loss, factory.loss)