!18361 code check clean

Merge pull request !18361 from yuchaojie/code-clean
This commit is contained in:
i-robot 2021-06-16 21:07:01 +08:00 committed by Gitee
commit 1c991331b9
44 changed files with 298 additions and 309 deletions

View File

@ -34,7 +34,7 @@ void BatchMatmulFusedMulAddFusionPass::MatchBatchMatmulFusedMulAdd(const CNodePt
MS_EXCEPTION_IF_NULL(candidate_fusion);
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
auto batch_matmul = cnode->input(2);
auto batch_matmul = cnode->input(kIndex2);
MS_EXCEPTION_IF_NULL(batch_matmul);
if (batch_matmul->isa<CNode>() && AnfAlgo::CheckPrimitiveType(batch_matmul, prim::kPrimBatchMatMul)) {
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[batch_matmul].size())};

View File

@ -31,8 +31,6 @@ namespace opt {
namespace {
constexpr size_t kEltwiseInputSize = 2;
constexpr size_t kEltwiseOutputSize = 2;
constexpr size_t kInputIndex1 = 1;
constexpr size_t kInputIndex2 = 2;
bool CheckEltwiseInputAndOutputSize(const AnfNodePtr &node) {
if (AnfAlgo::GetInputTensorNum(node) == kEltwiseInputSize) {
return true;
@ -54,7 +52,7 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod
MS_EXCEPTION_IF_NULL(relu_input);
auto add = relu_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(add);
auto tuple_getitem = add->input(1);
auto tuple_getitem = add->input(kIndex1);
std::vector<int64_t> add_output_used_num;
add_output_used_num.emplace_back(SizeToLong(manager->node_users()[add].size()));
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(add_output_used_num), add);
@ -62,7 +60,7 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod
if (tuple_getitem->isa<CNode>() && AnfAlgo::GetCNodeName(tuple_getitem) == prim::kPrimTupleGetItem->name()) {
auto getitem = tuple_getitem->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(getitem);
auto bnupdate = getitem->input(kInputIndex1);
auto bnupdate = getitem->input(kIndex1);
MS_EXCEPTION_IF_NULL(bnupdate);
if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) {
std::vector<int64_t> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0);
@ -73,7 +71,7 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod
}
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(out_getitem_ptr);
auto input2 = out_getitem_ptr->input(kInputIndex2);
auto input2 = out_getitem_ptr->input(kIndex2);
auto output_idx = GetValue<int64_t>(GetValueNode(input2));
output_used_num[output_idx] = SizeToLong(manager->node_users()[out_getitem.first].size());
}
@ -98,7 +96,7 @@ void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::K
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && CheckEltwiseInputAndOutputSize(cnode)) {
auto eltwise_input = cnode->input(1);
auto eltwise_input = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimAdd)) {
MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);

View File

@ -27,7 +27,6 @@
namespace mindspore {
namespace opt {
constexpr size_t INPUT2 = 2;
void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr &cnode, const AnfNodePtr &eltwise_input,
const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion) {
@ -38,7 +37,7 @@ void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr
MS_EXCEPTION_IF_NULL(eltwise_input);
auto getitem = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(getitem);
auto bnupdate = getitem->input(1);
auto bnupdate = getitem->input(kIndex1);
MS_EXCEPTION_IF_NULL(bnupdate);
if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) {
std::vector<int64_t> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0);
@ -49,7 +48,7 @@ void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr
}
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(out_getitem_ptr);
auto input2 = out_getitem_ptr->input(INPUT2);
auto input2 = out_getitem_ptr->input(kIndex2);
auto output_idx = GetValue<int64_t>(GetValueNode(input2));
output_used_num[output_idx] = SizeToLong(manager->node_users()[out_getitem.first].size());
}

View File

@ -26,15 +26,12 @@
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kInputIndex2 = 2;
}
void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltwise(
const CNodePtr &cnode, const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::unordered_set<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(1);
auto eltwise_input = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (CheckDoubleInEltWiseNode(kernel_graph, eltwise_input)) {
(void)record.insert(eltwise_input);
@ -45,7 +42,7 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
MS_EXCEPTION_IF_NULL(manager);
auto input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
auto double_in_eltwise_input = input_cnode->input(kInputIndex2);
auto double_in_eltwise_input = input_cnode->input(kIndex2);
MS_EXCEPTION_IF_NULL(double_in_eltwise_input);
std::vector<int64_t> conv2d_bp_output_used_num;
if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input)) {
@ -59,7 +56,7 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
conv2d_bp_output_used_num.emplace_back(SizeToLong(manager->node_users()[double_in_eltwise_input].size()));
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(conv2d_bp_output_used_num), double_in_eltwise_input);
} else {
auto double_in_eltwise_input_1 = input_cnode->input(1);
auto double_in_eltwise_input_1 = input_cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(double_in_eltwise_input_1);
if (!double_in_eltwise_input_1->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input_1)) {
return;

View File

@ -32,7 +32,7 @@ void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNod
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::unordered_set<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(1);
auto eltwise_input = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) ||
fusion_id_allocator->HasFusionIdAttr(eltwise_input)) {

View File

@ -34,7 +34,7 @@ void ConvBnReduceFusionPass::MatchConvBnreduce(const CNodePtr &cnode, const sess
MS_EXCEPTION_IF_NULL(candidate_fusion);
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
auto conv = cnode->input(1);
auto conv = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(conv);
if (conv->isa<CNode>() && AnfAlgo::GetCNodeName(conv) == prim::kPrimConv2D->name()) {
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[conv].size())};

View File

@ -31,7 +31,7 @@ void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, con
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::unordered_set<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(1);
auto eltwise_input = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (CheckDoubleInEltWiseNode(kernel_graph, eltwise_input)) {
(void)record.insert(eltwise_input);
@ -40,7 +40,7 @@ void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, con
}
auto input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
auto double_in_eltwise_input = input_cnode->input(1);
auto double_in_eltwise_input = input_cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(double_in_eltwise_input);
if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) ||
fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) {

View File

@ -31,12 +31,12 @@ void ConvSingleInFusionPass::MatchConvSingleInEltwise(const CNodePtr &cnode, con
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::unordered_set<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(1);
auto eltwise_input = cnode->input(kIndex1);
while (CheckEltWiseNode(kernel_graph, eltwise_input)) {
(void)record.insert(eltwise_input);
auto input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
eltwise_input = input_cnode->input(1);
eltwise_input = input_cnode->input(kIndex1);
if (record.size() == MAX_ELTWISE_NUM) {
break;
}

View File

@ -37,7 +37,7 @@ void DepthwiseConvEltwiseFusionPass::MatchDepthwiseConvRelu(const CNodePtr &cnod
MS_EXCEPTION_IF_NULL(manager);
if (is_order) {
// DepthwiseConvolution--->Elemwise
auto depthwise_conv = cnode->input(1);
auto depthwise_conv = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(depthwise_conv);
if (cnode->isa<CNode>() && IsPrimitiveCNode(depthwise_conv, prim::kPrimDepthwiseConv2dNative)) {
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[depthwise_conv].size())};
@ -48,7 +48,7 @@ void DepthwiseConvEltwiseFusionPass::MatchDepthwiseConvRelu(const CNodePtr &cnod
}
} else {
// Elemwise-->DepthwiseConvolution
auto relu = cnode->input(1);
auto relu = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(relu);
if (cnode->isa<CNode>() && (IsPrimitiveCNode(relu, prim::kPrimRelu) || IsPrimitiveCNode(relu, prim::kPrimReluV2))) {
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[relu].size())};
@ -73,7 +73,7 @@ void DepthwiseConvEltwiseFusionPass::MatchSingleFusionPattern(const session::Ker
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
auto eltwise_input = cnode->input(1);
auto eltwise_input = cnode->input(kIndex1);
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimDepthwiseConv2dNative)) {
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true);
}

View File

@ -31,7 +31,7 @@ void EltwiseFusionPass::MatchEltwise(const CNodePtr &cnode, const session::Kerne
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::unordered_set<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(1);
auto eltwise_input = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(eltwise_input);
while (CheckEltWiseNode(kernel_graph, eltwise_input)) {
(void)record.insert(eltwise_input);
@ -40,7 +40,7 @@ void EltwiseFusionPass::MatchEltwise(const CNodePtr &cnode, const session::Kerne
}
auto input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
eltwise_input = input_cnode->input(1);
eltwise_input = input_cnode->input(kIndex1);
}
if (CheckDoubleInEltWiseNode(kernel_graph, eltwise_input)) {
(void)record.insert(eltwise_input);

View File

@ -34,7 +34,7 @@ void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNode
MS_EXCEPTION_IF_NULL(candidate_fusion);
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
auto matmul = cnode->input(1);
auto matmul = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(matmul);
if (matmul->isa<CNode>() && (AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimMatMul) ||
AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimBatchMatMul))) {

View File

@ -56,7 +56,7 @@ void MatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGrap
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
auto eltwise_input = cnode->input(1);
auto eltwise_input = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimMatMul)) {
MatchMatmulEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion);

View File

