forked from mindspore-Ecosystem/mindspore
!22067 squad add gpu train and eval
Merge pull request !22067 from chenweitao_295/bert_squad_gpu
This commit is contained in:
commit
0b11b86bd7
|
@ -92,17 +92,17 @@ bash scripts/run_distributed_pretrain_ascend.sh /path/cn-wiki-128 /path/hccl.jso
|
|||
- Set bert network config and optimizer hyperparameters in `finetune_eval_config.py`.
|
||||
|
||||
- Classification task: Set task related hyperparameters in scripts/run_classifier.sh.
|
||||
- Run `bash scripts/run_classifier.py` for fine-tuning of BERT-base and BERT-NEZHA model.
|
||||
- Run `bash scripts/run_classifier.sh` for fine-tuning of BERT-base and BERT-NEZHA model.
|
||||
|
||||
bash scripts/run_classifier.sh
|
||||
|
||||
- NER task: Set task related hyperparameters in scripts/run_ner.sh.
|
||||
- Run `bash scripts/run_ner.py` for fine-tuning of BERT-base and BERT-NEZHA model.
|
||||
- Run `bash scripts/run_ner.sh` for fine-tuning of BERT-base and BERT-NEZHA model.
|
||||
|
||||
bash scripts/run_ner.sh
|
||||
|
||||
- SQuAD task: Set task related hyperparameters in scripts/run_squad.sh.
|
||||
- Run `bash scripts/run_squad.py` for fine-tuning of BERT-base and BERT-NEZHA model.
|
||||
- Run `bash scripts/run_squad.sh` for fine-tuning of BERT-base and BERT-NEZHA model.
|
||||
|
||||
bash scripts/run_squad.sh
|
||||
```
|
||||
|
@ -121,19 +121,19 @@ bash scripts/run_distributed_pretrain_for_gpu.sh 8 40 /path/cn-wiki-128
|
|||
- Set bert network config and optimizer hyperparameters in `finetune_eval_config.py`.
|
||||
|
||||
- Classification task: Set task related hyperparameters in scripts/run_classifier.sh.
|
||||
- Run `bash scripts/run_classifier.py` for fine-tuning of BERT-base and BERT-NEZHA model.
|
||||
- Run `bash scripts/run_classifier.sh` for fine-tuning of BERT-base and BERT-NEZHA model.
|
||||
|
||||
bash scripts/run_classifier.sh
|
||||
|
||||
- NER task: Set task related hyperparameters in scripts/run_ner.sh.
|
||||
- Run `bash scripts/run_ner.py` for fine-tuning of BERT-base and BERT-NEZHA model.
|
||||
- Run `bash scripts/run_ner.sh` for fine-tuning of BERT-base and BERT-NEZHA model.
|
||||
|
||||
bash scripts/run_ner.sh
|
||||
|
||||
- SQuAD task: Set task related hyperparameters in scripts/run_squad.sh.
|
||||
- Run `bash scripts/run_squad.py` for fine-tuning of BERT-base and BERT-NEZHA model.
|
||||
- SQuAD task: Set task related hyperparameters in scripts/run_squad_gpu.sh.
|
||||
- Run `bash scripts/run_squad_gpu.py` for fine-tuning of BERT-base and BERT-NEZHA model.
|
||||
|
||||
bash scripts/run_squad.sh
|
||||
bash scripts/run_squad_gpu.sh
|
||||
```
|
||||
|
||||
- running on ModelArts
|
||||
|
@ -268,7 +268,8 @@ For example, the schema file of cn-wiki-128 dataset for pretraining shows as fol
|
|||
├─README.md
|
||||
├─run_classifier.sh # shell script for standalone classifier task on ascend or gpu
|
||||
├─run_ner.sh # shell script for standalone NER task on ascend or gpu
|
||||
├─run_squad.sh # shell script for standalone SQUAD task on ascend or gpu
|
||||
├─run_squad.sh # shell script for standalone SQUAD task on ascend
|
||||
├─run_squad_gpu.sh # shell script for standalone SQUAD task on gpu
|
||||
├─run_standalone_pretrain_ascend.sh # shell script for standalone pretrain on ascend
|
||||
├─run_distributed_pretrain_ascend.sh # shell script for distributed pretrain on ascend
|
||||
├─run_distributed_pretrain_gpu.sh # shell script for distributed pretrain on gpu
|
||||
|
|
|
@ -95,17 +95,17 @@ bash scripts/run_distributed_pretrain_ascend.sh /path/cn-wiki-128 /path/hccl.jso
|
|||
- 在`finetune_eval_config.py`中设置BERT网络配置和优化器超参。
|
||||
|
||||
- 分类任务:在scripts/run_classifier.sh中设置任务相关的超参。
|
||||
- 运行`bash scripts/run_classifier.py`,对BERT-base和BERT-NEZHA模型进行微调。
|
||||
- 运行`bash scripts/run_classifier.sh`,对BERT-base和BERT-NEZHA模型进行微调。
|
||||
|
||||
bash scripts/run_classifier.sh
|
||||
|
||||
- NER任务:在scripts/run_ner.sh中设置任务相关的超参。
|
||||
- 运行`bash scripts/run_ner.py`,对BERT-base和BERT-NEZHA模型进行微调。
|
||||
- 运行`bash scripts/run_ner.sh`,对BERT-base和BERT-NEZHA模型进行微调。
|
||||
|
||||
bash scripts/run_ner.sh
|
||||
|
||||
- SQUAD任务:在scripts/run_squad.sh中设置任务相关的超参。
|
||||
-运行`bash scripts/run_squad.py`,对BERT-base和BERT-NEZHA模型进行微调。
|
||||
-运行`bash scripts/run_squad.sh`,对BERT-base和BERT-NEZHA模型进行微调。
|
||||
|
||||
bash scripts/run_squad.sh
|
||||
```
|
||||
|
@ -124,19 +124,19 @@ bash scripts/run_distributed_pretrain_for_gpu.sh 8 40 /path/cn-wiki-128
|
|||
- 在`finetune_eval_config.py`中设置BERT网络配置和优化器超参。
|
||||
|
||||
- 分类任务:在scripts/run_classifier.sh中设置任务相关的超参。
|
||||
- 运行`bash scripts/run_classifier.py`,对BERT-base和BERT-NEZHA模型进行微调。
|
||||
- 运行`bash scripts/run_classifier.sh`,对BERT-base和BERT-NEZHA模型进行微调。
|
||||
|
||||
bash scripts/run_classifier.sh
|
||||
|
||||
- NER任务:在scripts/run_ner.sh中设置任务相关的超参。
|
||||
- 运行`bash scripts/run_ner.py`,对BERT-base和BERT-NEZHA模型进行微调。
|
||||
- 运行`bash scripts/run_ner.sh`,对BERT-base和BERT-NEZHA模型进行微调。
|
||||
|
||||
bash scripts/run_ner.sh
|
||||
|
||||
- SQUAD任务:在scripts/run_squad.sh中设置任务相关的超参。
|
||||
-运行`bash scripts/run_squad.py`,对BERT-base和BERT-NEZHA模型进行微调。
|
||||
- SQUAD任务:在scripts/run_squad_gpu.sh中设置任务相关的超参。
|
||||
-运行`bash scripts/run_squad_gpu.sh`,对BERT-base和BERT-NEZHA模型进行微调。
|
||||
|
||||
bash scripts/run_squad.sh
|
||||
bash scripts/run_squad_gpu.sh
|
||||
```
|
||||
|
||||
- 在ModelArts上运行(如果你想在modelarts上运行,可以参考以下文档 [modelarts](https://support.huaweicloud.com/modelarts/))
|
||||
|
@ -266,7 +266,8 @@ For example, the schema file of cn-wiki-128 dataset for pretraining shows as fol
|
|||
--README.md
|
||||
├─run_classifier.sh # Ascend或GPU设备上单机分类器任务shell脚本
|
||||
├─run_ner.sh # Ascend或GPU设备上单机NER任务shell脚本
|
||||
├─run_squad.sh # Ascend或GPU设备上单机SQUAD任务shell脚本
|
||||
├─run_squad.sh # Ascend设备上单机SQUAD任务shell脚本
|
||||
├─run_squad_gpu.sh # GPU设备上单机SQUAD任务shell脚本
|
||||
├─run_standalone_pretrain_ascend.sh # Ascend设备上单机预训练shell脚本
|
||||
├─run_distributed_pretrain_ascend.sh # Ascend设备上分布式预训练shell脚本
|
||||
├─run_distributed_pretrain_gpu.sh # GPU设备上分布式预训练shell脚本
|
||||
|
|
|
@ -78,7 +78,7 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
|
|||
param_dict = load_checkpoint(load_checkpoint_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000)
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000)
|
||||
netwithgrads = BertSquadCell(network, optimizer=optimizer, scale_update_cell=update_cell)
|
||||
model = Model(netwithgrads)
|
||||
callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(dataset.get_dataset_size()), ckpoint_cb]
|
||||
|
@ -157,6 +157,7 @@ def run_squad():
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
elif target == "GPU":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
if bert_net_cfg.compute_type != mstype.float32:
|
||||
logger.warning('GPU only support fp32 temporarily, run with fp32.')
|
||||
bert_net_cfg.compute_type = mstype.float32
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash scripts/run_squad_gpu.sh"
|
||||
echo "for example: bash scripts/run_squad_gpu.sh"
|
||||
echo "assessment_method include: [Accuracy]"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
mkdir -p ms_log
|
||||
CUR_DIR=`pwd`
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python ${PROJECT_DIR}/../run_squad.py \
|
||||
--config_path="../../task_squad_config.yaml" \
|
||||
--device_target="GPU" \
|
||||
--do_train="true" \
|
||||
--do_eval="false" \
|
||||
--device_id=0 \
|
||||
--epoch_num=3 \
|
||||
--num_class=2 \
|
||||
--train_data_shuffle="true" \
|
||||
--eval_data_shuffle="false" \
|
||||
--train_batch_size=32 \
|
||||
--eval_batch_size=1 \
|
||||
--vocab_file_path="" \
|
||||
--save_finetune_checkpoint_path="" \
|
||||
--load_pretrain_checkpoint_path="" \
|
||||
--load_finetune_checkpoint_path="" \
|
||||
--train_data_file_path="" \
|
||||
--eval_json_path="" \
|
||||
--schema_file_path="" > squad_log.txt 2>&1 &
|
|
@ -32,21 +32,26 @@ from .bert_for_pre_training import clip_grad
|
|||
from .finetune_eval_model import BertCLSModel, BertNERModel, BertSquadModel
|
||||
from .utils import CrossEntropyCalculation
|
||||
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * reciprocal(scale)
|
||||
|
||||
|
||||
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
|
||||
grad_overflow = P.FloatStatus()
|
||||
|
||||
|
||||
@_grad_overflow.register("Tensor")
|
||||
def _tensor_grad_overflow(grad):
|
||||
return grad_overflow(grad)
|
||||
|
||||
|
||||
class BertFinetuneCell(nn.Cell):
|
||||
"""
|
||||
Especially defined for finetuning where only four inputs tensor are needed.
|
||||
|
@ -61,6 +66,7 @@ class BertFinetuneCell(nn.Cell):
|
|||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
|
||||
super(BertFinetuneCell, self).__init__(auto_prefix=False)
|
||||
|
@ -156,10 +162,12 @@ class BertFinetuneCell(nn.Cell):
|
|||
self.optimizer(grads)
|
||||
return (loss, cond)
|
||||
|
||||
|
||||
class BertSquadCell(nn.Cell):
|
||||
"""
|
||||
specifically defined for finetuning where only four inputs tensor are needed.
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(BertSquadCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
|
@ -179,6 +187,13 @@ class BertSquadCell(nn.Cell):
|
|||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.cast = P.Cast()
|
||||
self.gpu_target = False
|
||||
if context.get_context("device_target") == "GPU":
|
||||
self.gpu_target = True
|
||||
self.float_status = P.FloatStatus()
|
||||
self.addn = P.AddN()
|
||||
self.reshape = P.Reshape()
|
||||
else:
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_status = P.NPUClearFloatStatus()
|
||||
|
@ -202,7 +217,7 @@ class BertSquadCell(nn.Cell):
|
|||
sens=None):
|
||||
"""BertSquad"""
|
||||
weights = self.weights
|
||||
init = self.alloc_status()
|
||||
init = False
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
|
@ -214,6 +229,8 @@ class BertSquadCell(nn.Cell):
|
|||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
if not self.gpu_target:
|
||||
init = self.alloc_status()
|
||||
init = F.depend(init, loss)
|
||||
clear_status = self.clear_status(init)
|
||||
scaling_sens = F.depend(scaling_sens, clear_status)
|
||||
|
@ -230,10 +247,15 @@ class BertSquadCell(nn.Cell):
|
|||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
if not self.gpu_target:
|
||||
init = F.depend(init, grads)
|
||||
get_status = self.get_status(init)
|
||||
init = F.depend(init, get_status)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
else:
|
||||
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
|
||||
flag_sum = self.addn(flag_sum)
|
||||
flag_sum = self.reshape(flag_sum, (()))
|
||||
if self.is_distributed:
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
cond = self.less_equal(self.base, flag_reduce)
|
||||
|
@ -246,10 +268,12 @@ class BertSquadCell(nn.Cell):
|
|||
self.optimizer(grads)
|
||||
return (loss, cond)
|
||||
|
||||
|
||||
class BertCLS(nn.Cell):
|
||||
"""
|
||||
Train interface for classification finetuning task.
|
||||
"""
|
||||
|
||||
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False,
|
||||
assessment_method=""):
|
||||
super(BertCLS, self).__init__()
|
||||
|
@ -259,6 +283,7 @@ class BertCLS(nn.Cell):
|
|||
self.num_labels = num_labels
|
||||
self.assessment_method = assessment_method
|
||||
self.is_training = is_training
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id, label_ids):
|
||||
logits = self.bert(input_ids, input_mask, token_type_id)
|
||||
if self.assessment_method == "spearman_correlation":
|
||||
|
@ -275,6 +300,7 @@ class BertNER(nn.Cell):
|
|||
"""
|
||||
Train interface for sequence labeling finetuning task.
|
||||
"""
|
||||
|
||||
def __init__(self, config, batch_size, is_training, num_labels=11, use_crf=False,
|
||||
tag_to_index=None, dropout_prob=0.0, use_one_hot_embeddings=False):
|
||||
super(BertNER, self).__init__()
|
||||
|
@ -288,6 +314,7 @@ class BertNER(nn.Cell):
|
|||
self.loss = CrossEntropyCalculation(is_training)
|
||||
self.num_labels = num_labels
|
||||
self.use_crf = use_crf
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id, label_ids):
|
||||
logits = self.bert(input_ids, input_mask, token_type_id)
|
||||
if self.use_crf:
|
||||
|
@ -296,10 +323,12 @@ class BertNER(nn.Cell):
|
|||
loss = self.loss(logits, label_ids, self.num_labels)
|
||||
return loss
|
||||
|
||||
|
||||
class BertSquad(nn.Cell):
|
||||
'''
|
||||
Train interface for SQuAD finetuning task.
|
||||
'''
|
||||
|
||||
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
|
||||
super(BertSquad, self).__init__()
|
||||
self.bert = BertSquadModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings)
|
||||
|
|
|
@ -20,14 +20,18 @@ Bert finetune and evaluation model script.
|
|||
import mindspore.nn as nn
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import context
|
||||
from .bert_model import BertModel
|
||||
|
||||
|
||||
|
||||
class BertCLSModel(nn.Cell):
|
||||
"""
|
||||
This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3),
|
||||
LCQMC(num_labels=2), Chnsenti(num_labels=2). The returned output represents the final
|
||||
logits as the results of log_softmax is proportional to that of softmax.
|
||||
"""
|
||||
|
||||
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False,
|
||||
assessment_method=""):
|
||||
super(BertCLSModel, self).__init__()
|
||||
|
@ -46,8 +50,7 @@ class BertCLSModel(nn.Cell):
|
|||
self.assessment_method = assessment_method
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id):
|
||||
_, pooled_output, _ = \
|
||||
self.bert(input_ids, token_type_id, input_mask)
|
||||
_, pooled_output, _ = self.bert(input_ids, token_type_id, input_mask)
|
||||
cls = self.cast(pooled_output, self.dtype)
|
||||
cls = self.dropout(cls)
|
||||
logits = self.dense_1(cls)
|
||||
|
@ -56,10 +59,12 @@ class BertCLSModel(nn.Cell):
|
|||
logits = self.log_softmax(logits)
|
||||
return logits
|
||||
|
||||
|
||||
class BertSquadModel(nn.Cell):
|
||||
'''
|
||||
This class is responsible for SQuAD
|
||||
'''
|
||||
|
||||
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
|
||||
super(BertSquadModel, self).__init__()
|
||||
if not is_training:
|
||||
|
@ -73,22 +78,36 @@ class BertSquadModel(nn.Cell):
|
|||
self.dtype = config.dtype
|
||||
self.log_softmax = P.LogSoftmax(axis=1)
|
||||
self.is_training = is_training
|
||||
self.gpu_target = context.get_context("device_target") == "GPU"
|
||||
self.cast = P.Cast()
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.shape = (-1, config.hidden_size)
|
||||
self.origin_shape = (-1, config.seq_length, self.num_labels)
|
||||
self.transpose_shape = (-1, self.num_labels, config.seq_length)
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id):
|
||||
"""Return the final logits as the results of log_softmax."""
|
||||
sequence_output, _, _ = self.bert(input_ids, token_type_id, input_mask)
|
||||
batch_size, seq_length, hidden_size = P.Shape()(sequence_output)
|
||||
sequence = P.Reshape()(sequence_output, (-1, hidden_size))
|
||||
sequence = self.reshape(sequence_output, self.shape)
|
||||
logits = self.dense1(sequence)
|
||||
logits = P.Cast()(logits, self.dtype)
|
||||
logits = P.Reshape()(logits, (batch_size, seq_length, self.num_labels))
|
||||
logits = self.cast(logits, self.dtype)
|
||||
logits = self.reshape(logits, self.origin_shape)
|
||||
if self.gpu_target:
|
||||
logits = self.transpose(logits, (0, 2, 1))
|
||||
logits = self.log_softmax(self.reshape(logits, (-1, self.transpose_shape[-1])))
|
||||
logits = self.transpose(self.reshape(logits, self.transpose_shape), (0, 2, 1))
|
||||
else:
|
||||
logits = self.log_softmax(logits)
|
||||
return logits
|
||||
|
||||
|
||||
class BertNERModel(nn.Cell):
|
||||
"""
|
||||
This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11).
|
||||
The returned output represents the final logits as the results of log_softmax is proportional to that of softmax.
|
||||
"""
|
||||
|
||||
def __init__(self, config, is_training, num_labels=11, use_crf=False, dropout_prob=0.0,
|
||||
use_one_hot_embeddings=False):
|
||||
super(BertNERModel, self).__init__()
|
||||
|
@ -111,8 +130,7 @@ class BertNERModel(nn.Cell):
|
|||
|
||||
def construct(self, input_ids, input_mask, token_type_id):
|
||||
"""Return the final logits as the results of log_softmax."""
|
||||
sequence_output, _, _ = \
|
||||
self.bert(input_ids, token_type_id, input_mask)
|
||||
sequence_output, _, _ = self.bert(input_ids, token_type_id, input_mask)
|
||||
seq = self.dropout(sequence_output)
|
||||
seq = self.reshape(seq, self.shape)
|
||||
logits = self.dense_1(seq)
|
||||
|
|
|
@ -36,7 +36,7 @@ export_file_name: 'bert_squad'
|
|||
file_format: 'AIR'
|
||||
|
||||
optimizer_cfg:
|
||||
optimizer: 'Lamb'
|
||||
optimizer: 'Momentum'
|
||||
AdamWeightDecay:
|
||||
learning_rate: 0.0001 # 1e-4
|
||||
end_learning_rate: 0.00000000001 # 1e-11
|
||||
|
|
Loading…
Reference in New Issue