!22068 DynamicRNN supports scenarios that hidden_size is not multiples of 16

Merge pull request !22068 from yuchaojie/op_select2
This commit is contained in:
i-robot 2021-08-25 01:32:26 +00:00 committed by Gitee
commit abc9d8e6fe
15 changed files with 352 additions and 86 deletions

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -33,6 +33,11 @@ namespace mindspore {
namespace opt { namespace opt {
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
namespace { namespace {
bool NeedInsertTransData(const std::vector<size_t> &origin_shape, const std::string &format) {
return kCommonFormatSet.find(format) == kCommonFormatSet.end() &&
(origin_shape.size() > 1 || format == kOpFormat_ND_RNN_BIAS);
}
AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) { const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
std::vector<AnfNodePtr> trans_inputs; std::vector<AnfNodePtr> trans_inputs;
@ -50,14 +55,15 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i
void SetTransNodeAttr(const CNodePtr &trans_node) { void SetTransNodeAttr(const CNodePtr &trans_node) {
MS_EXCEPTION_IF_NULL(trans_node); MS_EXCEPTION_IF_NULL(trans_node);
if (AnfAlgo::GetCNodeName(trans_node) == kTransDataOpName) { auto trans_opname = AnfAlgo::GetCNodeName(trans_node);
if (trans_opname == kTransDataOpName || trans_opname == kTransDataRNNOpName) {
std::string input_format = AnfAlgo::GetInputFormat(trans_node, 0); std::string input_format = AnfAlgo::GetInputFormat(trans_node, 0);
std::string output_format = AnfAlgo::GetOutputFormat(trans_node, 0); std::string output_format = AnfAlgo::GetOutputFormat(trans_node, 0);
if (input_format == kOpFormat_DEFAULT) { if (input_format == kOpFormat_DEFAULT) {
input_format = kOpFormat_NCHW; input_format = AnfAlgo::GetCNodeName(trans_node) == kTransDataOpName ? kOpFormat_NCHW : kOpFormat_ND;
} }
if (output_format == kOpFormat_DEFAULT) { if (output_format == kOpFormat_DEFAULT) {
output_format = kOpFormat_NCHW; output_format = AnfAlgo::GetCNodeName(trans_node) == kTransDataOpName ? kOpFormat_NCHW : kOpFormat_ND;
} }
AnfAlgo::SetNodeAttr(kAttrSrcFormat, MakeValue(input_format), trans_node); AnfAlgo::SetNodeAttr(kAttrSrcFormat, MakeValue(input_format), trans_node);
AnfAlgo::SetNodeAttr(kAttrDstFormat, MakeValue(output_format), trans_node); AnfAlgo::SetNodeAttr(kAttrDstFormat, MakeValue(output_format), trans_node);
@ -115,7 +121,7 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
} }
std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index); std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index);
std::string dest_format = AnfAlgo::GetInputFormat(node, index); std::string dest_format = AnfAlgo::GetInputFormat(node, index);
if (kCommonFormatSet.find(dest_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { if (NeedInsertTransData(origin_shape, dest_format)) {
MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index) MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
<< " To DefaultFormat , index: " << index; << " To DefaultFormat , index: " << index;
auto transdata = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true); auto transdata = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
@ -136,7 +142,7 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
MS_LOG(EXCEPTION) << "Got the hw format " << output_format << "when insert the transdata node " MS_LOG(EXCEPTION) << "Got the hw format " << output_format << "when insert the transdata node "
<< node->DebugString() << " trace: " << trace::DumpSourceLines(node); << node->DebugString() << " trace: " << trace::DumpSourceLines(node);
} }
if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { if (NeedInsertTransData(origin_shape, output_format)) {
MS_LOG(DEBUG) << "Inserted transdata " << output_format << " to default , index :0"; MS_LOG(DEBUG) << "Inserted transdata " << output_format << " to default , index :0";
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false); return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false);
} }
@ -164,7 +170,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
} }
auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx); auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
if (origin_shape.size() > 1 && kCommonFormatSet.find(output_format) == kCommonFormatSet.end()) { if (NeedInsertTransData(origin_shape, output_format)) {
auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false); auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false);
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, output_idx)) { if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, output_idx)) {
kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0); kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0);
@ -193,11 +199,14 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
: AnfAlgo::GetOutputReshapeType(node, insert_index); : AnfAlgo::GetOutputReshapeType(node, insert_index);
auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index) auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index)
: AnfAlgo::GetOutputInferShape(input_node, insert_index); : AnfAlgo::GetOutputInferShape(input_node, insert_index);
bool need_padding = is_insert_input ? trans::IsNeedPadding(dst_format, input_node_out_shape.size()) std::string spec_format = is_insert_input ? dst_format : input_format;
: trans::IsNeedPadding(input_format, input_node_out_shape.size()); bool need_padding = trans::IsNeedPadding(spec_format, input_node_out_shape.size());
std::string trans_opname = (spec_format == kOpFormat_FRACTAL_ZN_RNN || spec_format == kOpFormat_ND_RNN_BIAS)
? prim::kPrimTransDataRNN->name()
: prim::kPrimTransData->name();
if (!need_padding) { if (!need_padding) {
// don't need padding insert transdata only // don't need padding insert transdata only
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::kPrimTransData->name()); trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, trans_opname);
trans_node = trans_data; trans_node = trans_data;
} else if (is_insert_input) { } else if (is_insert_input) {
// if need padding & is input need insert a transdata // if need padding & is input need insert a transdata
@ -205,16 +214,20 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
auto padding_shape = trans::PaddingShape(input_node_out_shape, AnfAlgo::GetInputFormat(node, insert_index), auto padding_shape = trans::PaddingShape(input_node_out_shape, AnfAlgo::GetInputFormat(node, insert_index),
AnfAlgo::GetInputReshapeType(node, insert_index)); AnfAlgo::GetInputReshapeType(node, insert_index));
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::kPrimTransData->name()); trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, trans_opname);
trans_node = trans_data; trans_node = trans_data;
trans_data->set_abstract(input_node->abstract()); trans_data->set_abstract(input_node->abstract());
} else { } else {
// if need padding & is output need insert a transdata // if need padding & is output need insert a transdata
// node -> transdata[padding shape] -> reshape[ori_shape] // node -> transdata[padding shape] -> reshape[ori_shape]
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::kPrimTransData->name()); trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, trans_opname);
auto reshape_node = CreateReshapeNode(func_graph, trans_data, kernel_select, input_node_out_shape); auto reshape_node = CreateReshapeNode(func_graph, trans_data, kernel_select, input_node_out_shape);
trans_node = reshape_node; trans_node = reshape_node;
} }
if (trans_opname == prim::kPrimTransDataRNN->name()) {
AnfAlgo::CopyNodeAttr(kAttrHiddenSize, node, trans_data);
AnfAlgo::CopyNodeAttr(kAttrInputSize, node, trans_data);
}
// refresh the transdata's format to ori format & dst format // refresh the transdata's format to ori format & dst format
RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis); RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis);
if (!is_insert_input) { if (!is_insert_input) {

View File

@ -459,7 +459,7 @@ const AnfNodePtr DynamicGRUV2GradFission::Process(const FuncGraphPtr &func_graph
return nullptr; return nullptr;
} }
if (AnfAlgo::IsDynamicShape(node)) { if (AnfAlgo::IsDynamicShape(node)) {
MS_LOG(INFO) << "DynamicGRUV2Grad is dynamic shape, can not optimizer."; MS_LOG(INFO) << "DynamicGRUV2Grad is dynamic shape, can not do fission.";
return nullptr; return nullptr;
} }

View File

@ -518,6 +518,10 @@ const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph
<< (kDynamicRNNGradInputNum + 1) << " inputs"; << (kDynamicRNNGradInputNum + 1) << " inputs";
return nullptr; return nullptr;
} }
if (AnfAlgo::IsDynamicShape(node)) {
MS_LOG(INFO) << "DynamicRnnGrad is dynamic shape, can not do fission.";
return nullptr;
}
std::vector<AnfNodePtr> new_outputs; std::vector<AnfNodePtr> new_outputs;
auto lstm_input_grad = AddLSTMInputGradNode(func_graph, dynamic_rnn_grad_cnode, &new_outputs); auto lstm_input_grad = AddLSTMInputGradNode(func_graph, dynamic_rnn_grad_cnode, &new_outputs);