@ -33,7 +33,7 @@ void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
std::unordered_set<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(1);
auto eltwise_input = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (CheckMultiOutputEltWiseNode(kernel_graph, eltwise_input)) {
std::vector<int64_t> output_used_num{SizeToInt(manager->node_users()[eltwise_input].size())};
@ -41,7 +41,7 @@ void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const
(void)record.insert(eltwise_input);
auto input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
eltwise_input = input_cnode->input(1);
eltwise_input = input_cnode->input(kIndex1);
} else {
return;
}
@ -52,7 +52,7 @@ void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const
}
auto input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
eltwise_input = input_cnode->input(1);
eltwise_input = input_cnode->input(kIndex1);
}
if (record.size() != MULTI_ELTWISE_SIZE) {
return;

View File

@ -32,12 +32,12 @@ void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const se
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::unordered_set<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(1);
auto eltwise_input = cnode->input(kIndex1);
while (CheckEltWiseNode(kernel_graph, eltwise_input)) {
(void)record.insert(eltwise_input);
auto input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
eltwise_input = input_cnode->input(1);
eltwise_input = input_cnode->input(kIndex1);
if (record.size() == MAX_ELTWISE_NUM) {
break;
}
@ -52,13 +52,13 @@ void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const se
(void)record.insert(eltwise_input);
auto previous_input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(previous_input_cnode);
auto previous_eltwise_input = previous_input_cnode->input(1);
auto previous_eltwise_input = previous_input_cnode->input(kIndex1);
auto previous_size = record.size();
while (CheckEltWiseNode(kernel_graph, previous_eltwise_input)) {
(void)record.insert(previous_eltwise_input);
auto previous_node = previous_eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(previous_node);
previous_eltwise_input = previous_node->input(1);
previous_eltwise_input = previous_node->input(kIndex1);
if (record.size() - previous_size == MAX_ELTWISE_NUM) {
break;
}

View File

@ -31,12 +31,12 @@ void SegmentEltwiseFusionPass::MatchSegmentEltwise(const CNodePtr &cnode, const
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::unordered_set<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(1);
auto eltwise_input = cnode->input(kIndex1);
while (CheckEltWiseNode(kernel_graph, eltwise_input)) {
(void)record.insert(eltwise_input);
auto input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
eltwise_input = input_cnode->input(1);
eltwise_input = input_cnode->input(kIndex1);
if (record.size() == MAX_ELTWISE_NUM) {
break;
}
@ -51,13 +51,13 @@ void SegmentEltwiseFusionPass::MatchSegmentEltwise(const CNodePtr &cnode, const
(void)record.insert(eltwise_input);
auto previous_input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(previous_input_cnode);
auto previous_eltwise_input = previous_input_cnode->input(1);
auto previous_eltwise_input = previous_input_cnode->input(kIndex1);
auto previous_size = record.size();
while (CheckEltWiseNode(kernel_graph, previous_eltwise_input)) {
(void)record.insert(previous_eltwise_input);
auto previous_node = previous_eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(previous_node);
previous_eltwise_input = previous_node->input(1);
previous_eltwise_input = previous_node->input(kIndex1);
if (record.size() - previous_size == MAX_ELTWISE_NUM) {
break;
}

View File

@ -34,12 +34,12 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::unordered_set<AnfNodePtr> record{cnode};
auto write_input = cnode->input(1);
auto write_input = cnode->input(kIndex1);
if (CheckEltWiseNode(kernel_graph, write_input)) {
(void)record.insert(write_input);
auto input_cnode = write_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
write_input = input_cnode->input(1);
write_input = input_cnode->input(kIndex1);
}
MS_EXCEPTION_IF_NULL(write_input);
if (!write_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(write_input) ||
@ -53,7 +53,7 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
conv_cnode->inputs().size() >= CONV_DOUBLE_IN_INPUT_SIZE &&
conv_cnode->inputs().size() <= CONV_QUART_IN_INPUT_SIZE) {
(void)record.insert(write_input);
auto conv_input = conv_cnode->input(1);
auto conv_input = conv_cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(conv_input);
if (!conv_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(conv_input) ||
fusion_id_allocator->HasFusionIdAttr(conv_input)) {

View File

@ -39,8 +39,6 @@ const int8_t ELTWISE_USE = 1;
const int8_t MULTI_ELTWISE_USE = 2;
const int8_t MAX_MULTI_ELTWISE_SIZE = 4;
const int8_t MAX_PURE_BUFFER_SUCC_SIZE = 3;
constexpr size_t kInputIndex1 = 1;
constexpr size_t kInputIndex2 = 2;
constexpr size_t kFusionNodeNumThreshold = 2;
constexpr auto kOpAttrFusionId = "fusion_id";
@ -132,12 +130,10 @@ kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector<AnfNodePtr
if (AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) {
auto tuple_getitem = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_getitem);
outputs_format.emplace_back(
AnfAlgo::GetOutputFormat(tuple_getitem->input(kInputIndex1),
LongToSize(GetValue<int64_t>(GetValueNode(tuple_getitem->input(kInputIndex2))))));
outputs_format.emplace_back(AnfAlgo::GetOutputFormat(
tuple_getitem->input(kIndex1), LongToSize(GetValue<int64_t>(GetValueNode(tuple_getitem->input(kIndex2))))));
outputs_data_type.emplace_back(AnfAlgo::GetOutputDeviceDataType(
tuple_getitem->input(kInputIndex1),
LongToSize(GetValue<int64_t>(GetValueNode(tuple_getitem->input(kInputIndex2))))));
tuple_getitem->input(kIndex1), LongToSize(GetValue<int64_t>(GetValueNode(tuple_getitem->input(kIndex2))))));
} else {
outputs_format.emplace_back(AnfAlgo::GetOutputFormat(output, 0));
outputs_data_type.emplace_back(AnfAlgo::GetOutputDeviceDataType(output, 0));
@ -190,6 +186,7 @@ void ReplaceInputNodeInOtherFusionScope(std::unordered_map<int64_t, BufferFusion
void ReplaceOldNode(std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos, int64_t fusion_id,
const AnfNodePtr &buffer_fusion_kernel, session::KernelGraph *kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
auto manager = kernel_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id];
@ -275,8 +272,8 @@ bool TupleGetitemNodeCompare(const AnfNodePtr &node1, const AnfNodePtr &node2) {
MS_LOG(EXCEPTION) << "node's input size less than " << kTupleGetItemInputSize << ", getitem1["
<< getitem2->DebugString() << "]";
}
auto output_idx1 = GetValue<int64_t>(GetValueNode(getitem1->input(kInputIndex2)));
auto output_idx2 = GetValue<int64_t>(GetValueNode(getitem2->input(kInputIndex2)));
auto output_idx1 = GetValue<int64_t>(GetValueNode(getitem1->input(kIndex2)));
auto output_idx2 = GetValue<int64_t>(GetValueNode(getitem2->input(kIndex2)));
return output_idx1 < output_idx2;
}
@ -311,7 +308,8 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph,
for (auto &getitem : tuple_getitem_nodes) {
MS_EXCEPTION_IF_NULL(getitem);
auto getitem_ptr = getitem->cast<CNodePtr>();
auto input2 = getitem_ptr->input(kInputIndex2);
MS_EXCEPTION_IF_NULL(getitem_ptr);
auto input2 = getitem_ptr->input(kIndex2);
auto output_idx = GetValue<int64_t>(GetValueNode(input2));
for (int64_t stub_idx = prev_idx; stub_idx < output_idx; ++stub_idx) {
auto stub_node = CreateTupleGetItem(node, kernel_graph, LongToSize(stub_idx));
@ -343,7 +341,7 @@ void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector<A
auto real_output = AnfAlgo::VisitKernel(output, 0);
auto output_cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode);
auto input2 = output_cnode->input(kInputIndex2);
auto input2 = output_cnode->input(kIndex2);
auto output_idx = GetValue<int64_t>(GetValueNode(input2));
session::AnfWithOutIndex out_pair(real_output.first, output_idx);
if (kernel_graph->IsInRefOutputMap(out_pair)) {
@ -398,6 +396,7 @@ void RemoveCircle(const session::KernelGraph &kernel_graph,
void UbPatternFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph,
std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos) const {
MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
MS_EXCEPTION_IF_NULL(kernel_graph);
GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos);
GetFusionScopeInputNodeList(*kernel_graph, buffer_fusion_infos);
GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos);

View File

@ -27,9 +27,6 @@ namespace {
const std::vector<int64_t> kOutputIndex{0, 3, 4, 5};
constexpr size_t kBatchNormRealOutputNum = 3;
constexpr size_t kBatchNormRealInputNum = 3;
constexpr size_t kInputIndex2 = 2;
constexpr size_t kInputIndex3 = 3;
constexpr size_t kInputIndex4 = 4;
bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) {
MS_EXCEPTION_IF_NULL(func_graph);
@ -72,12 +69,12 @@ AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodeP
<< (kBatchNormRealInputNum + 1) << " trace: " << trace::DumpSourceLines(bn);
}
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), bn_cnode->input(1)};
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), bn_cnode->input(kIndex1)};
auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs);
MS_EXCEPTION_IF_NULL(bn_training_reduce);
auto bn_input1 = bn_cnode->input(kInputIndex2);
auto bn_input1 = bn_cnode->input(kIndex2);
MS_EXCEPTION_IF_NULL(bn_input1);
auto bn_input2 = bn_cnode->input(kInputIndex3);
auto bn_input2 = bn_cnode->input(kIndex3);
MS_EXCEPTION_IF_NULL(bn_input2);
AbstractBasePtrList abstract_list{bn_input1->abstract(), bn_input2->abstract()};
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
@ -104,11 +101,11 @@ AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNod
}
std::vector<AnfNodePtr> bn_training_update_v2_inputs = {
NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateV2OpName)),
bn_cnode->input(1),
bn_training_reduce_outputs[0],
bn_training_reduce_outputs[1],
bn_cnode->input(kInputIndex2),
bn_cnode->input(kInputIndex3)};
bn_cnode->input(kIndex1),
bn_training_reduce_outputs[kIndex0],
bn_training_reduce_outputs[kIndex1],
bn_cnode->input(kIndex2),
bn_cnode->input(kIndex3)};
auto bn_training_update_v2 = func_graph->NewCNode(bn_training_update_v2_inputs);
MS_EXCEPTION_IF_NULL(bn_training_update_v2);
@ -118,9 +115,9 @@ AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNod
MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBnOutputNum << ", but it is "
<< bn_abstract_tuple->elements().size() << " trace: " << trace::DumpSourceLines(bn);
}
std::vector<AbstractBasePtr> abstract_list{bn_abstract_tuple->elements()[0],
bn_abstract_tuple->elements()[kInputIndex3],
bn_abstract_tuple->elements()[kInputIndex4]};
std::vector<AbstractBasePtr> abstract_list{bn_abstract_tuple->elements()[kIndex0],
bn_abstract_tuple->elements()[kIndex3],
bn_abstract_tuple->elements()[kIndex4]};
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
bn_training_update_v2->set_abstract(abstract_tuple);
bn_training_update_v2->set_scope(bn->scope());

View File

@ -23,8 +23,6 @@ namespace mindspore {
namespace opt {
namespace {
constexpr size_t kBatchNormGradInferOutputNum = 3;
constexpr size_t kElementIndex1 = 1;
constexpr size_t kElementIndex2 = 2;
bool CheckOutputsIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
@ -136,8 +134,8 @@ AnfNodePtr BatchNormGradInferFission::CreateBNTrainingUpdateGrad(const FuncGraph
MS_LOG(EXCEPTION) << "The abstract tuple of node " << bn_grad->DebugString() << "should not be less than 3"
<< trace::DumpSourceLines(bn_grad);
}
std::vector<AbstractBasePtr> abstract_list{bn_grad_abstract_tuple->elements()[kElementIndex1],
bn_grad_abstract_tuple->elements()[kElementIndex2]};
std::vector<AbstractBasePtr> abstract_list{bn_grad_abstract_tuple->elements()[kIndex1],
bn_grad_abstract_tuple->elements()[kIndex2]};
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
bn_training_update_grad->set_abstract(abstract_tuple);
AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad, bn_training_update_grad);

View File

@ -28,13 +28,6 @@
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kIndex0 = 0;
constexpr size_t kIndex1 = 1;
constexpr size_t kIndex2 = 2;
constexpr size_t kIndex3 = 3;
constexpr size_t kIndex4 = 4;
constexpr size_t kIndex5 = 5;
void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
std::vector<AnfNodePtr> *bn_update_grad_outputs) {
MS_EXCEPTION_IF_NULL(graph);

View File

@ -38,7 +38,7 @@ AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node)
new_simoid_inputs.insert(new_simoid_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
CNodePtr new_cnode = func_graph->NewCNode(new_simoid_inputs);
MS_EXCEPTION_IF_NULL(new_cnode);
auto predict_input = cnode->inputs()[1];
auto predict_input = cnode->inputs()[kIndex1];
auto new_node_dtype = {AnfAlgo::GetOutputInferDataType(predict_input, 0)};
auto new_node_shape = {AnfAlgo::GetOutputInferShape(predict_input, 0)};
AnfAlgo::SetOutputInferTypeAndShape(new_node_dtype, new_node_shape, new_cnode.get());

View File

@ -29,12 +29,6 @@
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kIndex0 = 0;
constexpr size_t kIndex1 = 1;
constexpr size_t kIndex2 = 2;
constexpr size_t kIndex3 = 3;
constexpr size_t kIndex4 = 4;
constexpr size_t kIndex5 = 5;
void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
std::vector<AnfNodePtr> *bn_update_grad_outputs) {
MS_EXCEPTION_IF_NULL(graph);

View File

@ -32,12 +32,6 @@ namespace opt {
namespace {
constexpr auto kReduceOpSum = "sum";
constexpr auto kDeviceNum = "device_num";
constexpr size_t kIndex0 = 0;
constexpr size_t kIndex1 = 1;
constexpr size_t kIndex2 = 2;
constexpr size_t kIndex3 = 3;
constexpr size_t kIndex4 = 4;
constexpr size_t kIndex5 = 5;
constexpr size_t kPositionOffset = 3;
constexpr int64_t kFusionNumThreshold = 2;
@ -51,7 +45,7 @@ bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &
}
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName))};
bn_training_reduce_inputs.push_back(bn_cnode->input(1));
bn_training_reduce_inputs.push_back(bn_cnode->input(kIndex1));
auto bn_training_reduce = graph->NewCNode(bn_training_reduce_inputs);
MS_EXCEPTION_IF_NULL(bn_training_reduce);
auto kernel_info = std::make_shared<device::KernelInfo>();
@ -62,7 +56,7 @@ bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &
MS_LOG(INFO) << "The BatchNorm's first input's shape dims less than " << kShape2dDims;
return false;
}
std::vector<size_t> bn_training_reduce_shape = {bn_shape_i0[1]};
std::vector<size_t> bn_training_reduce_shape = {bn_shape_i0[kDim1]};
auto types = {kNumberTypeFloat32, kNumberTypeFloat32};
auto shapes = {bn_training_reduce_shape, bn_training_reduce_shape};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_training_reduce.get());
@ -191,6 +185,8 @@ AnfNodePtr CreateValueNodeOfDeviceNumReciprocal(const FuncGraphPtr &graph, const
}
AnfNodePtr InsertCast(const FuncGraphPtr &graph, const AnfNodePtr &input, const TypeId dst_type) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(input);
if (AnfAlgo::GetOutputInferDataType(input, 0) != dst_type) {
AnfNodePtr cast = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(kCastOpName)), input});
AnfAlgo::SetOutputInferTypeAndShape({dst_type}, {AnfAlgo::GetOutputInferShape(input, 0)}, cast.get());

