forked from mindspore-Ecosystem/mindspore
!19003 Update the mindspore_hub_conf file required for the operation of the hub warehouse in the mindspore warehouse
Merge pull request !19003 from dinglinhe/dlh_code_ms_I3J8EP_1
This commit is contained in:
commit
d23ddffc16
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config"""
|
||||
from src.nets.FCN8s import FCN8s
|
||||
|
||||
def fcn8s_net(*args, **kwargs):
|
||||
return FCN8s(*args, **kwargs)
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about FCN8s"""
|
||||
if name == "fcn8s":
|
||||
num_classes = 21
|
||||
return fcn8s_net(n_class=num_classes, *args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config"""
|
||||
from src.cnn_direction_model import CNNDirectionModel
|
||||
|
||||
def cnnDirection_net(*args, **kwargs):
|
||||
return CNNDirectionModel(*args, **kwargs)
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about CNNDirectionModel"""
|
||||
if name == "cnn_direction_model":
|
||||
in_channels = [3, 64, 48, 48, 64]
|
||||
out_channels = [64, 48, 48, 64, 64]
|
||||
dense_layers = [256, 64]
|
||||
image_size = [64, 512]
|
||||
return cnnDirection_net(in_channels, out_channels, dense_layers, image_size, *args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config"""
|
||||
from src.crnn import crnn
|
||||
from src.config import config1
|
||||
|
||||
def crnn_net(*args, **kwargs):
|
||||
return crnn(*args, **kwargs)
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about crnn"""
|
||||
if name == "crnn":
|
||||
return crnn_net(config1, *args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config"""
|
||||
from src.attention_ocr import AttentionOCRInfer
|
||||
from src.config import config
|
||||
|
||||
def crnnseq2seqocr_net(*args, **kwargs):
|
||||
return AttentionOCRInfer(*args, **kwargs)
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about crnn_seq2seq_ocr"""
|
||||
if name == "crnn_seq2seq_ocr":
|
||||
return crnnseq2seqocr_net(config.batch_size,
|
||||
int(config.img_width / 4),
|
||||
config.encoder_hidden_size,
|
||||
config.decoder_hidden_size,
|
||||
config.decoder_output_size,
|
||||
config.max_length,
|
||||
config.dropout_p,
|
||||
*args,
|
||||
**kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -19,9 +19,10 @@ def create_network(name, *args, **kwargs):
|
|||
freeze_bn = True
|
||||
num_classes = kwargs["num_classes"]
|
||||
if name == 'deeplab_v3_s16':
|
||||
deeplab_v3_s16_network = net_factory.nets_map["deeplab_v3_s16"]('eval', num_classes, 16, freeze_bn)
|
||||
deeplab_v3_s16_network = net_factory.nets_map["deeplab_v3_s16"](num_classes, 16)
|
||||
return deeplab_v3_s16_network
|
||||
if name == 'deeplab_v3_s8':
|
||||
deeplab_v3_s8_network = net_factory.nets_map["deeplab_v3_s8"]('eval', num_classes, 8, freeze_bn)
|
||||
# deeplab_v3_s8_network = net_factory.nets_map["deeplab_v3_s8"]('eval', num_classes, 8, freeze_bn)
|
||||
deeplab_v3_s8_network = net_factory.nets_map["deeplab_v3_s8"](num_classes, 8)
|
||||
return deeplab_v3_s8_network
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config"""
|
||||
from src.Deeptext.deeptext_vgg16 import Deeptext_VGG16
|
||||
from src.config import config
|
||||
|
||||
def deeptext_net(*args, **kwargs):
|
||||
return Deeptext_VGG16(*args, **kwargs)
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about deeptext"""
|
||||
if name == "deeptext":
|
||||
return deeptext_net(config=config, *args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config"""
|
||||
from src.dpn import DPN
|
||||
|
||||
def dpn_net(*args, **kwargs):
|
||||
return DPN(*args, **kwargs)
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about dpn"""
|
||||
if name == "dpn":
|
||||
return dpn_net(*args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config"""
|
||||
from src.lenet_fusion import LeNet5 as LeNet5Fusion
|
||||
from src.config import mnist_cfg as cfg
|
||||
|
||||
def lenet_net(*args, **kwargs):
|
||||
return LeNet5Fusion(*args, **kwargs)
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about lenet_quant"""
|
||||
if name == "lenet_quant":
|
||||
return lenet_net(cfg.num_classes, *args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -16,13 +16,14 @@
|
|||
from src.mobilenetV2 import MobileNetV2Backbone, MobileNetV2Head, mobilenet_v2
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about mobilenetv2"""
|
||||
if name == "mobilenetv2":
|
||||
backbone_net = MobileNetV2Backbone()
|
||||
include_top = kwargs["include_top"]
|
||||
include_top = kwargs.get("include_top", True)
|
||||
if include_top is None:
|
||||
include_top = True
|
||||
if include_top:
|
||||
activation = kwargs["activation"]
|
||||
activation = kwargs.get("activation", True)
|
||||
head_net = MobileNetV2Head(input_channel=backbone_net.out_channels,
|
||||
num_classes=int(kwargs["num_classes"]),
|
||||
activation=activation)
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config"""
|
||||
from mindspore.compression.quant import QuantizationAwareTraining
|
||||
from src.config import config_ascend_quant
|
||||
from src.mobilenetV2 import mobilenetV2
|
||||
|
||||
def mobilenetv2_quant_net(*args, **kwargs):
|
||||
symmetric_list = [False, False]
|
||||
# define fusion network
|
||||
network = mobilenetV2(num_classes=config_ascend_quant.num_classes)
|
||||
# convert fusion network to quantization aware network
|
||||
quantizer = QuantizationAwareTraining(bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=symmetric_list)
|
||||
network = quantizer.quantize(network)
|
||||
return network
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about mobilenetv2_quant"""
|
||||
if name == "mobilenetv2_quant":
|
||||
return mobilenetv2_quant_net(*args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -13,9 +13,12 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config."""
|
||||
from src.resnet import resnet50, resnet101, se_resnet50
|
||||
from src.resnet import resnet50, resnet101, se_resnet50, resnet18
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about resnet"""
|
||||
if name == 'resnet18':
|
||||
return resnet18(*args, **kwargs)
|
||||
if name == 'resnet50':
|
||||
return resnet50(*args, **kwargs)
|
||||
if name == 'resnet101':
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config"""
|
||||
#from models.resnet_quant import resnet50_quant #auto construct quantative network of resnet50
|
||||
from models.resnet_quant_manual import resnet50_quant #manually construct quantative network of resnet50
|
||||
from src.config import config_quant as config
|
||||
|
||||
def resnet50_quant_net(*args, **kwargs):
|
||||
return resnet50_quant(*args, **kwargs)
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about resnet50_quant"""
|
||||
if name == "resnet50_quant":
|
||||
return resnet50_quant_net(class_num=config.class_num, *args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config"""
|
||||
from src.image_classification import get_network
|
||||
|
||||
def resnext101_net(*args, **kwargs):
|
||||
return get_network(*args, **kwargs)
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about resnext101"""
|
||||
if name == "resnext101":
|
||||
platform = "Ascend"
|
||||
return resnext101_net(platform=platform, *args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config"""
|
||||
from src.retinanet import retinanetWithLossCell, retinanet50, resnet50
|
||||
from src.config import config
|
||||
|
||||
def retinanet_net(*args, **kwargs):
|
||||
return retinanetWithLossCell(*args, **kwargs)
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about retinanet"""
|
||||
if name == "retinanet":
|
||||
backbone = resnet50(config.num_classes)
|
||||
retinanet = retinanet50(backbone, config)
|
||||
return retinanet_net(retinanet, config, *args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config"""
|
||||
from src.model import get_pose_net
|
||||
from src.config import config
|
||||
|
||||
def simplepose_net(*args, **kwargs):
|
||||
return get_pose_net(*args, **kwargs)
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about simple_pose"""
|
||||
if name == "simple_pose":
|
||||
return simplepose_net(config, *args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config"""
|
||||
from src.unet3d_model import UNet3d
|
||||
from src.config import config as cfg
|
||||
|
||||
def unet3d_net(*args, **kwargs):
|
||||
return UNet3d(*args, **kwargs)
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about unet3d"""
|
||||
if name == "unet3d":
|
||||
return unet3d_net(config=cfg, *args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -16,8 +16,13 @@
|
|||
from src.yolo import YOLOV4CspDarkNet53
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about yolov4"""
|
||||
if name == "yolov4_cspdarknet53":
|
||||
yolov4_cspdarknet53_net = YOLOV4CspDarkNet53()
|
||||
yolov4_cspdarknet53_net.set_train(False)
|
||||
return yolov4_cspdarknet53_net
|
||||
if name == "yolov4_shape416":
|
||||
yolov4_shape416 = YOLOV4CspDarkNet53()
|
||||
yolov4_shape416.set_train(False)
|
||||
return yolov4_shape416
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config."""
|
||||
import os
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from config import GNMTConfig
|
||||
|
@ -29,10 +30,9 @@ def get_config(config):
|
|||
def create_network(name, *args, **kwargs):
|
||||
"""create gnmt network."""
|
||||
if name == "gnmt":
|
||||
if "config" in kwargs:
|
||||
config = get_config(kwargs["config"])
|
||||
else:
|
||||
raise NotImplementedError(f"Please make sure the configuration file path is correct")
|
||||
default_config_path = os.path.join(os.path.split(os.path.realpath(__file__))[0], "config/config.json")
|
||||
config_path = kwargs.get("config", default_config_path)
|
||||
config = get_config(config_path)
|
||||
is_training = kwargs.get("is_training", False)
|
||||
if is_training:
|
||||
return GNMTNetworkWithLoss(config, is_training=is_training, *args)
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config"""
|
||||
from mindspore import dtype
|
||||
from src.gpt import GPT
|
||||
from src.utils import GPTConfig
|
||||
|
||||
def gpt_net(*args, **kwargs):
|
||||
return GPT(*args, **kwargs)
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""
|
||||
create net work gpt
|
||||
"""
|
||||
if name == "gpt":
|
||||
config = GPTConfig(batch_size=16,
|
||||
seq_length=1024,
|
||||
vocab_size=50257,
|
||||
embedding_size=1024,
|
||||
num_layers=24,
|
||||
num_heads=16,
|
||||
expand_ratio=4,
|
||||
post_layernorm_residual=False,
|
||||
dropout_rate=0.0,
|
||||
compute_dtype=dtype.float16,
|
||||
use_past=False)
|
||||
return gpt_net(config, *args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config"""
|
||||
from src.config import config
|
||||
from src.seq2seq import Seq2Seq
|
||||
from src.gru_for_train import GRUTrainOneStepWithLossScaleCell
|
||||
|
||||
def gru_net(*args, **kwargs):
|
||||
network = Seq2Seq(*args, **kwargs)
|
||||
return GRUTrainOneStepWithLossScaleCell(network)
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about gru"""
|
||||
if name == "gru":
|
||||
return gru_net(config, *args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config"""
|
||||
from src.naml import NAML, NAMLWithLossCell
|
||||
from src.option import get_args
|
||||
|
||||
def naml_net(*args, **kwargs):
|
||||
return NAMLWithLossCell(*args, **kwargs)
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about naml"""
|
||||
if name == "naml":
|
||||
arg = get_args("eval")
|
||||
net = NAML(arg)
|
||||
return naml_net(net, *args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -14,12 +14,24 @@
|
|||
# ============================================================================
|
||||
"""hub config"""
|
||||
from src.ncf import NCFModel
|
||||
from src.config import cfg
|
||||
|
||||
def ncfnet(*args, **kwargs):
|
||||
def ncf_net(*args, **kwargs):
|
||||
return NCFModel(*args, **kwargs)
|
||||
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about ncf"""
|
||||
if name == "ncf":
|
||||
return NCFModel(*args, **kwargs)
|
||||
layers = cfg.layers
|
||||
num_factors = cfg.num_factors
|
||||
num_users = 6040
|
||||
num_items = 3706
|
||||
return ncf_net(num_users=num_users,
|
||||
num_items=num_items,
|
||||
num_factors=num_factors,
|
||||
model_layers=layers,
|
||||
mf_regularization=0,
|
||||
mlp_reg_layers=[0.0, 0.0, 0.0, 0.0],
|
||||
mf_dim=16)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
||||
|
|
|
@ -26,6 +26,7 @@ def get_WideDeep_net(config):
|
|||
return eval_net
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about wide_and_deep"""
|
||||
if name == 'wide_and_deep':
|
||||
eval_net = get_WideDeep_net(cfg)
|
||||
return eval_net
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config"""
|
||||
from src.FaceDetection.yolov3 import HwYolov3 as backbone_HwYolov3
|
||||
from src.config import config
|
||||
|
||||
def facedetection_net(*args, **kwargs):
|
||||
return backbone_HwYolov3(*args, **kwargs)
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about facedetection"""
|
||||
if name == "facedetection":
|
||||
num_classes = config.num_classes
|
||||
anchors_mask = config.anchors_masl
|
||||
num_anchors_list = [len(x) for x in anchors_mask]
|
||||
return facedetection_net(num_classes, num_anchors_list, *args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config"""
|
||||
from src.face_qa import FaceQABackbone
|
||||
|
||||
def facequality_net(*args, **kwargs):
|
||||
return FaceQABackbone(*args, **kwargs)
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about face_quality_assesment"""
|
||||
if name == "face_quality_assessment":
|
||||
return facequality_net(*args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config"""
|
||||
from src.config import config_inference
|
||||
from src.backbone.resnet import get_backbone
|
||||
|
||||
def facerecognition_net(*args, **kwargs):
|
||||
return get_backbone(*args, **kwargs)
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about facerecognition"""
|
||||
if name == "facerecognition":
|
||||
arg = config_inference
|
||||
return facerecognition_net(arg, *args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config"""
|
||||
from src.reid import SphereNet
|
||||
|
||||
def faceRecognitionTrack_net(*args, **kwargs):
|
||||
return SphereNet(*args, **kwargs)
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about face_recognition_for_tracking"""
|
||||
if name == "face_recognition_for_tracking":
|
||||
layers = 12
|
||||
return faceRecognitionTrack_net(num_layers=layers, *args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config"""
|
||||
from src import CenterNetMultiPoseEval
|
||||
from src.config import net_config, eval_config
|
||||
|
||||
def centernet_net(*args, **kwargs):
|
||||
return CenterNetMultiPoseEval(*args, **kwargs)
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
"""create_network about centernet"""
|
||||
if name == "centernet":
|
||||
# True, if device is Ascend
|
||||
enable_nms_fp16 = True
|
||||
return centernet_net(net_config, eval_config, enable_nms_fp16, *args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
Loading…
Reference in New Issue