View File

@ -51,6 +51,7 @@ void SetAttrForInputNode(const AnfNodePtr &node, int64_t groups) {
void SetAttrForConvInput(const CNodePtr &cnode) { void SetAttrForConvInput(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
auto groups = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrGroups); auto groups = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrGroups);
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups), cnode);
if (groups > 1) { if (groups > 1) {
SetAttrForInputNode(cnode->input(kConvFilterInputIndex), groups); SetAttrForInputNode(cnode->input(kConvFilterInputIndex), groups);
} }

View File

@ -857,8 +857,7 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &n
if (trans::IsNeedPadding(format, infer_shape.size())) { if (trans::IsNeedPadding(format, infer_shape.size())) {
infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx)); infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx));
} }
auto input_node_index = GetPrevNodeOutput(node, input_idx); return trans::TransShapeToDevice(infer_shape, format, node, input_idx, false);
return trans::TransShapeToDevice(infer_shape, format, input_node_index.first, input_node_index.second);
} }
std::string AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) { std::string AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
@ -2055,8 +2054,7 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetInputRealDeviceShapeIfExist(const An
auto max_shape = GetInputMaxShape(anf_node, index); auto max_shape = GetInputMaxShape(anf_node, index);
std::transform(max_shape.begin(), max_shape.end(), device_shape.begin(), IntToSize); std::transform(max_shape.begin(), max_shape.end(), device_shape.begin(), IntToSize);
auto format = GetInputFormat(anf_node, index); auto format = GetInputFormat(anf_node, index);
auto input_node_index = GetPrevNodeOutput(anf_node, index); trans::TransShapeToDevice(device_shape, format, anf_node, index, false);
trans::TransShapeToDevice(device_shape, format, input_node_index.first, input_node_index.second);
} }
return device_shape; return device_shape;
} }

View File

