forked from mindspore-Ecosystem/mindspore
!20111 add adam offload for pangu & fix AdamWeightDecay nnacl
Merge pull request !20111 from zhaosida/zsd_adam_simd
This commit is contained in:
commit
eaf9588ac9
|
@ -59,11 +59,11 @@ void AdamWeightDecayCPUKernel::LaunchFusedAdam(const std::vector<AddressPtr> &in
|
|||
auto beta1 = reinterpret_cast<T *>(inputs[4]->addr)[0];
|
||||
auto beta2 = reinterpret_cast<T *>(inputs[5]->addr)[0];
|
||||
auto epsilon = reinterpret_cast<T *>(inputs[6]->addr)[0];
|
||||
auto decay = reinterpret_cast<T *>(inputs[7]->addr);
|
||||
auto decay = reinterpret_cast<T *>(inputs[7]->addr)[0];
|
||||
auto gradient16 = reinterpret_cast<S *>(inputs[8]->addr);
|
||||
const auto beta1_minus = 1 - beta1;
|
||||
const auto beta2_minus = 1 - beta2;
|
||||
|
||||
float beta1_minus = 1 - beta1;
|
||||
float beta2_minus = 1 - beta2;
|
||||
// multithreading
|
||||
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
|
||||
std::function<void(size_t, size_t)> task;
|
||||
|
@ -77,7 +77,7 @@ void AdamWeightDecayCPUKernel::LaunchFusedAdam(const std::vector<AddressPtr> &in
|
|||
m[i] += (temp - m[i]) * beta1_minus;
|
||||
v[i] += (temp * temp - v[i]) * beta2_minus;
|
||||
T update = m[i] / (std::sqrt(v[i]) + epsilon);
|
||||
update += *decay * var[i];
|
||||
update += decay * var[i];
|
||||
var[i] -= lr * update;
|
||||
}
|
||||
};
|
||||
|
@ -94,10 +94,10 @@ void AdamWeightDecayCPUKernel::LaunchAdamWeightDecay(const std::vector<AddressPt
|
|||
auto beta1 = reinterpret_cast<T *>(inputs[4]->addr)[0];
|
||||
auto beta2 = reinterpret_cast<T *>(inputs[5]->addr)[0];
|
||||
auto epsilon = reinterpret_cast<T *>(inputs[6]->addr)[0];
|
||||
auto decay = reinterpret_cast<T *>(inputs[7]->addr);
|
||||
auto decay = reinterpret_cast<T *>(inputs[7]->addr)[0];
|
||||
auto gradient = reinterpret_cast<T *>(inputs[8]->addr);
|
||||
auto beta1_minus = 1 - beta1;
|
||||
auto beta2_minus = 1 - beta2;
|
||||
const auto beta1_minus = 1 - beta1;
|
||||
const auto beta2_minus = 1 - beta2;
|
||||
|
||||
// multithreading
|
||||
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
|
||||
|
@ -110,7 +110,7 @@ void AdamWeightDecayCPUKernel::LaunchAdamWeightDecay(const std::vector<AddressPt
|
|||
m[i] += (gradient[i] - m[i]) * beta1_minus;
|
||||
v[i] += (gradient[i] * gradient[i] - v[i]) * beta2_minus;
|
||||
T update = m[i] / (std::sqrt(v[i]) + epsilon);
|
||||
update += decay[0] * var[i];
|
||||
update += decay * var[i];
|
||||
var[i] -= lr * update;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -16,6 +16,9 @@ endif()
|
|||
if("${X86_64_SIMD}" STREQUAL "avx")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -mavx -mavx2 -mfma")
|
||||
endif()
|
||||
if("${X86_64_SIMD}" STREQUAL "avx512")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -mavx -mavx2 -mfma")
|
||||
endif()
|
||||
if("${X86_64_SIMD}" STREQUAL "sse")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1")
|
||||
endif()
|
||||
|
@ -52,6 +55,13 @@ if("${X86_64_SIMD}" STREQUAL "avx")
|
|||
set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C)
|
||||
endif()
|
||||
|
||||
if("${X86_64_SIMD}" STREQUAL "avx512")
|
||||
file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/intrinsics/sse/*.c
|
||||
${NNACL_DIR}/intrinsics/avx/*.c
|
||||
${NNACL_DIR}/assembly/avx/*.S)
|
||||
set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C)
|
||||
endif()
|
||||
|
||||
if(APPLE)
|
||||
set_source_files_properties(${ASSEMBLY_SRC} PROPERTIES COMPILE_FLAGS "-x assembler-with-cpp")
|
||||
endif()
|
||||
|
@ -74,7 +84,7 @@ if(ENABLE_CPU)
|
|||
elseif("${X86_64_SIMD}" STREQUAL "avx")
|
||||
target_compile_definitions(nnacl_mid PRIVATE ENABLE_SSE ENABLE_AVX)
|
||||
elseif("${X86_64_SIMD}" STREQUAL "avx512")
|
||||
target_compile_definitions(nnacl_mid PRIVATE ENABLE_AVX512)
|
||||
target_compile_definitions(nnacl_mid PRIVATE ENABLE_SSE ENABLE_AVX ENABLE_AVX512)
|
||||
target_compile_options(nnacl_mid PRIVATE -mavx512f)
|
||||
endif()
|
||||
target_compile_options(nnacl_mid PRIVATE -fPIC)
|
||||
|
|
|
@ -152,12 +152,12 @@ int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float *decay,
|
||||
int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
||||
const float *gradient, size_t start, size_t end) {
|
||||
size_t c1 = start;
|
||||
#ifdef ENABLE_AVX512
|
||||
float beta1_minus = 1 - beta1;
|
||||
float beta2_minus = 1 - beta2;
|
||||
const float beta1_minus = 1 - beta1;
|
||||
const float beta2_minus = 1 - beta2;
|
||||
struct AVX_Data beta1_r, beta2_r, beta1_minus_r, beta2_minus_r, lr_neg_r, epsilon_r, decay_r;
|
||||
beta1_r.data = _mm512_set1_ps(beta1);
|
||||
beta2_r.data = _mm512_set1_ps(beta2);
|
||||
|
@ -165,7 +165,7 @@ int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, f
|
|||
beta2_minus_r.data = _mm512_set1_ps(beta2_minus);
|
||||
lr_neg_r.data = _mm512_set1_ps(-lr);
|
||||
epsilon_r.data = _mm512_set1_ps(epsilon);
|
||||
decay_r.data = _mm512_set1_ps(*decay);
|
||||
decay_r.data = _mm512_set1_ps(decay);
|
||||
size_t c16 = ((end - start) / C16NUM) * C16NUM + start;
|
||||
size_t c64 = ((end - start) / C64NUM) * C64NUM + start;
|
||||
|
||||
|
@ -175,11 +175,11 @@ int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, f
|
|||
float *v_ptr = v + start;
|
||||
|
||||
for (; c1 < c64; c1 += C64NUM) {
|
||||
struct AVX_Data g_r[4], var_r[4], m_r[4], v_r[4];
|
||||
LoadStep4(g_r, gradient_ptr);
|
||||
LoadStep4(var_r, var_ptr);
|
||||
LoadStep4(m_r, m_ptr);
|
||||
LoadStep4(v_r, v_ptr);
|
||||
struct AVX_Data g_r[kUnrollSize], var_r[kUnrollSize], m_r[kUnrollSize], v_r[kUnrollSize];
|
||||
LoadStep4(g_r, gradient_ptr, kUnrollSize);
|
||||
LoadStep4(var_r, var_ptr, kUnrollSize);
|
||||
LoadStep4(m_r, m_ptr, kUnrollSize);
|
||||
LoadStep4(v_r, v_ptr, kUnrollSize);
|
||||
|
||||
m_r[0].data = _mm512_mul_ps(m_r[0].data, beta1_r.data);
|
||||
m_r[1].data = _mm512_mul_ps(m_r[1].data, beta1_r.data);
|
||||
|
@ -221,9 +221,9 @@ int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, f
|
|||
var_r[2].data = _mm512_fmadd_ps(g_r[2].data, lr_neg_r.data, var_r[2].data);
|
||||
var_r[3].data = _mm512_fmadd_ps(g_r[3].data, lr_neg_r.data, var_r[3].data);
|
||||
|
||||
StoreStep4(var_ptr, var_r);
|
||||
StoreStep4(m_ptr, m_r);
|
||||
StoreStep4(v_ptr, v_r);
|
||||
StoreStep4(var_ptr, var_r, kUnrollSize);
|
||||
StoreStep4(m_ptr, m_r, kUnrollSize);
|
||||
StoreStep4(v_ptr, v_r, kUnrollSize);
|
||||
|
||||
gradient_ptr += C64NUM;
|
||||
var_ptr += C64NUM;
|
||||
|
@ -260,12 +260,12 @@ int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, f
|
|||
return c1;
|
||||
}
|
||||
|
||||
int FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float *decay,
|
||||
int FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
||||
const int16_t *gradient16, size_t start, size_t end) {
|
||||
size_t c1 = start;
|
||||
#ifdef ENABLE_AVX512
|
||||
float beta1_minus = 1 - beta1;
|
||||
float beta2_minus = 1 - beta2;
|
||||
const float beta1_minus = 1 - beta1;
|
||||
const float beta2_minus = 1 - beta2;
|
||||
struct AVX_Data beta1_r, beta2_r, beta1_minus_r, beta2_minus_r, lr_neg_r, epsilon_r, decay_r;
|
||||
beta1_r.data = _mm512_set1_ps(beta1);
|
||||
beta2_r.data = _mm512_set1_ps(beta2);
|
||||
|
@ -273,7 +273,7 @@ int FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float b
|
|||
beta2_minus_r.data = _mm512_set1_ps(beta2_minus);
|
||||
lr_neg_r.data = _mm512_set1_ps(-lr);
|
||||
epsilon_r.data = _mm512_set1_ps(epsilon);
|
||||
decay_r.data = _mm512_set1_ps(*decay);
|
||||
decay_r.data = _mm512_set1_ps(decay);
|
||||
size_t c16 = ((end - start) / C16NUM) * C16NUM + start;
|
||||
size_t c64 = ((end - start) / C64NUM) * C64NUM + start;
|
||||
|
||||
|
@ -283,15 +283,15 @@ int FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float b
|
|||
float *v_ptr = v + start;
|
||||
|
||||
for (; c1 < c64; c1 += C64NUM) {
|
||||
struct AVX_Data g_r[4], var_r[4], m_r[4], v_r[4];
|
||||
struct AVX_Data g_r[kUnrollSize], var_r[kUnrollSize], m_r[kUnrollSize], v_r[kUnrollSize];
|
||||
g_r[0].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr)));
|
||||
g_r[1].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr + C16NUM)));
|
||||
g_r[2].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr + C16NUM * 2)));
|
||||
g_r[3].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr + C16NUM * 3)));
|
||||
|
||||
LoadStep4(var_r, var_ptr);
|
||||
LoadStep4(m_r, m_ptr);
|
||||
LoadStep4(v_r, v_ptr);
|
||||
LoadStep4(var_r, var_ptr, kUnrollSize);
|
||||
LoadStep4(m_r, m_ptr, kUnrollSize);
|
||||
LoadStep4(v_r, v_ptr, kUnrollSize);
|
||||
|
||||
m_r[0].data = _mm512_mul_ps(m_r[0].data, beta1_r.data);
|
||||
m_r[1].data = _mm512_mul_ps(m_r[1].data, beta1_r.data);
|
||||
|
@ -333,9 +333,9 @@ int FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float b
|
|||
var_r[2].data = _mm512_fmadd_ps(g_r[2].data, lr_neg_r.data, var_r[2].data);
|
||||
var_r[3].data = _mm512_fmadd_ps(g_r[3].data, lr_neg_r.data, var_r[3].data);
|
||||
|
||||
StoreStep4(var_ptr, var_r);
|
||||
StoreStep4(m_ptr, m_r);
|
||||
StoreStep4(v_ptr, v_r);
|
||||
StoreStep4(var_ptr, var_r, kUnrollSize);
|
||||
StoreStep4(m_ptr, m_r, kUnrollSize);
|
||||
StoreStep4(v_ptr, v_r, kUnrollSize);
|
||||
|
||||
gradient16_ptr += C64NUM;
|
||||
var_ptr += C64NUM;
|
||||
|
|
|
@ -27,22 +27,37 @@
|
|||
#include <immintrin.h>
|
||||
#endif
|
||||
#ifdef ENABLE_AVX512
|
||||
const size_t kUnrollSize = 4;
|
||||
struct AVX_Data {
|
||||
__m512 data;
|
||||
};
|
||||
|
||||
static inline void LoadStep4(struct AVX_Data *inp0, const float *inp1) {
|
||||
static inline int LoadStep4(struct AVX_Data *inp0, const float *inp1, const size_t arrLen) {
|
||||
if (arrLen != kUnrollSize) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
if (inp0 == NULL || inp1 == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
}
|
||||
inp0[0].data = _mm512_loadu_ps(inp1);
|
||||
inp0[1].data = _mm512_loadu_ps(inp1 + C16NUM);
|
||||
inp0[2].data = _mm512_loadu_ps(inp1 + C16NUM * 2);
|
||||
inp0[3].data = _mm512_loadu_ps(inp1 + C16NUM * 3);
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
static inline void StoreStep4(float *inp0, struct AVX_Data *inp1) {
|
||||
static inline int StoreStep4(float *inp0, struct AVX_Data *inp1, const size_t arrLen) {
|
||||
if (arrLen != kUnrollSize) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
if (inp0 == NULL || inp1 == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
}
|
||||
_mm512_storeu_ps(inp0, inp1[0].data);
|
||||
_mm512_storeu_ps(inp0 + C16NUM, inp1[1].data);
|
||||
_mm512_storeu_ps(inp0 + C16NUM * 2, inp1[2].data);
|
||||
_mm512_storeu_ps(inp0 + C16NUM * 3, inp1[3].data);
|
||||
return NNACL_OK;
|
||||
}
|
||||
#endif
|
||||
#ifdef __cplusplus
|
||||
|
@ -52,9 +67,9 @@ int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2,
|
|||
size_t start, size_t end, bool use_nesterov);
|
||||
int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
|
||||
const float *gradient, size_t start, size_t end, bool use_nesterov);
|
||||
int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float *decay,
|
||||
int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
||||
const float *gradient, size_t start, size_t end);
|
||||
int FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float *decay,
|
||||
int FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
||||
const int16_t *gradient16, size_t start, size_t end);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""AdamWeightDecay, a customized Adam for pangu1. Input: gradient."""
|
||||
"""AdamWeightDecay, a customized Adam for pangu alpha. Input: gradient."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common import dtype as mstype
|
||||
|
@ -35,15 +35,14 @@ def _update_run_kernel(opt, beta1, beta2, eps, lr, weight_decay, param, m, v, gr
|
|||
"""
|
||||
Update parameters by AdamWeightDecay op.
|
||||
"""
|
||||
success = True
|
||||
if optim_filter:
|
||||
op_cast = P.Cast()
|
||||
gradient_fp32 = op_cast(gradient, mstype.float32)
|
||||
if decay_flags:
|
||||
next_param = opt(param, m, v, lr, beta1, beta2, eps, F.cast(weight_decay, mstype.float32), gradient_fp32)
|
||||
next_param = opt(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient)
|
||||
else:
|
||||
next_param = opt(param, m, v, lr, beta1, beta2, eps, F.cast(0.0, mstype.float32), gradient_fp32)
|
||||
return next_param
|
||||
return gradient
|
||||
next_param = opt(param, m, v, lr, beta1, beta2, eps, 0.0, gradient)
|
||||
return F.depend(success, next_param)
|
||||
return success
|
||||
|
||||
|
||||
def _check_param_value(beta1, beta2, eps, prim_name):
|
||||
|
@ -145,18 +144,17 @@ class AdamWeightDecayOp(Optimizer):
|
|||
lr = self.get_lr()
|
||||
if self.is_group:
|
||||
if self.is_group_lr:
|
||||
optim_result = self.map_(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps),
|
||||
lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
optim_result = self.map_reverse(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps),
|
||||
lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.map_(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr),
|
||||
self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
optim_result = self.map_reverse(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr),
|
||||
self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.map_(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr,
|
||||
self.weight_decay),
|
||||
self.parameters, self.moments1, self.moments2, gradients, self.decay_flags,
|
||||
self.optim_filter)
|
||||
optim_result = self.map_reverse(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr,
|
||||
self.weight_decay), self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
if self.use_parallel:
|
||||
self.broadcast_params(optim_result)
|
||||
return optim_result
|
|
@ -326,6 +326,10 @@ def add_training_params(opt):
|
|||
default="adam",
|
||||
choices=["adam", "lamb"],
|
||||
help="select which optimizer to be used, default adam")
|
||||
opt.add_argument("--opt_offload",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Enable optimizer status offload to host CPU, default is 0")
|
||||
opt.add_argument("--eod_id",
|
||||
type=int,
|
||||
default=6,
|
||||
|
|
|
@ -30,6 +30,7 @@ import mindspore.common.dtype as mstype
|
|||
from mindspore.parallel import set_algo_parameters
|
||||
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
|
||||
from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell
|
||||
from src.adam import AdamWeightDecayOp
|
||||
from src.dataset import create_dataset
|
||||
from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss, CrossEntropyLoss
|
||||
from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, PanguAlphaTrainPipelineWithLossScaleCell
|
||||
|
@ -76,13 +77,31 @@ project_root = os.path.abspath(
|
|||
print('project_root:', project_root)
|
||||
|
||||
|
||||
def set_weight_decay(params):
|
||||
"""
|
||||
Set weight decay coefficient, zero for bias and layernorm, 1e-1 for rest
|
||||
"""
|
||||
decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
|
||||
decay_params = list(filter(decay_filter, params))
|
||||
other_params = list(filter(lambda x: not decay_filter(x), params))
|
||||
group_params = [{
|
||||
'params': decay_params,
|
||||
'weight_decay': 1e-1
|
||||
}, {
|
||||
'params': other_params,
|
||||
'weight_decay': 0.0
|
||||
}, {
|
||||
'order_params': params
|
||||
}]
|
||||
return group_params
|
||||
|
||||
|
||||
def run_train(args_opt):
|
||||
r"""
|
||||
The main training process.
|
||||
"""
|
||||
# Set execution mode
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=args_opt.device_target)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
context.set_context(variable_memory_max_size="30GB")
|
||||
# Set parallel context
|
||||
if args_opt.distribute == "true":
|
||||
|
@ -149,22 +168,12 @@ def run_train(args_opt):
|
|||
warmup_steps=args_opt.warmup_step,
|
||||
decay_steps=200000)
|
||||
|
||||
# Set weight decay coefficient, zero for bias and layernorm, 1e-1 for rest
|
||||
decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
|
||||
params = pangu_alpha.trainable_params()
|
||||
decay_params = list(filter(decay_filter, params))
|
||||
other_params = list(filter(lambda x: not decay_filter(x), params))
|
||||
group_params = [{
|
||||
'params': decay_params,
|
||||
'weight_decay': 1e-1
|
||||
}, {
|
||||
'params': other_params,
|
||||
'weight_decay': 0.0
|
||||
}, {
|
||||
'order_params': params
|
||||
}]
|
||||
group_params = set_weight_decay(params)
|
||||
if args_opt.optimizer == "lamb":
|
||||
optimizer = nn.Lamb(group_params, learning_rate=lr)
|
||||
elif args_opt.opt_offload:
|
||||
optimizer = AdamWeightDecayOp(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95)
|
||||
else:
|
||||
optimizer = FP32StateAdamWeightDecay(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95)
|
||||
# Initial scaling sens
|
||||
|
@ -201,6 +210,7 @@ def run_train(args_opt):
|
|||
print("Dataset size: {}, actual_epoch_num: {}".format(ds.get_dataset_size(), actual_epoch_num), flush=True)
|
||||
model.train(actual_epoch_num, ds, callbacks=callback, sink_size=callback_size, dataset_sink_mode=True)
|
||||
|
||||
|
||||
def run_train_pipeline(args_opt):
|
||||
r"""
|
||||
The main training process in pipeline.
|
||||
|
@ -263,20 +273,11 @@ def run_train_pipeline(args_opt):
|
|||
lr = LearningRate(learning_rate=args_opt.start_lr, end_learning_rate=args_opt.end_lr,
|
||||
warmup_steps=args_opt.warmup_step, decay_steps=args_opt.decay_steps)
|
||||
params = pangu_alpha.infer_param_pipeline_stage()
|
||||
decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
|
||||
decay_params = list(filter(decay_filter, params))
|
||||
other_params = list(filter(lambda x: not decay_filter(x), params))
|
||||
group_params = [{
|
||||
'params': decay_params,
|
||||
'weight_decay': 1e-1
|
||||
}, {
|
||||
'params': other_params,
|
||||
'weight_decay': 0.0
|
||||
}, {
|
||||
'order_params': params
|
||||
}]
|
||||
group_params = set_weight_decay(params)
|
||||
if args_opt.optimizer == "lamb":
|
||||
optimizer = nn.Lamb(group_params, learning_rate=lr)
|
||||
elif args_opt.opt_offload:
|
||||
optimizer = AdamWeightDecayOp(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95)
|
||||
else:
|
||||
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8)
|
||||
|
||||
|
@ -297,6 +298,7 @@ def run_train_pipeline(args_opt):
|
|||
model.train(actual_epoch_num, ds, callbacks=callback,
|
||||
sink_size=callback_size, dataset_sink_mode=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
opt = get_args()
|
||||
set_parse(opt)
|
||||
|
|
|
@ -22,7 +22,7 @@ from mindspore import Tensor
|
|||
from mindspore.nn import Dense
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.ops import operations as P
|
||||
from model_zoo.official.nlp.gpt.src.adam import AdamWeightDecayOp
|
||||
from model_zoo.official.nlp.pangu_alpha.src.adam import AdamWeightDecayOp
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
|
Loading…
Reference in New Issue