support fp32 gradient for fusedcastadam
This commit is contained in:
parent
40d8645b5e
commit
378820c531
|
@ -50,7 +50,6 @@ void FusedCastAdamWeightDecayCpuKernelMod::LaunchFusedCastAdamFp32(const std::ve
|
|||
auto beta2 = reinterpret_cast<float *>(inputs[kBeta2Index]->addr)[kScalarIndex];
|
||||
auto epsilon = reinterpret_cast<float *>(inputs[kEpsIndex]->addr)[kScalarIndex];
|
||||
auto decay = reinterpret_cast<float *>(inputs[kDecayIndex]->addr)[kScalarIndex];
|
||||
auto gradient16 = reinterpret_cast<float16 *>(inputs[kGradIndex]->addr);
|
||||
auto var = reinterpret_cast<float *>(inputs[kVarIndex]->addr);
|
||||
auto global_norm = reinterpret_cast<float *>(inputs[kGlobalNormIndex]->addr)[kScalarIndex];
|
||||
if (global_norm < kMinGlobalNorm) {
|
||||
|
@ -64,19 +63,36 @@ void FusedCastAdamWeightDecayCpuKernelMod::LaunchFusedCastAdamFp32(const std::ve
|
|||
size_t lens = inputs[kVarIndex]->size > 0 ? static_cast<size_t>(inputs[kVarIndex]->size / kSizeFloat32) : 1;
|
||||
std::function<void(size_t, size_t)> task;
|
||||
|
||||
task = [&](size_t start, size_t end) {
|
||||
size_t i = FusedCastAdamFp32(var, m, v, lr, beta1, beta2, epsilon, decay, reinterpret_cast<int16_t *>(gradient16),
|
||||
global_norm_reciprocal, start, end);
|
||||
// remaining
|
||||
for (; i < end; i++) {
|
||||
auto temp = static_cast<float>(gradient16[i]) * global_norm_reciprocal;
|
||||
m[i] += (temp - m[i]) * beta1_minus;
|
||||
v[i] += (temp * temp - v[i]) * beta2_minus;
|
||||
auto update = m[i] / (std::sqrt(v[i]) + epsilon);
|
||||
update += decay * var[i];
|
||||
var[i] -= lr * update;
|
||||
}
|
||||
};
|
||||
if (gradient_dtype_ == kNumberTypeFloat16) {
|
||||
float16 *gradient16 = reinterpret_cast<float16 *>(inputs[kGradIndex]->addr);
|
||||
task = [&](size_t start, size_t end) {
|
||||
size_t i = FusedCastAdamFp32Fp16(var, reinterpret_cast<int16_t *>(gradient16), m, v, lr, beta1, beta2, epsilon,
|
||||
decay, global_norm_reciprocal, start, end);
|
||||
for (; i < end; ++i) {
|
||||
auto temp = static_cast<float>(gradient16[i]) * global_norm_reciprocal;
|
||||
m[i] += (temp - m[i]) * beta1_minus;
|
||||
v[i] += (temp * temp - v[i]) * beta2_minus;
|
||||
auto update = m[i] / (std::sqrt(v[i]) + epsilon);
|
||||
update += decay * var[i];
|
||||
var[i] -= lr * update;
|
||||
}
|
||||
};
|
||||
} else {
|
||||
float *gradient32 = reinterpret_cast<float *>(inputs[kGradIndex]->addr);
|
||||
task = [&](size_t start, size_t end) {
|
||||
size_t i = FusedCastAdamFp32Fp32(var, gradient32, m, v, lr, beta1, beta2, epsilon, decay, global_norm_reciprocal,
|
||||
start, end);
|
||||
for (; i < end; ++i) {
|
||||
auto temp = gradient32[i] * global_norm_reciprocal;
|
||||
m[i] += (temp - m[i]) * beta1_minus;
|
||||
v[i] += (temp * temp - v[i]) * beta2_minus;
|
||||
auto update = m[i] / (std::sqrt(v[i]) + epsilon);
|
||||
update += decay * var[i];
|
||||
var[i] -= lr * update;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
CPUKernelUtils::ParallelFor(task, lens, kBatchSize);
|
||||
}
|
||||
|
||||
|
@ -89,12 +105,12 @@ void FusedCastAdamWeightDecayCpuKernelMod::LaunchFusedCastAdamFp16(const std::ve
|
|||
auto beta2 = reinterpret_cast<float *>(inputs[kBeta2Index]->addr)[kScalarIndex];
|
||||
auto epsilon = reinterpret_cast<float *>(inputs[kEpsIndex]->addr)[kScalarIndex];
|
||||
auto decay = reinterpret_cast<float *>(inputs[kDecayIndex]->addr)[kScalarIndex];
|
||||
auto gradient16 = reinterpret_cast<float16 *>(inputs[kGradIndex]->addr);
|
||||
auto var16 = reinterpret_cast<float16 *>(inputs[kVarIndex]->addr);
|
||||
auto global_norm = reinterpret_cast<float *>(inputs[kGlobalNormIndex]->addr)[kScalarIndex];
|
||||
if (global_norm < kMinGlobalNorm) {
|
||||
global_norm = 1.0f;
|
||||
}
|
||||
|
||||
auto global_norm_reciprocal = 1.0f / global_norm;
|
||||
const auto beta1_minus = 1 - beta1;
|
||||
const auto beta2_minus = 1 - beta2;
|
||||
|
@ -103,21 +119,40 @@ void FusedCastAdamWeightDecayCpuKernelMod::LaunchFusedCastAdamFp16(const std::ve
|
|||
size_t lens = inputs[kVarIndex]->size > 0 ? static_cast<size_t>(inputs[kVarIndex]->size / kSizeFloat16) : 1;
|
||||
std::function<void(size_t, size_t)> task;
|
||||
|
||||
task = [&](size_t start, size_t end) {
|
||||
size_t i = FusedCastAdamFp16(reinterpret_cast<int16_t *>(var16), m, v, lr, beta1, beta2, epsilon, decay,
|
||||
reinterpret_cast<int16_t *>(gradient16), global_norm_reciprocal, start, end);
|
||||
// remaining
|
||||
for (; i < end; i++) {
|
||||
auto temp_var = static_cast<float>(var16[i]);
|
||||
auto temp_grad = static_cast<float>(gradient16[i]) * global_norm_reciprocal;
|
||||
m[i] += (temp_grad - m[i]) * beta1_minus;
|
||||
v[i] += (temp_grad * temp_grad - v[i]) * beta2_minus;
|
||||
auto update = m[i] / (std::sqrt(v[i]) + epsilon);
|
||||
update += decay * temp_var;
|
||||
temp_var -= lr * update;
|
||||
var16[i] = static_cast<float16>(temp_var);
|
||||
}
|
||||
};
|
||||
if (gradient_dtype_ == kNumberTypeFloat16) {
|
||||
float16 *gradient16 = reinterpret_cast<float16 *>(inputs[kGradIndex]->addr);
|
||||
task = [&](size_t start, size_t end) {
|
||||
size_t i = FusedCastAdamFp16Fp16(reinterpret_cast<int16_t *>(var16), reinterpret_cast<int16_t *>(gradient16), m,
|
||||
v, lr, beta1, beta2, epsilon, decay, global_norm_reciprocal, start, end);
|
||||
for (; i < end; i++) {
|
||||
auto temp_var = static_cast<float>(var16[i]);
|
||||
auto temp_grad = static_cast<float>(gradient16[i]) * global_norm_reciprocal;
|
||||
m[i] += (temp_grad - m[i]) * beta1_minus;
|
||||
v[i] += (temp_grad * temp_grad - v[i]) * beta2_minus;
|
||||
auto update = m[i] / (std::sqrt(v[i]) + epsilon);
|
||||
update += decay * temp_var;
|
||||
temp_var -= lr * update;
|
||||
var16[i] = static_cast<float16>(temp_var);
|
||||
}
|
||||
};
|
||||
} else {
|
||||
float *gradient32 = reinterpret_cast<float *>(inputs[kGradIndex]->addr);
|
||||
task = [&](size_t start, size_t end) {
|
||||
size_t i = FusedCastAdamFp16Fp32(reinterpret_cast<int16_t *>(var16), gradient32, m, v, lr, beta1, beta2, epsilon,
|
||||
decay, global_norm_reciprocal, start, end);
|
||||
for (; i < end; i++) {
|
||||
auto temp_var = static_cast<float>(var16[i]);
|
||||
auto temp_grad = gradient32[i] * global_norm_reciprocal;
|
||||
m[i] += (temp_grad - m[i]) * beta1_minus;
|
||||
v[i] += (temp_grad * temp_grad - v[i]) * beta2_minus;
|
||||
auto update = m[i] / (std::sqrt(v[i]) + epsilon);
|
||||
update += decay * temp_var;
|
||||
temp_var -= lr * update;
|
||||
var16[i] = static_cast<float16>(temp_var);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
CPUKernelUtils::ParallelFor(task, lens, kBatchSize);
|
||||
}
|
||||
|
||||
|
@ -147,8 +182,9 @@ void FusedCastAdamWeightDecayCpuKernelMod::InitKernel(const CNodePtr &kernel_nod
|
|||
if (elem_num_ < 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'var' can not be zero.";
|
||||
}
|
||||
if (gradient_dtype_ != kNumberTypeFloat16) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of 'gradient' must be float16, but got "
|
||||
|
||||
if (var_dtype_ != kNumberTypeFloat32 && gradient_dtype_ != kNumberTypeFloat16) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of 'gradient' must be float16 or float32, but got "
|
||||
<< TypeIdToType(gradient_dtype_)->ToString();
|
||||
}
|
||||
if (var_dtype_ != kNumberTypeFloat32 && var_dtype_ != kNumberTypeFloat16) {
|
||||
|
@ -170,6 +206,7 @@ void FusedCastAdamWeightDecayCpuKernelMod::CheckParam(const std::vector<kernel::
|
|||
size_t elem_size_fp32 = elem_num_ * kSizeFloat32;
|
||||
size_t elem_size_fp16 = elem_num_ * kSizeFloat16;
|
||||
size_t var_size = var_dtype_ == kNumberTypeFloat16 ? elem_size_fp16 : elem_size_fp32;
|
||||
size_t grad_size = gradient_dtype_ == kNumberTypeFloat16 ? elem_size_fp16 : elem_size_fp32;
|
||||
if (inputs[kVarIndex]->size != var_size) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'var' must be " << var_size << ", but got "
|
||||
<< inputs[kVarIndex]->size;
|
||||
|
@ -182,8 +219,8 @@ void FusedCastAdamWeightDecayCpuKernelMod::CheckParam(const std::vector<kernel::
|
|||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'v' must be " << elem_size_fp32
|
||||
<< ", but got " << inputs[kVIndex]->size;
|
||||
}
|
||||
if (inputs[kGradIndex]->size != elem_size_fp16) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'gradient' must be " << elem_size_fp16
|
||||
if (inputs[kGradIndex]->size != grad_size) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'gradient' must be " << grad_size
|
||||
<< ", but got " << inputs[kGradIndex]->size;
|
||||
}
|
||||
if (inputs[kLRIndex]->size != kSizeFloat32) {
|
||||
|
|
|
@ -47,6 +47,34 @@ class FusedCastAdamWeightDecayCpuKernelMod : public DeprecatedNativeCpuKernelMod
|
|||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
|
|
@ -172,20 +172,40 @@ int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, f
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
size_t FusedCastAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
||||
const int16_t *gradient16, float global_norm_reciprocal, size_t start, size_t end) {
|
||||
size_t FusedCastAdamFp32Fp16(float *var, const int16_t *gradient16, float *m, float *v, float lr, float beta1,
|
||||
float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start,
|
||||
size_t end) {
|
||||
size_t c1 = start;
|
||||
|
||||
SIMD_RUN_AVX512(FusedCastAdamFp32, c1, var, m, v, lr, beta1, beta2, epsilon, decay, gradient16,
|
||||
SIMD_RUN_AVX512(FusedCastAdamFp32Fp16, c1, var, gradient16, m, v, lr, beta1, beta2, epsilon, decay,
|
||||
global_norm_reciprocal, end);
|
||||
return c1;
|
||||
}
|
||||
|
||||
size_t FusedCastAdamFp16(int16_t *var16, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
|
||||
float decay, const int16_t *gradient16, float global_norm_reciprocal, size_t start,
|
||||
size_t end) {
|
||||
size_t FusedCastAdamFp32Fp32(float *var, const float *gradient32, float *m, float *v, float lr, float beta1,
|
||||
float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start,
|
||||
size_t end) {
|
||||
size_t c1 = start;
|
||||
SIMD_RUN_AVX512(FusedCastAdamFp16, c1, var16, m, v, lr, beta1, beta2, epsilon, decay, gradient16,
|
||||
|
||||
SIMD_RUN_AVX512(FusedCastAdamFp32Fp32, c1, var, gradient32, m, v, lr, beta1, beta2, epsilon, decay,
|
||||
global_norm_reciprocal, end);
|
||||
return c1;
|
||||
}
|
||||
|
||||
size_t FusedCastAdamFp16Fp16(int16_t *var16, const int16_t *gradient16, float *m, float *v, float lr, float beta1,
|
||||
float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start,
|
||||
size_t end) {
|
||||
size_t c1 = start;
|
||||
SIMD_RUN_AVX512(FusedCastAdamFp16Fp16, c1, var16, gradient16, m, v, lr, beta1, beta2, epsilon, decay,
|
||||
global_norm_reciprocal, end);
|
||||
return c1;
|
||||
}
|
||||
|
||||
size_t FusedCastAdamFp16Fp32(int16_t *var16, const float *gradient32, float *m, float *v, float lr, float beta1,
|
||||
float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start,
|
||||
size_t end) {
|
||||
size_t c1 = start;
|
||||
SIMD_RUN_AVX512(FusedCastAdamFp16Fp32, c1, var16, gradient32, m, v, lr, beta1, beta2, epsilon, decay,
|
||||
global_norm_reciprocal, end);
|
||||
return c1;
|
||||
}
|
||||
|
|
|
@ -28,11 +28,18 @@ int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float
|
|||
const float *gradient, size_t start, size_t end, bool use_nesterov);
|
||||
int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
||||
const float *gradient, size_t start, size_t end);
|
||||
size_t FusedCastAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
||||
const int16_t *gradient16, float global_norm_reciprocal, size_t start, size_t end);
|
||||
size_t FusedCastAdamFp16(int16_t *var16, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
|
||||
float decay, const int16_t *gradient16, float global_norm_reciprocal, size_t start,
|
||||
size_t end);
|
||||
size_t FusedCastAdamFp32Fp16(float *var, const int16_t *gradient16, float *m, float *v, float lr, float beta1,
|
||||
float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start,
|
||||
size_t end);
|
||||
size_t FusedCastAdamFp32Fp32(float *var, const float *gradient32, float *m, float *v, float lr, float beta1,
|
||||
float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start,
|
||||
size_t end);
|
||||
size_t FusedCastAdamFp16Fp16(int16_t *var16, const int16_t *gradient16, float *m, float *v, float lr, float beta1,
|
||||
float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start,
|
||||
size_t end);
|
||||
size_t FusedCastAdamFp16Fp32(int16_t *var16, const float *gradient32, float *m, float *v, float lr, float beta1,
|
||||
float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start,
|
||||
size_t end);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -23,10 +23,9 @@
|
|||
extern "C" {
|
||||
#endif
|
||||
@SIMD_INSTRUCTION_BEGIN@
|
||||
|
||||
#ifdef MS_SIMD_AVX512
|
||||
static inline size_t AdamWeightDecayFp32@SIMD_INSTRUCTION@(size_t index, float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
|
||||
float decay, const float *gradient, size_t end) {
|
||||
static inline size_t AdamWeightDecayFp32@SIMD_INSTRUCTION@(size_t index, float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
||||
const float *gradient, size_t end) {
|
||||
SIMD_F32 beta1_r = SIMD_MOV_F32(beta1);
|
||||
SIMD_F32 beta2_r = SIMD_MOV_F32(beta2);
|
||||
SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1);
|
||||
|
@ -58,8 +57,8 @@ static inline size_t AdamWeightDecayFp32@SIMD_INSTRUCTION@(size_t index, float *
|
|||
return index;
|
||||
}
|
||||
|
||||
static inline size_t FusedCastAdamFp32@SIMD_INSTRUCTION@(size_t index, float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
||||
const int16_t *gradient16, float global_norm_reciprocal, size_t end) {
|
||||
static inline size_t FusedCastAdamFp32Fp16@SIMD_INSTRUCTION@(size_t index, float *var, const int16_t *gradient16, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
||||
float global_norm_reciprocal, size_t end) {
|
||||
SIMD_F32 beta1_r = SIMD_MOV_F32(beta1);
|
||||
SIMD_F32 beta2_r = SIMD_MOV_F32(beta2);
|
||||
SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1);
|
||||
|
@ -93,8 +92,43 @@ static inline size_t FusedCastAdamFp32@SIMD_INSTRUCTION@(size_t index, float *va
|
|||
return index;
|
||||
}
|
||||
|
||||
static inline size_t FusedCastAdamFp16@SIMD_INSTRUCTION@(size_t index, int16_t *var16, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
|
||||
float decay, const int16_t *gradient16, float global_norm_reciprocal, size_t end) {
|
||||
static inline size_t FusedCastAdamFp32Fp32@SIMD_INSTRUCTION@(size_t index, float *var, const float *gradient32, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
||||
float global_norm_reciprocal, size_t end) {
|
||||
SIMD_F32 beta1_r = SIMD_MOV_F32(beta1);
|
||||
SIMD_F32 beta2_r = SIMD_MOV_F32(beta2);
|
||||
SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1);
|
||||
SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2);
|
||||
SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr);
|
||||
SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon);
|
||||
SIMD_F32 decay_r = SIMD_MOV_F32(decay);
|
||||
SIMD_F32 global_norm_reciprocal_r = SIMD_MOV_F32(global_norm_reciprocal);
|
||||
|
||||
for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
|
||||
SIMD_F32 var_r = SIMD_LD_F32(var + index);
|
||||
SIMD_F32 m_r = SIMD_LD_F32(m + index);
|
||||
SIMD_F32 v_r = SIMD_LD_F32(v + index);
|
||||
SIMD_F32 g_r = SIMD_LD_F32(gradient32 + index);
|
||||
|
||||
g_r = SIMD_MUL_F32(g_r, global_norm_reciprocal_r);
|
||||
m_r = SIMD_MUL_F32(m_r, beta1_r);
|
||||
v_r = SIMD_MUL_F32(v_r, beta2_r);
|
||||
SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r);
|
||||
m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r);
|
||||
v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r);
|
||||
avx_r0 = SIMD_SQRT_F32(v_r);
|
||||
avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r));
|
||||
avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0);
|
||||
var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r);
|
||||
SIMD_ST_F32(var + index, var_r);
|
||||
SIMD_ST_F32(m + index, m_r);
|
||||
SIMD_ST_F32(v + index, v_r);
|
||||
}
|
||||
|
||||
return index;
|
||||
}
|
||||
|
||||
static inline size_t FusedCastAdamFp16Fp16@SIMD_INSTRUCTION@(size_t index, int16_t *var16, const int16_t *gradient16, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
||||
float global_norm_reciprocal, size_t end) {
|
||||
SIMD_F32 beta1_r = SIMD_MOV_F32(beta1);
|
||||
SIMD_F32 beta2_r = SIMD_MOV_F32(beta2);
|
||||
SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1);
|
||||
|
@ -126,6 +160,40 @@ static inline size_t FusedCastAdamFp16@SIMD_INSTRUCTION@(size_t index, int16_t *
|
|||
|
||||
return index;
|
||||
}
|
||||
|
||||
static inline size_t FusedCastAdamFp16Fp32@SIMD_INSTRUCTION@(size_t index, int16_t *var16, const float *gradient32, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
||||
float global_norm_reciprocal, size_t end) {
|
||||
SIMD_F32 beta1_r = SIMD_MOV_F32(beta1);
|
||||
SIMD_F32 beta2_r = SIMD_MOV_F32(beta2);
|
||||
SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1);
|
||||
SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2);
|
||||
SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr);
|
||||
SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon);
|
||||
SIMD_F32 decay_r = SIMD_MOV_F32(decay);
|
||||
SIMD_F32 global_norm_reciprocal_r = SIMD_MOV_F32(global_norm_reciprocal);
|
||||
|
||||
for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
|
||||
SIMD_F32 var_r = SIMD_F16_TO_F32(SIMD_LD_HALF_EPI32(var16));
|
||||
SIMD_F32 m_r = SIMD_LD_F32(m + index);
|
||||
SIMD_F32 v_r = SIMD_LD_F32(v + index);
|
||||
SIMD_F32 g_r = SIMD_LD_F32(gradient32 + index);
|
||||
g_r = SIMD_MUL_F32(g_r, global_norm_reciprocal_r);
|
||||
m_r = SIMD_MUL_F32(m_r, beta1_r);
|
||||
v_r = SIMD_MUL_F32(v_r, beta2_r);
|
||||
SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r);
|
||||
m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r);
|
||||
v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r);
|
||||
avx_r0 = SIMD_SQRT_F32(v_r);
|
||||
avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r));
|
||||
avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0);
|
||||
var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r);
|
||||
SIMD_ST_F32(m + index, m_r);
|
||||
SIMD_ST_F32(v + index, v_r);
|
||||
SIMD_ST_HALF_EPI32(var16 + index, SIMD_F32_TO_F16(var_r, 0));
|
||||
}
|
||||
|
||||
return index;
|
||||
}
|
||||
#endif
|
||||
|
||||
@SIMD_INSTRUCTION_END@
|
||||
|
|
|
@ -493,7 +493,7 @@ class FusedCastAdamWeightDecay(PrimitiveWithInfer):
|
|||
args = {"m": m_dtype, "v": v_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
|
||||
validator.check_scalar_or_tensor_types_same({"var": var_dtype}, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_scalar_or_tensor_types_same({"grad": grad_dtype}, [mstype.float16], self.name)
|
||||
validator.check_scalar_or_tensor_types_same({"grad": grad_dtype}, [mstype.float16, mstype.float32], self.name)
|
||||
|
||||
args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype,
|
||||
"decay": decay_dtype}
|
||||
|
|
|
@ -0,0 +1,180 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
from mindspore.ops.composite.clip_ops import get_square_sum
|
||||
|
||||
|
||||
class LeNet(nn.Cell):
|
||||
"""
|
||||
Implements lenet.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(LeNet, self).__init__()
|
||||
self.relu = P.ReLU()
|
||||
self.batch_size = 1
|
||||
weight1 = Tensor(np.ones([6, 3, 5, 5]).astype(np.float32) * 0.01)
|
||||
weight2 = Tensor(np.ones([16, 6, 5, 5]).astype(np.float16) * 0.01)
|
||||
self.conv1 = nn.Conv2d(3, 6, (5, 5), weight_init=weight1, stride=1, padding=0, pad_mode='valid')
|
||||
self.conv2 = nn.Conv2d(6, 16, (5, 5), weight_init=weight2, pad_mode='valid', stride=1, padding=0)
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="valid")
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.reshape1 = P.Reshape()
|
||||
|
||||
self.fc1 = nn.Dense(400, 120)
|
||||
self.fc2 = nn.Dense(120, 84)
|
||||
self.fc3 = nn.Dense(84, 10)
|
||||
|
||||
def construct(self, input_x):
|
||||
output = self.conv1(input_x)
|
||||
output = self.relu(output)
|
||||
output = self.pool(output)
|
||||
output = P.Cast()(output, mstype.float16)
|
||||
output = self.conv2(output)
|
||||
output = P.Cast()(output, mstype.float32)
|
||||
output = self.relu(output)
|
||||
output = self.pool(output)
|
||||
output = self.reshape(output, (self.batch_size, -1))
|
||||
output = self.fc1(output)
|
||||
output = self.fc2(output)
|
||||
output = self.fc3(output)
|
||||
return output
|
||||
|
||||
|
||||
_adam_opt = C.MultitypeFuncGraph("adam_opt")
|
||||
|
||||
|
||||
@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Bool", "Bool")
|
||||
def _fused_update_with_global_norm(opt, global_norm, beta1, beta2, eps, lr, weight_decay,
|
||||
param, m, v, gradient, decay_flags, optim_filter):
|
||||
"""
|
||||
Update parameters by FusedAdamWeightDecay.
|
||||
"""
|
||||
success = True
|
||||
if optim_filter:
|
||||
if decay_flags:
|
||||
next_param = opt(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient, global_norm)
|
||||
else:
|
||||
next_param = opt(param, m, v, lr, beta1, beta2, eps, 0.0, gradient, global_norm)
|
||||
return F.depend(success, next_param)
|
||||
return success
|
||||
|
||||
|
||||
def clone_state(parameter_tuple, prefix, init):
|
||||
new = []
|
||||
for old_param in parameter_tuple:
|
||||
new_state = Parameter(initializer(init, shape=old_param.shape, dtype=mstype.float32))
|
||||
new_state.param_info = old_param.param_info.clone()
|
||||
new_state.is_init = False
|
||||
new_state.name = prefix + '.' + new_state.name
|
||||
new.append(new_state)
|
||||
return ParameterTuple(new)
|
||||
|
||||
|
||||
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
|
||||
|
||||
|
||||
@apply_global_norm.register("Tensor", "Tensor", "Tensor")
|
||||
def _apply_global_norm(clip_norm, global_norm, grad):
|
||||
return grad * clip_norm / global_norm
|
||||
|
||||
|
||||
class GlobalNorm(nn.Cell):
|
||||
"""
|
||||
Calculate the global norm value of given tensors
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(GlobalNorm, self).__init__()
|
||||
self.norm = nn.Norm()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, grads):
|
||||
"""Calculate global norm construct"""
|
||||
square_sum = self.hyper_map(get_square_sum, grads)
|
||||
global_norms = F.sqrt(F.addn(square_sum))
|
||||
return global_norms
|
||||
|
||||
|
||||
class FusedAdamWeightDecayWithGlobalNorm(Optimizer):
|
||||
"""
|
||||
Implements the gradient clipping by global norm for a AdamWeightDecay optimizer.
|
||||
"""
|
||||
|
||||
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
|
||||
super(FusedAdamWeightDecayWithGlobalNorm, self).__init__(learning_rate, params, weight_decay)
|
||||
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
|
||||
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
|
||||
self.eps = Tensor(np.array([eps]).astype(np.float32))
|
||||
self.moments1 = clone_state(self._parameters, prefix="adam_m", init='zeros')
|
||||
self.moments2 = clone_state(self._parameters, prefix="adam_v", init='zeros')
|
||||
self.norm = GlobalNorm()
|
||||
self.opt = P.FusedCastAdamWeightDecay()
|
||||
self.opt.add_prim_attr("primitive_target", "CPU")
|
||||
|
||||
def construct(self, gradients):
|
||||
"""construct with gradients"""
|
||||
global_norm = self.norm(gradients)
|
||||
lr = self.get_lr()
|
||||
optim_result = self.map_reverse(F.partial(_adam_opt, self.opt, global_norm,
|
||||
self.beta1, self.beta2, self.eps, lr, self.weight_decay),
|
||||
self._parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
return optim_result
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fused_cast_adam_weight_decay():
|
||||
'''
|
||||
Feature: FusedCastAdamWeightDecay
|
||||
Description: Test FusedCastAdamWeightDecay
|
||||
Expectation: Run lenet success
|
||||
'''
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
data = Tensor(np.ones([32, 3, 32, 32]).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.ones([32]).astype(np.int32))
|
||||
net = LeNet()
|
||||
net.batch_size = 32
|
||||
learning_rate = 0.01
|
||||
optimizer = FusedAdamWeightDecayWithGlobalNorm(filter(lambda x: x.requires_grad, net.get_parameters()),
|
||||
learning_rate)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
net_with_criterion = WithLossCell(net, criterion)
|
||||
train_network = TrainOneStepCell(net_with_criterion, optimizer)
|
||||
train_network.set_train()
|
||||
loss = []
|
||||
for _ in range(10):
|
||||
res = train_network(data, label)
|
||||
loss.append(res.asnumpy())
|
||||
assert np.all(loss[-1] < 0.1)
|
Loading…
Reference in New Issue