!18361 code check clean
Merge pull request !18361 from yuchaojie/code-clean
This commit is contained in:
commit
1c991331b9
|
@ -34,7 +34,7 @@ void BatchMatmulFusedMulAddFusionPass::MatchBatchMatmulFusedMulAdd(const CNodePt
|
|||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto batch_matmul = cnode->input(2);
|
||||
auto batch_matmul = cnode->input(kIndex2);
|
||||
MS_EXCEPTION_IF_NULL(batch_matmul);
|
||||
if (batch_matmul->isa<CNode>() && AnfAlgo::CheckPrimitiveType(batch_matmul, prim::kPrimBatchMatMul)) {
|
||||
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[batch_matmul].size())};
|
||||
|
|
|
@ -31,8 +31,6 @@ namespace opt {
|
|||
namespace {
|
||||
constexpr size_t kEltwiseInputSize = 2;
|
||||
constexpr size_t kEltwiseOutputSize = 2;
|
||||
constexpr size_t kInputIndex1 = 1;
|
||||
constexpr size_t kInputIndex2 = 2;
|
||||
bool CheckEltwiseInputAndOutputSize(const AnfNodePtr &node) {
|
||||
if (AnfAlgo::GetInputTensorNum(node) == kEltwiseInputSize) {
|
||||
return true;
|
||||
|
@ -54,7 +52,7 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod
|
|||
MS_EXCEPTION_IF_NULL(relu_input);
|
||||
auto add = relu_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(add);
|
||||
auto tuple_getitem = add->input(1);
|
||||
auto tuple_getitem = add->input(kIndex1);
|
||||
std::vector<int64_t> add_output_used_num;
|
||||
add_output_used_num.emplace_back(SizeToLong(manager->node_users()[add].size()));
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(add_output_used_num), add);
|
||||
|
@ -62,7 +60,7 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod
|
|||
if (tuple_getitem->isa<CNode>() && AnfAlgo::GetCNodeName(tuple_getitem) == prim::kPrimTupleGetItem->name()) {
|
||||
auto getitem = tuple_getitem->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(getitem);
|
||||
auto bnupdate = getitem->input(kInputIndex1);
|
||||
auto bnupdate = getitem->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(bnupdate);
|
||||
if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) {
|
||||
std::vector<int64_t> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0);
|
||||
|
@ -73,7 +71,7 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod
|
|||
}
|
||||
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(out_getitem_ptr);
|
||||
auto input2 = out_getitem_ptr->input(kInputIndex2);
|
||||
auto input2 = out_getitem_ptr->input(kIndex2);
|
||||
auto output_idx = GetValue<int64_t>(GetValueNode(input2));
|
||||
output_used_num[output_idx] = SizeToLong(manager->node_users()[out_getitem.first].size());
|
||||
}
|
||||
|
@ -98,7 +96,7 @@ void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::K
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && CheckEltwiseInputAndOutputSize(cnode)) {
|
||||
auto eltwise_input = cnode->input(1);
|
||||
auto eltwise_input = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimAdd)) {
|
||||
MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
|
||||
|
|
|
@ -27,7 +27,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
constexpr size_t INPUT2 = 2;
|
||||
void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr &cnode, const AnfNodePtr &eltwise_input,
|
||||
const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
|
@ -38,7 +37,7 @@ void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr
|
|||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||
auto getitem = eltwise_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(getitem);
|
||||
auto bnupdate = getitem->input(1);
|
||||
auto bnupdate = getitem->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(bnupdate);
|
||||
if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) {
|
||||
std::vector<int64_t> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0);
|
||||
|
@ -49,7 +48,7 @@ void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr
|
|||
}
|
||||
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(out_getitem_ptr);
|
||||
auto input2 = out_getitem_ptr->input(INPUT2);
|
||||
auto input2 = out_getitem_ptr->input(kIndex2);
|
||||
auto output_idx = GetValue<int64_t>(GetValueNode(input2));
|
||||
output_used_num[output_idx] = SizeToLong(manager->node_users()[out_getitem.first].size());
|
||||
}
|
||||
|
|
|
@ -26,15 +26,12 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kInputIndex2 = 2;
|
||||
}
|
||||
void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltwise(
|
||||
const CNodePtr &cnode, const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
std::unordered_set<AnfNodePtr> record{cnode};
|
||||
auto eltwise_input = cnode->input(1);
|
||||
auto eltwise_input = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||
if (CheckDoubleInEltWiseNode(kernel_graph, eltwise_input)) {
|
||||
(void)record.insert(eltwise_input);
|
||||
|
@ -45,7 +42,7 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
|
|||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto input_cnode = eltwise_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
auto double_in_eltwise_input = input_cnode->input(kInputIndex2);
|
||||
auto double_in_eltwise_input = input_cnode->input(kIndex2);
|
||||
MS_EXCEPTION_IF_NULL(double_in_eltwise_input);
|
||||
std::vector<int64_t> conv2d_bp_output_used_num;
|
||||
if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input)) {
|
||||
|
@ -59,7 +56,7 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
|
|||
conv2d_bp_output_used_num.emplace_back(SizeToLong(manager->node_users()[double_in_eltwise_input].size()));
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(conv2d_bp_output_used_num), double_in_eltwise_input);
|
||||
} else {
|
||||
auto double_in_eltwise_input_1 = input_cnode->input(1);
|
||||
auto double_in_eltwise_input_1 = input_cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(double_in_eltwise_input_1);
|
||||
if (!double_in_eltwise_input_1->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input_1)) {
|
||||
return;
|
||||
|
|
|
@ -32,7 +32,7 @@ void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNod
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
std::unordered_set<AnfNodePtr> record{cnode};
|
||||
auto eltwise_input = cnode->input(1);
|
||||
auto eltwise_input = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||
if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) ||
|
||||
fusion_id_allocator->HasFusionIdAttr(eltwise_input)) {
|
||||
|
|
|
@ -34,7 +34,7 @@ void ConvBnReduceFusionPass::MatchConvBnreduce(const CNodePtr &cnode, const sess
|
|||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto conv = cnode->input(1);
|
||||
auto conv = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(conv);
|
||||
if (conv->isa<CNode>() && AnfAlgo::GetCNodeName(conv) == prim::kPrimConv2D->name()) {
|
||||
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[conv].size())};
|
||||
|
|
|
@ -31,7 +31,7 @@ void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, con
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
std::unordered_set<AnfNodePtr> record{cnode};
|
||||
auto eltwise_input = cnode->input(1);
|
||||
auto eltwise_input = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||
if (CheckDoubleInEltWiseNode(kernel_graph, eltwise_input)) {
|
||||
(void)record.insert(eltwise_input);
|
||||
|
@ -40,7 +40,7 @@ void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, con
|
|||
}
|
||||
auto input_cnode = eltwise_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
auto double_in_eltwise_input = input_cnode->input(1);
|
||||
auto double_in_eltwise_input = input_cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(double_in_eltwise_input);
|
||||
if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) ||
|
||||
fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) {
|
||||
|
|
|
@ -31,12 +31,12 @@ void ConvSingleInFusionPass::MatchConvSingleInEltwise(const CNodePtr &cnode, con
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
std::unordered_set<AnfNodePtr> record{cnode};
|
||||
auto eltwise_input = cnode->input(1);
|
||||
auto eltwise_input = cnode->input(kIndex1);
|
||||
while (CheckEltWiseNode(kernel_graph, eltwise_input)) {
|
||||
(void)record.insert(eltwise_input);
|
||||
auto input_cnode = eltwise_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
eltwise_input = input_cnode->input(1);
|
||||
eltwise_input = input_cnode->input(kIndex1);
|
||||
if (record.size() == MAX_ELTWISE_NUM) {
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ void DepthwiseConvEltwiseFusionPass::MatchDepthwiseConvRelu(const CNodePtr &cnod
|
|||
MS_EXCEPTION_IF_NULL(manager);
|
||||
if (is_order) {
|
||||
// DepthwiseConvolution--->Elemwise
|
||||
auto depthwise_conv = cnode->input(1);
|
||||
auto depthwise_conv = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(depthwise_conv);
|
||||
if (cnode->isa<CNode>() && IsPrimitiveCNode(depthwise_conv, prim::kPrimDepthwiseConv2dNative)) {
|
||||
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[depthwise_conv].size())};
|
||||
|
@ -48,7 +48,7 @@ void DepthwiseConvEltwiseFusionPass::MatchDepthwiseConvRelu(const CNodePtr &cnod
|
|||
}
|
||||
} else {
|
||||
// Elemwise-->DepthwiseConvolution
|
||||
auto relu = cnode->input(1);
|
||||
auto relu = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(relu);
|
||||
if (cnode->isa<CNode>() && (IsPrimitiveCNode(relu, prim::kPrimRelu) || IsPrimitiveCNode(relu, prim::kPrimReluV2))) {
|
||||
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[relu].size())};
|
||||
|
@ -73,7 +73,7 @@ void DepthwiseConvEltwiseFusionPass::MatchSingleFusionPattern(const session::Ker
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
|
||||
auto eltwise_input = cnode->input(1);
|
||||
auto eltwise_input = cnode->input(kIndex1);
|
||||
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimDepthwiseConv2dNative)) {
|
||||
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true);
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ void EltwiseFusionPass::MatchEltwise(const CNodePtr &cnode, const session::Kerne
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
std::unordered_set<AnfNodePtr> record{cnode};
|
||||
auto eltwise_input = cnode->input(1);
|
||||
auto eltwise_input = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||
while (CheckEltWiseNode(kernel_graph, eltwise_input)) {
|
||||
(void)record.insert(eltwise_input);
|
||||
|
@ -40,7 +40,7 @@ void EltwiseFusionPass::MatchEltwise(const CNodePtr &cnode, const session::Kerne
|
|||
}
|
||||
auto input_cnode = eltwise_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
eltwise_input = input_cnode->input(1);
|
||||
eltwise_input = input_cnode->input(kIndex1);
|
||||
}
|
||||
if (CheckDoubleInEltWiseNode(kernel_graph, eltwise_input)) {
|
||||
(void)record.insert(eltwise_input);
|
||||
|
|
|
@ -34,7 +34,7 @@ void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNode
|
|||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto matmul = cnode->input(1);
|
||||
auto matmul = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(matmul);
|
||||
if (matmul->isa<CNode>() && (AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimMatMul) ||
|
||||
AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimBatchMatMul))) {
|
||||
|
|
|
@ -56,7 +56,7 @@ void MatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGrap
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
|
||||
auto eltwise_input = cnode->input(1);
|
||||
auto eltwise_input = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimMatMul)) {
|
||||
MatchMatmulEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion);
|
||||
|
|
|
@ -33,7 +33,7 @@ void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const
|
|||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::unordered_set<AnfNodePtr> record{cnode};
|
||||
auto eltwise_input = cnode->input(1);
|
||||
auto eltwise_input = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||
if (CheckMultiOutputEltWiseNode(kernel_graph, eltwise_input)) {
|
||||
std::vector<int64_t> output_used_num{SizeToInt(manager->node_users()[eltwise_input].size())};
|
||||
|
@ -41,7 +41,7 @@ void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const
|
|||
(void)record.insert(eltwise_input);
|
||||
auto input_cnode = eltwise_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
eltwise_input = input_cnode->input(1);
|
||||
eltwise_input = input_cnode->input(kIndex1);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
|
@ -52,7 +52,7 @@ void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const
|
|||
}
|
||||
auto input_cnode = eltwise_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
eltwise_input = input_cnode->input(1);
|
||||
eltwise_input = input_cnode->input(kIndex1);
|
||||
}
|
||||
if (record.size() != MULTI_ELTWISE_SIZE) {
|
||||
return;
|
||||
|
|
|
@ -32,12 +32,12 @@ void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const se
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
std::unordered_set<AnfNodePtr> record{cnode};
|
||||
auto eltwise_input = cnode->input(1);
|
||||
auto eltwise_input = cnode->input(kIndex1);
|
||||
while (CheckEltWiseNode(kernel_graph, eltwise_input)) {
|
||||
(void)record.insert(eltwise_input);
|
||||
auto input_cnode = eltwise_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
eltwise_input = input_cnode->input(1);
|
||||
eltwise_input = input_cnode->input(kIndex1);
|
||||
if (record.size() == MAX_ELTWISE_NUM) {
|
||||
break;
|
||||
}
|
||||
|
@ -52,13 +52,13 @@ void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const se
|
|||
(void)record.insert(eltwise_input);
|
||||
auto previous_input_cnode = eltwise_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(previous_input_cnode);
|
||||
auto previous_eltwise_input = previous_input_cnode->input(1);
|
||||
auto previous_eltwise_input = previous_input_cnode->input(kIndex1);
|
||||
auto previous_size = record.size();
|
||||
while (CheckEltWiseNode(kernel_graph, previous_eltwise_input)) {
|
||||
(void)record.insert(previous_eltwise_input);
|
||||
auto previous_node = previous_eltwise_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(previous_node);
|
||||
previous_eltwise_input = previous_node->input(1);
|
||||
previous_eltwise_input = previous_node->input(kIndex1);
|
||||
if (record.size() - previous_size == MAX_ELTWISE_NUM) {
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -31,12 +31,12 @@ void SegmentEltwiseFusionPass::MatchSegmentEltwise(const CNodePtr &cnode, const
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
std::unordered_set<AnfNodePtr> record{cnode};
|
||||
auto eltwise_input = cnode->input(1);
|
||||
auto eltwise_input = cnode->input(kIndex1);
|
||||
while (CheckEltWiseNode(kernel_graph, eltwise_input)) {
|
||||
(void)record.insert(eltwise_input);
|
||||
auto input_cnode = eltwise_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
eltwise_input = input_cnode->input(1);
|
||||
eltwise_input = input_cnode->input(kIndex1);
|
||||
if (record.size() == MAX_ELTWISE_NUM) {
|
||||
break;
|
||||
}
|
||||
|
@ -51,13 +51,13 @@ void SegmentEltwiseFusionPass::MatchSegmentEltwise(const CNodePtr &cnode, const
|
|||
(void)record.insert(eltwise_input);
|
||||
auto previous_input_cnode = eltwise_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(previous_input_cnode);
|
||||
auto previous_eltwise_input = previous_input_cnode->input(1);
|
||||
auto previous_eltwise_input = previous_input_cnode->input(kIndex1);
|
||||
auto previous_size = record.size();
|
||||
while (CheckEltWiseNode(kernel_graph, previous_eltwise_input)) {
|
||||
(void)record.insert(previous_eltwise_input);
|
||||
auto previous_node = previous_eltwise_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(previous_node);
|
||||
previous_eltwise_input = previous_node->input(1);
|
||||
previous_eltwise_input = previous_node->input(kIndex1);
|
||||
if (record.size() - previous_size == MAX_ELTWISE_NUM) {
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -34,12 +34,12 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
std::unordered_set<AnfNodePtr> record{cnode};
|
||||
auto write_input = cnode->input(1);
|
||||
auto write_input = cnode->input(kIndex1);
|
||||
if (CheckEltWiseNode(kernel_graph, write_input)) {
|
||||
(void)record.insert(write_input);
|
||||
auto input_cnode = write_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
write_input = input_cnode->input(1);
|
||||
write_input = input_cnode->input(kIndex1);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(write_input);
|
||||
if (!write_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(write_input) ||
|
||||
|
@ -53,7 +53,7 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
|
|||
conv_cnode->inputs().size() >= CONV_DOUBLE_IN_INPUT_SIZE &&
|
||||
conv_cnode->inputs().size() <= CONV_QUART_IN_INPUT_SIZE) {
|
||||
(void)record.insert(write_input);
|
||||
auto conv_input = conv_cnode->input(1);
|
||||
auto conv_input = conv_cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(conv_input);
|
||||
if (!conv_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(conv_input) ||
|
||||
fusion_id_allocator->HasFusionIdAttr(conv_input)) {
|
||||
|
|
|
@ -39,8 +39,6 @@ const int8_t ELTWISE_USE = 1;
|
|||
const int8_t MULTI_ELTWISE_USE = 2;
|
||||
const int8_t MAX_MULTI_ELTWISE_SIZE = 4;
|
||||
const int8_t MAX_PURE_BUFFER_SUCC_SIZE = 3;
|
||||
constexpr size_t kInputIndex1 = 1;
|
||||
constexpr size_t kInputIndex2 = 2;
|
||||
constexpr size_t kFusionNodeNumThreshold = 2;
|
||||
constexpr auto kOpAttrFusionId = "fusion_id";
|
||||
|
||||
|
@ -132,12 +130,10 @@ kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector<AnfNodePtr
|
|||
if (AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) {
|
||||
auto tuple_getitem = output->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
outputs_format.emplace_back(
|
||||
AnfAlgo::GetOutputFormat(tuple_getitem->input(kInputIndex1),
|
||||
LongToSize(GetValue<int64_t>(GetValueNode(tuple_getitem->input(kInputIndex2))))));
|
||||
outputs_format.emplace_back(AnfAlgo::GetOutputFormat(
|
||||
tuple_getitem->input(kIndex1), LongToSize(GetValue<int64_t>(GetValueNode(tuple_getitem->input(kIndex2))))));
|
||||
outputs_data_type.emplace_back(AnfAlgo::GetOutputDeviceDataType(
|
||||
tuple_getitem->input(kInputIndex1),
|
||||
LongToSize(GetValue<int64_t>(GetValueNode(tuple_getitem->input(kInputIndex2))))));
|
||||
tuple_getitem->input(kIndex1), LongToSize(GetValue<int64_t>(GetValueNode(tuple_getitem->input(kIndex2))))));
|
||||
} else {
|
||||
outputs_format.emplace_back(AnfAlgo::GetOutputFormat(output, 0));
|
||||
outputs_data_type.emplace_back(AnfAlgo::GetOutputDeviceDataType(output, 0));
|
||||
|
@ -190,6 +186,7 @@ void ReplaceInputNodeInOtherFusionScope(std::unordered_map<int64_t, BufferFusion
|
|||
void ReplaceOldNode(std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos, int64_t fusion_id,
|
||||
const AnfNodePtr &buffer_fusion_kernel, session::KernelGraph *kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
|
||||
auto manager = kernel_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id];
|
||||
|
@ -275,8 +272,8 @@ bool TupleGetitemNodeCompare(const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
|||
MS_LOG(EXCEPTION) << "node's input size less than " << kTupleGetItemInputSize << ", getitem1["
|
||||
<< getitem2->DebugString() << "]";
|
||||
}
|
||||
auto output_idx1 = GetValue<int64_t>(GetValueNode(getitem1->input(kInputIndex2)));
|
||||
auto output_idx2 = GetValue<int64_t>(GetValueNode(getitem2->input(kInputIndex2)));
|
||||
auto output_idx1 = GetValue<int64_t>(GetValueNode(getitem1->input(kIndex2)));
|
||||
auto output_idx2 = GetValue<int64_t>(GetValueNode(getitem2->input(kIndex2)));
|
||||
return output_idx1 < output_idx2;
|
||||
}
|
||||
|
||||
|
@ -311,7 +308,8 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph,
|
|||
for (auto &getitem : tuple_getitem_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(getitem);
|
||||
auto getitem_ptr = getitem->cast<CNodePtr>();
|
||||
auto input2 = getitem_ptr->input(kInputIndex2);
|
||||
MS_EXCEPTION_IF_NULL(getitem_ptr);
|
||||
auto input2 = getitem_ptr->input(kIndex2);
|
||||
auto output_idx = GetValue<int64_t>(GetValueNode(input2));
|
||||
for (int64_t stub_idx = prev_idx; stub_idx < output_idx; ++stub_idx) {
|
||||
auto stub_node = CreateTupleGetItem(node, kernel_graph, LongToSize(stub_idx));
|
||||
|
@ -343,7 +341,7 @@ void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector<A
|
|||
auto real_output = AnfAlgo::VisitKernel(output, 0);
|
||||
auto output_cnode = output->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(output_cnode);
|
||||
auto input2 = output_cnode->input(kInputIndex2);
|
||||
auto input2 = output_cnode->input(kIndex2);
|
||||
auto output_idx = GetValue<int64_t>(GetValueNode(input2));
|
||||
session::AnfWithOutIndex out_pair(real_output.first, output_idx);
|
||||
if (kernel_graph->IsInRefOutputMap(out_pair)) {
|
||||
|
@ -398,6 +396,7 @@ void RemoveCircle(const session::KernelGraph &kernel_graph,
|
|||
void UbPatternFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph,
|
||||
std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos) const {
|
||||
MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos);
|
||||
GetFusionScopeInputNodeList(*kernel_graph, buffer_fusion_infos);
|
||||
GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos);
|
||||
|
|
|
@ -27,9 +27,6 @@ namespace {
|
|||
const std::vector<int64_t> kOutputIndex{0, 3, 4, 5};
|
||||
constexpr size_t kBatchNormRealOutputNum = 3;
|
||||
constexpr size_t kBatchNormRealInputNum = 3;
|
||||
constexpr size_t kInputIndex2 = 2;
|
||||
constexpr size_t kInputIndex3 = 3;
|
||||
constexpr size_t kInputIndex4 = 4;
|
||||
|
||||
bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -72,12 +69,12 @@ AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodeP
|
|||
<< (kBatchNormRealInputNum + 1) << " trace: " << trace::DumpSourceLines(bn);
|
||||
}
|
||||
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), bn_cnode->input(1)};
|
||||
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), bn_cnode->input(kIndex1)};
|
||||
auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs);
|
||||
MS_EXCEPTION_IF_NULL(bn_training_reduce);
|
||||
auto bn_input1 = bn_cnode->input(kInputIndex2);
|
||||
auto bn_input1 = bn_cnode->input(kIndex2);
|
||||
MS_EXCEPTION_IF_NULL(bn_input1);
|
||||
auto bn_input2 = bn_cnode->input(kInputIndex3);
|
||||
auto bn_input2 = bn_cnode->input(kIndex3);
|
||||
MS_EXCEPTION_IF_NULL(bn_input2);
|
||||
AbstractBasePtrList abstract_list{bn_input1->abstract(), bn_input2->abstract()};
|
||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
|
@ -104,11 +101,11 @@ AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNod
|
|||
}
|
||||
std::vector<AnfNodePtr> bn_training_update_v2_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateV2OpName)),
|
||||
bn_cnode->input(1),
|
||||
bn_training_reduce_outputs[0],
|
||||
bn_training_reduce_outputs[1],
|
||||
bn_cnode->input(kInputIndex2),
|
||||
bn_cnode->input(kInputIndex3)};
|
||||
bn_cnode->input(kIndex1),
|
||||
bn_training_reduce_outputs[kIndex0],
|
||||
bn_training_reduce_outputs[kIndex1],
|
||||
bn_cnode->input(kIndex2),
|
||||
bn_cnode->input(kIndex3)};
|
||||
auto bn_training_update_v2 = func_graph->NewCNode(bn_training_update_v2_inputs);
|
||||
MS_EXCEPTION_IF_NULL(bn_training_update_v2);
|
||||
|
||||
|
@ -118,9 +115,9 @@ AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNod
|
|||
MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBnOutputNum << ", but it is "
|
||||
<< bn_abstract_tuple->elements().size() << " trace: " << trace::DumpSourceLines(bn);
|
||||
}
|
||||
std::vector<AbstractBasePtr> abstract_list{bn_abstract_tuple->elements()[0],
|
||||
bn_abstract_tuple->elements()[kInputIndex3],
|
||||
bn_abstract_tuple->elements()[kInputIndex4]};
|
||||
std::vector<AbstractBasePtr> abstract_list{bn_abstract_tuple->elements()[kIndex0],
|
||||
bn_abstract_tuple->elements()[kIndex3],
|
||||
bn_abstract_tuple->elements()[kIndex4]};
|
||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
bn_training_update_v2->set_abstract(abstract_tuple);
|
||||
bn_training_update_v2->set_scope(bn->scope());
|
||||
|
|
|
@ -23,8 +23,6 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kBatchNormGradInferOutputNum = 3;
|
||||
constexpr size_t kElementIndex1 = 1;
|
||||
constexpr size_t kElementIndex2 = 2;
|
||||
bool CheckOutputsIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -136,8 +134,8 @@ AnfNodePtr BatchNormGradInferFission::CreateBNTrainingUpdateGrad(const FuncGraph
|
|||
MS_LOG(EXCEPTION) << "The abstract tuple of node " << bn_grad->DebugString() << "should not be less than 3"
|
||||
<< trace::DumpSourceLines(bn_grad);
|
||||
}
|
||||
std::vector<AbstractBasePtr> abstract_list{bn_grad_abstract_tuple->elements()[kElementIndex1],
|
||||
bn_grad_abstract_tuple->elements()[kElementIndex2]};
|
||||
std::vector<AbstractBasePtr> abstract_list{bn_grad_abstract_tuple->elements()[kIndex1],
|
||||
bn_grad_abstract_tuple->elements()[kIndex2]};
|
||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
bn_training_update_grad->set_abstract(abstract_tuple);
|
||||
AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad, bn_training_update_grad);
|
||||
|
|
|
@ -28,13 +28,6 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kIndex0 = 0;
|
||||
constexpr size_t kIndex1 = 1;
|
||||
constexpr size_t kIndex2 = 2;
|
||||
constexpr size_t kIndex3 = 3;
|
||||
constexpr size_t kIndex4 = 4;
|
||||
constexpr size_t kIndex5 = 5;
|
||||
|
||||
void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
|
||||
std::vector<AnfNodePtr> *bn_update_grad_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
|
|
@ -38,7 +38,7 @@ AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node)
|
|||
new_simoid_inputs.insert(new_simoid_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
CNodePtr new_cnode = func_graph->NewCNode(new_simoid_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
auto predict_input = cnode->inputs()[1];
|
||||
auto predict_input = cnode->inputs()[kIndex1];
|
||||
auto new_node_dtype = {AnfAlgo::GetOutputInferDataType(predict_input, 0)};
|
||||
auto new_node_shape = {AnfAlgo::GetOutputInferShape(predict_input, 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(new_node_dtype, new_node_shape, new_cnode.get());
|
||||
|
|
|
@ -29,12 +29,6 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kIndex0 = 0;
|
||||
constexpr size_t kIndex1 = 1;
|
||||
constexpr size_t kIndex2 = 2;
|
||||
constexpr size_t kIndex3 = 3;
|
||||
constexpr size_t kIndex4 = 4;
|
||||
constexpr size_t kIndex5 = 5;
|
||||
void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node,
|
||||
std::vector<AnfNodePtr> *bn_update_grad_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
|
|
@ -32,12 +32,6 @@ namespace opt {
|
|||
namespace {
|
||||
constexpr auto kReduceOpSum = "sum";
|
||||
constexpr auto kDeviceNum = "device_num";
|
||||
constexpr size_t kIndex0 = 0;
|
||||
constexpr size_t kIndex1 = 1;
|
||||
constexpr size_t kIndex2 = 2;
|
||||
constexpr size_t kIndex3 = 3;
|
||||
constexpr size_t kIndex4 = 4;
|
||||
constexpr size_t kIndex5 = 5;
|
||||
constexpr size_t kPositionOffset = 3;
|
||||
constexpr int64_t kFusionNumThreshold = 2;
|
||||
|
||||
|
@ -51,7 +45,7 @@ bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &
|
|||
}
|
||||
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName))};
|
||||
bn_training_reduce_inputs.push_back(bn_cnode->input(1));
|
||||
bn_training_reduce_inputs.push_back(bn_cnode->input(kIndex1));
|
||||
auto bn_training_reduce = graph->NewCNode(bn_training_reduce_inputs);
|
||||
MS_EXCEPTION_IF_NULL(bn_training_reduce);
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
|
@ -62,7 +56,7 @@ bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &
|
|||
MS_LOG(INFO) << "The BatchNorm's first input's shape dims less than " << kShape2dDims;
|
||||
return false;
|
||||
}
|
||||
std::vector<size_t> bn_training_reduce_shape = {bn_shape_i0[1]};
|
||||
std::vector<size_t> bn_training_reduce_shape = {bn_shape_i0[kDim1]};
|
||||
auto types = {kNumberTypeFloat32, kNumberTypeFloat32};
|
||||
auto shapes = {bn_training_reduce_shape, bn_training_reduce_shape};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_training_reduce.get());
|
||||
|
@ -191,6 +185,8 @@ AnfNodePtr CreateValueNodeOfDeviceNumReciprocal(const FuncGraphPtr &graph, const
|
|||
}
|
||||
|
||||
AnfNodePtr InsertCast(const FuncGraphPtr &graph, const AnfNodePtr &input, const TypeId dst_type) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (AnfAlgo::GetOutputInferDataType(input, 0) != dst_type) {
|
||||
AnfNodePtr cast = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(kCastOpName)), input});
|
||||
AnfAlgo::SetOutputInferTypeAndShape({dst_type}, {AnfAlgo::GetOutputInferShape(input, 0)}, cast.get());
|
||||
|
|
|
@ -31,22 +31,6 @@ constexpr size_t kGRUV2HiddenGradOutputNum = 3;
|
|||
constexpr size_t kConcatNum = 2;
|
||||
constexpr size_t kGateNum = 3;
|
||||
constexpr size_t k3Dims = 3;
|
||||
constexpr size_t kIndex0 = 0;
|
||||
constexpr size_t kIndex1 = 1;
|
||||
constexpr size_t kIndex2 = 2;
|
||||
constexpr size_t kIndex3 = 3;
|
||||
constexpr size_t kIndex4 = 4;
|
||||
constexpr size_t kIndex5 = 5;
|
||||
constexpr size_t kIndex6 = 6;
|
||||
constexpr size_t kIndex7 = 7;
|
||||
constexpr size_t kIndex8 = 8;
|
||||
constexpr size_t kIndex9 = 9;
|
||||
constexpr size_t kIndex10 = 10;
|
||||
constexpr size_t kIndex11 = 11;
|
||||
constexpr size_t kIndex12 = 12;
|
||||
constexpr size_t DIM0 = 0;
|
||||
constexpr size_t DIM1 = 1;
|
||||
constexpr size_t DIM2 = 2;
|
||||
|
||||
AnfNodePtr CreateGRUV2HiddenGradNode(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -74,9 +58,9 @@ AnfNodePtr CreateGRUV2HiddenGradNode(const FuncGraphPtr &graph, const AnfNodePtr
|
|||
auto types = {h_dtype, h_dtype, h_dtype};
|
||||
std::vector<size_t> dh_preh_shape = AnfAlgo::GetOutputInferShape(ori_outputs[kIndex5], 0);
|
||||
std::vector<size_t> dgate_h_shape = {
|
||||
AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[kIndex6], 0)[DIM0],
|
||||
AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[kIndex6], 0)[DIM1],
|
||||
kGateNum * AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[kIndex6], 0)[DIM2]};
|
||||
AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[kIndex6], 0)[kDim0],
|
||||
AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[kIndex6], 0)[kDim1],
|
||||
kGateNum * AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[kIndex6], 0)[kDim2]};
|
||||
std::vector<size_t> dnx_t_shape = AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[kIndex6], 0);
|
||||
auto shapes = {dh_preh_shape, dgate_h_shape, dnx_t_shape};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, gru_v2_hidden_grad_op.get());
|
||||
|
@ -93,14 +77,14 @@ AnfNodePtr CreateHSplitVDNode(const FuncGraphPtr &graph, const AnfNodePtr &node)
|
|||
auto split_vd = graph->NewCNode(splitvd_input);
|
||||
MS_EXCEPTION_IF_NULL(split_vd);
|
||||
auto dtypes = {AnfAlgo::GetOutputInferDataType(node, 0), AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
size_t t_size = AnfAlgo::GetOutputInferShape(node, 0)[DIM0];
|
||||
size_t batch = AnfAlgo::GetOutputInferShape(node, 0)[DIM1];
|
||||
size_t hidden_size = AnfAlgo::GetOutputInferShape(node, 0)[DIM2];
|
||||
size_t t_size = AnfAlgo::GetOutputInferShape(node, 0)[kDim0];
|
||||
size_t batch = AnfAlgo::GetOutputInferShape(node, 0)[kDim1];
|
||||
size_t hidden_size = AnfAlgo::GetOutputInferShape(node, 0)[kDim2];
|
||||
std::vector<size_t> shape = {t_size - IntToSize(1), batch, hidden_size};
|
||||
std::vector<size_t> shape2 = {IntToSize(1), batch, hidden_size};
|
||||
std::vector<std::vector<size_t>> shapes = {shape, shape2};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_vd.get());
|
||||
AnfAlgo::SetNodeAttr("split_dim", MakeValue(SizeToLong(DIM0)), split_vd);
|
||||
AnfAlgo::SetNodeAttr("split_dim", MakeValue(SizeToLong(kDim0)), split_vd);
|
||||
AnfAlgo::SetNodeAttr("num_split", MakeValue(SizeToLong(kSplitVOutputNum)), split_vd);
|
||||
std::vector<int64_t> size_splits = {SizeToLong(t_size - 1), SizeToLong(1)};
|
||||
AnfAlgo::SetNodeAttr("size_splits", MakeValue(size_splits), split_vd);
|
||||
|
@ -116,7 +100,7 @@ AnfNodePtr CreateHReshape(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
|||
if (ori_shape.size() == k3Dims) {
|
||||
shape_tmp = {ori_shape};
|
||||
} else {
|
||||
shape_tmp = {{IntToSize(1), ori_shape[DIM0], ori_shape[DIM1]}};
|
||||
shape_tmp = {{IntToSize(1), ori_shape[kDim0], ori_shape[kDim1]}};
|
||||
}
|
||||
auto ori_dtype = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
// reshape
|
||||
|
@ -140,14 +124,14 @@ AnfNodePtr CreateHConcatDNode(const FuncGraphPtr &graph, const AnfNodePtr &node1
|
|||
auto concat_op = graph->NewCNode(concat_inputs);
|
||||
MS_EXCEPTION_IF_NULL(concat_op);
|
||||
|
||||
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node2, 0)[DIM0] + 1,
|
||||
AnfAlgo::GetOutputInferShape(node2, 0)[DIM1],
|
||||
AnfAlgo::GetOutputInferShape(node2, 0)[DIM2]};
|
||||
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node2, 0)[kDim0] + 1,
|
||||
AnfAlgo::GetOutputInferShape(node2, 0)[kDim1],
|
||||
AnfAlgo::GetOutputInferShape(node2, 0)[kDim2]};
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node2, 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, {shape}, concat_op.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(kConcatNum)), concat_op);
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{2}), concat_op);
|
||||
AnfAlgo::SetNodeAttr("axis", MakeValue(SizeToLong(0)), concat_op);
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(0)), concat_op);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat_op);
|
||||
return concat_op;
|
||||
}
|
||||
|
@ -160,14 +144,14 @@ AnfNodePtr CreateDgateHSplitVDNode(const FuncGraphPtr &graph, const AnfNodePtr &
|
|||
auto split_vd = graph->NewCNode(splitvd_input);
|
||||
MS_EXCEPTION_IF_NULL(split_vd);
|
||||
auto dtypes = {AnfAlgo::GetOutputInferDataType(node, 0), AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
size_t t_size = AnfAlgo::GetOutputInferShape(node, 0)[DIM0];
|
||||
size_t batch = AnfAlgo::GetOutputInferShape(node, 0)[DIM1];
|
||||
size_t hidden_size = AnfAlgo::GetOutputInferShape(node, 0)[DIM2] / kGateNum;
|
||||
size_t t_size = AnfAlgo::GetOutputInferShape(node, 0)[kDim0];
|
||||
size_t batch = AnfAlgo::GetOutputInferShape(node, 0)[kDim1];
|
||||
size_t hidden_size = AnfAlgo::GetOutputInferShape(node, 0)[kDim2] / kGateNum;
|
||||
std::vector<size_t> shape = {t_size, batch, hidden_size << 1};
|
||||
std::vector<size_t> shape2 = {t_size, batch, hidden_size};
|
||||
std::vector<std::vector<size_t>> shapes = {shape, shape2};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_vd.get());
|
||||
AnfAlgo::SetNodeAttr("split_dim", MakeValue(SizeToLong(DIM2)), split_vd);
|
||||
AnfAlgo::SetNodeAttr("split_dim", MakeValue(SizeToLong(kDim2)), split_vd);
|
||||
AnfAlgo::SetNodeAttr("num_split", MakeValue(SizeToLong(kSplitVOutputNum)), split_vd);
|
||||
std::vector<int64_t> size_splits = {SizeToLong(hidden_size + hidden_size), SizeToLong(hidden_size)};
|
||||
AnfAlgo::SetNodeAttr("size_splits", MakeValue(size_splits), split_vd);
|
||||
|
@ -190,13 +174,13 @@ AnfNodePtr CreateDgateXConcatDNode(const FuncGraphPtr &graph, const AnfNodePtr &
|
|||
auto concat_op = graph->NewCNode(concat_inputs);
|
||||
MS_EXCEPTION_IF_NULL(concat_op);
|
||||
std::vector<size_t> shape = {
|
||||
AnfAlgo::GetOutputInferShape(node2, 0)[DIM0], AnfAlgo::GetOutputInferShape(node2, 0)[DIM1],
|
||||
AnfAlgo::GetOutputInferShape(node1, 0)[DIM2] + AnfAlgo::GetOutputInferShape(node2, 0)[DIM2]};
|
||||
AnfAlgo::GetOutputInferShape(node2, 0)[kDim0], AnfAlgo::GetOutputInferShape(node2, 0)[kDim1],
|
||||
AnfAlgo::GetOutputInferShape(node1, 0)[kDim2] + AnfAlgo::GetOutputInferShape(node2, 0)[kDim2]};
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node2, 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, {shape}, concat_op.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(kConcatNum)), concat_op);
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{2}), concat_op);
|
||||
AnfAlgo::SetNodeAttr("axis", MakeValue(SizeToLong(DIM2)), concat_op);
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(kDim2)), concat_op);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat_op);
|
||||
return concat_op;
|
||||
}
|
||||
|
@ -211,15 +195,15 @@ AnfNodePtr CreateWBroadcastToDNode(const FuncGraphPtr &graph, const AnfNodePtr &
|
|||
std::vector<AnfNodePtr> braodcast_to_input = {NewValueNode(std::make_shared<Primitive>(kBroadcastToOpName)), node1};
|
||||
auto broadcast_to_d = graph->NewCNode(braodcast_to_input);
|
||||
MS_EXCEPTION_IF_NULL(broadcast_to_d);
|
||||
size_t t_size = AnfAlgo::GetOutputInferShape(node2, 0)[DIM0];
|
||||
size_t batch = AnfAlgo::GetOutputInferShape(node1, 0)[DIM0];
|
||||
size_t gate_size = AnfAlgo::GetOutputInferShape(node1, 0)[DIM1];
|
||||
size_t t_size = AnfAlgo::GetOutputInferShape(node2, 0)[kDim0];
|
||||
size_t batch = AnfAlgo::GetOutputInferShape(node1, 0)[kDim0];
|
||||
size_t gate_size = AnfAlgo::GetOutputInferShape(node1, 0)[kDim1];
|
||||
std::vector<size_t> shape = {t_size, batch, gate_size};
|
||||
auto type = {AnfAlgo::GetOutputInferDataType(node1, 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(type, {shape}, broadcast_to_d.get());
|
||||
|
||||
std::vector<int64_t> attr_shape = {SizeToLong(t_size), SizeToLong(batch), SizeToLong(gate_size)};
|
||||
AnfAlgo::SetNodeAttr("shape", MakeValue(attr_shape), broadcast_to_d);
|
||||
AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(attr_shape), broadcast_to_d);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), broadcast_to_d);
|
||||
return broadcast_to_d;
|
||||
}
|
||||
|
@ -233,9 +217,9 @@ AnfNodePtr CreateDhxBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &nod
|
|||
node1, node2};
|
||||
auto batch_matmul = graph->NewCNode(matmul_inputs);
|
||||
MS_EXCEPTION_IF_NULL(batch_matmul);
|
||||
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node1, 0)[DIM0],
|
||||
AnfAlgo::GetOutputInferShape(node1, 0)[DIM2],
|
||||
AnfAlgo::GetOutputInferShape(node2, 0)[DIM2]};
|
||||
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node1, 0)[kDim0],
|
||||
AnfAlgo::GetOutputInferShape(node1, 0)[kDim2],
|
||||
AnfAlgo::GetOutputInferShape(node2, 0)[kDim2]};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {shape}, batch_matmul.get());
|
||||
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(true), batch_matmul);
|
||||
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul);
|
||||
|
@ -252,9 +236,9 @@ AnfNodePtr CreateDwhBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &nod
|
|||
node1, node2};
|
||||
auto batch_matmul = graph->NewCNode(matmul_inputs);
|
||||
MS_EXCEPTION_IF_NULL(batch_matmul);
|
||||
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node1, 0)[DIM0],
|
||||
AnfAlgo::GetOutputInferShape(node1, 0)[DIM1],
|
||||
AnfAlgo::GetOutputInferShape(node2, 0)[DIM1]};
|
||||
std::vector<size_t> shape = {AnfAlgo::GetOutputInferShape(node1, 0)[kDim0],
|
||||
AnfAlgo::GetOutputInferShape(node1, 0)[kDim1],
|
||||
AnfAlgo::GetOutputInferShape(node2, 0)[kDim1]};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {shape}, batch_matmul.get());
|
||||
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), batch_matmul);
|
||||
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(true), batch_matmul);
|
||||
|
@ -290,7 +274,7 @@ AnfNodePtr CreateDbReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &n
|
|||
MS_EXCEPTION_IF_NULL(reduce_sumd);
|
||||
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
std::vector<size_t> shape = {kGateNum * AnfAlgo::GetOutputInferShape(node2, 0)[DIM1]};
|
||||
std::vector<size_t> shape = {kGateNum * AnfAlgo::GetOutputInferShape(node2, 0)[kDim1]};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, {shape}, reduce_sumd.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0, 1}), reduce_sumd);
|
||||
AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_sumd);
|
||||
|
@ -322,7 +306,7 @@ const AnfNodePtr DynamicGRUV2GradFission::Process(const FuncGraphPtr &func_graph
|
|||
auto gru_v2_gru_hidden = CreateGRUV2HiddenGradNode(func_graph, dynamic_gru_v2_grad_cnode);
|
||||
std::vector<AnfNodePtr> gru_hidden_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, gru_v2_gru_hidden, kGRUV2HiddenGradOutputNum, &gru_hidden_outputs);
|
||||
size_t step_num = AnfAlgo::GetOutputInferShape(ori_inputs[kIndex1], 0)[DIM0];
|
||||
size_t step_num = AnfAlgo::GetOutputInferShape(ori_inputs[kIndex1], 0)[kDim0];
|
||||
AnfNodePtr dwh_batch_matmul = nullptr;
|
||||
if (step_num != 1) {
|
||||
// split h
|
||||
|
|
|
@ -36,9 +36,9 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
|
|||
std::vector<AnfNodePtr> matmul_nodes;
|
||||
std::vector<AnfNodePtr> split_nodes;
|
||||
// Get the size of t
|
||||
auto origin_input9_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(11), 0);
|
||||
size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(9), 0)[0];
|
||||
auto input_i_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(12), 0);
|
||||
auto origin_input9_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex11), 0);
|
||||
size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex9), 0)[0];
|
||||
auto input_i_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex12), 0);
|
||||
|
||||
for (size_t i = 0; i < t_size; ++i) {
|
||||
// Create basic_lstm_cell_c_state_grad
|
||||
|
@ -46,15 +46,16 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
|
|||
NewValueNode(std::make_shared<Primitive>(kBasicLSTMCellCStateGradV2OpName))};
|
||||
auto basic_lstm_cell_c_state_grad = func_graph->NewCNode(basic_lstm_cell_c_state_grad_inputs);
|
||||
|
||||
std::vector<size_t> output0_dims{origin_input9_shape[0], 4 * (((origin_input9_shape[1] + 15) / 16) * 16)};
|
||||
std::vector<size_t> output1_dims{input_i_shape[1], input_i_shape[2]};
|
||||
std::vector<size_t> output0_dims{origin_input9_shape[kDim0],
|
||||
4 * (((origin_input9_shape[kDim1] + kCubeSize - 1) / kCubeSize) * kCubeSize)};
|
||||
std::vector<size_t> output1_dims{input_i_shape[kDim1], input_i_shape[kDim2]};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16, kNumberTypeFloat32}, {output0_dims, output1_dims},
|
||||
basic_lstm_cell_c_state_grad.get());
|
||||
AnfAlgo::SetNodeAttr("forget_bias", MakeValue(1.0f), basic_lstm_cell_c_state_grad);
|
||||
AnfAlgo::SetNodeAttr("activation", MakeValue("Tanh"), basic_lstm_cell_c_state_grad);
|
||||
|
||||
// Create matmul
|
||||
auto origin_input1_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(2), 0);
|
||||
auto origin_input1_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex2), 0);
|
||||
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))};
|
||||
auto matmul = func_graph->NewCNode(matmul_inputs);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {{IntToSize(1), output0_dims[0], origin_input1_shape[0]}},
|
||||
|
@ -67,14 +68,15 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
|
|||
auto split_v = func_graph->NewCNode(splitv_input);
|
||||
auto origin_output2_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 2);
|
||||
auto origin_output3_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 3);
|
||||
std::vector<size_t> split_v_output0_shape{IntToSize(1), origin_output2_shape[1], origin_output2_shape[2]};
|
||||
std::vector<size_t> split_v_output1_shape{IntToSize(1), origin_output3_shape[0], origin_output3_shape[1]};
|
||||
std::vector<size_t> split_v_output0_shape{IntToSize(1), origin_output2_shape[kDim1], origin_output2_shape[kDim2]};
|
||||
std::vector<size_t> split_v_output1_shape{IntToSize(1), origin_output3_shape[kDim0], origin_output3_shape[kDim1]};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32},
|
||||
{split_v_output0_shape, split_v_output1_shape}, split_v.get());
|
||||
|
||||
AnfAlgo::SetNodeAttr(kAttrSizeSplits,
|
||||
MakeValue(std::vector<int64_t>{SizeToLong((origin_output2_shape[2] + 15) / 16 * 16),
|
||||
SizeToLong((origin_output3_shape[1] + 15) / 16 * 16)}),
|
||||
MakeValue(std::vector<int64_t>{
|
||||
SizeToLong((origin_output2_shape[kDim2] + kCubeSize - 1) / kCubeSize * kCubeSize),
|
||||
SizeToLong((origin_output3_shape[kDim1] + kCubeSize - 1) / kCubeSize * kCubeSize)}),
|
||||
split_v);
|
||||
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(2)), split_v);
|
||||
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(static_cast<int64_t>(2)), split_v);
|
||||
|
@ -107,10 +109,10 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
|||
std::vector<std::vector<AnfNodePtr>> result_nodes;
|
||||
CreateTLoopNode(func_graph, dynamic_rnn_grad_cnode, &result_nodes);
|
||||
|
||||
auto origin_input5_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(6), 0);
|
||||
auto origin_input5_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0);
|
||||
std::vector<size_t> split_c_dims{IntToSize(1), origin_input5_shape[0], origin_input5_shape[1]};
|
||||
|
||||
auto origin_input7 = dynamic_rnn_grad_cnode->input(8);
|
||||
auto origin_input7 = dynamic_rnn_grad_cnode->input(kIndex8);
|
||||
size_t num_split_x = AnfAlgo::GetOutputInferShape(origin_input7, 0)[0];
|
||||
std::vector<std::vector<size_t>> split_shapes;
|
||||
std::vector<TypeId> split_types;
|
||||
|
@ -126,47 +128,47 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
|||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_c, num_split_x, &lstm_split_c_outputs);
|
||||
|
||||
// Create lstm_split_dy
|
||||
auto lstm_split_dy =
|
||||
CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(9), split_shapes, split_types, size_split, num_split_x);
|
||||
auto lstm_split_dy = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex9), split_shapes, split_types,
|
||||
size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_dy_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_dy, num_split_x, &lstm_split_dy_outputs);
|
||||
|
||||
// Create lstm_split_i
|
||||
auto lstm_split_i =
|
||||
CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(12), split_shapes, split_types, size_split, num_split_x);
|
||||
auto lstm_split_i = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex12), split_shapes, split_types,
|
||||
size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_i_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_i, num_split_x, &lstm_split_i_outputs);
|
||||
|
||||
// Create lstm_split_j
|
||||
auto lstm_split_j =
|
||||
CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(13), split_shapes, split_types, size_split, num_split_x);
|
||||
auto lstm_split_j = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex13), split_shapes, split_types,
|
||||
size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_j_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_j, num_split_x, &lstm_split_j_outputs);
|
||||
|
||||
// Create lstm_split_f
|
||||
auto lstm_split_f =
|
||||
CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(14), split_shapes, split_types, size_split, num_split_x);
|
||||
auto lstm_split_f = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex14), split_shapes, split_types,
|
||||
size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_f_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_f, num_split_x, &lstm_split_f_outputs);
|
||||
|
||||
// Create lstm_split_o
|
||||
auto lstm_split_o =
|
||||
CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(15), split_shapes, split_types, size_split, num_split_x);
|
||||
auto lstm_split_o = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex15), split_shapes, split_types,
|
||||
size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_o_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_o, num_split_x, &lstm_split_o_outputs);
|
||||
|
||||
// Create lstm_split_tanh
|
||||
auto lstm_split_tanh =
|
||||
CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(16), split_shapes, split_types, size_split, num_split_x);
|
||||
auto lstm_split_tanh = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex16), split_shapes,
|
||||
split_types, size_split, num_split_x);
|
||||
std::vector<AnfNodePtr> lstm_split_tanh_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_tanh, num_split_x, &lstm_split_tanh_outputs);
|
||||
|
||||
// Add edges
|
||||
std::vector<AnfNodePtr> pre_basic_lstm_cell_c_state_grad_outputs;
|
||||
std::vector<AnfNodePtr> pre_split_outputs;
|
||||
auto basic_lstm_cell_c_state_grad_nodes = result_nodes[0];
|
||||
auto matmul_nodes = result_nodes[1];
|
||||
auto split_nodes = result_nodes[2];
|
||||
auto basic_lstm_cell_c_state_grad_nodes = result_nodes[kIndex0];
|
||||
auto matmul_nodes = result_nodes[kIndex1];
|
||||
auto split_nodes = result_nodes[kIndex2];
|
||||
std::vector<AnfNodePtr> lstm_x_concat_input(num_split_x + 1);
|
||||
lstm_x_concat_input[0] = NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()));
|
||||
std::vector<AnfNodePtr> lstm_gage_concat_input(num_split_x + 1);
|
||||
|
@ -181,8 +183,9 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
|||
std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
|
||||
dynamic_rnn_grad_cnode->input(6)};
|
||||
auto reshape = func_graph->NewCNode(reshape_inputs);
|
||||
auto reshape_out_shape = {IntToSize(1), AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(6), 0)[0],
|
||||
AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(6), 0)[1]};
|
||||
auto reshape_out_shape = {IntToSize(1),
|
||||
AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0)[0],
|
||||
AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0)[1]};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {reshape_out_shape}, reshape.get());
|
||||
basic_lstm_cell_c_state_grad_inputs.emplace_back(reshape);
|
||||
} else {
|
||||
|
@ -190,8 +193,8 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
|||
}
|
||||
basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_dy_outputs[idx]);
|
||||
if (i == 0) {
|
||||
basic_lstm_cell_c_state_grad_inputs.emplace_back(dynamic_rnn_grad_cnode->input(10));
|
||||
basic_lstm_cell_c_state_grad_inputs.emplace_back(dynamic_rnn_grad_cnode->input(11));
|
||||
basic_lstm_cell_c_state_grad_inputs.emplace_back(dynamic_rnn_grad_cnode->input(kIndex10));
|
||||
basic_lstm_cell_c_state_grad_inputs.emplace_back(dynamic_rnn_grad_cnode->input(kIndex11));
|
||||
} else {
|
||||
basic_lstm_cell_c_state_grad_inputs.emplace_back(pre_split_outputs[1]);
|
||||
basic_lstm_cell_c_state_grad_inputs.emplace_back(pre_basic_lstm_cell_c_state_grad_outputs[1]);
|
||||
|
@ -213,7 +216,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
|||
// Create MatMul
|
||||
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))};
|
||||
matmul_inputs.emplace_back(basic_lstm_cell_c_state_grad_outputs[0]);
|
||||
matmul_inputs.emplace_back(dynamic_rnn_grad_cnode->input(2));
|
||||
matmul_inputs.emplace_back(dynamic_rnn_grad_cnode->input(kIndex2));
|
||||
auto matmul = func_graph->NewCNode(matmul_inputs);
|
||||
MS_EXCEPTION_IF_NULL(matmul);
|
||||
matmul->set_abstract(matmul_nodes[i]->abstract());
|
||||
|
@ -237,7 +240,8 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
|||
auto basic_lstm_cell_c_state_grad_outputs_0_shape =
|
||||
AnfAlgo::GetOutputInferShape(basic_lstm_cell_c_state_grad_outputs[0], 0);
|
||||
std::vector<size_t> temp_shape;
|
||||
if (basic_lstm_cell_c_state_grad_outputs_0_shape.size() == 3) {
|
||||
constexpr size_t kBasicLstmCStateGradOutput0DimNum = 3;
|
||||
if (basic_lstm_cell_c_state_grad_outputs_0_shape.size() == kBasicLstmCStateGradOutput0DimNum) {
|
||||
temp_shape = basic_lstm_cell_c_state_grad_outputs_0_shape;
|
||||
} else {
|
||||
temp_shape = {1, basic_lstm_cell_c_state_grad_outputs_0_shape[0],
|
||||
|
@ -262,9 +266,9 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
|||
// Create lstm_gage_concat
|
||||
auto lstm_gage_concat = func_graph->NewCNode(lstm_gage_concat_input);
|
||||
auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16},
|
||||
{{origin_input7_shape[0], origin_input7_shape[1], 4 * origin_input7_shape[2]}},
|
||||
lstm_gage_concat.get());
|
||||
AnfAlgo::SetOutputInferTypeAndShape(
|
||||
{kNumberTypeFloat16}, {{origin_input7_shape[kDim0], origin_input7_shape[kDim1], 4 * origin_input7_shape[kDim2]}},
|
||||
lstm_gage_concat.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_gage_concat);
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(num_split_x)}), lstm_gage_concat);
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(0)), lstm_gage_concat);
|
||||
|
@ -279,15 +283,15 @@ AnfNodePtr CreateSplitV(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
|
||||
// Create node
|
||||
auto origin_input6 = dynamic_rnn_grad_cnode->input(7);
|
||||
auto origin_input6 = dynamic_rnn_grad_cnode->input(kIndex7);
|
||||
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
||||
origin_input6};
|
||||
auto split_v = func_graph->NewCNode(splitv_input);
|
||||
// Set infer data type and shape
|
||||
auto dtypes = {AnfAlgo::GetOutputInferDataType(origin_input6, 0), AnfAlgo::GetOutputInferDataType(origin_input6, 0)};
|
||||
auto origin_input6_shape = AnfAlgo::GetOutputInferShape(origin_input6, 0);
|
||||
std::vector<size_t> shape1 = {origin_input6_shape[0] - 1, origin_input6_shape[1], origin_input6_shape[2]};
|
||||
std::vector<size_t> shape2 = {1, origin_input6_shape[1], origin_input6_shape[2]};
|
||||
std::vector<size_t> shape1 = {origin_input6_shape[kDim0] - 1, origin_input6_shape[kDim1], origin_input6_shape[kDim2]};
|
||||
std::vector<size_t> shape2 = {1, origin_input6_shape[kDim1], origin_input6_shape[kDim2]};
|
||||
std::vector<std::vector<size_t>> shapes = {shape1, shape2};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_v.get());
|
||||
// Set attr
|
||||
|
@ -311,11 +315,12 @@ AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic
|
|||
MS_LOG(EXCEPTION) << "Create outputs of node " << splitv->DebugString() << " failed"
|
||||
<< " trace: " << trace::DumpSourceLines(dynamic_rnn_grad_cnode);
|
||||
}
|
||||
auto origin_input4 = dynamic_rnn_grad_cnode->input(5);
|
||||
auto origin_input4 = dynamic_rnn_grad_cnode->input(kIndex5);
|
||||
auto origin_input4_shape = AnfAlgo::GetOutputInferShape(origin_input4, 0);
|
||||
// Create reshape to change shape
|
||||
std::vector<size_t> shape_tmp;
|
||||
if (origin_input4_shape.size() == 3) {
|
||||
constexpr size_t kInput4DimNum = 3;
|
||||
if (origin_input4_shape.size() == kInput4DimNum) {
|
||||
shape_tmp = origin_input4_shape;
|
||||
} else {
|
||||
shape_tmp = {1, origin_input4_shape[0], origin_input4_shape[1]};
|
||||
|
@ -351,8 +356,8 @@ AnfNodePtr CreateConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_
|
|||
// Set infer data type and shape
|
||||
auto origin_output0_shape = AnfAlgo::GetOutputInferShape(origin_input0, 0);
|
||||
auto h_concat_output_shape = AnfAlgo::GetOutputInferShape(h_concat, 0);
|
||||
std::vector<size_t> shape = {origin_output0_shape[0], origin_output0_shape[1],
|
||||
origin_output0_shape[2] + h_concat_output_shape[2]};
|
||||
std::vector<size_t> shape = {origin_output0_shape[kDim0], origin_output0_shape[kDim1],
|
||||
origin_output0_shape[kDim2] + h_concat_output_shape[kDim2]};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input0, 0)}, {shape}, concat.get());
|
||||
// Set attr
|
||||
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(2)), concat);
|
||||
|
@ -366,8 +371,8 @@ AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dy
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
|
||||
// Create node
|
||||
auto origin_input0 = dynamic_rnn_grad_cnode->input(1);
|
||||
auto origin_input4 = dynamic_rnn_grad_cnode->input(5);
|
||||
auto origin_input0 = dynamic_rnn_grad_cnode->input(kIndex1);
|
||||
auto origin_input4 = dynamic_rnn_grad_cnode->input(kIndex5);
|
||||
auto origin_input4_shape = AnfAlgo::GetOutputInferShape(origin_input4, 0);
|
||||
// Create reshape to change shape
|
||||
std::vector<size_t> shape_tmp;
|
||||
|
@ -386,7 +391,8 @@ AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dy
|
|||
auto concat = func_graph->NewCNode(concat_inputs);
|
||||
// Set infer data type and shape
|
||||
auto origin_input0_shape = AnfAlgo::GetOutputInferShape(origin_input0, 0);
|
||||
std::vector<size_t> shape = {origin_input0_shape[0], origin_input0_shape[1], origin_input0_shape[2] + shape_tmp[2]};
|
||||
std::vector<size_t> shape = {origin_input0_shape[kDim0], origin_input0_shape[kDim1],
|
||||
origin_input0_shape[kDim2] + shape_tmp[kDim2]};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input0, 0)}, {shape}, concat.get());
|
||||
// Set attr
|
||||
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(2)), concat);
|
||||
|
@ -406,7 +412,7 @@ AnfNodePtr CreateBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &l
|
|||
// Set infer data type and shape
|
||||
auto concat_shape = AnfAlgo::GetOutputInferShape(concat, 0);
|
||||
auto lstm_input_grad_shape = AnfAlgo::GetOutputInferShape(lstm_input_grad, 0);
|
||||
std::vector<size_t> shape = {concat_shape[0], concat_shape[2], lstm_input_grad_shape[2]};
|
||||
std::vector<size_t> shape = {concat_shape[kDim0], concat_shape[kDim2], lstm_input_grad_shape[kDim2]};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {shape}, batch_matmul.get());
|
||||
// Set attr
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul);
|
||||
|
@ -451,7 +457,7 @@ AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dyn
|
|||
}
|
||||
|
||||
AnfNodePtr CreateValueNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) {
|
||||
auto origin_input7 = dynamic_rnn_grad_cnode->input(8);
|
||||
auto origin_input7 = dynamic_rnn_grad_cnode->input(kIndex8);
|
||||
auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0);
|
||||
auto t_size = origin_input7_shape[0];
|
||||
auto n_size = origin_input7_shape[1];
|
||||
|
@ -477,7 +483,7 @@ AnfNodePtr CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dyn
|
|||
batch_matmul};
|
||||
auto reduce_sum = func_graph->NewCNode(reduce_sum_inputs);
|
||||
// Set infer data type and shape
|
||||
auto out_shape = {AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[2]};
|
||||
auto out_shape = {AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[kDim2]};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {out_shape}, reduce_sum.get());
|
||||
// Set attr
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0}), reduce_sum);
|
||||
|
@ -506,7 +512,7 @@ const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph
|
|||
std::vector<AnfNodePtr> new_outputs;
|
||||
auto lstm_input_grad = AddLSTMInputGradNode(func_graph, dynamic_rnn_grad_cnode, &new_outputs);
|
||||
|
||||
size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(7), 0)[0];
|
||||
size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex7), 0)[0];
|
||||
AnfNodePtr concat = nullptr;
|
||||
if (t_size != 1) {
|
||||
auto splitv = CreateSplitV(func_graph, dynamic_rnn_grad_cnode);
|
||||
|
|
|
@ -53,6 +53,7 @@ CNodePtr CreatePad(const FuncGraphPtr &graph, const CNodePtr &origin_node, const
|
|||
shape[shape.size() - 1] = SizeToLong(pad_dim_size);
|
||||
auto type_id = AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0);
|
||||
auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type_id), shape);
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
if (param_dyn_shape->max_shape().size() == param_dyn_shape->shape().size() &&
|
||||
param_dyn_shape->min_shape().size() == param_dyn_shape->shape().size()) {
|
||||
ShapeVector max_shape(param_dyn_shape->max_shape());
|
||||
|
@ -134,8 +135,7 @@ bool CheckInputs(const CNodePtr &origin_node) {
|
|||
auto param_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
|
||||
auto indice_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 1);
|
||||
// this optimizer only support embedding_table has dynamic shape
|
||||
constexpr size_t DIM2 = 2;
|
||||
if (param_shape.empty() || indice_shape.empty() || AnfAlgo::IsDynamicShape(origin_node->input(DIM2))) {
|
||||
if (param_shape.empty() || indice_shape.empty() || AnfAlgo::IsDynamicShape(origin_node->input(kDim2))) {
|
||||
return false;
|
||||
}
|
||||
if (param_shape[param_shape.size() - 1] != 1) {
|
||||
|
|
|
@ -28,6 +28,7 @@ constexpr size_t kLinSpaceInputNum = 3;
|
|||
constexpr size_t kFloat32Len = 4;
|
||||
tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
|
||||
// 1 get tensor value of input num
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto input_num = cnode->input(kLinSpaceInputNum);
|
||||
|
|
|
@ -26,12 +26,10 @@ namespace opt {
|
|||
constexpr size_t kInputNum = 3;
|
||||
constexpr size_t kFloat16Len = 2; // size of float16;
|
||||
constexpr size_t kKernelSizeNum = 5;
|
||||
constexpr size_t DIM2 = 2;
|
||||
constexpr size_t DIM3 = 3;
|
||||
constexpr size_t DIM4 = 4;
|
||||
namespace {
|
||||
tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
|
||||
// 1 get attr ksize
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto ksize = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, "kernel_size");
|
||||
|
@ -42,9 +40,9 @@ tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
|
|||
if (ksize.size() != kKernelSizeNum) {
|
||||
MS_LOG(EXCEPTION) << "kernel_size of MaxPool3DGradGrad must be five, but got :" << ksize;
|
||||
}
|
||||
int64_t d = ksize[DIM2];
|
||||
int64_t h = ksize[DIM3];
|
||||
int64_t w = ksize[DIM4];
|
||||
int64_t d = ksize[kDim2];
|
||||
int64_t h = ksize[kDim3];
|
||||
int64_t w = ksize[kDim4];
|
||||
|
||||
// 1 create tensor
|
||||
std::vector<int64_t> assist_shape = {1, 1, d, h, w}; // shape:NCDHW
|
||||
|
|
|
@ -27,6 +27,7 @@ CNodePtr CreateReduceMin(const FuncGraphPtr &graph, const AnfNodePtr &input, con
|
|||
MS_EXCEPTION_IF_NULL(old_node);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceMin->name())), input};
|
||||
CNodePtr reduce_min = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(reduce_min);
|
||||
reduce_min->set_scope(old_node->scope());
|
||||
AnfAlgo::CopyNodeAttr(kAttrKeepDims, old_node, reduce_min);
|
||||
return reduce_min;
|
||||
|
|
|
@ -40,8 +40,7 @@ AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodeP
|
|||
MS_EXCEPTION_IF_NULL(bn_training_reduce);
|
||||
|
||||
// set abstract
|
||||
constexpr size_t DIM2 = 2;
|
||||
auto bn_input1 = bn_cnode->input(DIM2);
|
||||
auto bn_input1 = bn_cnode->input(kDim2);
|
||||
MS_EXCEPTION_IF_NULL(bn_input1);
|
||||
AbstractBasePtrList abstract_list{bn_input1->abstract(), bn_input1->abstract()};
|
||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
|
@ -67,11 +66,11 @@ AnfNodePtr CreateBNTrainingUpdateV3(const FuncGraphPtr &func_graph, const AnfNod
|
|||
}
|
||||
std::vector<AnfNodePtr> bn_training_update_v3_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateV3OpName)),
|
||||
bn_cnode->input(1),
|
||||
bn_training_reduce_outputs[0],
|
||||
bn_training_reduce_outputs[1],
|
||||
bn_cnode->input(2),
|
||||
bn_cnode->input(3)};
|
||||
bn_cnode->input(kIndex1),
|
||||
bn_training_reduce_outputs[kIndex0],
|
||||
bn_training_reduce_outputs[kIndex1],
|
||||
bn_cnode->input(kIndex2),
|
||||
bn_cnode->input(kIndex3)};
|
||||
auto bn_training_update_v3 = func_graph->NewCNode(bn_training_update_v3_inputs);
|
||||
MS_EXCEPTION_IF_NULL(bn_training_update_v3);
|
||||
|
||||
|
|
|
@ -26,20 +26,17 @@ namespace opt {
|
|||
namespace {
|
||||
constexpr size_t kFloat16Len = 2;
|
||||
constexpr size_t kSpaceToDepthInputNum = 1;
|
||||
constexpr size_t kInputIndex1 = 1;
|
||||
constexpr size_t DIM1 = 1;
|
||||
constexpr size_t DIM2 = 2;
|
||||
constexpr size_t DIM3 = 3;
|
||||
|
||||
tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
|
||||
// 1 create tensor
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto input_x = cnode->input(kSpaceToDepthInputNum);
|
||||
int64_t block_size = AnfAlgo::GetNodeAttr<int64_t>(cnode, "block_size");
|
||||
std::vector<size_t> x_shape = AnfAlgo::GetOutputInferShape(input_x, 0);
|
||||
int64_t input_channel = SizeToLong(x_shape[DIM1]);
|
||||
int64_t assist_input_channel = SizeToLong(x_shape[DIM1]) * block_size * block_size;
|
||||
int64_t input_channel = SizeToLong(x_shape[kDim1]);
|
||||
int64_t assist_input_channel = SizeToLong(x_shape[kDim1]) * block_size * block_size;
|
||||
std::vector<int64_t> assist_input_shape = {assist_input_channel, input_channel, block_size, block_size};
|
||||
int64_t dest_size = assist_input_channel * input_channel * block_size * block_size;
|
||||
MS_LOG(DEBUG) << "For SpaceToDepth op, assist input shape is: (" << assist_input_channel << ", " << input_channel
|
||||
|
@ -50,7 +47,7 @@ tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
|
|||
assist_tensor->set_device_info(device_info);
|
||||
|
||||
// 2 set value of tensor
|
||||
int64_t window_size = assist_input_shape[DIM2] * assist_input_shape[DIM3];
|
||||
int64_t window_size = assist_input_shape[kDim2] * assist_input_shape[kDim3];
|
||||
int64_t channel_size = input_channel;
|
||||
auto data_ptr = assist_tensor->data_c();
|
||||
MS_EXCEPTION_IF_NULL(data_ptr);
|
||||
|
@ -108,7 +105,7 @@ const AnfNodePtr SpaceToDepthSplit::Process(const FuncGraphPtr &graph, const Anf
|
|||
return nullptr;
|
||||
}
|
||||
const auto &ori_inputs = cnode->inputs();
|
||||
TypeId x_dtype = AnfAlgo::GetOutputInferDataType(ori_inputs[kInputIndex1], 0);
|
||||
TypeId x_dtype = AnfAlgo::GetOutputInferDataType(ori_inputs[kIndex1], 0);
|
||||
if (x_dtype != kNumberTypeFloat16) {
|
||||
MS_LOG(INFO) << "Node " << cnode->DebugString() << ": The data type of node's first input is: " << x_dtype
|
||||
<< ", not fp16, cannot do fusion.";
|
||||
|
|
|
@ -62,7 +62,7 @@ CNodePtr TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePt
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto input_node = node->cast<CNodePtr>()->input(1);
|
||||
auto input_node = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
|
||||
auto input_format = AnfAlgo::GetInputFormat(node, 0);
|
||||
|
|
|
@ -27,7 +27,7 @@ CNodePtr CreatePadding(const FuncGraphPtr &graph, const CNodePtr &origin_node, c
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
std::vector<AnfNodePtr> padding_inputs = {NewValueNode(std::make_shared<Primitive>(kPaddingOpName)),
|
||||
origin_node->input(1)};
|
||||
origin_node->input(kIndex1)};
|
||||
auto padding = graph->NewCNode(padding_inputs);
|
||||
MS_EXCEPTION_IF_NULL(padding);
|
||||
padding->set_scope(origin_node->scope());
|
||||
|
@ -45,7 +45,8 @@ CNodePtr CreateUnsortedSegmentSum(const FuncGraphPtr &graph, const CNodePtr &ori
|
|||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
MS_EXCEPTION_IF_NULL(padding);
|
||||
std::vector<AnfNodePtr> unsorted_segment_sum8_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimUnsortedSegmentSum->name())), padding, origin_node->input(2)};
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimUnsortedSegmentSum->name())), padding,
|
||||
origin_node->input(kIndex2)};
|
||||
auto unsorted_segment_sum = graph->NewCNode(unsorted_segment_sum8_inputs);
|
||||
MS_EXCEPTION_IF_NULL(unsorted_segment_sum);
|
||||
unsorted_segment_sum->set_scope(origin_node->scope());
|
||||
|
|
|
@ -35,10 +35,6 @@ constexpr size_t kAvgPoolGradInputNum = 3;
|
|||
constexpr size_t kShapeDimNum = 4;
|
||||
constexpr float kKernelMatrixInitNum = 1.0;
|
||||
constexpr size_t kFloat32Len = 4; // size of float32
|
||||
constexpr size_t DIM0 = 0;
|
||||
constexpr size_t DIM1 = 1;
|
||||
constexpr size_t DIM2 = 2;
|
||||
constexpr size_t DIM3 = 3;
|
||||
std::vector<int64_t> GetInputXShape(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
std::vector<int64_t> shapes;
|
||||
|
@ -49,6 +45,8 @@ std::vector<int64_t> GetInputXShape(const AnfNodePtr &node) {
|
|||
|
||||
int64_t windowed_output_size(int64_t input_size, int64_t ksize, int64_t stride, PadMode pad_mode, int64_t *pad_before,
|
||||
int64_t *pad_after) {
|
||||
MS_EXCEPTION_IF_NULL(pad_before);
|
||||
MS_EXCEPTION_IF_NULL(pad_after);
|
||||
int64_t output = 0;
|
||||
*pad_before = 0;
|
||||
*pad_after = 0;
|
||||
|
@ -77,17 +75,20 @@ std::vector<std::vector<float>> GetAssistInputMatrix(const std::vector<int64_t>
|
|||
// w*w_stride : w*w_stride + w_ksize]` of `assist_input_matrix`, so the sum of slide window is the
|
||||
// number of input that associate with output element.
|
||||
std::vector<std::vector<float>> assist_input_matrix;
|
||||
std::vector<int64_t> in_shape_after_padding_2d = {x_shape[DIM2] + pad_top + pad_bottom,
|
||||
x_shape[DIM3] + pad_left + pad_right};
|
||||
std::vector<float> tmp_zero_vector(in_shape_after_padding_2d[1], 0.0);
|
||||
std::vector<float> tmp_one_vector(in_shape_after_padding_2d[1], 1.0);
|
||||
for (int64_t i = 0; i < in_shape_after_padding_2d[1]; ++i) {
|
||||
if (i < pad_left || i >= (in_shape_after_padding_2d[1] - pad_right)) {
|
||||
if (x_shape.size() < kShapeDimNum) {
|
||||
MS_LOG(EXCEPTION) << "The dim of x_shape should not be less than 4.";
|
||||
}
|
||||
std::vector<int64_t> in_shape_after_padding_2d = {x_shape[kDim2] + pad_top + pad_bottom,
|
||||
x_shape[kDim3] + pad_left + pad_right};
|
||||
std::vector<float> tmp_zero_vector(in_shape_after_padding_2d[kDim1], 0.0);
|
||||
std::vector<float> tmp_one_vector(in_shape_after_padding_2d[kDim1], 1.0);
|
||||
for (int64_t i = 0; i < in_shape_after_padding_2d[kDim1]; ++i) {
|
||||
if (i < pad_left || i >= (in_shape_after_padding_2d[kDim1] - pad_right)) {
|
||||
tmp_one_vector[LongToSize(i)] = 0.0;
|
||||
}
|
||||
}
|
||||
for (int64_t i = 0; i < in_shape_after_padding_2d[0]; ++i) {
|
||||
if (i < pad_top || i >= (in_shape_after_padding_2d[0] - pad_bottom)) {
|
||||
for (int64_t i = 0; i < in_shape_after_padding_2d[kDim0]; ++i) {
|
||||
if (i < pad_top || i >= (in_shape_after_padding_2d[kDim0] - pad_bottom)) {
|
||||
assist_input_matrix.emplace_back(tmp_zero_vector);
|
||||
} else {
|
||||
assist_input_matrix.emplace_back(tmp_one_vector);
|
||||
|
@ -106,8 +107,10 @@ ValueNodePtr CreateMeanMatrixValueNode(const FuncGraphPtr &func_graph, const std
|
|||
MS_LOG(EXCEPTION) << "The dim of x_shape or kernel_size or strides of AvgPoolGrad should be 4.";
|
||||
}
|
||||
int64_t pad_top, pad_bottom, pad_left, pad_right;
|
||||
int64_t h_output = windowed_output_size(x_shape[DIM2], k_size[DIM2], stride[DIM2], pad_mode, &pad_top, &pad_bottom);
|
||||
int64_t w_output = windowed_output_size(x_shape[DIM3], k_size[DIM3], stride[DIM3], pad_mode, &pad_left, &pad_right);
|
||||
int64_t h_output =
|
||||
windowed_output_size(x_shape[kDim2], k_size[kDim2], stride[kDim2], pad_mode, &pad_top, &pad_bottom);
|
||||
int64_t w_output =
|
||||
windowed_output_size(x_shape[kDim3], k_size[kDim3], stride[kDim3], pad_mode, &pad_left, &pad_right);
|
||||
auto assist_input_matrix = GetAssistInputMatrix(x_shape, pad_top, pad_bottom, pad_left, pad_right);
|
||||
|
||||
// calculate output
|
||||
|
@ -115,8 +118,8 @@ ValueNodePtr CreateMeanMatrixValueNode(const FuncGraphPtr &func_graph, const std
|
|||
for (int64_t h = 0; h < h_output; ++h) {
|
||||
for (int64_t w = 0; w < w_output; ++w) {
|
||||
float curr_sum = 0;
|
||||
for (int64_t i = h * stride[DIM2]; i < h * stride[DIM2] + k_size[DIM2]; ++i) {
|
||||
for (int64_t j = w * stride[DIM3]; j < w * stride[DIM3] + k_size[DIM3]; ++j) {
|
||||
for (int64_t i = h * stride[kDim2]; i < h * stride[kDim2] + k_size[kDim2]; ++i) {
|
||||
for (int64_t j = w * stride[kDim3]; j < w * stride[kDim3] + k_size[kDim3]; ++j) {
|
||||
curr_sum += assist_input_matrix[LongToSize(i)][LongToSize(j)];
|
||||
}
|
||||
}
|
||||
|
@ -127,12 +130,12 @@ ValueNodePtr CreateMeanMatrixValueNode(const FuncGraphPtr &func_graph, const std
|
|||
}
|
||||
|
||||
// make output tensor
|
||||
std::vector<int64_t> output_shape = {x_shape[0], x_shape[1], h_output, w_output};
|
||||
std::vector<int64_t> output_shape = {x_shape[kDim0], x_shape[kDim1], h_output, w_output};
|
||||
auto output_size = std::accumulate(output_shape.begin(), output_shape.end(), int64_t(1), std::multiplies<int64_t>());
|
||||
std::vector<float> output(output_size, 0.0);
|
||||
for (int64_t i = 0; i < output_shape[0] * output_shape[1]; ++i) {
|
||||
for (int64_t i = 0; i < output_shape[kDim0] * output_shape[kDim1]; ++i) {
|
||||
size_t src_size = hw_output.size() * kFloat32Len;
|
||||
auto dst_size = LongToSize(output_shape[DIM2]) * LongToSize(output_shape[DIM3]) * kFloat32Len;
|
||||
auto dst_size = LongToSize(output_shape[kDim2]) * LongToSize(output_shape[kDim3]) * kFloat32Len;
|
||||
auto ret = memcpy_s(&output[LongToSize(i) * hw_output.size()], dst_size, &hw_output[0], src_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
|
||||
|
@ -157,7 +160,7 @@ ValueNodePtr CreateKernelMatrixValueNode(const FuncGraphPtr &func_graph, const s
|
|||
if (x_shape.size() != kShapeDimNum || k_size.size() != kShapeDimNum) {
|
||||
MS_LOG(EXCEPTION) << "The dim of x_shape or kernel_size of AvgPoolGrad should be 4.";
|
||||
}
|
||||
std::vector<int64_t> kernel_shape = {1, x_shape[DIM1], k_size[DIM2], k_size[DIM3]};
|
||||
std::vector<int64_t> kernel_shape = {1, x_shape[kDim1], k_size[kDim2], k_size[kDim3]};
|
||||
auto data_size = std::accumulate(kernel_shape.begin(), kernel_shape.end(), int64_t(1), std::multiplies<int64_t>());
|
||||
std::vector<float> data(data_size, kKernelMatrixInitNum);
|
||||
auto kernel_matrix_tensor = std::make_shared<tensor::Tensor>(x_dtype, kernel_shape, &data[0], kNumberTypeFloat32);
|
||||
|
|
|
@ -35,17 +35,12 @@ AnfNodePtr CreateNewBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &bn_
|
|||
size_t kBNGradInputNum = 6;
|
||||
const auto &bn_grad_node_inputs = bn_grad_node->inputs();
|
||||
CheckCNodeInputSize(bn_grad_node, kBNGradInputNum);
|
||||
constexpr size_t DIM1 = 1;
|
||||
constexpr size_t DIM2 = 2;
|
||||
constexpr size_t DIM3 = 3;
|
||||
constexpr size_t DIM4 = 4;
|
||||
constexpr size_t DIM5 = 5;
|
||||
std::vector<AnfNodePtr> bn_grad_inputs = {NewValueNode(std::make_shared<Primitive>(kBatchNormGradOpName)),
|
||||
bn_grad_node_inputs[DIM1],
|
||||
bn_grad_node_inputs[DIM2],
|
||||
bn_grad_node_inputs[DIM3],
|
||||
bn_grad_node_inputs[DIM4],
|
||||
bn_grad_node_inputs[DIM5]};
|
||||
bn_grad_node_inputs[kDim1],
|
||||
bn_grad_node_inputs[kDim2],
|
||||
bn_grad_node_inputs[kDim3],
|
||||
bn_grad_node_inputs[kDim4],
|
||||
bn_grad_node_inputs[kDim5]};
|
||||
auto new_bn_grad = graph->NewCNode(bn_grad_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_bn_grad);
|
||||
new_bn_grad->set_scope(bn_grad_node->scope());
|
||||
|
|
|
@ -40,8 +40,8 @@ constexpr auto kAttrChannelMultiplier = "channel_multiplier";
|
|||
constexpr auto kAttrPerm = "perm";
|
||||
constexpr auto kAttrInputSizes = "input_sizes";
|
||||
constexpr auto kAttrInputSize = "input_size";
|
||||
constexpr auto kInput2 = 2;
|
||||
constexpr auto kInput3 = 3;
|
||||
constexpr auto kIndex2 = 2;
|
||||
constexpr auto kIndex3 = 3;
|
||||
|
||||
bool NeedUpdate(const CNodePtr &conv2d, std::vector<size_t> in_shape, std::vector<size_t> out_shape) {
|
||||
MS_EXCEPTION_IF_NULL(conv2d);
|
||||
|
@ -63,8 +63,8 @@ bool NeedUpdate(const CNodePtr &conv2d, std::vector<size_t> in_shape, std::vecto
|
|||
MS_LOG(EXCEPTION) << "Conv2D's input and output should have 4 axis, but got input axis num: " << in_shape.size()
|
||||
<< "output axis num: " << out_shape.size();
|
||||
}
|
||||
auto in_channel = in_shape[1];
|
||||
auto out_channel = out_shape[1];
|
||||
auto in_channel = in_shape[kDim1];
|
||||
auto out_channel = out_shape[kDim1];
|
||||
if (group != in_channel || group != out_channel) {
|
||||
return false;
|
||||
}
|
||||
|
@ -117,7 +117,7 @@ CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, cons
|
|||
MS_LOG(EXCEPTION) << "Conv2D's output axis number should be " << kConv2DAxisNum << ", but got "
|
||||
<< out_shape.size();
|
||||
}
|
||||
std::swap(out_shape[0], out_shape[1]);
|
||||
std::swap(out_shape[kDim0], out_shape[kDim1]);
|
||||
auto shapes = {out_shape};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, transpose.get());
|
||||
} else {
|
||||
|
@ -139,7 +139,7 @@ CNodePtr CreateDepthwiseConv2D(const FuncGraphPtr &graph, const CNodePtr &conv2d
|
|||
MS_EXCEPTION_IF_NULL(conv2d);
|
||||
CheckCNodeInputSize(conv2d, kConvInputTensorNum);
|
||||
std::vector<AnfNodePtr> depth_conv_inputs = {NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeOpName)),
|
||||
conv2d->input(1), transpose};
|
||||
conv2d->input(kIndex1), transpose};
|
||||
auto depth_conv = graph->NewCNode(depth_conv_inputs);
|
||||
MS_EXCEPTION_IF_NULL(depth_conv);
|
||||
depth_conv->set_abstract(conv2d->abstract());
|
||||
|
@ -155,15 +155,15 @@ CNodePtr CreateDepthwiseConv2DBackpropInput(const FuncGraphPtr &graph, const CNo
|
|||
CNodePtr depth_conv_backin = nullptr;
|
||||
if (conv2d_backin->inputs().size() == kConv2DBackpropInputNum) {
|
||||
std::vector<AnfNodePtr> depth_conv_backin_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)), conv2d_backin->input(3),
|
||||
transpose, conv2d_backin->input(1)};
|
||||
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)),
|
||||
conv2d_backin->input(kIndex3), transpose, conv2d_backin->input(kIndex1)};
|
||||
depth_conv_backin = graph->NewCNode(depth_conv_backin_inputs);
|
||||
} else {
|
||||
// In nn.Conv2DTranspose, Conv2DBackpropInput is a forward op and the input_sizes input will be convert to attr
|
||||
// in pynative mode.
|
||||
std::vector<AnfNodePtr> depth_conv_backin_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)), transpose,
|
||||
conv2d_backin->input(1)};
|
||||
conv2d_backin->input(kIndex1)};
|
||||
depth_conv_backin = graph->NewCNode(depth_conv_backin_inputs);
|
||||
AnfAlgo::CopyNodeAttr(kAttrInputSizes, kAttrInputSize, conv2d_backin, depth_conv_backin);
|
||||
}
|
||||
|
@ -180,7 +180,7 @@ CNodePtr CreateDepthwiseConv2DBackpropFilter(const FuncGraphPtr &graph, const CN
|
|||
MS_LOG(EXCEPTION) << "Conv2DBackpropFilter's input number should be " << (kConv2DBackpropInputNum - 1)
|
||||
<< ", but got " << (conv2d_backfil->inputs().size() - 1);
|
||||
}
|
||||
auto filter_size_node = conv2d_backfil->input(kInput3);
|
||||
auto filter_size_node = conv2d_backfil->input(kIndex3);
|
||||
MS_EXCEPTION_IF_NULL(filter_size_node);
|
||||
auto filter_size_vnode = filter_size_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(filter_size_vnode);
|
||||
|
@ -189,11 +189,11 @@ CNodePtr CreateDepthwiseConv2DBackpropFilter(const FuncGraphPtr &graph, const CN
|
|||
// when the filter_size value is same.
|
||||
if (filter_size[0] != 1) {
|
||||
std::swap(filter_size[0], filter_size[1]);
|
||||
conv2d_backfil->input(kInput3)->cast<ValueNodePtr>()->set_value(MakeValue(filter_size));
|
||||
conv2d_backfil->input(kIndex3)->cast<ValueNodePtr>()->set_value(MakeValue(filter_size));
|
||||
}
|
||||
std::vector<AnfNodePtr> depth_conv_backfil_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropFilterOpName)), conv2d_backfil->input(2),
|
||||
conv2d_backfil->input(kInput3), conv2d_backfil->input(1)};
|
||||
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropFilterOpName)),
|
||||
conv2d_backfil->input(kIndex2), conv2d_backfil->input(kIndex3), conv2d_backfil->input(kIndex1)};
|
||||
auto depth_conv_backfil = graph->NewCNode(depth_conv_backfil_inputs);
|
||||
MS_EXCEPTION_IF_NULL(depth_conv_backfil);
|
||||
depth_conv_backfil->set_scope(conv2d_backfil->scope());
|
||||
|
@ -276,7 +276,7 @@ const AnfNodePtr Conv2DUnifyMindIR::Process(const FuncGraphPtr &graph, const Anf
|
|||
return nullptr;
|
||||
}
|
||||
CheckCNodeInputSize(conv2d, kConvInputTensorNum);
|
||||
auto transpose = CreateTranspose(graph, conv2d, conv2d->input(kInput2), true);
|
||||
auto transpose = CreateTranspose(graph, conv2d, conv2d->input(kIndex2), true);
|
||||
auto depth_conv = CreateDepthwiseConv2D(graph, conv2d, transpose);
|
||||
SetConv2DAttrs(conv2d, depth_conv);
|
||||
return depth_conv;
|
||||
|
@ -307,7 +307,7 @@ const AnfNodePtr Conv2DBackpropInputUnifyMindIR::Process(const FuncGraphPtr &gra
|
|||
MS_LOG(EXCEPTION) << "Conv2DBackpropInput's input number should be " << (kConv2DBackpropInputNum - 1) << " or "
|
||||
<< (kConv2DBackpropInputNum - 2) << ", but got " << (input_size - 1);
|
||||
}
|
||||
auto transpose = CreateTranspose(graph, conv2d_backin, conv2d_backin->input(kInput2), true);
|
||||
auto transpose = CreateTranspose(graph, conv2d_backin, conv2d_backin->input(kIndex2), true);
|
||||
auto depth_conv_backin = CreateDepthwiseConv2DBackpropInput(graph, conv2d_backin, transpose);
|
||||
SetConv2DBackpropInputAttrs(conv2d_backin, depth_conv_backin);
|
||||
return depth_conv_backin;
|
||||
|
|
|
@ -40,7 +40,6 @@ constexpr int64_t kMaskAlignNum = 128;
|
|||
constexpr int64_t kMaskMultiNum = 16;
|
||||
constexpr size_t kFloat16Len = 2; // size of float16
|
||||
constexpr size_t kInt64Len = 8; // size of int64
|
||||
constexpr auto kInput2 = 2; // size of int64
|
||||
|
||||
TypeId GetInputXDataType(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -116,8 +115,8 @@ std::vector<int64_t> CalDropoutGenMaskOutput(const std::vector<int64_t> &shape)
|
|||
|
||||
bool NeedUpdate(const CNodePtr &getitem_cnode) {
|
||||
MS_EXCEPTION_IF_NULL(getitem_cnode);
|
||||
MS_EXCEPTION_IF_NULL(getitem_cnode->input(kInput2));
|
||||
auto index_vnode = getitem_cnode->input(kInput2)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(getitem_cnode->input(kIndex2));
|
||||
auto index_vnode = getitem_cnode->input(kIndex2)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(index_vnode);
|
||||
auto index_value = index_vnode->value();
|
||||
MS_EXCEPTION_IF_NULL(index_value);
|
||||
|
@ -127,6 +126,7 @@ bool NeedUpdate(const CNodePtr &getitem_cnode) {
|
|||
|
||||
CNodePtr CreateDynamicShapeCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node_input,
|
||||
const abstract::ShapePtr &input_shape) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> dynamic_shape_inputs{NewValueNode(std::make_shared<Primitive>("DynamicShape")), node_input};
|
||||
CNodePtr dynamic_shape = func_graph->NewCNode(dynamic_shape_inputs);
|
||||
MS_EXCEPTION_IF_NULL(dynamic_shape);
|
||||
|
@ -143,6 +143,8 @@ CNodePtr CreateDynamicShapeCNode(const FuncGraphPtr &func_graph, const AnfNodePt
|
|||
CNodePtr CreateDropoutGenMaskCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &dropout,
|
||||
const ValueNodePtr &keep_prob_value, const AnfNodePtr &dropout_input,
|
||||
const abstract::ShapePtr &input_shape) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dropout);
|
||||
std::vector<AnfNodePtr> dropout_gen_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName))};
|
||||
if (input_shape->IsDynamic()) {
|
||||
CNodePtr dynamic_shape = CreateDynamicShapeCNode(func_graph, dropout_input, input_shape);
|
||||
|
@ -163,6 +165,7 @@ CNodePtr CreateDropoutGenMaskCNode(const FuncGraphPtr &func_graph, const AnfNode
|
|||
ShapeVector mask_min_shp = CalDropoutGenMaskOutput(input_shape->min_shape());
|
||||
ShapeVector mask_max_shp = CalDropoutGenMaskOutput(input_shape->max_shape());
|
||||
auto gen_mask_shape = std::make_shared<abstract::Shape>(mask_shp, mask_min_shp, mask_max_shp);
|
||||
MS_EXCEPTION_IF_NULL(gen_mask_shape);
|
||||
gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, gen_mask_shape);
|
||||
} else {
|
||||
auto gen_mask_shape = CalDropoutGenMaskOutput(input_shape->shape());
|
||||
|
@ -204,11 +207,11 @@ const AnfNodePtr DropoutAndDropoutGradUnifyMindIR::Process(const FuncGraphPtr &f
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto dropout_grad_cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(dropout_grad_cnode);
|
||||
auto getitem1_node = dropout_grad_cnode->input(kInput2);
|
||||
auto getitem1_node = dropout_grad_cnode->input(kIndex2);
|
||||
MS_EXCEPTION_IF_NULL(getitem1_node);
|
||||
auto getitem1_cnode = getitem1_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(getitem1_cnode);
|
||||
auto dropout_node = getitem1_cnode->input(1);
|
||||
auto dropout_node = getitem1_cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(dropout_node);
|
||||
auto dropout_cnode = dropout_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(dropout_cnode);
|
||||
|
@ -216,7 +219,7 @@ const AnfNodePtr DropoutAndDropoutGradUnifyMindIR::Process(const FuncGraphPtr &f
|
|||
auto inputx_type_id = GetInputXDataType(dropout_node);
|
||||
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id);
|
||||
|
||||
auto dropout_input = dropout_cnode->input(1);
|
||||
auto dropout_input = dropout_cnode->input(kIndex1);
|
||||
auto input_shape = GetDropoutInputShape(dropout_input);
|
||||
// CreateDropoutGenMask
|
||||
auto dropout_gen_mask =
|
||||
|
@ -284,14 +287,14 @@ const AnfNodePtr DropoutUnifyMindIR0::Process(const FuncGraphPtr &func_graph, co
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
auto dropout_node = tuple_cnode->input(1);
|
||||
auto dropout_node = tuple_cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(dropout_node);
|
||||
auto inputx_type_id = GetInputXDataType(dropout_node);
|
||||
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id);
|
||||
|
||||
auto dropout_cnode = dropout_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(dropout_cnode);
|
||||
auto dropout_input = dropout_cnode->input(1);
|
||||
auto dropout_input = dropout_cnode->input(kIndex1);
|
||||
auto input_shape = GetDropoutInputShape(dropout_input);
|
||||
|
||||
// CreateDropoutGenMask
|
||||
|
@ -332,7 +335,7 @@ const AnfNodePtr DropoutUnifyMindIR1::Process(const FuncGraphPtr &func_graph, co
|
|||
auto inputx_type_id = GetInputXDataType(dropout_node);
|
||||
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id);
|
||||
|
||||
auto dropout_input = dropout_node->input(1);
|
||||
auto dropout_input = dropout_node->input(kIndex1);
|
||||
auto input_shape = GetDropoutInputShape(dropout_input);
|
||||
// CreateDropoutGenMask
|
||||
auto dropout_gen_mask =
|
||||
|
@ -371,7 +374,7 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph,
|
|||
|
||||
// DropoutGrad may not in the same graph with Dropout in heterogeneous scene, and mask input which is a parameter
|
||||
// in that scene, need to be updated.
|
||||
auto mask_input = dropout_grad_cnode->input(kInput2);
|
||||
auto mask_input = dropout_grad_cnode->input(kIndex2);
|
||||
if (mask_input->isa<Parameter>()) {
|
||||
// update abstract
|
||||
auto mask_abstract = mask_input->abstract();
|
||||
|
@ -387,7 +390,7 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph,
|
|||
}
|
||||
|
||||
// CreateDropoutDoMask
|
||||
auto grad_input = dropout_grad_cnode->input(1);
|
||||
auto grad_input = dropout_grad_cnode->input(kIndex1);
|
||||
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
|
||||
grad_input, mask_input, keep_prob_value};
|
||||
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);
|
||||
|
|
|
@ -38,8 +38,9 @@ void CreateOutputsOfLSQPerLayerGradD(const FuncGraphPtr &graph, const CNodePtr &
|
|||
<< " trace: " << trace::DumpSourceLines(lsq_perlayer_grad_node);
|
||||
}
|
||||
std::vector<AnfNodePtr> lsq_perlayer_grad_d_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kFakeLearnedScaleQuantPerLayerGradDOpName)), lsq_perlayer_grad_inputs[1],
|
||||
lsq_perlayer_grad_inputs[2], lsq_perlayer_grad_inputs[3], lsq_perlayer_grad_inputs[4]};
|
||||
NewValueNode(std::make_shared<Primitive>(kFakeLearnedScaleQuantPerLayerGradDOpName)),
|
||||
lsq_perlayer_grad_inputs[kIndex1], lsq_perlayer_grad_inputs[kIndex2], lsq_perlayer_grad_inputs[kIndex3],
|
||||
lsq_perlayer_grad_inputs[kIndex4]};
|
||||
auto lsq_perlayer_grad_d = graph->NewCNode(lsq_perlayer_grad_d_inputs);
|
||||
MS_EXCEPTION_IF_NULL(lsq_perlayer_grad_d);
|
||||
lsq_perlayer_grad_d->set_scope(lsq_perlayer_grad_node->scope());
|
||||
|
@ -72,7 +73,7 @@ void CreateOutputsOfLSQPerLayerReduceGrad(const FuncGraphPtr &graph, const CNode
|
|||
}
|
||||
std::vector<AnfNodePtr> lsq_perlayer_reduce_grad_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kFakeLearnedScaleQuantPerLayerGradDReduceOpName)),
|
||||
lsq_perlayer_grad_d_outputs[1]};
|
||||
lsq_perlayer_grad_d_outputs[kIndex1]};
|
||||
auto lsq_perlayer_reduce_grad = graph->NewCNode(lsq_perlayer_reduce_grad_inputs);
|
||||
MS_EXCEPTION_IF_NULL(lsq_perlayer_reduce_grad);
|
||||
lsq_perlayer_reduce_grad->set_scope(lsq_perlayer_grad_node->scope());
|
||||
|
@ -130,7 +131,7 @@ void CreateOutputsOfLSQPerChannelReduceGrad(const FuncGraphPtr &graph, const CNo
|
|||
}
|
||||
std::vector<AnfNodePtr> lsq_perchannel_reduce_grad_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kFakeLearnedScaleQuantPerChannelGradDReduceOpName)),
|
||||
lsq_perchannel_grad_d_outputs[1]};
|
||||
lsq_perchannel_grad_d_outputs[kIndex1]};
|
||||
auto lsq_perchannel_reduce_grad = graph->NewCNode(lsq_perchannel_reduce_grad_inputs);
|
||||
MS_EXCEPTION_IF_NULL(lsq_perchannel_reduce_grad);
|
||||
lsq_perchannel_reduce_grad->set_scope(lsq_perchannel_grad_node->scope());
|
||||
|
|
|
@ -32,9 +32,6 @@ constexpr size_t kMaxPoolInputNum = 2;
|
|||
constexpr size_t kMaxPoolAttrAxisNum = 4;
|
||||
constexpr size_t kMaxPoolGradInputNum = 4;
|
||||
constexpr size_t kMaxPoolWithArgmaxOutputNum = 2;
|
||||
constexpr size_t kIndex1 = 1;
|
||||
constexpr size_t kIndex2 = 2;
|
||||
constexpr size_t kIndex3 = 3;
|
||||
|
||||
CNodePtr GetMaxPool(const CNodePtr &maxpool_grad) {
|
||||
MS_EXCEPTION_IF_NULL(maxpool_grad);
|
||||
|
@ -55,7 +52,7 @@ CNodePtr CreateMaxPoolWithArgmax(const FuncGraphPtr &graph, const CNodePtr &maxp
|
|||
<< (maxpool->inputs().size() - 1);
|
||||
}
|
||||
std::vector<AnfNodePtr> maxpool_argmax_inputs = {NewValueNode(std::make_shared<Primitive>(kMaxPoolWithArgmaxOpName)),
|
||||
maxpool->input(1)};
|
||||
maxpool->input(kIndex1)};
|
||||
auto maxpool_argmax = graph->NewCNode(maxpool_argmax_inputs);
|
||||
MS_EXCEPTION_IF_NULL(maxpool_argmax);
|
||||
maxpool_argmax->set_scope(maxpool->scope());
|
||||
|
@ -80,8 +77,8 @@ CNodePtr CreateMaxPoolGradWithArgmax(const FuncGraphPtr &graph, const CNodePtr &
|
|||
// MaxPoolGrad's inputs are {input, output, grad_input}, MaxPoolGradWithArgmax's inputs are
|
||||
// {input, grad_input, argmax_output}
|
||||
std::vector<AnfNodePtr> maxpool_grad_argmax_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kMaxPoolGradWithArgmaxOpName)), maxpool_grad->input(1),
|
||||
maxpool_grad->input(kIndex3), maxpool_argmax_outputs[1]};
|
||||
NewValueNode(std::make_shared<Primitive>(kMaxPoolGradWithArgmaxOpName)), maxpool_grad->input(kIndex1),
|
||||
maxpool_grad->input(kIndex3), maxpool_argmax_outputs[kIndex1]};
|
||||
auto maxpool_grad_argmax = graph->NewCNode(maxpool_grad_argmax_inputs);
|
||||
MS_EXCEPTION_IF_NULL(maxpool_grad_argmax);
|
||||
maxpool_grad_argmax->set_scope(maxpool_grad->scope());
|
||||
|
|
|
@ -29,10 +29,6 @@ namespace {
|
|||
constexpr size_t kMaxPoolGradWithArgmaxInputNum = 4;
|
||||
constexpr size_t kMaxPoolWithArgmaxShape = 4;
|
||||
constexpr size_t kAlignBytes = 16;
|
||||
constexpr size_t kIndex1 = 1;
|
||||
constexpr size_t kIndex2 = 2;
|
||||
constexpr size_t kIndex3 = 3;
|
||||
constexpr size_t kIndex4 = 4;
|
||||
|
||||
bool IsC(const BaseRef &n) {
|
||||
if (utils::isa<AnfNodePtr>(n)) {
|
||||
|
@ -71,11 +67,11 @@ const AnfNodePtr MaxPoolWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &graph
|
|||
auto ksize = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(maxpool_with_argmax, kAttrKernelSize);
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(maxpool_with_argmax, 0);
|
||||
auto argmax_shape = output_shape;
|
||||
if (argmax_shape.size() != kMaxPoolWithArgmaxShape) {
|
||||
MS_LOG(DEBUG) << "argmax's infer shape size not equal 4";
|
||||
if (argmax_shape.size() != kMaxPoolWithArgmaxShape || ksize.size() != kMaxPoolWithArgmaxShape) {
|
||||
MS_LOG(EXCEPTION) << "argmax or kernel_size's shape size not equal to 4";
|
||||
}
|
||||
argmax_shape[kIndex2] = LongToSize(ksize[kIndex1] * ksize[kIndex2]);
|
||||
argmax_shape[kIndex3] = (output_shape[kIndex2] * output_shape[kIndex3] + kAlignBytes - 1) / kAlignBytes + 1;
|
||||
argmax_shape[kDim2] = LongToSize(ksize[kDim1] * ksize[kDim2]);
|
||||
argmax_shape[kDim3] = (output_shape[kDim2] * output_shape[kDim3] + kAlignBytes - 1) / kAlignBytes + 1;
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(maxpool_with_argmax, 0), argmax_dtype};
|
||||
auto shapes = {output_shape, argmax_shape};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, maxpool_with_argmax.get());
|
||||
|
@ -105,11 +101,11 @@ const AnfNodePtr MaxPoolGradWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &g
|
|||
TypeId argmax_dtype = kNumberTypeUInt16;
|
||||
auto ksize = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(maxpool_grad_with_argmax, kAttrKernelSize);
|
||||
auto argmax_shape = AnfAlgo::GetOutputInferShape(tuple_getitem0_anf, 0);
|
||||
if (argmax_shape.size() != kMaxPoolWithArgmaxShape) {
|
||||
MS_LOG(DEBUG) << "argmax's infer shape size not equal 4";
|
||||
if (argmax_shape.size() != kMaxPoolWithArgmaxShape || ksize.size() != kMaxPoolWithArgmaxShape) {
|
||||
MS_LOG(EXCEPTION) << "argmax or kernel_size's shape size not equal to 4";
|
||||
}
|
||||
argmax_shape[kIndex3] = (argmax_shape[kIndex2] * argmax_shape[kIndex3] + kAlignBytes - 1) / kAlignBytes + 1;
|
||||
argmax_shape[kIndex2] = LongToSize(ksize[kIndex1] * ksize[kIndex2]);
|
||||
argmax_shape[kDim3] = (argmax_shape[kDim2] * argmax_shape[kDim3] + kAlignBytes - 1) / kAlignBytes + 1;
|
||||
argmax_shape[kDim2] = LongToSize(ksize[kDim1] * ksize[kDim2]);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({argmax_dtype}, {argmax_shape}, tuple_getitem0_anf.get());
|
||||
|
||||
return maxpool_grad_with_argmax;
|
||||
|
|
|
@ -72,10 +72,8 @@ const AnfNodePtr SliceGradUnifyMindIR::Process(const FuncGraphPtr &graph, const
|
|||
|
||||
// set attr paddings
|
||||
auto x_shape = GetInputXShape(slice_grad);
|
||||
constexpr auto kInput3 = 3;
|
||||
constexpr auto kInput4 = 4;
|
||||
auto begins = GetTupleValue(slice_grad->input(kInput3));
|
||||
auto sizes = GetTupleValue(slice_grad->input(kInput4));
|
||||
auto begins = GetTupleValue(slice_grad->input(kIndex3));
|
||||
auto sizes = GetTupleValue(slice_grad->input(kIndex4));
|
||||
if (x_shape.size() != begins.size() || begins.size() != sizes.size()) {
|
||||
MS_LOG(EXCEPTION) << "For SliceGrad, x's shape dim number should be equal to len(begin) and len(size).";
|
||||
}
|
||||
|
|
|
@ -33,7 +33,6 @@ constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr auto kInput2 = 2;
|
||||
ValueNodePtr CreateValueNode(const ValuePtr &value_ptr, TypeId output_type) {
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
auto new_node = std::make_shared<ValueNode>(value_ptr);
|
||||
|
@ -72,6 +71,7 @@ CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_
|
|||
auto value_off_node = CreateValueNode(value_off, kNumberTypeFloat32);
|
||||
MS_EXCEPTION_IF_NULL(value_off_node);
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
kernel_graph->AddValueNodeToGraph(value_on_node);
|
||||
kernel_graph->AddValueNodeToGraph(value_off_node);
|
||||
|
||||
|
@ -83,14 +83,16 @@ CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_
|
|||
|
||||
std::vector<AnfNodePtr> one_hot_inputs;
|
||||
if (is_convert_const_to_attr) {
|
||||
one_hot_inputs = {NewValueNode(one_hot_primitive), sparse_softmax_node->input(2), value_on_node, value_off_node};
|
||||
one_hot_inputs = {NewValueNode(one_hot_primitive), sparse_softmax_node->input(kIndex2), value_on_node,
|
||||
value_off_node};
|
||||
} else {
|
||||
auto depth_node = NewValueNode(depth);
|
||||
MS_EXCEPTION_IF_NULL(depth_node);
|
||||
auto depth_abstract = std::make_shared<abstract::AbstractScalar>();
|
||||
MS_EXCEPTION_IF_NULL(depth_abstract);
|
||||
depth_abstract->set_type(kInt64);
|
||||
depth_node->set_abstract(depth_abstract);
|
||||
one_hot_inputs = {NewValueNode(one_hot_primitive), sparse_softmax_node->input(2), depth_node, value_on_node,
|
||||
one_hot_inputs = {NewValueNode(one_hot_primitive), sparse_softmax_node->input(kIndex2), depth_node, value_on_node,
|
||||
value_off_node};
|
||||
}
|
||||
auto one_hot_node = graph->NewCNode(one_hot_inputs);
|
||||
|
@ -112,7 +114,7 @@ CNodePtr CreateSoftmaxCrossEntropyWithLogits(const FuncGraphPtr &graph, const CN
|
|||
MS_EXCEPTION_IF_NULL(one_hot_node);
|
||||
CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kSoftmaxCrossEntropyWithLogitsOpName)),
|
||||
sparse_softmax_node->input(1), one_hot_node};
|
||||
sparse_softmax_node->input(kIndex1), one_hot_node};
|
||||
auto softmax_node = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(softmax_node);
|
||||
softmax_node->set_scope(sparse_softmax_node->scope());
|
||||
|
@ -171,6 +173,7 @@ CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_soft
|
|||
reduce_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
|
||||
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
if (is_pynative) {
|
||||
inputs = {NewValueNode(reduce_primitive), softmax_output_node};
|
||||
|
@ -199,6 +202,7 @@ CNodePtr CreateExpandDims(const FuncGraphPtr &graph, const CNodePtr &real_div_no
|
|||
auto axis_node = NewValueNode(axis);
|
||||
MS_EXCEPTION_IF_NULL(axis_node);
|
||||
auto axis_abstract = std::make_shared<abstract::AbstractScalar>();
|
||||
MS_EXCEPTION_IF_NULL(axis_abstract);
|
||||
axis_abstract->set_type(kInt64);
|
||||
axis_node->set_abstract(axis_abstract);
|
||||
|
||||
|
@ -305,6 +309,7 @@ CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax
|
|||
auto y_node = CreateValueNode(y, kNumberTypeFloat32);
|
||||
MS_EXCEPTION_IF_NULL(y_node);
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
kernel_graph->AddValueNodeToGraph(y_node);
|
||||
|
||||
auto real_div_primitive = std::make_shared<Primitive>(kRealDivOpName);
|
||||
|
@ -332,7 +337,7 @@ CNodePtr GetSparseNode(const CNodePtr &depend_node, size_t index) {
|
|||
CNodePtr GetDependNode(const CNodePtr &mul_node) {
|
||||
MS_EXCEPTION_IF_NULL(mul_node);
|
||||
CheckCNodeInputSize(mul_node, kMulInputTensorNum);
|
||||
auto depend_node = mul_node->input(1);
|
||||
auto depend_node = mul_node->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(depend_node);
|
||||
return depend_node->cast<CNodePtr>();
|
||||
}
|
||||
|
@ -357,6 +362,7 @@ CNodePtr CreateMul(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_nod
|
|||
MS_EXCEPTION_IF_NULL(y_node);
|
||||
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
kernel_graph->AddValueNodeToGraph(y_node);
|
||||
|
||||
auto mul_primitive = std::make_shared<Primitive>(kMulOpName);
|
||||
|
@ -427,7 +433,7 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con
|
|||
CheckCNodeInputSize(mul_node, kMulInputTensorNum);
|
||||
|
||||
auto depend_node = GetDependNode(mul_node);
|
||||
auto sparse_softmax_node = GetSparseNode(depend_node, kInput2);
|
||||
auto sparse_softmax_node = GetSparseNode(depend_node, kIndex2);
|
||||
auto sparse_softmax_node_grad = GetSparseNode(depend_node, 1);
|
||||
CheckCNodeInputSize(sparse_softmax_node_grad, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
|
||||
|
||||
|
@ -441,7 +447,7 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con
|
|||
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node);
|
||||
CNodePtr real_div_node;
|
||||
if (tile_node == nullptr) {
|
||||
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(kInput2));
|
||||
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(kIndex2));
|
||||
} else {
|
||||
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node);
|
||||
}
|
||||
|
@ -481,7 +487,7 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::Process(c
|
|||
|
||||
auto depend_node = node->cast<CNodePtr>();
|
||||
auto sparse_softmax_node_grad = GetSparseNode(depend_node, 1);
|
||||
auto sparse_softmax_node = GetSparseNode(depend_node, kInput2);
|
||||
auto sparse_softmax_node = GetSparseNode(depend_node, kIndex2);
|
||||
|
||||
CNodePtr softmax_node;
|
||||
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad);
|
||||
|
@ -541,7 +547,7 @@ const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Pro
|
|||
MS_EXCEPTION_IF_NULL(mul_node);
|
||||
CheckCNodeInputSize(mul_node, kMulInputTensorNum);
|
||||
|
||||
auto sparse_softmax_node = mul_node->input(1);
|
||||
auto sparse_softmax_node = mul_node->input(kIndex1);
|
||||
auto sparse_softmax_node_grad = sparse_softmax_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(sparse_softmax_node_grad);
|
||||
CheckCNodeInputSize(sparse_softmax_node_grad, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
|
||||
|
@ -555,7 +561,7 @@ const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Pro
|
|||
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node);
|
||||
CNodePtr real_div_node;
|
||||
if (tile_node == nullptr) {
|
||||
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(kInput2));
|
||||
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(kIndex2));
|
||||
} else {
|
||||
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node);
|
||||
}
|
||||
|
|
|
@ -883,12 +883,14 @@ bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &suppor
|
|||
ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
|
||||
MS_EXCEPTION_IF_NULL(new_value_node);
|
||||
new_value_node->set_abstract(value_node->abstract());
|
||||
// create kernel_info fo new value node
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
new_value_node->set_kernel_info(kernel_info);
|
||||
// create kernel_build_info for new value node
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
|
||||
// set the format of value_node to DEFAULT_FORMAT
|
||||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
|
||||
// set value node initial device data type = infer data type
|
||||
|
@ -920,6 +922,7 @@ void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const C
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap();
|
||||
|
|
|
@ -484,6 +484,35 @@ constexpr auto kUpdateStateRealInput = 2;
|
|||
// index define of Load
|
||||
constexpr auto kLoadRealInput = 1;
|
||||
constexpr auto kLoadStateInput = 2;
|
||||
// index of input or output
|
||||
enum Index : size_t {
|
||||
kIndex0 = 0,
|
||||
kIndex1,
|
||||
kIndex2,
|
||||
kIndex3,
|
||||
kIndex4,
|
||||
kIndex5,
|
||||
kIndex6,
|
||||
kIndex7,
|
||||
kIndex8,
|
||||
kIndex9,
|
||||
kIndex10,
|
||||
kIndex11,
|
||||
kIndex12,
|
||||
kIndex13,
|
||||
kIndex14,
|
||||
kIndex15,
|
||||
kIndex16,
|
||||
};
|
||||
// dim of shape
|
||||
enum Dim : size_t {
|
||||
kDim0 = 0,
|
||||
kDim1,
|
||||
kDim2,
|
||||
kDim3,
|
||||
kDim4,
|
||||
kDim5,
|
||||
};
|
||||
|
||||
// format
|
||||
constexpr auto kOpFormat_DEFAULT = "DefaultFormat";
|
||||
|
|
Loading…
Reference in New Issue