[MD] Update set_autotune_enable API to add save filepath

This commit is contained in:
Cathy Wong 2022-02-18 12:38:55 -05:00
parent 2f3d807773
commit 46e223e569
9 changed files with 411 additions and 42 deletions

View File

@ -61,7 +61,10 @@ PYBIND_REGISTER(ConfigManager, 0, ([](const py::module *m) {
.def("get_enable_shared_mem", &ConfigManager::enable_shared_mem) .def("get_enable_shared_mem", &ConfigManager::enable_shared_mem)
.def("set_auto_offload", &ConfigManager::set_auto_offload) .def("set_auto_offload", &ConfigManager::set_auto_offload)
.def("get_auto_offload", &ConfigManager::get_auto_offload) .def("get_auto_offload", &ConfigManager::get_auto_offload)
.def("set_enable_autotune", &ConfigManager::set_enable_autotune) .def("set_enable_autotune",
[](ConfigManager &c, bool enable, bool save_autoconfig, std::string json_filepath) {
THROW_IF_ERROR(c.set_enable_autotune(enable, save_autoconfig, json_filepath));
})
.def("get_enable_autotune", &ConfigManager::enable_autotune) .def("get_enable_autotune", &ConfigManager::enable_autotune)
.def("set_autotune_interval", &ConfigManager::set_autotune_interval) .def("set_autotune_interval", &ConfigManager::set_autotune_interval)
.def("get_autotune_interval", &ConfigManager::autotune_interval) .def("get_autotune_interval", &ConfigManager::autotune_interval)

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2019-2021 Huawei Technologies Co., Ltd * Copyright 2019-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,6 +15,7 @@
*/ */
#include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/core/config_manager.h"
#include <unistd.h>
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#include <limits> #include <limits>
@ -27,6 +28,7 @@
#else #else
#include "mindspore/lite/src/common/log_adapter.h" #include "mindspore/lite/src/common/log_adapter.h"
#endif #endif
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/system_pool.h" #include "minddata/dataset/util/system_pool.h"
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
@ -53,7 +55,9 @@ ConfigManager::ConfigManager()
enable_shared_mem_(true), enable_shared_mem_(true),
auto_offload_(false), auto_offload_(false),
enable_autotune_(false), enable_autotune_(false),
save_autoconfig_(false),
autotune_interval_(kCfgAutoTuneInterval) { autotune_interval_(kCfgAutoTuneInterval) {
autotune_json_filepath_ = kEmptyString;
num_cpu_threads_ = num_cpu_threads_ > 0 ? num_cpu_threads_ : std::numeric_limits<uint16_t>::max(); num_cpu_threads_ = num_cpu_threads_ > 0 ? num_cpu_threads_ : std::numeric_limits<uint16_t>::max();
num_parallel_workers_ = num_parallel_workers_ < num_cpu_threads_ ? num_parallel_workers_ : num_cpu_threads_; num_parallel_workers_ = num_parallel_workers_ < num_cpu_threads_ ? num_parallel_workers_ : num_cpu_threads_;
std::string env_cache_host = common::GetEnv("MS_CACHE_HOST"); std::string env_cache_host = common::GetEnv("MS_CACHE_HOST");
@ -126,7 +130,7 @@ Status ConfigManager::set_num_parallel_workers(int32_t num_parallel_workers) {
if (num_parallel_workers > num_cpu_threads_ || num_parallel_workers < 1) { if (num_parallel_workers > num_cpu_threads_ || num_parallel_workers < 1) {
std::string err_msg = "Invalid Parameter, num_parallel_workers exceeds the boundary between 1 and " + std::string err_msg = "Invalid Parameter, num_parallel_workers exceeds the boundary between 1 and " +
std::to_string(num_cpu_threads_) + ", as got " + std::to_string(num_parallel_workers) + "."; std::to_string(num_cpu_threads_) + ", as got " + std::to_string(num_parallel_workers) + ".";
RETURN_STATUS_UNEXPECTED(err_msg); LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
} }
num_parallel_workers_ = num_parallel_workers; num_parallel_workers_ = num_parallel_workers;
return Status::OK(); return Status::OK();
@ -162,5 +166,56 @@ void ConfigManager::set_num_connections(int32_t num_connections) { num_connectio
void ConfigManager::set_cache_prefetch_size(int32_t cache_prefetch_size) { cache_prefetch_size_ = cache_prefetch_size; } void ConfigManager::set_cache_prefetch_size(int32_t cache_prefetch_size) { cache_prefetch_size_ = cache_prefetch_size; }
Status ConfigManager::set_enable_autotune(bool enable, bool save_autoconfig, const std::string &json_filepath) {
enable_autotune_ = enable;
save_autoconfig_ = save_autoconfig;
// Check if not requested to save AutoTune config
if (!save_autoconfig_) {
// No need for further processing, like process json_filepath input
return Status::OK();
}
Path jsonpath(json_filepath);
if (jsonpath.IsDirectory()) {
std::string err_msg = "Invalid json_filepath parameter. <" + json_filepath + "> is a directory, not filename.";
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
std::string parent_path = jsonpath.ParentPath();
if (parent_path != "") {
if (!Path(parent_path).Exists()) {
std::string err_msg = "Invalid json_filepath parameter. Directory <" + parent_path + "> does not exist.";
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
} else {
// Set parent_path to current working directory
parent_path = ".";
}
std::string real_path;
if (Path::RealPath(parent_path, real_path).IsError()) {
std::string err_msg = "Invalid json_filepath parameter. Cannot get real json_filepath <" + real_path + ">.";
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (access(real_path.c_str(), W_OK) == -1) {
std::string err_msg = "Invalid json_filepath parameter. No access to write to <" + real_path + ">.";
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (jsonpath.Exists()) {
// Note: Allow file to be overwritten (like serialize)
std::string err_msg = "Invalid json_filepath parameter. File: <" + json_filepath + "> already exists." +
" File will be overwritten with the AutoTuned data pipeline configuration.";
MS_LOG(WARNING) << err_msg;
}
// Save the final AutoTune configuration JSON filepath name
autotune_json_filepath_ = std::move(json_filepath);
return Status::OK();
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2019-2021 Huawei Technologies Co., Ltd * Copyright 2019-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -36,6 +36,8 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
const char kEmptyString[] = "";
const char kJsonExtension[] = ".json";
// The ConfigManager is a class for managing default values. When a user is constructing any objects // The ConfigManager is a class for managing default values. When a user is constructing any objects
// in the framework, often they may choose to omit some settings instead of overriding them. // in the framework, often they may choose to omit some settings instead of overriding them.
// This class manages some of the default values, for cases when the user does not manually specify // This class manages some of the default values, for cases when the user does not manually specify
@ -232,12 +234,23 @@ class ConfigManager {
// setter function // setter function
// @param enable - To enable autotune // @param enable - To enable autotune
void set_enable_autotune(bool enable) { enable_autotune_ = enable; } // @param bool save_autoconfig - True if should save AutoTune data pipeline configuration
// @param json_filepath - JSON filepath where the final AutoTune data pipeline will be generated
// @return Status error code
Status set_enable_autotune(bool enable, bool save_autoconfig, const std::string &json_filepath);
// getter function // getter function
// @return - Flag to indicate whether autotune is enabled // @return - Flag to indicate whether autotune is enabled
bool enable_autotune() const { return enable_autotune_; } bool enable_autotune() const { return enable_autotune_; }
// getter function
// @return - Flag to indicate whether to save AutoTune configuration
bool save_autoconfig() { return save_autoconfig_; }
// getter function
// @return - The final AutoTune configuration JSON filepath
std::string get_autotune_json_filepath() { return autotune_json_filepath_; }
// getter function // getter function
// @return - autotune interval in steps // @return - autotune interval in steps
int64_t autotune_interval() const { return autotune_interval_; } int64_t autotune_interval() const { return autotune_interval_; }
@ -270,6 +283,8 @@ class ConfigManager {
bool enable_shared_mem_; bool enable_shared_mem_;
bool auto_offload_; bool auto_offload_;
bool enable_autotune_; bool enable_autotune_;
bool save_autoconfig_; // True if should save AutoTune configuration
std::string autotune_json_filepath_; // Filepath name of the final AutoTune Configuration JSON file
int64_t autotune_interval_; int64_t autotune_interval_;
// Private helper function that takes a nlohmann json format and populates the settings // Private helper function that takes a nlohmann json format and populates the settings
// @param j - The json nlohmann json info // @param j - The json nlohmann json info

View File

@ -19,6 +19,7 @@
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <utility>
#include <vector> #include <vector>
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h" #include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"
@ -39,6 +40,8 @@ AutoTune::AutoTune(TreeAdapter *tree_adap, ProfilingManager *profiling_mgr)
tree_modifier_ = std::make_unique<TreeModifier>(tree_adapter_); tree_modifier_ = std::make_unique<TreeModifier>(tree_adapter_);
max_workers_ = GlobalContext::config_manager()->num_cpu_threads(); max_workers_ = GlobalContext::config_manager()->num_cpu_threads();
step_gap_ = GlobalContext::config_manager()->autotune_interval(); step_gap_ = GlobalContext::config_manager()->autotune_interval();
save_autoconfig_ = GlobalContext::config_manager()->save_autoconfig();
autotune_json_filepath_ = GlobalContext::config_manager()->get_autotune_json_filepath();
} }
Status AutoTune::Main() { Status AutoTune::Main() {

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -14,12 +14,13 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_AUTO_TUNE_H_ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_PERF_AUTO_TUNE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_AUTO_TUNE_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_PERF_AUTO_TUNE_H_
#include <map> #include <map>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <string>
#include <vector> #include <vector>
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/log_adapter.h" #include "minddata/dataset/util/log_adapter.h"
@ -191,7 +192,12 @@ class AutoTune {
int64_t step_gap_; int64_t step_gap_;
int32_t last_step_profiled_; int32_t last_step_profiled_;
bool skip_bool_; bool skip_bool_;
/// True if should save AutoTune configuration
bool save_autoconfig_;
/// Filepath name of the final AutoTune Configuration JSON file
std::string autotune_json_filepath_;
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_AUTO_TUNE_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_PERF_AUTO_TUNE_H_

View File

@ -28,6 +28,7 @@ import random
import numpy import numpy
import mindspore._c_dataengine as cde import mindspore._c_dataengine as cde
from mindspore import log as logger from mindspore import log as logger
from .validator_helpers import replace_none
__all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers', __all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers',
'get_num_parallel_workers', 'set_numa_enable', 'get_numa_enable', 'set_monitor_sampling_interval', 'get_num_parallel_workers', 'set_numa_enable', 'get_numa_enable', 'set_monitor_sampling_interval',
@ -421,24 +422,55 @@ def load(file):
_config.load(file) _config.load(file)
def set_enable_autotune(enable): def set_enable_autotune(enable, json_filepath=None):
""" """
Set the default state of AutoTune flag. If it is True, will facilitate users to improve Set the default state of AutoTune flag. If it is True, will facilitate users to improve the
performance for a given workload by automatically finding the better settings for data pipeline. performance for a given workload by automatically finding better settings for data pipeline.
Optionally save the AutoTuned data pipeline configuration to a JSON file, which
can be loaded with deserialize().
Args: Args:
enable (bool): Whether to use AutoTune feature when running data pipeline. enable (bool): Whether to use AutoTune feature when running data pipeline.
json_filepath (str, optional): The filepath where the AutoTuned data pipeline
configuration will be generated as a JSON file. If the file already exists,
it will be overwritten. If no AutoTuned data pipeline configuration is desired,
then set json_filepath to None (Default=None).
Raises: Raises:
TypeError: If enable is not a boolean data type. TypeError: If enable is not a boolean data type.
TypeError: If json_filepath is not a str value.
RuntimeError: If the value of json_filepath is the empty string.
RuntimeError: If json_filepath a directory.
RuntimeError: If parent path for json_filepath does not exist.
RuntimeError: If parent path for json_filepath does not have write permission.
Note:
When using enable is False, the value of json_filepath is ignored.
Examples: Examples:
>>> # Enable AutoTune and save AutoTuned data pipeline configuration
>>> ds.config.set_enable_autotune(True, "/path/to/autotune_out.json")
>>>
>>> # Enable AutoTune >>> # Enable AutoTune
>>> ds.config.set_enable_autotune(True) >>> ds.config.set_enable_autotune(True)
""" """
if not isinstance(enable, bool): if not isinstance(enable, bool):
raise TypeError("enable must be of type bool.") raise TypeError("enable must be of type bool.")
_config.set_enable_autotune(enable)
save_autoconfig = bool(enable and json_filepath is not None)
if json_filepath and not isinstance(json_filepath, str):
raise TypeError("json_filepath must be a str value but was: {}.".format(json_filepath))
if enable and json_filepath == "":
raise RuntimeError("The value of json_filepath cannot be the empty string.")
if not enable and json_filepath is not None:
logger.warning("The value of json_filepath is ignored when enable is False.")
json_filepath = replace_none(json_filepath, "")
_config.set_enable_autotune(enable, save_autoconfig, json_filepath)
def get_enable_autotune(): def get_enable_autotune():

View File

@ -203,32 +203,3 @@ class TestAutotuneWithProfiler:
pass pass
ds.config.set_enable_autotune(False) ds.config.set_enable_autotune(False)
def test_autotune_config(self):
"""
Feature: Autotuning
Description: test basic config of autotune
Expectation: config can be set successfully
"""
autotune_state = ds.config.get_enable_autotune()
assert autotune_state is False
ds.config.set_enable_autotune(False)
autotune_state = ds.config.get_enable_autotune()
assert autotune_state is False
with pytest.raises(TypeError):
ds.config.set_enable_autotune(1)
autotune_interval = ds.config.get_autotune_interval()
assert autotune_interval == 0
ds.config.set_autotune_interval(200)
autotune_interval = ds.config.get_autotune_interval()
assert autotune_interval == 200
with pytest.raises(TypeError):
ds.config.set_autotune_interval(20.012)
with pytest.raises(ValueError):
ds.config.set_autotune_interval(-999)

View File

@ -0,0 +1,112 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test Dataset AutoTune Configuration Support
"""
import pytest
import mindspore.dataset as ds
@pytest.mark.forked
class TestAutotuneConfig:
@staticmethod
def test_autotune_config_basic():
"""
Feature: Autotuning
Description: Test basic config of AutoTune
Expectation: Config can be set successfully
"""
autotune_state = ds.config.get_enable_autotune()
assert autotune_state is False
ds.config.set_enable_autotune(False)
autotune_state = ds.config.get_enable_autotune()
assert autotune_state is False
with pytest.raises(TypeError):
ds.config.set_enable_autotune(1)
autotune_interval = ds.config.get_autotune_interval()
assert autotune_interval == 0
ds.config.set_autotune_interval(200)
autotune_interval = ds.config.get_autotune_interval()
assert autotune_interval == 200
with pytest.raises(TypeError):
ds.config.set_autotune_interval(20.012)
with pytest.raises(ValueError):
ds.config.set_autotune_interval(-999)
@staticmethod
def test_autotune_config_filepath_invalid():
"""
Feature: Autotuning
Description: Test set_enable_autotune() with invalid json_filepath
Expectation: Invalid input is detected
"""
with pytest.raises(TypeError):
ds.config.set_enable_autotune(True, 123)
with pytest.raises(TypeError):
ds.config.set_enable_autotune(True, 0)
with pytest.raises(TypeError):
ds.config.set_enable_autotune(True, True)
with pytest.raises(TypeError):
ds.config.set_enable_autotune(False, 1.1)
with pytest.raises(RuntimeError) as error_info:
ds.config.set_enable_autotune(True, "")
assert "cannot be the empty string" in str(error_info.value)
with pytest.raises(RuntimeError) as error_info:
ds.config.set_enable_autotune(True, "/tmp")
assert "is a directory" in str(error_info.value)
with pytest.raises(RuntimeError) as error_info:
ds.config.set_enable_autotune(True, ".")
assert "is a directory" in str(error_info.value)
with pytest.raises(RuntimeError) as error_info:
ds.config.set_enable_autotune(True, "/JUNKPATH/at_out.json")
assert "Directory" in str(error_info.value)
assert "does not exist" in str(error_info.value)
@staticmethod
def test_autotune_config_filepath_success():
"""
Feature: Autotuning
Description: Test set_enable_autotune() with valid filepath input
Expectation: set_enable_autotune() executes successfully
"""
# Note: No problem to have sequential calls to set_enable_autotune()
ds.config.set_enable_autotune(True, "file1.json")
ds.config.set_enable_autotune(True, "file1.json")
ds.config.set_enable_autotune(True, "file2.json")
# Note: It is permissible to not have preferred '.json' extension for json_filepath
ds.config.set_enable_autotune(True, "at_out.JSON")
ds.config.set_enable_autotune(True, "/tmp/at_out.txt")
ds.config.set_enable_autotune(True, "at_out")
# Note: When enable is false, the json_filepath parameter is ignored
ds.config.set_enable_autotune(False, "/NONEXISTDIR/junk.json")
ds.config.set_enable_autotune(False, "")
ds.config.set_enable_autotune(False, None)

View File

@ -0,0 +1,172 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test Dataset AutoTune's Save and Load Configuration support
"""
import filecmp
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as c_transforms
MNIST_DATA_DIR = "../data/dataset/testMnistData"
@pytest.mark.forked
class TestAutotuneSaveLoad:
# Note: Use pytest fixture tmp_path to create files within this temporary directory,
# which is automatically created for each test and deleted at the end of the test.
@staticmethod
def test_autotune_generator_pipeline(tmp_path):
"""
Feature: Autotuning
Description: Test save final config with GeneratorDataset pipeline: Generator -> Shuffle -> Batch
Expectation: pipeline runs successfully
"""
ds.config.set_enable_autotune(True, str(tmp_path) + "test_autotune_generator_atfinal.json")
source = [(np.array([x]),) for x in range(1024)]
data1 = ds.GeneratorDataset(source, ["data"])
data1 = data1.shuffle(64)
data1 = data1.batch(32)
ds.serialize(data1, str(tmp_path) + "test_autotune_generator_serialized.json")
itr = data1.create_dict_iterator(num_epochs=5)
for _ in range(5):
for _ in itr:
pass
ds.config.set_enable_autotune(False)
@staticmethod
def skip_test_autotune_mnist_pipeline(tmp_path):
"""
Feature: Autotuning
Description: Test save final config with Mnist pipeline: Mnist -> Batch -> Map
Expectation: pipeline runs successfully
"""
ds.config.set_enable_autotune(True, str(tmp_path) + "test_autotune_mnist_pipeline_atfinal.json")
ds.config.set_seed(1)
data1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=100)
one_hot_encode = c_transforms.OneHot(10) # num_classes is input argument
data1 = data1.map(operations=one_hot_encode, input_columns="label")
data1 = data1.batch(batch_size=10, drop_remainder=True)
ds.serialize(data1, str(tmp_path) + "test_autotune_mnist_pipeline_serialized.json")
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
pass
ds.config.set_enable_autotune(False)
# Confirm final AutoTune config file is identical to the serialized file.
assert filecmp.cmp(str(tmp_path) + "test_autotune_mnist_pipeline_atfinal.json",
str(tmp_path) + "test_autotune_mnist_pipeline_serialized.json")
desdata1 = ds.deserialize(json_filepath=str(tmp_path) + "test_autotune_mnist_pipeline_atfinal.json")
desdata2 = ds.deserialize(json_filepath=str(tmp_path) + "test_autotune_mnist_pipeline_serialized.json")
num = 0
for newdata1, newdata2 in zip(desdata1.create_dict_iterator(num_epochs=1, output_numpy=True),
desdata2.create_dict_iterator(num_epochs=1, output_numpy=True)):
np.testing.assert_array_equal(newdata1['image'], newdata2['image'])
np.testing.assert_array_equal(newdata1['label'], newdata2['label'])
num += 1
assert num == 10
@staticmethod
def test_autotune_save_overwrite_generator(tmp_path):
"""
Feature: Autotuning
Description: Test set_enable_autotune and existing json_filepath is overwritten
Expectation: set_enable_autotune() executes successfully with file-exist warning produced.
Execution of 2nd pipeline overwrites AutoTune configuration file of 1st pipeline.
"""
source = [(np.array([x]),) for x in range(1024)]
at_final_json_filename = "test_autotune_save_overwrite_generator_atfinal.json"
ds.config.set_enable_autotune(True, str(tmp_path) + at_final_json_filename)
data1 = ds.GeneratorDataset(source, ["data"])
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
pass
ds.config.set_enable_autotune(False)
ds.config.set_enable_autotune(True, str(tmp_path) + at_final_json_filename)
data2 = ds.GeneratorDataset(source, ["data"])
data2 = data2.shuffle(64)
for _ in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
pass
ds.config.set_enable_autotune(False)
@staticmethod
def skip_test_autotune_save_overwrite_mnist(tmp_path):
"""
Feature: Autotuning
Description: Test set_enable_autotune and existing json_filepath is overwritten
Expectation: set_enable_autotune() executes successfully with file-exist warning produced.
Execution of 2nd pipeline overwrites AutoTune configuration file of 1st pipeline.
"""
ds.config.set_seed(1)
at_final_json_filename = "test_autotune_save_overwrite_mnist_atfinal.json"
# Pipeline#1
ds.config.set_enable_autotune(True, str(tmp_path) + at_final_json_filename)
data1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=100)
one_hot_encode = c_transforms.OneHot(10) # num_classes is input argument
data1 = data1.map(operations=one_hot_encode, input_columns="label")
data1 = data1.batch(batch_size=10, drop_remainder=True)
ds.serialize(data1, str(tmp_path) + "test_autotune_save_overwrite_mnist_serialized1.json")
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
pass
ds.config.set_enable_autotune(False)
# Pipeline#2
ds.config.set_enable_autotune(True, str(tmp_path) + at_final_json_filename)
data1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=200)
data1 = data1.map(operations=one_hot_encode, input_columns="label")
data1 = data1.shuffle(40)
data1 = data1.batch(batch_size=20, drop_remainder=False)
ds.serialize(data1, str(tmp_path) + "test_autotune_save_overwrite_mnist_serialized2.json")
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
pass
ds.config.set_enable_autotune(False)
# Confirm 2nd serialized file is identical to final AutoTune config file.
assert filecmp.cmp(str(tmp_path) + "test_autotune_save_overwrite_mnist_atfinal.json",
str(tmp_path) + "test_autotune_save_overwrite_mnist_serialized2.json")
# Confirm the serialized files for the 2 different pipelines are different
assert not filecmp.cmp(str(tmp_path) + "test_autotune_save_overwrite_mnist_serialized1.json",
str(tmp_path) + "test_autotune_save_overwrite_mnist_serialized2.json")