View File

@ -31,22 +31,6 @@ constexpr size_t kGRUV2HiddenGradOutputNum = 3;
constexpr size_t kConcatNum = 2;
constexpr size_t kGateNum = 3;
constexpr size_t k3Dims = 3;
constexpr size_t kIndex0 = 0;
constexpr size_t kIndex1 = 1;
constexpr size_t kIndex2 = 2;
constexpr size_t kIndex3 = 3;
constexpr size_t kIndex4 = 4;
constexpr size_t kIndex5 = 5;
constexpr size_t kIndex6 = 6;
constexpr size_t kIndex7 = 7;
constexpr size_t kIndex8 = 8;
constexpr size_t kIndex9 = 9;
constexpr size_t kIndex10 = 10;
constexpr size_t kIndex11 = 11;
constexpr size_t kIndex12 = 12;
constexpr size_t DIM0 = 0;
constexpr size_t DIM1 = 1;
constexpr size_t DIM2 = 2;
AnfNodePtr CreateGRUV2HiddenGradNode(const FuncGraphPtr &graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(graph);
@ -74,9 +58,9 @@ AnfNodePtr CreateGRUV2HiddenGradNode(const FuncGraphPtr &graph, const AnfNodePtr
auto types = {h_dtype, h_dtype, h_dtype};
std::vector<size_t> dh_preh_shape = AnfAlgo::GetOutputInferShape(ori_outputs[kIndex5], 0);
std::vector<size_t> dgate_h_shape = {
AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[kIndex6], 0)[DIM0],
AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[kIndex6], 0)[DIM1],
kGateNum * AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[kIndex6], 0)[DIM2]};
AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[kIndex6], 0)[kDim0],
AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[kIndex6], 0)[kDim1],
kGateNum * AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[kIndex6], 0)[kDim2]};
std::vector<size_t> dnx_t_shape = AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[kIndex6], 0);
auto shapes = {dh_preh_shape, dgate_h_shape, dnx_t_shape};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, gru_v2_hidden_grad_op.get());
@ -93,14 +77,14 @@ AnfNodePtr CreateHSplitVDNode(const FuncGraphPtr &graph, const AnfNodePtr &node)
auto split_vd = graph->NewCNode(splitvd_input);
MS_EXCEPTION_IF_NULL(split_vd);
auto dtypes = {AnfAlgo::GetOutputInferDataType(node, 0), AnfAlgo::GetOutputInferDataType(node, 0)};
size_t t_size = AnfAlgo::GetOutputInferShape(node, 0)[DIM0];
size_t batch = AnfAlgo::GetOutputInferShape(node, 0)[DIM1];
size_t hidden_size = AnfAlgo::GetOutputInferShape(node, 0)[DIM2];
size_t t_size = AnfAlgo::GetOutputInferShape(node, 0)[kDim0];
size_t batch = AnfAlgo::GetOutputInferShape(node, 0)[kDim1];
size_t hidden_size = AnfAlgo::GetOutputInferShape(node, 0)[kDim2];
std::vector<size_t> shape = {t_size - IntToSize(1), batch, hidden_size};
std::vector<size_t> shape2 = {IntToSize(1), batch, hidden_size};
std::vector<std::vector<size_t>> shapes = {shape, shape2};
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_vd.get());
AnfAlgo::SetNodeAttr("split_dim", MakeValue(SizeToLong(DIM0)), split_vd);
AnfAlgo::SetNodeAttr("split_dim", MakeValue(SizeToLong(kDim0)), split_vd);
AnfAlgo::SetNodeAttr("num_split", MakeValue(SizeToLong(kSplitVOutputNum)), split_vd);
std::vector<int64_t> size_splits = {SizeToLong(t_size - 1), SizeToLong(1)};
AnfAlgo::SetNodeAttr("size_splits", MakeValue(size_splits), split_vd);
@ -116,7 +100,7 @@ AnfNodePtr CreateHReshape(const FuncGraphPtr &graph, const AnfNodePtr &node) {
if (ori_shape.size() == k3Dims) {
shape_tmp = {ori_shape};
} else {
shape_tmp = {{IntToSize(1), ori_shape[DIM0], ori_shape[DIM1]}};
shape_tmp = {{IntToSize(1), ori_shape[kDim0], ori_shape[kDim1]}};
}
auto ori_dtype = {AnfAlgo::GetOutputInferDataType(node, 0)};
// reshape
@ -140,14 +124,14 @@ AnfNodePtr CreateHConcatDNode(const FuncGraphPtr &graph, const AnfNodePtr &node1
auto concat_op = graph->NewCNode(concat_inputs);
MS_EXCEPTION_IF_NULL(concat_op);
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node2, 0)[DIM0] + 1,
AnfAlgo::GetOutputInferShape(node2, 0)[DIM1],
AnfAlgo::GetOutputInferShape(node2, 0)[DIM2]};
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node2, 0)[kDim0] + 1,
AnfAlgo::GetOutputInferShape(node2, 0)[kDim1],
AnfAlgo::GetOutputInferShape(node2, 0)[kDim2]};
auto types = {AnfAlgo::GetOutputInferDataType(node2, 0)};
AnfAlgo::SetOutputInferTypeAndShape(types, {shape}, concat_op.get());
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(kConcatNum)), concat_op);
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{2}), concat_op);
AnfAlgo::SetNodeAttr("axis", MakeValue(SizeToLong(0)), concat_op);
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(0)), concat_op);
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat_op);
return concat_op;
}
@ -160,14 +144,14 @@ AnfNodePtr CreateDgateHSplitVDNode(const FuncGraphPtr &graph, const AnfNodePtr &
auto split_vd = graph->NewCNode(splitvd_input);
MS_EXCEPTION_IF_NULL(split_vd);
auto dtypes = {AnfAlgo::GetOutputInferDataType(node, 0), AnfAlgo::GetOutputInferDataType(node, 0)};
size_t t_size = AnfAlgo::GetOutputInferShape(node, 0)[DIM0];
size_t batch = AnfAlgo::GetOutputInferShape(node, 0)[DIM1];
size_t hidden_size = AnfAlgo::GetOutputInferShape(node, 0)[DIM2] / kGateNum;
size_t t_size = AnfAlgo::GetOutputInferShape(node, 0)[kDim0];
size_t batch = AnfAlgo::GetOutputInferShape(node, 0)[kDim1];
size_t hidden_size = AnfAlgo::GetOutputInferShape(node, 0)[kDim2] / kGateNum;
std::vector<size_t> shape = {t_size, batch, hidden_size << 1};
std::vector<size_t> shape2 = {t_size, batch, hidden_size};
std::vector<std::vector<size_t>> shapes = {shape, shape2};
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_vd.get());
AnfAlgo::SetNodeAttr("split_dim", MakeValue(SizeToLong(DIM2)), split_vd);
AnfAlgo::SetNodeAttr("split_dim", MakeValue(SizeToLong(kDim2)), split_vd);
AnfAlgo::SetNodeAttr("num_split", MakeValue(SizeToLong(kSplitVOutputNum)), split_vd);
std::vector<int64_t> size_splits = {SizeToLong(hidden_size + hidden_size), SizeToLong(hidden_size)};
AnfAlgo::SetNodeAttr("size_splits", MakeValue(size_splits), split_vd);
@ -190,13 +174,13 @@ AnfNodePtr CreateDgateXConcatDNode(const FuncGraphPtr &graph, const AnfNodePtr &
auto concat_op = graph->NewCNode(concat_inputs);
MS_EXCEPTION_IF_NULL(concat_op);
std::vector<size_t> shape = {
AnfAlgo::GetOutputInferShape(node2, 0)[DIM0], AnfAlgo::GetOutputInferShape(node2, 0)[DIM1],
AnfAlgo::GetOutputInferShape(node1, 0)[DIM2] + AnfAlgo::GetOutputInferShape(node2, 0)[DIM2]};
AnfAlgo::GetOutputInferShape(node2, 0)[kDim0], AnfAlgo::GetOutputInferShape(node2, 0)[kDim1],
AnfAlgo::GetOutputInferShape(node1, 0)[kDim2] + AnfAlgo::GetOutputInferShape(node2, 0)[kDim2]};
auto types = {AnfAlgo::GetOutputInferDataType(node2, 0)};
AnfAlgo::SetOutputInferTypeAndShape(types, {shape}, concat_op.get());
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(kConcatNum)), concat_op);
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{2}), concat_op);
AnfAlgo::SetNodeAttr("axis", MakeValue(SizeToLong(DIM2)), concat_op);
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(kDim2)), concat_op);
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat_op);
return concat_op;
}
@ -211,15 +195,15 @@ AnfNodePtr CreateWBroadcastToDNode(const FuncGraphPtr &graph, const AnfNodePtr &
std::vector<AnfNodePtr> braodcast_to_input = {NewValueNode(std::make_shared<Primitive>(kBroadcastToOpName)), node1};
auto broadcast_to_d = graph->NewCNode(braodcast_to_input);
MS_EXCEPTION_IF_NULL(broadcast_to_d);
size_t t_size = AnfAlgo::GetOutputInferShape(node2, 0)[DIM0];
size_t batch = AnfAlgo::GetOutputInferShape(node1, 0)[DIM0];
size_t gate_size = AnfAlgo::GetOutputInferShape(node1, 0)[DIM1];
size_t t_size = AnfAlgo::GetOutputInferShape(node2, 0)[kDim0];
size_t batch = AnfAlgo::GetOutputInferShape(node1, 0)[kDim0];
size_t gate_size = AnfAlgo::GetOutputInferShape(node1, 0)[kDim1];
std::vector<size_t> shape = {t_size, batch, gate_size};
auto type = {AnfAlgo::GetOutputInferDataType(node1, 0)};
AnfAlgo::SetOutputInferTypeAndShape(type, {shape}, broadcast_to_d.get());
std::vector<int64_t> attr_shape = {SizeToLong(t_size), SizeToLong(batch), SizeToLong(gate_size)};
AnfAlgo::SetNodeAttr("shape", MakeValue(attr_shape), broadcast_to_d);
AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(attr_shape), broadcast_to_d);
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), broadcast_to_d);
return broadcast_to_d;
}
@ -233,9 +217,9 @@ AnfNodePtr CreateDhxBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &nod
node1, node2};
auto batch_matmul = graph->NewCNode(matmul_inputs);
MS_EXCEPTION_IF_NULL(batch_matmul);
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node1, 0)[DIM0],
AnfAlgo::GetOutputInferShape(node1, 0)[DIM2],
AnfAlgo::GetOutputInferShape(node2, 0)[DIM2]};
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node1, 0)[kDim0],
AnfAlgo::GetOutputInferShape(node1, 0)[kDim2],
AnfAlgo::GetOutputInferShape(node2, 0)[kDim2]};
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {shape}, batch_matmul.get());
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(true), batch_matmul);
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul);
@ -252,9 +236,9 @@ AnfNodePtr CreateDwhBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &nod
node1, node2};
auto batch_matmul = graph->NewCNode(matmul_inputs);
MS_EXCEPTION_IF_NULL(batch_matmul);
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node1, 0)[DIM0],
AnfAlgo::GetOutputInferShape(node1, 0)[DIM1],
AnfAlgo::GetOutputInferShape(node2, 0)[DIM1]};
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node1, 0)[kDim0],
AnfAlgo::GetOutputInferShape(node1, 0)[kDim1],
AnfAlgo::GetOutputInferShape(node2, 0)[kDim1]};
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {shape}, batch_matmul.get());
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), batch_matmul);
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(true), batch_matmul);
@ -290,7 +274,7 @@ AnfNodePtr CreateDbReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &n
MS_EXCEPTION_IF_NULL(reduce_sumd);
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
std::vector<size_t> shape = {kGateNum * AnfAlgo::GetOutputInferShape(node2, 0)[DIM1]};
std::vector<size_t> shape = {kGateNum * AnfAlgo::GetOutputInferShape(node2, 0)[kDim1]};
AnfAlgo::SetOutputInferTypeAndShape(types, {shape}, reduce_sumd.get());
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0, 1}), reduce_sumd);
AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_sumd);
@ -322,7 +306,7 @@ const AnfNodePtr DynamicGRUV2GradFission::Process(const FuncGraphPtr &func_graph
auto gru_v2_gru_hidden = CreateGRUV2HiddenGradNode(func_graph, dynamic_gru_v2_grad_cnode);
std::vector<AnfNodePtr> gru_hidden_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, gru_v2_gru_hidden, kGRUV2HiddenGradOutputNum, &gru_hidden_outputs);
size_t step_num = AnfAlgo::GetOutputInferShape(ori_inputs[kIndex1], 0)[DIM0];
size_t step_num = AnfAlgo::GetOutputInferShape(ori_inputs[kIndex1], 0)[kDim0];
AnfNodePtr dwh_batch_matmul = nullptr;
if (step_num != 1) {
// split h

View File

@ -36,9 +36,9 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
std::vector<AnfNodePtr> matmul_nodes;
std::vector<AnfNodePtr> split_nodes;
// Get the size of t
auto origin_input9_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(11), 0);
size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(9), 0)[0];
auto input_i_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(12), 0);
auto origin_input9_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex11), 0);
size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex9), 0)[0];
auto input_i_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex12), 0);
for (size_t i = 0; i < t_size; ++i) {
// Create basic_lstm_cell_c_state_grad
@ -46,15 +46,16 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
NewValueNode(std::make_shared<Primitive>(kBasicLSTMCellCStateGradV2OpName))};
auto basic_lstm_cell_c_state_grad = func_graph->NewCNode(basic_lstm_cell_c_state_grad_inputs);
std::vector<size_t> output0_dims{origin_input9_shape[0], 4 * (((origin_input9_shape[1] + 15) / 16) * 16)};
std::vector<size_t> output1_dims{input_i_shape[1], input_i_shape[2]};
std::vector<size_t> output0_dims{origin_input9_shape[kDim0],
4 * (((origin_input9_shape[kDim1] + kCubeSize - 1) / kCubeSize) * kCubeSize)};
std::vector<size_t> output1_dims{input_i_shape[kDim1], input_i_shape[kDim2]};
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16, kNumberTypeFloat32}, {output0_dims, output1_dims},
basic_lstm_cell_c_state_grad.get());
AnfAlgo::SetNodeAttr("forget_bias", MakeValue(1.0f), basic_lstm_cell_c_state_grad);
AnfAlgo::SetNodeAttr("activation", MakeValue("Tanh"), basic_lstm_cell_c_state_grad);
// Create matmul
auto origin_input1_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(2), 0);
auto origin_input1_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex2), 0);
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))};
auto matmul = func_graph->NewCNode(matmul_inputs);
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {{IntToSize(1), output0_dims[0], origin_input1_shape[0]}},
@ -67,14 +68,15 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
auto split_v = func_graph->NewCNode(splitv_input);
auto origin_output2_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 2);
auto origin_output3_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 3);
std::vector<size_t> split_v_output0_shape{IntToSize(1), origin_output2_shape[1], origin_output2_shape[2]};
std::vector<size_t> split_v_output1_shape{IntToSize(1), origin_output3_shape[0], origin_output3_shape[1]};
std::vector<size_t> split_v_output0_shape{IntToSize(1), origin_output2_shape[kDim1], origin_output2_shape[kDim2]};
std::vector<size_t> split_v_output1_shape{IntToSize(1), origin_output3_shape[kDim0], origin_output3_shape[kDim1]};
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32},
{split_v_output0_shape, split_v_output1_shape}, split_v.get());
AnfAlgo::SetNodeAttr(kAttrSizeSplits,
MakeValue(std::vector<int64_t>{SizeToLong((origin_output2_shape[2] + 15) / 16 * 16),
SizeToLong((origin_output3_shape[1] + 15) / 16 * 16)}),
MakeValue(std::vector<int64_t>{
SizeToLong((origin_output2_shape[kDim2] + kCubeSize - 1) / kCubeSize * kCubeSize),
SizeToLong((origin_output3_shape[kDim1] + kCubeSize - 1) / kCubeSize * kCubeSize)}),
split_v);
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(2)), split_v);
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(static_cast<int64_t>(2)), split_v);
@ -107,10 +109,10 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
std::vector<std::vector<AnfNodePtr>> result_nodes;
CreateTLoopNode(func_graph, dynamic_rnn_grad_cnode, &result_nodes);
auto origin_input5_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(6), 0);
auto origin_input5_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0);
std::vector<size_t> split_c_dims{IntToSize(1), origin_input5_shape[0], origin_input5_shape[1]};
auto origin_input7 = dynamic_rnn_grad_cnode->input(8);
auto origin_input7 = dynamic_rnn_grad_cnode->input(kIndex8);
size_t num_split_x = AnfAlgo::GetOutputInferShape(origin_input7, 0)[0];
std::vector<std::vector<size_t>> split_shapes;
std::vector<TypeId> split_types;
@ -126,47 +128,47 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_c, num_split_x, &lstm_split_c_outputs);
// Create lstm_split_dy
auto lstm_split_dy =
CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(9), split_shapes, split_types, size_split, num_split_x);
auto lstm_split_dy = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex9), split_shapes, split_types,
size_split, num_split_x);
std::vector<AnfNodePtr> lstm_split_dy_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_dy, num_split_x, &lstm_split_dy_outputs);
// Create lstm_split_i
auto lstm_split_i =
CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(12), split_shapes, split_types, size_split, num_split_x);
auto lstm_split_i = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex12), split_shapes, split_types,
size_split, num_split_x);
std::vector<AnfNodePtr> lstm_split_i_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_i, num_split_x, &lstm_split_i_outputs);
// Create lstm_split_j
auto lstm_split_j =
CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(13), split_shapes, split_types, size_split, num_split_x);
auto lstm_split_j = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex13), split_shapes, split_types,
size_split, num_split_x);
std::vector<AnfNodePtr> lstm_split_j_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_j, num_split_x, &lstm_split_j_outputs);
// Create lstm_split_f
auto lstm_split_f =
CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(14), split_shapes, split_types, size_split, num_split_x);
auto lstm_split_f = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex14), split_shapes, split_types,
size_split, num_split_x);
std::vector<AnfNodePtr> lstm_split_f_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_f, num_split_x, &lstm_split_f_outputs);
// Create lstm_split_o
auto lstm_split_o =
CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(15), split_shapes, split_types, size_split, num_split_x);
auto lstm_split_o = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex15), split_shapes, split_types,
size_split, num_split_x);
std::vector<AnfNodePtr> lstm_split_o_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_o, num_split_x, &lstm_split_o_outputs);
// Create lstm_split_tanh
auto lstm_split_tanh =
CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(16), split_shapes, split_types, size_split, num_split_x);
auto lstm_split_tanh = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex16), split_shapes,
split_types, size_split, num_split_x);
std::vector<AnfNodePtr> lstm_split_tanh_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_tanh, num_split_x, &lstm_split_tanh_outputs);
// Add edges
std::vector<AnfNodePtr> pre_basic_lstm_cell_c_state_grad_outputs;
std::vector<AnfNodePtr> pre_split_outputs;
auto basic_lstm_cell_c_state_grad_nodes = result_nodes[0];
auto matmul_nodes = result_nodes[1];
auto split_nodes = result_nodes[2];
auto basic_lstm_cell_c_state_grad_nodes = result_nodes[kIndex0];
auto matmul_nodes = result_nodes[kIndex1];
auto split_nodes = result_nodes[kIndex2];
std::vector<AnfNodePtr> lstm_x_concat_input(num_split_x + 1);
lstm_x_concat_input[0] = NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()));
std::vector<AnfNodePtr> lstm_gage_concat_input(num_split_x + 1);
@ -181,8 +183,9 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
dynamic_rnn_grad_cnode->input(6)};
auto reshape = func_graph->NewCNode(reshape_inputs);
auto reshape_out_shape = {IntToSize(1), AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(6), 0)[0],
AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(6), 0)[1]};
auto reshape_out_shape = {IntToSize(1),
AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0)[0],
AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0)[1]};
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {reshape_out_shape}, reshape.get());
basic_lstm_cell_c_state_grad_inputs.emplace_back(reshape);
} else {
@ -190,8 +193,8 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
}
basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_dy_outputs[idx]);
if (i == 0) {
basic_lstm_cell_c_state_grad_inputs.emplace_back(dynamic_rnn_grad_cnode->input(10));
basic_lstm_cell_c_state_grad_inputs.emplace_back(dynamic_rnn_grad_cnode->input(11));
basic_lstm_cell_c_state_grad_inputs.emplace_back(dynamic_rnn_grad_cnode->input(kIndex10));
basic_lstm_cell_c_state_grad_inputs.emplace_back(dynamic_rnn_grad_cnode->input(kIndex11));
} else {
basic_lstm_cell_c_state_grad_inputs.emplace_back(pre_split_outputs[1]);
basic_lstm_cell_c_state_grad_inputs.emplace_back(pre_basic_lstm_cell_c_state_grad_outputs[1]);
@ -213,7 +216,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
// Create MatMul
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))};
matmul_inputs.emplace_back(basic_lstm_cell_c_state_grad_outputs[0]);
matmul_inputs.emplace_back(dynamic_rnn_grad_cnode->input(2));
matmul_inputs.emplace_back(dynamic_rnn_grad_cnode->input(kIndex2));
auto matmul = func_graph->NewCNode(matmul_inputs);
MS_EXCEPTION_IF_NULL(matmul);
matmul->set_abstract(matmul_nodes[i]->abstract());
@ -237,7 +240,8 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
auto basic_lstm_cell_c_state_grad_outputs_0_shape =
AnfAlgo::GetOutputInferShape(basic_lstm_cell_c_state_grad_outputs[0], 0);
std::vector<size_t> temp_shape;
if (basic_lstm_cell_c_state_grad_outputs_0_shape.size() == 3) {
constexpr size_t kBasicLstmCStateGradOutput0DimNum = 3;
if (basic_lstm_cell_c_state_grad_outputs_0_shape.size() == kBasicLstmCStateGradOutput0DimNum) {
temp_shape = basic_lstm_cell_c_state_grad_outputs_0_shape;
} else {
temp_shape = {1, basic_lstm_cell_c_state_grad_outputs_0_shape[0],
@ -262,9 +266,9 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
// Create lstm_gage_concat
auto lstm_gage_concat = func_graph->NewCNode(lstm_gage_concat_input);
auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0);
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16},
{{origin_input7_shape[0], origin_input7_shape[1], 4 * origin_input7_shape[2]}},
lstm_gage_concat.get());
AnfAlgo::SetOutputInferTypeAndShape(
{kNumberTypeFloat16}, {{origin_input7_shape[kDim0], origin_input7_shape[kDim1], 4 * origin_input7_shape[kDim2]}},
lstm_gage_concat.get());
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_gage_concat);
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(num_split_x)}), lstm_gage_concat);
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(0)), lstm_gage_concat);
@ -279,15 +283,15 @@ AnfNodePtr CreateSplitV(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
// Create node
auto origin_input6 = dynamic_rnn_grad_cnode->input(7);
auto origin_input6 = dynamic_rnn_grad_cnode->input(kIndex7);
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
origin_input6};
auto split_v = func_graph->NewCNode(splitv_input);
// Set infer data type and shape
auto dtypes = {AnfAlgo::GetOutputInferDataType(origin_input6, 0), AnfAlgo::GetOutputInferDataType(origin_input6, 0)};
auto origin_input6_shape = AnfAlgo::GetOutputInferShape(origin_input6, 0);
std::vector<size_t> shape1 = {origin_input6_shape[0] - 1, origin_input6_shape[1], origin_input6_shape[2]};
std::vector<size_t> shape2 = {1, origin_input6_shape[1], origin_input6_shape[2]};
std::vector<size_t> shape1 = {origin_input6_shape[kDim0] - 1, origin_input6_shape[kDim1], origin_input6_shape[kDim2]};
std::vector<size_t> shape2 = {1, origin_input6_shape[kDim1], origin_input6_shape[kDim2]};
std::vector<std::vector<size_t>> shapes = {shape1, shape2};
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_v.get());
// Set attr
@ -311,11 +315,12 @@ AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic
MS_LOG(EXCEPTION) << "Create outputs of node " << splitv->DebugString() << " failed"
<< " trace: " << trace::DumpSourceLines(dynamic_rnn_grad_cnode);
}
auto origin_input4 = dynamic_rnn_grad_cnode->input(5);
auto origin_input4 = dynamic_rnn_grad_cnode->input(kIndex5);
auto origin_input4_shape = AnfAlgo::GetOutputInferShape(origin_input4, 0);
// Create reshape to change shape
std::vector<size_t> shape_tmp;
if (origin_input4_shape.size() == 3) {
constexpr size_t kInput4DimNum = 3;
if (origin_input4_shape.size() == kInput4DimNum) {
shape_tmp = origin_input4_shape;
} else {
shape_tmp = {1, origin_input4_shape[0], origin_input4_shape[1]};
@ -351,8 +356,8 @@ AnfNodePtr CreateConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_
// Set infer data type and shape
auto origin_output0_shape = AnfAlgo::GetOutputInferShape(origin_input0, 0);
auto h_concat_output_shape = AnfAlgo::GetOutputInferShape(h_concat, 0);
std::vector<size_t> shape = {origin_output0_shape[0], origin_output0_shape[1],
origin_output0_shape[2] + h_concat_output_shape[2]};
std::vector<size_t> shape = {origin_output0_shape[kDim0], origin_output0_shape[kDim1],
origin_output0_shape[kDim2] + h_concat_output_shape[kDim2]};
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input0, 0)}, {shape}, concat.get());
// Set attr
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(2)), concat);
@ -366,8 +371,8 @@ AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dy
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
// Create node
auto origin_input0 = dynamic_rnn_grad_cnode->input(1);
auto origin_input4 = dynamic_rnn_grad_cnode->input(5);
auto origin_input0 = dynamic_rnn_grad_cnode->input(kIndex1);
auto origin_input4 = dynamic_rnn_grad_cnode->input(kIndex5);
auto origin_input4_shape = AnfAlgo::GetOutputInferShape(origin_input4, 0);
// Create reshape to change shape
std::vector<size_t> shape_tmp;
@ -386,7 +391,8 @@ AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dy
auto concat = func_graph->NewCNode(concat_inputs);
// Set infer data type and shape
auto origin_input0_shape = AnfAlgo::GetOutputInferShape(origin_input0, 0);
std::vector<size_t> shape = {origin_input0_shape[0], origin_input0_shape[1], origin_input0_shape[2] + shape_tmp[2]};
std::vector<size_t> shape = {origin_input0_shape[kDim0], origin_input0_shape[kDim1],
origin_input0_shape[kDim2] + shape_tmp[kDim2]};
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input0, 0)}, {shape}, concat.get());
// Set attr
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(2)), concat);
@ -406,7 +412,7 @@ AnfNodePtr CreateBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &l
// Set infer data type and shape
auto concat_shape = AnfAlgo::GetOutputInferShape(concat, 0);
auto lstm_input_grad_shape = AnfAlgo::GetOutputInferShape(lstm_input_grad, 0);
std::vector<size_t> shape = {concat_shape[0], concat_shape[2], lstm_input_grad_shape[2]};
std::vector<size_t> shape = {concat_shape[kDim0], concat_shape[kDim2], lstm_input_grad_shape[kDim2]};
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {shape}, batch_matmul.get());
// Set attr
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul);
@ -451,7 +457,7 @@ AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dyn
}
AnfNodePtr CreateValueNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) {
auto origin_input7 = dynamic_rnn_grad_cnode->input(8);
auto origin_input7 = dynamic_rnn_grad_cnode->input(kIndex8);
auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0);
auto t_size = origin_input7_shape[0];
auto n_size = origin_input7_shape[1];
@ -477,7 +483,7 @@ AnfNodePtr CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dyn
batch_matmul};
auto reduce_sum = func_graph->NewCNode(reduce_sum_inputs);
// Set infer data type and shape
auto out_shape = {AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[2]};
auto out_shape = {AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[kDim2]};
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {out_shape}, reduce_sum.get());
// Set attr
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0}), reduce_sum);
@ -506,7 +512,7 @@ const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph
std::vector<AnfNodePtr> new_outputs;
auto lstm_input_grad = AddLSTMInputGradNode(func_graph, dynamic_rnn_grad_cnode, &new_outputs);
size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(7), 0)[0];
size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex7), 0)[0];
AnfNodePtr concat = nullptr;
if (t_size != 1) {
auto splitv = CreateSplitV(func_graph, dynamic_rnn_grad_cnode);

View File

@ -53,6 +53,7 @@ CNodePtr CreatePad(const FuncGraphPtr &graph, const CNodePtr &origin_node, const
shape[shape.size() - 1] = SizeToLong(pad_dim_size);
auto type_id = AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0);
auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type_id), shape);
MS_EXCEPTION_IF_NULL(abstract);
if (param_dyn_shape->max_shape().size() == param_dyn_shape->shape().size() &&
param_dyn_shape->min_shape().size() == param_dyn_shape->shape().size()) {
ShapeVector max_shape(param_dyn_shape->max_shape());
@ -134,8 +135,7 @@ bool CheckInputs(const CNodePtr &origin_node) {
auto param_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
auto indice_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 1);
// this optimizer only support embedding_table has dynamic shape
constexpr size_t DIM2 = 2;
if (param_shape.empty() || indice_shape.empty() || AnfAlgo::IsDynamicShape(origin_node->input(DIM2))) {
if (param_shape.empty() || indice_shape.empty() || AnfAlgo::IsDynamicShape(origin_node->input(kDim2))) {
return false;
}
if (param_shape[param_shape.size() - 1] != 1) {

View File

@ -28,6 +28,7 @@ constexpr size_t kLinSpaceInputNum = 3;
constexpr size_t kFloat32Len = 4;
tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
// 1 get tensor value of input num
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input_num = cnode->input(kLinSpaceInputNum);

View File

@ -26,12 +26,10 @@ namespace opt {
constexpr size_t kInputNum = 3;
constexpr size_t kFloat16Len = 2; // size of float16;
constexpr size_t kKernelSizeNum = 5;
constexpr size_t DIM2 = 2;
constexpr size_t DIM3 = 3;
constexpr size_t DIM4 = 4;
namespace {
tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
// 1 get attr ksize
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto ksize = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, "kernel_size");
@ -42,9 +40,9 @@ tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
if (ksize.size() != kKernelSizeNum) {
MS_LOG(EXCEPTION) << "kernel_size of MaxPool3DGradGrad must be five, but got :" << ksize;
}
int64_t d = ksize[DIM2];
int64_t h = ksize[DIM3];
int64_t w = ksize[DIM4];
int64_t d = ksize[kDim2];
int64_t h = ksize[kDim3];
int64_t w = ksize[kDim4];
// 1 create tensor
std::vector<int64_t> assist_shape = {1, 1, d, h, w}; // shape:NCDHW

View File

@ -27,6 +27,7 @@ CNodePtr CreateReduceMin(const FuncGraphPtr &graph, const AnfNodePtr &input, con
MS_EXCEPTION_IF_NULL(old_node);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceMin->name())), input};
CNodePtr reduce_min = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(reduce_min);
reduce_min->set_scope(old_node->scope());
AnfAlgo::CopyNodeAttr(kAttrKeepDims, old_node, reduce_min);
return reduce_min;

View File

@ -40,8 +40,7 @@ AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodeP
MS_EXCEPTION_IF_NULL(bn_training_reduce);
// set abstract
constexpr size_t DIM2 = 2;
auto bn_input1 = bn_cnode->input(DIM2);
auto bn_input1 = bn_cnode->input(kDim2);
MS_EXCEPTION_IF_NULL(bn_input1);
AbstractBasePtrList abstract_list{bn_input1->abstract(), bn_input1->abstract()};
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
@ -67,11 +66,11 @@ AnfNodePtr CreateBNTrainingUpdateV3(const FuncGraphPtr &func_graph, const AnfNod
}
std::vector<AnfNodePtr> bn_training_update_v3_inputs = {
NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateV3OpName)),
bn_cnode->input(1),
bn_training_reduce_outputs[0],
bn_training_reduce_outputs[1],
bn_cnode->input(2),
bn_cnode->input(3)};
bn_cnode->input(kIndex1),
bn_training_reduce_outputs[kIndex0],
bn_training_reduce_outputs[kIndex1],
bn_cnode->input(kIndex2),
bn_cnode->input(kIndex3)};
auto bn_training_update_v3 = func_graph->NewCNode(bn_training_update_v3_inputs);
MS_EXCEPTION_IF_NULL(bn_training_update_v3);

