diff --git a/model_zoo/official/gnn/gat/mindspore_hub_conf.py b/model_zoo/official/gnn/gat/mindspore_hub_conf.py index 23915bb7365..0eb2cd8397d 100644 --- a/model_zoo/official/gnn/gat/mindspore_hub_conf.py +++ b/model_zoo/official/gnn/gat/mindspore_hub_conf.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,12 +14,34 @@ # ============================================================================ """hub config.""" from src.gat import GAT - -def gat(*args, **kwargs): - return GAT(*args, **kwargs) - +from src.config import GatConfig def create_network(name, *args, **kwargs): + """ create net work""" if name == "gat": - return gat(*args, **kwargs) + + if "ftr_dims" in kwargs: + featureDims = kwargs.get("ftr_dims") + else: + featureDims = 3706 + + if "num_class" in kwargs: + numClass = kwargs.get("num_class") + else: + numClass = 10 + + if "num_nodes" in kwargs: + numNodes = kwargs.get("num_nodes") + else: + numNodes = 30 + + gat_net = GAT(featureDims, + numClass, + numNodes, + GatConfig.hid_units, + GatConfig.n_heads, + attn_drop=GatConfig.attn_dropout, + ftr_drop=GatConfig.feature_dropout) + + return gat_net raise NotImplementedError(f"{name} is not implemented in the repo") diff --git a/model_zoo/official/nlp/mass/mindspore_hub_conf.py b/model_zoo/official/nlp/mass/mindspore_hub_conf.py index 999ad31ffba..1c2e2463dff 100644 --- a/model_zoo/official/nlp/mass/mindspore_hub_conf.py +++ b/model_zoo/official/nlp/mass/mindspore_hub_conf.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ # limitations under the License. # ============================================================================ """hub config.""" +import os import mindspore.common.dtype as mstype - from config import TransformerConfig from src.transformer import TransformerNetworkWithLoss, TransformerInferModel @@ -27,10 +27,10 @@ def get_config(config): def create_network(name, *args, **kwargs): """create mass network.""" if name == "mass": - if "config" in kwargs: - config = get_config(kwargs["config"]) - else: - raise NotImplementedError(f"Please make sure the configuration file path is correct") + # get the config running dir + configDir = os.path.split(os.path.realpath(__file__))[0] + "/config/config.json" + # get the config + config = get_config(configDir) is_training = kwargs.get("is_training", False) if is_training: return TransformerNetworkWithLoss(config, is_training=is_training, *args)