@ -626,8 +626,8 @@ std::vector<size_t> FracNZDeviceShape(const std::vector<size_t> &shape) {
return shape; return shape;
} }
std::vector<size_t> device_shape; std::vector<size_t> device_shape;
if (shape.size() < 2) { if (shape.size() < kShape2dDims) {
MS_LOG(EXCEPTION) << "Format FRACTAL_NZ is not support shape " << shape.size(); MS_LOG(EXCEPTION) << "Format FRACTAL_NZ don't support shape with " << shape.size() << " dims";
} else { } else {
(void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape)); (void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape));
} }
@ -646,8 +646,8 @@ std::vector<int64_t> FracNZDeviceDynamicShape(const std::vector<int64_t> &shape)
// For [1] and [1024] shape we can trait it as NZ shape // For [1] and [1024] shape we can trait it as NZ shape
return shape; return shape;
} }
if (shape.size() < 2) { if (shape.size() < kShape2dDims) {
MS_LOG(EXCEPTION) << "Format FRACTAL_NZ is not support shape " << shape.size(); MS_LOG(EXCEPTION) << "Format FRACTAL_NZ don't support shape with " << shape.size() << " dims";
} else { } else {
(void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape)); (void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape));
} }
@ -695,6 +695,108 @@ std::vector<int64_t> FracNZLSTMDeviceDynamicShape(const std::vector<int64_t> &sh
device_shape.push_back(kCubeSize); device_shape.push_back(kCubeSize);
return device_shape; return device_shape;
} }
std::vector<size_t> FracZNRNNDeviceShape(const std::vector<size_t> &shape,
const std::vector<int64_t> &input_hidden_size = {kAlign16, kAlign16}) {
if (shape.size() < kShape2dDims) {
MS_LOG(EXCEPTION) << "Format FRACTAL_ZN_RNN don't support shape with " << shape.size() << " dims";
}
size_t input_size = LongToSize(input_hidden_size[0]);
size_t hidden_size = LongToSize(input_hidden_size[1]);
auto dim_last1 = shape[shape.size() - 1];
auto dim_last2 = shape[shape.size() - 2];
if (dim_last1 % hidden_size != 0) {
MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hidden_size;
}
size_t n_num = dim_last1 / hidden_size;
const size_t NUM16 = 16;
const size_t C0 = kCubeSize;
std::vector<size_t> device_shape = shape;
if (dim_last2 == input_size || dim_last2 == hidden_size) {
device_shape[shape.size() - 2] = DivCeil(dim_last2, NUM16);
} else if (dim_last2 == input_size + hidden_size) {
device_shape[shape.size() - 2] = DivCeil(input_size, NUM16) + DivCeil(hidden_size, NUM16);
} else {
MS_LOG(EXCEPTION) << "The second-last dim value of shape is invalid.";
}
device_shape[shape.size() - 1] = n_num * DivCeil(hidden_size, C0);
device_shape.push_back(NUM16);
device_shape.push_back(C0);
return device_shape;
}
std::vector<int64_t> FracZNRNNDeviceDynamicShape(const std::vector<int64_t> &shape,
const std::vector<int64_t> &input_hidden_size = {kAlign16, kAlign16}) {
if (shape.size() < kShape2dDims) {
MS_LOG(EXCEPTION) << "Format FRACTAL_NZ_RNN don't support shape with " << shape.size() << " dims";
}
int64_t input_size = input_hidden_size[0];
int64_t hidden_size = input_hidden_size[1];
auto dim_last1 = shape[shape.size() - 1];
auto dim_last2 = shape[shape.size() - 2];
const int64_t NUM16 = 16;
const int64_t C0 = SizeToLong(kCubeSize);
std::vector<int64_t> device_shape = shape;
if (dim_last2 == Shape::SHP_ANY) {
device_shape[shape.size() - 2] = Shape::SHP_ANY;
} else if (dim_last2 == input_size || dim_last2 == hidden_size) {
device_shape[shape.size() - 2] = DivCeil(dim_last2, NUM16);
} else if (dim_last2 == input_size + hidden_size) {
device_shape[shape.size() - 2] = DivCeil(input_size, NUM16) + DivCeil(hidden_size, NUM16);
} else {
MS_LOG(EXCEPTION) << "The second-last dim value of shape is invalid.";
}
if (dim_last1 == Shape::SHP_ANY) {
device_shape[shape.size() - 1] = Shape::SHP_ANY;
} else {
if (dim_last1 % hidden_size != 0) {
MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hidden_size;
}
int64_t n_num = shape[shape.size() - 1] / hidden_size;
device_shape[shape.size() - 1] = n_num * DivCeil(hidden_size, C0);
}
device_shape.push_back(NUM16);
device_shape.push_back(C0);
return device_shape;
}
std::vector<size_t> NDRNNBiasDeviceShape(const std::vector<size_t> &shape, const int64_t hidden_size = 16) {
if (shape.empty()) {
MS_LOG(EXCEPTION) << "Format ND_RNN_BIAS don't support empty shape.";
}
size_t hid_size = LongToSize(hidden_size);
// cppcheck-suppress *
if (shape[shape.size() - 1] % hid_size != 0) {
MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hid_size;
}
size_t n_num = shape[shape.size() - 1] / hid_size;
const size_t C0 = kCubeSize;
std::vector<size_t> device_shape = shape;
device_shape[shape.size() - 1] = n_num * DivCeil(hid_size, C0) * C0;
return device_shape;
}
std::vector<int64_t> NDRNNBiasDeviceDynamicShape(const std::vector<int64_t> &shape, const int64_t hidden_size = 16) {
if (shape.empty()) {
MS_LOG(EXCEPTION) << "Format ND_RNN_BIAS don't support empty shape.";
}
const int64_t C0 = SizeToLong(kCubeSize);
std::vector<int64_t> device_shape = shape;
// cppcheck-suppress *
auto dim_last1 = shape[shape.size() - 1];
if (dim_last1 == Shape::SHP_ANY) {
device_shape[shape.size() - 1] = Shape::SHP_ANY;
} else {
if (dim_last1 % hidden_size != 0) {
MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hidden_size;
}
int64_t n_num = shape[shape.size() - 1] / hidden_size;
device_shape[shape.size() - 1] = n_num * DivCeil(hidden_size, C0) * C0;
}
return device_shape;
}
} // namespace } // namespace
int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index) { int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index) {
@ -723,6 +825,22 @@ int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index) {
return 1; return 1;
} }
std::vector<int64_t> GetAttrInputAndHiddenSize(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
std::vector<int64_t> input_hidden_size = {kAlign16, kAlign16};
if (!node->isa<CNode>()) {
return input_hidden_size;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!AnfAlgo::HasNodeAttr(kAttrHiddenSize, cnode) || !AnfAlgo::HasNodeAttr(kAttrInputSize, cnode)) {
MS_LOG(EXCEPTION) << "Node with format FRACTAL_ZN_RNN or ND_RNN_BIAS should have hidden_size or input_size attr.";
}
input_hidden_size[0] = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrInputSize);
input_hidden_size[1] = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrHiddenSize);
return input_hidden_size;
}
bool IsNeedPadding(const std::string &format, const size_t shape_size) { bool IsNeedPadding(const std::string &format, const size_t shape_size) {
if (shape_size == 0) { if (shape_size == 0) {
return false; return false;
@ -820,7 +938,7 @@ void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5
} }
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format, std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
const int64_t groups) { const int64_t groups, const std::vector<int64_t> &input_hidden_size) {
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>; using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape}, const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape},
{kOpFormat_NHWC, NhwcDeviceShape}, {kOpFormat_NHWC, NhwcDeviceShape},
@ -843,6 +961,12 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
if (groups > 1 && format == kOpFormat_FRAC_Z) { if (groups > 1 && format == kOpFormat_FRAC_Z) {
return FracZDeviceShapeWithGroups(shape, groups); return FracZDeviceShapeWithGroups(shape, groups);
} }
if (format == kOpFormat_FRACTAL_ZN_RNN) {
return FracZNRNNDeviceShape(shape, input_hidden_size);
}
if (format == kOpFormat_ND_RNN_BIAS) {
return NDRNNBiasDeviceShape(shape, input_hidden_size[1]);
}
auto temp_shape = shape; auto temp_shape = shape;
if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM && if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM &&
shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) { shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
@ -860,7 +984,7 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
} }
std::vector<int64_t> TransShapeToDevice(const std::vector<int64_t> &shape, const std::string &format, std::vector<int64_t> TransShapeToDevice(const std::vector<int64_t> &shape, const std::string &format,
const int64_t groups) { const int64_t groups, const std::vector<int64_t> &input_hidden_size) {
using DeviceShapeTransfer = std::function<std::vector<int64_t>(const std::vector<int64_t> &)>; using DeviceShapeTransfer = std::function<std::vector<int64_t>(const std::vector<int64_t> &)>;
const std::map<std::string, DeviceShapeTransfer> device_shape_map{ const std::map<std::string, DeviceShapeTransfer> device_shape_map{
{kOpFormat_NCHW, NchwDeviceDynamicShape}, {kOpFormat_NCHW, NchwDeviceDynamicShape},
@ -884,6 +1008,12 @@ std::vector<int64_t> TransShapeToDevice(const std::vector<int64_t> &shape, const
if (groups > 1 && format == kOpFormat_FRAC_Z) { if (groups > 1 && format == kOpFormat_FRAC_Z) {
return FracZDeviceShapeWithGroups(shape, groups); return FracZDeviceShapeWithGroups(shape, groups);
} }
if (format == kOpFormat_FRACTAL_ZN_RNN) {
return FracZNRNNDeviceDynamicShape(shape, input_hidden_size);
}
if (format == kOpFormat_ND_RNN_BIAS) {
return NDRNNBiasDeviceDynamicShape(shape, input_hidden_size[1]);
}
auto temp_shape = shape; auto temp_shape = shape;
if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM && if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM &&
shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) { shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {

View File

@ -31,7 +31,10 @@
namespace mindspore { namespace mindspore {
namespace trans { namespace trans {
constexpr int64_t kAlign16 = 16;
enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims }; enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims };
enum Axis5D : int { enum Axis5D : int {
N_ncdhw = 0, N_ncdhw = 0,
C_ncdhw, C_ncdhw,
@ -66,23 +69,30 @@ struct FormatArgs {
}; };
int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index); int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index);
std::vector<int64_t> GetAttrInputAndHiddenSize(const AnfNodePtr &node);
void StringToAxisVector4D(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec); void StringToAxisVector4D(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec);
void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5D> *reshape_type_vec); void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5D> *reshape_type_vec);
ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index); ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index);
bool IsNeedPadding(const std::string &format, const size_t shape_size); bool IsNeedPadding(const std::string &format, const size_t shape_size);
int64_t GetNodeGroups(const AnfNodePtr &node); int64_t GetNodeGroups(const AnfNodePtr &node);
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format, std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
const int64_t groups = 1); const int64_t groups = 1,
const std::vector<int64_t> &input_hidden_size = {kAlign16, kAlign16});
std::vector<int64_t> TransShapeToDevice(const std::vector<int64_t> &shape, const std::string &format, std::vector<int64_t> TransShapeToDevice(const std::vector<int64_t> &shape, const std::string &format,
const int64_t groups = 1); const int64_t groups = 1,
const std::vector<int64_t> &input_hidden_size = {kAlign16, kAlign16});
template <typename T> template <typename T>
std::vector<T> TransShapeToDevice(const std::vector<T> &shape, const std::string &format, const AnfNodePtr &node, std::vector<T> TransShapeToDevice(const std::vector<T> &shape, const std::string &format, const AnfNodePtr &node,
const size_t index) { const size_t index, bool is_output = true) {
int64_t groups = 1; int64_t groups = 1;
if (format == kOpFormat_FRAC_Z) { if (format == kOpFormat_FRAC_Z) {
groups = GetAttrGroups(node, index); groups = GetAttrGroups(node, index);
} }
return TransShapeToDevice(shape, format, groups); std::vector<int64_t> input_hidden_size = {kAlign16, kAlign16};
if (format == kOpFormat_FRACTAL_ZN_RNN || format == kOpFormat_ND_RNN_BIAS) {
input_hidden_size = GetAttrInputAndHiddenSize(node);
}
return TransShapeToDevice(shape, format, groups, input_hidden_size);
} }
bool TransDataType(const TypeIdArgs &args, void *result); bool TransDataType(const TypeIdArgs &args, void *result);
bool TransFormat(const FormatArgs &args, void *result, int64_t groups = 1); bool TransFormat(const FormatArgs &args, void *result, int64_t groups = 1);

View File

@ -119,6 +119,7 @@ constexpr auto kApplyProximalAdagradOpName = "ApplyProximalAdagrad ";
constexpr auto kApplyProximalGradientDescentOpName = "ApplyProximalGradientDescent"; constexpr auto kApplyProximalGradientDescentOpName = "ApplyProximalGradientDescent";
constexpr auto kApplyRMSPropOpName = "ApplyRMSProp"; constexpr auto kApplyRMSPropOpName = "ApplyRMSProp";
constexpr auto kTransDataOpName = "TransData"; constexpr auto kTransDataOpName = "TransData";
constexpr auto kTransDataRNNOpName = "TransDataRNN";
constexpr auto kStackInitOpName = "StackInit"; constexpr auto kStackInitOpName = "StackInit";
constexpr auto kStackPushOpName = "StackPush"; constexpr auto kStackPushOpName = "StackPush";
constexpr auto kStackPopOpName = "StackPop"; constexpr auto kStackPopOpName = "StackPop";
@ -460,6 +461,8 @@ constexpr auto kAttrRecursiveEnd = "recursive_end";
constexpr auto kAttrRecursive = "recursive"; constexpr auto kAttrRecursive = "recursive";
constexpr auto kAttrMultiCallEnd = "multicall_end"; constexpr auto kAttrMultiCallEnd = "multicall_end";
constexpr auto kAttrProfilingIterEnd = "PROFILING_ITER_END"; constexpr auto kAttrProfilingIterEnd = "PROFILING_ITER_END";
constexpr auto kAttrHiddenSize = "hidden_size";
constexpr auto kAttrInputSize = "input_size";
// primal attr key name // primal attr key name
constexpr auto kPrimalAttrForwardNodeName = "forward_node_name"; constexpr auto kPrimalAttrForwardNodeName = "forward_node_name";
@ -566,19 +569,34 @@ constexpr auto kOpFormat_DHWCN = "DHWCN";
constexpr auto kOpFormat_NDC1HWC0 = "NDC1HWC0"; constexpr auto kOpFormat_NDC1HWC0 = "NDC1HWC0";
constexpr auto kOpFormat_FRACTAL_Z_3D = "FRACTAL_Z_3D"; constexpr auto kOpFormat_FRACTAL_Z_3D = "FRACTAL_Z_3D";
constexpr auto kOpFormat_FRACTAL_ZN_LSTM = "FRACTAL_ZN_LSTM"; constexpr auto kOpFormat_FRACTAL_ZN_LSTM = "FRACTAL_ZN_LSTM";
constexpr auto kOpFormat_FRACTAL_ZN_RNN = "FRACTAL_ZN_RNN";
constexpr auto kOpFormat_ND_RNN_BIAS = "ND_RNN_BIAS";
const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT,
kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NC1KHKWHWC0,
kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_ND,
kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_NCHW,
kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, kOpFormat_NHWC,
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_HWCN,
kOpFormat_NDHWC, kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_NC1HWC0,
kOpFormat_NDC1HWC0, kOpFormat_NCDHW, kOpFormat_FRAC_Z,
kOpFormat_FRACTAL_Z_3D, kOpFormat_DHWNC, kOpFormat_C1HWNCoC0,
kOpFormat_FRAC_NZ,
kOpFormat_NC1HWC0_C04,
kOpFormat_FRACTAL_Z_C04,
kOpFormat_NDHWC,
kOpFormat_FRACTAL_ZN_LSTM,
kOpFormat_FRACTAL_ZN_RNN,
kOpFormat_ND_RNN_BIAS,
kOpFormat_NDC1HWC0,
kOpFormat_NCDHW,
kOpFormat_FRACTAL_Z_3D,
kOpFormat_DHWNC,
kOpFormat_DHWCN}; kOpFormat_DHWCN};
const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN, const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN,
kOpFormat_NCDHW}; kOpFormat_NCDHW};
const std::set<std::string> kOptOperatorSet = {kMomentumOpName, const std::set<std::string> kOptOperatorSet = {kMomentumOpName,
kApplyMomentumOpName, kApplyMomentumOpName,
kApplyAdadeltaOpName, kApplyAdadeltaOpName,
@ -625,8 +643,9 @@ const std::set<std::string> kOpNotSupportMultiThreadExecList = {kAvgPoolOpName,
kBatchNorm, kBatchNormGradOpName}; kBatchNorm, kBatchNormGradOpName};
const std::set<std::string> kHWSpecialFormatSet = { const std::set<std::string> kHWSpecialFormatSet = {
kOpFormat_FRACTAL_Z_3D, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, kOpFormat_FRACTAL_Z_3D, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ,
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_NDC1HWC0, kOpFormat_FRAC_Z}; kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM,
kOpFormat_FRACTAL_ZN_RNN, kOpFormat_NDC1HWC0, kOpFormat_FRAC_Z};
const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32}; const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32};
@ -637,7 +656,8 @@ const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccid
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D, const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC}; kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC};
const std::set<std::string> kNoPaddingFormatSet = {kOpFormat_ChannelLast, kOpFormat_FRAC_NZ}; const std::set<std::string> kNoPaddingFormatSet = {kOpFormat_ChannelLast, kOpFormat_FRAC_NZ, kOpFormat_FRACTAL_ZN_RNN,
kOpFormat_ND_RNN_BIAS};
const std::set<std::string> DynamicShapeConstInputToAttr = { const std::set<std::string> DynamicShapeConstInputToAttr = {
kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceMinOpName, kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceMinOpName,

View File

@ -205,6 +205,7 @@ inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>(kTile);
inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN"); inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN");
inline const PrimitivePtr kPrimAccumulateNV2 = std::make_shared<Primitive>("AccumulateNV2"); inline const PrimitivePtr kPrimAccumulateNV2 = std::make_shared<Primitive>("AccumulateNV2");
inline const PrimitivePtr kPrimTransData = std::make_shared<Primitive>("TransData"); inline const PrimitivePtr kPrimTransData = std::make_shared<Primitive>("TransData");
inline const PrimitivePtr kPrimTransDataRNN = std::make_shared<Primitive>("TransDataRNN");
inline const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask"); inline const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask");
inline const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad"); inline const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad");
inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("ArgMaxWithValue"); inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("ArgMaxWithValue");

