check model_zoo again
This commit is contained in:
parent
ade60ad3d3
commit
569fdd1037
|
@ -58,9 +58,6 @@ def _load_dataset(input_files, batch_size, epoch_count=1,
|
||||||
ori_dataset_size = ds.get_dataset_size()
|
ori_dataset_size = ds.get_dataset_size()
|
||||||
print(f" | Dataset size: {ori_dataset_size}.")
|
print(f" | Dataset size: {ori_dataset_size}.")
|
||||||
repeat_count = epoch_count
|
repeat_count = epoch_count
|
||||||
if sink_mode:
|
|
||||||
ds.set_dataset_size(sink_step * batch_size)
|
|
||||||
repeat_count = epoch_count * ori_dataset_size // ds.get_dataset_size()
|
|
||||||
|
|
||||||
type_cast_op = deC.TypeCast(mstype.int32)
|
type_cast_op = deC.TypeCast(mstype.int32)
|
||||||
ds = ds.map(input_columns="src", operations=type_cast_op)
|
ds = ds.map(input_columns="src", operations=type_cast_op)
|
||||||
|
|
|
@ -79,11 +79,15 @@ def _train(model, config: TransformerConfig,
|
||||||
|
|
||||||
if pre_training_dataset is not None:
|
if pre_training_dataset is not None:
|
||||||
print(" | Start pre-training job.")
|
print(" | Start pre-training job.")
|
||||||
epoch_size = pre_training_dataset.get_repeat_count()
|
epoch_size = config.epochs * pre_training_dataset.get_dataset_size() // config.dataset_sink_step
|
||||||
|
|
||||||
if os.getenv("RANK_SIZE") is not None and int(os.getenv("RANK_SIZE")) > 1:
|
if os.getenv("RANK_SIZE") is not None and int(os.getenv("RANK_SIZE")) > 1:
|
||||||
print(f" | Rank {MultiAscend.get_rank()} Call model train.")
|
print(f" | Rank {MultiAscend.get_rank()} Call model train.")
|
||||||
|
|
||||||
model.train(epoch_size, pre_training_dataset,
|
model.train(epoch_size, pre_training_dataset,
|
||||||
callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode)
|
callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode,
|
||||||
|
sink_size=config.dataset_sink_step)
|
||||||
|
|
||||||
# Test the accuracy of the model.
|
# Test the accuracy of the model.
|
||||||
if test_dataset is not None:
|
if test_dataset is not None:
|
||||||
print(" | Start test job.")
|
print(" | Start test job.")
|
||||||
|
@ -93,10 +97,11 @@ def _train(model, config: TransformerConfig,
|
||||||
|
|
||||||
if fine_tune_dataset is not None:
|
if fine_tune_dataset is not None:
|
||||||
print(" | Start fine-tuning job.")
|
print(" | Start fine-tuning job.")
|
||||||
epoch_size = fine_tune_dataset.get_repeat_count()
|
epoch_size = config.epochs * fine_tune_dataset.get_dataset_size() // config.dataset_sink_step
|
||||||
|
|
||||||
model.train(epoch_size, fine_tune_dataset,
|
model.train(epoch_size, fine_tune_dataset,
|
||||||
callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode)
|
callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode,
|
||||||
|
sink_size=config.dataset_sink_step)
|
||||||
|
|
||||||
# Test the accuracy of the model.
|
# Test the accuracy of the model.
|
||||||
if test_dataset is not None:
|
if test_dataset is not None:
|
||||||
|
|
Loading…
Reference in New Issue