forked from mindspore-Ecosystem/mindspore
!22068 DynamicRNN supports scenarios that hidden_size is not multiples of 16
Merge pull request !22068 from yuchaojie/op_select2
This commit is contained in:
commit
abc9d8e6fe
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -33,6 +33,11 @@ namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
|
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
|
||||||
namespace {
|
namespace {
|
||||||
|
bool NeedInsertTransData(const std::vector<size_t> &origin_shape, const std::string &format) {
|
||||||
|
return kCommonFormatSet.find(format) == kCommonFormatSet.end() &&
|
||||||
|
(origin_shape.size() > 1 || format == kOpFormat_ND_RNN_BIAS);
|
||||||
|
}
|
||||||
|
|
||||||
AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
|
AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
|
||||||
const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
|
const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
|
||||||
std::vector<AnfNodePtr> trans_inputs;
|
std::vector<AnfNodePtr> trans_inputs;
|
||||||
|
@ -50,14 +55,15 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i
|
||||||
|
|
||||||
void SetTransNodeAttr(const CNodePtr &trans_node) {
|
void SetTransNodeAttr(const CNodePtr &trans_node) {
|
||||||
MS_EXCEPTION_IF_NULL(trans_node);
|
MS_EXCEPTION_IF_NULL(trans_node);
|
||||||
if (AnfAlgo::GetCNodeName(trans_node) == kTransDataOpName) {
|
auto trans_opname = AnfAlgo::GetCNodeName(trans_node);
|
||||||
|
if (trans_opname == kTransDataOpName || trans_opname == kTransDataRNNOpName) {
|
||||||
std::string input_format = AnfAlgo::GetInputFormat(trans_node, 0);
|
std::string input_format = AnfAlgo::GetInputFormat(trans_node, 0);
|
||||||
std::string output_format = AnfAlgo::GetOutputFormat(trans_node, 0);
|
std::string output_format = AnfAlgo::GetOutputFormat(trans_node, 0);
|
||||||
if (input_format == kOpFormat_DEFAULT) {
|
if (input_format == kOpFormat_DEFAULT) {
|
||||||
input_format = kOpFormat_NCHW;
|
input_format = AnfAlgo::GetCNodeName(trans_node) == kTransDataOpName ? kOpFormat_NCHW : kOpFormat_ND;
|
||||||
}
|
}
|
||||||
if (output_format == kOpFormat_DEFAULT) {
|
if (output_format == kOpFormat_DEFAULT) {
|
||||||
output_format = kOpFormat_NCHW;
|
output_format = AnfAlgo::GetCNodeName(trans_node) == kTransDataOpName ? kOpFormat_NCHW : kOpFormat_ND;
|
||||||
}
|
}
|
||||||
AnfAlgo::SetNodeAttr(kAttrSrcFormat, MakeValue(input_format), trans_node);
|
AnfAlgo::SetNodeAttr(kAttrSrcFormat, MakeValue(input_format), trans_node);
|
||||||
AnfAlgo::SetNodeAttr(kAttrDstFormat, MakeValue(output_format), trans_node);
|
AnfAlgo::SetNodeAttr(kAttrDstFormat, MakeValue(output_format), trans_node);
|
||||||
|
@ -115,7 +121,7 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
|
||||||
}
|
}
|
||||||
std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index);
|
std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index);
|
||||||
std::string dest_format = AnfAlgo::GetInputFormat(node, index);
|
std::string dest_format = AnfAlgo::GetInputFormat(node, index);
|
||||||
if (kCommonFormatSet.find(dest_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
|
if (NeedInsertTransData(origin_shape, dest_format)) {
|
||||||
MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
|
MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
|
||||||
<< " To DefaultFormat , index: " << index;
|
<< " To DefaultFormat , index: " << index;
|
||||||
auto transdata = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
|
auto transdata = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
|
||||||
|
@ -136,7 +142,7 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
|
||||||
MS_LOG(EXCEPTION) << "Got the hw format " << output_format << "when insert the transdata node "
|
MS_LOG(EXCEPTION) << "Got the hw format " << output_format << "when insert the transdata node "
|
||||||
<< node->DebugString() << " trace: " << trace::DumpSourceLines(node);
|
<< node->DebugString() << " trace: " << trace::DumpSourceLines(node);
|
||||||
}
|
}
|
||||||
if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
|
if (NeedInsertTransData(origin_shape, output_format)) {
|
||||||
MS_LOG(DEBUG) << "Inserted transdata " << output_format << " to default , index :0";
|
MS_LOG(DEBUG) << "Inserted transdata " << output_format << " to default , index :0";
|
||||||
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false);
|
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false);
|
||||||
}
|
}
|
||||||
|
@ -164,7 +170,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
|
||||||
}
|
}
|
||||||
auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);
|
auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);
|
||||||
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
|
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
|
||||||
if (origin_shape.size() > 1 && kCommonFormatSet.find(output_format) == kCommonFormatSet.end()) {
|
if (NeedInsertTransData(origin_shape, output_format)) {
|
||||||
auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false);
|
auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false);
|
||||||
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, output_idx)) {
|
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, output_idx)) {
|
||||||
kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0);
|
kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0);
|
||||||
|
@ -193,11 +199,14 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
|
||||||
: AnfAlgo::GetOutputReshapeType(node, insert_index);
|
: AnfAlgo::GetOutputReshapeType(node, insert_index);
|
||||||
auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index)
|
auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index)
|
||||||
: AnfAlgo::GetOutputInferShape(input_node, insert_index);
|
: AnfAlgo::GetOutputInferShape(input_node, insert_index);
|
||||||
bool need_padding = is_insert_input ? trans::IsNeedPadding(dst_format, input_node_out_shape.size())
|
std::string spec_format = is_insert_input ? dst_format : input_format;
|
||||||
: trans::IsNeedPadding(input_format, input_node_out_shape.size());
|
bool need_padding = trans::IsNeedPadding(spec_format, input_node_out_shape.size());
|
||||||
|
std::string trans_opname = (spec_format == kOpFormat_FRACTAL_ZN_RNN || spec_format == kOpFormat_ND_RNN_BIAS)
|
||||||
|
? prim::kPrimTransDataRNN->name()
|
||||||
|
: prim::kPrimTransData->name();
|
||||||
if (!need_padding) {
|
if (!need_padding) {
|
||||||
// don't need padding insert transdata only
|
// don't need padding insert transdata only
|
||||||
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::kPrimTransData->name());
|
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, trans_opname);
|
||||||
trans_node = trans_data;
|
trans_node = trans_data;
|
||||||
} else if (is_insert_input) {
|
} else if (is_insert_input) {
|
||||||
// if need padding & is input need insert a transdata
|
// if need padding & is input need insert a transdata
|
||||||
|
@ -205,16 +214,20 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
|
||||||
auto padding_shape = trans::PaddingShape(input_node_out_shape, AnfAlgo::GetInputFormat(node, insert_index),
|
auto padding_shape = trans::PaddingShape(input_node_out_shape, AnfAlgo::GetInputFormat(node, insert_index),
|
||||||
AnfAlgo::GetInputReshapeType(node, insert_index));
|
AnfAlgo::GetInputReshapeType(node, insert_index));
|
||||||
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
|
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
|
||||||
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::kPrimTransData->name());
|
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, trans_opname);
|
||||||
trans_node = trans_data;
|
trans_node = trans_data;
|
||||||
trans_data->set_abstract(input_node->abstract());
|
trans_data->set_abstract(input_node->abstract());
|
||||||
} else {
|
} else {
|
||||||
// if need padding & is output need insert a transdata
|
// if need padding & is output need insert a transdata
|
||||||
// node -> transdata[padding shape] -> reshape[ori_shape]
|
// node -> transdata[padding shape] -> reshape[ori_shape]
|
||||||
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::kPrimTransData->name());
|
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, trans_opname);
|
||||||
auto reshape_node = CreateReshapeNode(func_graph, trans_data, kernel_select, input_node_out_shape);
|
auto reshape_node = CreateReshapeNode(func_graph, trans_data, kernel_select, input_node_out_shape);
|
||||||
trans_node = reshape_node;
|
trans_node = reshape_node;
|
||||||
}
|
}
|
||||||
|
if (trans_opname == prim::kPrimTransDataRNN->name()) {
|
||||||
|
AnfAlgo::CopyNodeAttr(kAttrHiddenSize, node, trans_data);
|
||||||
|
AnfAlgo::CopyNodeAttr(kAttrInputSize, node, trans_data);
|
||||||
|
}
|
||||||
// refresh the transdata's format to ori format & dst format
|
// refresh the transdata's format to ori format & dst format
|
||||||
RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis);
|
RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis);
|
||||||
if (!is_insert_input) {
|
if (!is_insert_input) {
|
||||||
|
|
|
@ -459,7 +459,7 @@ const AnfNodePtr DynamicGRUV2GradFission::Process(const FuncGraphPtr &func_graph
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (AnfAlgo::IsDynamicShape(node)) {
|
if (AnfAlgo::IsDynamicShape(node)) {
|
||||||
MS_LOG(INFO) << "DynamicGRUV2Grad is dynamic shape, can not optimizer.";
|
MS_LOG(INFO) << "DynamicGRUV2Grad is dynamic shape, can not do fission.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -518,6 +518,10 @@ const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph
|
||||||
<< (kDynamicRNNGradInputNum + 1) << " inputs";
|
<< (kDynamicRNNGradInputNum + 1) << " inputs";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
if (AnfAlgo::IsDynamicShape(node)) {
|
||||||
|
MS_LOG(INFO) << "DynamicRnnGrad is dynamic shape, can not do fission.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
std::vector<AnfNodePtr> new_outputs;
|
std::vector<AnfNodePtr> new_outputs;
|
||||||
auto lstm_input_grad = AddLSTMInputGradNode(func_graph, dynamic_rnn_grad_cnode, &new_outputs);
|
auto lstm_input_grad = AddLSTMInputGradNode(func_graph, dynamic_rnn_grad_cnode, &new_outputs);
|
||||||
|
|
||||||
|
|
|
@ -51,6 +51,7 @@ void SetAttrForInputNode(const AnfNodePtr &node, int64_t groups) {
|
||||||
void SetAttrForConvInput(const CNodePtr &cnode) {
|
void SetAttrForConvInput(const CNodePtr &cnode) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
auto groups = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrGroups);
|
auto groups = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrGroups);
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups), cnode);
|
||||||
if (groups > 1) {
|
if (groups > 1) {
|
||||||
SetAttrForInputNode(cnode->input(kConvFilterInputIndex), groups);
|
SetAttrForInputNode(cnode->input(kConvFilterInputIndex), groups);
|
||||||
}
|
}
|
||||||
|
|
|
@ -857,8 +857,7 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &n
|
||||||
if (trans::IsNeedPadding(format, infer_shape.size())) {
|
if (trans::IsNeedPadding(format, infer_shape.size())) {
|
||||||
infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx));
|
infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx));
|
||||||
}
|
}
|
||||||
auto input_node_index = GetPrevNodeOutput(node, input_idx);
|
return trans::TransShapeToDevice(infer_shape, format, node, input_idx, false);
|
||||||
return trans::TransShapeToDevice(infer_shape, format, input_node_index.first, input_node_index.second);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
|
std::string AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
|
||||||
|
@ -2055,8 +2054,7 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetInputRealDeviceShapeIfExist(const An
|
||||||
auto max_shape = GetInputMaxShape(anf_node, index);
|
auto max_shape = GetInputMaxShape(anf_node, index);
|
||||||
std::transform(max_shape.begin(), max_shape.end(), device_shape.begin(), IntToSize);
|
std::transform(max_shape.begin(), max_shape.end(), device_shape.begin(), IntToSize);
|
||||||
auto format = GetInputFormat(anf_node, index);
|
auto format = GetInputFormat(anf_node, index);
|
||||||
auto input_node_index = GetPrevNodeOutput(anf_node, index);
|
trans::TransShapeToDevice(device_shape, format, anf_node, index, false);
|
||||||
trans::TransShapeToDevice(device_shape, format, input_node_index.first, input_node_index.second);
|
|
||||||
}
|
}
|
||||||
return device_shape;
|
return device_shape;
|
||||||
}
|
}
|
||||||
|
|
|
@ -626,8 +626,8 @@ std::vector<size_t> FracNZDeviceShape(const std::vector<size_t> &shape) {
|
||||||
return shape;
|
return shape;
|
||||||
}
|
}
|
||||||
std::vector<size_t> device_shape;
|
std::vector<size_t> device_shape;
|
||||||
if (shape.size() < 2) {
|
if (shape.size() < kShape2dDims) {
|
||||||
MS_LOG(EXCEPTION) << "Format FRACTAL_NZ is not support shape " << shape.size();
|
MS_LOG(EXCEPTION) << "Format FRACTAL_NZ don't support shape with " << shape.size() << " dims";
|
||||||
} else {
|
} else {
|
||||||
(void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape));
|
(void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape));
|
||||||
}
|
}
|
||||||
|
@ -646,8 +646,8 @@ std::vector<int64_t> FracNZDeviceDynamicShape(const std::vector<int64_t> &shape)
|
||||||
// For [1] and [1024] shape we can trait it as NZ shape
|
// For [1] and [1024] shape we can trait it as NZ shape
|
||||||
return shape;
|
return shape;
|
||||||
}
|
}
|
||||||
if (shape.size() < 2) {
|
if (shape.size() < kShape2dDims) {
|
||||||
MS_LOG(EXCEPTION) << "Format FRACTAL_NZ is not support shape " << shape.size();
|
MS_LOG(EXCEPTION) << "Format FRACTAL_NZ don't support shape with " << shape.size() << " dims";
|
||||||
} else {
|
} else {
|
||||||
(void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape));
|
(void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape));
|
||||||
}
|
}
|
||||||
|
@ -695,6 +695,108 @@ std::vector<int64_t> FracNZLSTMDeviceDynamicShape(const std::vector<int64_t> &sh
|
||||||
device_shape.push_back(kCubeSize);
|
device_shape.push_back(kCubeSize);
|
||||||
return device_shape;
|
return device_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> FracZNRNNDeviceShape(const std::vector<size_t> &shape,
|
||||||
|
const std::vector<int64_t> &input_hidden_size = {kAlign16, kAlign16}) {
|
||||||
|
if (shape.size() < kShape2dDims) {
|
||||||
|
MS_LOG(EXCEPTION) << "Format FRACTAL_ZN_RNN don't support shape with " << shape.size() << " dims";
|
||||||
|
}
|
||||||
|
size_t input_size = LongToSize(input_hidden_size[0]);
|
||||||
|
size_t hidden_size = LongToSize(input_hidden_size[1]);
|
||||||
|
auto dim_last1 = shape[shape.size() - 1];
|
||||||
|
auto dim_last2 = shape[shape.size() - 2];
|
||||||
|
if (dim_last1 % hidden_size != 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hidden_size;
|
||||||
|
}
|
||||||
|
size_t n_num = dim_last1 / hidden_size;
|
||||||
|
const size_t NUM16 = 16;
|
||||||
|
const size_t C0 = kCubeSize;
|
||||||
|
|
||||||
|
std::vector<size_t> device_shape = shape;
|
||||||
|
if (dim_last2 == input_size || dim_last2 == hidden_size) {
|
||||||
|
device_shape[shape.size() - 2] = DivCeil(dim_last2, NUM16);
|
||||||
|
} else if (dim_last2 == input_size + hidden_size) {
|
||||||
|
device_shape[shape.size() - 2] = DivCeil(input_size, NUM16) + DivCeil(hidden_size, NUM16);
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "The second-last dim value of shape is invalid.";
|
||||||
|
}
|
||||||
|
device_shape[shape.size() - 1] = n_num * DivCeil(hidden_size, C0);
|
||||||
|
device_shape.push_back(NUM16);
|
||||||
|
device_shape.push_back(C0);
|
||||||
|
return device_shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> FracZNRNNDeviceDynamicShape(const std::vector<int64_t> &shape,
|
||||||
|
const std::vector<int64_t> &input_hidden_size = {kAlign16, kAlign16}) {
|
||||||
|
if (shape.size() < kShape2dDims) {
|
||||||
|
MS_LOG(EXCEPTION) << "Format FRACTAL_NZ_RNN don't support shape with " << shape.size() << " dims";
|
||||||
|
}
|
||||||
|
int64_t input_size = input_hidden_size[0];
|
||||||
|
int64_t hidden_size = input_hidden_size[1];
|
||||||
|
auto dim_last1 = shape[shape.size() - 1];
|
||||||
|
auto dim_last2 = shape[shape.size() - 2];
|
||||||
|
const int64_t NUM16 = 16;
|
||||||
|
const int64_t C0 = SizeToLong(kCubeSize);
|
||||||
|
|
||||||
|
std::vector<int64_t> device_shape = shape;
|
||||||
|
if (dim_last2 == Shape::SHP_ANY) {
|
||||||
|
device_shape[shape.size() - 2] = Shape::SHP_ANY;
|
||||||
|
} else if (dim_last2 == input_size || dim_last2 == hidden_size) {
|
||||||
|
device_shape[shape.size() - 2] = DivCeil(dim_last2, NUM16);
|
||||||
|
} else if (dim_last2 == input_size + hidden_size) {
|
||||||
|
device_shape[shape.size() - 2] = DivCeil(input_size, NUM16) + DivCeil(hidden_size, NUM16);
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "The second-last dim value of shape is invalid.";
|
||||||
|
}
|
||||||
|
if (dim_last1 == Shape::SHP_ANY) {
|
||||||
|
device_shape[shape.size() - 1] = Shape::SHP_ANY;
|
||||||
|
} else {
|
||||||
|
if (dim_last1 % hidden_size != 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hidden_size;
|
||||||
|
}
|
||||||
|
int64_t n_num = shape[shape.size() - 1] / hidden_size;
|
||||||
|
device_shape[shape.size() - 1] = n_num * DivCeil(hidden_size, C0);
|
||||||
|
}
|
||||||
|
device_shape.push_back(NUM16);
|
||||||
|
device_shape.push_back(C0);
|
||||||
|
return device_shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> NDRNNBiasDeviceShape(const std::vector<size_t> &shape, const int64_t hidden_size = 16) {
|
||||||
|
if (shape.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Format ND_RNN_BIAS don't support empty shape.";
|
||||||
|
}
|
||||||
|
size_t hid_size = LongToSize(hidden_size);
|
||||||
|
// cppcheck-suppress *
|
||||||
|
if (shape[shape.size() - 1] % hid_size != 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hid_size;
|
||||||
|
}
|
||||||
|
size_t n_num = shape[shape.size() - 1] / hid_size;
|
||||||
|
const size_t C0 = kCubeSize;
|
||||||
|
std::vector<size_t> device_shape = shape;
|
||||||
|
device_shape[shape.size() - 1] = n_num * DivCeil(hid_size, C0) * C0;
|
||||||
|
return device_shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> NDRNNBiasDeviceDynamicShape(const std::vector<int64_t> &shape, const int64_t hidden_size = 16) {
|
||||||
|
if (shape.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Format ND_RNN_BIAS don't support empty shape.";
|
||||||
|
}
|
||||||
|
const int64_t C0 = SizeToLong(kCubeSize);
|
||||||
|
std::vector<int64_t> device_shape = shape;
|
||||||
|
// cppcheck-suppress *
|
||||||
|
auto dim_last1 = shape[shape.size() - 1];
|
||||||
|
if (dim_last1 == Shape::SHP_ANY) {
|
||||||
|
device_shape[shape.size() - 1] = Shape::SHP_ANY;
|
||||||
|
} else {
|
||||||
|
if (dim_last1 % hidden_size != 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hidden_size;
|
||||||
|
}
|
||||||
|
int64_t n_num = shape[shape.size() - 1] / hidden_size;
|
||||||
|
device_shape[shape.size() - 1] = n_num * DivCeil(hidden_size, C0) * C0;
|
||||||
|
}
|
||||||
|
return device_shape;
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index) {
|
int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index) {
|
||||||
|
@ -723,6 +825,22 @@ int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> GetAttrInputAndHiddenSize(const AnfNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
std::vector<int64_t> input_hidden_size = {kAlign16, kAlign16};
|
||||||
|
if (!node->isa<CNode>()) {
|
||||||
|
return input_hidden_size;
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
if (!AnfAlgo::HasNodeAttr(kAttrHiddenSize, cnode) || !AnfAlgo::HasNodeAttr(kAttrInputSize, cnode)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Node with format FRACTAL_ZN_RNN or ND_RNN_BIAS should have hidden_size or input_size attr.";
|
||||||
|
}
|
||||||
|
input_hidden_size[0] = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrInputSize);
|
||||||
|
input_hidden_size[1] = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrHiddenSize);
|
||||||
|
return input_hidden_size;
|
||||||
|
}
|
||||||
|
|
||||||
bool IsNeedPadding(const std::string &format, const size_t shape_size) {
|
bool IsNeedPadding(const std::string &format, const size_t shape_size) {
|
||||||
if (shape_size == 0) {
|
if (shape_size == 0) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -820,7 +938,7 @@ void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
|
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
|
||||||
const int64_t groups) {
|
const int64_t groups, const std::vector<int64_t> &input_hidden_size) {
|
||||||
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
|
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
|
||||||
const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape},
|
const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape},
|
||||||
{kOpFormat_NHWC, NhwcDeviceShape},
|
{kOpFormat_NHWC, NhwcDeviceShape},
|
||||||
|
@ -843,6 +961,12 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
||||||
if (groups > 1 && format == kOpFormat_FRAC_Z) {
|
if (groups > 1 && format == kOpFormat_FRAC_Z) {
|
||||||
return FracZDeviceShapeWithGroups(shape, groups);
|
return FracZDeviceShapeWithGroups(shape, groups);
|
||||||
}
|
}
|
||||||
|
if (format == kOpFormat_FRACTAL_ZN_RNN) {
|
||||||
|
return FracZNRNNDeviceShape(shape, input_hidden_size);
|
||||||
|
}
|
||||||
|
if (format == kOpFormat_ND_RNN_BIAS) {
|
||||||
|
return NDRNNBiasDeviceShape(shape, input_hidden_size[1]);
|
||||||
|
}
|
||||||
auto temp_shape = shape;
|
auto temp_shape = shape;
|
||||||
if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM &&
|
if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM &&
|
||||||
shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
|
shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
|
||||||
|
@ -860,7 +984,7 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> TransShapeToDevice(const std::vector<int64_t> &shape, const std::string &format,
|
std::vector<int64_t> TransShapeToDevice(const std::vector<int64_t> &shape, const std::string &format,
|
||||||
const int64_t groups) {
|
const int64_t groups, const std::vector<int64_t> &input_hidden_size) {
|
||||||
using DeviceShapeTransfer = std::function<std::vector<int64_t>(const std::vector<int64_t> &)>;
|
using DeviceShapeTransfer = std::function<std::vector<int64_t>(const std::vector<int64_t> &)>;
|
||||||
const std::map<std::string, DeviceShapeTransfer> device_shape_map{
|
const std::map<std::string, DeviceShapeTransfer> device_shape_map{
|
||||||
{kOpFormat_NCHW, NchwDeviceDynamicShape},
|
{kOpFormat_NCHW, NchwDeviceDynamicShape},
|
||||||
|
@ -884,6 +1008,12 @@ std::vector<int64_t> TransShapeToDevice(const std::vector<int64_t> &shape, const
|
||||||
if (groups > 1 && format == kOpFormat_FRAC_Z) {
|
if (groups > 1 && format == kOpFormat_FRAC_Z) {
|
||||||
return FracZDeviceShapeWithGroups(shape, groups);
|
return FracZDeviceShapeWithGroups(shape, groups);
|
||||||
}
|
}
|
||||||
|
if (format == kOpFormat_FRACTAL_ZN_RNN) {
|
||||||
|
return FracZNRNNDeviceDynamicShape(shape, input_hidden_size);
|
||||||
|
}
|
||||||
|
if (format == kOpFormat_ND_RNN_BIAS) {
|
||||||
|
return NDRNNBiasDeviceDynamicShape(shape, input_hidden_size[1]);
|
||||||
|
}
|
||||||
auto temp_shape = shape;
|
auto temp_shape = shape;
|
||||||
if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM &&
|
if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM &&
|
||||||
shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
|
shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
|
||||||
|
|
|
@ -31,7 +31,10 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace trans {
|
namespace trans {
|
||||||
|
constexpr int64_t kAlign16 = 16;
|
||||||
|
|
||||||
enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims };
|
enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims };
|
||||||
|
|
||||||
enum Axis5D : int {
|
enum Axis5D : int {
|
||||||
N_ncdhw = 0,
|
N_ncdhw = 0,
|
||||||
C_ncdhw,
|
C_ncdhw,
|
||||||
|
@ -66,23 +69,30 @@ struct FormatArgs {
|
||||||
};
|
};
|
||||||
|
|
||||||
int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index);
|
int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index);
|
||||||
|
std::vector<int64_t> GetAttrInputAndHiddenSize(const AnfNodePtr &node);
|
||||||
void StringToAxisVector4D(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec);
|
void StringToAxisVector4D(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec);
|
||||||
void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5D> *reshape_type_vec);
|
void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5D> *reshape_type_vec);
|
||||||
ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index);
|
ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index);
|
||||||
bool IsNeedPadding(const std::string &format, const size_t shape_size);
|
bool IsNeedPadding(const std::string &format, const size_t shape_size);
|
||||||
int64_t GetNodeGroups(const AnfNodePtr &node);
|
int64_t GetNodeGroups(const AnfNodePtr &node);
|
||||||
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
|
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
|
||||||
const int64_t groups = 1);
|
const int64_t groups = 1,
|
||||||
|
const std::vector<int64_t> &input_hidden_size = {kAlign16, kAlign16});
|
||||||
std::vector<int64_t> TransShapeToDevice(const std::vector<int64_t> &shape, const std::string &format,
|
std::vector<int64_t> TransShapeToDevice(const std::vector<int64_t> &shape, const std::string &format,
|
||||||
const int64_t groups = 1);
|
const int64_t groups = 1,
|
||||||
|
const std::vector<int64_t> &input_hidden_size = {kAlign16, kAlign16});
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::vector<T> TransShapeToDevice(const std::vector<T> &shape, const std::string &format, const AnfNodePtr &node,
|
std::vector<T> TransShapeToDevice(const std::vector<T> &shape, const std::string &format, const AnfNodePtr &node,
|
||||||
const size_t index) {
|
const size_t index, bool is_output = true) {
|
||||||
int64_t groups = 1;
|
int64_t groups = 1;
|
||||||
if (format == kOpFormat_FRAC_Z) {
|
if (format == kOpFormat_FRAC_Z) {
|
||||||
groups = GetAttrGroups(node, index);
|
groups = GetAttrGroups(node, index);
|
||||||
}
|
}
|
||||||
return TransShapeToDevice(shape, format, groups);
|
std::vector<int64_t> input_hidden_size = {kAlign16, kAlign16};
|
||||||
|
if (format == kOpFormat_FRACTAL_ZN_RNN || format == kOpFormat_ND_RNN_BIAS) {
|
||||||
|
input_hidden_size = GetAttrInputAndHiddenSize(node);
|
||||||
|
}
|
||||||
|
return TransShapeToDevice(shape, format, groups, input_hidden_size);
|
||||||
}
|
}
|
||||||
bool TransDataType(const TypeIdArgs &args, void *result);
|
bool TransDataType(const TypeIdArgs &args, void *result);
|
||||||
bool TransFormat(const FormatArgs &args, void *result, int64_t groups = 1);
|
bool TransFormat(const FormatArgs &args, void *result, int64_t groups = 1);
|
||||||
|
|
|
@ -119,6 +119,7 @@ constexpr auto kApplyProximalAdagradOpName = "ApplyProximalAdagrad ";
|
||||||
constexpr auto kApplyProximalGradientDescentOpName = "ApplyProximalGradientDescent";
|
constexpr auto kApplyProximalGradientDescentOpName = "ApplyProximalGradientDescent";
|
||||||
constexpr auto kApplyRMSPropOpName = "ApplyRMSProp";
|
constexpr auto kApplyRMSPropOpName = "ApplyRMSProp";
|
||||||
constexpr auto kTransDataOpName = "TransData";
|
constexpr auto kTransDataOpName = "TransData";
|
||||||
|
constexpr auto kTransDataRNNOpName = "TransDataRNN";
|
||||||
constexpr auto kStackInitOpName = "StackInit";
|
constexpr auto kStackInitOpName = "StackInit";
|
||||||
constexpr auto kStackPushOpName = "StackPush";
|
constexpr auto kStackPushOpName = "StackPush";
|
||||||
constexpr auto kStackPopOpName = "StackPop";
|
constexpr auto kStackPopOpName = "StackPop";
|
||||||
|
@ -460,6 +461,8 @@ constexpr auto kAttrRecursiveEnd = "recursive_end";
|
||||||
constexpr auto kAttrRecursive = "recursive";
|
constexpr auto kAttrRecursive = "recursive";
|
||||||
constexpr auto kAttrMultiCallEnd = "multicall_end";
|
constexpr auto kAttrMultiCallEnd = "multicall_end";
|
||||||
constexpr auto kAttrProfilingIterEnd = "PROFILING_ITER_END";
|
constexpr auto kAttrProfilingIterEnd = "PROFILING_ITER_END";
|
||||||
|
constexpr auto kAttrHiddenSize = "hidden_size";
|
||||||
|
constexpr auto kAttrInputSize = "input_size";
|
||||||
|
|
||||||
// primal attr key name
|
// primal attr key name
|
||||||
constexpr auto kPrimalAttrForwardNodeName = "forward_node_name";
|
constexpr auto kPrimalAttrForwardNodeName = "forward_node_name";
|
||||||
|
@ -566,19 +569,34 @@ constexpr auto kOpFormat_DHWCN = "DHWCN";
|
||||||
constexpr auto kOpFormat_NDC1HWC0 = "NDC1HWC0";
|
constexpr auto kOpFormat_NDC1HWC0 = "NDC1HWC0";
|
||||||
constexpr auto kOpFormat_FRACTAL_Z_3D = "FRACTAL_Z_3D";
|
constexpr auto kOpFormat_FRACTAL_Z_3D = "FRACTAL_Z_3D";
|
||||||
constexpr auto kOpFormat_FRACTAL_ZN_LSTM = "FRACTAL_ZN_LSTM";
|
constexpr auto kOpFormat_FRACTAL_ZN_LSTM = "FRACTAL_ZN_LSTM";
|
||||||
|
constexpr auto kOpFormat_FRACTAL_ZN_RNN = "FRACTAL_ZN_RNN";
|
||||||
|
constexpr auto kOpFormat_ND_RNN_BIAS = "ND_RNN_BIAS";
|
||||||
|
|
||||||
const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0,
|
const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT,
|
||||||
kOpFormat_ND, kOpFormat_NCHW,
|
kOpFormat_NC1KHKWHWC0,
|
||||||
kOpFormat_NHWC, kOpFormat_HWCN,
|
kOpFormat_ND,
|
||||||
kOpFormat_NC1HWC0, kOpFormat_FRAC_Z,
|
kOpFormat_NCHW,
|
||||||
kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ,
|
kOpFormat_NHWC,
|
||||||
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04,
|
kOpFormat_HWCN,
|
||||||
kOpFormat_NDHWC, kOpFormat_FRACTAL_ZN_LSTM,
|
kOpFormat_NC1HWC0,
|
||||||
kOpFormat_NDC1HWC0, kOpFormat_NCDHW,
|
kOpFormat_FRAC_Z,
|
||||||
kOpFormat_FRACTAL_Z_3D, kOpFormat_DHWNC,
|
kOpFormat_C1HWNCoC0,
|
||||||
|
kOpFormat_FRAC_NZ,
|
||||||
|
kOpFormat_NC1HWC0_C04,
|
||||||
|
kOpFormat_FRACTAL_Z_C04,
|
||||||
|
kOpFormat_NDHWC,
|
||||||
|
kOpFormat_FRACTAL_ZN_LSTM,
|
||||||
|
kOpFormat_FRACTAL_ZN_RNN,
|
||||||
|
kOpFormat_ND_RNN_BIAS,
|
||||||
|
kOpFormat_NDC1HWC0,
|
||||||
|
kOpFormat_NCDHW,
|
||||||
|
kOpFormat_FRACTAL_Z_3D,
|
||||||
|
kOpFormat_DHWNC,
|
||||||
kOpFormat_DHWCN};
|
kOpFormat_DHWCN};
|
||||||
|
|
||||||
const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN,
|
const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN,
|
||||||
kOpFormat_NCDHW};
|
kOpFormat_NCDHW};
|
||||||
|
|
||||||
const std::set<std::string> kOptOperatorSet = {kMomentumOpName,
|
const std::set<std::string> kOptOperatorSet = {kMomentumOpName,
|
||||||
kApplyMomentumOpName,
|
kApplyMomentumOpName,
|
||||||
kApplyAdadeltaOpName,
|
kApplyAdadeltaOpName,
|
||||||
|
@ -625,8 +643,9 @@ const std::set<std::string> kOpNotSupportMultiThreadExecList = {kAvgPoolOpName,
|
||||||
kBatchNorm, kBatchNormGradOpName};
|
kBatchNorm, kBatchNormGradOpName};
|
||||||
|
|
||||||
const std::set<std::string> kHWSpecialFormatSet = {
|
const std::set<std::string> kHWSpecialFormatSet = {
|
||||||
kOpFormat_FRACTAL_Z_3D, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0,
|
kOpFormat_FRACTAL_Z_3D, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ,
|
||||||
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_NDC1HWC0, kOpFormat_FRAC_Z};
|
kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM,
|
||||||
|
kOpFormat_FRACTAL_ZN_RNN, kOpFormat_NDC1HWC0, kOpFormat_FRAC_Z};
|
||||||
|
|
||||||
const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32};
|
const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32};
|
||||||
|
|
||||||
|
@ -637,7 +656,8 @@ const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccid
|
||||||
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
|
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
|
||||||
kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC};
|
kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC};
|
||||||
|
|
||||||
const std::set<std::string> kNoPaddingFormatSet = {kOpFormat_ChannelLast, kOpFormat_FRAC_NZ};
|
const std::set<std::string> kNoPaddingFormatSet = {kOpFormat_ChannelLast, kOpFormat_FRAC_NZ, kOpFormat_FRACTAL_ZN_RNN,
|
||||||
|
kOpFormat_ND_RNN_BIAS};
|
||||||
|
|
||||||
const std::set<std::string> DynamicShapeConstInputToAttr = {
|
const std::set<std::string> DynamicShapeConstInputToAttr = {
|
||||||
kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceMinOpName,
|
kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceMinOpName,
|
||||||
|
|
|
@ -205,6 +205,7 @@ inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>(kTile);
|
||||||
inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN");
|
inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN");
|
||||||
inline const PrimitivePtr kPrimAccumulateNV2 = std::make_shared<Primitive>("AccumulateNV2");
|
inline const PrimitivePtr kPrimAccumulateNV2 = std::make_shared<Primitive>("AccumulateNV2");
|
||||||
inline const PrimitivePtr kPrimTransData = std::make_shared<Primitive>("TransData");
|
inline const PrimitivePtr kPrimTransData = std::make_shared<Primitive>("TransData");
|
||||||
|
inline const PrimitivePtr kPrimTransDataRNN = std::make_shared<Primitive>("TransDataRNN");
|
||||||
inline const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask");
|
inline const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask");
|
||||||
inline const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad");
|
inline const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad");
|
||||||
inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("ArgMaxWithValue");
|
inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("ArgMaxWithValue");
|
||||||
|
|
|
@ -92,6 +92,7 @@ from .tensor_add import _tensor_add_tbe
|
||||||
from .tensor_add_ds import _tensor_add_ds_tbe
|
from .tensor_add_ds import _tensor_add_ds_tbe
|
||||||
from .trans_data import _trans_data_tbe
|
from .trans_data import _trans_data_tbe
|
||||||
from .trans_data_ds import _trans_data_ds_tbe
|
from .trans_data_ds import _trans_data_ds_tbe
|
||||||
|
from .trans_data_rnn import _trans_data_rnn_tbe
|
||||||
from .top_k import _top_k_tbe
|
from .top_k import _top_k_tbe
|
||||||
from .matmul import _matmul_tbe
|
from .matmul import _matmul_tbe
|
||||||
from .matmul_ds import _matmul_ds_tbe
|
from .matmul_ds import _matmul_ds_tbe
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -53,22 +53,22 @@ dynamic_rnn_op_info = TBERegOp("DynamicRNN") \
|
||||||
.output(5, "f", False, "required", "all") \
|
.output(5, "f", False, "required", "all") \
|
||||||
.output(6, "o", False, "required", "all") \
|
.output(6, "o", False, "required", "all") \
|
||||||
.output(7, "tanhc", False, "required", "all") \
|
.output(7, "tanhc", False, "required", "all") \
|
||||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F32_Default, DataType.None_Default,
|
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNRNN, DataType.F32_ND_RNNBIAS, DataType.None_Default,
|
||||||
DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||||
DataType.F16_FracNZ, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ,
|
DataType.F16_FracNZ, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ,
|
||||||
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
|
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
|
||||||
DataType.F32_FracNZ, DataType.F32_FracNZ) \
|
DataType.F32_FracNZ, DataType.F32_FracNZ) \
|
||||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F16_Default, DataType.None_Default,
|
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNRNN, DataType.F16_ND_RNNBIAS, DataType.None_Default,
|
||||||
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||||
DataType.F16_FracNZ, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
DataType.F16_FracNZ, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||||
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||||
DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
||||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F32_Default, DataType.I32_Default,
|
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNRNN, DataType.F32_ND_RNNBIAS, DataType.I32_Default,
|
||||||
DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||||
DataType.F16_FracNZ, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ,
|
DataType.F16_FracNZ, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ,
|
||||||
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
|
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
|
||||||
DataType.F32_FracNZ, DataType.F32_FracNZ) \
|
DataType.F32_FracNZ, DataType.F32_FracNZ) \
|
||||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F16_Default, DataType.I32_Default,
|
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNRNN, DataType.F16_ND_RNNBIAS, DataType.I32_Default,
|
||||||
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||||
DataType.F16_FracNZ, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
DataType.F16_FracNZ, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||||
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""TransDataRNN op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
trans_data_rnn_op_info = TBERegOp("TransDataRNN") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("trans_data_rnn.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("trans_data_rnn") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.attr("src_format", "required", "str", "FRACTAL_ZN_RNN, ND_RNN_BIAS") \
|
||||||
|
.attr("dst_format", "required", "str", "FRACTAL_ZN_RNN, ND_RNN_BIAS") \
|
||||||
|
.attr("input_size", "required", "int", "all") \
|
||||||
|
.attr("hidden_size", "required", "int", "all") \
|
||||||
|
.input(0, "src", False, "required", "all") \
|
||||||
|
.output(0, "dst", False, "required", "all") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_FracZNRNN) \
|
||||||
|
.dtype_format(DataType.F16_FracZNRNN, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_ND_RNNBIAS) \
|
||||||
|
.dtype_format(DataType.F16_ND_RNNBIAS, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_ND_RNNBIAS) \
|
||||||
|
.dtype_format(DataType.F32_ND_RNNBIAS, DataType.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(trans_data_rnn_op_info)
|
||||||
|
def _trans_data_rnn_tbe():
|
||||||
|
"""TransDataRNN TBE register"""
|
||||||
|
return
|
|
@ -721,6 +721,8 @@ class DataType:
|
||||||
F16_NDC1HWC0 = ("float16", "NDC1HWC0")
|
F16_NDC1HWC0 = ("float16", "NDC1HWC0")
|
||||||
F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D")
|
F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D")
|
||||||
F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM")
|
F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM")
|
||||||
|
F16_FracZNRNN = ("float16", "FRACTAL_ZN_RNN")
|
||||||
|
F16_ND_RNNBIAS = ("float16", "ND_RNN_BIAS")
|
||||||
F16_ChannelLast = ("float16", "ChannelLast")
|
F16_ChannelLast = ("float16", "ChannelLast")
|
||||||
|
|
||||||
F32_None = ("float32", "")
|
F32_None = ("float32", "")
|
||||||
|
@ -738,6 +740,8 @@ class DataType:
|
||||||
F32_NDC1HWC0 = ("float32", "NDC1HWC0")
|
F32_NDC1HWC0 = ("float32", "NDC1HWC0")
|
||||||
F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D")
|
F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D")
|
||||||
F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM")
|
F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM")
|
||||||
|
F32_FracZNRNN = ("float32", "FRACTAL_ZN_RNN")
|
||||||
|
F32_ND_RNNBIAS = ("float32", "ND_RNN_BIAS")
|
||||||
F32_ChannelLast = ("float32", "ChannelLast")
|
F32_ChannelLast = ("float32", "ChannelLast")
|
||||||
|
|
||||||
F64_None = ("float64", "")
|
F64_None = ("float64", "")
|
||||||
|
@ -878,6 +882,8 @@ class DataType:
|
||||||
F16_NDC1HWC0 = ("float16", "NDC1HWC0")
|
F16_NDC1HWC0 = ("float16", "NDC1HWC0")
|
||||||
F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D")
|
F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D")
|
||||||
F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM")
|
F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM")
|
||||||
|
F16_FracZNRNN = ("float16", "FRACTAL_ZN_RNN")
|
||||||
|
F16_ND_RNNBIAS = ("float16", "ND_RNN_BIAS")
|
||||||
F16_ChannelLast = ("float16", "ChannelLast")
|
F16_ChannelLast = ("float16", "ChannelLast")
|
||||||
|
|
||||||
F32_None = ("float32", "")
|
F32_None = ("float32", "")
|
||||||
|
@ -895,6 +901,8 @@ class DataType:
|
||||||
F32_NDC1HWC0 = ("float32", "NDC1HWC0")
|
F32_NDC1HWC0 = ("float32", "NDC1HWC0")
|
||||||
F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D")
|
F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D")
|
||||||
F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM")
|
F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM")
|
||||||
|
F32_FracZNRNN = ("float32", "FRACTAL_ZN_RNN")
|
||||||
|
F32_ND_RNNBIAS = ("float32", "ND_RNN_BIAS")
|
||||||
F32_ChannelLast = ("float32", "ChannelLast")
|
F32_ChannelLast = ("float32", "ChannelLast")
|
||||||
|
|
||||||
F64_None = ("float64", "")
|
F64_None = ("float64", "")
|
||||||
|
|
|
@ -7404,9 +7404,6 @@ class DynamicRNN(PrimitiveWithInfer):
|
||||||
are learnable weights between the output and the input in the formula. For instance,
|
are learnable weights between the output and the input in the formula. For instance,
|
||||||
:math:`W_{ix}, b_{ix}` are the weight and bias used to transform from input :math:`x` to :math:`i`.
|
:math:`W_{ix}, b_{ix}` are the weight and bias used to transform from input :math:`x` to :math:`i`.
|
||||||
|
|
||||||
Note:
|
|
||||||
The `hidden_size` in shape of inputs must be multiple of 16.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cell_type (str): A string identifying the cell type in the op. Default: 'LSTM'.
|
cell_type (str): A string identifying the cell type in the op. Default: 'LSTM'.
|
||||||
Only 'LSTM' is currently supported.
|
Only 'LSTM' is currently supported.
|
||||||
|
@ -7534,6 +7531,8 @@ class DynamicRNN(PrimitiveWithInfer):
|
||||||
validator.check("c_shape", c_shape, "h_shape", h_shape, Rel.EQ, self.name)
|
validator.check("c_shape", c_shape, "h_shape", h_shape, Rel.EQ, self.name)
|
||||||
self.placeholder_index = [3]
|
self.placeholder_index = [3]
|
||||||
self.add_prim_attr("placeholder_index", self.placeholder_index)
|
self.add_prim_attr("placeholder_index", self.placeholder_index)
|
||||||
|
self.add_prim_attr("input_size", input_size)
|
||||||
|
self.add_prim_attr("hidden_size", hidden_size)
|
||||||
y_shape = (num_step, batch_size, hidden_size)
|
y_shape = (num_step, batch_size, hidden_size)
|
||||||
return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape
|
return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape
|
||||||
|
|
||||||
|
|
|
@ -30,14 +30,10 @@ class FormatTransTest : public UT::Common {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(FormatTransTest, nchw_to_hwcn) {
|
TEST_F(FormatTransTest, nchw_to_hwcn) {
|
||||||
uint16_t data[2*2*2*2] = {12581,14220,14937,14302,
|
uint16_t data[2 * 2 * 2 * 2] = {12581, 14220, 14937, 14302, 15004, 14951, 14694, 14564,
|
||||||
15004,14951,14694,14564,
|
14069, 14554, 10507, 14787, 13016, 15263, 14872, 10838};
|
||||||
14069,14554,10507,14787,
|
uint16_t res[2 * 2 * 2 * 2] = {12581, 14069, 15004, 13016, 14220, 14554, 14951, 15263,
|
||||||
13016,15263,14872,10838};
|
14937, 10507, 14694, 14872, 14302, 14787, 14564, 10838};
|
||||||
uint16_t res[2*2*2*2] = {12581,14069,15004,13016,
|
|
||||||
14220,14554,14951,15263,
|
|
||||||
14937,10507,14694,14872,
|
|
||||||
14302,14787,14564,10838};
|
|
||||||
size_t device_size = 32;
|
size_t device_size = 32;
|
||||||
auto trans_tmp = std::vector<uint8_t>(device_size);
|
auto trans_tmp = std::vector<uint8_t>(device_size);
|
||||||
FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_HWCN,
|
FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_HWCN,
|
||||||
|
@ -49,15 +45,11 @@ TEST_F(FormatTransTest, nchw_to_hwcn) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(FormatTransTest, hwcn_to_nchw) {
|
TEST_F(FormatTransTest, hwcn_to_nchw) {
|
||||||
uint16_t data[2*2*2*2] = {12581,14069,15004,13016,
|
uint16_t data[2 * 2 * 2 * 2] = {12581, 14069, 15004, 13016, 14220, 14554, 14951, 15263,
|
||||||
14220,14554,14951,15263,
|
14937, 10507, 14694, 14872, 14302, 14787, 14564, 10838};
|
||||||
14937,10507,14694,14872,
|
|
||||||
14302,14787,14564,10838};
|
|
||||||
|
|
||||||
uint16_t res[2*2*2*2] = {12581,14220,14937,14302,
|
uint16_t res[2 * 2 * 2 * 2] = {12581, 14220, 14937, 14302, 15004, 14951, 14694, 14564,
|
||||||
15004,14951,14694,14564,
|
14069, 14554, 10507, 14787, 13016, 15263, 14872, 10838};
|
||||||
14069,14554,10507,14787,
|
|
||||||
13016,15263,14872,10838};
|
|
||||||
|
|
||||||
size_t device_size = 32;
|
size_t device_size = 32;
|
||||||
auto trans_tmp = std::vector<uint8_t>(device_size);
|
auto trans_tmp = std::vector<uint8_t>(device_size);
|
||||||
|
@ -70,14 +62,10 @@ TEST_F(FormatTransTest, hwcn_to_nchw) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(FormatTransTest, nchw_to_nhwc) {
|
TEST_F(FormatTransTest, nchw_to_nhwc) {
|
||||||
uint16_t data[2*2*2*2] = {11750,13778,15007,15321,
|
uint16_t data[2 * 2 * 2 * 2] = {11750, 13778, 15007, 15321, 15163, 13446, 15063, 14467,
|
||||||
15163,13446,15063,14467,
|
15056, 13284, 15219, 14797, 12684, 14288, 14855, 14799};
|
||||||
15056,13284,15219,14797,
|
uint16_t res[2 * 2 * 2 * 2] = {11750, 15163, 13778, 13446, 15007, 15063, 15321, 14467,
|
||||||
12684,14288,14855,14799};
|
15056, 12684, 13284, 14288, 15219, 14855, 14797, 14799};
|
||||||
uint16_t res[2*2*2*2] = {11750,15163,13778,13446,
|
|
||||||
15007,15063,15321,14467,
|
|
||||||
15056,12684,13284,14288,
|
|
||||||
15219,14855,14797,14799};
|
|
||||||
size_t device_size = 32;
|
size_t device_size = 32;
|
||||||
auto trans_tmp = std::vector<uint8_t>(device_size);
|
auto trans_tmp = std::vector<uint8_t>(device_size);
|
||||||
FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_NHWC,
|
FormatArgs format_args{data, device_size, kOpFormat_NCHW, kOpFormat_NHWC,
|
||||||
|
@ -87,15 +75,12 @@ TEST_F(FormatTransTest, nchw_to_nhwc) {
|
||||||
EXPECT_EQ((reinterpret_cast<uint16_t *>(trans_tmp.data()))[i], res[i]);
|
EXPECT_EQ((reinterpret_cast<uint16_t *>(trans_tmp.data()))[i], res[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(FormatTransTest, nhwc_to_nchw) {
|
TEST_F(FormatTransTest, nhwc_to_nchw) {
|
||||||
uint16_t data[2*2*2*2] = {11750,15163,13778,13446,
|
uint16_t data[2 * 2 * 2 * 2] = {11750, 15163, 13778, 13446, 15007, 15063, 15321, 14467,
|
||||||
15007,15063,15321,14467,
|
15056, 12684, 13284, 14288, 15219, 14855, 14797, 14799};
|
||||||
15056,12684,13284,14288,
|
uint16_t res[2 * 2 * 2 * 2] = {11750, 13778, 15007, 15321, 15163, 13446, 15063, 14467,
|
||||||
15219,14855,14797,14799};
|
15056, 13284, 15219, 14797, 12684, 14288, 14855, 14799};
|
||||||
uint16_t res[2*2*2*2] = {11750,13778,15007,15321,
|
|
||||||
15163,13446,15063,14467,
|
|
||||||
15056,13284,15219,14797,
|
|
||||||
12684,14288,14855,14799};
|
|
||||||
|
|
||||||
size_t device_size = 32;
|
size_t device_size = 32;
|
||||||
auto trans_tmp = std::vector<uint8_t>(device_size);
|
auto trans_tmp = std::vector<uint8_t>(device_size);
|
||||||
|
@ -106,8 +91,60 @@ TEST_F(FormatTransTest, nhwc_to_nchw) {
|
||||||
EXPECT_EQ((reinterpret_cast<uint16_t *>(trans_tmp.data()))[i], res[i]);
|
EXPECT_EQ((reinterpret_cast<uint16_t *>(trans_tmp.data()))[i], res[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class ShapeTransTest : public UT::Common {
|
||||||
|
public:
|
||||||
|
ShapeTransTest() = default;
|
||||||
|
void SetUp() override {}
|
||||||
|
void TearDown() override {}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(ShapeTransTest, fraczn_rnn_device_shape) {
|
||||||
|
std::vector<size_t> host_shape = {43, 120};
|
||||||
|
std::string format = kOpFormat_FRACTAL_ZN_RNN;
|
||||||
|
std::vector<int64_t> input_hidden_size = {13, 30};
|
||||||
|
auto trans_shape = trans::TransShapeToDevice(host_shape, format, 1, input_hidden_size);
|
||||||
|
const std::vector<size_t> expect_shape = {3, 8, 16, 16};
|
||||||
|
EXPECT_EQ(trans_shape.size(), expect_shape.size());
|
||||||
|
for (size_t i = 0; i < expect_shape.size(); i++) {
|
||||||
|
EXPECT_EQ(trans_shape[i], expect_shape[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeTransTest, nd_rnn_bias_device_shape) {
|
||||||
|
std::vector<size_t> host_shape = {120};
|
||||||
|
std::string format = kOpFormat_ND_RNN_BIAS;
|
||||||
|
std::vector<int64_t> input_hidden_size = {13, 30};
|
||||||
|
auto trans_shape = trans::TransShapeToDevice(host_shape, format, 1, input_hidden_size);
|
||||||
|
std::vector<size_t> expect_shape = {128};
|
||||||
|
EXPECT_EQ(trans_shape.size(), expect_shape.size());
|
||||||
|
for (size_t i = 0; i < expect_shape.size(); i++) {
|
||||||
|
EXPECT_EQ(trans_shape[i], expect_shape[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeTransTest, fraczn_rnn_dynamic_device_shape) {
|
||||||
|
std::vector<int64_t> host_shape = {-1, -1};
|
||||||
|
std::string format = kOpFormat_FRACTAL_ZN_RNN;
|
||||||
|
std::vector<int64_t> input_hidden_size = {13, 30};
|
||||||
|
auto trans_shape = trans::TransShapeToDevice(host_shape, format, 1, input_hidden_size);
|
||||||
|
const std::vector<int64_t> expect_shape = {-1, -1, 16, 16};
|
||||||
|
EXPECT_EQ(trans_shape.size(), expect_shape.size());
|
||||||
|
for (size_t i = 0; i < expect_shape.size(); i++) {
|
||||||
|
EXPECT_EQ(trans_shape[i], expect_shape[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeTransTest, nd_rnn_bias_dynamic_device_shape) {
|
||||||
|
std::vector<int64_t> host_shape = {-1};
|
||||||
|
std::string format = kOpFormat_ND_RNN_BIAS;
|
||||||
|
std::vector<int64_t> input_hidden_size = {13, 30};
|
||||||
|
auto trans_shape = trans::TransShapeToDevice(host_shape, format, 1, input_hidden_size);
|
||||||
|
std::vector<int64_t> expect_shape = {-1};
|
||||||
|
EXPECT_EQ(trans_shape.size(), expect_shape.size());
|
||||||
|
for (size_t i = 0; i < expect_shape.size(); i++) {
|
||||||
|
EXPECT_EQ(trans_shape[i], expect_shape[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace trans
|
} // namespace trans
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue