add mindspore hub for download ckpt file

add mindspore.hub and change model_zoo
This commit is contained in:
chenzomi 2020-07-29 14:26:28 +08:00
parent 800b9dc596
commit 8918c90b66
60 changed files with 227 additions and 1 deletions

212
mindspore/hub.py Normal file
View File

@ -0,0 +1,212 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
hub for loading models:
Users can load pre-trained models using mindspore.hub.load() API.
"""
import os
import re
import shutil
import tarfile
import hashlib
from urllib.request import urlretrieve
import requests
from bs4 import BeautifulSoup
import mindspore
import mindspore.nn as nn
from mindspore import log as logger
from mindspore.train.serialization import load_checkpoint, load_param_into_net
DOWNLOAD_BASIC_URL = "http://download.mindspore.cn/model_zoo"
OFFICIAL_NAME = "official"
DEFAULT_CACHE_DIR = '~/.cache'
MODEL_TARGET_CV = ['alexnet', 'fasterrcnn', 'googlenet',
'lenet', 'resnet', 'ssd', 'vgg', 'yolo']
MODEL_TARGET_NLP = ['bert', 'mass', 'transformer']
def _packing_targz(output_filename, savepath="./"):
"""
Packing the input filename to filename.tar.gz in source dir.
"""
try:
with tarfile.open(output_filename, "w:gz") as tar:
tar.add(savepath, arcname=os.path.basename(savepath))
except Exception as e:
raise OSError("Cannot tar file {} for - {}".format(output_filename, e))
def _unpacking_targz(input_filename, savepath="./"):
"""
Unpacking the input filename to dirs.
"""
try:
t = tarfile.open(input_filename)
t.extractall(path=savepath)
except Exception as e:
raise OSError("Cannot untar file {} for - {}".format(input_filename, e))
def _remove_path_if_exists(path):
if os.path.exists(path):
if os.path.isfile(path):
os.remove(path)
else:
shutil.rmtree(path)
def _create_path_if_not_exists(path):
if os.path.exists(path):
if os.path.isfile(path):
os.remove(path)
else:
os.mkdir(path)
def _get_weights_file(url, hash_md5=None, savepath='./'):
"""
get checkpoint weight from giving url.
Args:
url(string): checkpoint tar.gz url path.
hash_md5(string): checkpoint file md5.
savepath(string): checkpoint download save path.
Returns:
string.
"""
def reporthook(a, b, c):
percent = a * b * 100.0 / c
show_str = ('[%%-%ds]' % 70) % (int(percent * 80) * '#')
print("\rDownloading:", show_str, " %5.1f%%" % (percent), end="")
def md5sum(file_name, hash_md5):
fp = open(file_name, 'rb')
content = fp.read()
fp.close()
m = hashlib.md5()
m.update(content.encode('utf-8'))
download_md5 = m.hexdigest()
return download_md5 == hash_md5
_create_path_if_not_exists(savepath)
ckpt_name = os.path.basename(url.split("/")[-1])
# identify file exist or not
file_path = os.path.join(savepath, ckpt_name)
if os.path.isfile(file_path):
if hash_md5 and md5sum(file_path, hash_md5):
print('File already exists!')
return file_path
file_path = file_path[:-7] if ".tar.gz" in file_path else file_path
_remove_path_if_exists(file_path)
# download the checkpoint file
print('Downloading data from url {}'.format(url))
try:
urlretrieve(url, file_path, reporthook=reporthook)
except HTTPError as e:
raise Exception(e.code, e.msg, url)
except URLError as e:
raise Exception(e.errno, e.reason, url)
print('\nDownload finished!')
# untar file_path
_unpacking_targz(file_path)
# # get the file size
file_path = os.path.join(savepath, ckpt_name)
filesize = os.path.getsize(file_path)
# turn the file size to Mb format
print('File size = %.2f Mb' % (filesize / 1024 / 1024))
return file_path
def _get_url_paths(url, ext='.tar.gz'):
response = requests.get(url)
if response.ok:
response_text = response.text
else:
return response.raise_for_status()
soup = BeautifulSoup(response_text, 'html.parser')
parent = [url + node.get('href') for node in soup.find_all('a')
if node.get('href').endswith(ext)]
return parent
def _get_file_from_url(base_url, base_name):
idx = 0
urls = _get_url_paths(base_url)
files = [url.split('/')[-1] for url in urls]
for i, name in enumerate(files):
if re.match(base_name + '*', name) is not None:
idx = i
break
return urls[idx]
def load_weights(network, network_name=None, force_reload=True, **kwargs):
r"""
Load a model from mindspore, with pretrained weights.
Args:
network (Cell): Cell network.
network_name (string, optional): Cell network name get from network. Default: None.
force_reload (bool, optional): Whether to force a fresh download unconditionally. Default: False.
**kwargs (optional): The corresponding kwargs for download for model.
device_target (string, optional): Runtime device target. Default: 'ascend'.
dataset (string, optional): Dataset to train the network. Default: 'cifar10'.
Example:
>>> mindspore.hub.load(network, network_name='lenet',
**{'device_target': 'ascend', 'dataset':'cifar10', 'version': 'beta0.5'})
"""
if not isinstance(network, nn.Cell):
logger.error("Failed to combine the net and the parameters.")
msg = ("Argument net should be a Cell, but got {}.".format(type(network)))
raise TypeError(msg)
if network_name is None:
if hasattr(network, network_name):
network_name = network.network_name
else:
msg = "Should input network name, but got None."
raise TypeError(msg)
device_target = kwargs['device_target'] if kwargs['device_target'] else 'ascend'
dataset = kwargs['dataset'] if kwargs['dataset'] else 'imagenet'
version = kwargs['version'] if kwargs['version'] else mindspore.version.__version__
if network_name.split("_")[0] in MODEL_TARGET_CV:
model_type = "cv"
elif network_name.split("_")[0] in MODEL_TARGET_NLP:
model_type = "nlp"
download_base_url = "/".join([DOWNLOAD_BASIC_URL,
OFFICIAL_NAME, model_type])
download_file_name = "_".join(
[network_name, device_target, version, dataset, OFFICIAL_NAME])
download_url = _get_file_from_url(download_base_url, download_file_name)
if force_reload:
ckpt_path = _get_weights_file(download_url, None, DEFAULT_CACHE_DIR)
else:
raise ValueError("Unsupported not force reload.")
ckpt_file = os.path.join(ckpt_path, network_name + ".ckpt")
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(network, param_dict)

View File

@ -905,6 +905,8 @@ class DepthwiseConv2d(Cell):
self.dilation = dilation
self.group = group
self.has_bias = has_bias
self.weight_init = weight_init
self.bias_init = bias_init
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=self.pad_mode,

View File

@ -48,10 +48,16 @@ class LossMonitor(Callback):
self.lr_init = lr_init
def epoch_begin(self, run_context):
"""
epoch begin
"""
self.losses = []
self.epoch_time = time.time()
def epoch_end(self, run_context):
"""
epoch end
"""
cb_params = run_context.original_args()
epoch_mseconds = (time.time() - self.epoch_time) * 1000
per_step_mseconds = epoch_mseconds / cb_params.batch_num
@ -62,9 +68,15 @@ class LossMonitor(Callback):
print("*" * 60)
def step_begin(self, run_context):
"""
step begin
"""
self.step_time = time.time()
def step_end(self, run_context):
"""
step end
"""
cb_params = run_context.original_args()
step_mseconds = (time.time() - self.step_time) * 1000
step_loss = cb_params.net_outputs

View File

@ -20,7 +20,7 @@ import mindspore.context as context
from mindspore import Tensor
from mindspore import nn
from mindspore.train.quant import quant as qat
from model_zoo.mobilenetv2_quant.src.mobilenetV2 import mobilenetV2
from model_zoo.official.cv.mobilenetv2_quant.src.mobilenetV2 import mobilenetV2
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")