!19003 Update the mindspore_hub_conf file required for the operation of the hub warehouse in the mindspore warehouse

Merge pull request !19003 from dinglinhe/dlh_code_ms_I3J8EP_1
This commit is contained in:
i-robot 2021-06-29 08:21:47 +00:00 committed by Gitee
commit d23ddffc16
28 changed files with 625 additions and 11 deletions

View File

@ -0,0 +1,26 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config"""
from src.nets.FCN8s import FCN8s
def fcn8s_net(*args, **kwargs):
return FCN8s(*args, **kwargs)
def create_network(name, *args, **kwargs):
"""create_network about FCN8s"""
if name == "fcn8s":
num_classes = 21
return fcn8s_net(n_class=num_classes, *args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,29 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config"""
from src.cnn_direction_model import CNNDirectionModel
def cnnDirection_net(*args, **kwargs):
return CNNDirectionModel(*args, **kwargs)
def create_network(name, *args, **kwargs):
"""create_network about CNNDirectionModel"""
if name == "cnn_direction_model":
in_channels = [3, 64, 48, 48, 64]
out_channels = [64, 48, 48, 64, 64]
dense_layers = [256, 64]
image_size = [64, 512]
return cnnDirection_net(in_channels, out_channels, dense_layers, image_size, *args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,26 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config"""
from src.crnn import crnn
from src.config import config1
def crnn_net(*args, **kwargs):
return crnn(*args, **kwargs)
def create_network(name, *args, **kwargs):
"""create_network about crnn"""
if name == "crnn":
return crnn_net(config1, *args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,34 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config"""
from src.attention_ocr import AttentionOCRInfer
from src.config import config
def crnnseq2seqocr_net(*args, **kwargs):
return AttentionOCRInfer(*args, **kwargs)
def create_network(name, *args, **kwargs):
"""create_network about crnn_seq2seq_ocr"""
if name == "crnn_seq2seq_ocr":
return crnnseq2seqocr_net(config.batch_size,
int(config.img_width / 4),
config.encoder_hidden_size,
config.decoder_hidden_size,
config.decoder_output_size,
config.max_length,
config.dropout_p,
*args,
**kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -19,9 +19,10 @@ def create_network(name, *args, **kwargs):
freeze_bn = True
num_classes = kwargs["num_classes"]
if name == 'deeplab_v3_s16':
deeplab_v3_s16_network = net_factory.nets_map["deeplab_v3_s16"]('eval', num_classes, 16, freeze_bn)
deeplab_v3_s16_network = net_factory.nets_map["deeplab_v3_s16"](num_classes, 16)
return deeplab_v3_s16_network
if name == 'deeplab_v3_s8':
deeplab_v3_s8_network = net_factory.nets_map["deeplab_v3_s8"]('eval', num_classes, 8, freeze_bn)
# deeplab_v3_s8_network = net_factory.nets_map["deeplab_v3_s8"]('eval', num_classes, 8, freeze_bn)
deeplab_v3_s8_network = net_factory.nets_map["deeplab_v3_s8"](num_classes, 8)
return deeplab_v3_s8_network
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,26 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config"""
from src.Deeptext.deeptext_vgg16 import Deeptext_VGG16
from src.config import config
def deeptext_net(*args, **kwargs):
return Deeptext_VGG16(*args, **kwargs)
def create_network(name, *args, **kwargs):
"""create_network about deeptext"""
if name == "deeptext":
return deeptext_net(config=config, *args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,25 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config"""
from src.dpn import DPN
def dpn_net(*args, **kwargs):
return DPN(*args, **kwargs)
def create_network(name, *args, **kwargs):
"""create_network about dpn"""
if name == "dpn":
return dpn_net(*args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,26 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config"""
from src.lenet_fusion import LeNet5 as LeNet5Fusion
from src.config import mnist_cfg as cfg
def lenet_net(*args, **kwargs):
return LeNet5Fusion(*args, **kwargs)
def create_network(name, *args, **kwargs):
"""create_network about lenet_quant"""
if name == "lenet_quant":
return lenet_net(cfg.num_classes, *args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -16,13 +16,14 @@
from src.mobilenetV2 import MobileNetV2Backbone, MobileNetV2Head, mobilenet_v2
def create_network(name, *args, **kwargs):
"""create_network about mobilenetv2"""
if name == "mobilenetv2":
backbone_net = MobileNetV2Backbone()
include_top = kwargs["include_top"]
include_top = kwargs.get("include_top", True)
if include_top is None:
include_top = True
if include_top:
activation = kwargs["activation"]
activation = kwargs.get("activation", True)
head_net = MobileNetV2Head(input_channel=backbone_net.out_channels,
num_classes=int(kwargs["num_classes"]),
activation=activation)

View File

@ -0,0 +1,35 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config"""
from mindspore.compression.quant import QuantizationAwareTraining
from src.config import config_ascend_quant
from src.mobilenetV2 import mobilenetV2
def mobilenetv2_quant_net(*args, **kwargs):
symmetric_list = [False, False]
# define fusion network
network = mobilenetV2(num_classes=config_ascend_quant.num_classes)
# convert fusion network to quantization aware network
quantizer = QuantizationAwareTraining(bn_fold=True,
per_channel=[True, False],
symmetric=symmetric_list)
network = quantizer.quantize(network)
return network
def create_network(name, *args, **kwargs):
"""create_network about mobilenetv2_quant"""
if name == "mobilenetv2_quant":
return mobilenetv2_quant_net(*args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -13,9 +13,12 @@
# limitations under the License.
# ============================================================================
"""hub config."""
from src.resnet import resnet50, resnet101, se_resnet50
from src.resnet import resnet50, resnet101, se_resnet50, resnet18
def create_network(name, *args, **kwargs):
"""create_network about resnet"""
if name == 'resnet18':
return resnet18(*args, **kwargs)
if name == 'resnet50':
return resnet50(*args, **kwargs)
if name == 'resnet101':

View File

@ -0,0 +1,27 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config"""
#from models.resnet_quant import resnet50_quant #auto construct quantative network of resnet50
from models.resnet_quant_manual import resnet50_quant #manually construct quantative network of resnet50
from src.config import config_quant as config
def resnet50_quant_net(*args, **kwargs):
return resnet50_quant(*args, **kwargs)
def create_network(name, *args, **kwargs):
"""create_network about resnet50_quant"""
if name == "resnet50_quant":
return resnet50_quant_net(class_num=config.class_num, *args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,26 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config"""
from src.image_classification import get_network
def resnext101_net(*args, **kwargs):
return get_network(*args, **kwargs)
def create_network(name, *args, **kwargs):
"""create_network about resnext101"""
if name == "resnext101":
platform = "Ascend"
return resnext101_net(platform=platform, *args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,28 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config"""
from src.retinanet import retinanetWithLossCell, retinanet50, resnet50
from src.config import config
def retinanet_net(*args, **kwargs):
return retinanetWithLossCell(*args, **kwargs)
def create_network(name, *args, **kwargs):
"""create_network about retinanet"""
if name == "retinanet":
backbone = resnet50(config.num_classes)
retinanet = retinanet50(backbone, config)
return retinanet_net(retinanet, config, *args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,26 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config"""
from src.model import get_pose_net
from src.config import config
def simplepose_net(*args, **kwargs):
return get_pose_net(*args, **kwargs)
def create_network(name, *args, **kwargs):
"""create_network about simple_pose"""
if name == "simple_pose":
return simplepose_net(config, *args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,26 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config"""
from src.unet3d_model import UNet3d
from src.config import config as cfg
def unet3d_net(*args, **kwargs):
return UNet3d(*args, **kwargs)
def create_network(name, *args, **kwargs):
"""create_network about unet3d"""
if name == "unet3d":
return unet3d_net(config=cfg, *args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -16,8 +16,13 @@
from src.yolo import YOLOV4CspDarkNet53
def create_network(name, *args, **kwargs):
"""create_network about yolov4"""
if name == "yolov4_cspdarknet53":
yolov4_cspdarknet53_net = YOLOV4CspDarkNet53()
yolov4_cspdarknet53_net.set_train(False)
return yolov4_cspdarknet53_net
if name == "yolov4_shape416":
yolov4_shape416 = YOLOV4CspDarkNet53()
yolov4_shape416.set_train(False)
return yolov4_shape416
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""hub config."""
import os
import mindspore.common.dtype as mstype
from config import GNMTConfig
@ -29,10 +30,9 @@ def get_config(config):
def create_network(name, *args, **kwargs):
"""create gnmt network."""
if name == "gnmt":
if "config" in kwargs:
config = get_config(kwargs["config"])
else:
raise NotImplementedError(f"Please make sure the configuration file path is correct")
default_config_path = os.path.join(os.path.split(os.path.realpath(__file__))[0], "config/config.json")
config_path = kwargs.get("config", default_config_path)
config = get_config(config_path)
is_training = kwargs.get("is_training", False)
if is_training:
return GNMTNetworkWithLoss(config, is_training=is_training, *args)

View File

@ -0,0 +1,40 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config"""
from mindspore import dtype
from src.gpt import GPT
from src.utils import GPTConfig
def gpt_net(*args, **kwargs):
return GPT(*args, **kwargs)
def create_network(name, *args, **kwargs):
"""
create net work gpt
"""
if name == "gpt":
config = GPTConfig(batch_size=16,
seq_length=1024,
vocab_size=50257,
embedding_size=1024,
num_layers=24,
num_heads=16,
expand_ratio=4,
post_layernorm_residual=False,
dropout_rate=0.0,
compute_dtype=dtype.float16,
use_past=False)
return gpt_net(config, *args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,28 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config"""
from src.config import config
from src.seq2seq import Seq2Seq
from src.gru_for_train import GRUTrainOneStepWithLossScaleCell
def gru_net(*args, **kwargs):
network = Seq2Seq(*args, **kwargs)
return GRUTrainOneStepWithLossScaleCell(network)
def create_network(name, *args, **kwargs):
"""create_network about gru"""
if name == "gru":
return gru_net(config, *args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,28 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config"""
from src.naml import NAML, NAMLWithLossCell
from src.option import get_args
def naml_net(*args, **kwargs):
return NAMLWithLossCell(*args, **kwargs)
def create_network(name, *args, **kwargs):
"""create_network about naml"""
if name == "naml":
arg = get_args("eval")
net = NAML(arg)
return naml_net(net, *args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -14,12 +14,24 @@
# ============================================================================
"""hub config"""
from src.ncf import NCFModel
from src.config import cfg
def ncfnet(*args, **kwargs):
def ncf_net(*args, **kwargs):
return NCFModel(*args, **kwargs)
def create_network(name, *args, **kwargs):
"""create_network about ncf"""
if name == "ncf":
return NCFModel(*args, **kwargs)
layers = cfg.layers
num_factors = cfg.num_factors
num_users = 6040
num_items = 3706
return ncf_net(num_users=num_users,
num_items=num_items,
num_factors=num_factors,
model_layers=layers,
mf_regularization=0,
mlp_reg_layers=[0.0, 0.0, 0.0, 0.0],
mf_dim=16)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -26,6 +26,7 @@ def get_WideDeep_net(config):
return eval_net
def create_network(name, *args, **kwargs):
"""create_network about wide_and_deep"""
if name == 'wide_and_deep':
eval_net = get_WideDeep_net(cfg)
return eval_net

View File

@ -0,0 +1,29 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config"""
from src.FaceDetection.yolov3 import HwYolov3 as backbone_HwYolov3
from src.config import config
def facedetection_net(*args, **kwargs):
return backbone_HwYolov3(*args, **kwargs)
def create_network(name, *args, **kwargs):
"""create_network about facedetection"""
if name == "facedetection":
num_classes = config.num_classes
anchors_mask = config.anchors_masl
num_anchors_list = [len(x) for x in anchors_mask]
return facedetection_net(num_classes, num_anchors_list, *args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,25 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config"""
from src.face_qa import FaceQABackbone
def facequality_net(*args, **kwargs):
return FaceQABackbone(*args, **kwargs)
def create_network(name, *args, **kwargs):
"""create_network about face_quality_assesment"""
if name == "face_quality_assessment":
return facequality_net(*args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,27 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config"""
from src.config import config_inference
from src.backbone.resnet import get_backbone
def facerecognition_net(*args, **kwargs):
return get_backbone(*args, **kwargs)
def create_network(name, *args, **kwargs):
"""create_network about facerecognition"""
if name == "facerecognition":
arg = config_inference
return facerecognition_net(arg, *args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,26 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config"""
from src.reid import SphereNet
def faceRecognitionTrack_net(*args, **kwargs):
return SphereNet(*args, **kwargs)
def create_network(name, *args, **kwargs):
"""create_network about face_recognition_for_tracking"""
if name == "face_recognition_for_tracking":
layers = 12
return faceRecognitionTrack_net(num_layers=layers, *args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,28 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config"""
from src import CenterNetMultiPoseEval
from src.config import net_config, eval_config
def centernet_net(*args, **kwargs):
return CenterNetMultiPoseEval(*args, **kwargs)
def create_network(name, *args, **kwargs):
"""create_network about centernet"""
if name == "centernet":
# True, if device is Ascend
enable_nms_fp16 = True
return centernet_net(net_config, eval_config, enable_nms_fp16, *args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")