forked from mindspore-Ecosystem/mindspore
!22120 Transformer API For Rewring the PanGU model
Merge pull request !22120 from huangxinjing/transformer_example
This commit is contained in:
commit
570363c2a3
|
@ -28,8 +28,9 @@ from mindspore.parallel import set_algo_parameters
|
|||
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_distributed_checkpoint
|
||||
from src.pangu_alpha import PanguAlpha, EvalNet
|
||||
from src.pangu_alpha_config import PANGUALPHAConfig, set_parse
|
||||
from mindspore.parallel.nn.transformer import TransformerOpParallelConfig
|
||||
from src.pangu_alpha import EvalNet, PanguAlphaModel
|
||||
from src.pangu_alpha_config import set_parse, PanguAlphaConfig
|
||||
from src.utils import get_args
|
||||
|
||||
|
||||
|
@ -74,29 +75,31 @@ def load_model(args_opt):
|
|||
# Set model property
|
||||
model_parallel_num = args_opt.op_level_model_parallel_num
|
||||
data_parallel_num = int(device_num / model_parallel_num)
|
||||
|
||||
parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num,
|
||||
model_parallel=model_parallel_num,
|
||||
pipeline_stage=args_opt.stage_num,
|
||||
micro_batch_num=args_opt.micro_size,
|
||||
recompute=True)
|
||||
|
||||
per_batch_size = args_opt.per_batch_size
|
||||
batch_size = per_batch_size * data_parallel_num
|
||||
# Now only support single batch_size for predict
|
||||
if args_opt.run_type == "predict":
|
||||
batch_size = 1
|
||||
config = PANGUALPHAConfig(
|
||||
data_parallel_num=data_parallel_num,
|
||||
model_parallel_num=model_parallel_num,
|
||||
config = PanguAlphaConfig(
|
||||
batch_size=batch_size,
|
||||
seq_length=args_opt.seq_length,
|
||||
vocab_size=args_opt.vocab_size,
|
||||
embedding_size=args_opt.embedding_size,
|
||||
hidden_size=args_opt.embedding_size,
|
||||
num_layers=args_opt.num_layers,
|
||||
num_heads=args_opt.num_heads,
|
||||
expand_ratio=4,
|
||||
post_layernorm_residual=False,
|
||||
dropout_rate=0.0,
|
||||
compute_dtype=mstype.float16,
|
||||
ffn_hidden_size=args_opt.embedding_size * 4,
|
||||
use_past=use_past,
|
||||
stage_num=args_opt.stage_num,
|
||||
micro_size=args_opt.micro_size,
|
||||
eod_reset=False,
|
||||
word_emb_dp=True,
|
||||
parallel_config=parallel_config,
|
||||
load_ckpt_path=args_opt.load_ckpt_path,
|
||||
param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16)
|
||||
print("===config is: ", config, flush=True)
|
||||
|
@ -104,7 +107,7 @@ def load_model(args_opt):
|
|||
|
||||
ckpt_name = args_opt.load_ckpt_name
|
||||
# Define network
|
||||
pangu_alpha = PanguAlpha(config)
|
||||
pangu_alpha = PanguAlphaModel(config)
|
||||
eval_net = EvalNet(pangu_alpha)
|
||||
eval_net.set_train(False)
|
||||
model_predict = Model(eval_net)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
numpy
|
||||
trnsformers
|
||||
transformers
|
||||
sentencepiece
|
||||
jieba
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_distributed_train_gpu.sh RANK_SIZE HOSTFILE DATASET PER_BATCH MODE"
|
||||
echo "for example: bash run_distributed_train_gpu.sh 16 hostfile_16p /mass_dataset/train_data/ 16 2.6B"
|
||||
echo "bash run_distributed_train_gpu.sh RANK_SIZE HOSTFILE DATASET MODE"
|
||||
echo "for example: bash run_distributed_train_gpu.sh 16 hostfile_16p /mass_dataset/train_data/ 2.6B"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
|
@ -26,8 +26,7 @@ self_path=$(dirname "${script_self}")
|
|||
RANK_SIZE=$1
|
||||
HOSTFILE=$2
|
||||
DATASET=$3
|
||||
PER_BATCH=$4
|
||||
MODE=$5
|
||||
MODE=$4
|
||||
|
||||
mpirun --allow-run-as-root -x PATH -x LD_LIBRARY_PATH -x PYTHONPATH -x NCCL_DEBUG -x GLOG_v -n $RANK_SIZE --hostfile $HOSTFILE --output-filename log_output --merge-stderr-to-stdout \
|
||||
python -s ${self_path}/../train.py \
|
||||
|
@ -36,5 +35,4 @@ mpirun --allow-run-as-root -x PATH -x LD_LIBRARY_PATH -x PYTHONPATH -x NCCL_DEBU
|
|||
--device_target="GPU" \
|
||||
--data_url=$DATASET \
|
||||
--mode=$MODE \
|
||||
--per_batch_size=$PER_BATCH \
|
||||
--run_type=train > train_log.txt 2>&1 &
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -17,57 +17,47 @@ network config setting
|
|||
"""
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
class PanguAlphaConfig:
|
||||
"""
|
||||
PanGUConfig config class which defines the model size
|
||||
"""
|
||||
|
||||
class PANGUALPHAConfig:
|
||||
"""
|
||||
PANGUALPHA config class which defines the model size
|
||||
"""
|
||||
def __init__(self,
|
||||
data_parallel_num,
|
||||
model_parallel_num,
|
||||
batch_size=32,
|
||||
seq_length=1024,
|
||||
vocab_size=50257,
|
||||
embedding_size=768,
|
||||
vocab_size=40000,
|
||||
hidden_size=768,
|
||||
ffn_hidden_size=768,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
expand_ratio=4,
|
||||
load_ckpt_path=None,
|
||||
param_init_type=mstype.float32,
|
||||
post_layernorm_residual=False,
|
||||
dropout_rate=0.1,
|
||||
compute_dtype=mstype.float16,
|
||||
eod_token=6,
|
||||
use_past=False,
|
||||
word_emb_dp=True,
|
||||
stage_num=16,
|
||||
hidden_act='gelu',
|
||||
eod_reset=True,
|
||||
micro_size=32,
|
||||
load_ckpt_path=None,
|
||||
use_top_query_attention=True,
|
||||
param_init_type=mstype.float32,
|
||||
enable_offload=False):
|
||||
enable_offload=False,
|
||||
parallel_config=None):
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.vocab_size = vocab_size
|
||||
self.embedding_size = embedding_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
# The expand ratio of feature size in FFN
|
||||
self.expand_ratio = expand_ratio
|
||||
self.eod_token = eod_token
|
||||
# Use post-layernorm or pre-layernrom, default:pre-layernorm
|
||||
self.post_layernorm_residual = post_layernorm_residual
|
||||
self.dropout_rate = dropout_rate
|
||||
self.compute_dtype = compute_dtype
|
||||
# Whether use incremental inference
|
||||
self.use_past = use_past
|
||||
self.dp = data_parallel_num
|
||||
self.mp = model_parallel_num
|
||||
self.stage_num = stage_num
|
||||
self.micro_size = micro_size
|
||||
self.word_emb_dp = word_emb_dp
|
||||
self.eod_reset = eod_reset
|
||||
# Used for loading embedding tables
|
||||
self.load_ckpt_path = load_ckpt_path
|
||||
self.use_top_query_attention = use_top_query_attention
|
||||
self.param_init_type = param_init_type
|
||||
self.dropout_rate = dropout_rate
|
||||
self.compute_dtype = mstype.float16
|
||||
self.parallel_config = parallel_config
|
||||
self.ffn_hidden_size = ffn_hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.use_past = use_past
|
||||
self.eod_reset = eod_reset
|
||||
self.enable_offload = enable_offload
|
||||
|
||||
def __str__(self):
|
||||
|
@ -78,6 +68,7 @@ class PANGUALPHAConfig:
|
|||
info += '=' * 10
|
||||
return info
|
||||
|
||||
|
||||
def set_parse(args_opt):
|
||||
r"""
|
||||
Set config according to the mode
|
||||
|
|
|
@ -70,6 +70,7 @@ reciprocal = P.Reciprocal()
|
|||
def tensor_grad_scale(scale, grad):
|
||||
return grad * reciprocal(scale)
|
||||
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor", "Tensor")
|
||||
def tensor_grad_scale_pipeline(scale, grad, accu_grad):
|
||||
accu_grad = F.depend(accu_grad, grad)
|
||||
|
@ -79,6 +80,7 @@ def tensor_grad_scale_pipeline(scale, grad, accu_grad):
|
|||
new_grad = F.depend(new_grad, F.assign(accu_grad, zeros))
|
||||
return new_grad
|
||||
|
||||
|
||||
@shard_grad_scale.register("Tensor", "Tensor", "Tensor")
|
||||
def tensor_shard_grad_scale_pipeline(scale, grad, accu_grad):
|
||||
new_grad = grad * reciprocal(scale)
|
||||
|
@ -151,6 +153,7 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
|
|||
self.optimizer(grads)
|
||||
return loss, cond, scaling_sens
|
||||
|
||||
|
||||
class PanguAlphaTrainPipelineWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of PanguAlpha network training.
|
||||
|
@ -200,7 +203,7 @@ class PanguAlphaTrainPipelineWithLossScaleCell(nn.Cell):
|
|||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||
name="loss_scale")
|
||||
self.clip = ClipByGlobalNorm(self.weights, self.config)
|
||||
self.micro_size = config.micro_size
|
||||
self.micro_size = config.parallel_config.micro_batch_num
|
||||
self.opt_shard = _get_enable_parallel_optimizer()
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
|
|
|
@ -139,12 +139,12 @@ class GlobalNorm(nn.Cell):
|
|||
super(GlobalNorm, self).__init__()
|
||||
self.norm = nn.Norm()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.is_pipeline = (config.stage_num > 1)
|
||||
self.is_pipeline = context.get_auto_parallel_context("pipeline_stages") > 1
|
||||
if self.is_pipeline:
|
||||
if context.get_auto_parallel_context("enable_parallel_optimizer"):
|
||||
group_size = get_group_size() // config.stage_num
|
||||
group_size = get_group_size() // config.parallel_config.pipeline_stage
|
||||
else:
|
||||
group_size = config.mp
|
||||
group_size = config.parallel_config.model_parallel
|
||||
group_list, group_name = _get_model_parallel_group(group_size)
|
||||
create_group(group_name, group_list)
|
||||
self.allreduce = P.AllReduce(group=group_name)
|
||||
|
@ -160,9 +160,10 @@ class GlobalNorm(nn.Cell):
|
|||
elif "embedding_table" not in x.name:
|
||||
self.allreduce_group_size = self.allreduce_group_size + (group_size * 1.0,)
|
||||
else:
|
||||
if not config.word_emb_dp and "position_embedding.embedding_table" not in x.name \
|
||||
if not config.parallel_config.vocab_emb_dp and "position_embedding.embedding_table" not in x.name \
|
||||
and "top_query_embedding_table" not in x.name:
|
||||
self.allreduce_group_size = self.allreduce_group_size + (config.dp * 1.0,)
|
||||
self.allreduce_group_size = self.allreduce_group_size +\
|
||||
(config.parallel_config.data_parallel * 1.0,)
|
||||
else:
|
||||
self.allreduce_group_size = self.allreduce_group_size + (group_size * 1.0,)
|
||||
|
||||
|
|
|
@ -29,17 +29,17 @@ import mindspore.common.dtype as mstype
|
|||
from mindspore.parallel import set_algo_parameters
|
||||
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
|
||||
from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell
|
||||
from mindspore.parallel.nn import TransformerOpParallelConfig, CrossEntropyLoss
|
||||
from src.adam import AdamWeightDecayOp
|
||||
from src.dataset import create_dataset
|
||||
from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss, CrossEntropyLoss
|
||||
from src.pangu_alpha import PanGUAlphaWithLoss, PanguAlphaModel
|
||||
from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, PanguAlphaTrainPipelineWithLossScaleCell
|
||||
from src.pangu_alpha_config import PANGUALPHAConfig, set_parse
|
||||
from src.pangu_alpha_config import set_parse, PanguAlphaConfig
|
||||
from src.utils import LearningRate, get_args, FP32StateAdamWeightDecay
|
||||
from src.utils import download_data
|
||||
from src.callbacks import EvalCallBack, LossCallBack
|
||||
from src.metrics import PPLMetric
|
||||
|
||||
|
||||
project_root = os.path.abspath(
|
||||
os.path.dirname(os.path.realpath(__file__)) + os.path.sep + "..")
|
||||
print('project_root:', project_root)
|
||||
|
@ -69,7 +69,8 @@ def run_train(args_opt):
|
|||
The main training process.
|
||||
"""
|
||||
# Set execution mode
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, variable_memory_max_size="31GB")
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
context.set_context(variable_memory_max_size="31GB")
|
||||
# Set parallel context
|
||||
if args_opt.distribute == "true":
|
||||
D.init()
|
||||
|
@ -100,27 +101,37 @@ def run_train(args_opt):
|
|||
model_parallel_num = args_opt.op_level_model_parallel_num
|
||||
data_parallel_num = int(device_num / model_parallel_num)
|
||||
batch_size = args_opt.per_batch_size * data_parallel_num
|
||||
config = PANGUALPHAConfig(
|
||||
data_parallel_num=data_parallel_num, model_parallel_num=model_parallel_num,
|
||||
batch_size=batch_size, seq_length=args_opt.seq_length,
|
||||
vocab_size=args_opt.vocab_size, embedding_size=args_opt.embedding_size,
|
||||
num_layers=args_opt.num_layers, num_heads=args_opt.num_heads,
|
||||
expand_ratio=4, dropout_rate=0.1, compute_dtype=mstype.float16,
|
||||
stage_num=args_opt.stage_num, micro_size=args_opt.micro_size,
|
||||
eod_reset=bool(args_opt.eod_reset), load_ckpt_path=args_opt.load_ckpt_path,
|
||||
param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16,
|
||||
word_emb_dp=bool(args_opt.word_emb_dp), enable_offload=bool(args_opt.opt_offload))
|
||||
parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num,
|
||||
model_parallel=model_parallel_num,
|
||||
pipeline_stage=args_opt.stage_num,
|
||||
micro_batch_num=args_opt.micro_size,
|
||||
optimizer_shard=bool(args_opt.optimizer_shard),
|
||||
recompute=True)
|
||||
config = PanguAlphaConfig(batch_size=batch_size, num_heads=args_opt.num_heads,
|
||||
hidden_size=args_opt.embedding_size, seq_length=args_opt.seq_length,
|
||||
vocab_size=args_opt.vocab_size, num_layers=args_opt.num_layers,
|
||||
ffn_hidden_size=args_opt.embedding_size * 4,
|
||||
eod_token=bool(args_opt.eod_reset),
|
||||
load_ckpt_path=args_opt.load_ckpt_path,
|
||||
param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16,
|
||||
enable_offload=bool(args_opt.opt_offload),
|
||||
parallel_config=parallel_config)
|
||||
print("===config is: ", config, flush=True)
|
||||
|
||||
# Define network
|
||||
pangu_alpha = PanguAlpha(config)
|
||||
loss = CrossEntropyLoss(config)
|
||||
pangu_alpha_with_loss_net = PanguAlphaWithLoss(config, pangu_alpha, loss)
|
||||
pangu_alpha = PanguAlphaModel(config=config)
|
||||
loss = CrossEntropyLoss(config.parallel_config.dp_mp_config)
|
||||
pangu_alpha_with_loss_net = PanGUAlphaWithLoss(config, pangu_alpha, loss)
|
||||
pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss_net)
|
||||
|
||||
print("=====args_opt is: ", args_opt, flush=True)
|
||||
# Warm-up and cosine decay learning rate
|
||||
lr = LearningRate(learning_rate=args_opt.start_lr, end_learning_rate=args_opt.end_lr,
|
||||
warmup_steps=args_opt.warmup_step, decay_steps=200000)
|
||||
params = pangu_alpha.trainable_params()
|
||||
lr = LearningRate(learning_rate=args_opt.start_lr,
|
||||
end_learning_rate=args_opt.end_lr,
|
||||
warmup_steps=args_opt.warmup_step,
|
||||
decay_steps=200000)
|
||||
|
||||
params = pangu_alpha_with_loss.trainable_params()
|
||||
group_params = set_weight_decay(params)
|
||||
if args_opt.optimizer == "lamb":
|
||||
optimizer = nn.Lamb(group_params, learning_rate=lr)
|
||||
|
@ -165,6 +176,7 @@ def run_train(args_opt):
|
|||
print("Dataset size: {}, actual_epoch_num: {}".format(ds.get_dataset_size(), actual_epoch_num), flush=True)
|
||||
model.train(actual_epoch_num, ds, callbacks=callback, sink_size=args_opt.sink_size, dataset_sink_mode=True)
|
||||
|
||||
|
||||
def run_train_pipeline(args_opt):
|
||||
r"""
|
||||
The main training process in pipeline.
|
||||
|
@ -203,27 +215,26 @@ def run_train_pipeline(args_opt):
|
|||
raise ValueError("The dp must large than 1 when applying optimizer shard.")
|
||||
per_batch_size = args_opt.per_batch_size
|
||||
batch_size = per_batch_size * data_parallel_num * args_opt.micro_size
|
||||
config = PANGUALPHAConfig(
|
||||
data_parallel_num=data_parallel_num,
|
||||
model_parallel_num=model_parallel_num,
|
||||
batch_size=batch_size,
|
||||
seq_length=args_opt.seq_length,
|
||||
vocab_size=args_opt.vocab_size,
|
||||
embedding_size=args_opt.embedding_size,
|
||||
num_layers=args_opt.num_layers,
|
||||
num_heads=args_opt.num_heads,
|
||||
expand_ratio=4,
|
||||
post_layernorm_residual=False,
|
||||
dropout_rate=0.1,
|
||||
compute_dtype=mstype.float16,
|
||||
use_past=False,
|
||||
stage_num=args_opt.stage_num,
|
||||
micro_size=args_opt.micro_size,
|
||||
word_emb_dp=bool(args_opt.word_emb_dp), enable_offload=bool(args_opt.opt_offload))
|
||||
|
||||
parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num,
|
||||
model_parallel=model_parallel_num,
|
||||
pipeline_stage=args_opt.stage_num,
|
||||
micro_batch_num=args_opt.micro_size,
|
||||
optimizer_shard=bool(args_opt.optimizer_shard),
|
||||
recompute=True)
|
||||
config = PanguAlphaConfig(batch_size=batch_size // parallel_config.micro_batch_num,
|
||||
num_heads=args_opt.num_heads, hidden_size=args_opt.embedding_size,
|
||||
seq_length=args_opt.seq_length, vocab_size=args_opt.vocab_size,
|
||||
num_layers=args_opt.num_layers, ffn_hidden_size=args_opt.embedding_size * 4,
|
||||
eod_token=bool(args_opt.eod_reset), load_ckpt_path=args_opt.load_ckpt_path,
|
||||
param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16,
|
||||
enable_offload=bool(args_opt.opt_offload), parallel_config=parallel_config)
|
||||
|
||||
print("===config is: ", config, flush=True)
|
||||
pangu_alpha = PanguAlpha(config)
|
||||
loss = CrossEntropyLoss(config)
|
||||
pangu_alpha_with_loss_net = PipelineCell(PanguAlphaWithLoss(config, pangu_alpha, loss), config.micro_size)
|
||||
pangu_alpha = PanguAlphaModel(config=config)
|
||||
loss = CrossEntropyLoss(config.parallel_config.dp_mp_config)
|
||||
pangu_alpha_with_loss_net = PipelineCell(PanGUAlphaWithLoss(config, pangu_alpha, loss),
|
||||
config.parallel_config.micro_batch_num)
|
||||
pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss_net)
|
||||
print("=====args_opt is: ", args_opt, flush=True)
|
||||
lr = LearningRate(learning_rate=args_opt.start_lr, end_learning_rate=args_opt.end_lr,
|
||||
|
@ -238,7 +249,8 @@ def run_train_pipeline(args_opt):
|
|||
else:
|
||||
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8)
|
||||
|
||||
ds = create_dataset(config.batch_size, data_path=cache_url, device_num=stage_device_num,
|
||||
ds = create_dataset(config.batch_size * parallel_config.micro_batch_num, data_path=cache_url,
|
||||
device_num=stage_device_num,
|
||||
rank=rank_id % stage_device_num, eod_reset=True, data_start_index=0,
|
||||
full_batch=context.get_auto_parallel_context("full_batch"),
|
||||
column_name=args_opt.data_column_name)
|
||||
|
@ -246,19 +258,20 @@ def run_train_pipeline(args_opt):
|
|||
step_per_epoch = ds.get_dataset_size()
|
||||
callback_size = args_opt.sink_size
|
||||
actual_epoch_num = int(epoch_num * step_per_epoch / callback_size)
|
||||
callback = [TimeMonitor(callback_size), LossCallBack(callback_size, rank_id, micro_size=config.micro_size)]
|
||||
callback = [TimeMonitor(callback_size), LossCallBack(callback_size, rank_id,
|
||||
micro_size=parallel_config.micro_batch_num)]
|
||||
loss_scale_value = math.pow(2, 32)
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000)
|
||||
pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell(
|
||||
pangu_alpha_with_loss, optimizer=optimizer, config=config, scale_update_cell=update_cell)
|
||||
if args_opt.train_and_eval_mode:
|
||||
ds_eval = create_dataset(config.batch_size // config.micro_size, data_path=eval_cache_url,
|
||||
ds_eval = create_dataset(config.batch_size * parallel_config.micro_batch_num, data_path=eval_cache_url,
|
||||
device_num=stage_device_num, rank=rank_id % stage_device_num, eod_reset=True,
|
||||
data_start_index=0, full_batch=bool(args_opt.full_batch),
|
||||
column_name=args_opt.data_column_name,
|
||||
num_samples=args_opt.eval_steps * config.batch_size)
|
||||
ppl_metric = PPLMetric(config.seq_length)
|
||||
pangu_alpha_with_loss_eval_net = _VirtualDatasetCell(PanguAlphaWithLoss(config, pangu_alpha, loss))
|
||||
pangu_alpha_with_loss_eval_net = _VirtualDatasetCell(PanGUAlphaWithLoss(config, pangu_alpha, loss))
|
||||
model = Model(pangu_alpha_with_grads, eval_network=pangu_alpha_with_loss_eval_net, metrics={"ppl": ppl_metric})
|
||||
model.build(ds, ds_eval, sink_size=callback_size)
|
||||
eval_callback = EvalCallBack(model, ds_eval, ppl_metric)
|
||||
|
|
Loading…
Reference in New Issue