support fp32 gradient for fusedcastadam

This commit is contained in:
kswang 2022-07-27 09:37:31 +08:00
parent 40d8645b5e
commit 378820c531
12 changed files with 394 additions and 54 deletions

View File

@ -50,7 +50,6 @@ void FusedCastAdamWeightDecayCpuKernelMod::LaunchFusedCastAdamFp32(const std::ve
auto beta2 = reinterpret_cast<float *>(inputs[kBeta2Index]->addr)[kScalarIndex];
auto epsilon = reinterpret_cast<float *>(inputs[kEpsIndex]->addr)[kScalarIndex];
auto decay = reinterpret_cast<float *>(inputs[kDecayIndex]->addr)[kScalarIndex];
auto gradient16 = reinterpret_cast<float16 *>(inputs[kGradIndex]->addr);
auto var = reinterpret_cast<float *>(inputs[kVarIndex]->addr);
auto global_norm = reinterpret_cast<float *>(inputs[kGlobalNormIndex]->addr)[kScalarIndex];
if (global_norm < kMinGlobalNorm) {
@ -64,19 +63,36 @@ void FusedCastAdamWeightDecayCpuKernelMod::LaunchFusedCastAdamFp32(const std::ve
size_t lens = inputs[kVarIndex]->size > 0 ? static_cast<size_t>(inputs[kVarIndex]->size / kSizeFloat32) : 1;
std::function<void(size_t, size_t)> task;
task = [&](size_t start, size_t end) {
size_t i = FusedCastAdamFp32(var, m, v, lr, beta1, beta2, epsilon, decay, reinterpret_cast<int16_t *>(gradient16),
global_norm_reciprocal, start, end);
// remaining
for (; i < end; i++) {
auto temp = static_cast<float>(gradient16[i]) * global_norm_reciprocal;
m[i] += (temp - m[i]) * beta1_minus;
v[i] += (temp * temp - v[i]) * beta2_minus;
auto update = m[i] / (std::sqrt(v[i]) + epsilon);
update += decay * var[i];
var[i] -= lr * update;
}
};
if (gradient_dtype_ == kNumberTypeFloat16) {
float16 *gradient16 = reinterpret_cast<float16 *>(inputs[kGradIndex]->addr);
task = [&](size_t start, size_t end) {
size_t i = FusedCastAdamFp32Fp16(var, reinterpret_cast<int16_t *>(gradient16), m, v, lr, beta1, beta2, epsilon,
decay, global_norm_reciprocal, start, end);
for (; i < end; ++i) {
auto temp = static_cast<float>(gradient16[i]) * global_norm_reciprocal;
m[i] += (temp - m[i]) * beta1_minus;
v[i] += (temp * temp - v[i]) * beta2_minus;
auto update = m[i] / (std::sqrt(v[i]) + epsilon);
update += decay * var[i];
var[i] -= lr * update;
}
};
} else {
float *gradient32 = reinterpret_cast<float *>(inputs[kGradIndex]->addr);
task = [&](size_t start, size_t end) {
size_t i = FusedCastAdamFp32Fp32(var, gradient32, m, v, lr, beta1, beta2, epsilon, decay, global_norm_reciprocal,
start, end);
for (; i < end; ++i) {
auto temp = gradient32[i] * global_norm_reciprocal;
m[i] += (temp - m[i]) * beta1_minus;
v[i] += (temp * temp - v[i]) * beta2_minus;
auto update = m[i] / (std::sqrt(v[i]) + epsilon);
update += decay * var[i];
var[i] -= lr * update;
}
};
}
CPUKernelUtils::ParallelFor(task, lens, kBatchSize);
}
@ -89,12 +105,12 @@ void FusedCastAdamWeightDecayCpuKernelMod::LaunchFusedCastAdamFp16(const std::ve
auto beta2 = reinterpret_cast<float *>(inputs[kBeta2Index]->addr)[kScalarIndex];
auto epsilon = reinterpret_cast<float *>(inputs[kEpsIndex]->addr)[kScalarIndex];
auto decay = reinterpret_cast<float *>(inputs[kDecayIndex]->addr)[kScalarIndex];
auto gradient16 = reinterpret_cast<float16 *>(inputs[kGradIndex]->addr);
auto var16 = reinterpret_cast<float16 *>(inputs[kVarIndex]->addr);
auto global_norm = reinterpret_cast<float *>(inputs[kGlobalNormIndex]->addr)[kScalarIndex];
if (global_norm < kMinGlobalNorm) {
global_norm = 1.0f;
}
auto global_norm_reciprocal = 1.0f / global_norm;
const auto beta1_minus = 1 - beta1;
const auto beta2_minus = 1 - beta2;
@ -103,21 +119,40 @@ void FusedCastAdamWeightDecayCpuKernelMod::LaunchFusedCastAdamFp16(const std::ve
size_t lens = inputs[kVarIndex]->size > 0 ? static_cast<size_t>(inputs[kVarIndex]->size / kSizeFloat16) : 1;
std::function<void(size_t, size_t)> task;
task = [&](size_t start, size_t end) {
size_t i = FusedCastAdamFp16(reinterpret_cast<int16_t *>(var16), m, v, lr, beta1, beta2, epsilon, decay,
reinterpret_cast<int16_t *>(gradient16), global_norm_reciprocal, start, end);
// remaining
for (; i < end; i++) {
auto temp_var = static_cast<float>(var16[i]);
auto temp_grad = static_cast<float>(gradient16[i]) * global_norm_reciprocal;
m[i] += (temp_grad - m[i]) * beta1_minus;
v[i] += (temp_grad * temp_grad - v[i]) * beta2_minus;
auto update = m[i] / (std::sqrt(v[i]) + epsilon);
update += decay * temp_var;
temp_var -= lr * update;
var16[i] = static_cast<float16>(temp_var);
}
};
if (gradient_dtype_ == kNumberTypeFloat16) {
float16 *gradient16 = reinterpret_cast<float16 *>(inputs[kGradIndex]->addr);
task = [&](size_t start, size_t end) {
size_t i = FusedCastAdamFp16Fp16(reinterpret_cast<int16_t *>(var16), reinterpret_cast<int16_t *>(gradient16), m,
v, lr, beta1, beta2, epsilon, decay, global_norm_reciprocal, start, end);
for (; i < end; i++) {
auto temp_var = static_cast<float>(var16[i]);
auto temp_grad = static_cast<float>(gradient16[i]) * global_norm_reciprocal;
m[i] += (temp_grad - m[i]) * beta1_minus;
v[i] += (temp_grad * temp_grad - v[i]) * beta2_minus;
auto update = m[i] / (std::sqrt(v[i]) + epsilon);
update += decay * temp_var;
temp_var -= lr * update;
var16[i] = static_cast<float16>(temp_var);
}
};
} else {
float *gradient32 = reinterpret_cast<float *>(inputs[kGradIndex]->addr);
task = [&](size_t start, size_t end) {
size_t i = FusedCastAdamFp16Fp32(reinterpret_cast<int16_t *>(var16), gradient32, m, v, lr, beta1, beta2, epsilon,
decay, global_norm_reciprocal, start, end);
for (; i < end; i++) {
auto temp_var = static_cast<float>(var16[i]);
auto temp_grad = gradient32[i] * global_norm_reciprocal;
m[i] += (temp_grad - m[i]) * beta1_minus;
v[i] += (temp_grad * temp_grad - v[i]) * beta2_minus;
auto update = m[i] / (std::sqrt(v[i]) + epsilon);
update += decay * temp_var;
temp_var -= lr * update;
var16[i] = static_cast<float16>(temp_var);
}
};
}
CPUKernelUtils::ParallelFor(task, lens, kBatchSize);
}
@ -147,8 +182,9 @@ void FusedCastAdamWeightDecayCpuKernelMod::InitKernel(const CNodePtr &kernel_nod
if (elem_num_ < 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'var' can not be zero.";
}
if (gradient_dtype_ != kNumberTypeFloat16) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of 'gradient' must be float16, but got "
if (var_dtype_ != kNumberTypeFloat32 && gradient_dtype_ != kNumberTypeFloat16) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of 'gradient' must be float16 or float32, but got "
<< TypeIdToType(gradient_dtype_)->ToString();
}
if (var_dtype_ != kNumberTypeFloat32 && var_dtype_ != kNumberTypeFloat16) {
@ -170,6 +206,7 @@ void FusedCastAdamWeightDecayCpuKernelMod::CheckParam(const std::vector<kernel::
size_t elem_size_fp32 = elem_num_ * kSizeFloat32;
size_t elem_size_fp16 = elem_num_ * kSizeFloat16;
size_t var_size = var_dtype_ == kNumberTypeFloat16 ? elem_size_fp16 : elem_size_fp32;
size_t grad_size = gradient_dtype_ == kNumberTypeFloat16 ? elem_size_fp16 : elem_size_fp32;
if (inputs[kVarIndex]->size != var_size) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'var' must be " << var_size << ", but got "
<< inputs[kVarIndex]->size;
@ -182,8 +219,8 @@ void FusedCastAdamWeightDecayCpuKernelMod::CheckParam(const std::vector<kernel::
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'v' must be " << elem_size_fp32
<< ", but got " << inputs[kVIndex]->size;
}
if (inputs[kGradIndex]->size != elem_size_fp16) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'gradient' must be " << elem_size_fp16
if (inputs[kGradIndex]->size != grad_size) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'gradient' must be " << grad_size
<< ", but got " << inputs[kGradIndex]->size;
}
if (inputs[kLRIndex]->size != kSizeFloat32) {

View File

@ -47,6 +47,34 @@ class FusedCastAdamWeightDecayCpuKernelMod : public DeprecatedNativeCpuKernelMod
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)

View File

@ -172,20 +172,40 @@ int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, f
return NNACL_OK;
}
size_t FusedCastAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
const int16_t *gradient16, float global_norm_reciprocal, size_t start, size_t end) {
size_t FusedCastAdamFp32Fp16(float *var, const int16_t *gradient16, float *m, float *v, float lr, float beta1,
float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start,
size_t end) {
size_t c1 = start;
SIMD_RUN_AVX512(FusedCastAdamFp32, c1, var, m, v, lr, beta1, beta2, epsilon, decay, gradient16,
SIMD_RUN_AVX512(FusedCastAdamFp32Fp16, c1, var, gradient16, m, v, lr, beta1, beta2, epsilon, decay,
global_norm_reciprocal, end);
return c1;
}
size_t FusedCastAdamFp16(int16_t *var16, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
float decay, const int16_t *gradient16, float global_norm_reciprocal, size_t start,
size_t end) {
size_t FusedCastAdamFp32Fp32(float *var, const float *gradient32, float *m, float *v, float lr, float beta1,
float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start,
size_t end) {
size_t c1 = start;
SIMD_RUN_AVX512(FusedCastAdamFp16, c1, var16, m, v, lr, beta1, beta2, epsilon, decay, gradient16,
SIMD_RUN_AVX512(FusedCastAdamFp32Fp32, c1, var, gradient32, m, v, lr, beta1, beta2, epsilon, decay,
global_norm_reciprocal, end);
return c1;
}
size_t FusedCastAdamFp16Fp16(int16_t *var16, const int16_t *gradient16, float *m, float *v, float lr, float beta1,
float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start,
size_t end) {
size_t c1 = start;
SIMD_RUN_AVX512(FusedCastAdamFp16Fp16, c1, var16, gradient16, m, v, lr, beta1, beta2, epsilon, decay,
global_norm_reciprocal, end);
return c1;
}
size_t FusedCastAdamFp16Fp32(int16_t *var16, const float *gradient32, float *m, float *v, float lr, float beta1,
float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start,
size_t end) {
size_t c1 = start;
SIMD_RUN_AVX512(FusedCastAdamFp16Fp32, c1, var16, gradient32, m, v, lr, beta1, beta2, epsilon, decay,
global_norm_reciprocal, end);
return c1;
}

View File

@ -28,11 +28,18 @@ int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float
const float *gradient, size_t start, size_t end, bool use_nesterov);
int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
const float *gradient, size_t start, size_t end);
size_t FusedCastAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
const int16_t *gradient16, float global_norm_reciprocal, size_t start, size_t end);
size_t FusedCastAdamFp16(int16_t *var16, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
float decay, const int16_t *gradient16, float global_norm_reciprocal, size_t start,
size_t end);
size_t FusedCastAdamFp32Fp16(float *var, const int16_t *gradient16, float *m, float *v, float lr, float beta1,
float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start,
size_t end);
size_t FusedCastAdamFp32Fp32(float *var, const float *gradient32, float *m, float *v, float lr, float beta1,
float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start,
size_t end);
size_t FusedCastAdamFp16Fp16(int16_t *var16, const int16_t *gradient16, float *m, float *v, float lr, float beta1,
float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start,
size_t end);
size_t FusedCastAdamFp16Fp32(int16_t *var16, const float *gradient32, float *m, float *v, float lr, float beta1,
float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start,
size_t end);
#ifdef __cplusplus
}
#endif

View File

@ -23,10 +23,9 @@
extern "C" {
#endif
@SIMD_INSTRUCTION_BEGIN@
#ifdef MS_SIMD_AVX512
static inline size_t AdamWeightDecayFp32@SIMD_INSTRUCTION@(size_t index, float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
float decay, const float *gradient, size_t end) {
static inline size_t AdamWeightDecayFp32@SIMD_INSTRUCTION@(size_t index, float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
const float *gradient, size_t end) {
SIMD_F32 beta1_r = SIMD_MOV_F32(beta1);
SIMD_F32 beta2_r = SIMD_MOV_F32(beta2);
SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1);
@ -58,8 +57,8 @@ static inline size_t AdamWeightDecayFp32@SIMD_INSTRUCTION@(size_t index, float *
return index;
}
static inline size_t FusedCastAdamFp32@SIMD_INSTRUCTION@(size_t index, float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
const int16_t *gradient16, float global_norm_reciprocal, size_t end) {
static inline size_t FusedCastAdamFp32Fp16@SIMD_INSTRUCTION@(size_t index, float *var, const int16_t *gradient16, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
float global_norm_reciprocal, size_t end) {
SIMD_F32 beta1_r = SIMD_MOV_F32(beta1);
SIMD_F32 beta2_r = SIMD_MOV_F32(beta2);
SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1);
@ -93,8 +92,43 @@ static inline size_t FusedCastAdamFp32@SIMD_INSTRUCTION@(size_t index, float *va
return index;
}
static inline size_t FusedCastAdamFp16@SIMD_INSTRUCTION@(size_t index, int16_t *var16, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
float decay, const int16_t *gradient16, float global_norm_reciprocal, size_t end) {
static inline size_t FusedCastAdamFp32Fp32@SIMD_INSTRUCTION@(size_t index, float *var, const float *gradient32, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
float global_norm_reciprocal, size_t end) {
SIMD_F32 beta1_r = SIMD_MOV_F32(beta1);
SIMD_F32 beta2_r = SIMD_MOV_F32(beta2);
SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1);
SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2);
SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr);
SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon);
SIMD_F32 decay_r = SIMD_MOV_F32(decay);
SIMD_F32 global_norm_reciprocal_r = SIMD_MOV_F32(global_norm_reciprocal);
for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
SIMD_F32 var_r = SIMD_LD_F32(var + index);
SIMD_F32 m_r = SIMD_LD_F32(m + index);
SIMD_F32 v_r = SIMD_LD_F32(v + index);
SIMD_F32 g_r = SIMD_LD_F32(gradient32 + index);
g_r = SIMD_MUL_F32(g_r, global_norm_reciprocal_r);
m_r = SIMD_MUL_F32(m_r, beta1_r);
v_r = SIMD_MUL_F32(v_r, beta2_r);
SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r);
m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r);
v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r);
avx_r0 = SIMD_SQRT_F32(v_r);
avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r));
avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0);
var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r);
SIMD_ST_F32(var + index, var_r);
SIMD_ST_F32(m + index, m_r);
SIMD_ST_F32(v + index, v_r);
}
return index;
}
static inline size_t FusedCastAdamFp16Fp16@SIMD_INSTRUCTION@(size_t index, int16_t *var16, const int16_t *gradient16, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
float global_norm_reciprocal, size_t end) {
SIMD_F32 beta1_r = SIMD_MOV_F32(beta1);
SIMD_F32 beta2_r = SIMD_MOV_F32(beta2);
SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1);
@ -126,6 +160,40 @@ static inline size_t FusedCastAdamFp16@SIMD_INSTRUCTION@(size_t index, int16_t *
return index;
}
static inline size_t FusedCastAdamFp16Fp32@SIMD_INSTRUCTION@(size_t index, int16_t *var16, const float *gradient32, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
float global_norm_reciprocal, size_t end) {
SIMD_F32 beta1_r = SIMD_MOV_F32(beta1);
SIMD_F32 beta2_r = SIMD_MOV_F32(beta2);
SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1);
SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2);
SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr);
SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon);
SIMD_F32 decay_r = SIMD_MOV_F32(decay);
SIMD_F32 global_norm_reciprocal_r = SIMD_MOV_F32(global_norm_reciprocal);
for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
SIMD_F32 var_r = SIMD_F16_TO_F32(SIMD_LD_HALF_EPI32(var16));
SIMD_F32 m_r = SIMD_LD_F32(m + index);
SIMD_F32 v_r = SIMD_LD_F32(v + index);
SIMD_F32 g_r = SIMD_LD_F32(gradient32 + index);
g_r = SIMD_MUL_F32(g_r, global_norm_reciprocal_r);
m_r = SIMD_MUL_F32(m_r, beta1_r);
v_r = SIMD_MUL_F32(v_r, beta2_r);
SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r);
m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r);
v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r);
avx_r0 = SIMD_SQRT_F32(v_r);
avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r));
avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0);
var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r);
SIMD_ST_F32(m + index, m_r);
SIMD_ST_F32(v + index, v_r);
SIMD_ST_HALF_EPI32(var16 + index, SIMD_F32_TO_F16(var_r, 0));
}
return index;
}
#endif
@SIMD_INSTRUCTION_END@

