forked from mindspore-Ecosystem/mindspore
update resnet network performence.
less_bn pattern update. fix clang-format
This commit is contained in:
parent
b698b41b51
commit
5ebf3bdf67
|
@ -177,7 +177,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
||||||
|
|
||||||
// Accelerated Algorithm
|
// Accelerated Algorithm
|
||||||
less_batch_normalization_ =
|
less_batch_normalization_ =
|
||||||
MakeSubstitution(std::make_shared<LessBatchNormalization>(), "less_batch_normalization", prim::kPrimAdd);
|
MakeSubstitution(std::make_shared<LessBatchNormalization>(), "less_batch_normalization",
|
||||||
|
{prim::kPrimAdd, prim::kPrimRelu6, prim::kPrimMatMul, prim::kPrimMakeTuple, prim::kPrimMaxPool});
|
||||||
|
|
||||||
// inline
|
// inline
|
||||||
inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph);
|
inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph);
|
||||||
|
|
|
@ -31,8 +31,8 @@ constexpr auto kFirstBranchPattern1 = 12;
|
||||||
constexpr auto kSecondBranchPattern1 = 3;
|
constexpr auto kSecondBranchPattern1 = 3;
|
||||||
constexpr auto kFirstBranchStartIndexPattern1 = 4;
|
constexpr auto kFirstBranchStartIndexPattern1 = 4;
|
||||||
constexpr auto kFirstBranchEndIndexPattern1 = 11;
|
constexpr auto kFirstBranchEndIndexPattern1 = 11;
|
||||||
constexpr auto kSecondBranchStartIndexPattern1 = 12;
|
constexpr auto kSecondBranchStartIndexPattern1 = kFirstBranchPattern1;
|
||||||
constexpr auto kSecondBranchEndIndexPattern1 = 14;
|
constexpr auto kSecondBranchEndIndexPattern1 = 2 + kFirstBranchPattern1;
|
||||||
const std::vector<kStructureTuple> ResidualStructureBasePattern{
|
const std::vector<kStructureTuple> ResidualStructureBasePattern{
|
||||||
{kFirstBranchPattern1,
|
{kFirstBranchPattern1,
|
||||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
|
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
|
||||||
|
@ -47,8 +47,8 @@ constexpr auto kFirstBranchPattern2 = 12;
|
||||||
constexpr auto kSecondBranchPattern2 = 1;
|
constexpr auto kSecondBranchPattern2 = 1;
|
||||||
constexpr auto kFirstBranchStartIndexPattern2 = 4;
|
constexpr auto kFirstBranchStartIndexPattern2 = 4;
|
||||||
constexpr auto kFirstBranchEndIndexPattern2 = 11;
|
constexpr auto kFirstBranchEndIndexPattern2 = 11;
|
||||||
constexpr auto kSecondBranchStartIndexPattern2 = 12;
|
constexpr auto kSecondBranchStartIndexPattern2 = kFirstBranchPattern2;
|
||||||
constexpr auto kSecondBranchEndIndexPattern2 = 13;
|
constexpr auto kSecondBranchEndIndexPattern2 = 1 + kSecondBranchPattern2;
|
||||||
const std::vector<kStructureTuple> ResidualStructureShortCutPattern{
|
const std::vector<kStructureTuple> ResidualStructureShortCutPattern{
|
||||||
{kFirstBranchPattern2,
|
{kFirstBranchPattern2,
|
||||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
|
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
|
||||||
|
@ -61,8 +61,8 @@ constexpr auto kFirstBranchPattern3 = 11;
|
||||||
constexpr auto kSecondBranchPattern3 = 3;
|
constexpr auto kSecondBranchPattern3 = 3;
|
||||||
constexpr auto kFirstBranchStartIndexPattern3 = 4;
|
constexpr auto kFirstBranchStartIndexPattern3 = 4;
|
||||||
constexpr auto kFirstBranchEndIndexPattern3 = 10;
|
constexpr auto kFirstBranchEndIndexPattern3 = 10;
|
||||||
constexpr auto kSecondBranchStartIndexPattern3 = 11;
|
constexpr auto kSecondBranchStartIndexPattern3 = kFirstBranchPattern3;
|
||||||
constexpr auto kSecondBranchEndIndexPattern3 = 13;
|
constexpr auto kSecondBranchEndIndexPattern3 = 2 + kFirstBranchPattern3;
|
||||||
const std::vector<kStructureTuple> ResidualStructureFirstStepPattern{
|
const std::vector<kStructureTuple> ResidualStructureFirstStepPattern{
|
||||||
{kFirstBranchPattern3,
|
{kFirstBranchPattern3,
|
||||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem,
|
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem,
|
||||||
|
@ -73,15 +73,13 @@ const std::vector<kStructureTuple> ResidualStructureFirstStepPattern{
|
||||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
|
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
|
||||||
{kSecondBranchStartIndexPattern3, kSecondBranchEndIndexPattern3}}};
|
{kSecondBranchStartIndexPattern3, kSecondBranchEndIndexPattern3}}};
|
||||||
// Pattern 4
|
// Pattern 4
|
||||||
// Add -> BatchNorm -> Conv2D -> Relu ... -> End
|
|
||||||
// ↘ BatchNorm -> Conv2D -> -> -> -> ↗
|
|
||||||
constexpr auto kFirstBranchPattern4 = 8;
|
constexpr auto kFirstBranchPattern4 = 8;
|
||||||
constexpr auto kSecondBranchPattern4 = 3;
|
constexpr auto kSecondBranchPattern4 = 3;
|
||||||
constexpr auto kFirstBranchStartIndexPattern4 = 4;
|
constexpr auto kFirstBranchStartIndexPattern4 = 4;
|
||||||
constexpr auto kFirstBranchEndIndexPattern4 = 6;
|
constexpr auto kFirstBranchEndIndexPattern4 = 6;
|
||||||
constexpr auto kSecondBranchStartIndexPattern4 = 8;
|
constexpr auto kSecondBranchStartIndexPattern4 = kFirstBranchPattern4;
|
||||||
constexpr auto kSecondBranchEndIndexPattern4 = 11;
|
constexpr auto kSecondBranchEndIndexPattern4 = 3 + kFirstBranchPattern4;
|
||||||
const std::vector<kStructureTuple> BasicStructureBasePattern{
|
const std::vector<kStructureTuple> BasicStructBasePattern{
|
||||||
{kFirstBranchPattern4,
|
{kFirstBranchPattern4,
|
||||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
|
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
|
||||||
{kFirstBranchStartIndexPattern4, kFirstBranchEndIndexPattern4}},
|
{kFirstBranchStartIndexPattern4, kFirstBranchEndIndexPattern4}},
|
||||||
|
@ -89,37 +87,163 @@ const std::vector<kStructureTuple> BasicStructureBasePattern{
|
||||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
|
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
|
||||||
{kSecondBranchStartIndexPattern4, kSecondBranchEndIndexPattern4}}};
|
{kSecondBranchStartIndexPattern4, kSecondBranchEndIndexPattern4}}};
|
||||||
// Pattern 5
|
// Pattern 5
|
||||||
// Add -> BatchNorm -> Conv2D -> Relu ... -> End
|
constexpr auto kFirstBranchPattern5 = 7;
|
||||||
// ↘ -> -> -> -> Relu -> -> -> -> ↗
|
|
||||||
constexpr auto kFirstBranchPattern5 = 8;
|
|
||||||
constexpr auto kSecondBranchPattern5 = 1;
|
constexpr auto kSecondBranchPattern5 = 1;
|
||||||
constexpr auto kFirstBranchStartIndexPattern5 = 4;
|
constexpr auto kFirstBranchStartIndexPattern5 = 4;
|
||||||
constexpr auto kFirstBranchEndIndexPattern5 = 6;
|
constexpr auto kFirstBranchEndIndexPattern5 = 6;
|
||||||
constexpr auto kSecondBranchStartIndexPattern5 = 8;
|
constexpr auto kSecondBranchStartIndexPattern5 = kFirstBranchPattern5;
|
||||||
constexpr auto kSecondBranchEndIndexPattern5 = 11;
|
constexpr auto kSecondBranchEndIndexPattern5 = 3 + kFirstBranchPattern5;
|
||||||
const std::vector<kStructureTuple> BasicStructureShortCutPattern{
|
const std::vector<kStructureTuple> BasicStructFirstStepPattern{
|
||||||
{kFirstBranchPattern5,
|
{kFirstBranchPattern5,
|
||||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
|
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem,
|
||||||
|
prim::kPrimBatchNorm, prim::kPrimConv2D},
|
||||||
{kFirstBranchStartIndexPattern5, kFirstBranchEndIndexPattern5}},
|
{kFirstBranchStartIndexPattern5, kFirstBranchEndIndexPattern5}},
|
||||||
{kSecondBranchPattern5, {prim::kPrimRelu}, {kSecondBranchStartIndexPattern5, kSecondBranchEndIndexPattern5}}};
|
{kSecondBranchPattern5, {prim::kPrimMaxPool}, {kSecondBranchStartIndexPattern5, kSecondBranchEndIndexPattern5}}};
|
||||||
// Pattern 6
|
// Pattern 6
|
||||||
// Add -> BatchNorm -> Conv2D -> Relu ... -> End
|
constexpr auto kFirstBranchPattern6 = 8;
|
||||||
// ↘ -> -> -> -> MaxPool -> -> -> ↗
|
|
||||||
constexpr auto kFirstBranchPattern6 = 7;
|
|
||||||
constexpr auto kSecondBranchPattern6 = 1;
|
constexpr auto kSecondBranchPattern6 = 1;
|
||||||
constexpr auto kFirstBranchStartIndexPattern6 = 4;
|
constexpr auto kFirstBranchStartIndexPattern6 = 4;
|
||||||
constexpr auto kFirstBranchEndIndexPattern6 = 6;
|
constexpr auto kFirstBranchEndIndexPattern6 = 6;
|
||||||
constexpr auto kSecondBranchStartIndexPattern6 = 7;
|
constexpr auto kSecondBranchStartIndexPattern6 = kFirstBranchPattern6;
|
||||||
constexpr auto kSecondBranchEndIndexPattern6 = 10;
|
constexpr auto kSecondBranchEndIndexPattern6 = 3 + kFirstBranchPattern6;
|
||||||
const std::vector<kStructureTuple> BasicStructureFirstStepPattern{
|
const std::vector<kStructureTuple> BasicStructShortCutPattern{
|
||||||
{kFirstBranchPattern6,
|
{kFirstBranchPattern6,
|
||||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem,
|
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
|
||||||
prim::kPrimBatchNorm, prim::kPrimConv2D},
|
|
||||||
{kFirstBranchStartIndexPattern6, kFirstBranchEndIndexPattern6}},
|
{kFirstBranchStartIndexPattern6, kFirstBranchEndIndexPattern6}},
|
||||||
{kSecondBranchPattern6, {prim::kPrimMaxPool}, {kSecondBranchStartIndexPattern6, kSecondBranchEndIndexPattern6}}};
|
{kSecondBranchPattern6, {prim::kPrimRelu}, {kSecondBranchStartIndexPattern6, kSecondBranchEndIndexPattern6}}};
|
||||||
static const std::vector<std::vector<kStructureTuple>> kNeedMatchPattern = {
|
// Pattern 7
|
||||||
ResidualStructureBasePattern, ResidualStructureShortCutPattern, ResidualStructureFirstStepPattern,
|
constexpr auto kFirstBranchPattern7 = 1;
|
||||||
BasicStructureBasePattern, BasicStructureShortCutPattern, BasicStructureFirstStepPattern};
|
constexpr auto kSecondBranchPattern7 = 13;
|
||||||
|
constexpr auto kFirstBranchStartIndexPattern7 = SIZE_MAX;
|
||||||
|
constexpr auto kFirstBranchEndIndexPattern7 = SIZE_MAX;
|
||||||
|
constexpr auto kSecondBranchStartIndexPattern7 = 7;
|
||||||
|
constexpr auto kSecondBranchEndIndexPattern7 = 10;
|
||||||
|
const std::vector<kStructureTuple> InvertedResidualShortCutPattern{
|
||||||
|
{kFirstBranchPattern7,
|
||||||
|
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm},
|
||||||
|
{kFirstBranchStartIndexPattern7, kFirstBranchEndIndexPattern7}},
|
||||||
|
{kSecondBranchPattern7,
|
||||||
|
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem,
|
||||||
|
prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm,
|
||||||
|
prim::kPrimConv2D, prim::kPrimTupleGetItem, prim::kPrimBatchNorm},
|
||||||
|
{kSecondBranchStartIndexPattern7, kSecondBranchEndIndexPattern7}}};
|
||||||
|
// Pattern 8
|
||||||
|
constexpr auto kFirstBranchPattern8 = 4;
|
||||||
|
constexpr auto kFirstBranchStartIndexPattern8 = 0;
|
||||||
|
constexpr auto kFirstBranchEndIndexPattern8 = 3;
|
||||||
|
const std::vector<kStructureTuple> InvertedResidualPattern{
|
||||||
|
{kFirstBranchPattern8,
|
||||||
|
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimAdd},
|
||||||
|
{kFirstBranchStartIndexPattern8, kFirstBranchEndIndexPattern8}}};
|
||||||
|
// Pattern 9
|
||||||
|
constexpr auto kFirstBranchPattern9 = 1;
|
||||||
|
constexpr auto kSecondBranchPattern9 = 12;
|
||||||
|
constexpr auto kFirstBranchStartIndexPattern9 = SIZE_MAX;
|
||||||
|
constexpr auto kFirstBranchEndIndexPattern9 = SIZE_MAX;
|
||||||
|
constexpr auto kSecondBranchStartIndexPattern9 = 7;
|
||||||
|
constexpr auto kSecondBranchEndIndexPattern9 = 10;
|
||||||
|
const std::vector<kStructureTuple> InvertedResidualShortCutPattern2{
|
||||||
|
{kFirstBranchPattern9, {prim::kPrimAdd}, {kFirstBranchStartIndexPattern9, kFirstBranchEndIndexPattern9}},
|
||||||
|
{kSecondBranchPattern9,
|
||||||
|
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem,
|
||||||
|
prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm,
|
||||||
|
prim::kPrimConv2D, prim::kPrimAdd},
|
||||||
|
{kSecondBranchStartIndexPattern9, kSecondBranchEndIndexPattern9}}};
|
||||||
|
// Pattern 10
|
||||||
|
constexpr auto kFirstBranchPattern10 = 5;
|
||||||
|
constexpr auto kFirstBranchStartIndexPattern10 = 0;
|
||||||
|
constexpr auto kFirstBranchEndIndexPattern10 = 4;
|
||||||
|
const std::vector<kStructureTuple> InvertedResidualPattern2{
|
||||||
|
{kFirstBranchPattern10,
|
||||||
|
{prim::kPrimReduceMean, prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
|
||||||
|
{kFirstBranchStartIndexPattern10, kFirstBranchEndIndexPattern10}}};
|
||||||
|
// Pattern 11
|
||||||
|
constexpr auto kFirstBranchPattern11 = 17;
|
||||||
|
constexpr auto kFirstBranchStartIndexPattern11 = 3;
|
||||||
|
constexpr auto kFirstBranchEndIndexPattern11 = 6;
|
||||||
|
const std::vector<kStructureTuple> InvertedResidualPattern3{
|
||||||
|
{kFirstBranchPattern11,
|
||||||
|
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem,
|
||||||
|
prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
|
||||||
|
prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6,
|
||||||
|
prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
|
||||||
|
{kFirstBranchStartIndexPattern11, kFirstBranchEndIndexPattern11}}};
|
||||||
|
// Pattern 12
|
||||||
|
constexpr auto kFirstBranchPattern12 = 1;
|
||||||
|
constexpr auto kSecondBranchPattern12 = 9;
|
||||||
|
constexpr auto kFirstBranchStartIndexPattern12 = SIZE_MAX;
|
||||||
|
constexpr auto kFirstBranchEndIndexPattern12 = SIZE_MAX;
|
||||||
|
constexpr auto kSecondBranchStartIndexPattern12 = kFirstBranchPattern12 + 5;
|
||||||
|
constexpr auto kSecondBranchEndIndexPattern12 = kFirstBranchPattern12 + 8;
|
||||||
|
const std::vector<kStructureTuple> DenseBlockShortCutPattern{
|
||||||
|
{kFirstBranchPattern12, {prim::kPrimConcat}, {kFirstBranchStartIndexPattern12, kFirstBranchEndIndexPattern12}},
|
||||||
|
{kSecondBranchPattern12,
|
||||||
|
{prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
|
||||||
|
prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat},
|
||||||
|
{kSecondBranchStartIndexPattern12, kSecondBranchEndIndexPattern12}}};
|
||||||
|
// Pattern 13
|
||||||
|
constexpr auto kFirstBranchPattern13 = 5;
|
||||||
|
constexpr auto kFirstBranchStartIndexPattern13 = 0;
|
||||||
|
constexpr auto kFirstBranchEndIndexPattern13 = 4;
|
||||||
|
const std::vector<kStructureTuple> DenseBlockPattern{
|
||||||
|
{kFirstBranchPattern13,
|
||||||
|
{prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat},
|
||||||
|
{kFirstBranchStartIndexPattern13, kFirstBranchEndIndexPattern13}}};
|
||||||
|
// Pattern 14
|
||||||
|
constexpr auto kFirstBranchPattern14 = 9;
|
||||||
|
constexpr auto kSecondBranchPattern14 = 1;
|
||||||
|
constexpr auto kFirstBranchStartIndexPattern14 = 5;
|
||||||
|
constexpr auto kFirstBranchEndIndexPattern14 = 8;
|
||||||
|
constexpr auto kSecondBranchStartIndexPattern14 = SIZE_MAX;
|
||||||
|
constexpr auto kSecondBranchEndIndexPattern14 = SIZE_MAX;
|
||||||
|
const std::vector<kStructureTuple> DenseBlockShortCutPattern2{
|
||||||
|
{kFirstBranchPattern14,
|
||||||
|
{prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
|
||||||
|
prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat},
|
||||||
|
{kFirstBranchStartIndexPattern14, kFirstBranchEndIndexPattern14}},
|
||||||
|
{kSecondBranchPattern14, {prim::kPrimConcat}, {kSecondBranchStartIndexPattern14, kSecondBranchEndIndexPattern14}}};
|
||||||
|
// Pattern 15
|
||||||
|
constexpr auto kFirstBranchPattern15 = 9;
|
||||||
|
constexpr auto kSecondBranchPattern15 = 1;
|
||||||
|
constexpr auto kFirstBranchStartIndexPattern15 = 0;
|
||||||
|
constexpr auto kFirstBranchEndIndexPattern15 = 4;
|
||||||
|
constexpr auto kSecondBranchStartIndexPattern15 = SIZE_MAX;
|
||||||
|
constexpr auto kSecondBranchEndIndexPattern15 = SIZE_MAX;
|
||||||
|
const std::vector<kStructureTuple> DenseBlockPoolPattern{
|
||||||
|
{kFirstBranchPattern15,
|
||||||
|
{prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
|
||||||
|
prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimMaxPool},
|
||||||
|
{kFirstBranchStartIndexPattern15, kFirstBranchEndIndexPattern15}},
|
||||||
|
{kSecondBranchPattern15, {prim::kPrimConcat}, {kSecondBranchStartIndexPattern15, kSecondBranchEndIndexPattern15}}};
|
||||||
|
// Pattern 16
|
||||||
|
constexpr auto kFirstBranchPattern16 = 1;
|
||||||
|
constexpr auto kSecondBranchPattern16 = 9;
|
||||||
|
constexpr auto kFirstBranchStartIndexPattern16 = SIZE_MAX;
|
||||||
|
constexpr auto kFirstBranchEndIndexPattern16 = SIZE_MAX;
|
||||||
|
constexpr auto kSecondBranchStartIndexPattern16 = kFirstBranchPattern16;
|
||||||
|
constexpr auto kSecondBranchEndIndexPattern16 = kFirstBranchPattern16 + 4;
|
||||||
|
const std::vector<kStructureTuple> DenseBlockPoolPatter2{
|
||||||
|
{kFirstBranchPattern16, {prim::kPrimConcat}, {kFirstBranchStartIndexPattern16, kFirstBranchEndIndexPattern16}},
|
||||||
|
{kSecondBranchPattern16,
|
||||||
|
{prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
|
||||||
|
prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimMaxPool},
|
||||||
|
{kSecondBranchStartIndexPattern16, kSecondBranchEndIndexPattern16}}};
|
||||||
|
static const std::vector<std::vector<kStructureTuple>> kNeedMatchPattern = {ResidualStructureBasePattern,
|
||||||
|
ResidualStructureShortCutPattern,
|
||||||
|
ResidualStructureFirstStepPattern,
|
||||||
|
BasicStructBasePattern,
|
||||||
|
BasicStructFirstStepPattern,
|
||||||
|
BasicStructShortCutPattern,
|
||||||
|
InvertedResidualShortCutPattern,
|
||||||
|
InvertedResidualPattern,
|
||||||
|
InvertedResidualShortCutPattern2,
|
||||||
|
InvertedResidualPattern2,
|
||||||
|
InvertedResidualPattern3,
|
||||||
|
DenseBlockShortCutPattern,
|
||||||
|
DenseBlockPattern,
|
||||||
|
DenseBlockShortCutPattern2,
|
||||||
|
DenseBlockPoolPattern,
|
||||||
|
DenseBlockPoolPatter2};
|
||||||
const std::set<PrimitivePtr> kNeedRemoveNodeSet{
|
const std::set<PrimitivePtr> kNeedRemoveNodeSet{
|
||||||
prim::kPrimLoad, prim::kPrimRefToEmbed, prim::kPrimApplyMomentum, prim::kPrimMomentum,
|
prim::kPrimLoad, prim::kPrimRefToEmbed, prim::kPrimApplyMomentum, prim::kPrimMomentum,
|
||||||
prim::kPrimApplyFtrl, prim::kPrimSGD, prim::kPrimApplyRMSProp, prim::kPrimAdam};
|
prim::kPrimApplyFtrl, prim::kPrimSGD, prim::kPrimApplyRMSProp, prim::kPrimAdam};
|
||||||
|
@ -286,7 +410,13 @@ AnfNodePtr LessBatchNormalization::operator()(const OptimizerPtr &optimizer, con
|
||||||
sum_match_node += std::get<0>(t);
|
sum_match_node += std::get<0>(t);
|
||||||
total_match_node_.emplace_back(sum_match_node);
|
total_match_node_.emplace_back(sum_match_node);
|
||||||
});
|
});
|
||||||
AnfVisitor::Match(prim::kPrimAdd, {IsCNode, IsCNode})(node);
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (cnode == nullptr || cnode->inputs().empty()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||||
|
std::vector<PredicateFuncType> funcs(cnode->inputs().size() - 1, IsCNode);
|
||||||
|
AnfVisitor::Match(prim, funcs)(node);
|
||||||
if (is_match_) {
|
if (is_match_) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
@ -124,9 +124,9 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
||||||
device_num = 1
|
device_num = 1
|
||||||
|
|
||||||
if device_num == 1:
|
if device_num == 1:
|
||||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True)
|
||||||
else:
|
else:
|
||||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True,
|
||||||
num_shards=device_num, shard_id=rank_id)
|
num_shards=device_num, shard_id=rank_id)
|
||||||
|
|
||||||
image_size = 224
|
image_size = 224
|
||||||
|
@ -152,7 +152,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
||||||
|
|
||||||
type_cast_op = C2.TypeCast(mstype.int32)
|
type_cast_op = C2.TypeCast(mstype.int32)
|
||||||
|
|
||||||
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
|
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=12)
|
||||||
# only enable cache for eval
|
# only enable cache for eval
|
||||||
if do_train:
|
if do_train:
|
||||||
enable_cache = False
|
enable_cache = False
|
||||||
|
@ -160,10 +160,10 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
||||||
if not cache_session_id:
|
if not cache_session_id:
|
||||||
raise ValueError("A cache session_id must be provided to use cache.")
|
raise ValueError("A cache session_id must be provided to use cache.")
|
||||||
eval_cache = ds.DatasetCache(session_id=int(cache_session_id), size=0)
|
eval_cache = ds.DatasetCache(session_id=int(cache_session_id), size=0)
|
||||||
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8,
|
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=12,
|
||||||
cache=eval_cache)
|
cache=eval_cache)
|
||||||
else:
|
else:
|
||||||
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
|
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=12)
|
||||||
|
|
||||||
# apply batch operations
|
# apply batch operations
|
||||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||||
|
|
Loading…
Reference in New Issue