diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc index 95ea2527aa7..8b2d8206fb3 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,6 +33,11 @@ namespace mindspore { namespace opt { using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; namespace { +bool NeedInsertTransData(const std::vector &origin_shape, const std::string &format) { + return kCommonFormatSet.find(format) == kCommonFormatSet.end() && + (origin_shape.size() > 1 || format == kOpFormat_ND_RNN_BIAS); +} + AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const KernelSelectPtr &kernel_select, const std::vector &dst_shape) { std::vector trans_inputs; @@ -50,14 +55,15 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i void SetTransNodeAttr(const CNodePtr &trans_node) { MS_EXCEPTION_IF_NULL(trans_node); - if (AnfAlgo::GetCNodeName(trans_node) == kTransDataOpName) { + auto trans_opname = AnfAlgo::GetCNodeName(trans_node); + if (trans_opname == kTransDataOpName || trans_opname == kTransDataRNNOpName) { std::string input_format = AnfAlgo::GetInputFormat(trans_node, 0); std::string output_format = AnfAlgo::GetOutputFormat(trans_node, 0); if (input_format == kOpFormat_DEFAULT) { - input_format = kOpFormat_NCHW; + input_format = AnfAlgo::GetCNodeName(trans_node) == kTransDataOpName ? kOpFormat_NCHW : kOpFormat_ND; } if (output_format == kOpFormat_DEFAULT) { - output_format = kOpFormat_NCHW; + output_format = AnfAlgo::GetCNodeName(trans_node) == kTransDataOpName ? kOpFormat_NCHW : kOpFormat_ND; } AnfAlgo::SetNodeAttr(kAttrSrcFormat, MakeValue(input_format), trans_node); AnfAlgo::SetNodeAttr(kAttrDstFormat, MakeValue(output_format), trans_node); @@ -115,7 +121,7 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr & } std::vector origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index); std::string dest_format = AnfAlgo::GetInputFormat(node, index); - if (kCommonFormatSet.find(dest_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { + if (NeedInsertTransData(origin_shape, dest_format)) { MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index) << " To DefaultFormat , index: " << index; auto transdata = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true); @@ -136,7 +142,7 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An MS_LOG(EXCEPTION) << "Got the hw format " << output_format << "when insert the transdata node " << node->DebugString() << " trace: " << trace::DumpSourceLines(node); } - if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { + if (NeedInsertTransData(origin_shape, output_format)) { MS_LOG(DEBUG) << "Inserted transdata " << output_format << " to default , index :0"; return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false); } @@ -164,7 +170,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const } auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx); std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); - if (origin_shape.size() > 1 && kCommonFormatSet.find(output_format) == kCommonFormatSet.end()) { + if (NeedInsertTransData(origin_shape, output_format)) { auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false); if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, output_idx)) { kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0); @@ -193,11 +199,14 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt : AnfAlgo::GetOutputReshapeType(node, insert_index); auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index) : AnfAlgo::GetOutputInferShape(input_node, insert_index); - bool need_padding = is_insert_input ? trans::IsNeedPadding(dst_format, input_node_out_shape.size()) - : trans::IsNeedPadding(input_format, input_node_out_shape.size()); + std::string spec_format = is_insert_input ? dst_format : input_format; + bool need_padding = trans::IsNeedPadding(spec_format, input_node_out_shape.size()); + std::string trans_opname = (spec_format == kOpFormat_FRACTAL_ZN_RNN || spec_format == kOpFormat_ND_RNN_BIAS) + ? prim::kPrimTransDataRNN->name() + : prim::kPrimTransData->name(); if (!need_padding) { // don't need padding insert transdata only - trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::kPrimTransData->name()); + trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, trans_opname); trans_node = trans_data; } else if (is_insert_input) { // if need padding & is input need insert a transdata @@ -205,16 +214,20 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt auto padding_shape = trans::PaddingShape(input_node_out_shape, AnfAlgo::GetInputFormat(node, insert_index), AnfAlgo::GetInputReshapeType(node, insert_index)); auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); - trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::kPrimTransData->name()); + trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, trans_opname); trans_node = trans_data; trans_data->set_abstract(input_node->abstract()); } else { // if need padding & is output need insert a transdata // node -> transdata[padding shape] -> reshape[ori_shape] - trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::kPrimTransData->name()); + trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, trans_opname); auto reshape_node = CreateReshapeNode(func_graph, trans_data, kernel_select, input_node_out_shape); trans_node = reshape_node; } + if (trans_opname == prim::kPrimTransDataRNN->name()) { + AnfAlgo::CopyNodeAttr(kAttrHiddenSize, node, trans_data); + AnfAlgo::CopyNodeAttr(kAttrInputSize, node, trans_data); + } // refresh the transdata's format to ori format & dst format RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis); if (!is_insert_input) { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_gru_v2_grad_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_gru_v2_grad_fission.cc index e64934a1ce2..159723fc0f0 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_gru_v2_grad_fission.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_gru_v2_grad_fission.cc @@ -459,7 +459,7 @@ const AnfNodePtr DynamicGRUV2GradFission::Process(const FuncGraphPtr &func_graph return nullptr; } if (AnfAlgo::IsDynamicShape(node)) { - MS_LOG(INFO) << "DynamicGRUV2Grad is dynamic shape, can not optimizer."; + MS_LOG(INFO) << "DynamicGRUV2Grad is dynamic shape, can not do fission."; return nullptr; } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc index 344ea7fd66c..da869500598 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc @@ -518,6 +518,10 @@ const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph << (kDynamicRNNGradInputNum + 1) << " inputs"; return nullptr; } + if (AnfAlgo::IsDynamicShape(node)) { + MS_LOG(INFO) << "DynamicRnnGrad is dynamic shape, can not do fission."; + return nullptr; + } std::vector new_outputs; auto lstm_input_grad = AddLSTMInputGradNode(func_graph, dynamic_rnn_grad_cnode, &new_outputs); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/set_fracz_group_attr.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/set_fracz_group_attr.cc index 36ca813ab85..7add56cecef 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/set_fracz_group_attr.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/set_fracz_group_attr.cc @@ -51,6 +51,7 @@ void SetAttrForInputNode(const AnfNodePtr &node, int64_t groups) { void SetAttrForConvInput(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); auto groups = AnfAlgo::GetNodeAttr(cnode, kAttrGroups); + AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups), cnode); if (groups > 1) { SetAttrForInputNode(cnode->input(kConvFilterInputIndex), groups); } diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 214c049ddea..61b1ba4699b 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -857,8 +857,7 @@ std::vector AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &n if (trans::IsNeedPadding(format, infer_shape.size())) { infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx)); } - auto input_node_index = GetPrevNodeOutput(node, input_idx); - return trans::TransShapeToDevice(infer_shape, format, input_node_index.first, input_node_index.second); + return trans::TransShapeToDevice(infer_shape, format, node, input_idx, false); } std::string AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) { @@ -2055,8 +2054,7 @@ std::vector AnfRuntimeAlgorithm::GetInputRealDeviceShapeIfExist(const An auto max_shape = GetInputMaxShape(anf_node, index); std::transform(max_shape.begin(), max_shape.end(), device_shape.begin(), IntToSize); auto format = GetInputFormat(anf_node, index); - auto input_node_index = GetPrevNodeOutput(anf_node, index); - trans::TransShapeToDevice(device_shape, format, input_node_index.first, input_node_index.second); + trans::TransShapeToDevice(device_shape, format, anf_node, index, false); } return device_shape; } diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index 21858105cc3..430f20ad7ee 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -626,8 +626,8 @@ std::vector FracNZDeviceShape(const std::vector &shape) { return shape; } std::vector device_shape; - if (shape.size() < 2) { - MS_LOG(EXCEPTION) << "Format FRACTAL_NZ is not support shape " << shape.size(); + if (shape.size() < kShape2dDims) { + MS_LOG(EXCEPTION) << "Format FRACTAL_NZ don't support shape with " << shape.size() << " dims"; } else { (void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape)); } @@ -646,8 +646,8 @@ std::vector FracNZDeviceDynamicShape(const std::vector &shape) // For [1] and [1024] shape we can trait it as NZ shape return shape; } - if (shape.size() < 2) { - MS_LOG(EXCEPTION) << "Format FRACTAL_NZ is not support shape " << shape.size(); + if (shape.size() < kShape2dDims) { + MS_LOG(EXCEPTION) << "Format FRACTAL_NZ don't support shape with " << shape.size() << " dims"; } else { (void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape)); } @@ -695,6 +695,108 @@ std::vector FracNZLSTMDeviceDynamicShape(const std::vector &sh device_shape.push_back(kCubeSize); return device_shape; } + +std::vector FracZNRNNDeviceShape(const std::vector &shape, + const std::vector &input_hidden_size = {kAlign16, kAlign16}) { + if (shape.size() < kShape2dDims) { + MS_LOG(EXCEPTION) << "Format FRACTAL_ZN_RNN don't support shape with " << shape.size() << " dims"; + } + size_t input_size = LongToSize(input_hidden_size[0]); + size_t hidden_size = LongToSize(input_hidden_size[1]); + auto dim_last1 = shape[shape.size() - 1]; + auto dim_last2 = shape[shape.size() - 2]; + if (dim_last1 % hidden_size != 0) { + MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hidden_size; + } + size_t n_num = dim_last1 / hidden_size; + const size_t NUM16 = 16; + const size_t C0 = kCubeSize; + + std::vector device_shape = shape; + if (dim_last2 == input_size || dim_last2 == hidden_size) { + device_shape[shape.size() - 2] = DivCeil(dim_last2, NUM16); + } else if (dim_last2 == input_size + hidden_size) { + device_shape[shape.size() - 2] = DivCeil(input_size, NUM16) + DivCeil(hidden_size, NUM16); + } else { + MS_LOG(EXCEPTION) << "The second-last dim value of shape is invalid."; + } + device_shape[shape.size() - 1] = n_num * DivCeil(hidden_size, C0); + device_shape.push_back(NUM16); + device_shape.push_back(C0); + return device_shape; +} + +std::vector FracZNRNNDeviceDynamicShape(const std::vector &shape, + const std::vector &input_hidden_size = {kAlign16, kAlign16}) { + if (shape.size() < kShape2dDims) { + MS_LOG(EXCEPTION) << "Format FRACTAL_NZ_RNN don't support shape with " << shape.size() << " dims"; + } + int64_t input_size = input_hidden_size[0]; + int64_t hidden_size = input_hidden_size[1]; + auto dim_last1 = shape[shape.size() - 1]; + auto dim_last2 = shape[shape.size() - 2]; + const int64_t NUM16 = 16; + const int64_t C0 = SizeToLong(kCubeSize); + + std::vector device_shape = shape; + if (dim_last2 == Shape::SHP_ANY) { + device_shape[shape.size() - 2] = Shape::SHP_ANY; + } else if (dim_last2 == input_size || dim_last2 == hidden_size) { + device_shape[shape.size() - 2] = DivCeil(dim_last2, NUM16); + } else if (dim_last2 == input_size + hidden_size) { + device_shape[shape.size() - 2] = DivCeil(input_size, NUM16) + DivCeil(hidden_size, NUM16); + } else { + MS_LOG(EXCEPTION) << "The second-last dim value of shape is invalid."; + } + if (dim_last1 == Shape::SHP_ANY) { + device_shape[shape.size() - 1] = Shape::SHP_ANY; + } else { + if (dim_last1 % hidden_size != 0) { + MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hidden_size; + } + int64_t n_num = shape[shape.size() - 1] / hidden_size; + device_shape[shape.size() - 1] = n_num * DivCeil(hidden_size, C0); + } + device_shape.push_back(NUM16); + device_shape.push_back(C0); + return device_shape; +} + +std::vector NDRNNBiasDeviceShape(const std::vector &shape, const int64_t hidden_size = 16) { + if (shape.empty()) { + MS_LOG(EXCEPTION) << "Format ND_RNN_BIAS don't support empty shape."; + } + size_t hid_size = LongToSize(hidden_size); + // cppcheck-suppress * + if (shape[shape.size() - 1] % hid_size != 0) { + MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hid_size; + } + size_t n_num = shape[shape.size() - 1] / hid_size; + const size_t C0 = kCubeSize; + std::vector device_shape = shape; + device_shape[shape.size() - 1] = n_num * DivCeil(hid_size, C0) * C0; + return device_shape; +} + +std::vector NDRNNBiasDeviceDynamicShape(const std::vector &shape, const int64_t hidden_size = 16) { + if (shape.empty()) { + MS_LOG(EXCEPTION) << "Format ND_RNN_BIAS don't support empty shape."; + } + const int64_t C0 = SizeToLong(kCubeSize); + std::vector device_shape = shape; + // cppcheck-suppress * + auto dim_last1 = shape[shape.size() - 1]; + if (dim_last1 == Shape::SHP_ANY) { + device_shape[shape.size() - 1] = Shape::SHP_ANY; + } else { + if (dim_last1 % hidden_size != 0) { + MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hidden_size; + } + int64_t n_num = shape[shape.size() - 1] / hidden_size; + device_shape[shape.size() - 1] = n_num * DivCeil(hidden_size, C0) * C0; + } + return device_shape; +} } // namespace int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index) { @@ -723,6 +825,22 @@ int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index) { return 1; } +std::vector GetAttrInputAndHiddenSize(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + std::vector input_hidden_size = {kAlign16, kAlign16}; + if (!node->isa()) { + return input_hidden_size; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!AnfAlgo::HasNodeAttr(kAttrHiddenSize, cnode) || !AnfAlgo::HasNodeAttr(kAttrInputSize, cnode)) { + MS_LOG(EXCEPTION) << "Node with format FRACTAL_ZN_RNN or ND_RNN_BIAS should have hidden_size or input_size attr."; + } + input_hidden_size[0] = AnfAlgo::GetNodeAttr(node, kAttrInputSize); + input_hidden_size[1] = AnfAlgo::GetNodeAttr(node, kAttrHiddenSize); + return input_hidden_size; +} + bool IsNeedPadding(const std::string &format, const size_t shape_size) { if (shape_size == 0) { return false; @@ -820,7 +938,7 @@ void StringToAxisVector5D(const std::string &reshape_type_str, std::vector TransShapeToDevice(const std::vector &shape, const std::string &format, - const int64_t groups) { + const int64_t groups, const std::vector &input_hidden_size) { using DeviceShapeTransfer = std::function(const std::vector &)>; const std::map device_shape_map{{kOpFormat_NCHW, NchwDeviceShape}, {kOpFormat_NHWC, NhwcDeviceShape}, @@ -843,6 +961,12 @@ std::vector TransShapeToDevice(const std::vector &shape, const s if (groups > 1 && format == kOpFormat_FRAC_Z) { return FracZDeviceShapeWithGroups(shape, groups); } + if (format == kOpFormat_FRACTAL_ZN_RNN) { + return FracZNRNNDeviceShape(shape, input_hidden_size); + } + if (format == kOpFormat_ND_RNN_BIAS) { + return NDRNNBiasDeviceShape(shape, input_hidden_size[1]); + } auto temp_shape = shape; if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM && shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) { @@ -860,7 +984,7 @@ std::vector TransShapeToDevice(const std::vector &shape, const s } std::vector TransShapeToDevice(const std::vector &shape, const std::string &format, - const int64_t groups) { + const int64_t groups, const std::vector &input_hidden_size) { using DeviceShapeTransfer = std::function(const std::vector &)>; const std::map device_shape_map{ {kOpFormat_NCHW, NchwDeviceDynamicShape}, @@ -884,6 +1008,12 @@ std::vector TransShapeToDevice(const std::vector &shape, const if (groups > 1 && format == kOpFormat_FRAC_Z) { return FracZDeviceShapeWithGroups(shape, groups); } + if (format == kOpFormat_FRACTAL_ZN_RNN) { + return FracZNRNNDeviceDynamicShape(shape, input_hidden_size); + } + if (format == kOpFormat_ND_RNN_BIAS) { + return NDRNNBiasDeviceDynamicShape(shape, input_hidden_size[1]); + } auto temp_shape = shape; if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM && shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) { diff --git a/mindspore/ccsrc/common/trans.h b/mindspore/ccsrc/common/trans.h index dbc99c93d9d..2722a1b5263 100644 --- a/mindspore/ccsrc/common/trans.h +++ b/mindspore/ccsrc/common/trans.h @@ -31,7 +31,10 @@ namespace mindspore { namespace trans { +constexpr int64_t kAlign16 = 16; + enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims }; + enum Axis5D : int { N_ncdhw = 0, C_ncdhw, @@ -66,23 +69,30 @@ struct FormatArgs { }; int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index); +std::vector GetAttrInputAndHiddenSize(const AnfNodePtr &node); void StringToAxisVector4D(const std::string &reshape_type_str, std::vector *reshape_type_vec); void StringToAxisVector5D(const std::string &reshape_type_str, std::vector *reshape_type_vec); ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index); bool IsNeedPadding(const std::string &format, const size_t shape_size); int64_t GetNodeGroups(const AnfNodePtr &node); std::vector TransShapeToDevice(const std::vector &shape, const std::string &format, - const int64_t groups = 1); + const int64_t groups = 1, + const std::vector &input_hidden_size = {kAlign16, kAlign16}); std::vector TransShapeToDevice(const std::vector &shape, const std::string &format, - const int64_t groups = 1); + const int64_t groups = 1, + const std::vector &input_hidden_size = {kAlign16, kAlign16}); template std::vector TransShapeToDevice(const std::vector &shape, const std::string &format, const AnfNodePtr &node, - const size_t index) { + const size_t index, bool is_output = true) { int64_t groups = 1; if (format == kOpFormat_FRAC_Z) { groups = GetAttrGroups(node, index); } - return TransShapeToDevice(shape, format, groups); + std::vector input_hidden_size = {kAlign16, kAlign16}; + if (format == kOpFormat_FRACTAL_ZN_RNN || format == kOpFormat_ND_RNN_BIAS) { + input_hidden_size = GetAttrInputAndHiddenSize(node); + } + return TransShapeToDevice(shape, format, groups, input_hidden_size); } bool TransDataType(const TypeIdArgs &args, void *result); bool TransFormat(const FormatArgs &args, void *result, int64_t groups = 1); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 57729bf395d..1168c52a4de 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -119,6 +119,7 @@ constexpr auto kApplyProximalAdagradOpName = "ApplyProximalAdagrad "; constexpr auto kApplyProximalGradientDescentOpName = "ApplyProximalGradientDescent"; constexpr auto kApplyRMSPropOpName = "ApplyRMSProp"; constexpr auto kTransDataOpName = "TransData"; +constexpr auto kTransDataRNNOpName = "TransDataRNN"; constexpr auto kStackInitOpName = "StackInit"; constexpr auto kStackPushOpName = "StackPush"; constexpr auto kStackPopOpName = "StackPop"; @@ -460,6 +461,8 @@ constexpr auto kAttrRecursiveEnd = "recursive_end"; constexpr auto kAttrRecursive = "recursive"; constexpr auto kAttrMultiCallEnd = "multicall_end"; constexpr auto kAttrProfilingIterEnd = "PROFILING_ITER_END"; +constexpr auto kAttrHiddenSize = "hidden_size"; +constexpr auto kAttrInputSize = "input_size"; // primal attr key name constexpr auto kPrimalAttrForwardNodeName = "forward_node_name"; @@ -566,19 +569,34 @@ constexpr auto kOpFormat_DHWCN = "DHWCN"; constexpr auto kOpFormat_NDC1HWC0 = "NDC1HWC0"; constexpr auto kOpFormat_FRACTAL_Z_3D = "FRACTAL_Z_3D"; constexpr auto kOpFormat_FRACTAL_ZN_LSTM = "FRACTAL_ZN_LSTM"; +constexpr auto kOpFormat_FRACTAL_ZN_RNN = "FRACTAL_ZN_RNN"; +constexpr auto kOpFormat_ND_RNN_BIAS = "ND_RNN_BIAS"; -const std::set kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, - kOpFormat_ND, kOpFormat_NCHW, - kOpFormat_NHWC, kOpFormat_HWCN, - kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, - kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, - kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, - kOpFormat_NDHWC, kOpFormat_FRACTAL_ZN_LSTM, - kOpFormat_NDC1HWC0, kOpFormat_NCDHW, - kOpFormat_FRACTAL_Z_3D, kOpFormat_DHWNC, +const std::set kOpFormatList = {kOpFormat_DEFAULT, + kOpFormat_NC1KHKWHWC0, + kOpFormat_ND, + kOpFormat_NCHW, + kOpFormat_NHWC, + kOpFormat_HWCN, + kOpFormat_NC1HWC0, + kOpFormat_FRAC_Z, + kOpFormat_C1HWNCoC0, + kOpFormat_FRAC_NZ, + kOpFormat_NC1HWC0_C04, + kOpFormat_FRACTAL_Z_C04, + kOpFormat_NDHWC, + kOpFormat_FRACTAL_ZN_LSTM, + kOpFormat_FRACTAL_ZN_RNN, + kOpFormat_ND_RNN_BIAS, + kOpFormat_NDC1HWC0, + kOpFormat_NCDHW, + kOpFormat_FRACTAL_Z_3D, + kOpFormat_DHWNC, kOpFormat_DHWCN}; + const std::set kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_NCDHW}; + const std::set kOptOperatorSet = {kMomentumOpName, kApplyMomentumOpName, kApplyAdadeltaOpName, @@ -625,8 +643,9 @@ const std::set kOpNotSupportMultiThreadExecList = {kAvgPoolOpName, kBatchNorm, kBatchNormGradOpName}; const std::set kHWSpecialFormatSet = { - kOpFormat_FRACTAL_Z_3D, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, - kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_NDC1HWC0, kOpFormat_FRAC_Z}; + kOpFormat_FRACTAL_Z_3D, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, + kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM, + kOpFormat_FRACTAL_ZN_RNN, kOpFormat_NDC1HWC0, kOpFormat_FRAC_Z}; const std::set kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32}; @@ -637,7 +656,8 @@ const std::set kComputeDepend = {kUniqueOpName, kComputeAccid const std::set k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D, kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC}; -const std::set kNoPaddingFormatSet = {kOpFormat_ChannelLast, kOpFormat_FRAC_NZ}; +const std::set kNoPaddingFormatSet = {kOpFormat_ChannelLast, kOpFormat_FRAC_NZ, kOpFormat_FRACTAL_ZN_RNN, + kOpFormat_ND_RNN_BIAS}; const std::set DynamicShapeConstInputToAttr = { kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceMinOpName, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 802acd0500f..f0644b29435 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -205,6 +205,7 @@ inline const PrimitivePtr kPrimTile = std::make_shared(kTile); inline const PrimitivePtr kPrimAddN = std::make_shared("AddN"); inline const PrimitivePtr kPrimAccumulateNV2 = std::make_shared("AccumulateNV2"); inline const PrimitivePtr kPrimTransData = std::make_shared("TransData"); +inline const PrimitivePtr kPrimTransDataRNN = std::make_shared("TransDataRNN"); inline const PrimitivePtr kPrimNMSWithMask = std::make_shared("NMSWithMask"); inline const PrimitivePtr kPrimPad = std::make_shared("Pad"); inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared("ArgMaxWithValue"); diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index c57c1760bfa..14ee3fc22ea 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -92,6 +92,7 @@ from .tensor_add import _tensor_add_tbe from .tensor_add_ds import _tensor_add_ds_tbe from .trans_data import _trans_data_tbe from .trans_data_ds import _trans_data_ds_tbe +from .trans_data_rnn import _trans_data_rnn_tbe from .top_k import _top_k_tbe from .matmul import _matmul_tbe from .matmul_ds import _matmul_ds_tbe diff --git a/mindspore/ops/_op_impl/tbe/dynamic_rnn.py b/mindspore/ops/_op_impl/tbe/dynamic_rnn.py index b4a45bd6456..95531553772 100644 --- a/mindspore/ops/_op_impl/tbe/dynamic_rnn.py +++ b/mindspore/ops/_op_impl/tbe/dynamic_rnn.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -53,22 +53,22 @@ dynamic_rnn_op_info = TBERegOp("DynamicRNN") \ .output(5, "f", False, "required", "all") \ .output(6, "o", False, "required", "all") \ .output(7, "tanhc", False, "required", "all") \ - .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F32_Default, DataType.None_Default, + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNRNN, DataType.F32_ND_RNNBIAS, DataType.None_Default, DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ - .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F16_Default, DataType.None_Default, + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNRNN, DataType.F16_ND_RNNBIAS, DataType.None_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \ - .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F32_Default, DataType.I32_Default, + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNRNN, DataType.F32_ND_RNNBIAS, DataType.I32_Default, DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ - .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F16_Default, DataType.I32_Default, + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNRNN, DataType.F16_ND_RNNBIAS, DataType.I32_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, diff --git a/mindspore/ops/_op_impl/tbe/trans_data_rnn.py b/mindspore/ops/_op_impl/tbe/trans_data_rnn.py new file mode 100644 index 00000000000..c82f208ca60 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/trans_data_rnn.py @@ -0,0 +1,44 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""TransDataRNN op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +trans_data_rnn_op_info = TBERegOp("TransDataRNN") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("trans_data_rnn.so") \ + .compute_cost(10) \ + .kernel_name("trans_data_rnn") \ + .partial_flag(True) \ + .attr("src_format", "required", "str", "FRACTAL_ZN_RNN, ND_RNN_BIAS") \ + .attr("dst_format", "required", "str", "FRACTAL_ZN_RNN, ND_RNN_BIAS") \ + .attr("input_size", "required", "int", "all") \ + .attr("hidden_size", "required", "int", "all") \ + .input(0, "src", False, "required", "all") \ + .output(0, "dst", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_FracZNRNN) \ + .dtype_format(DataType.F16_FracZNRNN, DataType.F16_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_ND_RNNBIAS) \ + .dtype_format(DataType.F16_ND_RNNBIAS, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_ND_RNNBIAS) \ + .dtype_format(DataType.F32_ND_RNNBIAS, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(trans_data_rnn_op_info) +def _trans_data_rnn_tbe(): + """TransDataRNN TBE register""" + return diff --git a/mindspore/ops/op_info_register.py b/mindspore/ops/op_info_register.py index 4251ef4e80e..476d2f3dac9 100644 --- a/mindspore/ops/op_info_register.py +++ b/mindspore/ops/op_info_register.py @@ -721,6 +721,8 @@ class DataType: F16_NDC1HWC0 = ("float16", "NDC1HWC0") F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D") F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM") + F16_FracZNRNN = ("float16", "FRACTAL_ZN_RNN") + F16_ND_RNNBIAS = ("float16", "ND_RNN_BIAS") F16_ChannelLast = ("float16", "ChannelLast") F32_None = ("float32", "") @@ -738,6 +740,8 @@ class DataType: F32_NDC1HWC0 = ("float32", "NDC1HWC0") F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D") F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM") + F32_FracZNRNN = ("float32", "FRACTAL_ZN_RNN") + F32_ND_RNNBIAS = ("float32", "ND_RNN_BIAS") F32_ChannelLast = ("float32", "ChannelLast") F64_None = ("float64", "") @@ -878,6 +882,8 @@ class DataType: F16_NDC1HWC0 = ("float16", "NDC1HWC0") F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D") F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM") + F16_FracZNRNN = ("float16", "FRACTAL_ZN_RNN") + F16_ND_RNNBIAS = ("float16", "ND_RNN_BIAS") F16_ChannelLast = ("float16", "ChannelLast") F32_None = ("float32", "") @@ -895,6 +901,8 @@ class DataType: F32_NDC1HWC0 = ("float32", "NDC1HWC0") F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D") F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM") + F32_FracZNRNN = ("float32", "FRACTAL_ZN_RNN") + F32_ND_RNNBIAS = ("float32", "ND_RNN_BIAS") F32_ChannelLast = ("float32", "ChannelLast") F64_None = ("float64", "") diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index dd19b1adef3..3c45d1e2446 100755 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -7404,9 +7404,6 @@ class DynamicRNN(PrimitiveWithInfer): are learnable weights between the output and the input in the formula. For instance, :math:`W_{ix}, b_{ix}` are the weight and bias used to transform from input :math:`x` to :math:`i`. - Note: - The `hidden_size` in shape of inputs must be multiple of 16. - Args: cell_type (str): A string identifying the cell type in the op. Default: 'LSTM'. Only 'LSTM' is currently supported. @@ -7534,6 +7531,8 @@ class DynamicRNN(PrimitiveWithInfer): validator.check("c_shape", c_shape, "h_shape", h_shape, Rel.EQ, self.name) self.placeholder_index = [3] self.add_prim_attr("placeholder_index", self.placeholder_index) + self.add_prim_attr("input_size", input_size) + self.add_prim_attr("hidden_size", hidden_size) y_shape = (num_step, batch_size, hidden_size) return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape diff --git a/tests/ut/cpp/common/trans_test.cc b/tests/ut/cpp/common/trans_test.cc index 559933fd6f6..1fb3ee78b1f 100644 --- a/tests/ut/cpp/common/trans_test.cc +++ b/tests/ut/cpp/common/trans_test.cc @@ -30,17 +30,13 @@ class FormatTransTest : public UT::Common { }; TEST_F(FormatTransTest, nchw_to_hwcn) { - uint16_t data[2*2*2*2] = {12581,14220,14937,14302, - 15004,14951,14694,14564, - 14069,14554,10507,14787, - 13016,15263,14872,10838}; - uint16_t res[2*2*2*2] = {12581,14069,15004,13016, - 14220,14554,14951,15263, - 14937,10507,14694,14872, - 14302,14787,14564,10838}; + uint16_t data[2 * 2 * 2 * 2] = {12581, 14220, 14937, 14302, 15004, 14951, 14694, 14564, + 14069, 14554, 10507, 14787, 13016, 15263, 14872, 10838}; + uint16_t res[2 * 2 * 2 * 2] = {12581, 14069, 15004, 13016, 14220, 14554, 14951, 15263, + 14937, 10507, 14694, 14872, 14302, 14787, 14564, 10838}; size_t device_size = 32; auto trans_tmp = std::vector(device_size); - FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_HWCN, + FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_HWCN, {2, 2, 2, 2}, {2, 2, 2, 2}, kNumberTypeFloat16}; EXPECT_EQ(trans::TransFormat(format_args, trans_tmp.data()), true); for (size_t i = 0; i < sizeof(res) / sizeof(res[0]); i++) { @@ -49,19 +45,15 @@ TEST_F(FormatTransTest, nchw_to_hwcn) { } TEST_F(FormatTransTest, hwcn_to_nchw) { - uint16_t data[2*2*2*2] = {12581,14069,15004,13016, - 14220,14554,14951,15263, - 14937,10507,14694,14872, - 14302,14787,14564,10838}; + uint16_t data[2 * 2 * 2 * 2] = {12581, 14069, 15004, 13016, 14220, 14554, 14951, 15263, + 14937, 10507, 14694, 14872, 14302, 14787, 14564, 10838}; - uint16_t res[2*2*2*2] = {12581,14220,14937,14302, - 15004,14951,14694,14564, - 14069,14554,10507,14787, - 13016,15263,14872,10838}; + uint16_t res[2 * 2 * 2 * 2] = {12581, 14220, 14937, 14302, 15004, 14951, 14694, 14564, + 14069, 14554, 10507, 14787, 13016, 15263, 14872, 10838}; size_t device_size = 32; auto trans_tmp = std::vector(device_size); - FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_HWCN, + FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_HWCN, {2, 2, 2, 2}, {2, 2, 2, 2}, kNumberTypeFloat16}; EXPECT_EQ(trans::TransFormatFromDeviceToHost(format_args, trans_tmp.data()), true); for (size_t i = 0; i < sizeof(res) / sizeof(res[0]); i++) { @@ -70,44 +62,89 @@ TEST_F(FormatTransTest, hwcn_to_nchw) { } TEST_F(FormatTransTest, nchw_to_nhwc) { - uint16_t data[2*2*2*2] = {11750,13778,15007,15321, - 15163,13446,15063,14467, - 15056,13284,15219,14797, - 12684,14288,14855,14799}; - uint16_t res[2*2*2*2] = {11750,15163,13778,13446, - 15007,15063,15321,14467, - 15056,12684,13284,14288, - 15219,14855,14797,14799}; + uint16_t data[2 * 2 * 2 * 2] = {11750, 13778, 15007, 15321, 15163, 13446, 15063, 14467, + 15056, 13284, 15219, 14797, 12684, 14288, 14855, 14799}; + uint16_t res[2 * 2 * 2 * 2] = {11750, 15163, 13778, 13446, 15007, 15063, 15321, 14467, + 15056, 12684, 13284, 14288, 15219, 14855, 14797, 14799}; size_t device_size = 32; auto trans_tmp = std::vector(device_size); - FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_NHWC, + FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_NHWC, {2, 2, 2, 2}, {2, 2, 2, 2}, kNumberTypeFloat16}; EXPECT_EQ(trans::TransFormat(format_args, trans_tmp.data()), true); for (size_t i = 0; i < sizeof(res) / sizeof(res[0]); i++) { EXPECT_EQ((reinterpret_cast(trans_tmp.data()))[i], res[i]); } } + TEST_F(FormatTransTest, nhwc_to_nchw) { - uint16_t data[2*2*2*2] = {11750,15163,13778,13446, - 15007,15063,15321,14467, - 15056,12684,13284,14288, - 15219,14855,14797,14799}; - uint16_t res[2*2*2*2] = {11750,13778,15007,15321, - 15163,13446,15063,14467, - 15056,13284,15219,14797, - 12684,14288,14855,14799}; + uint16_t data[2 * 2 * 2 * 2] = {11750, 15163, 13778, 13446, 15007, 15063, 15321, 14467, + 15056, 12684, 13284, 14288, 15219, 14855, 14797, 14799}; + uint16_t res[2 * 2 * 2 * 2] = {11750, 13778, 15007, 15321, 15163, 13446, 15063, 14467, + 15056, 13284, 15219, 14797, 12684, 14288, 14855, 14799}; size_t device_size = 32; auto trans_tmp = std::vector(device_size); - FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_NHWC, + FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_NHWC, {2, 2, 2, 2}, {2, 2, 2, 2}, kNumberTypeFloat16}; EXPECT_EQ(trans::TransFormatFromDeviceToHost(format_args, trans_tmp.data()), true); for (size_t i = 0; i < sizeof(res) / sizeof(res[0]); i++) { EXPECT_EQ((reinterpret_cast(trans_tmp.data()))[i], res[i]); } } + +class ShapeTransTest : public UT::Common { + public: + ShapeTransTest() = default; + void SetUp() override {} + void TearDown() override {} +}; + +TEST_F(ShapeTransTest, fraczn_rnn_device_shape) { + std::vector host_shape = {43, 120}; + std::string format = kOpFormat_FRACTAL_ZN_RNN; + std::vector input_hidden_size = {13, 30}; + auto trans_shape = trans::TransShapeToDevice(host_shape, format, 1, input_hidden_size); + const std::vector expect_shape = {3, 8, 16, 16}; + EXPECT_EQ(trans_shape.size(), expect_shape.size()); + for (size_t i = 0; i < expect_shape.size(); i++) { + EXPECT_EQ(trans_shape[i], expect_shape[i]); + } +} + +TEST_F(ShapeTransTest, nd_rnn_bias_device_shape) { + std::vector host_shape = {120}; + std::string format = kOpFormat_ND_RNN_BIAS; + std::vector input_hidden_size = {13, 30}; + auto trans_shape = trans::TransShapeToDevice(host_shape, format, 1, input_hidden_size); + std::vector expect_shape = {128}; + EXPECT_EQ(trans_shape.size(), expect_shape.size()); + for (size_t i = 0; i < expect_shape.size(); i++) { + EXPECT_EQ(trans_shape[i], expect_shape[i]); + } +} + +TEST_F(ShapeTransTest, fraczn_rnn_dynamic_device_shape) { + std::vector host_shape = {-1, -1}; + std::string format = kOpFormat_FRACTAL_ZN_RNN; + std::vector input_hidden_size = {13, 30}; + auto trans_shape = trans::TransShapeToDevice(host_shape, format, 1, input_hidden_size); + const std::vector expect_shape = {-1, -1, 16, 16}; + EXPECT_EQ(trans_shape.size(), expect_shape.size()); + for (size_t i = 0; i < expect_shape.size(); i++) { + EXPECT_EQ(trans_shape[i], expect_shape[i]); + } +} + +TEST_F(ShapeTransTest, nd_rnn_bias_dynamic_device_shape) { + std::vector host_shape = {-1}; + std::string format = kOpFormat_ND_RNN_BIAS; + std::vector input_hidden_size = {13, 30}; + auto trans_shape = trans::TransShapeToDevice(host_shape, format, 1, input_hidden_size); + std::vector expect_shape = {-1}; + EXPECT_EQ(trans_shape.size(), expect_shape.size()); + for (size_t i = 0; i < expect_shape.size(); i++) { + EXPECT_EQ(trans_shape[i], expect_shape[i]); + } +} } // namespace trans } // namespace mindspore - - -