update resnet network performence.

less_bn pattern update.

fix clang-format
This commit is contained in:
linqingke 2021-06-22 15:26:56 +08:00
parent b698b41b51
commit 5ebf3bdf67
3 changed files with 169 additions and 38 deletions

View File

@ -177,7 +177,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Accelerated Algorithm
less_batch_normalization_ =
MakeSubstitution(std::make_shared<LessBatchNormalization>(), "less_batch_normalization", prim::kPrimAdd);
MakeSubstitution(std::make_shared<LessBatchNormalization>(), "less_batch_normalization",
{prim::kPrimAdd, prim::kPrimRelu6, prim::kPrimMatMul, prim::kPrimMakeTuple, prim::kPrimMaxPool});
// inline
inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph);

View File

@ -31,8 +31,8 @@ constexpr auto kFirstBranchPattern1 = 12;
constexpr auto kSecondBranchPattern1 = 3;
constexpr auto kFirstBranchStartIndexPattern1 = 4;
constexpr auto kFirstBranchEndIndexPattern1 = 11;
constexpr auto kSecondBranchStartIndexPattern1 = 12;
constexpr auto kSecondBranchEndIndexPattern1 = 14;
constexpr auto kSecondBranchStartIndexPattern1 = kFirstBranchPattern1;
constexpr auto kSecondBranchEndIndexPattern1 = 2 + kFirstBranchPattern1;
const std::vector<kStructureTuple> ResidualStructureBasePattern{
{kFirstBranchPattern1,
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
@ -47,8 +47,8 @@ constexpr auto kFirstBranchPattern2 = 12;
constexpr auto kSecondBranchPattern2 = 1;
constexpr auto kFirstBranchStartIndexPattern2 = 4;
constexpr auto kFirstBranchEndIndexPattern2 = 11;
constexpr auto kSecondBranchStartIndexPattern2 = 12;
constexpr auto kSecondBranchEndIndexPattern2 = 13;
constexpr auto kSecondBranchStartIndexPattern2 = kFirstBranchPattern2;
constexpr auto kSecondBranchEndIndexPattern2 = 1 + kSecondBranchPattern2;
const std::vector<kStructureTuple> ResidualStructureShortCutPattern{
{kFirstBranchPattern2,
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
@ -61,8 +61,8 @@ constexpr auto kFirstBranchPattern3 = 11;
constexpr auto kSecondBranchPattern3 = 3;
constexpr auto kFirstBranchStartIndexPattern3 = 4;
constexpr auto kFirstBranchEndIndexPattern3 = 10;
constexpr auto kSecondBranchStartIndexPattern3 = 11;
constexpr auto kSecondBranchEndIndexPattern3 = 13;
constexpr auto kSecondBranchStartIndexPattern3 = kFirstBranchPattern3;
constexpr auto kSecondBranchEndIndexPattern3 = 2 + kFirstBranchPattern3;
const std::vector<kStructureTuple> ResidualStructureFirstStepPattern{
{kFirstBranchPattern3,
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem,
@ -73,15 +73,13 @@ const std::vector<kStructureTuple> ResidualStructureFirstStepPattern{
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
{kSecondBranchStartIndexPattern3, kSecondBranchEndIndexPattern3}}};
// Pattern 4
// Add -> BatchNorm -> Conv2D -> Relu ... -> End
// ↘ BatchNorm -> Conv2D -> -> -> -> ↗
constexpr auto kFirstBranchPattern4 = 8;
constexpr auto kSecondBranchPattern4 = 3;
constexpr auto kFirstBranchStartIndexPattern4 = 4;
constexpr auto kFirstBranchEndIndexPattern4 = 6;
constexpr auto kSecondBranchStartIndexPattern4 = 8;
constexpr auto kSecondBranchEndIndexPattern4 = 11;
const std::vector<kStructureTuple> BasicStructureBasePattern{
constexpr auto kSecondBranchStartIndexPattern4 = kFirstBranchPattern4;
constexpr auto kSecondBranchEndIndexPattern4 = 3 + kFirstBranchPattern4;
const std::vector<kStructureTuple> BasicStructBasePattern{
{kFirstBranchPattern4,
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
{kFirstBranchStartIndexPattern4, kFirstBranchEndIndexPattern4}},
@ -89,37 +87,163 @@ const std::vector<kStructureTuple> BasicStructureBasePattern{
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
{kSecondBranchStartIndexPattern4, kSecondBranchEndIndexPattern4}}};
// Pattern 5
// Add -> BatchNorm -> Conv2D -> Relu ... -> End
// ↘ -> -> -> -> Relu -> -> -> -> ↗
constexpr auto kFirstBranchPattern5 = 8;
constexpr auto kFirstBranchPattern5 = 7;
constexpr auto kSecondBranchPattern5 = 1;
constexpr auto kFirstBranchStartIndexPattern5 = 4;
constexpr auto kFirstBranchEndIndexPattern5 = 6;
constexpr auto kSecondBranchStartIndexPattern5 = 8;
constexpr auto kSecondBranchEndIndexPattern5 = 11;
const std::vector<kStructureTuple> BasicStructureShortCutPattern{
constexpr auto kSecondBranchStartIndexPattern5 = kFirstBranchPattern5;
constexpr auto kSecondBranchEndIndexPattern5 = 3 + kFirstBranchPattern5;
const std::vector<kStructureTuple> BasicStructFirstStepPattern{
{kFirstBranchPattern5,
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem,
prim::kPrimBatchNorm, prim::kPrimConv2D},
{kFirstBranchStartIndexPattern5, kFirstBranchEndIndexPattern5}},
{kSecondBranchPattern5, {prim::kPrimRelu}, {kSecondBranchStartIndexPattern5, kSecondBranchEndIndexPattern5}}};
{kSecondBranchPattern5, {prim::kPrimMaxPool}, {kSecondBranchStartIndexPattern5, kSecondBranchEndIndexPattern5}}};
// Pattern 6
// Add -> BatchNorm -> Conv2D -> Relu ... -> End
// ↘ -> -> -> -> MaxPool -> -> -> ↗
constexpr auto kFirstBranchPattern6 = 7;
constexpr auto kFirstBranchPattern6 = 8;
constexpr auto kSecondBranchPattern6 = 1;
constexpr auto kFirstBranchStartIndexPattern6 = 4;
constexpr auto kFirstBranchEndIndexPattern6 = 6;
constexpr auto kSecondBranchStartIndexPattern6 = 7;
constexpr auto kSecondBranchEndIndexPattern6 = 10;
const std::vector<kStructureTuple> BasicStructureFirstStepPattern{
constexpr auto kSecondBranchStartIndexPattern6 = kFirstBranchPattern6;
constexpr auto kSecondBranchEndIndexPattern6 = 3 + kFirstBranchPattern6;
const std::vector<kStructureTuple> BasicStructShortCutPattern{
{kFirstBranchPattern6,
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem,
prim::kPrimBatchNorm, prim::kPrimConv2D},
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
{kFirstBranchStartIndexPattern6, kFirstBranchEndIndexPattern6}},
{kSecondBranchPattern6, {prim::kPrimMaxPool}, {kSecondBranchStartIndexPattern6, kSecondBranchEndIndexPattern6}}};
static const std::vector<std::vector<kStructureTuple>> kNeedMatchPattern = {
ResidualStructureBasePattern, ResidualStructureShortCutPattern, ResidualStructureFirstStepPattern,
BasicStructureBasePattern, BasicStructureShortCutPattern, BasicStructureFirstStepPattern};
{kSecondBranchPattern6, {prim::kPrimRelu}, {kSecondBranchStartIndexPattern6, kSecondBranchEndIndexPattern6}}};
// Pattern 7
constexpr auto kFirstBranchPattern7 = 1;
constexpr auto kSecondBranchPattern7 = 13;
constexpr auto kFirstBranchStartIndexPattern7 = SIZE_MAX;
constexpr auto kFirstBranchEndIndexPattern7 = SIZE_MAX;
constexpr auto kSecondBranchStartIndexPattern7 = 7;
constexpr auto kSecondBranchEndIndexPattern7 = 10;
const std::vector<kStructureTuple> InvertedResidualShortCutPattern{
{kFirstBranchPattern7,
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm},
{kFirstBranchStartIndexPattern7, kFirstBranchEndIndexPattern7}},
{kSecondBranchPattern7,
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem,
prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm,
prim::kPrimConv2D, prim::kPrimTupleGetItem, prim::kPrimBatchNorm},
{kSecondBranchStartIndexPattern7, kSecondBranchEndIndexPattern7}}};
// Pattern 8
constexpr auto kFirstBranchPattern8 = 4;
constexpr auto kFirstBranchStartIndexPattern8 = 0;
constexpr auto kFirstBranchEndIndexPattern8 = 3;
const std::vector<kStructureTuple> InvertedResidualPattern{
{kFirstBranchPattern8,
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimAdd},
{kFirstBranchStartIndexPattern8, kFirstBranchEndIndexPattern8}}};
// Pattern 9
constexpr auto kFirstBranchPattern9 = 1;
constexpr auto kSecondBranchPattern9 = 12;
constexpr auto kFirstBranchStartIndexPattern9 = SIZE_MAX;
constexpr auto kFirstBranchEndIndexPattern9 = SIZE_MAX;
constexpr auto kSecondBranchStartIndexPattern9 = 7;
constexpr auto kSecondBranchEndIndexPattern9 = 10;
const std::vector<kStructureTuple> InvertedResidualShortCutPattern2{
{kFirstBranchPattern9, {prim::kPrimAdd}, {kFirstBranchStartIndexPattern9, kFirstBranchEndIndexPattern9}},
{kSecondBranchPattern9,
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem,
prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm,
prim::kPrimConv2D, prim::kPrimAdd},
{kSecondBranchStartIndexPattern9, kSecondBranchEndIndexPattern9}}};
// Pattern 10
constexpr auto kFirstBranchPattern10 = 5;
constexpr auto kFirstBranchStartIndexPattern10 = 0;
constexpr auto kFirstBranchEndIndexPattern10 = 4;
const std::vector<kStructureTuple> InvertedResidualPattern2{
{kFirstBranchPattern10,
{prim::kPrimReduceMean, prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
{kFirstBranchStartIndexPattern10, kFirstBranchEndIndexPattern10}}};
// Pattern 11
constexpr auto kFirstBranchPattern11 = 17;
constexpr auto kFirstBranchStartIndexPattern11 = 3;
constexpr auto kFirstBranchEndIndexPattern11 = 6;
const std::vector<kStructureTuple> InvertedResidualPattern3{
{kFirstBranchPattern11,
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem,
prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6,
prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
{kFirstBranchStartIndexPattern11, kFirstBranchEndIndexPattern11}}};
// Pattern 12
constexpr auto kFirstBranchPattern12 = 1;
constexpr auto kSecondBranchPattern12 = 9;
constexpr auto kFirstBranchStartIndexPattern12 = SIZE_MAX;
constexpr auto kFirstBranchEndIndexPattern12 = SIZE_MAX;
constexpr auto kSecondBranchStartIndexPattern12 = kFirstBranchPattern12 + 5;
constexpr auto kSecondBranchEndIndexPattern12 = kFirstBranchPattern12 + 8;
const std::vector<kStructureTuple> DenseBlockShortCutPattern{
{kFirstBranchPattern12, {prim::kPrimConcat}, {kFirstBranchStartIndexPattern12, kFirstBranchEndIndexPattern12}},
{kSecondBranchPattern12,
{prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat},
{kSecondBranchStartIndexPattern12, kSecondBranchEndIndexPattern12}}};
// Pattern 13
constexpr auto kFirstBranchPattern13 = 5;
constexpr auto kFirstBranchStartIndexPattern13 = 0;
constexpr auto kFirstBranchEndIndexPattern13 = 4;
const std::vector<kStructureTuple> DenseBlockPattern{
{kFirstBranchPattern13,
{prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat},
{kFirstBranchStartIndexPattern13, kFirstBranchEndIndexPattern13}}};
// Pattern 14
constexpr auto kFirstBranchPattern14 = 9;
constexpr auto kSecondBranchPattern14 = 1;
constexpr auto kFirstBranchStartIndexPattern14 = 5;
constexpr auto kFirstBranchEndIndexPattern14 = 8;
constexpr auto kSecondBranchStartIndexPattern14 = SIZE_MAX;
constexpr auto kSecondBranchEndIndexPattern14 = SIZE_MAX;
const std::vector<kStructureTuple> DenseBlockShortCutPattern2{
{kFirstBranchPattern14,
{prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat},
{kFirstBranchStartIndexPattern14, kFirstBranchEndIndexPattern14}},
{kSecondBranchPattern14, {prim::kPrimConcat}, {kSecondBranchStartIndexPattern14, kSecondBranchEndIndexPattern14}}};
// Pattern 15
constexpr auto kFirstBranchPattern15 = 9;
constexpr auto kSecondBranchPattern15 = 1;
constexpr auto kFirstBranchStartIndexPattern15 = 0;
constexpr auto kFirstBranchEndIndexPattern15 = 4;
constexpr auto kSecondBranchStartIndexPattern15 = SIZE_MAX;
constexpr auto kSecondBranchEndIndexPattern15 = SIZE_MAX;
const std::vector<kStructureTuple> DenseBlockPoolPattern{
{kFirstBranchPattern15,
{prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimMaxPool},
{kFirstBranchStartIndexPattern15, kFirstBranchEndIndexPattern15}},
{kSecondBranchPattern15, {prim::kPrimConcat}, {kSecondBranchStartIndexPattern15, kSecondBranchEndIndexPattern15}}};
// Pattern 16
constexpr auto kFirstBranchPattern16 = 1;
constexpr auto kSecondBranchPattern16 = 9;
constexpr auto kFirstBranchStartIndexPattern16 = SIZE_MAX;
constexpr auto kFirstBranchEndIndexPattern16 = SIZE_MAX;
constexpr auto kSecondBranchStartIndexPattern16 = kFirstBranchPattern16;
constexpr auto kSecondBranchEndIndexPattern16 = kFirstBranchPattern16 + 4;
const std::vector<kStructureTuple> DenseBlockPoolPatter2{
{kFirstBranchPattern16, {prim::kPrimConcat}, {kFirstBranchStartIndexPattern16, kFirstBranchEndIndexPattern16}},
{kSecondBranchPattern16,
{prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimMaxPool},
{kSecondBranchStartIndexPattern16, kSecondBranchEndIndexPattern16}}};
static const std::vector<std::vector<kStructureTuple>> kNeedMatchPattern = {ResidualStructureBasePattern,
ResidualStructureShortCutPattern,
ResidualStructureFirstStepPattern,
BasicStructBasePattern,
BasicStructFirstStepPattern,
BasicStructShortCutPattern,
InvertedResidualShortCutPattern,
InvertedResidualPattern,
InvertedResidualShortCutPattern2,
InvertedResidualPattern2,
InvertedResidualPattern3,
DenseBlockShortCutPattern,
DenseBlockPattern,
DenseBlockShortCutPattern2,
DenseBlockPoolPattern,
DenseBlockPoolPatter2};
const std::set<PrimitivePtr> kNeedRemoveNodeSet{
prim::kPrimLoad, prim::kPrimRefToEmbed, prim::kPrimApplyMomentum, prim::kPrimMomentum,
prim::kPrimApplyFtrl, prim::kPrimSGD, prim::kPrimApplyRMSProp, prim::kPrimAdam};
@ -286,7 +410,13 @@ AnfNodePtr LessBatchNormalization::operator()(const OptimizerPtr &optimizer, con
sum_match_node += std::get<0>(t);
total_match_node_.emplace_back(sum_match_node);
});
AnfVisitor::Match(prim::kPrimAdd, {IsCNode, IsCNode})(node);
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr || cnode->inputs().empty()) {
return nullptr;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
std::vector<PredicateFuncType> funcs(cnode->inputs().size() - 1, IsCNode);
AnfVisitor::Match(prim, funcs)(node);
if (is_match_) {
break;
}

View File

@ -124,9 +124,9 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
device_num = 1
if device_num == 1:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True)
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True,
num_shards=device_num, shard_id=rank_id)
image_size = 224
@ -152,7 +152,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
type_cast_op = C2.TypeCast(mstype.int32)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=12)
# only enable cache for eval
if do_train:
enable_cache = False
@ -160,10 +160,10 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
if not cache_session_id:
raise ValueError("A cache session_id must be provided to use cache.")
eval_cache = ds.DatasetCache(session_id=int(cache_session_id), size=0)
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8,
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=12,
cache=eval_cache)
else:
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=12)
# apply batch operations
data_set = data_set.batch(batch_size, drop_remainder=True)