modify code of the file mindspore_hub_conf.py which network is gat and mass

This commit is contained in:
dinglinhe 2021-05-20 15:58:00 +08:00
parent 90a070e1ce
commit 0f1fae7198
2 changed files with 34 additions and 12 deletions

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,12 +14,34 @@
# ============================================================================ # ============================================================================
"""hub config.""" """hub config."""
from src.gat import GAT from src.gat import GAT
from src.config import GatConfig
def gat(*args, **kwargs):
return GAT(*args, **kwargs)
def create_network(name, *args, **kwargs): def create_network(name, *args, **kwargs):
""" create net work"""
if name == "gat": if name == "gat":
return gat(*args, **kwargs)
if "ftr_dims" in kwargs:
featureDims = kwargs.get("ftr_dims")
else:
featureDims = 3706
if "num_class" in kwargs:
numClass = kwargs.get("num_class")
else:
numClass = 10
if "num_nodes" in kwargs:
numNodes = kwargs.get("num_nodes")
else:
numNodes = 30
gat_net = GAT(featureDims,
numClass,
numNodes,
GatConfig.hid_units,
GatConfig.n_heads,
attn_drop=GatConfig.attn_dropout,
ftr_drop=GatConfig.feature_dropout)
return gat_net
raise NotImplementedError(f"{name} is not implemented in the repo") raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""hub config.""" """hub config."""
import os
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from config import TransformerConfig from config import TransformerConfig
from src.transformer import TransformerNetworkWithLoss, TransformerInferModel from src.transformer import TransformerNetworkWithLoss, TransformerInferModel
@ -27,10 +27,10 @@ def get_config(config):
def create_network(name, *args, **kwargs): def create_network(name, *args, **kwargs):
"""create mass network.""" """create mass network."""
if name == "mass": if name == "mass":
if "config" in kwargs: # get the config running dir
config = get_config(kwargs["config"]) configDir = os.path.split(os.path.realpath(__file__))[0] + "/config/config.json"
else: # get the config
raise NotImplementedError(f"Please make sure the configuration file path is correct") config = get_config(configDir)
is_training = kwargs.get("is_training", False) is_training = kwargs.get("is_training", False)
if is_training: if is_training:
return TransformerNetworkWithLoss(config, is_training=is_training, *args) return TransformerNetworkWithLoss(config, is_training=is_training, *args)