mindspore/tests/ut/python/dataset/prep_data.py

125 lines
4.5 KiB
Python

# Copyright 2020 Huawei Technologies Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# import jsbeautifier
import os
import urllib
import urllib.request
def create_data_cache_dir():
cwd = os.getcwd()
target_directory = os.path.join(cwd, "data_cache")
try:
if not (os.path.exists(target_directory)):
os.mkdir(target_directory)
except OSError:
print("Creation of the directory %s failed" % target_directory)
return target_directory;
def download_and_uncompress(files, source_url, target_directory, is_tar=False):
for f in files:
url = source_url + f
target_file = os.path.join(target_directory, f)
##check if file already downloaded
if not (os.path.exists(target_file) or os.path.exists(target_file[:-3])):
urllib.request.urlretrieve(url, target_file)
if is_tar:
print("extracting from local tar file " + target_file)
rc = os.system("tar -C " + target_directory + " -xvf " + target_file)
else:
print("unzipping " + target_file)
rc = os.system("gunzip -f " + target_file)
if rc != 0:
print("Failed to uncompress ", target_file, " removing")
os.system("rm " + target_file)
##exit with error so that build script will fail
raise SystemError
else:
print("Using cached dataset at ", target_file)
def download_mnist(target_directory=None):
if target_directory == None:
target_directory = create_data_cache_dir()
##create mnst directory
target_directory = os.path.join(target_directory, "mnist")
try:
if not (os.path.exists(target_directory)):
os.mkdir(target_directory)
except OSError:
print("Creation of the directory %s failed" % target_directory)
MNIST_URL = "http://yann.lecun.com/exdb/mnist/"
files = ['train-images-idx3-ubyte.gz',
'train-labels-idx1-ubyte.gz',
't10k-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz']
download_and_uncompress(files, MNIST_URL, target_directory, is_tar=False)
return target_directory, os.path.join(target_directory, "datasetSchema.json")
CIFAR_URL = "https://www.cs.toronto.edu/~kriz/"
def download_cifar(target_directory, files, directory_from_tar):
if target_directory == None:
target_directory = create_data_cache_dir()
download_and_uncompress([files], CIFAR_URL, target_directory, is_tar=True)
##if target dir was specify move data from directory created by tar
##and put data into target dir
if target_directory != None:
tar_dir_full_path = os.path.join(target_directory, directory_from_tar)
all_files = os.path.join(tar_dir_full_path, "*")
cmd = "mv " + all_files + " " + target_directory
if os.path.exists(tar_dir_full_path):
print("copy files back to target_directory")
print("Executing: ", cmd)
rc1 = os.system(cmd)
rc2 = os.system("rm -r " + tar_dir_full_path)
if rc1 != 0 or rc2 != 0:
print("error when running command: ", cmd)
download_file = os.path.join(target_directory, files)
print("removing " + download_file)
os.system("rm " + download_file)
##exit with error so that build script will fail
raise SystemError
##change target directory to directory after tar
return os.path.join(target_directory, directory_from_tar)
def download_cifar10(target_directory=None):
return download_cifar(target_directory, "cifar-10-binary.tar.gz", "cifar-10-batches-bin")
def download_cifar100(target_directory=None):
return download_cifar(target_directory, "cifar-100-binary.tar.gz", "cifar-100-binary")
def download_all_for_test(cwd):
download_mnist(os.path.join(cwd, "testMnistData"))
##Download all datasets to existing test directories
if __name__ == "__main__":
download_all_for_test(os.getcwd())