View File

@ -92,6 +92,7 @@ from .tensor_add import _tensor_add_tbe
from .tensor_add_ds import _tensor_add_ds_tbe from .tensor_add_ds import _tensor_add_ds_tbe
from .trans_data import _trans_data_tbe from .trans_data import _trans_data_tbe
from .trans_data_ds import _trans_data_ds_tbe from .trans_data_ds import _trans_data_ds_tbe
from .trans_data_rnn import _trans_data_rnn_tbe
from .top_k import _top_k_tbe from .top_k import _top_k_tbe
from .matmul import _matmul_tbe from .matmul import _matmul_tbe
from .matmul_ds import _matmul_ds_tbe from .matmul_ds import _matmul_ds_tbe

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -53,22 +53,22 @@ dynamic_rnn_op_info = TBERegOp("DynamicRNN") \
.output(5, "f", False, "required", "all") \ .output(5, "f", False, "required", "all") \
.output(6, "o", False, "required", "all") \ .output(6, "o", False, "required", "all") \
.output(7, "tanhc", False, "required", "all") \ .output(7, "tanhc", False, "required", "all") \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F32_Default, DataType.None_Default, .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNRNN, DataType.F32_ND_RNNBIAS, DataType.None_Default,
DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ,
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
DataType.F32_FracNZ, DataType.F32_FracNZ) \ DataType.F32_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F16_Default, DataType.None_Default, .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNRNN, DataType.F16_ND_RNNBIAS, DataType.None_Default,
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.F16_FracNZ) \ DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F32_Default, DataType.I32_Default, .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNRNN, DataType.F32_ND_RNNBIAS, DataType.I32_Default,
DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ,
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
DataType.F32_FracNZ, DataType.F32_FracNZ) \ DataType.F32_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F16_Default, DataType.I32_Default, .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNRNN, DataType.F16_ND_RNNBIAS, DataType.I32_Default,
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,

