forked from OSSInnovation/mindspore
!5239 reduce cyclomatic complexity in model zoo
Merge pull request !5239 from zhaoting/master
This commit is contained in:
commit
c95ed54fe1
|
@ -17,6 +17,120 @@ import math
|
|||
import numpy as np
|
||||
|
||||
|
||||
def _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps):
|
||||
"""
|
||||
Applies three steps decay to generate learning rate array.
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate.
|
||||
lr_max(float): max learning rate.
|
||||
total_steps(int): all steps in training.
|
||||
warmup_steps(int): all steps in warmup epochs.
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
if i < decay_epoch_index[0]:
|
||||
lr = lr_max
|
||||
elif i < decay_epoch_index[1]:
|
||||
lr = lr_max * 0.1
|
||||
elif i < decay_epoch_index[2]:
|
||||
lr = lr_max * 0.01
|
||||
else:
|
||||
lr = lr_max * 0.001
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def _generate_exponential_lr(lr_init, lr_max, total_steps, warmup_steps, steps_per_epoch):
|
||||
"""
|
||||
Applies exponential decay to generate learning rate array.
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate.
|
||||
lr_max(float): max learning rate.
|
||||
total_steps(int): all steps in training.
|
||||
warmup_steps(int): all steps in warmup epochs.
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
lr_each_step = []
|
||||
if warmup_steps != 0:
|
||||
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
else:
|
||||
inc_each_step = 0
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = float(lr_init) + inc_each_step * float(i)
|
||||
else:
|
||||
decay_nums = math.floor((float(i - warmup_steps) / steps_per_epoch) / 2)
|
||||
decay_rate = pow(0.94, decay_nums)
|
||||
lr = float(lr_max) * decay_rate
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
|
||||
"""
|
||||
Applies cosine decay to generate learning rate array.
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate.
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate.
|
||||
total_steps(int): all steps in training.
|
||||
warmup_steps(int): all steps in warmup epochs.
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
decay_steps = total_steps - warmup_steps
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
lr = float(lr_init) + lr_inc * (i + 1)
|
||||
else:
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * (i-warmup_steps) / decay_steps))
|
||||
lr = (lr_max-lr_end)*cosine_decay + lr_end
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
|
||||
"""
|
||||
Applies liner decay to generate learning rate array.
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate.
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate.
|
||||
total_steps(int): all steps in training.
|
||||
warmup_steps(int): all steps in warmup epochs.
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
@ -28,60 +142,20 @@ def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch
|
|||
warmup_epochs(int): number of warmup epochs
|
||||
total_epochs(int): total epoch of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or default
|
||||
lr_decay_mode(string): learning rate decay mode, including steps, steps_decay, cosine or liner(default)
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
warmup_steps = steps_per_epoch * warmup_epochs
|
||||
if lr_decay_mode == 'steps':
|
||||
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
if i < decay_epoch_index[0]:
|
||||
lr = lr_max
|
||||
elif i < decay_epoch_index[1]:
|
||||
lr = lr_max * 0.1
|
||||
elif i < decay_epoch_index[2]:
|
||||
lr = lr_max * 0.01
|
||||
else:
|
||||
lr = lr_max * 0.001
|
||||
lr_each_step.append(lr)
|
||||
lr_each_step = _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps)
|
||||
elif lr_decay_mode == 'steps_decay':
|
||||
if warmup_steps != 0:
|
||||
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
else:
|
||||
inc_each_step = 0
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = float(lr_init) + inc_each_step * float(i)
|
||||
else:
|
||||
decay_nums = math.floor((float(i-warmup_steps)/steps_per_epoch) / 2)
|
||||
decay_rate = pow(0.94, decay_nums)
|
||||
lr = float(lr_max)*decay_rate
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
lr_each_step = _generate_exponential_lr(lr_init, lr_max, total_steps, warmup_steps, steps_per_epoch)
|
||||
elif lr_decay_mode == 'cosine':
|
||||
decay_steps = total_steps - warmup_steps
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
lr = float(lr_init) + lr_inc * (i + 1)
|
||||
else:
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * (i-warmup_steps) / decay_steps))
|
||||
lr = (lr_max-lr_end)*cosine_decay + lr_end
|
||||
lr_each_step.append(lr)
|
||||
lr_each_step = _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)
|
||||
else:
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
|
||||
lr_each_step.append(lr)
|
||||
lr_each_step = _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)
|
||||
learning_rate = np.array(lr_each_step).astype(np.float32)
|
||||
return learning_rate
|
||||
|
|
|
@ -17,6 +17,120 @@ import math
|
|||
import numpy as np
|
||||
|
||||
|
||||
def _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps):
|
||||
"""
|
||||
Applies three steps decay to generate learning rate array.
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate.
|
||||
lr_max(float): max learning rate.
|
||||
total_steps(int): all steps in training.
|
||||
warmup_steps(int): all steps in warmup epochs.
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
if i < decay_epoch_index[0]:
|
||||
lr = lr_max
|
||||
elif i < decay_epoch_index[1]:
|
||||
lr = lr_max * 0.1
|
||||
elif i < decay_epoch_index[2]:
|
||||
lr = lr_max * 0.01
|
||||
else:
|
||||
lr = lr_max * 0.001
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def _generate_poly_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
|
||||
"""
|
||||
Applies polynomial decay to generate learning rate array.
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate.
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate.
|
||||
total_steps(int): all steps in training.
|
||||
warmup_steps(int): all steps in warmup epochs.
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
lr_each_step = []
|
||||
if warmup_steps != 0:
|
||||
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
else:
|
||||
inc_each_step = 0
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = float(lr_init) + inc_each_step * float(i)
|
||||
else:
|
||||
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
|
||||
lr = float(lr_max) * base * base
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
|
||||
"""
|
||||
Applies cosine decay to generate learning rate array.
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate.
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate.
|
||||
total_steps(int): all steps in training.
|
||||
warmup_steps(int): all steps in warmup epochs.
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
decay_steps = total_steps - warmup_steps
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
lr = float(lr_init) + lr_inc * (i + 1)
|
||||
else:
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * (i-warmup_steps) / decay_steps))
|
||||
lr = (lr_max-lr_end)*cosine_decay + lr_end
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
|
||||
"""
|
||||
Applies liner decay to generate learning rate array.
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate.
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate.
|
||||
total_steps(int): all steps in training.
|
||||
warmup_steps(int): all steps in warmup epochs.
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
|
||||
def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
@ -28,7 +142,7 @@ def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch
|
|||
warmup_epochs(int): number of warmup epochs
|
||||
total_epochs(int): total epoch of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
lr_decay_mode(string): learning rate decay mode, including steps, poly or default
|
||||
lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or liner(default)
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
|
@ -36,54 +150,17 @@ def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch
|
|||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
warmup_steps = steps_per_epoch * warmup_epochs
|
||||
|
||||
if lr_decay_mode == 'steps':
|
||||
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
|
||||
for i in range(total_steps):
|
||||
if i < decay_epoch_index[0]:
|
||||
lr = lr_max
|
||||
elif i < decay_epoch_index[1]:
|
||||
lr = lr_max * 0.1
|
||||
elif i < decay_epoch_index[2]:
|
||||
lr = lr_max * 0.01
|
||||
else:
|
||||
lr = lr_max * 0.001
|
||||
lr_each_step.append(lr)
|
||||
lr_each_step = _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps)
|
||||
elif lr_decay_mode == 'poly':
|
||||
if warmup_steps != 0:
|
||||
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
else:
|
||||
inc_each_step = 0
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = float(lr_init) + inc_each_step * float(i)
|
||||
else:
|
||||
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
|
||||
lr = float(lr_max) * base * base
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
lr_each_step = _generate_poly_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)
|
||||
elif lr_decay_mode == 'cosine':
|
||||
decay_steps = total_steps - warmup_steps
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
lr = float(lr_init) + lr_inc * (i + 1)
|
||||
else:
|
||||
linear_decay = (total_steps - i) / decay_steps
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
|
||||
decayed = linear_decay * cosine_decay + 0.00001
|
||||
lr = lr_max * decayed
|
||||
lr_each_step.append(lr)
|
||||
lr_each_step = _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)
|
||||
else:
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
|
||||
lr_each_step.append(lr)
|
||||
lr_each_step = _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)
|
||||
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
return lr_each_step
|
||||
|
||||
|
||||
|
|
|
@ -17,6 +17,120 @@ import math
|
|||
import numpy as np
|
||||
|
||||
|
||||
def _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps):
|
||||
"""
|
||||
Applies three steps decay to generate learning rate array.
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate.
|
||||
lr_max(float): max learning rate.
|
||||
total_steps(int): all steps in training.
|
||||
warmup_steps(int): all steps in warmup epochs.
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
if i < decay_epoch_index[0]:
|
||||
lr = lr_max
|
||||
elif i < decay_epoch_index[1]:
|
||||
lr = lr_max * 0.1
|
||||
elif i < decay_epoch_index[2]:
|
||||
lr = lr_max * 0.01
|
||||
else:
|
||||
lr = lr_max * 0.001
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def _generate_poly_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
|
||||
"""
|
||||
Applies polynomial decay to generate learning rate array.
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate.
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate.
|
||||
total_steps(int): all steps in training.
|
||||
warmup_steps(int): all steps in warmup epochs.
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
lr_each_step = []
|
||||
if warmup_steps != 0:
|
||||
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
else:
|
||||
inc_each_step = 0
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = float(lr_init) + inc_each_step * float(i)
|
||||
else:
|
||||
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
|
||||
lr = float(lr_max) * base * base
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
|
||||
"""
|
||||
Applies cosine decay to generate learning rate array.
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate.
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate.
|
||||
total_steps(int): all steps in training.
|
||||
warmup_steps(int): all steps in warmup epochs.
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
decay_steps = total_steps - warmup_steps
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
lr = float(lr_init) + lr_inc * (i + 1)
|
||||
else:
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * (i-warmup_steps) / decay_steps))
|
||||
lr = (lr_max-lr_end)*cosine_decay + lr_end
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
|
||||
"""
|
||||
Applies liner decay to generate learning rate array.
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate.
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate.
|
||||
total_steps(int): all steps in training.
|
||||
warmup_steps(int): all steps in warmup epochs.
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
|
||||
def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
@ -28,7 +142,7 @@ def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch
|
|||
warmup_epochs(int): number of warmup epochs
|
||||
total_epochs(int): total epoch of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or default
|
||||
lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or liner(default)
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
|
@ -36,52 +150,15 @@ def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch
|
|||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
warmup_steps = steps_per_epoch * warmup_epochs
|
||||
|
||||
if lr_decay_mode == 'steps':
|
||||
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
|
||||
for i in range(total_steps):
|
||||
if i < decay_epoch_index[0]:
|
||||
lr = lr_max
|
||||
elif i < decay_epoch_index[1]:
|
||||
lr = lr_max * 0.1
|
||||
elif i < decay_epoch_index[2]:
|
||||
lr = lr_max * 0.01
|
||||
else:
|
||||
lr = lr_max * 0.001
|
||||
lr_each_step.append(lr)
|
||||
lr_each_step = _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps)
|
||||
elif lr_decay_mode == 'poly':
|
||||
if warmup_steps != 0:
|
||||
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
else:
|
||||
inc_each_step = 0
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = float(lr_init) + inc_each_step * float(i)
|
||||
else:
|
||||
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
|
||||
lr = float(lr_max) * base * base
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
lr_each_step = _generate_poly_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)
|
||||
elif lr_decay_mode == 'cosine':
|
||||
decay_steps = total_steps - warmup_steps
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
lr = float(lr_init) + lr_inc * (i + 1)
|
||||
else:
|
||||
linear_decay = (total_steps - i) / decay_steps
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
|
||||
decayed = linear_decay * cosine_decay + 0.00001
|
||||
lr = lr_max * decayed
|
||||
lr_each_step.append(lr)
|
||||
lr_each_step = _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)
|
||||
else:
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
|
||||
lr_each_step.append(lr)
|
||||
lr_each_step = _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)
|
||||
|
||||
learning_rate = np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
return learning_rate
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
return lr_each_step
|
||||
|
|
|
@ -22,14 +22,15 @@ import numpy as np
|
|||
import mindspore.nn as nn
|
||||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_rank, get_group_size, release
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from src.utils.logging import get_logger
|
||||
from src.utils.auto_mixed_precision import auto_mixed_precision
|
||||
from src.utils.var_init import load_pretrain_model
|
||||
from src.image_classification import get_network
|
||||
from src.dataset import classification_dataset
|
||||
from src.config import config
|
||||
|
@ -79,6 +80,22 @@ def parse_args(cloud_args=None):
|
|||
|
||||
args.image_size = list(map(int, args.image_size.split(',')))
|
||||
|
||||
# init distributed
|
||||
if args.is_distributed:
|
||||
if args.platform == "Ascend":
|
||||
init()
|
||||
elif args.platform == "GPU":
|
||||
init("nccl")
|
||||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
else:
|
||||
args.rank = 0
|
||||
args.group_size = 1
|
||||
|
||||
args.outputs_dir = os.path.join(args.log_path,
|
||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
|
||||
args.logger = get_logger(args.outputs_dir, args.rank)
|
||||
return args
|
||||
|
||||
|
||||
|
@ -102,6 +119,53 @@ def merge_args(args, cloud_args):
|
|||
args_dict[key] = val
|
||||
return args
|
||||
|
||||
|
||||
def get_result(args, model, top1_correct, top5_correct, img_tot):
|
||||
"""calculate top1 and top5 value."""
|
||||
results = [[top1_correct], [top5_correct], [img_tot]]
|
||||
args.logger.info('before results={}'.format(results))
|
||||
if args.is_distributed:
|
||||
model_md5 = model.replace('/', '')
|
||||
tmp_dir = '/cache'
|
||||
if not os.path.exists(tmp_dir):
|
||||
os.mkdir(tmp_dir)
|
||||
top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(args.rank, model_md5)
|
||||
top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(args.rank, model_md5)
|
||||
img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(args.rank, model_md5)
|
||||
np.save(top1_correct_npy, top1_correct)
|
||||
np.save(top5_correct_npy, top5_correct)
|
||||
np.save(img_tot_npy, img_tot)
|
||||
while True:
|
||||
rank_ok = True
|
||||
for other_rank in range(args.group_size):
|
||||
top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(other_rank, model_md5)
|
||||
top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(other_rank, model_md5)
|
||||
img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(other_rank, model_md5)
|
||||
if not os.path.exists(top1_correct_npy) or not os.path.exists(top5_correct_npy) or \
|
||||
not os.path.exists(img_tot_npy):
|
||||
rank_ok = False
|
||||
if rank_ok:
|
||||
break
|
||||
|
||||
top1_correct_all = 0
|
||||
top5_correct_all = 0
|
||||
img_tot_all = 0
|
||||
for other_rank in range(args.group_size):
|
||||
top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(other_rank, model_md5)
|
||||
top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(other_rank, model_md5)
|
||||
img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(other_rank, model_md5)
|
||||
top1_correct_all += np.load(top1_correct_npy)
|
||||
top5_correct_all += np.load(top5_correct_npy)
|
||||
img_tot_all += np.load(img_tot_npy)
|
||||
results = [[top1_correct_all], [top5_correct_all], [img_tot_all]]
|
||||
results = np.array(results)
|
||||
else:
|
||||
results = np.array(results)
|
||||
|
||||
args.logger.info('after results={}'.format(results))
|
||||
return results
|
||||
|
||||
|
||||
def test(cloud_args=None):
|
||||
"""test"""
|
||||
args = parse_args(cloud_args)
|
||||
|
@ -112,20 +176,10 @@ def test(cloud_args=None):
|
|||
|
||||
# init distributed
|
||||
if args.is_distributed:
|
||||
init()
|
||||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
|
||||
parameter_broadcast=True, gradients_mean=True)
|
||||
else:
|
||||
args.rank = 0
|
||||
args.group_size = 1
|
||||
|
||||
args.outputs_dir = os.path.join(args.log_path,
|
||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
|
||||
args.logger = get_logger(args.outputs_dir, args.rank)
|
||||
args.logger.save_args(args)
|
||||
|
||||
# network
|
||||
|
@ -151,18 +205,7 @@ def test(cloud_args=None):
|
|||
if network is None:
|
||||
raise NotImplementedError('not implement {}'.format(args.backbone))
|
||||
|
||||
param_dict = load_checkpoint(model)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.'):
|
||||
continue
|
||||
elif key.startswith('network.'):
|
||||
param_dict_new[key[8:]] = values
|
||||
else:
|
||||
param_dict_new[key] = values
|
||||
|
||||
load_param_into_net(network, param_dict_new)
|
||||
args.logger.info('load model {} success'.format(model))
|
||||
load_pretrain_model(model, network, args)
|
||||
|
||||
img_tot = 0
|
||||
top1_correct = 0
|
||||
|
@ -193,47 +236,7 @@ def test(cloud_args=None):
|
|||
time_used = time.time() - t_end
|
||||
fps = (img_tot - args.per_batch_size) * args.group_size / time_used
|
||||
args.logger.info('Inference Performance: {:.2f} img/sec'.format(fps))
|
||||
results = [[top1_correct], [top5_correct], [img_tot]]
|
||||
args.logger.info('before results={}'.format(results))
|
||||
if args.is_distributed:
|
||||
model_md5 = model.replace('/', '')
|
||||
tmp_dir = '/cache'
|
||||
if not os.path.exists(tmp_dir):
|
||||
os.mkdir(tmp_dir)
|
||||
top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(args.rank, model_md5)
|
||||
top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(args.rank, model_md5)
|
||||
img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(args.rank, model_md5)
|
||||
np.save(top1_correct_npy, top1_correct)
|
||||
np.save(top5_correct_npy, top5_correct)
|
||||
np.save(img_tot_npy, img_tot)
|
||||
while True:
|
||||
rank_ok = True
|
||||
for other_rank in range(args.group_size):
|
||||
top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(other_rank, model_md5)
|
||||
top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(other_rank, model_md5)
|
||||
img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(other_rank, model_md5)
|
||||
if not os.path.exists(top1_correct_npy) or not os.path.exists(top5_correct_npy) or \
|
||||
not os.path.exists(img_tot_npy):
|
||||
rank_ok = False
|
||||
if rank_ok:
|
||||
break
|
||||
|
||||
top1_correct_all = 0
|
||||
top5_correct_all = 0
|
||||
img_tot_all = 0
|
||||
for other_rank in range(args.group_size):
|
||||
top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(other_rank, model_md5)
|
||||
top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(other_rank, model_md5)
|
||||
img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(other_rank, model_md5)
|
||||
top1_correct_all += np.load(top1_correct_npy)
|
||||
top5_correct_all += np.load(top5_correct_npy)
|
||||
img_tot_all += np.load(img_tot_npy)
|
||||
results = [[top1_correct_all], [top5_correct_all], [img_tot_all]]
|
||||
results = np.array(results)
|
||||
else:
|
||||
results = np.array(results)
|
||||
|
||||
args.logger.info('after results={}'.format(results))
|
||||
results = get_result(args, model, top1_correct, top5_correct, img_tot)
|
||||
top1_correct = results[0, 0]
|
||||
top5_correct = results[1, 0]
|
||||
img_tot = results[2, 0]
|
||||
|
|
|
@ -1,21 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
linear warm up learning rate.
|
||||
"""
|
||||
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
lr = float(init_lr) + lr_inc * current_step
|
||||
return lr
|
|
@ -0,0 +1,142 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
learning rate generator.
|
||||
"""
|
||||
import math
|
||||
from collections import Counter
|
||||
import numpy as np
|
||||
|
||||
|
||||
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
|
||||
"""
|
||||
Applies liner Increasing to generate learning rate array in warmup stage.
|
||||
|
||||
Args:
|
||||
current_step(int): current step in warmup stage.
|
||||
warmup_steps(int): all steps in warmup stage.
|
||||
base_lr(float): init learning rate.
|
||||
init_lr(float): end learning rate
|
||||
|
||||
Returns:
|
||||
float, learning rate.
|
||||
"""
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
lr = float(init_lr) + lr_inc * current_step
|
||||
return lr
|
||||
|
||||
|
||||
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
|
||||
"""
|
||||
Applies cosine decay to generate learning rate array with warmup.
|
||||
|
||||
Args:
|
||||
lr(float): init learning rate
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
warmup_epochs(int): number of warmup epochs
|
||||
max_epoch(int): total epoch of training
|
||||
T_max(int): max epoch in decay.
|
||||
eta_min(float): end learning rate
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
last_epoch = i // steps_per_epoch
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max)) / 2
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
|
||||
"""
|
||||
Applies step decay to generate learning rate array with warmup.
|
||||
|
||||
Args:
|
||||
lr(float): init learning rate
|
||||
lr_epochs(list): learning rate decay epoches list
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
warmup_epochs(int): number of warmup epochs
|
||||
max_epoch(int): total epoch of training
|
||||
gamma(float): attenuation constants.
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
milestones = lr_epochs
|
||||
milestones_steps = []
|
||||
for milestone in milestones:
|
||||
milestones_step = milestone * steps_per_epoch
|
||||
milestones_steps.append(milestones_step)
|
||||
|
||||
lr_each_step = []
|
||||
lr = base_lr
|
||||
milestones_steps_counter = Counter(milestones_steps)
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = lr * gamma**milestones_steps_counter[i]
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
|
||||
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
|
||||
|
||||
|
||||
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
|
||||
lr_epochs = []
|
||||
for i in range(1, max_epoch):
|
||||
if i % epoch_size == 0:
|
||||
lr_epochs.append(i)
|
||||
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)
|
||||
|
||||
|
||||
def get_lr(args):
|
||||
"""generate learning rate array."""
|
||||
if args.lr_scheduler == 'exponential':
|
||||
lr = warmup_step_lr(args.lr,
|
||||
args.lr_epochs,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
gamma=args.lr_gamma,
|
||||
)
|
||||
elif args.lr_scheduler == 'cosine_annealing':
|
||||
lr = warmup_cosine_annealing_lr(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.T_max,
|
||||
args.eta_min)
|
||||
else:
|
||||
raise NotImplementedError(args.lr_scheduler)
|
||||
return lr
|
|
@ -15,11 +15,13 @@
|
|||
"""
|
||||
Initialize.
|
||||
"""
|
||||
import os
|
||||
import math
|
||||
from functools import reduce
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common import initializer as init
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
def _calculate_gain(nonlinearity, param=None):
|
||||
r"""
|
||||
|
@ -208,3 +210,19 @@ def default_recurisive_init(custom_cell):
|
|||
cell.bias.dtype))
|
||||
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
||||
pass
|
||||
|
||||
|
||||
def load_pretrain_model(ckpt_file, network, args):
|
||||
"""load pretrain model."""
|
||||
if os.path.isfile(ckpt_file):
|
||||
param_dict = load_checkpoint(ckpt_file)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.'):
|
||||
continue
|
||||
elif key.startswith('network.'):
|
||||
param_dict_new[key[8:]] = values
|
||||
else:
|
||||
param_dict_new[key] = values
|
||||
load_param_into_net(network, param_dict_new)
|
||||
args.logger.info('load model {} success'.format(ckpt_file))
|
||||
|
|
|
@ -1,40 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
warm up cosine annealing learning rate.
|
||||
"""
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from .linear_warmup import linear_warmup_lr
|
||||
|
||||
|
||||
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
|
||||
"""warm up cosine annealing learning rate."""
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
last_epoch = i // steps_per_epoch
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
|
@ -1,56 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
warm up step learning rate.
|
||||
"""
|
||||
from collections import Counter
|
||||
import numpy as np
|
||||
|
||||
from .linear_warmup import linear_warmup_lr
|
||||
|
||||
|
||||
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
|
||||
"""warmup_step_lr"""
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
milestones = lr_epochs
|
||||
milestones_steps = []
|
||||
for milestone in milestones:
|
||||
milestones_step = milestone * steps_per_epoch
|
||||
milestones_steps.append(milestones_step)
|
||||
|
||||
lr_each_step = []
|
||||
lr = base_lr
|
||||
milestones_steps_counter = Counter(milestones_steps)
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = lr * gamma**milestones_steps_counter[i]
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
|
||||
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
|
||||
|
||||
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
|
||||
lr_epochs = []
|
||||
for i in range(1, max_epoch):
|
||||
if i % epoch_size == 0:
|
||||
lr_epochs.append(i)
|
||||
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)
|
|
@ -25,17 +25,16 @@ from mindspore.nn.optim import Momentum
|
|||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.train.callback import ModelCheckpoint
|
||||
from mindspore.train.callback import CheckpointConfig, Callback
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.dataset import classification_dataset
|
||||
from src.crossentropy import CrossEntropy
|
||||
from src.warmup_step_lr import warmup_step_lr
|
||||
from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr
|
||||
from src.lr_generator import get_lr
|
||||
from src.utils.logging import get_logger
|
||||
from src.utils.optimizers__init__ import get_param_groups
|
||||
from src.utils.var_init import load_pretrain_model
|
||||
from src.image_classification import get_network
|
||||
from src.config import config
|
||||
|
||||
|
@ -149,37 +148,11 @@ def parse_args(cloud_args=None):
|
|||
args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
|
||||
args.image_size = list(map(int, args.image_size.split(',')))
|
||||
|
||||
return args
|
||||
|
||||
def merge_args(args, cloud_args):
|
||||
"""dictionary"""
|
||||
args_dict = vars(args)
|
||||
if isinstance(cloud_args, dict):
|
||||
for key in cloud_args.keys():
|
||||
val = cloud_args[key]
|
||||
if key in args_dict and val:
|
||||
arg_type = type(args_dict[key])
|
||||
if arg_type is not type(None):
|
||||
val = arg_type(val)
|
||||
args_dict[key] = val
|
||||
return args
|
||||
|
||||
def train(cloud_args=None):
|
||||
"""training process"""
|
||||
args = parse_args(cloud_args)
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||
device_target=args.platform, save_graphs=False)
|
||||
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||
|
||||
# init distributed
|
||||
if args.is_distributed:
|
||||
init()
|
||||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
|
||||
parameter_broadcast=True, gradients_mean=True)
|
||||
else:
|
||||
args.rank = 0
|
||||
args.group_size = 1
|
||||
|
@ -199,7 +172,35 @@ def train(cloud_args=None):
|
|||
args.outputs_dir = os.path.join(args.ckpt_path,
|
||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
args.logger = get_logger(args.outputs_dir, args.rank)
|
||||
return args
|
||||
|
||||
def merge_args(args, cloud_args):
|
||||
"""dictionary"""
|
||||
args_dict = vars(args)
|
||||
if isinstance(cloud_args, dict):
|
||||
for key in cloud_args.keys():
|
||||
val = cloud_args[key]
|
||||
if key in args_dict and val:
|
||||
arg_type = type(args_dict[key])
|
||||
if arg_type is not type(None):
|
||||
val = arg_type(val)
|
||||
args_dict[key] = val
|
||||
return args
|
||||
|
||||
|
||||
def train(cloud_args=None):
|
||||
"""training process"""
|
||||
args = parse_args(cloud_args)
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||
device_target=args.platform, save_graphs=False)
|
||||
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||
|
||||
# init distributed
|
||||
if args.is_distributed:
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
|
||||
parameter_broadcast=True, gradients_mean=True)
|
||||
# dataloader
|
||||
de_dataset = classification_dataset(args.data_dir, args.image_size,
|
||||
args.per_batch_size, 1,
|
||||
|
@ -216,38 +217,10 @@ def train(cloud_args=None):
|
|||
if network is None:
|
||||
raise NotImplementedError('not implement {}'.format(args.backbone))
|
||||
|
||||
# load pretrain model
|
||||
if os.path.isfile(args.pretrained):
|
||||
param_dict = load_checkpoint(args.pretrained)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.'):
|
||||
continue
|
||||
elif key.startswith('network.'):
|
||||
param_dict_new[key[8:]] = values
|
||||
else:
|
||||
param_dict_new[key] = values
|
||||
load_param_into_net(network, param_dict_new)
|
||||
args.logger.info('load model {} success'.format(args.pretrained))
|
||||
load_pretrain_model(args.pretrained, network, args)
|
||||
|
||||
# lr scheduler
|
||||
if args.lr_scheduler == 'exponential':
|
||||
lr = warmup_step_lr(args.lr,
|
||||
args.lr_epochs,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
gamma=args.lr_gamma,
|
||||
)
|
||||
elif args.lr_scheduler == 'cosine_annealing':
|
||||
lr = warmup_cosine_annealing_lr(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.T_max,
|
||||
args.eta_min)
|
||||
else:
|
||||
raise NotImplementedError(args.lr_scheduler)
|
||||
lr = get_lr(args)
|
||||
|
||||
# optimizer
|
||||
opt = Momentum(params=get_param_groups(network),
|
||||
|
|
|
@ -18,8 +18,9 @@ from functools import reduce
|
|||
import numpy as np
|
||||
from mindspore.common import initializer as init
|
||||
from mindspore.common.initializer import Initializer as MeInitializer
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
import mindspore.nn as nn
|
||||
|
||||
from .util import load_backbone
|
||||
|
||||
def calculate_gain(nonlinearity, param=None):
|
||||
r"""Return the recommended gain value for the given nonlinearity function.
|
||||
|
@ -176,3 +177,28 @@ def default_recurisive_init(custom_cell):
|
|||
cell.bias.dtype))
|
||||
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
||||
pass
|
||||
|
||||
def load_yolov3_params(args, network):
|
||||
"""Load yolov3 darknet parameter from checkpoint."""
|
||||
if args.pretrained_backbone:
|
||||
network = load_backbone(network, args.pretrained_backbone, args)
|
||||
args.logger.info('load pre-trained backbone {} into network'.format(args.pretrained_backbone))
|
||||
else:
|
||||
args.logger.info('Not load pre-trained backbone, please be careful')
|
||||
|
||||
if args.resume_yolov3:
|
||||
param_dict = load_checkpoint(args.resume_yolov3)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.'):
|
||||
continue
|
||||
elif key.startswith('yolo_network.'):
|
||||
param_dict_new[key[13:]] = values
|
||||
args.logger.info('in resume {}'.format(key))
|
||||
else:
|
||||
param_dict_new[key] = values
|
||||
args.logger.info('in resume {}'.format(key))
|
||||
|
||||
args.logger.info('resume finished')
|
||||
load_param_into_net(network, param_dict_new)
|
||||
args.logger.info('load_model {} success'.format(args.resume_yolov3))
|
||||
|
|
|
@ -142,3 +142,39 @@ def warmup_cosine_annealing_lr_sample(lr, steps_per_epoch, warmup_epochs, max_ep
|
|||
|
||||
assert total_steps == len(lr_each_step)
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
def get_lr(args):
|
||||
"""generate learning rate."""
|
||||
if args.lr_scheduler == 'exponential':
|
||||
lr = warmup_step_lr(args.lr,
|
||||
args.lr_epochs,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
gamma=args.lr_gamma,
|
||||
)
|
||||
elif args.lr_scheduler == 'cosine_annealing':
|
||||
lr = warmup_cosine_annealing_lr(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.T_max,
|
||||
args.eta_min)
|
||||
elif args.lr_scheduler == 'cosine_annealing_V2':
|
||||
lr = warmup_cosine_annealing_lr_V2(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.T_max,
|
||||
args.eta_min)
|
||||
elif args.lr_scheduler == 'cosine_annealing_sample':
|
||||
lr = warmup_cosine_annealing_lr_sample(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.T_max,
|
||||
args.eta_min)
|
||||
else:
|
||||
raise NotImplementedError(args.lr_scheduler)
|
||||
return lr
|
||||
|
|
|
@ -27,18 +27,16 @@ from mindspore.communication.management import init, get_rank, get_group_size
|
|||
from mindspore.train.callback import ModelCheckpoint, RunContext
|
||||
from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
|
||||
import mindspore as ms
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore import amp
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper
|
||||
from src.logger import get_logger
|
||||
from src.util import AverageMeter, load_backbone, get_param_groups
|
||||
from src.lr_scheduler import warmup_step_lr, warmup_cosine_annealing_lr, \
|
||||
warmup_cosine_annealing_lr_V2, warmup_cosine_annealing_lr_sample
|
||||
from src.util import AverageMeter, get_param_groups
|
||||
from src.lr_scheduler import get_lr
|
||||
from src.yolo_dataset import create_yolo_dataset
|
||||
from src.initializer import default_recurisive_init
|
||||
from src.initializer import default_recurisive_init, load_yolov3_params
|
||||
from src.config import ConfigYOLOV3DarkNet53
|
||||
from src.util import keep_loss_fp32
|
||||
|
||||
|
@ -126,22 +124,6 @@ def parse_args():
|
|||
args.data_root = os.path.join(args.data_dir, 'train2014')
|
||||
args.annFile = os.path.join(args.data_dir, 'annotations/instances_train2014.json')
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def conver_training_shape(args):
|
||||
training_shape = [int(args.training_shape), int(args.training_shape)]
|
||||
return training_shape
|
||||
|
||||
|
||||
def train():
|
||||
"""Train function."""
|
||||
args = parse_args()
|
||||
|
||||
devid = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||
device_target=args.device_target, save_graphs=True, device_id=devid)
|
||||
|
||||
# init distributed
|
||||
if args.is_distributed:
|
||||
if args.device_target == "Ascend":
|
||||
|
@ -165,6 +147,20 @@ def train():
|
|||
args.logger = get_logger(args.outputs_dir, args.rank)
|
||||
args.logger.save_args(args)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def conver_training_shape(args):
|
||||
training_shape = [int(args.training_shape), int(args.training_shape)]
|
||||
return training_shape
|
||||
|
||||
|
||||
def train():
|
||||
"""Train function."""
|
||||
args = parse_args()
|
||||
devid = int(os.getenv('DEVICE_ID', '0'))
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||
device_target=args.device_target, save_graphs=True, device_id=devid)
|
||||
if args.need_profiler:
|
||||
from mindspore.profiler.profiling import Profiler
|
||||
profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)
|
||||
|
@ -172,40 +168,17 @@ def train():
|
|||
loss_meter = AverageMeter('loss')
|
||||
|
||||
context.reset_auto_parallel_context()
|
||||
parallel_mode = ParallelMode.STAND_ALONE
|
||||
degree = 1
|
||||
if args.is_distributed:
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
degree = get_group_size()
|
||||
else:
|
||||
parallel_mode = ParallelMode.STAND_ALONE
|
||||
degree = 1
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree)
|
||||
|
||||
network = YOLOV3DarkNet53(is_training=True)
|
||||
# default is kaiming-normal
|
||||
default_recurisive_init(network)
|
||||
|
||||
if args.pretrained_backbone:
|
||||
network = load_backbone(network, args.pretrained_backbone, args)
|
||||
args.logger.info('load pre-trained backbone {} into network'.format(args.pretrained_backbone))
|
||||
else:
|
||||
args.logger.info('Not load pre-trained backbone, please be careful')
|
||||
|
||||
if args.resume_yolov3:
|
||||
param_dict = load_checkpoint(args.resume_yolov3)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.'):
|
||||
continue
|
||||
elif key.startswith('yolo_network.'):
|
||||
param_dict_new[key[13:]] = values
|
||||
args.logger.info('in resume {}'.format(key))
|
||||
else:
|
||||
param_dict_new[key] = values
|
||||
args.logger.info('in resume {}'.format(key))
|
||||
|
||||
args.logger.info('resume finished')
|
||||
load_param_into_net(network, param_dict_new)
|
||||
args.logger.info('load_model {} success'.format(args.resume_yolov3))
|
||||
load_yolov3_params(args, network)
|
||||
|
||||
network = YoloWithLossCell(network)
|
||||
args.logger.info('finish get network')
|
||||
|
@ -230,49 +203,15 @@ def train():
|
|||
if not args.ckpt_interval:
|
||||
args.ckpt_interval = args.steps_per_epoch
|
||||
|
||||
# lr scheduler
|
||||
if args.lr_scheduler == 'exponential':
|
||||
lr = warmup_step_lr(args.lr,
|
||||
args.lr_epochs,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
gamma=args.lr_gamma,
|
||||
)
|
||||
elif args.lr_scheduler == 'cosine_annealing':
|
||||
lr = warmup_cosine_annealing_lr(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.T_max,
|
||||
args.eta_min)
|
||||
elif args.lr_scheduler == 'cosine_annealing_V2':
|
||||
lr = warmup_cosine_annealing_lr_V2(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.T_max,
|
||||
args.eta_min)
|
||||
elif args.lr_scheduler == 'cosine_annealing_sample':
|
||||
lr = warmup_cosine_annealing_lr_sample(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.T_max,
|
||||
args.eta_min)
|
||||
else:
|
||||
raise NotImplementedError(args.lr_scheduler)
|
||||
lr = get_lr(args)
|
||||
|
||||
opt = Momentum(params=get_param_groups(network),
|
||||
learning_rate=Tensor(lr),
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
loss_scale=args.loss_scale)
|
||||
enable_amp = False
|
||||
is_gpu = context.get_context("device_target") == "GPU"
|
||||
if is_gpu:
|
||||
enable_amp = True
|
||||
if enable_amp:
|
||||
loss_scale_value = 1.0
|
||||
loss_scale = FixedLossScaleManager(loss_scale_value, drop_overflow_update=False)
|
||||
network = amp.build_train_network(network, optimizer=opt, loss_scale_manager=loss_scale,
|
||||
|
|
|
@ -19,6 +19,7 @@ from mindspore.common import initializer as init
|
|||
from mindspore.common.initializer import Initializer as MeInitializer
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
|
||||
def calculate_gain(nonlinearity, param=None):
|
||||
|
@ -174,3 +175,51 @@ def default_recurisive_init(custom_cell):
|
|||
cell.bias.data.dtype))
|
||||
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
||||
pass
|
||||
|
||||
def load_yolov3_quant_params(args, network):
|
||||
"""Load quant yolov3 darknet parameter from checkpoint."""
|
||||
if args.resume_yolov3:
|
||||
param_dict = load_checkpoint(args.resume_yolov3)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
args.logger.info('ckpt param name = {}'.format(key))
|
||||
if key.startswith('moments.') or key.startswith('global_') or \
|
||||
key.startswith('learning_rate') or key.startswith('momentum'):
|
||||
continue
|
||||
elif key.startswith('yolo_network.'):
|
||||
key_new = key[13:]
|
||||
|
||||
if key_new.endswith('1.beta'):
|
||||
key_new = key_new.replace('1.beta', 'batchnorm.beta')
|
||||
|
||||
if key_new.endswith('1.gamma'):
|
||||
key_new = key_new.replace('1.gamma', 'batchnorm.gamma')
|
||||
|
||||
if key_new.endswith('1.moving_mean'):
|
||||
key_new = key_new.replace('1.moving_mean', 'batchnorm.moving_mean')
|
||||
|
||||
if key_new.endswith('1.moving_variance'):
|
||||
key_new = key_new.replace('1.moving_variance', 'batchnorm.moving_variance')
|
||||
|
||||
if key_new.endswith('.weight'):
|
||||
if key_new.endswith('0.weight'):
|
||||
key_new = key_new.replace('0.weight', 'conv.weight')
|
||||
else:
|
||||
key_new = key_new.replace('.weight', '.conv.weight')
|
||||
|
||||
if key_new.endswith('.bias'):
|
||||
key_new = key_new.replace('.bias', '.conv.bias')
|
||||
param_dict_new[key_new] = values
|
||||
|
||||
args.logger.info('in resume {}'.format(key_new))
|
||||
else:
|
||||
param_dict_new[key] = values
|
||||
args.logger.info('in resume {}'.format(key))
|
||||
|
||||
args.logger.info('resume finished')
|
||||
for _, param in network.parameters_and_names():
|
||||
args.logger.info('network param name = {}'.format(param.name))
|
||||
if param.name not in param_dict_new:
|
||||
args.logger.info('not match param name = {}'.format(param.name))
|
||||
load_param_into_net(network, param_dict_new)
|
||||
args.logger.info('load_model {} success'.format(args.resume_yolov3))
|
||||
|
|
|
@ -141,3 +141,39 @@ def warmup_cosine_annealing_lr_sample(lr, steps_per_epoch, warmup_epochs, max_ep
|
|||
|
||||
assert total_steps == len(lr_each_step)
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
def get_lr(args):
|
||||
"""generate learning rate."""
|
||||
if args.lr_scheduler == 'exponential':
|
||||
lr = warmup_step_lr(args.lr,
|
||||
args.lr_epochs,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
gamma=args.lr_gamma,
|
||||
)
|
||||
elif args.lr_scheduler == 'cosine_annealing':
|
||||
lr = warmup_cosine_annealing_lr(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.T_max,
|
||||
args.eta_min)
|
||||
elif args.lr_scheduler == 'cosine_annealing_V2':
|
||||
lr = warmup_cosine_annealing_lr_V2(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.T_max,
|
||||
args.eta_min)
|
||||
elif args.lr_scheduler == 'cosine_annealing_sample':
|
||||
lr = warmup_cosine_annealing_lr_sample(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.T_max,
|
||||
args.eta_min)
|
||||
else:
|
||||
raise NotImplementedError(args.lr_scheduler)
|
||||
return lr
|
||||
|
|
|
@ -27,17 +27,15 @@ from mindspore.communication.management import init, get_rank, get_group_size
|
|||
from mindspore.train.callback import ModelCheckpoint, RunContext
|
||||
from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
|
||||
import mindspore as ms
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper
|
||||
from src.logger import get_logger
|
||||
from src.util import AverageMeter, get_param_groups
|
||||
from src.lr_scheduler import warmup_step_lr, warmup_cosine_annealing_lr, \
|
||||
warmup_cosine_annealing_lr_V2, warmup_cosine_annealing_lr_sample
|
||||
from src.lr_scheduler import get_lr
|
||||
from src.yolo_dataset import create_yolo_dataset
|
||||
from src.initializer import default_recurisive_init
|
||||
from src.initializer import default_recurisive_init, load_yolov3_quant_params
|
||||
from src.config import ConfigYOLOV3DarkNet53
|
||||
from src.transforms import batch_preprocess_true_box, batch_preprocess_true_box_single
|
||||
from src.util import ShapeRecord
|
||||
|
@ -117,18 +115,6 @@ def parse_args():
|
|||
args.data_root = os.path.join(args.data_dir, 'train2014')
|
||||
args.annFile = os.path.join(args.data_dir, 'annotations/instances_train2014.json')
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def conver_training_shape(args):
|
||||
training_shape = [int(args.training_shape), int(args.training_shape)]
|
||||
return training_shape
|
||||
|
||||
|
||||
def train():
|
||||
"""Train function."""
|
||||
args = parse_args()
|
||||
|
||||
# init distributed
|
||||
if args.is_distributed:
|
||||
init()
|
||||
|
@ -147,6 +133,17 @@ def train():
|
|||
args.outputs_dir = os.path.join(args.ckpt_path,
|
||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
args.logger = get_logger(args.outputs_dir, args.rank)
|
||||
return args
|
||||
|
||||
|
||||
def conver_training_shape(args):
|
||||
training_shape = [int(args.training_shape), int(args.training_shape)]
|
||||
return training_shape
|
||||
|
||||
|
||||
def train():
|
||||
"""Train function."""
|
||||
args = parse_args()
|
||||
args.logger.save_args(args)
|
||||
|
||||
if args.need_profiler:
|
||||
|
@ -156,63 +153,17 @@ def train():
|
|||
loss_meter = AverageMeter('loss')
|
||||
|
||||
context.reset_auto_parallel_context()
|
||||
parallel_mode = ParallelMode.STAND_ALONE
|
||||
degree = 1
|
||||
if args.is_distributed:
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
degree = get_group_size()
|
||||
else:
|
||||
parallel_mode = ParallelMode.STAND_ALONE
|
||||
degree = 1
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree)
|
||||
|
||||
network = YOLOV3DarkNet53(is_training=True)
|
||||
# default is kaiming-normal
|
||||
default_recurisive_init(network)
|
||||
|
||||
if args.resume_yolov3:
|
||||
param_dict = load_checkpoint(args.resume_yolov3)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
args.logger.info('ckpt param name = {}'.format(key))
|
||||
if key.startswith('moments.') or key.startswith('global_') or \
|
||||
key.startswith('learning_rate') or key.startswith('momentum'):
|
||||
continue
|
||||
elif key.startswith('yolo_network.'):
|
||||
key_new = key[13:]
|
||||
|
||||
if key_new.endswith('1.beta'):
|
||||
key_new = key_new.replace('1.beta', 'batchnorm.beta')
|
||||
|
||||
if key_new.endswith('1.gamma'):
|
||||
key_new = key_new.replace('1.gamma', 'batchnorm.gamma')
|
||||
|
||||
if key_new.endswith('1.moving_mean'):
|
||||
key_new = key_new.replace('1.moving_mean', 'batchnorm.moving_mean')
|
||||
|
||||
if key_new.endswith('1.moving_variance'):
|
||||
key_new = key_new.replace('1.moving_variance', 'batchnorm.moving_variance')
|
||||
|
||||
if key_new.endswith('.weight'):
|
||||
if key_new.endswith('0.weight'):
|
||||
key_new = key_new.replace('0.weight', 'conv.weight')
|
||||
else:
|
||||
key_new = key_new.replace('.weight', '.conv.weight')
|
||||
|
||||
if key_new.endswith('.bias'):
|
||||
key_new = key_new.replace('.bias', '.conv.bias')
|
||||
param_dict_new[key_new] = values
|
||||
|
||||
args.logger.info('in resume {}'.format(key_new))
|
||||
else:
|
||||
param_dict_new[key] = values
|
||||
args.logger.info('in resume {}'.format(key))
|
||||
|
||||
args.logger.info('resume finished')
|
||||
for _, param in network.parameters_and_names():
|
||||
args.logger.info('network param name = {}'.format(param.name))
|
||||
if param.name not in param_dict_new:
|
||||
args.logger.info('not match param name = {}'.format(param.name))
|
||||
load_param_into_net(network, param_dict_new)
|
||||
args.logger.info('load_model {} success'.format(args.resume_yolov3))
|
||||
load_yolov3_quant_params(args, network)
|
||||
|
||||
config = ConfigYOLOV3DarkNet53()
|
||||
# convert fusion network to quantization aware network
|
||||
|
@ -244,38 +195,7 @@ def train():
|
|||
if not args.ckpt_interval:
|
||||
args.ckpt_interval = args.steps_per_epoch
|
||||
|
||||
# lr scheduler
|
||||
if args.lr_scheduler == 'exponential':
|
||||
lr = warmup_step_lr(args.lr,
|
||||
args.lr_epochs,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
gamma=args.lr_gamma,
|
||||
)
|
||||
elif args.lr_scheduler == 'cosine_annealing':
|
||||
lr = warmup_cosine_annealing_lr(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.T_max,
|
||||
args.eta_min)
|
||||
elif args.lr_scheduler == 'cosine_annealing_V2':
|
||||
lr = warmup_cosine_annealing_lr_V2(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.T_max,
|
||||
args.eta_min)
|
||||
elif args.lr_scheduler == 'cosine_annealing_sample':
|
||||
lr = warmup_cosine_annealing_lr_sample(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.T_max,
|
||||
args.eta_min)
|
||||
else:
|
||||
raise NotImplementedError(args.lr_scheduler)
|
||||
lr = get_lr(args)
|
||||
|
||||
opt = Momentum(params=get_param_groups(network),
|
||||
learning_rate=Tensor(lr),
|
||||
|
|
|
@ -139,8 +139,9 @@ def do_eval(dataset=None, network=None, use_crf="", num_class=2, assessment_meth
|
|||
eval_result_print(assessment_method, callback)
|
||||
print("==============================================================")
|
||||
|
||||
def run_ner():
|
||||
"""run ner task"""
|
||||
|
||||
def parse_args():
|
||||
"""set and check parameters."""
|
||||
parser = argparse.ArgumentParser(description="run classifier")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU"],
|
||||
help="Device type, default is Ascend")
|
||||
|
@ -171,12 +172,6 @@ def run_ner():
|
|||
parser.add_argument("--schema_file_path", type=str, default="",
|
||||
help="Schema path, it is better to use absolute path")
|
||||
args_opt = parser.parse_args()
|
||||
epoch_num = args_opt.epoch_num
|
||||
assessment_method = args_opt.assessment_method.lower()
|
||||
load_pretrain_checkpoint_path = args_opt.load_pretrain_checkpoint_path
|
||||
save_finetune_checkpoint_path = args_opt.save_finetune_checkpoint_path
|
||||
load_finetune_checkpoint_path = args_opt.load_finetune_checkpoint_path
|
||||
|
||||
if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false":
|
||||
raise ValueError("At least one of 'do_train' or 'do_eval' must be true")
|
||||
if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "":
|
||||
|
@ -189,7 +184,17 @@ def run_ner():
|
|||
raise ValueError("'label2id_file_path' must be set to use crf")
|
||||
if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.label2id_file_path == "":
|
||||
raise ValueError("'label2id_file_path' must be set to do clue benchmark")
|
||||
return args_opt
|
||||
|
||||
|
||||
def run_ner():
|
||||
"""run ner task"""
|
||||
args_opt = parse_args()
|
||||
epoch_num = args_opt.epoch_num
|
||||
assessment_method = args_opt.assessment_method.lower()
|
||||
load_pretrain_checkpoint_path = args_opt.load_pretrain_checkpoint_path
|
||||
save_finetune_checkpoint_path = args_opt.save_finetune_checkpoint_path
|
||||
load_finetune_checkpoint_path = args_opt.load_finetune_checkpoint_path
|
||||
target = args_opt.device_target
|
||||
if target == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
|
|
|
@ -39,6 +39,58 @@ from src.utils import LossCallBack, BertLearningRate
|
|||
_current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
|
||||
def _set_bert_all_reduce_split():
|
||||
"""set bert all_reduce fusion split, support num_hidden_layers is 12 and 24."""
|
||||
if bert_net_cfg.num_hidden_layers == 12:
|
||||
if bert_net_cfg.use_relative_positions:
|
||||
context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 87, 116, 145, 174, 203, 217])
|
||||
else:
|
||||
context.set_auto_parallel_context(all_reduce_fusion_config=[28, 55, 82, 109, 136, 163, 190, 205])
|
||||
elif bert_net_cfg.num_hidden_layers == 24:
|
||||
if bert_net_cfg.use_relative_positions:
|
||||
context.set_auto_parallel_context(all_reduce_fusion_config=[30, 90, 150, 210, 270, 330, 390, 421])
|
||||
else:
|
||||
context.set_auto_parallel_context(all_reduce_fusion_config=[38, 93, 148, 203, 258, 313, 368, 397])
|
||||
|
||||
|
||||
def _get_optimizer(args_opt, network):
|
||||
"""get bert optimizer, support Lamb, Momentum, AdamWeightDecay."""
|
||||
if cfg.optimizer == 'Lamb':
|
||||
lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
|
||||
end_learning_rate=cfg.Lamb.end_learning_rate,
|
||||
warmup_steps=cfg.Lamb.warmup_steps,
|
||||
decay_steps=args_opt.train_steps,
|
||||
power=cfg.Lamb.power)
|
||||
params = network.trainable_params()
|
||||
decay_params = list(filter(cfg.Lamb.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
|
||||
{'params': other_params},
|
||||
{'order_params': params}]
|
||||
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
|
||||
elif cfg.optimizer == 'Momentum':
|
||||
optimizer = Momentum(network.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
|
||||
momentum=cfg.Momentum.momentum)
|
||||
elif cfg.optimizer == 'AdamWeightDecay':
|
||||
lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
|
||||
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
|
||||
warmup_steps=cfg.AdamWeightDecay.warmup_steps,
|
||||
decay_steps=args_opt.train_steps,
|
||||
power=cfg.AdamWeightDecay.power)
|
||||
params = network.trainable_params()
|
||||
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0},
|
||||
{'order_params': params}]
|
||||
|
||||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
||||
else:
|
||||
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]".
|
||||
format(cfg.optimizer))
|
||||
return optimizer
|
||||
|
||||
|
||||
def run_pretrain():
|
||||
"""pre-train bert_clue"""
|
||||
parser = argparse.ArgumentParser(description='bert pre_training')
|
||||
|
@ -88,16 +140,7 @@ def run_pretrain():
|
|||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
device_num=device_num)
|
||||
if bert_net_cfg.num_hidden_layers == 12:
|
||||
if bert_net_cfg.use_relative_positions:
|
||||
context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 87, 116, 145, 174, 203, 217])
|
||||
else:
|
||||
context.set_auto_parallel_context(all_reduce_fusion_config=[28, 55, 82, 109, 136, 163, 190, 205])
|
||||
elif bert_net_cfg.num_hidden_layers == 24:
|
||||
if bert_net_cfg.use_relative_positions:
|
||||
context.set_auto_parallel_context(all_reduce_fusion_config=[30, 90, 150, 210, 270, 330, 390, 421])
|
||||
else:
|
||||
context.set_auto_parallel_context(all_reduce_fusion_config=[38, 93, 148, 203, 258, 313, 368, 397])
|
||||
_set_bert_all_reduce_split()
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
@ -127,39 +170,7 @@ def run_pretrain():
|
|||
args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size() // args_opt.accumulation_steps
|
||||
logger.info("train steps: {}".format(args_opt.train_steps))
|
||||
|
||||
if cfg.optimizer == 'Lamb':
|
||||
lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
|
||||
end_learning_rate=cfg.Lamb.end_learning_rate,
|
||||
warmup_steps=cfg.Lamb.warmup_steps,
|
||||
decay_steps=args_opt.train_steps,
|
||||
power=cfg.Lamb.power)
|
||||
params = net_with_loss.trainable_params()
|
||||
decay_params = list(filter(cfg.Lamb.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
|
||||
{'params': other_params},
|
||||
{'order_params': params}]
|
||||
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
|
||||
elif cfg.optimizer == 'Momentum':
|
||||
optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
|
||||
momentum=cfg.Momentum.momentum)
|
||||
elif cfg.optimizer == 'AdamWeightDecay':
|
||||
lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
|
||||
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
|
||||
warmup_steps=cfg.AdamWeightDecay.warmup_steps,
|
||||
decay_steps=args_opt.train_steps,
|
||||
power=cfg.AdamWeightDecay.power)
|
||||
params = net_with_loss.trainable_params()
|
||||
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0},
|
||||
{'order_params': params}]
|
||||
|
||||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
||||
else:
|
||||
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]".
|
||||
format(cfg.optimizer))
|
||||
optimizer = _get_optimizer(args_opt, net_with_loss)
|
||||
callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack(ds.get_dataset_size())]
|
||||
if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(8, device_num) == 0:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
|
||||
|
|
|
@ -28,7 +28,6 @@ from src.model_thor import Model
|
|||
from src.utils import LossCallBack, BertLearningRate
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.communication.management as D
|
||||
from mindspore.communication.management import get_rank
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
|
||||
|
@ -41,6 +40,83 @@ from mindspore.common import set_seed
|
|||
_current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
|
||||
def _set_bert_all_reduce_split():
|
||||
"""set bert all_reduce fusion split, support num_hidden_layers is 12 and 24."""
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
if bert_net_cfg.num_hidden_layers == 12:
|
||||
if bert_net_cfg.use_relative_positions:
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217],
|
||||
"hccl_world_groupsum1")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217],
|
||||
"hccl_world_groupsum3")
|
||||
else:
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205],
|
||||
"hccl_world_groupsum1")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205],
|
||||
"hccl_world_groupsum3")
|
||||
elif bert_net_cfg.num_hidden_layers == 24:
|
||||
if bert_net_cfg.use_relative_positions:
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421],
|
||||
"hccl_world_groupsum1")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421],
|
||||
"hccl_world_groupsum3")
|
||||
else:
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397],
|
||||
"hccl_world_groupsum1")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397],
|
||||
"hccl_world_groupsum3")
|
||||
|
||||
|
||||
def _get_optimizer(args_opt, network):
|
||||
"""get bert optimizer, support Lamb, Momentum, AdamWeightDecay and Thor."""
|
||||
if cfg.optimizer == 'Lamb':
|
||||
lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
|
||||
end_learning_rate=cfg.Lamb.end_learning_rate,
|
||||
warmup_steps=cfg.Lamb.warmup_steps,
|
||||
decay_steps=args_opt.train_steps,
|
||||
power=cfg.Lamb.power)
|
||||
params = network.trainable_params()
|
||||
decay_params = list(filter(cfg.Lamb.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
|
||||
{'params': other_params},
|
||||
{'order_params': params}]
|
||||
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
|
||||
elif cfg.optimizer == 'Momentum':
|
||||
optimizer = Momentum(network.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
|
||||
momentum=cfg.Momentum.momentum)
|
||||
elif cfg.optimizer == 'AdamWeightDecay':
|
||||
lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
|
||||
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
|
||||
warmup_steps=cfg.AdamWeightDecay.warmup_steps,
|
||||
decay_steps=args_opt.train_steps,
|
||||
power=cfg.AdamWeightDecay.power)
|
||||
params = network.trainable_params()
|
||||
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0},
|
||||
{'order_params': params}]
|
||||
|
||||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
||||
elif cfg.optimizer == "Thor":
|
||||
if args_opt.distribute == "true":
|
||||
from src.thor_for_bert_arg import THOR
|
||||
else:
|
||||
from src.thor_for_bert import THOR
|
||||
lr = get_bert_lr()
|
||||
damping = get_bert_damping()
|
||||
optimizer = THOR(filter(lambda x: x.requires_grad, network.get_parameters()), lr, cfg.Thor.momentum,
|
||||
filter(lambda x: 'matrix_A' in x.name, network.get_parameters()),
|
||||
filter(lambda x: 'matrix_G' in x.name, network.get_parameters()),
|
||||
cfg.Thor.weight_decay, cfg.Thor.loss_scale, bert_net_cfg.num_hidden_layers,
|
||||
bert_net_cfg.batch_size, damping)
|
||||
else:
|
||||
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]".
|
||||
format(cfg.optimizer))
|
||||
return optimizer
|
||||
|
||||
|
||||
def run_pretrain():
|
||||
"""pre-train bert_clue"""
|
||||
parser = argparse.ArgumentParser(description='bert pre_training')
|
||||
|
@ -66,10 +142,6 @@ def run_pretrain():
|
|||
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
if args_opt.distribute == "true":
|
||||
from src.thor_for_bert_arg import THOR
|
||||
else:
|
||||
from src.thor_for_bert import THOR
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target,
|
||||
device_id=args_opt.device_id, save_graphs=False)
|
||||
context.set_context(reserve_class_name_in_scope=False)
|
||||
|
@ -77,42 +149,15 @@ def run_pretrain():
|
|||
context.set_context(max_call_depth=3000)
|
||||
ckpt_save_dir = args_opt.save_checkpoint_path
|
||||
if args_opt.distribute == "true":
|
||||
if args_opt.device_target == 'Ascend':
|
||||
D.init()
|
||||
device_num = args_opt.device_num
|
||||
rank = args_opt.device_id % device_num
|
||||
else:
|
||||
D.init()
|
||||
device_num = D.get_group_size()
|
||||
rank = D.get_rank()
|
||||
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(get_rank()) + '/'
|
||||
|
||||
D.init()
|
||||
device_num = D.get_group_size()
|
||||
rank = D.get_rank()
|
||||
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/'
|
||||
_set_bert_all_reduce_split()
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
device_num=device_num)
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
if bert_net_cfg.num_hidden_layers == 12:
|
||||
if bert_net_cfg.use_relative_positions:
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217],
|
||||
"hccl_world_groupsum1")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217],
|
||||
"hccl_world_groupsum3")
|
||||
else:
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205],
|
||||
"hccl_world_groupsum1")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205],
|
||||
"hccl_world_groupsum3")
|
||||
elif bert_net_cfg.num_hidden_layers == 24:
|
||||
if bert_net_cfg.use_relative_positions:
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421],
|
||||
"hccl_world_groupsum1")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421],
|
||||
"hccl_world_groupsum3")
|
||||
else:
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397],
|
||||
"hccl_world_groupsum1")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397],
|
||||
"hccl_world_groupsum3")
|
||||
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
@ -131,47 +176,7 @@ def run_pretrain():
|
|||
args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size()
|
||||
logger.info("train steps: {}".format(args_opt.train_steps))
|
||||
|
||||
if cfg.optimizer == 'Lamb':
|
||||
lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
|
||||
end_learning_rate=cfg.Lamb.end_learning_rate,
|
||||
warmup_steps=cfg.Lamb.warmup_steps,
|
||||
decay_steps=args_opt.train_steps,
|
||||
power=cfg.Lamb.power)
|
||||
params = net_with_loss.trainable_params()
|
||||
decay_params = list(filter(cfg.Lamb.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
|
||||
{'params': other_params},
|
||||
{'order_params': params}]
|
||||
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
|
||||
elif cfg.optimizer == 'Momentum':
|
||||
optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
|
||||
momentum=cfg.Momentum.momentum)
|
||||
elif cfg.optimizer == 'AdamWeightDecay':
|
||||
lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
|
||||
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
|
||||
warmup_steps=cfg.AdamWeightDecay.warmup_steps,
|
||||
decay_steps=args_opt.train_steps,
|
||||
power=cfg.AdamWeightDecay.power)
|
||||
params = net_with_loss.trainable_params()
|
||||
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0},
|
||||
{'order_params': params}]
|
||||
|
||||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
||||
elif cfg.optimizer == "Thor":
|
||||
lr = get_bert_lr()
|
||||
damping = get_bert_damping()
|
||||
optimizer = THOR(filter(lambda x: x.requires_grad, net_with_loss.get_parameters()), lr, cfg.Thor.momentum,
|
||||
filter(lambda x: 'matrix_A' in x.name, net_with_loss.get_parameters()),
|
||||
filter(lambda x: 'matrix_G' in x.name, net_with_loss.get_parameters()),
|
||||
cfg.Thor.weight_decay, cfg.Thor.loss_scale, bert_net_cfg.num_hidden_layers,
|
||||
bert_net_cfg.batch_size, damping)
|
||||
else:
|
||||
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]".
|
||||
format(cfg.optimizer))
|
||||
optimizer = _get_optimizer(args_opt, net_with_loss)
|
||||
callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()]
|
||||
if args_opt.enable_save_ckpt == "true" and rank == 0:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
|
||||
|
|
|
@ -100,6 +100,81 @@ def _train(model, config: TransformerConfig,
|
|||
pickle.dump(result, f, 1)
|
||||
|
||||
|
||||
def _load_checkpoint_to_net(config, network):
|
||||
"""load parameters to network from checkpoint."""
|
||||
if config.existed_ckpt:
|
||||
if config.existed_ckpt.endswith(".npz"):
|
||||
weights = np.load(config.existed_ckpt)
|
||||
else:
|
||||
weights = load_checkpoint(config.existed_ckpt)
|
||||
for param in network.trainable_params():
|
||||
weights_name = param.name
|
||||
if weights_name not in weights:
|
||||
raise ValueError(f"Param {weights_name} is not found in ckpt file.")
|
||||
|
||||
if isinstance(weights[weights_name], Parameter):
|
||||
param.set_data(weights[weights_name].data)
|
||||
elif isinstance(weights[weights_name], Tensor):
|
||||
param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
|
||||
elif isinstance(weights[weights_name], np.ndarray):
|
||||
param.set_data(Tensor(weights[weights_name], config.dtype))
|
||||
else:
|
||||
param.set_data(weights[weights_name])
|
||||
else:
|
||||
for param in network.trainable_params():
|
||||
name = param.name
|
||||
value = param.data
|
||||
if isinstance(value, Tensor):
|
||||
if name.endswith(".gamma"):
|
||||
param.set_data(one_weight(value.asnumpy().shape))
|
||||
elif name.endswith(".beta") or name.endswith(".bias"):
|
||||
param.set_data(zero_weight(value.asnumpy().shape))
|
||||
else:
|
||||
param.set_data(weight_variable(value.asnumpy().shape))
|
||||
|
||||
|
||||
def _get_lr(config, update_steps):
|
||||
"""generate learning rate."""
|
||||
if config.lr_scheduler == "isr":
|
||||
lr = Tensor(square_root_schedule(lr=config.lr,
|
||||
update_num=update_steps,
|
||||
decay_start_step=config.decay_start_step,
|
||||
warmup_steps=config.warmup_steps,
|
||||
min_lr=config.min_lr), dtype=mstype.float32)
|
||||
elif config.lr_scheduler == "poly":
|
||||
lr = Tensor(polynomial_decay_scheduler(lr=config.lr,
|
||||
min_lr=config.min_lr,
|
||||
decay_steps=config.decay_steps,
|
||||
total_update_num=update_steps,
|
||||
warmup_steps=config.warmup_steps,
|
||||
power=config.poly_lr_scheduler_power), dtype=mstype.float32)
|
||||
else:
|
||||
lr = config.lr
|
||||
return lr
|
||||
|
||||
|
||||
def _get_optimizer(config, network, lr):
|
||||
"""get mass optimizer, support Adam, Lamb, Momentum."""
|
||||
if config.optimizer.lower() == "adam":
|
||||
optimizer = Adam(network.trainable_params(), lr, beta1=0.9, beta2=0.98)
|
||||
elif config.optimizer.lower() == "lamb":
|
||||
lr = BertLearningRate(decay_steps=12000, learning_rate=config.lr, end_learning_rate=config.min_lr,
|
||||
power=10.0, warmup_steps=config.warmup_steps)
|
||||
decay_params = list(filter(lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
network.trainable_params()))
|
||||
other_params = list(filter(lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.lower(),
|
||||
network.trainable_params()))
|
||||
group_params = [{'params': decay_params, 'weight_decay': 0.01},
|
||||
{'params': other_params}]
|
||||
|
||||
optimizer = Lamb(group_params, lr, eps=1e-6)
|
||||
elif config.optimizer.lower() == "momentum":
|
||||
optimizer = Momentum(network.trainable_params(), lr, momentum=0.9)
|
||||
else:
|
||||
raise ValueError(f"optimizer only support `adam` and `momentum` now.")
|
||||
return optimizer
|
||||
|
||||
|
||||
def _build_training_pipeline(config: TransformerConfig,
|
||||
pre_training_dataset=None,
|
||||
fine_tune_dataset=None,
|
||||
|
@ -116,36 +191,7 @@ def _build_training_pipeline(config: TransformerConfig,
|
|||
"""
|
||||
net_with_loss = TransformerNetworkWithLoss(config, is_training=True)
|
||||
net_with_loss.init_parameters_data()
|
||||
|
||||
if config.existed_ckpt:
|
||||
if config.existed_ckpt.endswith(".npz"):
|
||||
weights = np.load(config.existed_ckpt)
|
||||
else:
|
||||
weights = load_checkpoint(config.existed_ckpt)
|
||||
for param in net_with_loss.trainable_params():
|
||||
weights_name = param.name
|
||||
if weights_name not in weights:
|
||||
raise ValueError(f"Param {weights_name} is not found in ckpt file.")
|
||||
|
||||
if isinstance(weights[weights_name], Parameter):
|
||||
param.set_data(weights[weights_name].data)
|
||||
elif isinstance(weights[weights_name], Tensor):
|
||||
param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
|
||||
elif isinstance(weights[weights_name], np.ndarray):
|
||||
param.set_data(Tensor(weights[weights_name], config.dtype))
|
||||
else:
|
||||
param.set_data(weights[weights_name])
|
||||
else:
|
||||
for param in net_with_loss.trainable_params():
|
||||
name = param.name
|
||||
value = param.data
|
||||
if isinstance(value, Tensor):
|
||||
if name.endswith(".gamma"):
|
||||
param.set_data(one_weight(value.asnumpy().shape))
|
||||
elif name.endswith(".beta") or name.endswith(".bias"):
|
||||
param.set_data(zero_weight(value.asnumpy().shape))
|
||||
else:
|
||||
param.set_data(weight_variable(value.asnumpy().shape))
|
||||
_load_checkpoint_to_net(config, net_with_loss)
|
||||
|
||||
dataset = pre_training_dataset if pre_training_dataset is not None \
|
||||
else fine_tune_dataset
|
||||
|
@ -154,39 +200,10 @@ def _build_training_pipeline(config: TransformerConfig,
|
|||
raise ValueError("pre-training dataset or fine-tuning dataset must be provided one.")
|
||||
|
||||
update_steps = config.epochs * dataset.get_dataset_size()
|
||||
if config.lr_scheduler == "isr":
|
||||
lr = Tensor(square_root_schedule(lr=config.lr,
|
||||
update_num=update_steps,
|
||||
decay_start_step=config.decay_start_step,
|
||||
warmup_steps=config.warmup_steps,
|
||||
min_lr=config.min_lr), dtype=mstype.float32)
|
||||
elif config.lr_scheduler == "poly":
|
||||
lr = Tensor(polynomial_decay_scheduler(lr=config.lr,
|
||||
min_lr=config.min_lr,
|
||||
decay_steps=config.decay_steps,
|
||||
total_update_num=update_steps,
|
||||
warmup_steps=config.warmup_steps,
|
||||
power=config.poly_lr_scheduler_power), dtype=mstype.float32)
|
||||
else:
|
||||
lr = config.lr
|
||||
|
||||
if config.optimizer.lower() == "adam":
|
||||
optimizer = Adam(net_with_loss.trainable_params(), lr, beta1=0.9, beta2=0.98)
|
||||
elif config.optimizer.lower() == "lamb":
|
||||
lr = BertLearningRate(decay_steps=12000, learning_rate=config.lr, end_learning_rate=config.min_lr,
|
||||
power=10.0, warmup_steps=config.warmup_steps)
|
||||
decay_params = list(filter(lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
net_with_loss.trainable_params()))
|
||||
other_params = list(filter(lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.lower(),
|
||||
net_with_loss.trainable_params()))
|
||||
group_params = [{'params': decay_params, 'weight_decay': 0.01},
|
||||
{'params': other_params}]
|
||||
lr = _get_lr(config, update_steps)
|
||||
|
||||
optimizer = Lamb(group_params, lr, eps=1e-6)
|
||||
elif config.optimizer.lower() == "momentum":
|
||||
optimizer = Momentum(net_with_loss.trainable_params(), lr, momentum=0.9)
|
||||
else:
|
||||
raise ValueError(f"optimizer only support `adam` and `momentum` now.")
|
||||
optimizer = _get_optimizer(config, net_with_loss, lr)
|
||||
|
||||
# loss scale.
|
||||
if config.loss_scale_mode == "dynamic":
|
||||
|
|
Loading…
Reference in New Issue