forked from OSSInnovation/mindspore
!6637 modify resnet50_adv_pruning for hub loading
Merge pull request !6637 from hzf/master
This commit is contained in:
commit
a57c06f021
|
@ -14,9 +14,21 @@
|
|||
# ============================================================================
|
||||
"""hub config."""
|
||||
from src.resnet_imgnet import resnet50
|
||||
from mindspore import Tensor
|
||||
import numpy as np
|
||||
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
if name == 'resnet-0.65x':
|
||||
return resnet50(*args, **kwargs)
|
||||
def get_index(filename):
|
||||
index = []
|
||||
with open(filename) as fr:
|
||||
for line in fr:
|
||||
ind = Tensor((np.array(line.strip('\n').split(' ')[:-1])).astype(np.int32).reshape(-1, 1))
|
||||
index.append(ind)
|
||||
return index
|
||||
|
||||
|
||||
def create_network(name, rate=0.65, index_filename='index.txt', **kwargs):
|
||||
index = get_index(index_filename)
|
||||
if name == 'resnet50-0.65x':
|
||||
return resnet50(rate=rate, index=index, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
||||
|
|
Loading…
Reference in New Issue