View File

@ -493,7 +493,7 @@ class FusedCastAdamWeightDecay(PrimitiveWithInfer):
args = {"m": m_dtype, "v": v_dtype}
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
validator.check_scalar_or_tensor_types_same({"var": var_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_scalar_or_tensor_types_same({"grad": grad_dtype}, [mstype.float16], self.name)
validator.check_scalar_or_tensor_types_same({"grad": grad_dtype}, [mstype.float16, mstype.float32], self.name)
args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype,
"decay": decay_dtype}

View File

@ -0,0 +1,180 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.ops.composite.clip_ops import get_square_sum
class LeNet(nn.Cell):
"""
Implements lenet.
"""
def __init__(self):
super(LeNet, self).__init__()
self.relu = P.ReLU()
self.batch_size = 1
weight1 = Tensor(np.ones([6, 3, 5, 5]).astype(np.float32) * 0.01)
weight2 = Tensor(np.ones([16, 6, 5, 5]).astype(np.float16) * 0.01)
self.conv1 = nn.Conv2d(3, 6, (5, 5), weight_init=weight1, stride=1, padding=0, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, (5, 5), weight_init=weight2, pad_mode='valid', stride=1, padding=0)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="valid")
self.reshape = P.Reshape()
self.reshape1 = P.Reshape()
self.fc1 = nn.Dense(400, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10)
def construct(self, input_x):
output = self.conv1(input_x)
output = self.relu(output)
output = self.pool(output)
output = P.Cast()(output, mstype.float16)
output = self.conv2(output)
output = P.Cast()(output, mstype.float32)
output = self.relu(output)
output = self.pool(output)
output = self.reshape(output, (self.batch_size, -1))
output = self.fc1(output)
output = self.fc2(output)
output = self.fc3(output)
return output
_adam_opt = C.MultitypeFuncGraph("adam_opt")
@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "Bool", "Bool")
def _fused_update_with_global_norm(opt, global_norm, beta1, beta2, eps, lr, weight_decay,
param, m, v, gradient, decay_flags, optim_filter):
"""
Update parameters by FusedAdamWeightDecay.
"""
success = True
if optim_filter:
if decay_flags:
next_param = opt(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient, global_norm)
else:
next_param = opt(param, m, v, lr, beta1, beta2, eps, 0.0, gradient, global_norm)
return F.depend(success, next_param)
return success
def clone_state(parameter_tuple, prefix, init):
new = []
for old_param in parameter_tuple:
new_state = Parameter(initializer(init, shape=old_param.shape, dtype=mstype.float32))
new_state.param_info = old_param.param_info.clone()
new_state.is_init = False
new_state.name = prefix + '.' + new_state.name
new.append(new_state)
return ParameterTuple(new)
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
@apply_global_norm.register("Tensor", "Tensor", "Tensor")
def _apply_global_norm(clip_norm, global_norm, grad):
return grad * clip_norm / global_norm
class GlobalNorm(nn.Cell):
"""
Calculate the global norm value of given tensors
"""
def __init__(self):
super(GlobalNorm, self).__init__()
self.norm = nn.Norm()
self.hyper_map = C.HyperMap()
def construct(self, grads):
"""Calculate global norm construct"""
square_sum = self.hyper_map(get_square_sum, grads)
global_norms = F.sqrt(F.addn(square_sum))
return global_norms
class FusedAdamWeightDecayWithGlobalNorm(Optimizer):
"""
Implements the gradient clipping by global norm for a AdamWeightDecay optimizer.
"""
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
super(FusedAdamWeightDecayWithGlobalNorm, self).__init__(learning_rate, params, weight_decay)
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
self.eps = Tensor(np.array([eps]).astype(np.float32))
self.moments1 = clone_state(self._parameters, prefix="adam_m", init='zeros')
self.moments2 = clone_state(self._parameters, prefix="adam_v", init='zeros')
self.norm = GlobalNorm()
self.opt = P.FusedCastAdamWeightDecay()
self.opt.add_prim_attr("primitive_target", "CPU")
def construct(self, gradients):
"""construct with gradients"""
global_norm = self.norm(gradients)
lr = self.get_lr()
optim_result = self.map_reverse(F.partial(_adam_opt, self.opt, global_norm,
self.beta1, self.beta2, self.eps, lr, self.weight_decay),
self._parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
return optim_result
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fused_cast_adam_weight_decay():
'''
Feature: FusedCastAdamWeightDecay
Description: Test FusedCastAdamWeightDecay
Expectation: Run lenet success
'''
context.set_context(mode=context.GRAPH_MODE)
data = Tensor(np.ones([32, 3, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([32]).astype(np.int32))
net = LeNet()
net.batch_size = 32
learning_rate = 0.01
optimizer = FusedAdamWeightDecayWithGlobalNorm(filter(lambda x: x.requires_grad, net.get_parameters()),
learning_rate)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer)
train_network.set_train()
loss = []
for _ in range(10):
res = train_network(data, label)
loss.append(res.asnumpy())
assert np.all(loss[-1] < 0.1)