View File

@ -0,0 +1,44 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TransDataRNN op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
trans_data_rnn_op_info = TBERegOp("TransDataRNN") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("trans_data_rnn.so") \
.compute_cost(10) \
.kernel_name("trans_data_rnn") \
.partial_flag(True) \
.attr("src_format", "required", "str", "FRACTAL_ZN_RNN, ND_RNN_BIAS") \
.attr("dst_format", "required", "str", "FRACTAL_ZN_RNN, ND_RNN_BIAS") \
.attr("input_size", "required", "int", "all") \
.attr("hidden_size", "required", "int", "all") \
.input(0, "src", False, "required", "all") \
.output(0, "dst", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_FracZNRNN) \
.dtype_format(DataType.F16_FracZNRNN, DataType.F16_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_ND_RNNBIAS) \
.dtype_format(DataType.F16_ND_RNNBIAS, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_ND_RNNBIAS) \
.dtype_format(DataType.F32_ND_RNNBIAS, DataType.F32_Default) \
.get_op_info()
@op_info_register(trans_data_rnn_op_info)
def _trans_data_rnn_tbe():
"""TransDataRNN TBE register"""
return

View File

@ -721,6 +721,8 @@ class DataType:
F16_NDC1HWC0 = ("float16", "NDC1HWC0") F16_NDC1HWC0 = ("float16", "NDC1HWC0")
F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D") F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D")
F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM") F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM")
F16_FracZNRNN = ("float16", "FRACTAL_ZN_RNN")
F16_ND_RNNBIAS = ("float16", "ND_RNN_BIAS")
F16_ChannelLast = ("float16", "ChannelLast") F16_ChannelLast = ("float16", "ChannelLast")
F32_None = ("float32", "") F32_None = ("float32", "")
@ -738,6 +740,8 @@ class DataType:
F32_NDC1HWC0 = ("float32", "NDC1HWC0") F32_NDC1HWC0 = ("float32", "NDC1HWC0")
F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D") F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D")
F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM") F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM")
F32_FracZNRNN = ("float32", "FRACTAL_ZN_RNN")
F32_ND_RNNBIAS = ("float32", "ND_RNN_BIAS")
F32_ChannelLast = ("float32", "ChannelLast") F32_ChannelLast = ("float32", "ChannelLast")
F64_None = ("float64", "") F64_None = ("float64", "")
@ -878,6 +882,8 @@ class DataType:
F16_NDC1HWC0 = ("float16", "NDC1HWC0") F16_NDC1HWC0 = ("float16", "NDC1HWC0")
F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D") F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D")
F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM") F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM")
F16_FracZNRNN = ("float16", "FRACTAL_ZN_RNN")
F16_ND_RNNBIAS = ("float16", "ND_RNN_BIAS")
F16_ChannelLast = ("float16", "ChannelLast") F16_ChannelLast = ("float16", "ChannelLast")
F32_None = ("float32", "") F32_None = ("float32", "")
@ -895,6 +901,8 @@ class DataType:
F32_NDC1HWC0 = ("float32", "NDC1HWC0") F32_NDC1HWC0 = ("float32", "NDC1HWC0")
F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D") F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D")
F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM") F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM")
F32_FracZNRNN = ("float32", "FRACTAL_ZN_RNN")
F32_ND_RNNBIAS = ("float32", "ND_RNN_BIAS")
F32_ChannelLast = ("float32", "ChannelLast") F32_ChannelLast = ("float32", "ChannelLast")
F64_None = ("float64", "") F64_None = ("float64", "")

