diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index 09d3dba818c..d3ff88e572f 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -177,7 +177,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { // Accelerated Algorithm less_batch_normalization_ = - MakeSubstitution(std::make_shared(), "less_batch_normalization", prim::kPrimAdd); + MakeSubstitution(std::make_shared(), "less_batch_normalization", + {prim::kPrimAdd, prim::kPrimRelu6, prim::kPrimMatMul, prim::kPrimMakeTuple, prim::kPrimMaxPool}); // inline inline_ = MakeSubstitution(std::make_shared(), "inline", IsCNodeGraph); diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/less_batch_normalization.cc b/mindspore/ccsrc/frontend/optimizer/irpass/less_batch_normalization.cc index 9ce30eeaa7f..77ca4b7f752 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/less_batch_normalization.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/less_batch_normalization.cc @@ -31,8 +31,8 @@ constexpr auto kFirstBranchPattern1 = 12; constexpr auto kSecondBranchPattern1 = 3; constexpr auto kFirstBranchStartIndexPattern1 = 4; constexpr auto kFirstBranchEndIndexPattern1 = 11; -constexpr auto kSecondBranchStartIndexPattern1 = 12; -constexpr auto kSecondBranchEndIndexPattern1 = 14; +constexpr auto kSecondBranchStartIndexPattern1 = kFirstBranchPattern1; +constexpr auto kSecondBranchEndIndexPattern1 = 2 + kFirstBranchPattern1; const std::vector ResidualStructureBasePattern{ {kFirstBranchPattern1, {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu}, @@ -47,8 +47,8 @@ constexpr auto kFirstBranchPattern2 = 12; constexpr auto kSecondBranchPattern2 = 1; constexpr auto kFirstBranchStartIndexPattern2 = 4; constexpr auto kFirstBranchEndIndexPattern2 = 11; -constexpr auto kSecondBranchStartIndexPattern2 = 12; -constexpr auto kSecondBranchEndIndexPattern2 = 13; +constexpr auto kSecondBranchStartIndexPattern2 = kFirstBranchPattern2; +constexpr auto kSecondBranchEndIndexPattern2 = 1 + kSecondBranchPattern2; const std::vector ResidualStructureShortCutPattern{ {kFirstBranchPattern2, {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu}, @@ -61,8 +61,8 @@ constexpr auto kFirstBranchPattern3 = 11; constexpr auto kSecondBranchPattern3 = 3; constexpr auto kFirstBranchStartIndexPattern3 = 4; constexpr auto kFirstBranchEndIndexPattern3 = 10; -constexpr auto kSecondBranchStartIndexPattern3 = 11; -constexpr auto kSecondBranchEndIndexPattern3 = 13; +constexpr auto kSecondBranchStartIndexPattern3 = kFirstBranchPattern3; +constexpr auto kSecondBranchEndIndexPattern3 = 2 + kFirstBranchPattern3; const std::vector ResidualStructureFirstStepPattern{ {kFirstBranchPattern3, {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, @@ -73,15 +73,13 @@ const std::vector ResidualStructureFirstStepPattern{ {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D}, {kSecondBranchStartIndexPattern3, kSecondBranchEndIndexPattern3}}}; // Pattern 4 -// Add -> BatchNorm -> Conv2D -> Relu ... -> End -// ↘ BatchNorm -> Conv2D -> -> -> -> ↗ constexpr auto kFirstBranchPattern4 = 8; constexpr auto kSecondBranchPattern4 = 3; constexpr auto kFirstBranchStartIndexPattern4 = 4; constexpr auto kFirstBranchEndIndexPattern4 = 6; -constexpr auto kSecondBranchStartIndexPattern4 = 8; -constexpr auto kSecondBranchEndIndexPattern4 = 11; -const std::vector BasicStructureBasePattern{ +constexpr auto kSecondBranchStartIndexPattern4 = kFirstBranchPattern4; +constexpr auto kSecondBranchEndIndexPattern4 = 3 + kFirstBranchPattern4; +const std::vector BasicStructBasePattern{ {kFirstBranchPattern4, {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu}, {kFirstBranchStartIndexPattern4, kFirstBranchEndIndexPattern4}}, @@ -89,37 +87,163 @@ const std::vector BasicStructureBasePattern{ {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D}, {kSecondBranchStartIndexPattern4, kSecondBranchEndIndexPattern4}}}; // Pattern 5 -// Add -> BatchNorm -> Conv2D -> Relu ... -> End -// ↘ -> -> -> -> Relu -> -> -> -> ↗ -constexpr auto kFirstBranchPattern5 = 8; +constexpr auto kFirstBranchPattern5 = 7; constexpr auto kSecondBranchPattern5 = 1; constexpr auto kFirstBranchStartIndexPattern5 = 4; constexpr auto kFirstBranchEndIndexPattern5 = 6; -constexpr auto kSecondBranchStartIndexPattern5 = 8; -constexpr auto kSecondBranchEndIndexPattern5 = 11; -const std::vector BasicStructureShortCutPattern{ +constexpr auto kSecondBranchStartIndexPattern5 = kFirstBranchPattern5; +constexpr auto kSecondBranchEndIndexPattern5 = 3 + kFirstBranchPattern5; +const std::vector BasicStructFirstStepPattern{ {kFirstBranchPattern5, - {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu}, + {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, + prim::kPrimBatchNorm, prim::kPrimConv2D}, {kFirstBranchStartIndexPattern5, kFirstBranchEndIndexPattern5}}, - {kSecondBranchPattern5, {prim::kPrimRelu}, {kSecondBranchStartIndexPattern5, kSecondBranchEndIndexPattern5}}}; + {kSecondBranchPattern5, {prim::kPrimMaxPool}, {kSecondBranchStartIndexPattern5, kSecondBranchEndIndexPattern5}}}; // Pattern 6 -// Add -> BatchNorm -> Conv2D -> Relu ... -> End -// ↘ -> -> -> -> MaxPool -> -> -> ↗ -constexpr auto kFirstBranchPattern6 = 7; +constexpr auto kFirstBranchPattern6 = 8; constexpr auto kSecondBranchPattern6 = 1; constexpr auto kFirstBranchStartIndexPattern6 = 4; constexpr auto kFirstBranchEndIndexPattern6 = 6; -constexpr auto kSecondBranchStartIndexPattern6 = 7; -constexpr auto kSecondBranchEndIndexPattern6 = 10; -const std::vector BasicStructureFirstStepPattern{ +constexpr auto kSecondBranchStartIndexPattern6 = kFirstBranchPattern6; +constexpr auto kSecondBranchEndIndexPattern6 = 3 + kFirstBranchPattern6; +const std::vector BasicStructShortCutPattern{ {kFirstBranchPattern6, - {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, - prim::kPrimBatchNorm, prim::kPrimConv2D}, + {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu}, {kFirstBranchStartIndexPattern6, kFirstBranchEndIndexPattern6}}, - {kSecondBranchPattern6, {prim::kPrimMaxPool}, {kSecondBranchStartIndexPattern6, kSecondBranchEndIndexPattern6}}}; -static const std::vector> kNeedMatchPattern = { - ResidualStructureBasePattern, ResidualStructureShortCutPattern, ResidualStructureFirstStepPattern, - BasicStructureBasePattern, BasicStructureShortCutPattern, BasicStructureFirstStepPattern}; + {kSecondBranchPattern6, {prim::kPrimRelu}, {kSecondBranchStartIndexPattern6, kSecondBranchEndIndexPattern6}}}; +// Pattern 7 +constexpr auto kFirstBranchPattern7 = 1; +constexpr auto kSecondBranchPattern7 = 13; +constexpr auto kFirstBranchStartIndexPattern7 = SIZE_MAX; +constexpr auto kFirstBranchEndIndexPattern7 = SIZE_MAX; +constexpr auto kSecondBranchStartIndexPattern7 = 7; +constexpr auto kSecondBranchEndIndexPattern7 = 10; +const std::vector InvertedResidualShortCutPattern{ + {kFirstBranchPattern7, + {prim::kPrimTupleGetItem, prim::kPrimBatchNorm}, + {kFirstBranchStartIndexPattern7, kFirstBranchEndIndexPattern7}}, + {kSecondBranchPattern7, + {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem, + prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, + prim::kPrimConv2D, prim::kPrimTupleGetItem, prim::kPrimBatchNorm}, + {kSecondBranchStartIndexPattern7, kSecondBranchEndIndexPattern7}}}; +// Pattern 8 +constexpr auto kFirstBranchPattern8 = 4; +constexpr auto kFirstBranchStartIndexPattern8 = 0; +constexpr auto kFirstBranchEndIndexPattern8 = 3; +const std::vector InvertedResidualPattern{ + {kFirstBranchPattern8, + {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimAdd}, + {kFirstBranchStartIndexPattern8, kFirstBranchEndIndexPattern8}}}; +// Pattern 9 +constexpr auto kFirstBranchPattern9 = 1; +constexpr auto kSecondBranchPattern9 = 12; +constexpr auto kFirstBranchStartIndexPattern9 = SIZE_MAX; +constexpr auto kFirstBranchEndIndexPattern9 = SIZE_MAX; +constexpr auto kSecondBranchStartIndexPattern9 = 7; +constexpr auto kSecondBranchEndIndexPattern9 = 10; +const std::vector InvertedResidualShortCutPattern2{ + {kFirstBranchPattern9, {prim::kPrimAdd}, {kFirstBranchStartIndexPattern9, kFirstBranchEndIndexPattern9}}, + {kSecondBranchPattern9, + {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem, + prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, + prim::kPrimConv2D, prim::kPrimAdd}, + {kSecondBranchStartIndexPattern9, kSecondBranchEndIndexPattern9}}}; +// Pattern 10 +constexpr auto kFirstBranchPattern10 = 5; +constexpr auto kFirstBranchStartIndexPattern10 = 0; +constexpr auto kFirstBranchEndIndexPattern10 = 4; +const std::vector InvertedResidualPattern2{ + {kFirstBranchPattern10, + {prim::kPrimReduceMean, prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D}, + {kFirstBranchStartIndexPattern10, kFirstBranchEndIndexPattern10}}}; +// Pattern 11 +constexpr auto kFirstBranchPattern11 = 17; +constexpr auto kFirstBranchStartIndexPattern11 = 3; +constexpr auto kFirstBranchEndIndexPattern11 = 6; +const std::vector InvertedResidualPattern3{ + {kFirstBranchPattern11, + {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem, + prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, + prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, + prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D}, + {kFirstBranchStartIndexPattern11, kFirstBranchEndIndexPattern11}}}; +// Pattern 12 +constexpr auto kFirstBranchPattern12 = 1; +constexpr auto kSecondBranchPattern12 = 9; +constexpr auto kFirstBranchStartIndexPattern12 = SIZE_MAX; +constexpr auto kFirstBranchEndIndexPattern12 = SIZE_MAX; +constexpr auto kSecondBranchStartIndexPattern12 = kFirstBranchPattern12 + 5; +constexpr auto kSecondBranchEndIndexPattern12 = kFirstBranchPattern12 + 8; +const std::vector DenseBlockShortCutPattern{ + {kFirstBranchPattern12, {prim::kPrimConcat}, {kFirstBranchStartIndexPattern12, kFirstBranchEndIndexPattern12}}, + {kSecondBranchPattern12, + {prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, + prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat}, + {kSecondBranchStartIndexPattern12, kSecondBranchEndIndexPattern12}}}; +// Pattern 13 +constexpr auto kFirstBranchPattern13 = 5; +constexpr auto kFirstBranchStartIndexPattern13 = 0; +constexpr auto kFirstBranchEndIndexPattern13 = 4; +const std::vector DenseBlockPattern{ + {kFirstBranchPattern13, + {prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat}, + {kFirstBranchStartIndexPattern13, kFirstBranchEndIndexPattern13}}}; +// Pattern 14 +constexpr auto kFirstBranchPattern14 = 9; +constexpr auto kSecondBranchPattern14 = 1; +constexpr auto kFirstBranchStartIndexPattern14 = 5; +constexpr auto kFirstBranchEndIndexPattern14 = 8; +constexpr auto kSecondBranchStartIndexPattern14 = SIZE_MAX; +constexpr auto kSecondBranchEndIndexPattern14 = SIZE_MAX; +const std::vector DenseBlockShortCutPattern2{ + {kFirstBranchPattern14, + {prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, + prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat}, + {kFirstBranchStartIndexPattern14, kFirstBranchEndIndexPattern14}}, + {kSecondBranchPattern14, {prim::kPrimConcat}, {kSecondBranchStartIndexPattern14, kSecondBranchEndIndexPattern14}}}; +// Pattern 15 +constexpr auto kFirstBranchPattern15 = 9; +constexpr auto kSecondBranchPattern15 = 1; +constexpr auto kFirstBranchStartIndexPattern15 = 0; +constexpr auto kFirstBranchEndIndexPattern15 = 4; +constexpr auto kSecondBranchStartIndexPattern15 = SIZE_MAX; +constexpr auto kSecondBranchEndIndexPattern15 = SIZE_MAX; +const std::vector DenseBlockPoolPattern{ + {kFirstBranchPattern15, + {prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, + prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimMaxPool}, + {kFirstBranchStartIndexPattern15, kFirstBranchEndIndexPattern15}}, + {kSecondBranchPattern15, {prim::kPrimConcat}, {kSecondBranchStartIndexPattern15, kSecondBranchEndIndexPattern15}}}; +// Pattern 16 +constexpr auto kFirstBranchPattern16 = 1; +constexpr auto kSecondBranchPattern16 = 9; +constexpr auto kFirstBranchStartIndexPattern16 = SIZE_MAX; +constexpr auto kFirstBranchEndIndexPattern16 = SIZE_MAX; +constexpr auto kSecondBranchStartIndexPattern16 = kFirstBranchPattern16; +constexpr auto kSecondBranchEndIndexPattern16 = kFirstBranchPattern16 + 4; +const std::vector DenseBlockPoolPatter2{ + {kFirstBranchPattern16, {prim::kPrimConcat}, {kFirstBranchStartIndexPattern16, kFirstBranchEndIndexPattern16}}, + {kSecondBranchPattern16, + {prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, + prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimMaxPool}, + {kSecondBranchStartIndexPattern16, kSecondBranchEndIndexPattern16}}}; +static const std::vector> kNeedMatchPattern = {ResidualStructureBasePattern, + ResidualStructureShortCutPattern, + ResidualStructureFirstStepPattern, + BasicStructBasePattern, + BasicStructFirstStepPattern, + BasicStructShortCutPattern, + InvertedResidualShortCutPattern, + InvertedResidualPattern, + InvertedResidualShortCutPattern2, + InvertedResidualPattern2, + InvertedResidualPattern3, + DenseBlockShortCutPattern, + DenseBlockPattern, + DenseBlockShortCutPattern2, + DenseBlockPoolPattern, + DenseBlockPoolPatter2}; const std::set kNeedRemoveNodeSet{ prim::kPrimLoad, prim::kPrimRefToEmbed, prim::kPrimApplyMomentum, prim::kPrimMomentum, prim::kPrimApplyFtrl, prim::kPrimSGD, prim::kPrimApplyRMSProp, prim::kPrimAdam}; @@ -286,7 +410,13 @@ AnfNodePtr LessBatchNormalization::operator()(const OptimizerPtr &optimizer, con sum_match_node += std::get<0>(t); total_match_node_.emplace_back(sum_match_node); }); - AnfVisitor::Match(prim::kPrimAdd, {IsCNode, IsCNode})(node); + auto cnode = node->cast(); + if (cnode == nullptr || cnode->inputs().empty()) { + return nullptr; + } + auto prim = GetValueNode(cnode->input(0)); + std::vector funcs(cnode->inputs().size() - 1, IsCNode); + AnfVisitor::Match(prim, funcs)(node); if (is_match_) { break; } diff --git a/model_zoo/official/cv/resnet/src/dataset.py b/model_zoo/official/cv/resnet/src/dataset.py index 13d76701f7f..34ab2869a6b 100755 --- a/model_zoo/official/cv/resnet/src/dataset.py +++ b/model_zoo/official/cv/resnet/src/dataset.py @@ -124,9 +124,9 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target= device_num = 1 if device_num == 1: - data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) + data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True) else: - data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, + data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True, num_shards=device_num, shard_id=rank_id) image_size = 224 @@ -152,7 +152,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target= type_cast_op = C2.TypeCast(mstype.int32) - data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8) + data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=12) # only enable cache for eval if do_train: enable_cache = False @@ -160,10 +160,10 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target= if not cache_session_id: raise ValueError("A cache session_id must be provided to use cache.") eval_cache = ds.DatasetCache(session_id=int(cache_session_id), size=0) - data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8, + data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=12, cache=eval_cache) else: - data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8) + data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=12) # apply batch operations data_set = data_set.batch(batch_size, drop_remainder=True)