forked from mindspore-Ecosystem/mindspore
add mindspore hub for download ckpt file
add mindspore.hub and change model_zoo
This commit is contained in:
parent
800b9dc596
commit
8918c90b66
|
@ -0,0 +1,212 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
hub for loading models:
|
||||
Users can load pre-trained models using mindspore.hub.load() API.
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tarfile
|
||||
import hashlib
|
||||
from urllib.request import urlretrieve
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
from mindspore import log as logger
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
DOWNLOAD_BASIC_URL = "http://download.mindspore.cn/model_zoo"
|
||||
OFFICIAL_NAME = "official"
|
||||
DEFAULT_CACHE_DIR = '~/.cache'
|
||||
MODEL_TARGET_CV = ['alexnet', 'fasterrcnn', 'googlenet',
|
||||
'lenet', 'resnet', 'ssd', 'vgg', 'yolo']
|
||||
MODEL_TARGET_NLP = ['bert', 'mass', 'transformer']
|
||||
|
||||
|
||||
def _packing_targz(output_filename, savepath="./"):
|
||||
"""
|
||||
Packing the input filename to filename.tar.gz in source dir.
|
||||
"""
|
||||
try:
|
||||
with tarfile.open(output_filename, "w:gz") as tar:
|
||||
tar.add(savepath, arcname=os.path.basename(savepath))
|
||||
except Exception as e:
|
||||
raise OSError("Cannot tar file {} for - {}".format(output_filename, e))
|
||||
|
||||
|
||||
def _unpacking_targz(input_filename, savepath="./"):
|
||||
"""
|
||||
Unpacking the input filename to dirs.
|
||||
"""
|
||||
try:
|
||||
t = tarfile.open(input_filename)
|
||||
t.extractall(path=savepath)
|
||||
except Exception as e:
|
||||
raise OSError("Cannot untar file {} for - {}".format(input_filename, e))
|
||||
|
||||
|
||||
def _remove_path_if_exists(path):
|
||||
if os.path.exists(path):
|
||||
if os.path.isfile(path):
|
||||
os.remove(path)
|
||||
else:
|
||||
shutil.rmtree(path)
|
||||
|
||||
|
||||
def _create_path_if_not_exists(path):
|
||||
if os.path.exists(path):
|
||||
if os.path.isfile(path):
|
||||
os.remove(path)
|
||||
else:
|
||||
os.mkdir(path)
|
||||
|
||||
|
||||
def _get_weights_file(url, hash_md5=None, savepath='./'):
|
||||
"""
|
||||
get checkpoint weight from giving url.
|
||||
|
||||
Args:
|
||||
url(string): checkpoint tar.gz url path.
|
||||
hash_md5(string): checkpoint file md5.
|
||||
savepath(string): checkpoint download save path.
|
||||
|
||||
Returns:
|
||||
string.
|
||||
"""
|
||||
|
||||
def reporthook(a, b, c):
|
||||
percent = a * b * 100.0 / c
|
||||
show_str = ('[%%-%ds]' % 70) % (int(percent * 80) * '#')
|
||||
print("\rDownloading:", show_str, " %5.1f%%" % (percent), end="")
|
||||
|
||||
def md5sum(file_name, hash_md5):
|
||||
fp = open(file_name, 'rb')
|
||||
content = fp.read()
|
||||
fp.close()
|
||||
m = hashlib.md5()
|
||||
m.update(content.encode('utf-8'))
|
||||
download_md5 = m.hexdigest()
|
||||
return download_md5 == hash_md5
|
||||
|
||||
_create_path_if_not_exists(savepath)
|
||||
ckpt_name = os.path.basename(url.split("/")[-1])
|
||||
# identify file exist or not
|
||||
file_path = os.path.join(savepath, ckpt_name)
|
||||
if os.path.isfile(file_path):
|
||||
if hash_md5 and md5sum(file_path, hash_md5):
|
||||
print('File already exists!')
|
||||
return file_path
|
||||
|
||||
file_path = file_path[:-7] if ".tar.gz" in file_path else file_path
|
||||
_remove_path_if_exists(file_path)
|
||||
|
||||
# download the checkpoint file
|
||||
print('Downloading data from url {}'.format(url))
|
||||
try:
|
||||
urlretrieve(url, file_path, reporthook=reporthook)
|
||||
except HTTPError as e:
|
||||
raise Exception(e.code, e.msg, url)
|
||||
except URLError as e:
|
||||
raise Exception(e.errno, e.reason, url)
|
||||
print('\nDownload finished!')
|
||||
|
||||
# untar file_path
|
||||
_unpacking_targz(file_path)
|
||||
|
||||
# # get the file size
|
||||
file_path = os.path.join(savepath, ckpt_name)
|
||||
filesize = os.path.getsize(file_path)
|
||||
# turn the file size to Mb format
|
||||
print('File size = %.2f Mb' % (filesize / 1024 / 1024))
|
||||
return file_path
|
||||
|
||||
|
||||
def _get_url_paths(url, ext='.tar.gz'):
|
||||
response = requests.get(url)
|
||||
if response.ok:
|
||||
response_text = response.text
|
||||
else:
|
||||
return response.raise_for_status()
|
||||
soup = BeautifulSoup(response_text, 'html.parser')
|
||||
parent = [url + node.get('href') for node in soup.find_all('a')
|
||||
if node.get('href').endswith(ext)]
|
||||
return parent
|
||||
|
||||
|
||||
def _get_file_from_url(base_url, base_name):
|
||||
idx = 0
|
||||
urls = _get_url_paths(base_url)
|
||||
files = [url.split('/')[-1] for url in urls]
|
||||
for i, name in enumerate(files):
|
||||
if re.match(base_name + '*', name) is not None:
|
||||
idx = i
|
||||
break
|
||||
return urls[idx]
|
||||
|
||||
|
||||
def load_weights(network, network_name=None, force_reload=True, **kwargs):
|
||||
r"""
|
||||
Load a model from mindspore, with pretrained weights.
|
||||
|
||||
Args:
|
||||
network (Cell): Cell network.
|
||||
network_name (string, optional): Cell network name get from network. Default: None.
|
||||
force_reload (bool, optional): Whether to force a fresh download unconditionally. Default: False.
|
||||
**kwargs (optional): The corresponding kwargs for download for model.
|
||||
device_target (string, optional): Runtime device target. Default: 'ascend'.
|
||||
dataset (string, optional): Dataset to train the network. Default: 'cifar10'.
|
||||
|
||||
Example:
|
||||
>>> mindspore.hub.load(network, network_name='lenet',
|
||||
**{'device_target': 'ascend', 'dataset':'cifar10', 'version': 'beta0.5'})
|
||||
"""
|
||||
if not isinstance(network, nn.Cell):
|
||||
logger.error("Failed to combine the net and the parameters.")
|
||||
msg = ("Argument net should be a Cell, but got {}.".format(type(network)))
|
||||
raise TypeError(msg)
|
||||
|
||||
if network_name is None:
|
||||
if hasattr(network, network_name):
|
||||
network_name = network.network_name
|
||||
else:
|
||||
msg = "Should input network name, but got None."
|
||||
raise TypeError(msg)
|
||||
|
||||
device_target = kwargs['device_target'] if kwargs['device_target'] else 'ascend'
|
||||
dataset = kwargs['dataset'] if kwargs['dataset'] else 'imagenet'
|
||||
version = kwargs['version'] if kwargs['version'] else mindspore.version.__version__
|
||||
|
||||
if network_name.split("_")[0] in MODEL_TARGET_CV:
|
||||
model_type = "cv"
|
||||
elif network_name.split("_")[0] in MODEL_TARGET_NLP:
|
||||
model_type = "nlp"
|
||||
|
||||
download_base_url = "/".join([DOWNLOAD_BASIC_URL,
|
||||
OFFICIAL_NAME, model_type])
|
||||
download_file_name = "_".join(
|
||||
[network_name, device_target, version, dataset, OFFICIAL_NAME])
|
||||
download_url = _get_file_from_url(download_base_url, download_file_name)
|
||||
|
||||
if force_reload:
|
||||
ckpt_path = _get_weights_file(download_url, None, DEFAULT_CACHE_DIR)
|
||||
else:
|
||||
raise ValueError("Unsupported not force reload.")
|
||||
|
||||
ckpt_file = os.path.join(ckpt_path, network_name + ".ckpt")
|
||||
param_dict = load_checkpoint(ckpt_file)
|
||||
load_param_into_net(network, param_dict)
|
|
@ -905,6 +905,8 @@ class DepthwiseConv2d(Cell):
|
|||
self.dilation = dilation
|
||||
self.group = group
|
||||
self.has_bias = has_bias
|
||||
self.weight_init = weight_init
|
||||
self.bias_init = bias_init
|
||||
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
|
||||
kernel_size=self.kernel_size,
|
||||
pad_mode=self.pad_mode,
|
||||
|
|
|
@ -48,10 +48,16 @@ class LossMonitor(Callback):
|
|||
self.lr_init = lr_init
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
"""
|
||||
epoch begin
|
||||
"""
|
||||
self.losses = []
|
||||
self.epoch_time = time.time()
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""
|
||||
epoch end
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
epoch_mseconds = (time.time() - self.epoch_time) * 1000
|
||||
per_step_mseconds = epoch_mseconds / cb_params.batch_num
|
||||
|
@ -62,9 +68,15 @@ class LossMonitor(Callback):
|
|||
print("*" * 60)
|
||||
|
||||
def step_begin(self, run_context):
|
||||
"""
|
||||
step begin
|
||||
"""
|
||||
self.step_time = time.time()
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
step end
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
step_mseconds = (time.time() - self.step_time) * 1000
|
||||
step_loss = cb_params.net_outputs
|
|
@ -20,7 +20,7 @@ import mindspore.context as context
|
|||
from mindspore import Tensor
|
||||
from mindspore import nn
|
||||
from mindspore.train.quant import quant as qat
|
||||
from model_zoo.mobilenetv2_quant.src.mobilenetV2 import mobilenetV2
|
||||
from model_zoo.official.cv.mobilenetv2_quant.src.mobilenetV2 import mobilenetV2
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
|
Loading…
Reference in New Issue