View File

@ -26,20 +26,17 @@ namespace opt {
namespace {
constexpr size_t kFloat16Len = 2;
constexpr size_t kSpaceToDepthInputNum = 1;
constexpr size_t kInputIndex1 = 1;
constexpr size_t DIM1 = 1;
constexpr size_t DIM2 = 2;
constexpr size_t DIM3 = 3;
tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
// 1 create tensor
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input_x = cnode->input(kSpaceToDepthInputNum);
int64_t block_size = AnfAlgo::GetNodeAttr<int64_t>(cnode, "block_size");
std::vector<size_t> x_shape = AnfAlgo::GetOutputInferShape(input_x, 0);
int64_t input_channel = SizeToLong(x_shape[DIM1]);
int64_t assist_input_channel = SizeToLong(x_shape[DIM1]) * block_size * block_size;
int64_t input_channel = SizeToLong(x_shape[kDim1]);
int64_t assist_input_channel = SizeToLong(x_shape[kDim1]) * block_size * block_size;
std::vector<int64_t> assist_input_shape = {assist_input_channel, input_channel, block_size, block_size};
int64_t dest_size = assist_input_channel * input_channel * block_size * block_size;
MS_LOG(DEBUG) << "For SpaceToDepth op, assist input shape is: (" << assist_input_channel << ", " << input_channel
@ -50,7 +47,7 @@ tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
assist_tensor->set_device_info(device_info);
// 2 set value of tensor
int64_t window_size = assist_input_shape[DIM2] * assist_input_shape[DIM3];
int64_t window_size = assist_input_shape[kDim2] * assist_input_shape[kDim3];
int64_t channel_size = input_channel;
auto data_ptr = assist_tensor->data_c();
MS_EXCEPTION_IF_NULL(data_ptr);
@ -108,7 +105,7 @@ const AnfNodePtr SpaceToDepthSplit::Process(const FuncGraphPtr &graph, const Anf
return nullptr;
}
const auto &ori_inputs = cnode->inputs();
TypeId x_dtype = AnfAlgo::GetOutputInferDataType(ori_inputs[kInputIndex1], 0);
TypeId x_dtype = AnfAlgo::GetOutputInferDataType(ori_inputs[kIndex1], 0);
if (x_dtype != kNumberTypeFloat16) {
MS_LOG(INFO) << "Node " << cnode->DebugString() << ": The data type of node's first input is: " << x_dtype
<< ", not fp16, cannot do fusion.";

View File

@ -62,7 +62,7 @@ CNodePtr TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePt
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input_node = node->cast<CNodePtr>()->input(1);
auto input_node = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(input_node);
auto input_format = AnfAlgo::GetInputFormat(node, 0);

View File

@ -27,7 +27,7 @@ CNodePtr CreatePadding(const FuncGraphPtr &graph, const CNodePtr &origin_node, c
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(origin_node);
std::vector<AnfNodePtr> padding_inputs = {NewValueNode(std::make_shared<Primitive>(kPaddingOpName)),
origin_node->input(1)};
origin_node->input(kIndex1)};
auto padding = graph->NewCNode(padding_inputs);
MS_EXCEPTION_IF_NULL(padding);
padding->set_scope(origin_node->scope());
@ -45,7 +45,8 @@ CNodePtr CreateUnsortedSegmentSum(const FuncGraphPtr &graph, const CNodePtr &ori
MS_EXCEPTION_IF_NULL(origin_node);
MS_EXCEPTION_IF_NULL(padding);
std::vector<AnfNodePtr> unsorted_segment_sum8_inputs = {
NewValueNode(std::make_shared<Primitive>(prim::kPrimUnsortedSegmentSum->name())), padding, origin_node->input(2)};
NewValueNode(std::make_shared<Primitive>(prim::kPrimUnsortedSegmentSum->name())), padding,
origin_node->input(kIndex2)};
auto unsorted_segment_sum = graph->NewCNode(unsorted_segment_sum8_inputs);
MS_EXCEPTION_IF_NULL(unsorted_segment_sum);
unsorted_segment_sum->set_scope(origin_node->scope());

View File

@ -35,10 +35,6 @@ constexpr size_t kAvgPoolGradInputNum = 3;
constexpr size_t kShapeDimNum = 4;
constexpr float kKernelMatrixInitNum = 1.0;
constexpr size_t kFloat32Len = 4; // size of float32
constexpr size_t DIM0 = 0;
constexpr size_t DIM1 = 1;
constexpr size_t DIM2 = 2;
constexpr size_t DIM3 = 3;
std::vector<int64_t> GetInputXShape(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
std::vector<int64_t> shapes;
@ -49,6 +45,8 @@ std::vector<int64_t> GetInputXShape(const AnfNodePtr &node) {
int64_t windowed_output_size(int64_t input_size, int64_t ksize, int64_t stride, PadMode pad_mode, int64_t *pad_before,
int64_t *pad_after) {
MS_EXCEPTION_IF_NULL(pad_before);
MS_EXCEPTION_IF_NULL(pad_after);
int64_t output = 0;
*pad_before = 0;
*pad_after = 0;
@ -77,17 +75,20 @@ std::vector<std::vector<float>> GetAssistInputMatrix(const std::vector<int64_t>
// w*w_stride : w*w_stride + w_ksize]` of `assist_input_matrix`, so the sum of slide window is the
// number of input that associate with output element.
std::vector<std::vector<float>> assist_input_matrix;
std::vector<int64_t> in_shape_after_padding_2d = {x_shape[DIM2] + pad_top + pad_bottom,
x_shape[DIM3] + pad_left + pad_right};
std::vector<float> tmp_zero_vector(in_shape_after_padding_2d[1], 0.0);
std::vector<float> tmp_one_vector(in_shape_after_padding_2d[1], 1.0);
for (int64_t i = 0; i < in_shape_after_padding_2d[1]; ++i) {
if (i < pad_left || i >= (in_shape_after_padding_2d[1] - pad_right)) {
if (x_shape.size() < kShapeDimNum) {
MS_LOG(EXCEPTION) << "The dim of x_shape should not be less than 4.";
}
std::vector<int64_t> in_shape_after_padding_2d = {x_shape[kDim2] + pad_top + pad_bottom,
x_shape[kDim3] + pad_left + pad_right};
std::vector<float> tmp_zero_vector(in_shape_after_padding_2d[kDim1], 0.0);
std::vector<float> tmp_one_vector(in_shape_after_padding_2d[kDim1], 1.0);
for (int64_t i = 0; i < in_shape_after_padding_2d[kDim1]; ++i) {
if (i < pad_left || i >= (in_shape_after_padding_2d[kDim1] - pad_right)) {
tmp_one_vector[LongToSize(i)] = 0.0;
}
}
for (int64_t i = 0; i < in_shape_after_padding_2d[0]; ++i) {
if (i < pad_top || i >= (in_shape_after_padding_2d[0] - pad_bottom)) {
for (int64_t i = 0; i < in_shape_after_padding_2d[kDim0]; ++i) {
if (i < pad_top || i >= (in_shape_after_padding_2d[kDim0] - pad_bottom)) {
assist_input_matrix.emplace_back(tmp_zero_vector);
} else {
assist_input_matrix.emplace_back(tmp_one_vector);
@ -106,8 +107,10 @@ ValueNodePtr CreateMeanMatrixValueNode(const FuncGraphPtr &func_graph, const std
MS_LOG(EXCEPTION) << "The dim of x_shape or kernel_size or strides of AvgPoolGrad should be 4.";
}
int64_t pad_top, pad_bottom, pad_left, pad_right;
int64_t h_output = windowed_output_size(x_shape[DIM2], k_size[DIM2], stride[DIM2], pad_mode, &pad_top, &pad_bottom);
int64_t w_output = windowed_output_size(x_shape[DIM3], k_size[DIM3], stride[DIM3], pad_mode, &pad_left, &pad_right);
int64_t h_output =
windowed_output_size(x_shape[kDim2], k_size[kDim2], stride[kDim2], pad_mode, &pad_top, &pad_bottom);
int64_t w_output =
windowed_output_size(x_shape[kDim3], k_size[kDim3], stride[kDim3], pad_mode, &pad_left, &pad_right);
auto assist_input_matrix = GetAssistInputMatrix(x_shape, pad_top, pad_bottom, pad_left, pad_right);
// calculate output
@ -115,8 +118,8 @@ ValueNodePtr CreateMeanMatrixValueNode(const FuncGraphPtr &func_graph, const std
for (int64_t h = 0; h < h_output; ++h) {
for (int64_t w = 0; w < w_output; ++w) {
float curr_sum = 0;
for (int64_t i = h * stride[DIM2]; i < h * stride[DIM2] + k_size[DIM2]; ++i) {
for (int64_t j = w * stride[DIM3]; j < w * stride[DIM3] + k_size[DIM3]; ++j) {
for (int64_t i = h * stride[kDim2]; i < h * stride[kDim2] + k_size[kDim2]; ++i) {
for (int64_t j = w * stride[kDim3]; j < w * stride[kDim3] + k_size[kDim3]; ++j) {
curr_sum += assist_input_matrix[LongToSize(i)][LongToSize(j)];
}
}
@ -127,12 +130,12 @@ ValueNodePtr CreateMeanMatrixValueNode(const FuncGraphPtr &func_graph, const std
}
// make output tensor
std::vector<int64_t> output_shape = {x_shape[0], x_shape[1], h_output, w_output};
std::vector<int64_t> output_shape = {x_shape[kDim0], x_shape[kDim1], h_output, w_output};
auto output_size = std::accumulate(output_shape.begin(), output_shape.end(), int64_t(1), std::multiplies<int64_t>());
std::vector<float> output(output_size, 0.0);
for (int64_t i = 0; i < output_shape[0] * output_shape[1]; ++i) {
for (int64_t i = 0; i < output_shape[kDim0] * output_shape[kDim1]; ++i) {
size_t src_size = hw_output.size() * kFloat32Len;
auto dst_size = LongToSize(output_shape[DIM2]) * LongToSize(output_shape[DIM3]) * kFloat32Len;
auto dst_size = LongToSize(output_shape[kDim2]) * LongToSize(output_shape[kDim3]) * kFloat32Len;
auto ret = memcpy_s(&output[LongToSize(i) * hw_output.size()], dst_size, &hw_output[0], src_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
@ -157,7 +160,7 @@ ValueNodePtr CreateKernelMatrixValueNode(const FuncGraphPtr &func_graph, const s
if (x_shape.size() != kShapeDimNum || k_size.size() != kShapeDimNum) {
MS_LOG(EXCEPTION) << "The dim of x_shape or kernel_size of AvgPoolGrad should be 4.";
}
std::vector<int64_t> kernel_shape = {1, x_shape[DIM1], k_size[DIM2], k_size[DIM3]};
std::vector<int64_t> kernel_shape = {1, x_shape[kDim1], k_size[kDim2], k_size[kDim3]};
auto data_size = std::accumulate(kernel_shape.begin(), kernel_shape.end(), int64_t(1), std::multiplies<int64_t>());
std::vector<float> data(data_size, kKernelMatrixInitNum);
auto kernel_matrix_tensor = std::make_shared<tensor::Tensor>(x_dtype, kernel_shape, &data[0], kNumberTypeFloat32);

View File

@ -35,17 +35,12 @@ AnfNodePtr CreateNewBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &bn_
size_t kBNGradInputNum = 6;
const auto &bn_grad_node_inputs = bn_grad_node->inputs();
CheckCNodeInputSize(bn_grad_node, kBNGradInputNum);
constexpr size_t DIM1 = 1;
constexpr size_t DIM2 = 2;
constexpr size_t DIM3 = 3;
constexpr size_t DIM4 = 4;
constexpr size_t DIM5 = 5;
std::vector<AnfNodePtr> bn_grad_inputs = {NewValueNode(std::make_shared<Primitive>(kBatchNormGradOpName)),
bn_grad_node_inputs[DIM1],
bn_grad_node_inputs[DIM2],
bn_grad_node_inputs[DIM3],
bn_grad_node_inputs[DIM4],
bn_grad_node_inputs[DIM5]};
bn_grad_node_inputs[kDim1],
bn_grad_node_inputs[kDim2],
bn_grad_node_inputs[kDim3],
bn_grad_node_inputs[kDim4],
bn_grad_node_inputs[kDim5]};
auto new_bn_grad = graph->NewCNode(bn_grad_inputs);
MS_EXCEPTION_IF_NULL(new_bn_grad);
new_bn_grad->set_scope(bn_grad_node->scope());

View File

@ -40,8 +40,8 @@ constexpr auto kAttrChannelMultiplier = "channel_multiplier";
constexpr auto kAttrPerm = "perm";
constexpr auto kAttrInputSizes = "input_sizes";
constexpr auto kAttrInputSize = "input_size";
constexpr auto kInput2 = 2;
constexpr auto kInput3 = 3;
constexpr auto kIndex2 = 2;
constexpr auto kIndex3 = 3;
bool NeedUpdate(const CNodePtr &conv2d, std::vector<size_t> in_shape, std::vector<size_t> out_shape) {
MS_EXCEPTION_IF_NULL(conv2d);
@ -63,8 +63,8 @@ bool NeedUpdate(const CNodePtr &conv2d, std::vector<size_t> in_shape, std::vecto
MS_LOG(EXCEPTION) << "Conv2D's input and output should have 4 axis, but got input axis num: " << in_shape.size()
<< "output axis num: " << out_shape.size();
}
auto in_channel = in_shape[1];
auto out_channel = out_shape[1];
auto in_channel = in_shape[kDim1];
auto out_channel = out_shape[kDim1];
if (group != in_channel || group != out_channel) {
return false;
}
@ -117,7 +117,7 @@ CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, cons
MS_LOG(EXCEPTION) << "Conv2D's output axis number should be " << kConv2DAxisNum << ", but got "
<< out_shape.size();
}
std::swap(out_shape[0], out_shape[1]);
std::swap(out_shape[kDim0], out_shape[kDim1]);
auto shapes = {out_shape};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, transpose.get());
} else {
@ -139,7 +139,7 @@ CNodePtr CreateDepthwiseConv2D(const FuncGraphPtr &graph, const CNodePtr &conv2d
MS_EXCEPTION_IF_NULL(conv2d);
CheckCNodeInputSize(conv2d, kConvInputTensorNum);
std::vector<AnfNodePtr> depth_conv_inputs = {NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeOpName)),
conv2d->input(1), transpose};
conv2d->input(kIndex1), transpose};
auto depth_conv = graph->NewCNode(depth_conv_inputs);
MS_EXCEPTION_IF_NULL(depth_conv);
depth_conv->set_abstract(conv2d->abstract());
@ -155,15 +155,15 @@ CNodePtr CreateDepthwiseConv2DBackpropInput(const FuncGraphPtr &graph, const CNo
CNodePtr depth_conv_backin = nullptr;
if (conv2d_backin->inputs().size() == kConv2DBackpropInputNum) {
std::vector<AnfNodePtr> depth_conv_backin_inputs = {
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)), conv2d_backin->input(3),
transpose, conv2d_backin->input(1)};
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)),
conv2d_backin->input(kIndex3), transpose, conv2d_backin->input(kIndex1)};
depth_conv_backin = graph->NewCNode(depth_conv_backin_inputs);
} else {
// In nn.Conv2DTranspose, Conv2DBackpropInput is a forward op and the input_sizes input will be convert to attr
// in pynative mode.
std::vector<AnfNodePtr> depth_conv_backin_inputs = {
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)), transpose,
conv2d_backin->input(1)};
conv2d_backin->input(kIndex1)};
depth_conv_backin = graph->NewCNode(depth_conv_backin_inputs);
AnfAlgo::CopyNodeAttr(kAttrInputSizes, kAttrInputSize, conv2d_backin, depth_conv_backin);
}
@ -180,7 +180,7 @@ CNodePtr CreateDepthwiseConv2DBackpropFilter(const FuncGraphPtr &graph, const CN
MS_LOG(EXCEPTION) << "Conv2DBackpropFilter's input number should be " << (kConv2DBackpropInputNum - 1)
<< ", but got " << (conv2d_backfil->inputs().size() - 1);
}
auto filter_size_node = conv2d_backfil->input(kInput3);
auto filter_size_node = conv2d_backfil->input(kIndex3);
MS_EXCEPTION_IF_NULL(filter_size_node);
auto filter_size_vnode = filter_size_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(filter_size_vnode);
@ -189,11 +189,11 @@ CNodePtr CreateDepthwiseConv2DBackpropFilter(const FuncGraphPtr &graph, const CN
// when the filter_size value is same.
if (filter_size[0] != 1) {
std::swap(filter_size[0], filter_size[1]);
conv2d_backfil->input(kInput3)->cast<ValueNodePtr>()->set_value(MakeValue(filter_size));
conv2d_backfil->input(kIndex3)->cast<ValueNodePtr>()->set_value(MakeValue(filter_size));
}
std::vector<AnfNodePtr> depth_conv_backfil_inputs = {
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropFilterOpName)), conv2d_backfil->input(2),
conv2d_backfil->input(kInput3), conv2d_backfil->input(1)};
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropFilterOpName)),
conv2d_backfil->input(kIndex2), conv2d_backfil->input(kIndex3), conv2d_backfil->input(kIndex1)};
auto depth_conv_backfil = graph->NewCNode(depth_conv_backfil_inputs);
MS_EXCEPTION_IF_NULL(depth_conv_backfil);
depth_conv_backfil->set_scope(conv2d_backfil->scope());
@ -276,7 +276,7 @@ const AnfNodePtr Conv2DUnifyMindIR::Process(const FuncGraphPtr &graph, const Anf
return nullptr;
}
CheckCNodeInputSize(conv2d, kConvInputTensorNum);
auto transpose = CreateTranspose(graph, conv2d, conv2d->input(kInput2), true);
auto transpose = CreateTranspose(graph, conv2d, conv2d->input(kIndex2), true);
auto depth_conv = CreateDepthwiseConv2D(graph, conv2d, transpose);
SetConv2DAttrs(conv2d, depth_conv);
return depth_conv;
@ -307,7 +307,7 @@ const AnfNodePtr Conv2DBackpropInputUnifyMindIR::Process(const FuncGraphPtr &gra
MS_LOG(EXCEPTION) << "Conv2DBackpropInput's input number should be " << (kConv2DBackpropInputNum - 1) << " or "
<< (kConv2DBackpropInputNum - 2) << ", but got " << (input_size - 1);
}
auto transpose = CreateTranspose(graph, conv2d_backin, conv2d_backin->input(kInput2), true);
auto transpose = CreateTranspose(graph, conv2d_backin, conv2d_backin->input(kIndex2), true);
auto depth_conv_backin = CreateDepthwiseConv2DBackpropInput(graph, conv2d_backin, transpose);
SetConv2DBackpropInputAttrs(conv2d_backin, depth_conv_backin);
return depth_conv_backin;

View File

@ -40,7 +40,6 @@ constexpr int64_t kMaskAlignNum = 128;
constexpr int64_t kMaskMultiNum = 16;
constexpr size_t kFloat16Len = 2; // size of float16
constexpr size_t kInt64Len = 8; // size of int64
constexpr auto kInput2 = 2; // size of int64
TypeId GetInputXDataType(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
@ -116,8 +115,8 @@ std::vector<int64_t> CalDropoutGenMaskOutput(const std::vector<int64_t> &shape)
bool NeedUpdate(const CNodePtr &getitem_cnode) {
MS_EXCEPTION_IF_NULL(getitem_cnode);
MS_EXCEPTION_IF_NULL(getitem_cnode->input(kInput2));
auto index_vnode = getitem_cnode->input(kInput2)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(getitem_cnode->input(kIndex2));
auto index_vnode = getitem_cnode->input(kIndex2)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(index_vnode);
auto index_value = index_vnode->value();
MS_EXCEPTION_IF_NULL(index_value);
@ -127,6 +126,7 @@ bool NeedUpdate(const CNodePtr &getitem_cnode) {
CNodePtr CreateDynamicShapeCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node_input,
const abstract::ShapePtr &input_shape) {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> dynamic_shape_inputs{NewValueNode(std::make_shared<Primitive>("DynamicShape")), node_input};
CNodePtr dynamic_shape = func_graph->NewCNode(dynamic_shape_inputs);
MS_EXCEPTION_IF_NULL(dynamic_shape);
@ -143,6 +143,8 @@ CNodePtr CreateDynamicShapeCNode(const FuncGraphPtr &func_graph, const AnfNodePt
CNodePtr CreateDropoutGenMaskCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &dropout,
const ValueNodePtr &keep_prob_value, const AnfNodePtr &dropout_input,
const abstract::ShapePtr &input_shape) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(dropout);
std::vector<AnfNodePtr> dropout_gen_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName))};
if (input_shape->IsDynamic()) {
CNodePtr dynamic_shape = CreateDynamicShapeCNode(func_graph, dropout_input, input_shape);
@ -163,6 +165,7 @@ CNodePtr CreateDropoutGenMaskCNode(const FuncGraphPtr &func_graph, const AnfNode
ShapeVector mask_min_shp = CalDropoutGenMaskOutput(input_shape->min_shape());
ShapeVector mask_max_shp = CalDropoutGenMaskOutput(input_shape->max_shape());
auto gen_mask_shape = std::make_shared<abstract::Shape>(mask_shp, mask_min_shp, mask_max_shp);
MS_EXCEPTION_IF_NULL(gen_mask_shape);
gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, gen_mask_shape);
} else {
auto gen_mask_shape = CalDropoutGenMaskOutput(input_shape->shape());
@ -204,11 +207,11 @@ const AnfNodePtr DropoutAndDropoutGradUnifyMindIR::Process(const FuncGraphPtr &f
MS_EXCEPTION_IF_NULL(node);
auto dropout_grad_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(dropout_grad_cnode);
auto getitem1_node = dropout_grad_cnode->input(kInput2);
auto getitem1_node = dropout_grad_cnode->input(kIndex2);
MS_EXCEPTION_IF_NULL(getitem1_node);
auto getitem1_cnode = getitem1_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(getitem1_cnode);
auto dropout_node = getitem1_cnode->input(1);
auto dropout_node = getitem1_cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(dropout_node);
auto dropout_cnode = dropout_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(dropout_cnode);
@ -216,7 +219,7 @@ const AnfNodePtr DropoutAndDropoutGradUnifyMindIR::Process(const FuncGraphPtr &f
auto inputx_type_id = GetInputXDataType(dropout_node);
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id);
auto dropout_input = dropout_cnode->input(1);
auto dropout_input = dropout_cnode->input(kIndex1);
auto input_shape = GetDropoutInputShape(dropout_input);
// CreateDropoutGenMask
auto dropout_gen_mask =
@ -284,14 +287,14 @@ const AnfNodePtr DropoutUnifyMindIR0::Process(const FuncGraphPtr &func_graph, co
return nullptr;
}
auto dropout_node = tuple_cnode->input(1);
auto dropout_node = tuple_cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(dropout_node);
auto inputx_type_id = GetInputXDataType(dropout_node);
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id);
auto dropout_cnode = dropout_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(dropout_cnode);
auto dropout_input = dropout_cnode->input(1);
auto dropout_input = dropout_cnode->input(kIndex1);
auto input_shape = GetDropoutInputShape(dropout_input);
// CreateDropoutGenMask
@ -332,7 +335,7 @@ const AnfNodePtr DropoutUnifyMindIR1::Process(const FuncGraphPtr &func_graph, co
auto inputx_type_id = GetInputXDataType(dropout_node);
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id);
auto dropout_input = dropout_node->input(1);
auto dropout_input = dropout_node->input(kIndex1);
auto input_shape = GetDropoutInputShape(dropout_input);
// CreateDropoutGenMask
auto dropout_gen_mask =
@ -371,7 +374,7 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph,
// DropoutGrad may not in the same graph with Dropout in heterogeneous scene, and mask input which is a parameter
// in that scene, need to be updated.
auto mask_input = dropout_grad_cnode->input(kInput2);
auto mask_input = dropout_grad_cnode->input(kIndex2);
if (mask_input->isa<Parameter>()) {
// update abstract
auto mask_abstract = mask_input->abstract();
@ -387,7 +390,7 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph,
}
// CreateDropoutDoMask
auto grad_input = dropout_grad_cnode->input(1);
auto grad_input = dropout_grad_cnode->input(kIndex1);
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
grad_input, mask_input, keep_prob_value};
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);

View File

@ -38,8 +38,9 @@ void CreateOutputsOfLSQPerLayerGradD(const FuncGraphPtr &graph, const CNodePtr &
<< " trace: " << trace::DumpSourceLines(lsq_perlayer_grad_node);
}
std::vector<AnfNodePtr> lsq_perlayer_grad_d_inputs = {
NewValueNode(std::make_shared<Primitive>(kFakeLearnedScaleQuantPerLayerGradDOpName)), lsq_perlayer_grad_inputs[1],
lsq_perlayer_grad_inputs[2], lsq_perlayer_grad_inputs[3], lsq_perlayer_grad_inputs[4]};
NewValueNode(std::make_shared<Primitive>(kFakeLearnedScaleQuantPerLayerGradDOpName)),
lsq_perlayer_grad_inputs[kIndex1], lsq_perlayer_grad_inputs[kIndex2], lsq_perlayer_grad_inputs[kIndex3],
lsq_perlayer_grad_inputs[kIndex4]};
auto lsq_perlayer_grad_d = graph->NewCNode(lsq_perlayer_grad_d_inputs);
MS_EXCEPTION_IF_NULL(lsq_perlayer_grad_d);
lsq_perlayer_grad_d->set_scope(lsq_perlayer_grad_node->scope());
@ -72,7 +73,7 @@ void CreateOutputsOfLSQPerLayerReduceGrad(const FuncGraphPtr &graph, const CNode
}
std::vector<AnfNodePtr> lsq_perlayer_reduce_grad_inputs = {
NewValueNode(std::make_shared<Primitive>(kFakeLearnedScaleQuantPerLayerGradDReduceOpName)),
lsq_perlayer_grad_d_outputs[1]};
lsq_perlayer_grad_d_outputs[kIndex1]};
auto lsq_perlayer_reduce_grad = graph->NewCNode(lsq_perlayer_reduce_grad_inputs);
MS_EXCEPTION_IF_NULL(lsq_perlayer_reduce_grad);
lsq_perlayer_reduce_grad->set_scope(lsq_perlayer_grad_node->scope());
@ -130,7 +131,7 @@ void CreateOutputsOfLSQPerChannelReduceGrad(const FuncGraphPtr &graph, const CNo
}
std::vector<AnfNodePtr> lsq_perchannel_reduce_grad_inputs = {
NewValueNode(std::make_shared<Primitive>(kFakeLearnedScaleQuantPerChannelGradDReduceOpName)),
lsq_perchannel_grad_d_outputs[1]};
lsq_perchannel_grad_d_outputs[kIndex1]};
auto lsq_perchannel_reduce_grad = graph->NewCNode(lsq_perchannel_reduce_grad_inputs);
MS_EXCEPTION_IF_NULL(lsq_perchannel_reduce_grad);
lsq_perchannel_reduce_grad->set_scope(lsq_perchannel_grad_node->scope());

View File

@ -32,9 +32,6 @@ constexpr size_t kMaxPoolInputNum = 2;
constexpr size_t kMaxPoolAttrAxisNum = 4;
constexpr size_t kMaxPoolGradInputNum = 4;
constexpr size_t kMaxPoolWithArgmaxOutputNum = 2;
constexpr size_t kIndex1 = 1;
constexpr size_t kIndex2 = 2;
constexpr size_t kIndex3 = 3;
CNodePtr GetMaxPool(const CNodePtr &maxpool_grad) {
MS_EXCEPTION_IF_NULL(maxpool_grad);
@ -55,7 +52,7 @@ CNodePtr CreateMaxPoolWithArgmax(const FuncGraphPtr &graph, const CNodePtr &maxp
<< (maxpool->inputs().size() - 1);
}
std::vector<AnfNodePtr> maxpool_argmax_inputs = {NewValueNode(std::make_shared<Primitive>(kMaxPoolWithArgmaxOpName)),
maxpool->input(1)};
maxpool->input(kIndex1)};
auto maxpool_argmax = graph->NewCNode(maxpool_argmax_inputs);
MS_EXCEPTION_IF_NULL(maxpool_argmax);
maxpool_argmax->set_scope(maxpool->scope());
@ -80,8 +77,8 @@ CNodePtr CreateMaxPoolGradWithArgmax(const FuncGraphPtr &graph, const CNodePtr &
// MaxPoolGrad's inputs are {input, output, grad_input}, MaxPoolGradWithArgmax's inputs are
// {input, grad_input, argmax_output}
std::vector<AnfNodePtr> maxpool_grad_argmax_inputs = {
NewValueNode(std::make_shared<Primitive>(kMaxPoolGradWithArgmaxOpName)), maxpool_grad->input(1),
maxpool_grad->input(kIndex3), maxpool_argmax_outputs[1]};
NewValueNode(std::make_shared<Primitive>(kMaxPoolGradWithArgmaxOpName)), maxpool_grad->input(kIndex1),
maxpool_grad->input(kIndex3), maxpool_argmax_outputs[kIndex1]};
auto maxpool_grad_argmax = graph->NewCNode(maxpool_grad_argmax_inputs);
MS_EXCEPTION_IF_NULL(maxpool_grad_argmax);
maxpool_grad_argmax->set_scope(maxpool_grad->scope());

View File

@ -29,10 +29,6 @@ namespace {
constexpr size_t kMaxPoolGradWithArgmaxInputNum = 4;
constexpr size_t kMaxPoolWithArgmaxShape = 4;
constexpr size_t kAlignBytes = 16;
constexpr size_t kIndex1 = 1;
constexpr size_t kIndex2 = 2;
constexpr size_t kIndex3 = 3;
constexpr size_t kIndex4 = 4;
bool IsC(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
@ -71,11 +67,11 @@ const AnfNodePtr MaxPoolWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &graph
auto ksize = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(maxpool_with_argmax, kAttrKernelSize);
auto output_shape = AnfAlgo::GetOutputInferShape(maxpool_with_argmax, 0);
auto argmax_shape = output_shape;
if (argmax_shape.size() != kMaxPoolWithArgmaxShape) {
MS_LOG(DEBUG) << "argmax's infer shape size not equal 4";
if (argmax_shape.size() != kMaxPoolWithArgmaxShape || ksize.size() != kMaxPoolWithArgmaxShape) {
MS_LOG(EXCEPTION) << "argmax or kernel_size's shape size not equal to 4";
}
argmax_shape[kIndex2] = LongToSize(ksize[kIndex1] * ksize[kIndex2]);
argmax_shape[kIndex3] = (output_shape[kIndex2] * output_shape[kIndex3] + kAlignBytes - 1) / kAlignBytes + 1;
argmax_shape[kDim2] = LongToSize(ksize[kDim1] * ksize[kDim2]);
argmax_shape[kDim3] = (output_shape[kDim2] * output_shape[kDim3] + kAlignBytes - 1) / kAlignBytes + 1;
auto types = {AnfAlgo::GetOutputInferDataType(maxpool_with_argmax, 0), argmax_dtype};
auto shapes = {output_shape, argmax_shape};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, maxpool_with_argmax.get());
@ -105,11 +101,11 @@ const AnfNodePtr MaxPoolGradWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &g
TypeId argmax_dtype = kNumberTypeUInt16;
auto ksize = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(maxpool_grad_with_argmax, kAttrKernelSize);
auto argmax_shape = AnfAlgo::GetOutputInferShape(tuple_getitem0_anf, 0);
if (argmax_shape.size() != kMaxPoolWithArgmaxShape) {
MS_LOG(DEBUG) << "argmax's infer shape size not equal 4";
if (argmax_shape.size() != kMaxPoolWithArgmaxShape || ksize.size() != kMaxPoolWithArgmaxShape) {
MS_LOG(EXCEPTION) << "argmax or kernel_size's shape size not equal to 4";
}
argmax_shape[kIndex3] = (argmax_shape[kIndex2] * argmax_shape[kIndex3] + kAlignBytes - 1) / kAlignBytes + 1;
argmax_shape[kIndex2] = LongToSize(ksize[kIndex1] * ksize[kIndex2]);
argmax_shape[kDim3] = (argmax_shape[kDim2] * argmax_shape[kDim3] + kAlignBytes - 1) / kAlignBytes + 1;
argmax_shape[kDim2] = LongToSize(ksize[kDim1] * ksize[kDim2]);
AnfAlgo::SetOutputInferTypeAndShape({argmax_dtype}, {argmax_shape}, tuple_getitem0_anf.get());
return maxpool_grad_with_argmax;

View File

@ -72,10 +72,8 @@ const AnfNodePtr SliceGradUnifyMindIR::Process(const FuncGraphPtr &graph, const
// set attr paddings
auto x_shape = GetInputXShape(slice_grad);
constexpr auto kInput3 = 3;
constexpr auto kInput4 = 4;
auto begins = GetTupleValue(slice_grad->input(kInput3));
auto sizes = GetTupleValue(slice_grad->input(kInput4));
auto begins = GetTupleValue(slice_grad->input(kIndex3));
auto sizes = GetTupleValue(slice_grad->input(kIndex4));
if (x_shape.size() != begins.size() || begins.size() != sizes.size()) {
MS_LOG(EXCEPTION) << "For SliceGrad, x's shape dim number should be equal to len(begin) and len(size).";
}

View File

@ -33,7 +33,6 @@ constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
namespace mindspore {
namespace opt {
namespace {
constexpr auto kInput2 = 2;
ValueNodePtr CreateValueNode(const ValuePtr &value_ptr, TypeId output_type) {
MS_EXCEPTION_IF_NULL(value_ptr);
auto new_node = std::make_shared<ValueNode>(value_ptr);
@ -72,6 +71,7 @@ CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_
auto value_off_node = CreateValueNode(value_off, kNumberTypeFloat32);
MS_EXCEPTION_IF_NULL(value_off_node);
auto kernel_graph = graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
kernel_graph->AddValueNodeToGraph(value_on_node);
kernel_graph->AddValueNodeToGraph(value_off_node);
@ -83,14 +83,16 @@ CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_
std::vector<AnfNodePtr> one_hot_inputs;
if (is_convert_const_to_attr) {
one_hot_inputs = {NewValueNode(one_hot_primitive), sparse_softmax_node->input(2), value_on_node, value_off_node};
one_hot_inputs = {NewValueNode(one_hot_primitive), sparse_softmax_node->input(kIndex2), value_on_node,
value_off_node};
} else {
auto depth_node = NewValueNode(depth);
MS_EXCEPTION_IF_NULL(depth_node);
auto depth_abstract = std::make_shared<abstract::AbstractScalar>();
MS_EXCEPTION_IF_NULL(depth_abstract);
depth_abstract->set_type(kInt64);
depth_node->set_abstract(depth_abstract);
one_hot_inputs = {NewValueNode(one_hot_primitive), sparse_softmax_node->input(2), depth_node, value_on_node,
one_hot_inputs = {NewValueNode(one_hot_primitive), sparse_softmax_node->input(kIndex2), depth_node, value_on_node,
value_off_node};
}
auto one_hot_node = graph->NewCNode(one_hot_inputs);
@ -112,7 +114,7 @@ CNodePtr CreateSoftmaxCrossEntropyWithLogits(const FuncGraphPtr &graph, const CN
MS_EXCEPTION_IF_NULL(one_hot_node);
CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kSoftmaxCrossEntropyWithLogitsOpName)),
sparse_softmax_node->input(1), one_hot_node};
sparse_softmax_node->input(kIndex1), one_hot_node};
auto softmax_node = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(softmax_node);
softmax_node->set_scope(sparse_softmax_node->scope());
@ -171,6 +173,7 @@ CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_soft
reduce_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
auto kernel_graph = graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
std::vector<AnfNodePtr> inputs;
if (is_pynative) {
inputs = {NewValueNode(reduce_primitive), softmax_output_node};
@ -199,6 +202,7 @@ CNodePtr CreateExpandDims(const FuncGraphPtr &graph, const CNodePtr &real_div_no
auto axis_node = NewValueNode(axis);
MS_EXCEPTION_IF_NULL(axis_node);
auto axis_abstract = std::make_shared<abstract::AbstractScalar>();
MS_EXCEPTION_IF_NULL(axis_abstract);
axis_abstract->set_type(kInt64);
axis_node->set_abstract(axis_abstract);
@ -305,6 +309,7 @@ CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax
auto y_node = CreateValueNode(y, kNumberTypeFloat32);
MS_EXCEPTION_IF_NULL(y_node);
auto kernel_graph = graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
kernel_graph->AddValueNodeToGraph(y_node);
auto real_div_primitive = std::make_shared<Primitive>(kRealDivOpName);
@ -332,7 +337,7 @@ CNodePtr GetSparseNode(const CNodePtr &depend_node, size_t index) {
CNodePtr GetDependNode(const CNodePtr &mul_node) {
MS_EXCEPTION_IF_NULL(mul_node);
CheckCNodeInputSize(mul_node, kMulInputTensorNum);
auto depend_node = mul_node->input(1);
auto depend_node = mul_node->input(kIndex1);
MS_EXCEPTION_IF_NULL(depend_node);
return depend_node->cast<CNodePtr>();
}
@ -357,6 +362,7 @@ CNodePtr CreateMul(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_nod
MS_EXCEPTION_IF_NULL(y_node);
auto kernel_graph = graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
kernel_graph->AddValueNodeToGraph(y_node);
auto mul_primitive = std::make_shared<Primitive>(kMulOpName);
@ -427,7 +433,7 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con
CheckCNodeInputSize(mul_node, kMulInputTensorNum);
auto depend_node = GetDependNode(mul_node);
auto sparse_softmax_node = GetSparseNode(depend_node, kInput2);
auto sparse_softmax_node = GetSparseNode(depend_node, kIndex2);
auto sparse_softmax_node_grad = GetSparseNode(depend_node, 1);
CheckCNodeInputSize(sparse_softmax_node_grad, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
@ -441,7 +447,7 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node);
CNodePtr real_div_node;
if (tile_node == nullptr) {
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(kInput2));
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(kIndex2));
} else {
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node);
}
@ -481,7 +487,7 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::Process(c
auto depend_node = node->cast<CNodePtr>();
auto sparse_softmax_node_grad = GetSparseNode(depend_node, 1);
auto sparse_softmax_node = GetSparseNode(depend_node, kInput2);
auto sparse_softmax_node = GetSparseNode(depend_node, kIndex2);
CNodePtr softmax_node;
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad);
@ -541,7 +547,7 @@ const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Pro
MS_EXCEPTION_IF_NULL(mul_node);
CheckCNodeInputSize(mul_node, kMulInputTensorNum);
auto sparse_softmax_node = mul_node->input(1);
auto sparse_softmax_node = mul_node->input(kIndex1);
auto sparse_softmax_node_grad = sparse_softmax_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(sparse_softmax_node_grad);
CheckCNodeInputSize(sparse_softmax_node_grad, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
@ -555,7 +561,7 @@ const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Pro
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node);
CNodePtr real_div_node;
if (tile_node == nullptr) {
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(kInput2));
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(kIndex2));
} else {
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node);
}

View File

@ -883,12 +883,14 @@ bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &suppor
ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
MS_EXCEPTION_IF_NULL(value_node);
ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
MS_EXCEPTION_IF_NULL(new_value_node);
new_value_node->set_abstract(value_node->abstract());
// create kernel_info fo new value node
auto kernel_info = std::make_shared<device::KernelInfo>();
new_value_node->set_kernel_info(kernel_info);
// create kernel_build_info for new value node
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
// set the format of value_node to DEFAULT_FORMAT
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
// set value node initial device data type = infer data type
@ -920,6 +922,7 @@ void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const C
}
}
}
AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(prim);
auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap();

View File

@ -484,6 +484,35 @@ constexpr auto kUpdateStateRealInput = 2;
// index define of Load
constexpr auto kLoadRealInput = 1;
constexpr auto kLoadStateInput = 2;
// index of input or output
enum Index : size_t {
kIndex0 = 0,
kIndex1,
kIndex2,
kIndex3,
kIndex4,
kIndex5,
kIndex6,
kIndex7,
kIndex8,
kIndex9,
kIndex10,
kIndex11,
kIndex12,
kIndex13,
kIndex14,
kIndex15,
kIndex16,
};
// dim of shape
enum Dim : size_t {
kDim0 = 0,
kDim1,
kDim2,
kDim3,
kDim4,
kDim5,
};
// format
constexpr auto kOpFormat_DEFAULT = "DefaultFormat";