forked from mindspore-Ecosystem/mindspore
!20268 modify modelzoo network cpm scripts to adapt cloud environment.
Merge pull request !20268 from anzhengqi/modify-modelzoo-cpm
This commit is contained in:
commit
6abd98bf22
|
@ -0,0 +1,141 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: Ascend
|
||||
enable_profiling: False
|
||||
|
||||
# ==============================================================================
|
||||
# config
|
||||
|
||||
dataset: ""
|
||||
dataset_path: ""
|
||||
truth_labels_path: ""
|
||||
dev_dataset: ""
|
||||
dev_data_path: ""
|
||||
test_dataset: ""
|
||||
test_data_path: ""
|
||||
pretrain_ckpt_path: ""
|
||||
save_checkpoint_path: "./"
|
||||
ckpt_path_doc: ""
|
||||
ckpt_partition: 8
|
||||
distribute: False
|
||||
has_train_strategy: True
|
||||
result_path: ""
|
||||
ckpt_epoch: 4
|
||||
multi_machine: False
|
||||
|
||||
file_name: "cpm"
|
||||
file_format: "AIR" # ["AIR", "MINDIR"]
|
||||
|
||||
config_zero_shot_standalone:
|
||||
dp: 1
|
||||
mp: 1
|
||||
batch_size: 1
|
||||
rank_size: 1
|
||||
vocab_size: 30000
|
||||
seq_length: 571
|
||||
hidden_size: 2560
|
||||
num_hidden_layers: 32
|
||||
num_attention_heads: 32
|
||||
|
||||
config_zero_shot_distrubute:
|
||||
dp: 1
|
||||
mp: 2
|
||||
batch_size: 2
|
||||
rank_size: 2
|
||||
vocab_size: 30000
|
||||
seq_length: 571
|
||||
hidden_size: 2560
|
||||
num_hidden_layers: 32
|
||||
num_attention_heads: 32
|
||||
|
||||
finetune_dev_standalone:
|
||||
dp: 1
|
||||
mp: 2
|
||||
batch_size: 2
|
||||
rank_size: 2
|
||||
vocab_size: 30000
|
||||
seq_length: 696
|
||||
hidden_size: 2560
|
||||
num_hidden_layers: 32
|
||||
num_attention_heads: 32
|
||||
|
||||
finetune_dev_distrubute:
|
||||
dp: 1
|
||||
mp: 2
|
||||
batch_size: 1
|
||||
rank_size: 2
|
||||
vocab_size: 30000
|
||||
seq_length: 696
|
||||
hidden_size: 2560
|
||||
num_hidden_layers: 32
|
||||
num_attention_heads: 32
|
||||
|
||||
finetune_test_standalone:
|
||||
dp: 1
|
||||
mp: 1
|
||||
batch_size: 1
|
||||
rank_size: 1
|
||||
vocab_size: 30000
|
||||
seq_length: 666
|
||||
hidden_size: 2560
|
||||
num_hidden_layers: 32
|
||||
num_attention_heads: 32
|
||||
|
||||
finetune_test_distrubute:
|
||||
dp: 1
|
||||
mp: 2
|
||||
batch_size: 1
|
||||
rank_size: 2
|
||||
vocab_size: 30000
|
||||
seq_length: 666
|
||||
hidden_size: 2560
|
||||
num_hidden_layers: 32
|
||||
num_attention_heads: 32
|
||||
|
||||
config_train_single_machine:
|
||||
dp: 4
|
||||
mp: 2
|
||||
epoch: 10
|
||||
batch_size: 16
|
||||
rank_size: 8
|
||||
vocab_size: 30000
|
||||
seq_length: 725
|
||||
hidden_size: 2560
|
||||
num_hidden_layers: 32
|
||||
num_attention_heads: 32
|
||||
lr: 0.00001
|
||||
eps: 0.00000001
|
||||
dropout: 0.2
|
||||
end_learning_rate: 0.0000001
|
||||
weight_decay: 0.01
|
||||
warmup_steps: 0.05
|
||||
power: 1.0
|
||||
grad_accumulation_step: 4
|
||||
sink_size: 1
|
||||
|
||||
config_train_multi_machine:
|
||||
dp: 16
|
||||
mp: 2
|
||||
epoch: 10
|
||||
batch_size: 128
|
||||
rank_size: 32
|
||||
vocab_size: 30000
|
||||
seq_length: 725
|
||||
hidden_size: 2560
|
||||
num_hidden_layers: 32
|
||||
num_attention_heads: 32
|
||||
lr: 0.00002
|
||||
eps: 0.00000001
|
||||
dropout: 0.1
|
||||
end_learning_rate: 0.0000001
|
||||
weight_decay: 0.01
|
||||
warmup_steps: 0.1
|
||||
power: 1.0
|
||||
grad_accumulation_step: 1
|
||||
sink_size: 1
|
|
@ -14,8 +14,6 @@
|
|||
# ============================================================================
|
||||
"""Eval."""
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
@ -32,15 +30,16 @@ from mindspore.parallel import set_algo_parameters
|
|||
from src.cpm import CPMModel
|
||||
from src.cpm_train import VirtualDatasetOneInputCell
|
||||
from src.cpm_loss import Cross_entropy_eval
|
||||
from src.config import finetune_test_distrubute, finetune_test_standalone
|
||||
from train import load_dataset
|
||||
|
||||
device_id = int(os.getenv("DEVICE_ID"))
|
||||
rank_size = os.getenv('RANK_SIZE')
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target="Ascend",
|
||||
device_id=device_id)
|
||||
device_id=get_device_id())
|
||||
|
||||
|
||||
class CPMForInfer(nn.Cell):
|
||||
|
@ -52,12 +51,12 @@ class CPMForInfer(nn.Cell):
|
|||
batch_size (int): Batch size of input dataset.
|
||||
seq_length (int): Length of input tensor sequence.
|
||||
vocab_size (int): Size of the dictionary of embeddings.
|
||||
config: The config of networks.
|
||||
cfg: The config of networks.
|
||||
|
||||
Returns:
|
||||
Tensor, losses.
|
||||
"""
|
||||
def __init__(self, network, batch_size, seq_length, vocab_size, config):
|
||||
def __init__(self, network, batch_size, seq_length, vocab_size, cfg):
|
||||
super(CPMForInfer, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.batch_size = batch_size
|
||||
|
@ -66,7 +65,7 @@ class CPMForInfer(nn.Cell):
|
|||
self.loss_net = Cross_entropy_eval(batch_size=self.batch_size,
|
||||
seq_length=self.seq_length,
|
||||
vocab_size=self.vocab_size,
|
||||
config=config)
|
||||
config=cfg)
|
||||
|
||||
def construct(self, input_ids, position_ids, attention_mask, loss_mask):
|
||||
logits = self.network(input_ids, position_ids, attention_mask)
|
||||
|
@ -94,19 +93,19 @@ class CPM_LAYER(nn.Cell):
|
|||
return output
|
||||
|
||||
|
||||
def run_eval(args, config_eval, ckpt_file_list=None):
|
||||
def do_eval(args, config_eval, ckpt_file_list=None):
|
||||
"""
|
||||
Building infer pipeline
|
||||
"""
|
||||
with open(args.data_path, "r") as f:
|
||||
with open(args.dataset_path, "r") as f:
|
||||
# cand_ids, data
|
||||
cand_ids, _ = json.load(f)
|
||||
print("++++ cand_ids: ", cand_ids)
|
||||
|
||||
if args.distribute:
|
||||
dataset = load_dataset(args.dataset, config_eval.batch_size,
|
||||
rank_size=MultiAscend.get_group_size(),
|
||||
rank_id=MultiAscend.get_rank(),
|
||||
rank_size=get_device_num(),
|
||||
rank_id=get_rank_id(),
|
||||
drop_remainder=False,
|
||||
is_training=False,
|
||||
shuffle=False)
|
||||
|
@ -146,7 +145,7 @@ def run_eval(args, config_eval, ckpt_file_list=None):
|
|||
batch_size=config_eval.batch_size,
|
||||
seq_length=config_eval.seq_length,
|
||||
vocab_size=config_eval.vocab_size,
|
||||
config=config_eval)
|
||||
cfg=config_eval)
|
||||
|
||||
model = Model(infer_net)
|
||||
|
||||
|
@ -212,7 +211,7 @@ def set_parallel_env():
|
|||
MultiAscend.init()
|
||||
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
|
||||
device_num=MultiAscend.get_group_size(),
|
||||
device_num=get_device_num(),
|
||||
gradients_mean=True,
|
||||
full_batch=True)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||
|
@ -260,38 +259,36 @@ def create_ckpt_file_list(args, max_index=None, train_strategy=None, steps_per_e
|
|||
raise Exception("+++ ckpt not found!!! +++")
|
||||
return ckpt_file_list
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="CPM inference")
|
||||
parser.add_argument('--dataset', type=str, default="", help="dataset path.")
|
||||
parser.add_argument("--data_path", type=str, default="/disk0/dataset/finetune_dataset/test.json",
|
||||
help='test_json path.')
|
||||
parser.add_argument('--ckpt_path_doc', type=str, default="", help="Checkpoint path document.")
|
||||
parser.add_argument('--ckpt_partition', type=int, default=8, help="Number of checkpoint partition.")
|
||||
parser.add_argument("--distribute", type=ast.literal_eval, default=False,
|
||||
help='Distribute evaluating with model parallel.')
|
||||
parser.add_argument("--has_train_strategy", type=ast.literal_eval, default=True,
|
||||
help='Model has distributed training strategy.')
|
||||
args_eval = parser.parse_args()
|
||||
if args_eval.distribute:
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_eval():
|
||||
'''eval cpm network'''
|
||||
finetune_test_distrubute = config.finetune_test_distrubute
|
||||
finetune_test_standalone = config.finetune_test_standalone
|
||||
if config.distribute:
|
||||
set_parallel_env()
|
||||
|
||||
ckpt_file_list_test = None
|
||||
if args_eval.has_train_strategy:
|
||||
if config.has_train_strategy:
|
||||
# Get the checkpoint with train strategy.
|
||||
train_strategy_list = create_ckpt_file_list(args_eval, train_strategy="train_strategy.ckpt")
|
||||
train_strategy_list = create_ckpt_file_list(config, train_strategy="train_strategy.ckpt")
|
||||
context.set_auto_parallel_context(
|
||||
strategy_ckpt_load_file=train_strategy_list[0]
|
||||
)
|
||||
ckpt_file_list_test = create_ckpt_file_list(args_eval)
|
||||
ckpt_file_list_test = create_ckpt_file_list(config)
|
||||
print("++++ Get sliced checkpoint file, lists: ", ckpt_file_list_test, flush=True)
|
||||
|
||||
result_accuracy = 0.0
|
||||
if args_eval.distribute:
|
||||
if config.distribute:
|
||||
print("Start validation on 2 devices with model parallel.")
|
||||
result_accuracy = run_eval(args_eval, finetune_test_distrubute, ckpt_file_list_test)
|
||||
result_accuracy = do_eval(config, finetune_test_distrubute, ckpt_file_list_test)
|
||||
else:
|
||||
print("Start validation on 1 device without model parallel.")
|
||||
result_accuracy = run_eval(args_eval, finetune_test_standalone, ckpt_file_list_test)
|
||||
result_accuracy = do_eval(config, finetune_test_standalone, ckpt_file_list_test)
|
||||
|
||||
print("++++ Accuracy=", result_accuracy)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_eval()
|
||||
|
|
|
@ -13,9 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export checkpoint file into air models"""
|
||||
import ast
|
||||
import argparse
|
||||
from easydict import EasyDict as ed
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context, load_distributed_checkpoint
|
||||
|
@ -25,39 +22,24 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net,
|
|||
|
||||
from eval import CPM_LAYER, create_ckpt_file_list
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="CPM export")
|
||||
parser.add_argument('--ckpt_path_doc', type=str, default="", help="checkpoint path document.")
|
||||
parser.add_argument('--ckpt_partition', type=int, default=8, help="Number of checkpoint partition.")
|
||||
parser.add_argument("--has_train_strategy", type=ast.literal_eval, default=True,
|
||||
help='has distributed training strategy')
|
||||
parser.add_argument("--file_name", type=str, default="cpm", help="output file name.")
|
||||
parser.add_argument('--file_format', type=str, choices=["AIR", "MINDIR"], default='AIR', help='file format')
|
||||
args = parser.parse_args()
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target="Ascend")
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
|
||||
finetune_eval_single = ed({
|
||||
"dp": 1,
|
||||
"mp": 1,
|
||||
"batch_size": 1,
|
||||
"rank_size": 1,
|
||||
"vocab_size": 30000,
|
||||
'seq_length': 666,
|
||||
"hidden_size": 2560,
|
||||
"num_hidden_layers": 32,
|
||||
"num_attention_heads": 32
|
||||
})
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_export():
|
||||
'''export cpm network'''
|
||||
finetune_test_standalone = config.finetune_test_standalone
|
||||
cpm_model = CPM_LAYER(finetune_test_standalone)
|
||||
|
||||
if __name__ == '__main__':
|
||||
config_eval = finetune_eval_single
|
||||
cpm_model = CPM_LAYER(config_eval)
|
||||
|
||||
if not args.has_train_strategy:
|
||||
weights = load_checkpoint(args.ckpt_path_doc)
|
||||
if not config.has_train_strategy:
|
||||
weights = load_checkpoint(config.ckpt_path_doc)
|
||||
can_be_loaded = {}
|
||||
print("+++++++loading weights+++++")
|
||||
for name, _ in weights.items():
|
||||
|
@ -71,17 +53,22 @@ if __name__ == '__main__':
|
|||
load_param_into_net(cpm_model, parameter_dict=can_be_loaded)
|
||||
else:
|
||||
context.set_auto_parallel_context(
|
||||
strategy_ckpt_load_file=args.ckpt_path_doc + "/train_strategy.ckpt"
|
||||
strategy_ckpt_load_file=config.ckpt_path_doc + "/train_strategy.ckpt"
|
||||
)
|
||||
ckpt_file_list = create_ckpt_file_list(args)
|
||||
ckpt_file_list = create_ckpt_file_list(config)
|
||||
print("Get checkpoint file lists++++", ckpt_file_list, flush=True)
|
||||
load_distributed_checkpoint(cpm_model, ckpt_file_list, None)
|
||||
|
||||
input_ids = Tensor(np.ones((config_eval.batch_size, config_eval.seq_length)), mstype.int64)
|
||||
position_ids = Tensor(np.random.randint(0, 10, [config_eval.batch_size, config_eval.seq_length]),
|
||||
mstype.int64)
|
||||
attention_mask = Tensor(np.random.randn(config_eval.batch_size, config_eval.seq_length, config_eval.seq_length),
|
||||
mstype.float16)
|
||||
input_ids = Tensor(np.ones((finetune_test_standalone.batch_size, finetune_test_standalone.seq_length)),
|
||||
mstype.int64)
|
||||
position_ids = Tensor(np.random.randint(0, 10, [finetune_test_standalone.batch_size,
|
||||
finetune_test_standalone.seq_length]), mstype.int64)
|
||||
attention_mask = Tensor(np.random.randn(finetune_test_standalone.batch_size,
|
||||
finetune_test_standalone.seq_length,
|
||||
finetune_test_standalone.seq_length), mstype.float16)
|
||||
|
||||
export(cpm_model, input_ids, position_ids, attention_mask, file_name=args.file_name,
|
||||
file_format=args.file_format)
|
||||
export(cpm_model, input_ids, position_ids, attention_mask, file_name=config.file_name,
|
||||
file_format=config.file_format)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_export()
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from .config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
# Run the main function
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -59,7 +59,7 @@ do
|
|||
export DEVICE_ID=$i
|
||||
echo "start eval for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python ../../eval.py --dataset $DATASET --data_path $LABEL --ckpt_path_doc $MODEL_CKPT --ckpt_partition $CKPT_NUMBER --distribute True --has_train_strategy True> log_cpm.log 2>&1 &
|
||||
python ../../eval.py --dataset $DATASET --dataset_path $LABEL --ckpt_path_doc $MODEL_CKPT --ckpt_partition $CKPT_NUMBER --distribute True --has_train_strategy True> log_cpm.log 2>&1 &
|
||||
cd ${current_exec_path} || exit
|
||||
done
|
||||
cd ${current_exec_path} || exit
|
||||
|
|
|
@ -70,7 +70,7 @@ do
|
|||
python ../../test.py --dev_dataset $DEV_DATASET --dev_data_path $DEV_LABEL \
|
||||
--test_dataset $TEST_DATASET --test_data_path $TEST_LABEL \
|
||||
--ckpt_path_doc $MODEL_CKPT --ckpt_partition $CKPT_NUMBER \
|
||||
--distribute True --has_train_strategy True> log_cpm.log 2>&1 &
|
||||
--distribute True --has_train_strategy True --result_path ./result.txt > log_cpm.log 2>&1 &
|
||||
|
||||
cd ${current_exec_path} || exit
|
||||
done
|
||||
|
|
|
@ -73,7 +73,7 @@ do
|
|||
--test_dataset $TEST_DATASET --test_data_path $TEST_LABEL \
|
||||
--ckpt_path_doc $MODEL_CKPT --ckpt_partition $CKPT_NUMBER \
|
||||
--ckpt_epoch $ckptepoch --result_path $result_path \
|
||||
--distribute False --has_train_strategy True> log_cpm.log 2>&1
|
||||
--distribute False --has_train_strategy True --result_path ./result.txt > log_cpm.log 2>&1
|
||||
|
||||
cd ${current_exec_path} || exit
|
||||
done
|
||||
|
|
|
@ -1,132 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Configure"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config_zero_shot_standalone = ed({
|
||||
"dp": 1,
|
||||
"mp": 1,
|
||||
"batch_size": 1,
|
||||
"rank_size": 1,
|
||||
"vocab_size": 30000,
|
||||
'seq_length': 571,
|
||||
"hidden_size": 2560,
|
||||
"num_hidden_layers": 32,
|
||||
"num_attention_heads": 32
|
||||
})
|
||||
|
||||
config_zero_shot_distrubute = ed({
|
||||
"dp": 1,
|
||||
"mp": 2,
|
||||
"batch_size": 2,
|
||||
"rank_size": 2,
|
||||
"vocab_size": 30000,
|
||||
'seq_length': 571,
|
||||
"hidden_size": 2560,
|
||||
"num_hidden_layers": 32,
|
||||
"num_attention_heads": 32
|
||||
})
|
||||
|
||||
finetune_dev_standalone = ed({
|
||||
"dp": 1,
|
||||
"mp": 1,
|
||||
"batch_size": 1,
|
||||
"rank_size": 1,
|
||||
"vocab_size": 30000,
|
||||
'seq_length': 696,
|
||||
"hidden_size": 2560,
|
||||
"num_hidden_layers": 32,
|
||||
"num_attention_heads": 32
|
||||
})
|
||||
|
||||
finetune_dev_distrubute = ed({
|
||||
"dp": 1,
|
||||
"mp": 2,
|
||||
"batch_size": 1,
|
||||
"rank_size": 2,
|
||||
"vocab_size": 30000,
|
||||
'seq_length': 696,
|
||||
"hidden_size": 2560,
|
||||
"num_hidden_layers": 32,
|
||||
"num_attention_heads": 32
|
||||
})
|
||||
|
||||
finetune_test_standalone = ed({
|
||||
"dp": 1,
|
||||
"mp": 1,
|
||||
"batch_size": 1,
|
||||
"rank_size": 1,
|
||||
"vocab_size": 30000,
|
||||
'seq_length': 666,
|
||||
"hidden_size": 2560,
|
||||
"num_hidden_layers": 32,
|
||||
"num_attention_heads": 32
|
||||
})
|
||||
|
||||
finetune_test_distrubute = ed({
|
||||
"dp": 1,
|
||||
"mp": 2,
|
||||
"batch_size": 1,
|
||||
"rank_size": 2,
|
||||
"vocab_size": 30000,
|
||||
'seq_length': 666,
|
||||
"hidden_size": 2560,
|
||||
"num_hidden_layers": 32,
|
||||
"num_attention_heads": 32
|
||||
})
|
||||
|
||||
config_train_single_machine = ed({
|
||||
"dp": 4,
|
||||
"mp": 2,
|
||||
"epoch": 10,
|
||||
"batch_size": 16,
|
||||
"rank_size": 8,
|
||||
"vocab_size": 30000,
|
||||
'seq_length': 725,
|
||||
"hidden_size": 2560,
|
||||
"num_hidden_layers": 32,
|
||||
"num_attention_heads": 32,
|
||||
"lr": 1e-5,
|
||||
"eps": 1e-8,
|
||||
"dropout": 0.2,
|
||||
"end_learning_rate": 1e-7,
|
||||
"weight_decay": 1e-2,
|
||||
"warmup_steps": 0.05,
|
||||
"power": 1.0,
|
||||
"grad_accumulation_step": 4,
|
||||
"sink_size": 1
|
||||
})
|
||||
|
||||
config_train_multi_machine = ed({
|
||||
"dp": 16,
|
||||
"mp": 2,
|
||||
"epoch": 10,
|
||||
"batch_size": 128,
|
||||
"rank_size": 32,
|
||||
"vocab_size": 30000,
|
||||
'seq_length': 725,
|
||||
"hidden_size": 2560,
|
||||
"num_hidden_layers": 32,
|
||||
"num_attention_heads": 32,
|
||||
"lr": 2e-5,
|
||||
"eps": 1e-8,
|
||||
"dropout": 0.1,
|
||||
"end_learning_rate": 1e-7,
|
||||
"weight_decay": 1e-2,
|
||||
"warmup_steps": 0.1,
|
||||
"power": 1.0,
|
||||
"grad_accumulation_step": 1,
|
||||
"sink_size": 1
|
||||
})
|
|
@ -14,25 +14,22 @@
|
|||
# ============================================================================
|
||||
"""Test."""
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.communication import management as MultiAscend
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.parallel import set_algo_parameters
|
||||
|
||||
from src.config import finetune_test_standalone, finetune_test_distrubute, \
|
||||
finetune_dev_distrubute, finetune_dev_standalone
|
||||
from eval import run_eval, create_ckpt_file_list
|
||||
from eval import do_eval, create_ckpt_file_list
|
||||
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
device_id = int(os.getenv("DEVICE_ID"))
|
||||
rank_size = os.getenv('RANK_SIZE')
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target="Ascend",
|
||||
device_id=device_id)
|
||||
|
||||
device_id=get_device_id())
|
||||
|
||||
def set_parallel_env():
|
||||
r"""
|
||||
|
@ -42,79 +39,71 @@ def set_parallel_env():
|
|||
MultiAscend.init()
|
||||
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
|
||||
device_num=MultiAscend.get_group_size(),
|
||||
device_num=get_device_num(),
|
||||
gradients_mean=True,
|
||||
full_batch=True)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="CPM inference")
|
||||
parser.add_argument('--dev_dataset', type=str, default="", help="dev_dataset path.")
|
||||
parser.add_argument("--dev_data_path", type=str, default="/disk0/dataset/finetune_dataset/dev.json",
|
||||
help='dev_json path.')
|
||||
parser.add_argument('--test_dataset', type=str, default="", help="test_dataset path.")
|
||||
parser.add_argument("--test_data_path", type=str, default="/disk0/dataset/finetune_dataset/test.json",
|
||||
help='test_json path.')
|
||||
parser.add_argument('--ckpt_path_doc', type=str, default="", help="checkpoint path document.")
|
||||
parser.add_argument('--ckpt_partition', type=int, default=8, help="Number of checkpoint partition.")
|
||||
parser.add_argument("--distribute", type=ast.literal_eval, default=False,
|
||||
help='Whether distributed evaluation with model parallel.')
|
||||
parser.add_argument("--has_train_strategy", type=ast.literal_eval, default=True,
|
||||
help='Whether the loaded checkpoints have distributed training strategy.')
|
||||
parser.add_argument("--result_path", type=str, default="/home/result.txt",
|
||||
help='Text save address.')
|
||||
parser.add_argument("--ckpt_epoch", type=int, default=4,
|
||||
help='The number of checkpoint epochs.')
|
||||
args_eval = parser.parse_args()
|
||||
|
||||
if args_eval.distribute:
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_test():
|
||||
'''test cpm network'''
|
||||
finetune_test_standalone = config.finetune_test_standalone
|
||||
finetune_test_distrubute = config.finetune_test_distrubute
|
||||
finetune_dev_standalone = config.finetune_dev_standalone
|
||||
finetune_dev_distrubute = config.finetune_dev_distrubute
|
||||
if config.distribute:
|
||||
set_parallel_env()
|
||||
print("Start validation on 2 devices.")
|
||||
else:
|
||||
print("Start validation on 1 device.")
|
||||
|
||||
args_eval.dataset = args_eval.dev_dataset
|
||||
args_eval.data_path = args_eval.dev_data_path
|
||||
if args_eval.has_train_strategy:
|
||||
config.dataset = config.dev_dataset
|
||||
config.dataset_path = config.dev_data_path
|
||||
if config.has_train_strategy:
|
||||
# Get the checkpoint with train strategy.
|
||||
train_strategy_list = create_ckpt_file_list(args_eval, train_strategy="train_strategy.ckpt")
|
||||
train_strategy_list = create_ckpt_file_list(config, train_strategy="train_strategy.ckpt")
|
||||
context.set_auto_parallel_context(
|
||||
strategy_ckpt_load_file=train_strategy_list[0]
|
||||
)
|
||||
# start run in dev dataset.
|
||||
ckpt_file_list_dev = None
|
||||
if args_eval.has_train_strategy:
|
||||
if config.has_train_strategy:
|
||||
# Get the checkpoint slice.
|
||||
ckpt_file_list_dev = create_ckpt_file_list(args_eval, args_eval.ckpt_epoch)
|
||||
ckpt_file_list_dev = create_ckpt_file_list(config, config.ckpt_epoch)
|
||||
print("++++ Get sliced checkpoint file, lists: ", ckpt_file_list_dev, flush=True)
|
||||
result_i = 0.0
|
||||
if args_eval.distribute:
|
||||
result_i = run_eval(args_eval, finetune_dev_distrubute, ckpt_file_list_dev)
|
||||
if config.distribute:
|
||||
result_i = do_eval(config, finetune_dev_distrubute, ckpt_file_list_dev)
|
||||
else:
|
||||
result_i = run_eval(args_eval, finetune_dev_standalone, ckpt_file_list_dev)
|
||||
print("+++++ ckpt_epoch=", args_eval.ckpt_epoch, ", dev_dataset Accuracy: ", result_i)
|
||||
result_i = do_eval(config, finetune_dev_standalone, ckpt_file_list_dev)
|
||||
print("+++++ ckpt_epoch=", config.ckpt_epoch, ", dev_dataset Accuracy: ", result_i)
|
||||
print("++++ Then we take the model to predict on the test dataset.")
|
||||
ckpt_file_list_test = None
|
||||
if args_eval.has_train_strategy:
|
||||
if config.has_train_strategy:
|
||||
# Get the best precision checkpoint slice.
|
||||
ckpt_file_list_test = create_ckpt_file_list(args_eval, args_eval.ckpt_epoch)
|
||||
ckpt_file_list_test = create_ckpt_file_list(config, config.ckpt_epoch)
|
||||
|
||||
args_eval.dataset = args_eval.test_dataset
|
||||
args_eval.data_path = args_eval.test_data_path
|
||||
config.dataset = config.test_dataset
|
||||
config.dataset_path = config.test_data_path
|
||||
# start run in test dataset.
|
||||
result_last = 0.0
|
||||
if args_eval.distribute:
|
||||
result_last = run_eval(args_eval, finetune_test_distrubute, ckpt_file_list_test)
|
||||
if config.distribute:
|
||||
result_last = do_eval(config, finetune_test_distrubute, ckpt_file_list_test)
|
||||
else:
|
||||
result_last = run_eval(args_eval, finetune_test_standalone, ckpt_file_list_test)
|
||||
result_last = do_eval(config, finetune_test_standalone, ckpt_file_list_test)
|
||||
print("++++ Accuracy on test dataset is: ", result_last)
|
||||
|
||||
# write to file.
|
||||
result_path = args_eval.result_path
|
||||
result_path = config.result_path
|
||||
if not os.path.exists(result_path):
|
||||
with open(result_path, "w") as file:
|
||||
file.write("CkptEpcoh Accuracy_dev Accuracy_test\n")
|
||||
|
||||
with open(result_path, "a") as file:
|
||||
file.write(str(args_eval.ckpt_epoch) + " " + str(result_i) + " " + str(result_last) + "\n")
|
||||
file.write(str(config.ckpt_epoch) + " " + str(result_i) + " " + str(result_last) + "\n")
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_test()
|
||||
|
|
|
@ -14,8 +14,6 @@
|
|||
# ============================================================================
|
||||
"""Train."""
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
@ -34,20 +32,22 @@ import mindspore.common.dtype as mstype
|
|||
import mindspore.dataset.transforms.c_transforms as C
|
||||
from mindspore.parallel import set_algo_parameters
|
||||
|
||||
from src.config import config_train_single_machine, config_train_multi_machine
|
||||
from src.cpm_train import CPMWithLoss, CPMTrainOneStepWithLossScaleCell, VirtualDatasetOneInputCell, \
|
||||
CPMTrainAccuStepsWithLossScaleCell
|
||||
from src.lr_schedule import CPMLearningRate
|
||||
from src.loss_monitor import LossCallBack, TimeCallBack
|
||||
from src.model_cpm import Model_ACCU as Model_CPM
|
||||
|
||||
device_id = int(os.getenv("DEVICE_ID"))
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
|
||||
|
||||
|
||||
set_seed(23333)
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target="Ascend",
|
||||
device_id=device_id)
|
||||
device_id=get_device_id())
|
||||
context.set_context(variable_memory_max_size="30GB")
|
||||
|
||||
|
||||
|
@ -188,7 +188,8 @@ def _build_training_pipeline(datasets, pretrain_ckpt_path, config_train):
|
|||
integrated_save=False,
|
||||
keep_checkpoint_max=config_train.epoch)
|
||||
ckpt_model = ModelCheckpoint(prefix='cpm_rank_{}'.format(os.getenv("RANK_ID")),
|
||||
directory=os.path.join('./', 'ckpt_rank_{}'.format(os.getenv("RANK_ID"))),
|
||||
directory=os.path.join(config.save_checkpoint_path,
|
||||
'ckpt_rank_{}'.format(get_rank_id())),
|
||||
config=ckpt_config)
|
||||
callback.append(ckpt_model)
|
||||
|
||||
|
@ -221,7 +222,7 @@ def set_parallel_env(config_train):
|
|||
context.reset_auto_parallel_context()
|
||||
MultiAscend.init()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
|
||||
device_num=MultiAscend.get_group_size(),
|
||||
device_num=get_device_num(),
|
||||
gradients_mean=True,
|
||||
grad_accumulation_step=config_train.grad_accumulation_step,
|
||||
full_batch=True)
|
||||
|
@ -248,22 +249,24 @@ def train_paralle(input_file, pretrain_ckpt_path, config_train):
|
|||
print("Staring training on multiple device")
|
||||
processed_data = load_dataset(dataset=input_file,
|
||||
batch_size=config_train.batch_size,
|
||||
rank_size=MultiAscend.get_group_size(),
|
||||
rank_id=MultiAscend.get_rank())
|
||||
rank_size=get_device_num(),
|
||||
rank_id=get_rank_id())
|
||||
_build_training_pipeline(processed_data, pretrain_ckpt_path, config_train)
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
config.save_checkpoint_path = config.output_path
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="CPM training.")
|
||||
parser.add_argument("--dataset", type=str, default="", help="CPM dataset path")
|
||||
parser.add_argument("--pretrain_ckpt_path", type=str, default="",
|
||||
help="Load the checkpoint file path for train.")
|
||||
parser.add_argument("--multi_machine", type=ast.literal_eval, default=False, help='distributed training')
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.multi_machine:
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_train():
|
||||
config_train_single_machine = config.config_train_single_machine
|
||||
config_train_multi_machine = config.config_train_multi_machine
|
||||
if config.multi_machine:
|
||||
print("Training on multiple machines.")
|
||||
train_paralle(args.dataset, args.pretrain_ckpt_path, config_train_multi_machine)
|
||||
train_paralle(config.dataset, config.pretrain_ckpt_path, config_train_multi_machine)
|
||||
else:
|
||||
print("Training on single machine.")
|
||||
train_paralle(args.dataset, args.pretrain_ckpt_path, config_train_single_machine)
|
||||
train_paralle(config.dataset, config.pretrain_ckpt_path, config_train_single_machine)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_train()
|
||||
|
|
|
@ -13,9 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Zero-shot."""
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
|
@ -33,14 +30,16 @@ from mindspore.parallel import set_algo_parameters
|
|||
from src.cpm import CPMModel
|
||||
from src.cpm_train import VirtualDatasetOneInputCell
|
||||
from src.cpm_loss import Cross_entropy
|
||||
from src.config import config_zero_shot_standalone, config_zero_shot_distrubute
|
||||
from eval import create_ckpt_file_list
|
||||
|
||||
device_id = int(os.getenv("DEVICE_ID"))
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target="Ascend",
|
||||
device_id=device_id)
|
||||
device_id=get_device_id())
|
||||
|
||||
|
||||
class CPMForInfer(nn.Cell):
|
||||
|
@ -52,12 +51,12 @@ class CPMForInfer(nn.Cell):
|
|||
batch_size (int): Batch size of input dataset.
|
||||
seq_length (int): Length of input tensor sequence.
|
||||
vocab_size (int): Size of the dictionary of embeddings.
|
||||
config: The config of networks.
|
||||
cfg: The config of networks.
|
||||
|
||||
Returns:
|
||||
Tensor, losses.
|
||||
"""
|
||||
def __init__(self, network, batch_size, seq_length, vocab_size, config):
|
||||
def __init__(self, network, batch_size, seq_length, vocab_size, cfg):
|
||||
super(CPMForInfer, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.batch_size = batch_size
|
||||
|
@ -66,7 +65,7 @@ class CPMForInfer(nn.Cell):
|
|||
self.loss_net = Cross_entropy(batch_size=self.batch_size,
|
||||
seq_length=self.seq_length,
|
||||
vocab_size=self.vocab_size,
|
||||
config=config)
|
||||
config=cfg)
|
||||
|
||||
def construct(self, input_ids, target, loss_mask):
|
||||
"""Defines the computation performed."""
|
||||
|
@ -75,28 +74,6 @@ class CPMForInfer(nn.Cell):
|
|||
return loss
|
||||
|
||||
|
||||
def collate(sid, cid, input_ids, BatchInfo):
|
||||
"""Collate operation for dataset."""
|
||||
bs = len(sid)
|
||||
max_size = np.size(input_ids, 1)
|
||||
|
||||
attn_mask = np.tril(np.ones(shape=(max_size, max_size),))
|
||||
attention_mask = np.expand_dims(attn_mask, 0)
|
||||
attention_mask = np.tile(attention_mask, (bs, 1, 1))
|
||||
|
||||
position_ids = np.expand_dims(np.arange(max_size * 1), 0)
|
||||
position_ids = np.tile(position_ids, (bs, 1))
|
||||
|
||||
sids_list = np.zeros(bs, dtype=np.int64)
|
||||
cids_list = np.zeros(bs, dtype=np.int64)
|
||||
|
||||
for i in range(bs):
|
||||
sids_list[i] = sid[i]
|
||||
cids_list[i] = cid[i]
|
||||
|
||||
return input_ids, attention_mask, position_ids, sids_list, cids_list
|
||||
|
||||
|
||||
def _load_dataset(dataset_path, batch_size, rank_size=None, rank_id=None, shuffle=True, drop_remainder=True):
|
||||
"""Loader for data."""
|
||||
data = ds.MindDataset(dataset_file=dataset_path,
|
||||
|
@ -163,8 +140,8 @@ def run_eval(args, config_eval, ckpt_file_list=None):
|
|||
|
||||
if args.distribute:
|
||||
dataset = load_dataset(args.dataset, config_eval.batch_size,
|
||||
rank_size=MultiAscend.get_group_size(),
|
||||
rank_id=MultiAscend.get_rank(),
|
||||
rank_size=get_device_num(),
|
||||
rank_id=get_rank_id(),
|
||||
drop_remainder=False,
|
||||
shuffle=False)
|
||||
else:
|
||||
|
@ -192,7 +169,7 @@ def run_eval(args, config_eval, ckpt_file_list=None):
|
|||
batch_size=config_eval.batch_size,
|
||||
seq_length=config_eval.seq_length,
|
||||
vocab_size=config_eval.vocab_size,
|
||||
config=config_eval)
|
||||
cfg=config_eval)
|
||||
|
||||
model = Model(infer_net)
|
||||
|
||||
|
@ -254,38 +231,37 @@ def set_parallel_env():
|
|||
MultiAscend.init()
|
||||
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
|
||||
device_num=MultiAscend.get_group_size(),
|
||||
device_num=get_device_num(),
|
||||
gradients_mean=True,
|
||||
full_batch=True)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="CPM inference")
|
||||
parser.add_argument('--dataset', type=str, default="", help="dataset path.")
|
||||
parser.add_argument('--truth_labels_path', type=str, default="", help="truth_labels path.")
|
||||
parser.add_argument('--ckpt_path_doc', type=str, default="", help="checkpoint path doc or checkpoint path.")
|
||||
parser.add_argument("--distribute", type=ast.literal_eval, default=False, help='Whether distributed evaluation'
|
||||
' with model parallel.')
|
||||
parser.add_argument("--has_train_strategy", type=ast.literal_eval, default=False,
|
||||
help='Whether the loaded checkpoints have distributed training strategy.')
|
||||
parser.add_argument('--ckpt_partition', type=int, default=1, help="Number of checkpoint partition.")
|
||||
args_parse = parser.parse_args()
|
||||
if args_parse.distribute:
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_test():
|
||||
'''test cpm network with zero_shot dataset.'''
|
||||
config_zero_shot_standalone = config.config_zero_shot_standalone
|
||||
config_zero_shot_distrubute = config.config_zero_shot_distrubute
|
||||
if config.distribute:
|
||||
set_parallel_env()
|
||||
|
||||
ckpt_file_list_test = None
|
||||
if args_parse.has_train_strategy:
|
||||
if config.has_train_strategy:
|
||||
# Get the checkpoint with train strategy.
|
||||
train_strategy_list = create_ckpt_file_list(args_parse, train_strategy="train_strategy.ckpt")
|
||||
train_strategy_list = create_ckpt_file_list(config, train_strategy="train_strategy.ckpt")
|
||||
context.set_auto_parallel_context(
|
||||
strategy_ckpt_load_file=train_strategy_list[0]
|
||||
)
|
||||
ckpt_file_list_test = create_ckpt_file_list(args_parse)
|
||||
ckpt_file_list_test = create_ckpt_file_list(config)
|
||||
print("Get checkpoint file lists++++", ckpt_file_list_test, flush=True)
|
||||
if args_parse.distribute:
|
||||
if config.distribute:
|
||||
print("Staring evaluating on 2 devices with model parallel.")
|
||||
run_eval(args_parse, config_zero_shot_distrubute, ckpt_file_list_test)
|
||||
run_eval(config, config_zero_shot_distrubute, ckpt_file_list_test)
|
||||
else:
|
||||
print("Staring evaluating on 1 device without model parallel.")
|
||||
run_eval(args_parse, config_zero_shot_standalone, ckpt_file_list_test)
|
||||
run_eval(config, config_zero_shot_standalone, ckpt_file_list_test)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_test()
|
||||
|
|
Loading…
Reference in New Issue