forked from mindspore-Ecosystem/mindspore
update resnet network performence.
less_bn pattern update. fix clang-format
This commit is contained in:
parent
b698b41b51
commit
5ebf3bdf67
|
@ -177,7 +177,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
|
||||
// Accelerated Algorithm
|
||||
less_batch_normalization_ =
|
||||
MakeSubstitution(std::make_shared<LessBatchNormalization>(), "less_batch_normalization", prim::kPrimAdd);
|
||||
MakeSubstitution(std::make_shared<LessBatchNormalization>(), "less_batch_normalization",
|
||||
{prim::kPrimAdd, prim::kPrimRelu6, prim::kPrimMatMul, prim::kPrimMakeTuple, prim::kPrimMaxPool});
|
||||
|
||||
// inline
|
||||
inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph);
|
||||
|
|
|
@ -31,8 +31,8 @@ constexpr auto kFirstBranchPattern1 = 12;
|
|||
constexpr auto kSecondBranchPattern1 = 3;
|
||||
constexpr auto kFirstBranchStartIndexPattern1 = 4;
|
||||
constexpr auto kFirstBranchEndIndexPattern1 = 11;
|
||||
constexpr auto kSecondBranchStartIndexPattern1 = 12;
|
||||
constexpr auto kSecondBranchEndIndexPattern1 = 14;
|
||||
constexpr auto kSecondBranchStartIndexPattern1 = kFirstBranchPattern1;
|
||||
constexpr auto kSecondBranchEndIndexPattern1 = 2 + kFirstBranchPattern1;
|
||||
const std::vector<kStructureTuple> ResidualStructureBasePattern{
|
||||
{kFirstBranchPattern1,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
|
||||
|
@ -47,8 +47,8 @@ constexpr auto kFirstBranchPattern2 = 12;
|
|||
constexpr auto kSecondBranchPattern2 = 1;
|
||||
constexpr auto kFirstBranchStartIndexPattern2 = 4;
|
||||
constexpr auto kFirstBranchEndIndexPattern2 = 11;
|
||||
constexpr auto kSecondBranchStartIndexPattern2 = 12;
|
||||
constexpr auto kSecondBranchEndIndexPattern2 = 13;
|
||||
constexpr auto kSecondBranchStartIndexPattern2 = kFirstBranchPattern2;
|
||||
constexpr auto kSecondBranchEndIndexPattern2 = 1 + kSecondBranchPattern2;
|
||||
const std::vector<kStructureTuple> ResidualStructureShortCutPattern{
|
||||
{kFirstBranchPattern2,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
|
||||
|
@ -61,8 +61,8 @@ constexpr auto kFirstBranchPattern3 = 11;
|
|||
constexpr auto kSecondBranchPattern3 = 3;
|
||||
constexpr auto kFirstBranchStartIndexPattern3 = 4;
|
||||
constexpr auto kFirstBranchEndIndexPattern3 = 10;
|
||||
constexpr auto kSecondBranchStartIndexPattern3 = 11;
|
||||
constexpr auto kSecondBranchEndIndexPattern3 = 13;
|
||||
constexpr auto kSecondBranchStartIndexPattern3 = kFirstBranchPattern3;
|
||||
constexpr auto kSecondBranchEndIndexPattern3 = 2 + kFirstBranchPattern3;
|
||||
const std::vector<kStructureTuple> ResidualStructureFirstStepPattern{
|
||||
{kFirstBranchPattern3,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem,
|
||||
|
@ -73,15 +73,13 @@ const std::vector<kStructureTuple> ResidualStructureFirstStepPattern{
|
|||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
|
||||
{kSecondBranchStartIndexPattern3, kSecondBranchEndIndexPattern3}}};
|
||||
// Pattern 4
|
||||
// Add -> BatchNorm -> Conv2D -> Relu ... -> End
|
||||
// ↘ BatchNorm -> Conv2D -> -> -> -> ↗
|
||||
constexpr auto kFirstBranchPattern4 = 8;
|
||||
constexpr auto kSecondBranchPattern4 = 3;
|
||||
constexpr auto kFirstBranchStartIndexPattern4 = 4;
|
||||
constexpr auto kFirstBranchEndIndexPattern4 = 6;
|
||||
constexpr auto kSecondBranchStartIndexPattern4 = 8;
|
||||
constexpr auto kSecondBranchEndIndexPattern4 = 11;
|
||||
const std::vector<kStructureTuple> BasicStructureBasePattern{
|
||||
constexpr auto kSecondBranchStartIndexPattern4 = kFirstBranchPattern4;
|
||||
constexpr auto kSecondBranchEndIndexPattern4 = 3 + kFirstBranchPattern4;
|
||||
const std::vector<kStructureTuple> BasicStructBasePattern{
|
||||
{kFirstBranchPattern4,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
|
||||
{kFirstBranchStartIndexPattern4, kFirstBranchEndIndexPattern4}},
|
||||
|
@ -89,37 +87,163 @@ const std::vector<kStructureTuple> BasicStructureBasePattern{
|
|||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
|
||||
{kSecondBranchStartIndexPattern4, kSecondBranchEndIndexPattern4}}};
|
||||
// Pattern 5
|
||||
// Add -> BatchNorm -> Conv2D -> Relu ... -> End
|
||||
// ↘ -> -> -> -> Relu -> -> -> -> ↗
|
||||
constexpr auto kFirstBranchPattern5 = 8;
|
||||
constexpr auto kFirstBranchPattern5 = 7;
|
||||
constexpr auto kSecondBranchPattern5 = 1;
|
||||
constexpr auto kFirstBranchStartIndexPattern5 = 4;
|
||||
constexpr auto kFirstBranchEndIndexPattern5 = 6;
|
||||
constexpr auto kSecondBranchStartIndexPattern5 = 8;
|
||||
constexpr auto kSecondBranchEndIndexPattern5 = 11;
|
||||
const std::vector<kStructureTuple> BasicStructureShortCutPattern{
|
||||
constexpr auto kSecondBranchStartIndexPattern5 = kFirstBranchPattern5;
|
||||
constexpr auto kSecondBranchEndIndexPattern5 = 3 + kFirstBranchPattern5;
|
||||
const std::vector<kStructureTuple> BasicStructFirstStepPattern{
|
||||
{kFirstBranchPattern5,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem,
|
||||
prim::kPrimBatchNorm, prim::kPrimConv2D},
|
||||
{kFirstBranchStartIndexPattern5, kFirstBranchEndIndexPattern5}},
|
||||
{kSecondBranchPattern5, {prim::kPrimRelu}, {kSecondBranchStartIndexPattern5, kSecondBranchEndIndexPattern5}}};
|
||||
{kSecondBranchPattern5, {prim::kPrimMaxPool}, {kSecondBranchStartIndexPattern5, kSecondBranchEndIndexPattern5}}};
|
||||
// Pattern 6
|
||||
// Add -> BatchNorm -> Conv2D -> Relu ... -> End
|
||||
// ↘ -> -> -> -> MaxPool -> -> -> ↗
|
||||
constexpr auto kFirstBranchPattern6 = 7;
|
||||
constexpr auto kFirstBranchPattern6 = 8;
|
||||
constexpr auto kSecondBranchPattern6 = 1;
|
||||
constexpr auto kFirstBranchStartIndexPattern6 = 4;
|
||||
constexpr auto kFirstBranchEndIndexPattern6 = 6;
|
||||
constexpr auto kSecondBranchStartIndexPattern6 = 7;
|
||||
constexpr auto kSecondBranchEndIndexPattern6 = 10;
|
||||
const std::vector<kStructureTuple> BasicStructureFirstStepPattern{
|
||||
constexpr auto kSecondBranchStartIndexPattern6 = kFirstBranchPattern6;
|
||||
constexpr auto kSecondBranchEndIndexPattern6 = 3 + kFirstBranchPattern6;
|
||||
const std::vector<kStructureTuple> BasicStructShortCutPattern{
|
||||
{kFirstBranchPattern6,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem,
|
||||
prim::kPrimBatchNorm, prim::kPrimConv2D},
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
|
||||
{kFirstBranchStartIndexPattern6, kFirstBranchEndIndexPattern6}},
|
||||
{kSecondBranchPattern6, {prim::kPrimMaxPool}, {kSecondBranchStartIndexPattern6, kSecondBranchEndIndexPattern6}}};
|
||||
static const std::vector<std::vector<kStructureTuple>> kNeedMatchPattern = {
|
||||
ResidualStructureBasePattern, ResidualStructureShortCutPattern, ResidualStructureFirstStepPattern,
|
||||
BasicStructureBasePattern, BasicStructureShortCutPattern, BasicStructureFirstStepPattern};
|
||||
{kSecondBranchPattern6, {prim::kPrimRelu}, {kSecondBranchStartIndexPattern6, kSecondBranchEndIndexPattern6}}};
|
||||
// Pattern 7
|
||||
constexpr auto kFirstBranchPattern7 = 1;
|
||||
constexpr auto kSecondBranchPattern7 = 13;
|
||||
constexpr auto kFirstBranchStartIndexPattern7 = SIZE_MAX;
|
||||
constexpr auto kFirstBranchEndIndexPattern7 = SIZE_MAX;
|
||||
constexpr auto kSecondBranchStartIndexPattern7 = 7;
|
||||
constexpr auto kSecondBranchEndIndexPattern7 = 10;
|
||||
const std::vector<kStructureTuple> InvertedResidualShortCutPattern{
|
||||
{kFirstBranchPattern7,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm},
|
||||
{kFirstBranchStartIndexPattern7, kFirstBranchEndIndexPattern7}},
|
||||
{kSecondBranchPattern7,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem,
|
||||
prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm,
|
||||
prim::kPrimConv2D, prim::kPrimTupleGetItem, prim::kPrimBatchNorm},
|
||||
{kSecondBranchStartIndexPattern7, kSecondBranchEndIndexPattern7}}};
|
||||
// Pattern 8
|
||||
constexpr auto kFirstBranchPattern8 = 4;
|
||||
constexpr auto kFirstBranchStartIndexPattern8 = 0;
|
||||
constexpr auto kFirstBranchEndIndexPattern8 = 3;
|
||||
const std::vector<kStructureTuple> InvertedResidualPattern{
|
||||
{kFirstBranchPattern8,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimAdd},
|
||||
{kFirstBranchStartIndexPattern8, kFirstBranchEndIndexPattern8}}};
|
||||
// Pattern 9
|
||||
constexpr auto kFirstBranchPattern9 = 1;
|
||||
constexpr auto kSecondBranchPattern9 = 12;
|
||||
constexpr auto kFirstBranchStartIndexPattern9 = SIZE_MAX;
|
||||
constexpr auto kFirstBranchEndIndexPattern9 = SIZE_MAX;
|
||||
constexpr auto kSecondBranchStartIndexPattern9 = 7;
|
||||
constexpr auto kSecondBranchEndIndexPattern9 = 10;
|
||||
const std::vector<kStructureTuple> InvertedResidualShortCutPattern2{
|
||||
{kFirstBranchPattern9, {prim::kPrimAdd}, {kFirstBranchStartIndexPattern9, kFirstBranchEndIndexPattern9}},
|
||||
{kSecondBranchPattern9,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem,
|
||||
prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm,
|
||||
prim::kPrimConv2D, prim::kPrimAdd},
|
||||
{kSecondBranchStartIndexPattern9, kSecondBranchEndIndexPattern9}}};
|
||||
// Pattern 10
|
||||
constexpr auto kFirstBranchPattern10 = 5;
|
||||
constexpr auto kFirstBranchStartIndexPattern10 = 0;
|
||||
constexpr auto kFirstBranchEndIndexPattern10 = 4;
|
||||
const std::vector<kStructureTuple> InvertedResidualPattern2{
|
||||
{kFirstBranchPattern10,
|
||||
{prim::kPrimReduceMean, prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
|
||||
{kFirstBranchStartIndexPattern10, kFirstBranchEndIndexPattern10}}};
|
||||
// Pattern 11
|
||||
constexpr auto kFirstBranchPattern11 = 17;
|
||||
constexpr auto kFirstBranchStartIndexPattern11 = 3;
|
||||
constexpr auto kFirstBranchEndIndexPattern11 = 6;
|
||||
const std::vector<kStructureTuple> InvertedResidualPattern3{
|
||||
{kFirstBranchPattern11,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem,
|
||||
prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
|
||||
prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6,
|
||||
prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
|
||||
{kFirstBranchStartIndexPattern11, kFirstBranchEndIndexPattern11}}};
|
||||
// Pattern 12
|
||||
constexpr auto kFirstBranchPattern12 = 1;
|
||||
constexpr auto kSecondBranchPattern12 = 9;
|
||||
constexpr auto kFirstBranchStartIndexPattern12 = SIZE_MAX;
|
||||
constexpr auto kFirstBranchEndIndexPattern12 = SIZE_MAX;
|
||||
constexpr auto kSecondBranchStartIndexPattern12 = kFirstBranchPattern12 + 5;
|
||||
constexpr auto kSecondBranchEndIndexPattern12 = kFirstBranchPattern12 + 8;
|
||||
const std::vector<kStructureTuple> DenseBlockShortCutPattern{
|
||||
{kFirstBranchPattern12, {prim::kPrimConcat}, {kFirstBranchStartIndexPattern12, kFirstBranchEndIndexPattern12}},
|
||||
{kSecondBranchPattern12,
|
||||
{prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
|
||||
prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat},
|
||||
{kSecondBranchStartIndexPattern12, kSecondBranchEndIndexPattern12}}};
|
||||
// Pattern 13
|
||||
constexpr auto kFirstBranchPattern13 = 5;
|
||||
constexpr auto kFirstBranchStartIndexPattern13 = 0;
|
||||
constexpr auto kFirstBranchEndIndexPattern13 = 4;
|
||||
const std::vector<kStructureTuple> DenseBlockPattern{
|
||||
{kFirstBranchPattern13,
|
||||
{prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat},
|
||||
{kFirstBranchStartIndexPattern13, kFirstBranchEndIndexPattern13}}};
|
||||
// Pattern 14
|
||||
constexpr auto kFirstBranchPattern14 = 9;
|
||||
constexpr auto kSecondBranchPattern14 = 1;
|
||||
constexpr auto kFirstBranchStartIndexPattern14 = 5;
|
||||
constexpr auto kFirstBranchEndIndexPattern14 = 8;
|
||||
constexpr auto kSecondBranchStartIndexPattern14 = SIZE_MAX;
|
||||
constexpr auto kSecondBranchEndIndexPattern14 = SIZE_MAX;
|
||||
const std::vector<kStructureTuple> DenseBlockShortCutPattern2{
|
||||
{kFirstBranchPattern14,
|
||||
{prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
|
||||
prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat},
|
||||
{kFirstBranchStartIndexPattern14, kFirstBranchEndIndexPattern14}},
|
||||
{kSecondBranchPattern14, {prim::kPrimConcat}, {kSecondBranchStartIndexPattern14, kSecondBranchEndIndexPattern14}}};
|
||||
// Pattern 15
|
||||
constexpr auto kFirstBranchPattern15 = 9;
|
||||
constexpr auto kSecondBranchPattern15 = 1;
|
||||
constexpr auto kFirstBranchStartIndexPattern15 = 0;
|
||||
constexpr auto kFirstBranchEndIndexPattern15 = 4;
|
||||
constexpr auto kSecondBranchStartIndexPattern15 = SIZE_MAX;
|
||||
constexpr auto kSecondBranchEndIndexPattern15 = SIZE_MAX;
|
||||
const std::vector<kStructureTuple> DenseBlockPoolPattern{
|
||||
{kFirstBranchPattern15,
|
||||
{prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
|
||||
prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimMaxPool},
|
||||
{kFirstBranchStartIndexPattern15, kFirstBranchEndIndexPattern15}},
|
||||
{kSecondBranchPattern15, {prim::kPrimConcat}, {kSecondBranchStartIndexPattern15, kSecondBranchEndIndexPattern15}}};
|
||||
// Pattern 16
|
||||
constexpr auto kFirstBranchPattern16 = 1;
|
||||
constexpr auto kSecondBranchPattern16 = 9;
|
||||
constexpr auto kFirstBranchStartIndexPattern16 = SIZE_MAX;
|
||||
constexpr auto kFirstBranchEndIndexPattern16 = SIZE_MAX;
|
||||
constexpr auto kSecondBranchStartIndexPattern16 = kFirstBranchPattern16;
|
||||
constexpr auto kSecondBranchEndIndexPattern16 = kFirstBranchPattern16 + 4;
|
||||
const std::vector<kStructureTuple> DenseBlockPoolPatter2{
|
||||
{kFirstBranchPattern16, {prim::kPrimConcat}, {kFirstBranchStartIndexPattern16, kFirstBranchEndIndexPattern16}},
|
||||
{kSecondBranchPattern16,
|
||||
{prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
|
||||
prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimMaxPool},
|
||||
{kSecondBranchStartIndexPattern16, kSecondBranchEndIndexPattern16}}};
|
||||
static const std::vector<std::vector<kStructureTuple>> kNeedMatchPattern = {ResidualStructureBasePattern,
|
||||
ResidualStructureShortCutPattern,
|
||||
ResidualStructureFirstStepPattern,
|
||||
BasicStructBasePattern,
|
||||
BasicStructFirstStepPattern,
|
||||
BasicStructShortCutPattern,
|
||||
InvertedResidualShortCutPattern,
|
||||
InvertedResidualPattern,
|
||||
InvertedResidualShortCutPattern2,
|
||||
InvertedResidualPattern2,
|
||||
InvertedResidualPattern3,
|
||||
DenseBlockShortCutPattern,
|
||||
DenseBlockPattern,
|
||||
DenseBlockShortCutPattern2,
|
||||
DenseBlockPoolPattern,
|
||||
DenseBlockPoolPatter2};
|
||||
const std::set<PrimitivePtr> kNeedRemoveNodeSet{
|
||||
prim::kPrimLoad, prim::kPrimRefToEmbed, prim::kPrimApplyMomentum, prim::kPrimMomentum,
|
||||
prim::kPrimApplyFtrl, prim::kPrimSGD, prim::kPrimApplyRMSProp, prim::kPrimAdam};
|
||||
|
@ -286,7 +410,13 @@ AnfNodePtr LessBatchNormalization::operator()(const OptimizerPtr &optimizer, con
|
|||
sum_match_node += std::get<0>(t);
|
||||
total_match_node_.emplace_back(sum_match_node);
|
||||
});
|
||||
AnfVisitor::Match(prim::kPrimAdd, {IsCNode, IsCNode})(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr || cnode->inputs().empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
std::vector<PredicateFuncType> funcs(cnode->inputs().size() - 1, IsCNode);
|
||||
AnfVisitor::Match(prim, funcs)(node);
|
||||
if (is_match_) {
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -124,9 +124,9 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|||
device_num = 1
|
||||
|
||||
if device_num == 1:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True)
|
||||
else:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
|
||||
image_size = 224
|
||||
|
@ -152,7 +152,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
|
||||
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
|
||||
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=12)
|
||||
# only enable cache for eval
|
||||
if do_train:
|
||||
enable_cache = False
|
||||
|
@ -160,10 +160,10 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|||
if not cache_session_id:
|
||||
raise ValueError("A cache session_id must be provided to use cache.")
|
||||
eval_cache = ds.DatasetCache(session_id=int(cache_session_id), size=0)
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8,
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=12,
|
||||
cache=eval_cache)
|
||||
else:
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=12)
|
||||
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||
|
|
Loading…
Reference in New Issue