View File

@ -7404,9 +7404,6 @@ class DynamicRNN(PrimitiveWithInfer):
are learnable weights between the output and the input in the formula. For instance, are learnable weights between the output and the input in the formula. For instance,
:math:`W_{ix}, b_{ix}` are the weight and bias used to transform from input :math:`x` to :math:`i`. :math:`W_{ix}, b_{ix}` are the weight and bias used to transform from input :math:`x` to :math:`i`.
Note:
The `hidden_size` in shape of inputs must be multiple of 16.
Args: Args:
cell_type (str): A string identifying the cell type in the op. Default: 'LSTM'. cell_type (str): A string identifying the cell type in the op. Default: 'LSTM'.
Only 'LSTM' is currently supported. Only 'LSTM' is currently supported.
@ -7534,6 +7531,8 @@ class DynamicRNN(PrimitiveWithInfer):
validator.check("c_shape", c_shape, "h_shape", h_shape, Rel.EQ, self.name) validator.check("c_shape", c_shape, "h_shape", h_shape, Rel.EQ, self.name)
self.placeholder_index = [3] self.placeholder_index = [3]
self.add_prim_attr("placeholder_index", self.placeholder_index) self.add_prim_attr("placeholder_index", self.placeholder_index)
self.add_prim_attr("input_size", input_size)
self.add_prim_attr("hidden_size", hidden_size)
y_shape = (num_step, batch_size, hidden_size) y_shape = (num_step, batch_size, hidden_size)
return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape

