forked from mindspore-Ecosystem/mindspore
!10507 Modify dataset preprocess of textcnn
From: @zhouyaqiang0 Reviewed-by: @linqingke,@oacjiewen Signed-off-by: @linqingke
This commit is contained in:
commit
0903777d3f
|
@ -104,6 +104,7 @@ Parameters for both training and evaluation can be set in config.py
|
|||
'checkpoint_path': './train_textcnn.ckpt' # the absolute full path to save the checkpoint file
|
||||
'word_len': 51 # The length of the word
|
||||
'vec_length': 40 # The length of the vector
|
||||
'base_lr': 1e-3 # The base learning rate
|
||||
```
|
||||
|
||||
For more configuration details, please refer the script `config.py`.
|
||||
|
|
|
@ -24,20 +24,29 @@ from mindspore import context
|
|||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import cfg
|
||||
from src.config import cfg_mr, cfg_subj, cfg_sst2
|
||||
from src.textcnn import TextCNN
|
||||
from src.dataset import MovieReview
|
||||
from src.dataset import MovieReview, SST2, Subjectivity
|
||||
|
||||
parser = argparse.ArgumentParser(description='TextCNN')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
parser.add_argument('--dataset', type=str, default="MR", choices=['MR', 'SUBJ', 'SST2'])
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args_opt.dataset == 'MR':
|
||||
cfg = cfg_mr
|
||||
instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9)
|
||||
elif args_opt.dataset == 'SUBJ':
|
||||
cfg = cfg_subj
|
||||
instance = Subjectivity(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9)
|
||||
elif args_opt.dataset == 'SST2':
|
||||
cfg = cfg_sst2
|
||||
instance = SST2(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9)
|
||||
device_target = cfg.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
|
||||
if device_target == "Ascend":
|
||||
context.set_context(device_id=cfg.device_id)
|
||||
instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9)
|
||||
dataset = instance.create_test_dataset(batch_size=cfg.batch_size)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
net = TextCNN(vocab_len=instance.get_dict_len(), word_len=cfg.word_len,
|
||||
|
|
|
@ -13,5 +13,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
python eval.py --checkpoint_path="$1" > eval.log 2>&1 &
|
||||
dataset_type='MR'
|
||||
if [ $# == 2 ]
|
||||
then
|
||||
if [ $2 != "MR" ] && [ $2 != "SUBJ" ] && [ $2 != "SST2" ]
|
||||
then
|
||||
echo "error: the selected dataset is not in supported set{MR, SUBJ, SST2}"
|
||||
exit 1
|
||||
fi
|
||||
dataset_type=$2
|
||||
fi
|
||||
python eval.py --checkpoint_path="$1" --dataset=$dataset_type > eval.log 2>&1 &
|
||||
|
|
|
@ -13,5 +13,15 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
python train.py > train.log 2>&1 &
|
||||
dataset_type='MR'
|
||||
if [ $# == 1 ]
|
||||
then
|
||||
if [ $1 != "MR" ] && [ $1 != "SUBJ" ] && [ $1 != "SST2" ]
|
||||
then
|
||||
echo "error: the selected dataset is not in supported set{MR, SUBJ, SST2}"
|
||||
exit 1
|
||||
fi
|
||||
dataset_type=$1
|
||||
fi
|
||||
rm ./ckpt_0 -rf
|
||||
python train.py --dataset=$dataset_type > train.log 2>&1 &
|
||||
|
|
|
@ -17,7 +17,7 @@ network config setting, will be used in main.py
|
|||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
cfg = edict({
|
||||
cfg_mr = edict({
|
||||
'name': 'movie review',
|
||||
'pre_trained': False,
|
||||
'num_classes': 2,
|
||||
|
@ -30,5 +30,40 @@ cfg = edict({
|
|||
'keep_checkpoint_max': 1,
|
||||
'checkpoint_path': './ckpt/train_textcnn-4_149.ckpt',
|
||||
'word_len': 51,
|
||||
'vec_length': 40
|
||||
'vec_length': 40,
|
||||
'base_lr': 1e-3
|
||||
})
|
||||
|
||||
cfg_subj = edict({
|
||||
'name': 'subjectivity',
|
||||
'pre_trained': False,
|
||||
'num_classes': 2,
|
||||
'batch_size': 64,
|
||||
'epoch_size': 5,
|
||||
'weight_decay': 3e-5,
|
||||
'data_path': './Subj/',
|
||||
'device_target': 'Ascend',
|
||||
'device_id': 7,
|
||||
'keep_checkpoint_max': 1,
|
||||
'checkpoint_path': './ckpt/train_textcnn-4_149.ckpt',
|
||||
'word_len': 51,
|
||||
'vec_length': 40,
|
||||
'base_lr': 8e-4
|
||||
})
|
||||
|
||||
cfg_sst2 = edict({
|
||||
'name': 'SST2',
|
||||
'pre_trained': False,
|
||||
'num_classes': 2,
|
||||
'batch_size': 64,
|
||||
'epoch_size': 4,
|
||||
'weight_decay': 3e-5,
|
||||
'data_path': './SST-2/',
|
||||
'device_target': 'Ascend',
|
||||
'device_id': 7,
|
||||
'keep_checkpoint_max': 1,
|
||||
'checkpoint_path': './ckpt/train_textcnn-4_149.ckpt',
|
||||
'word_len': 51,
|
||||
'vec_length': 40,
|
||||
'base_lr': 5e-3
|
||||
})
|
||||
|
|
|
@ -21,6 +21,7 @@ import random
|
|||
import codecs
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import mindspore.dataset as ds
|
||||
|
||||
|
||||
|
@ -36,9 +37,56 @@ class Generator():
|
|||
return len(self.input_list)
|
||||
|
||||
|
||||
class MovieReview:
|
||||
class DataProcessor:
|
||||
"""
|
||||
preprocess MR dataset
|
||||
preprocess dataset
|
||||
"""
|
||||
def get_dict_len(self):
|
||||
"""
|
||||
get number of different words in the whole dataset
|
||||
"""
|
||||
if self.doConvert:
|
||||
return len(self.Vocab)
|
||||
return -1
|
||||
|
||||
def collect_weight(self, glove_path, embed_size):
|
||||
""" collect weight """
|
||||
vocab_size = self.get_dict_len()
|
||||
embedding_index = {}
|
||||
f = open(glove_path)
|
||||
for line in f:
|
||||
values = line.split()
|
||||
word = values[0]
|
||||
vec = np.array(values[1:], dtype='float32')
|
||||
embedding_index[word] = vec
|
||||
weight_np = np.zeros((vocab_size, embed_size)).astype(np.float32)
|
||||
|
||||
for word, vec in embedding_index.items():
|
||||
try:
|
||||
index = self.Vocab[word]
|
||||
except KeyError:
|
||||
continue
|
||||
weight_np[index, :] = vec
|
||||
return weight_np
|
||||
|
||||
def create_train_dataset(self, epoch_size, batch_size, collect_weight=False, glove_path='', embed_size=50):
|
||||
if collect_weight:
|
||||
weight_np = self.collect_weight(glove_path, embed_size)
|
||||
np.savetxt('./weight.txt', weight_np)
|
||||
dataset = ds.GeneratorDataset(source=Generator(input_list=self.train),
|
||||
column_names=["data", "label"], shuffle=False)
|
||||
dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
|
||||
return dataset
|
||||
|
||||
def create_test_dataset(self, batch_size):
|
||||
dataset = ds.GeneratorDataset(source=Generator(input_list=self.test),
|
||||
column_names=["data", "label"], shuffle=False)
|
||||
dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
|
||||
return dataset
|
||||
|
||||
class MovieReview(DataProcessor):
|
||||
"""
|
||||
preprocess MovieReview dataset
|
||||
"""
|
||||
def __init__(self, root_dir, maxlen, split):
|
||||
"""
|
||||
|
@ -151,7 +199,6 @@ class MovieReview:
|
|||
# Vocab = {word : index}
|
||||
self.Vocab = dict()
|
||||
|
||||
# self.Vocab['None']
|
||||
for SentenceLabel in self.Pos+self.Neg:
|
||||
vector = [0]*maxlen
|
||||
for index, word in enumerate(SentenceLabel[0]):
|
||||
|
@ -185,27 +232,256 @@ class MovieReview:
|
|||
self.train = [i for item in pos_temp+neg_temp for i in item]
|
||||
|
||||
random.shuffle(self.train)
|
||||
# random.shuffle(self.test)
|
||||
|
||||
def get_dict_len(self):
|
||||
class Subjectivity(DataProcessor):
|
||||
"""
|
||||
preprocess Subjectivity dataset
|
||||
"""
|
||||
def __init__(self, root_dir, maxlen, split):
|
||||
self.path = root_dir
|
||||
self.feelMap = {
|
||||
'neg': 0,
|
||||
'pos': 1
|
||||
}
|
||||
self.files = []
|
||||
self.doConvert = False
|
||||
mypath = Path(self.path)
|
||||
|
||||
if not mypath.exists() or not mypath.is_dir():
|
||||
print("please check the root_dir!")
|
||||
raise ValueError
|
||||
|
||||
# walk through the root_dir
|
||||
for root, _, filename in os.walk(self.path):
|
||||
for each in filename:
|
||||
self.files.append(os.path.join(root, each))
|
||||
break
|
||||
|
||||
# begin to read data
|
||||
self.word_num = 0
|
||||
self.maxlen = 0
|
||||
self.minlen = float("inf")
|
||||
self.maxlen = float("-inf")
|
||||
self.Pos = []
|
||||
self.Neg = []
|
||||
for filename in self.files:
|
||||
self.read_data(filename)
|
||||
self.PosNeg = self.Pos + self.Neg
|
||||
self.text2vec(maxlen=maxlen)
|
||||
self.split_dataset(split=split)
|
||||
|
||||
def read_data(self, filePath):
|
||||
"""
|
||||
get number of different words in the whole dataset
|
||||
read text into memory
|
||||
|
||||
input:
|
||||
filePath: the path where the data is stored in
|
||||
"""
|
||||
if self.doConvert:
|
||||
return len(self.Vocab)
|
||||
return -1
|
||||
#else:
|
||||
# print("Haven't finished Text2Vec")
|
||||
# return -1
|
||||
with open(filePath, 'r', encoding="ISO-8859-1") as f:
|
||||
for sentence in f.readlines():
|
||||
sentence = sentence.replace('\n', '')\
|
||||
.replace('"', '')\
|
||||
.replace('\'', '')\
|
||||
.replace('.', '')\
|
||||
.replace(',', '')\
|
||||
.replace('[', '')\
|
||||
.replace(']', '')\
|
||||
.replace('(', '')\
|
||||
.replace(')', '')\
|
||||
.replace(':', '')\
|
||||
.replace('--', '')\
|
||||
.replace('-', '')\
|
||||
.replace('\\', '')\
|
||||
.replace('0', '')\
|
||||
.replace('1', '')\
|
||||
.replace('2', '')\
|
||||
.replace('3', '')\
|
||||
.replace('4', '')\
|
||||
.replace('5', '')\
|
||||
.replace('6', '')\
|
||||
.replace('7', '')\
|
||||
.replace('8', '')\
|
||||
.replace('9', '')\
|
||||
.replace('`', '')\
|
||||
.replace('=', '')\
|
||||
.replace('$', '')\
|
||||
.replace('/', '')\
|
||||
.replace('*', '')\
|
||||
.replace(';', '')\
|
||||
.replace('<b>', '')\
|
||||
.replace('%', '')
|
||||
sentence = sentence.split(' ')
|
||||
sentence = list(filter(lambda x: x, sentence))
|
||||
if sentence:
|
||||
self.word_num += len(sentence)
|
||||
self.maxlen = self.maxlen if self.maxlen >= len(sentence) else len(sentence)
|
||||
self.minlen = self.minlen if self.minlen <= len(sentence) else len(sentence)
|
||||
if 'quote' in filePath:
|
||||
self.Pos.append([sentence, self.feelMap['pos']])
|
||||
elif 'plot' in filePath:
|
||||
self.Neg.append([sentence, self.feelMap['neg']])
|
||||
|
||||
def create_train_dataset(self, epoch_size, batch_size):
|
||||
dataset = ds.GeneratorDataset(source=Generator(input_list=self.train),
|
||||
column_names=["data", "label"], shuffle=False)
|
||||
dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
|
||||
return dataset
|
||||
def text2vec(self, maxlen):
|
||||
"""
|
||||
convert the sentence into a vector in an int type
|
||||
|
||||
def create_test_dataset(self, batch_size):
|
||||
dataset = ds.GeneratorDataset(source=Generator(input_list=self.test),
|
||||
column_names=["data", "label"], shuffle=False)
|
||||
dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
|
||||
return dataset
|
||||
input:
|
||||
maxlen: max length of the sentence
|
||||
"""
|
||||
# Vocab = {word : index}
|
||||
self.Vocab = dict()
|
||||
|
||||
for SentenceLabel in self.Pos+self.Neg:
|
||||
vector = [0]*maxlen
|
||||
for index, word in enumerate(SentenceLabel[0]):
|
||||
if index >= maxlen:
|
||||
break
|
||||
if word not in self.Vocab.keys():
|
||||
self.Vocab[word] = len(self.Vocab)
|
||||
vector[index] = len(self.Vocab) - 1
|
||||
else:
|
||||
vector[index] = self.Vocab[word]
|
||||
SentenceLabel[0] = vector
|
||||
self.doConvert = True
|
||||
|
||||
def split_dataset(self, split):
|
||||
"""
|
||||
split the dataset into training set and test set
|
||||
input:
|
||||
split: the ratio of training set to test set
|
||||
rank: logic order
|
||||
size: device num
|
||||
"""
|
||||
trunk_pos_size = math.ceil((1-split)*len(self.Pos))
|
||||
trunk_neg_size = math.ceil((1-split)*len(self.Neg))
|
||||
trunk_num = int(1/(1-split))
|
||||
pos_temp = list()
|
||||
neg_temp = list()
|
||||
for index in range(trunk_num):
|
||||
pos_temp.append(self.Pos[index*trunk_pos_size:(index+1)*trunk_pos_size])
|
||||
neg_temp.append(self.Neg[index*trunk_neg_size:(index+1)*trunk_neg_size])
|
||||
self.test = pos_temp.pop(2)+neg_temp.pop(2)
|
||||
self.train = [i for item in pos_temp+neg_temp for i in item]
|
||||
|
||||
random.shuffle(self.train)
|
||||
|
||||
class SST2(DataProcessor):
|
||||
"""
|
||||
preprocess SST2 dataset
|
||||
"""
|
||||
def __init__(self, root_dir, maxlen, split):
|
||||
self.path = root_dir
|
||||
self.files = []
|
||||
self.train = []
|
||||
self.test = []
|
||||
self.doConvert = False
|
||||
mypath = Path(self.path)
|
||||
|
||||
if not mypath.exists() or not mypath.is_dir():
|
||||
print("please check the root_dir!")
|
||||
raise ValueError
|
||||
|
||||
# walk through the root_dir
|
||||
for root, _, filename in os.walk(self.path):
|
||||
for each in filename:
|
||||
self.files.append(os.path.join(root, each))
|
||||
break
|
||||
|
||||
# begin to read data
|
||||
self.word_num = 0
|
||||
self.maxlen = 0
|
||||
self.minlen = float("inf")
|
||||
self.maxlen = float("-inf")
|
||||
for filename in self.files:
|
||||
if 'train' in filename or 'dev' in filename:
|
||||
f = codecs.open(filename, 'r')
|
||||
ff = f.read()
|
||||
file_object = codecs.open(filename, 'w', 'utf-8')
|
||||
file_object.write(ff)
|
||||
self.read_data(filename)
|
||||
self.text2vec(maxlen=maxlen)
|
||||
self.split_dataset(split=split)
|
||||
|
||||
def read_data(self, filePath):
|
||||
"""
|
||||
read text into memory
|
||||
|
||||
input:
|
||||
filePath: the path where the data is stored in
|
||||
"""
|
||||
df = pd.read_csv(filePath, delimiter='\t')
|
||||
for sentence, label in zip(df['sentence'], df['label']):
|
||||
sentence = sentence.replace('\n', '')\
|
||||
.replace('"', '')\
|
||||
.replace('\'', '')\
|
||||
.replace('.', '')\
|
||||
.replace(',', '')\
|
||||
.replace('[', '')\
|
||||
.replace(']', '')\
|
||||
.replace('(', '')\
|
||||
.replace(')', '')\
|
||||
.replace(':', '')\
|
||||
.replace('--', '')\
|
||||
.replace('-', '')\
|
||||
.replace('\\', '')\
|
||||
.replace('0', '')\
|
||||
.replace('1', '')\
|
||||
.replace('2', '')\
|
||||
.replace('3', '')\
|
||||
.replace('4', '')\
|
||||
.replace('5', '')\
|
||||
.replace('6', '')\
|
||||
.replace('7', '')\
|
||||
.replace('8', '')\
|
||||
.replace('9', '')\
|
||||
.replace('`', '')\
|
||||
.replace('=', '')\
|
||||
.replace('$', '')\
|
||||
.replace('/', '')\
|
||||
.replace('*', '')\
|
||||
.replace(';', '')\
|
||||
.replace('<b>', '')\
|
||||
.replace('%', '')
|
||||
sentence = sentence.split(' ')
|
||||
sentence = list(filter(lambda x: x, sentence))
|
||||
if sentence:
|
||||
self.word_num += len(sentence)
|
||||
self.maxlen = self.maxlen if self.maxlen >= len(sentence) else len(sentence)
|
||||
self.minlen = self.minlen if self.minlen <= len(sentence) else len(sentence)
|
||||
if 'train' in filePath:
|
||||
self.train.append([sentence, label])
|
||||
elif 'dev' in filePath:
|
||||
self.test.append([sentence, label])
|
||||
|
||||
def text2vec(self, maxlen):
|
||||
"""
|
||||
convert the sentence into a vector in an int type
|
||||
|
||||
input:
|
||||
maxlen: max length of the sentence
|
||||
"""
|
||||
# Vocab = {word : index}
|
||||
self.Vocab = dict()
|
||||
|
||||
for SentenceLabel in self.train+self.test:
|
||||
vector = [0]*maxlen
|
||||
for index, word in enumerate(SentenceLabel[0]):
|
||||
if index >= maxlen:
|
||||
break
|
||||
if word not in self.Vocab.keys():
|
||||
self.Vocab[word] = len(self.Vocab)
|
||||
vector[index] = len(self.Vocab) - 1
|
||||
else:
|
||||
vector[index] = self.Vocab[word]
|
||||
SentenceLabel[0] = vector
|
||||
self.doConvert = True
|
||||
|
||||
def split_dataset(self, split):
|
||||
"""
|
||||
split the dataset into training set and test set
|
||||
input:
|
||||
split: the ratio of training set to test set
|
||||
rank: logic order
|
||||
size: device num
|
||||
"""
|
||||
random.shuffle(self.train)
|
||||
|
|
|
@ -97,14 +97,14 @@ class TextCNN(nn.Cell):
|
|||
"""
|
||||
TextCNN architecture
|
||||
"""
|
||||
def __init__(self, vocab_len, word_len, num_classes, vec_length):
|
||||
def __init__(self, vocab_len, word_len, num_classes, vec_length, embedding_table='uniform'):
|
||||
super(TextCNN, self).__init__()
|
||||
self.vec_length = vec_length
|
||||
self.word_len = word_len
|
||||
self.num_classes = num_classes
|
||||
|
||||
self.unsqueeze = P.ExpandDims()
|
||||
self.embedding = nn.Embedding(vocab_len, self.vec_length, embedding_table='uniform')
|
||||
self.embedding = nn.Embedding(vocab_len, self.vec_length, embedding_table=embedding_table)
|
||||
|
||||
self.slice = P.Slice()
|
||||
self.layer1 = self.make_layer(kernel_height=3)
|
||||
|
|
|
@ -26,15 +26,16 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMoni
|
|||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import cfg
|
||||
from src.config import cfg_mr, cfg_subj, cfg_sst2
|
||||
from src.textcnn import TextCNN
|
||||
from src.textcnn import SoftmaxCrossEntropyExpand
|
||||
from src.dataset import MovieReview
|
||||
from src.dataset import MovieReview, SST2, Subjectivity
|
||||
|
||||
parser = argparse.ArgumentParser(description='TextCNN')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--device_id', type=int, default=5, help='device id of GPU or Ascend.')
|
||||
parser.add_argument('--dataset', type=str, default="MR", choices=['MR', 'SUBJ', 'SST2'])
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -42,16 +43,25 @@ if __name__ == '__main__':
|
|||
# set context
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
if args_opt.dataset == 'MR':
|
||||
cfg = cfg_mr
|
||||
instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9)
|
||||
elif args_opt.dataset == 'SUBJ':
|
||||
cfg = cfg_subj
|
||||
instance = Subjectivity(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9)
|
||||
elif args_opt.dataset == 'SST2':
|
||||
cfg = cfg_sst2
|
||||
instance = SST2(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9)
|
||||
|
||||
instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9)
|
||||
dataset = instance.create_train_dataset(batch_size=cfg.batch_size, epoch_size=cfg.epoch_size)
|
||||
batch_num = dataset.get_dataset_size()
|
||||
|
||||
base_lr = cfg.base_lr
|
||||
learning_rate = []
|
||||
warm_up = [1e-3 / math.floor(cfg.epoch_size / 5) * (i + 1) for _ in range(batch_num) for i in
|
||||
warm_up = [base_lr / math.floor(cfg.epoch_size / 5) * (i + 1) for _ in range(batch_num) for i in
|
||||
range(math.floor(cfg.epoch_size / 5))]
|
||||
shrink = [1e-3 / (16 * (i + 1)) for _ in range(batch_num) for i in range(math.floor(cfg.epoch_size * 3 / 5))]
|
||||
normal_run = [1e-3 for _ in range(batch_num) for i in
|
||||
shrink = [base_lr / (16 * (i + 1)) for _ in range(batch_num) for i in range(math.floor(cfg.epoch_size * 3 / 5))]
|
||||
normal_run = [base_lr for _ in range(batch_num) for i in
|
||||
range(cfg.epoch_size - math.floor(cfg.epoch_size / 5) - math.floor(cfg.epoch_size * 2 / 5))]
|
||||
learning_rate = learning_rate + warm_up + normal_run + shrink
|
||||
|
||||
|
|
Loading…
Reference in New Issue