View File

@ -30,14 +30,10 @@ class FormatTransTest : public UT::Common {
}; };
TEST_F(FormatTransTest, nchw_to_hwcn) { TEST_F(FormatTransTest, nchw_to_hwcn) {
uint16_t data[2*2*2*2] = {12581,14220,14937,14302, uint16_t data[2 * 2 * 2 * 2] = {12581, 14220, 14937, 14302, 15004, 14951, 14694, 14564,
15004,14951,14694,14564, 14069, 14554, 10507, 14787, 13016, 15263, 14872, 10838};
14069,14554,10507,14787, uint16_t res[2 * 2 * 2 * 2] = {12581, 14069, 15004, 13016, 14220, 14554, 14951, 15263,
13016,15263,14872,10838}; 14937, 10507, 14694, 14872, 14302, 14787, 14564, 10838};
uint16_t res[2*2*2*2] = {12581,14069,15004,13016,
14220,14554,14951,15263,
14937,10507,14694,14872,
14302,14787,14564,10838};
size_t device_size = 32; size_t device_size = 32;
auto trans_tmp = std::vector<uint8_t>(device_size); auto trans_tmp = std::vector<uint8_t>(device_size);
FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_HWCN, FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_HWCN,
@ -49,15 +45,11 @@ TEST_F(FormatTransTest, nchw_to_hwcn) {
} }
TEST_F(FormatTransTest, hwcn_to_nchw) { TEST_F(FormatTransTest, hwcn_to_nchw) {
uint16_t data[2*2*2*2] = {12581,14069,15004,13016, uint16_t data[2 * 2 * 2 * 2] = {12581, 14069, 15004, 13016, 14220, 14554, 14951, 15263,
14220,14554,14951,15263, 14937, 10507, 14694, 14872, 14302, 14787, 14564, 10838};
14937,10507,14694,14872,
14302,14787,14564,10838};
uint16_t res[2*2*2*2] = {12581,14220,14937,14302, uint16_t res[2 * 2 * 2 * 2] = {12581, 14220, 14937, 14302, 15004, 14951, 14694, 14564,
15004,14951,14694,14564, 14069, 14554, 10507, 14787, 13016, 15263, 14872, 10838};
14069,14554,10507,14787,
13016,15263,14872,10838};
size_t device_size = 32; size_t device_size = 32;
auto trans_tmp = std::vector<uint8_t>(device_size); auto trans_tmp = std::vector<uint8_t>(device_size);
@ -70,14 +62,10 @@ TEST_F(FormatTransTest, hwcn_to_nchw) {
} }
TEST_F(FormatTransTest, nchw_to_nhwc) { TEST_F(FormatTransTest, nchw_to_nhwc) {
uint16_t data[2*2*2*2] = {11750,13778,15007,15321, uint16_t data[2 * 2 * 2 * 2] = {11750, 13778, 15007, 15321, 15163, 13446, 15063, 14467,
15163,13446,15063,14467, 15056, 13284, 15219, 14797, 12684, 14288, 14855, 14799};
15056,13284,15219,14797, uint16_t res[2 * 2 * 2 * 2] = {11750, 15163, 13778, 13446, 15007, 15063, 15321, 14467,
12684,14288,14855,14799}; 15056, 12684, 13284, 14288, 15219, 14855, 14797, 14799};
uint16_t res[2*2*2*2] = {11750,15163,13778,13446,
15007,15063,15321,14467,
15056,12684,13284,14288,
15219,14855,14797,14799};
size_t device_size = 32; size_t device_size = 32;
auto trans_tmp = std::vector<uint8_t>(device_size); auto trans_tmp = std::vector<uint8_t>(device_size);
FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_NHWC, FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_NHWC,
@ -87,15 +75,12 @@ TEST_F(FormatTransTest, nchw_to_nhwc) {
EXPECT_EQ((reinterpret_cast<uint16_t *>(trans_tmp.data()))[i], res[i]); EXPECT_EQ((reinterpret_cast<uint16_t *>(trans_tmp.data()))[i], res[i]);
} }
} }
TEST_F(FormatTransTest, nhwc_to_nchw) { TEST_F(FormatTransTest, nhwc_to_nchw) {
uint16_t data[2*2*2*2] = {11750,15163,13778,13446, uint16_t data[2 * 2 * 2 * 2] = {11750, 15163, 13778, 13446, 15007, 15063, 15321, 14467,
15007,15063,15321,14467, 15056, 12684, 13284, 14288, 15219, 14855, 14797, 14799};
15056,12684,13284,14288, uint16_t res[2 * 2 * 2 * 2] = {11750, 13778, 15007, 15321, 15163, 13446, 15063, 14467,
15219,14855,14797,14799}; 15056, 13284, 15219, 14797, 12684, 14288, 14855, 14799};
uint16_t res[2*2*2*2] = {11750,13778,15007,15321,
15163,13446,15063,14467,
15056,13284,15219,14797,
12684,14288,14855,14799};
size_t device_size = 32; size_t device_size = 32;
auto trans_tmp = std::vector<uint8_t>(device_size); auto trans_tmp = std::vector<uint8_t>(device_size);
@ -106,8 +91,60 @@ TEST_F(FormatTransTest, nhwc_to_nchw) {
EXPECT_EQ((reinterpret_cast<uint16_t *>(trans_tmp.data()))[i], res[i]); EXPECT_EQ((reinterpret_cast<uint16_t *>(trans_tmp.data()))[i], res[i]);
} }
} }
class ShapeTransTest : public UT::Common {
public:
ShapeTransTest() = default;
void SetUp() override {}
void TearDown() override {}
};
TEST_F(ShapeTransTest, fraczn_rnn_device_shape) {
std::vector<size_t> host_shape = {43, 120};
std::string format = kOpFormat_FRACTAL_ZN_RNN;
std::vector<int64_t> input_hidden_size = {13, 30};
auto trans_shape = trans::TransShapeToDevice(host_shape, format, 1, input_hidden_size);
const std::vector<size_t> expect_shape = {3, 8, 16, 16};
EXPECT_EQ(trans_shape.size(), expect_shape.size());
for (size_t i = 0; i < expect_shape.size(); i++) {
EXPECT_EQ(trans_shape[i], expect_shape[i]);
}
}
TEST_F(ShapeTransTest, nd_rnn_bias_device_shape) {
std::vector<size_t> host_shape = {120};
std::string format = kOpFormat_ND_RNN_BIAS;
std::vector<int64_t> input_hidden_size = {13, 30};
auto trans_shape = trans::TransShapeToDevice(host_shape, format, 1, input_hidden_size);
std::vector<size_t> expect_shape = {128};
EXPECT_EQ(trans_shape.size(), expect_shape.size());
for (size_t i = 0; i < expect_shape.size(); i++) {
EXPECT_EQ(trans_shape[i], expect_shape[i]);
}
}
TEST_F(ShapeTransTest, fraczn_rnn_dynamic_device_shape) {
std::vector<int64_t> host_shape = {-1, -1};
std::string format = kOpFormat_FRACTAL_ZN_RNN;
std::vector<int64_t> input_hidden_size = {13, 30};
auto trans_shape = trans::TransShapeToDevice(host_shape, format, 1, input_hidden_size);
const std::vector<int64_t> expect_shape = {-1, -1, 16, 16};
EXPECT_EQ(trans_shape.size(), expect_shape.size());
for (size_t i = 0; i < expect_shape.size(); i++) {
EXPECT_EQ(trans_shape[i], expect_shape[i]);
}
}
TEST_F(ShapeTransTest, nd_rnn_bias_dynamic_device_shape) {
std::vector<int64_t> host_shape = {-1};
std::string format = kOpFormat_ND_RNN_BIAS;
std::vector<int64_t> input_hidden_size = {13, 30};
auto trans_shape = trans::TransShapeToDevice(host_shape, format, 1, input_hidden_size);
std::vector<int64_t> expect_shape = {-1};
EXPECT_EQ(trans_shape.size(), expect_shape.size());
for (size_t i = 0; i < expect_shape.size(); i++) {
EXPECT_EQ(trans_shape[i], expect_shape[i]);
}
}
} // namespace trans } // namespace trans
} // namespace mindspore } // namespace mindspore