zsd AdamWeightDecay CPU op avx512

This commit is contained in:
zhaosida 2021-07-01 17:20:54 +08:00
parent 20a5e30481
commit 9a42eda0c8
8 changed files with 399 additions and 148 deletions

View File

@ -66,18 +66,6 @@ if(ENABLE_CPU)
if(PLATFORM_ARM64)
add_compile_definitions(ENABLE_ARM)
endif()
if("${ARM_SIMD}" STREQUAL "neon")
set(CPU_SIMD_SRC "${CMAKE_CURRENT_SOURCE_DIR}/cpu/adam_weight_decay_cpu_kernel.cc")
add_compile_definitions(ENABLE_NEON)
set_property(SOURCE ${CPU_SIMD_SRC} PROPERTY COMPILE_OPTIONS -O3 -ffast-math)
endif()
if("${X86_64_SIMD}" STREQUAL "avx")
set(CPU_SIMD_SRC "${CMAKE_CURRENT_SOURCE_DIR}/cpu/adam_weight_decay_cpu_kernel.cc")
add_compile_definitions(ENABLE_AVX512)
set_property(SOURCE ${CPU_SIMD_SRC} PROPERTY COMPILE_OPTIONS -O3 -fopenmp -mavx512f -ffast-math)
endif()
endif()
if(NOT ENABLE_CPU OR WIN32)

View File

@ -18,75 +18,94 @@
#include <cmath>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "nnacl/fp32/adam_fp32.h"
#include "utils/ms_utils.h"
#include "common/thread_pool.h"
namespace mindspore {
namespace kernel {
template <typename T>
void AdamWeightDecayCPUKernel::LaunchAdamWeightDecay(T *var, T *m, T *v, float lr, float beta1, float beta2,
float epsilon, T *decay, const T *gradient, size_t size) {
constexpr size_t kSizeFloat16 = sizeof(float16);
constexpr size_t kSizeFloat32 = sizeof(float);
constexpr size_t kAdamWeightDecayInputSize = 9;
constexpr size_t kAdamWeightDecayOutputSize = 3;
void AdamWeightDecayCPUKernel::ParallelForAdam(const CTask &task, size_t count) {
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
const float block_size = 128.0;
const float align_size = 16.0;
size_t thread_num = count < block_size * max_thread_num ? std::ceil(count / block_size) : max_thread_num;
std::vector<common::Task> tasks;
size_t start = 0;
size_t once_compute_size = align_size * std::ceil(count / (align_size * thread_num));
while (start < count) {
size_t end = (start + once_compute_size) > count ? count : (start + once_compute_size);
auto block = [&, start, end]() {
task(start, end);
return common::SUCCESS;
};
tasks.emplace_back(block);
start += once_compute_size;
}
common::ThreadPool::GetInstance().SyncRun(tasks);
}
template <typename T, typename S>
void AdamWeightDecayCPUKernel::LaunchFusedAdam(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto var = reinterpret_cast<T *>(inputs[0]->addr);
auto m = reinterpret_cast<T *>(inputs[1]->addr);
auto v = reinterpret_cast<T *>(inputs[2]->addr);
auto lr = reinterpret_cast<T *>(inputs[3]->addr)[0];
auto beta1 = reinterpret_cast<T *>(inputs[4]->addr)[0];
auto beta2 = reinterpret_cast<T *>(inputs[5]->addr)[0];
auto epsilon = reinterpret_cast<T *>(inputs[6]->addr)[0];
auto decay = reinterpret_cast<T *>(inputs[7]->addr);
auto gradient16 = reinterpret_cast<S *>(inputs[8]->addr);
float beta1_minus = 1 - beta1;
float beta2_minus = 1 - beta2;
#if defined(ENABLE_AVX512)
MS_FLOAT32X16 beta1_16 = MS_MOV512_F32(beta1);
MS_FLOAT32X16 beta2_16 = MS_MOV512_F32(beta2);
MS_FLOAT32X16 beta1_minus_16 = MS_MOV512_F32(beta1_minus);
MS_FLOAT32X16 beta2_minus_16 = MS_MOV512_F32(beta2_minus);
MS_FLOAT32X16 lr_neg_16 = MS_MOV512_F32(-lr);
MS_FLOAT32X16 epsilon_16 = MS_MOV512_F32(epsilon);
MS_FLOAT32X16 decay_16 = MS_MOV512_F32(*decay);
#endif
#if defined(ENABLE_NEON)
MS_FLOAT32X4 epsilon_4 = MS_MOVQ_F32(epsilon);
float lr_neg = -lr;
#endif
// multithreading
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
std::function<void(size_t, size_t)> task;
auto task = [&](size_t start, size_t end) {
size_t i = start;
#if defined(ENABLE_AVX512)
if (end >= MS_AVX512_WIDTH) {
for (; i <= end - MS_AVX512_WIDTH; i += MS_AVX512_WIDTH) {
MS_FLOAT32X16 var_16 = MS_LD512_F32(var + i);
MS_FLOAT32X16 m_16 = MS_LD512_F32(m + i);
MS_FLOAT32X16 v_16 = MS_LD512_F32(v + i);
MS_FLOAT32X16 g_16 = MS_LD512_F32(gradient + i);
m_16 = MS_MUL512_F32(m_16, beta1_16);
m_16 = MS_FMA512_F32(g_16, beta1_minus_16, m_16);
v_16 = MS_MUL512_F32(v_16, beta2_16);
v_16 = MS_MUL512_F32(g_16, g_16);
v_16 = MS_FMA512_F32(g_16, beta2_minus_16, v_16);
g_16 = MS_SQRT512_F32(v_16);
g_16 = MS_DIV512_F32(m_16, MS_ADD512_F32(g_16, epsilon_16));
g_16 = MS_FMA512_F32(var_16, decay_16, g_16);
var_16 = MS_FMA512_F32(g_16, lr_neg_16, var_16);
MS_ST512_F32(var + i, var_16);
MS_ST512_F32(m + i, m_16);
MS_ST512_F32(v + i, v_16);
}
task = [&](size_t start, size_t end) {
size_t i =
FusedAdamFp32(var, m, v, lr, beta1, beta2, epsilon, decay, reinterpret_cast<int16_t *>(gradient16), start, end);
// remaining
for (; i < end; i++) {
auto temp = static_cast<float>(gradient16[i]);
m[i] += (temp - m[i]) * beta1_minus;
v[i] += (temp * temp - v[i]) * beta2_minus;
T update = m[i] / (std::sqrt(v[i]) + epsilon);
update += *decay * var[i];
var[i] -= lr * update;
}
#endif
#if defined(ENABLE_NEON)
if (end >= MS_NEON_WIDTH) {
for (; i <= end - MS_NEON_WIDTH; i += MS_NEON_WIDTH) {
MS_FLOAT32X4 var_4 = MS_LDQ_F32(var + i);
MS_FLOAT32X4 m_4 = MS_LDQ_F32(m + i);
MS_FLOAT32X4 v_4 = MS_LDQ_F32(v + i);
MS_FLOAT32X4 g_4 = MS_LDQ_F32(gradient + i);
m_4 = MS_MULQ_N_F32(m_4, beta1);
m_4 = MS_MLAQ_N_F32(m_4, g_4, beta1_minus);
v_4 = MS_MULQ_N_F32(v_4, beta2);
g_4 = MS_MULQ_F32(g_4, g_4);
v_4 = MS_MLAQ_N_F32(v_4, g_4, beta2_minus);
g_4 = MS_SQRT_F32(v_4);
g_4 = MS_DIVQ_F32(m_4, MS_ADDQ_F32(g_4, epsilon_4));
g_4 = MS_MLAQ_N_F32(g_4, var_4, *decay);
var_4 = MS_MLAQ_N_F32(var_4, g_4, lr_neg);
MS_STQ_F32(var + i, var_4);
MS_STQ_F32(m + i, m_4);
MS_STQ_F32(v + i, v_4);
}
}
#endif
};
ParallelForAdam(task, lens);
}
template <typename T>
void AdamWeightDecayCPUKernel::LaunchAdamWeightDecay(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto var = reinterpret_cast<T *>(inputs[0]->addr);
auto m = reinterpret_cast<T *>(inputs[1]->addr);
auto v = reinterpret_cast<T *>(inputs[2]->addr);
auto lr = reinterpret_cast<T *>(inputs[3]->addr)[0];
auto beta1 = reinterpret_cast<T *>(inputs[4]->addr)[0];
auto beta2 = reinterpret_cast<T *>(inputs[5]->addr)[0];
auto epsilon = reinterpret_cast<T *>(inputs[6]->addr)[0];
auto decay = reinterpret_cast<T *>(inputs[7]->addr);
auto gradient = reinterpret_cast<T *>(inputs[8]->addr);
auto beta1_minus = 1 - beta1;
auto beta2_minus = 1 - beta2;
// multithreading
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
std::function<void(size_t, size_t)> task;
task = [&](size_t start, size_t end) {
size_t i = AdamWeightDecayFp32(var, m, v, lr, beta1, beta2, epsilon, decay, gradient, start, end);
// remaining
for (; i < end; i++) {
m[i] += (gradient[i] - m[i]) * beta1_minus;
v[i] += (gradient[i] * gradient[i] - v[i]) * beta2_minus;
@ -95,51 +114,66 @@ void AdamWeightDecayCPUKernel::LaunchAdamWeightDecay(T *var, T *m, T *v, float l
var[i] -= lr * update;
}
};
CPUKernelUtils::ParallelFor(task, size);
ParallelForAdam(task, lens);
}
void AdamWeightDecayCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> var_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
gradient_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 8);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 9) {
if (input_num != kAdamWeightDecayInputSize) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but AdamWeightDecay needs 9 inputs.";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 3) {
if (output_num != kAdamWeightDecayOutputSize) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but AdamWeightDecay needs 3 outputs.";
}
elem_num_ = 1;
for (size_t i : var_shape) {
elem_num_ *= i;
}
if (elem_num_ < 1) {
MS_LOG(EXCEPTION) << "Invalid parameter shape";
}
if (dtype_ != kNumberTypeFloat32) {
MS_LOG(EXCEPTION) << "The dtype of parameter must be float32!";
}
if (gradient_dtype_ != kNumberTypeFloat32 && gradient_dtype_ != kNumberTypeFloat16) {
MS_LOG(EXCEPTION) << "The dtype of gradient must be float32 or float16!";
}
}
void AdamWeightDecayCPUKernel::CheckParam(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.size() != kAdamWeightDecayInputSize) {
MS_LOG(EXCEPTION) << "Input number is " << inputs.size() << ", but AdamWeightDecay needs 9 inputs.";
}
if (outputs.size() != kAdamWeightDecayOutputSize) {
MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but AdamWeightDecay needs 3 outputs.";
}
size_t elem1_size = elem_num_ * kSizeFloat32;
size_t elem2_size = gradient_dtype_ == kNumberTypeFloat16 ? elem_num_ * kSizeFloat16 : elem1_size;
if (inputs[0]->size != elem1_size || inputs[1]->size != elem1_size || inputs[2]->size != elem1_size ||
inputs[8]->size != elem2_size) {
MS_LOG(EXCEPTION) << "Error input data size!";
}
if (inputs[3]->size != kSizeFloat32 || inputs[4]->size != kSizeFloat32 || inputs[5]->size != kSizeFloat32 ||
inputs[6]->size != kSizeFloat32 || inputs[7]->size != kSizeFloat32) {
MS_LOG(EXCEPTION) << "The attribute beta, lr, epsilon and weight decay must be float!";
}
}
bool AdamWeightDecayCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.size() != 9) {
MS_LOG(EXCEPTION) << "Input number is " << inputs.size() << ", but AdamWeightDecay needs 9 inputs.";
CheckParam(inputs, outputs);
if (gradient_dtype_ == kNumberTypeFloat16) {
LaunchFusedAdam<float, float16>(inputs, outputs);
} else if (gradient_dtype_ == kNumberTypeFloat32) {
LaunchAdamWeightDecay<float>(inputs, outputs);
}
if (outputs.size() != 3) {
MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but AdamWeightDecay needs 3 outputs.";
}
if (inputs[0]->size != inputs[1]->size || inputs[0]->size != inputs[2]->size || inputs[0]->size != inputs[8]->size) {
MS_LOG(EXCEPTION) << "Error input data size!";
}
size_t f_size = sizeof(float);
if (inputs[3]->size != f_size || inputs[4]->size != f_size || inputs[5]->size != f_size ||
inputs[6]->size != f_size || inputs[7]->size != f_size) {
MS_LOG(EXCEPTION) << "The attribute beta, lr and epsilon must be float!";
}
auto var = reinterpret_cast<float *>(inputs[0]->addr);
auto m = reinterpret_cast<float *>(inputs[1]->addr);
auto v = reinterpret_cast<float *>(inputs[2]->addr);
float lr = reinterpret_cast<float *>(inputs[3]->addr)[0];
float beta1 = reinterpret_cast<float *>(inputs[4]->addr)[0];
float beta2 = reinterpret_cast<float *>(inputs[5]->addr)[0];
float epsilon = reinterpret_cast<float *>(inputs[6]->addr)[0];
auto decay = reinterpret_cast<float *>(inputs[7]->addr);
auto gradient = reinterpret_cast<float *>(inputs[8]->addr);
// multithreading
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
LaunchAdamWeightDecay<float>(var, m, v, lr, beta1, beta2, epsilon, decay, gradient, lens);
return true;
}
} // namespace kernel

View File

@ -21,59 +21,26 @@
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#if defined(ENABLE_AVX512)
#include <x86intrin.h>
#endif
#ifdef ENABLE_NEON
#define MS_FLOAT32X4 float32x4_t
#define MS_LDQ_F32 vld1q_f32
#define MS_MOVQ_F32 vmovq_n_f32
#define MS_STQ_F32 vst1q_f32
#define MS_ADDQ_F32(src1, src2) vaddq_f32(src1, src2)
#define MS_MULQ_F32(src1, src2) vmulq_f32(src1, src2)
#define MS_MULQ_N_F32(src1, src2) vmulq_n_f32(src1, src2)
#define MS_DIVQ_F32(src1, src2) vdivq_f32(src1, src2)
#define MS_MLAQ_F32(src1, src2, src3) vmlaq_f32(src1, src2, src3)
#define MS_MLAQ_N_F32(src1, src2, src3) vmlaq_n_f32(src1, src2, src3)
#define MS_SQRT_F32(src) vsqrtq_f32(src)
#define MS_CAST_F32_F16(src) vreinterpretq_f32_f16(src)
#define MS_NEON_WIDTH 4
#endif
#if defined(ENABLE_AVX512)
#define MS_FLOAT32X16 __m512
#define MS_LD512_F32 _mm512_loadu_ps
#define MS_ST512_F32 _mm512_storeu_ps
#define MS_MOV512_F32 _mm512_set1_ps
#define MS_ADD512_F32(src1, src2) _mm512_add_ps(src1, src2)
#define MS_MUL512_F32(src1, src2) _mm512_mul_ps(src1, src2)
#define MS_DIV512_F32(src1, src2) _mm512_div_ps(src1, src2)
#define MS_FMA512_F32(src1, src2, src3) _mm512_fmadd_ps(src1, src2, src3)
#define MS_SQRT512_F32(src) _mm512_sqrt_ps(src)
#define MS_CAST512_F32_S32(src) _mm512_castsi512_ps(src)
#define MS_AVX512_WIDTH 16
#endif
namespace mindspore {
namespace kernel {
class AdamWeightDecayCPUKernel : public CPUKernel {
public:
AdamWeightDecayCPUKernel() = default;
~AdamWeightDecayCPUKernel() override = default;
template <typename T>
void LaunchAdamWeightDecay(T *var, T *m, T *v, float lr, float beta1, float beta2, float epsilon, T *decay,
const T *gradient, size_t size);
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
private:
void ParallelForAdam(const CTask &task, size_t count);
void CheckParam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T, typename S>
void LaunchFusedAdam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T>
void LaunchAdamWeightDecay(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
size_t elem_num_{0};
TypeId dtype_{kTypeUnknown};
TypeId gradient_dtype_{kTypeUnknown};
};
MS_REG_CPU_KERNEL(AdamWeightDecay,
@ -91,6 +58,22 @@ MS_REG_CPU_KERNEL(AdamWeightDecay,
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
AdamWeightDecayCPUKernel)
MS_REG_CPU_KERNEL(AdamWeightDecay,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
AdamWeightDecayCPUKernel)
} // namespace kernel
} // namespace mindspore

View File

@ -72,7 +72,8 @@ if(ENABLE_CPU)
elseif("${X86_64_SIMD}" STREQUAL "sse")
target_compile_definitions(nnacl_mid PRIVATE ENABLE_SSE)
elseif("${X86_64_SIMD}" STREQUAL "avx")
target_compile_definitions(nnacl_mid PRIVATE ENABLE_SSE ENABLE_AVX)
target_compile_definitions(nnacl_mid PRIVATE ENABLE_SSE ENABLE_AVX ENABLE_AVX512)
target_compile_options(nnacl_mid PRIVATE -mavx512f)
endif()
target_compile_options(nnacl_mid PRIVATE -fPIC)
if(NOT CMAKE_SYSTEM_NAME MATCHES "Windows")

View File

@ -25,6 +25,25 @@
#include "nnacl/fp32/exp_fp32.h"
#include "nnacl/fp32/adam_fp32.h"
#ifdef ENABLE_AVX512
struct AVX_Data {
__m512 data;
};
inline void LoadStep4(struct AVX_Data *inp0, const float *inp1) {
inp0[0].data = _mm512_loadu_ps(inp1);
inp0[1].data = _mm512_loadu_ps(inp1 + C16NUM);
inp0[2].data = _mm512_loadu_ps(inp1 + C16NUM * 2);
inp0[3].data = _mm512_loadu_ps(inp1 + C16NUM * 3);
}
inline void StoreStep4(float *inp0, struct AVX_Data *inp1) {
_mm512_storeu_ps(inp0, inp1[0].data);
_mm512_storeu_ps(inp0 + C16NUM, inp1[1].data);
_mm512_storeu_ps(inp0 + C16NUM * 2, inp1[2].data);
_mm512_storeu_ps(inp0 + C16NUM * 3, inp1[3].data);
}
#endif
int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, const float *gradient,
size_t start, size_t end, bool use_nesterov) {
size_t c1 = start;
@ -159,3 +178,223 @@ int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float
}
return NNACL_OK;
}
int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float *decay,
const float *gradient, size_t start, size_t end) {
size_t c1 = start;
#ifdef ENABLE_AVX512
float beta1_minus = 1 - beta1;
float beta2_minus = 1 - beta2;
struct AVX_Data beta1_r, beta2_r, beta1_minus_r, beta2_minus_r, lr_neg_r, epsilon_r, decay_r;
beta1_r.data = _mm512_set1_ps(beta1);
beta2_r.data = _mm512_set1_ps(beta2);
beta1_minus_r.data = _mm512_set1_ps(beta1_minus);
beta2_minus_r.data = _mm512_set1_ps(beta2_minus);
lr_neg_r.data = _mm512_set1_ps(-lr);
epsilon_r.data = _mm512_set1_ps(epsilon);
decay_r.data = _mm512_set1_ps(*decay);
size_t c16 = ((end - start) / C16NUM) * C16NUM + start;
size_t c64 = ((end - start) / C64NUM) * C64NUM + start;
const float *gradient_ptr = gradient + start;
float *var_ptr = var + start;
float *m_ptr = m + start;
float *v_ptr = v + start;
for (; c1 < c64; c1 += C64NUM) {
struct AVX_Data g_r[4], var_r[4], m_r[4], v_r[4];
LoadStep4(g_r, gradient_ptr);
LoadStep4(var_r, var_ptr);
LoadStep4(m_r, m_ptr);
LoadStep4(v_r, v_ptr);
m_r[0].data = _mm512_mul_ps(m_r[0].data, beta1_r.data);
m_r[1].data = _mm512_mul_ps(m_r[1].data, beta1_r.data);
m_r[2].data = _mm512_mul_ps(m_r[2].data, beta1_r.data);
m_r[3].data = _mm512_mul_ps(m_r[3].data, beta1_r.data);
m_r[0].data = _mm512_fmadd_ps(g_r[0].data, beta1_minus_r.data, m_r[0].data);
m_r[1].data = _mm512_fmadd_ps(g_r[1].data, beta1_minus_r.data, m_r[1].data);
m_r[2].data = _mm512_fmadd_ps(g_r[2].data, beta1_minus_r.data, m_r[2].data);
m_r[3].data = _mm512_fmadd_ps(g_r[3].data, beta1_minus_r.data, m_r[3].data);
v_r[0].data = _mm512_mul_ps(v_r[0].data, beta2_r.data);
v_r[1].data = _mm512_mul_ps(v_r[1].data, beta2_r.data);
v_r[2].data = _mm512_mul_ps(v_r[2].data, beta2_r.data);
v_r[3].data = _mm512_mul_ps(v_r[3].data, beta2_r.data);
g_r[0].data = _mm512_mul_ps(g_r[0].data, g_r[0].data);
g_r[1].data = _mm512_mul_ps(g_r[1].data, g_r[1].data);
g_r[2].data = _mm512_mul_ps(g_r[2].data, g_r[2].data);
g_r[3].data = _mm512_mul_ps(g_r[3].data, g_r[3].data);
v_r[0].data = _mm512_fmadd_ps(g_r[0].data, beta2_minus_r.data, v_r[0].data);
v_r[1].data = _mm512_fmadd_ps(g_r[1].data, beta2_minus_r.data, v_r[1].data);
v_r[2].data = _mm512_fmadd_ps(g_r[2].data, beta2_minus_r.data, v_r[2].data);
v_r[3].data = _mm512_fmadd_ps(g_r[3].data, beta2_minus_r.data, v_r[3].data);
g_r[0].data = _mm512_sqrt_ps(v_r[0].data);
g_r[1].data = _mm512_sqrt_ps(v_r[1].data);
g_r[2].data = _mm512_sqrt_ps(v_r[2].data);
g_r[3].data = _mm512_sqrt_ps(v_r[3].data);
g_r[0].data = _mm512_div_ps(m_r[0].data, _mm512_add_ps(g_r[0].data, epsilon_r.data));
g_r[1].data = _mm512_div_ps(m_r[1].data, _mm512_add_ps(g_r[1].data, epsilon_r.data));
g_r[2].data = _mm512_div_ps(m_r[2].data, _mm512_add_ps(g_r[2].data, epsilon_r.data));
g_r[3].data = _mm512_div_ps(m_r[3].data, _mm512_add_ps(g_r[3].data, epsilon_r.data));
g_r[0].data = _mm512_fmadd_ps(var_r[0].data, decay_r.data, g_r[0].data);
g_r[1].data = _mm512_fmadd_ps(var_r[1].data, decay_r.data, g_r[1].data);
g_r[2].data = _mm512_fmadd_ps(var_r[2].data, decay_r.data, g_r[2].data);
g_r[3].data = _mm512_fmadd_ps(var_r[3].data, decay_r.data, g_r[3].data);
var_r[0].data = _mm512_fmadd_ps(g_r[0].data, lr_neg_r.data, var_r[0].data);
var_r[1].data = _mm512_fmadd_ps(g_r[1].data, lr_neg_r.data, var_r[1].data);
var_r[2].data = _mm512_fmadd_ps(g_r[2].data, lr_neg_r.data, var_r[2].data);
var_r[3].data = _mm512_fmadd_ps(g_r[3].data, lr_neg_r.data, var_r[3].data);
StoreStep4(var_ptr, var_r);
StoreStep4(m_ptr, m_r);
StoreStep4(v_ptr, v_r);
gradient_ptr += C64NUM;
var_ptr += C64NUM;
m_ptr += C64NUM;
v_ptr += C64NUM;
}
for (; c1 < c16; c1 += C16NUM) {
struct AVX_Data g_r, var_r, m_r, v_r;
g_r.data = _mm512_loadu_ps(gradient_ptr);
var_r.data = _mm512_loadu_ps(var_ptr);
m_r.data = _mm512_loadu_ps(m_ptr);
v_r.data = _mm512_loadu_ps(v_ptr);
m_r.data = _mm512_mul_ps(m_r.data, beta1_r.data);
v_r.data = _mm512_mul_ps(v_r.data, beta2_r.data);
struct AVX_Data avx_r0;
avx_r0.data = _mm512_mul_ps(g_r.data, g_r.data);
m_r.data = _mm512_fmadd_ps(g_r.data, beta1_minus_r.data, m_r.data);
v_r.data = _mm512_fmadd_ps(avx_r0.data, beta2_minus_r.data, v_r.data);
avx_r0.data = _mm512_sqrt_ps(v_r.data);
avx_r0.data = _mm512_div_ps(m_r.data, _mm512_add_ps(avx_r0.data, epsilon_r.data));
avx_r0.data = _mm512_fmadd_ps(var_r.data, decay_r.data, avx_r0.data);
var_r.data = _mm512_fmadd_ps(avx_r0.data, lr_neg_r.data, var_r.data);
_mm512_storeu_ps(var_ptr, var_r.data);
_mm512_storeu_ps(m_ptr, m_r.data);
_mm512_storeu_ps(v_ptr, v_r.data);
gradient_ptr += C16NUM;
var_ptr += C16NUM;
m_ptr += C16NUM;
v_ptr += C16NUM;
}
#endif
return c1;
}
int FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float *decay,
const int16_t *gradient16, size_t start, size_t end) {
size_t c1 = start;
#ifdef ENABLE_AVX512
float beta1_minus = 1 - beta1;
float beta2_minus = 1 - beta2;
struct AVX_Data beta1_r, beta2_r, beta1_minus_r, beta2_minus_r, lr_neg_r, epsilon_r, decay_r;
beta1_r.data = _mm512_set1_ps(beta1);
beta2_r.data = _mm512_set1_ps(beta2);
beta1_minus_r.data = _mm512_set1_ps(beta1_minus);
beta2_minus_r.data = _mm512_set1_ps(beta2_minus);
lr_neg_r.data = _mm512_set1_ps(-lr);
epsilon_r.data = _mm512_set1_ps(epsilon);
decay_r.data = _mm512_set1_ps(*decay);
size_t c16 = ((end - start) / C16NUM) * C16NUM + start;
size_t c64 = ((end - start) / C64NUM) * C64NUM + start;
const int16_t *gradient16_ptr = gradient16 + start;
float *var_ptr = var + start;
float *m_ptr = m + start;
float *v_ptr = v + start;
for (; c1 < c64; c1 += C64NUM) {
struct AVX_Data g_r[4], var_r[4], m_r[4], v_r[4];
g_r[0].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr)));
g_r[1].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr + C16NUM)));
g_r[2].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr + C16NUM * 2)));
g_r[3].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr + C16NUM * 3)));
LoadStep4(var_r, var_ptr);
LoadStep4(m_r, m_ptr);
LoadStep4(v_r, v_ptr);
m_r[0].data = _mm512_mul_ps(m_r[0].data, beta1_r.data);
m_r[1].data = _mm512_mul_ps(m_r[1].data, beta1_r.data);
m_r[2].data = _mm512_mul_ps(m_r[2].data, beta1_r.data);
m_r[3].data = _mm512_mul_ps(m_r[3].data, beta1_r.data);
m_r[0].data = _mm512_fmadd_ps(g_r[0].data, beta1_minus_r.data, m_r[0].data);
m_r[1].data = _mm512_fmadd_ps(g_r[1].data, beta1_minus_r.data, m_r[1].data);
m_r[2].data = _mm512_fmadd_ps(g_r[2].data, beta1_minus_r.data, m_r[2].data);
m_r[3].data = _mm512_fmadd_ps(g_r[3].data, beta1_minus_r.data, m_r[3].data);
v_r[0].data = _mm512_mul_ps(v_r[0].data, beta2_r.data);
v_r[1].data = _mm512_mul_ps(v_r[1].data, beta2_r.data);
v_r[2].data = _mm512_mul_ps(v_r[2].data, beta2_r.data);
v_r[3].data = _mm512_mul_ps(v_r[3].data, beta2_r.data);
g_r[0].data = _mm512_mul_ps(g_r[0].data, g_r[0].data);
g_r[1].data = _mm512_mul_ps(g_r[1].data, g_r[1].data);
g_r[2].data = _mm512_mul_ps(g_r[2].data, g_r[2].data);
g_r[3].data = _mm512_mul_ps(g_r[3].data, g_r[3].data);
v_r[0].data = _mm512_fmadd_ps(g_r[0].data, beta2_minus_r.data, v_r[0].data);
v_r[1].data = _mm512_fmadd_ps(g_r[1].data, beta2_minus_r.data, v_r[1].data);
v_r[2].data = _mm512_fmadd_ps(g_r[2].data, beta2_minus_r.data, v_r[2].data);
v_r[3].data = _mm512_fmadd_ps(g_r[3].data, beta2_minus_r.data, v_r[3].data);
g_r[0].data = _mm512_sqrt_ps(v_r[0].data);
g_r[1].data = _mm512_sqrt_ps(v_r[1].data);
g_r[2].data = _mm512_sqrt_ps(v_r[2].data);
g_r[3].data = _mm512_sqrt_ps(v_r[3].data);
g_r[0].data = _mm512_div_ps(m_r[0].data, _mm512_add_ps(g_r[0].data, epsilon_r.data));
g_r[1].data = _mm512_div_ps(m_r[1].data, _mm512_add_ps(g_r[1].data, epsilon_r.data));
g_r[2].data = _mm512_div_ps(m_r[2].data, _mm512_add_ps(g_r[2].data, epsilon_r.data));
g_r[3].data = _mm512_div_ps(m_r[3].data, _mm512_add_ps(g_r[3].data, epsilon_r.data));
g_r[0].data = _mm512_fmadd_ps(var_r[0].data, decay_r.data, g_r[0].data);
g_r[1].data = _mm512_fmadd_ps(var_r[1].data, decay_r.data, g_r[1].data);
g_r[2].data = _mm512_fmadd_ps(var_r[2].data, decay_r.data, g_r[2].data);
g_r[3].data = _mm512_fmadd_ps(var_r[3].data, decay_r.data, g_r[3].data);
var_r[0].data = _mm512_fmadd_ps(g_r[0].data, lr_neg_r.data, var_r[0].data);
var_r[1].data = _mm512_fmadd_ps(g_r[1].data, lr_neg_r.data, var_r[1].data);
var_r[2].data = _mm512_fmadd_ps(g_r[2].data, lr_neg_r.data, var_r[2].data);
var_r[3].data = _mm512_fmadd_ps(g_r[3].data, lr_neg_r.data, var_r[3].data);
StoreStep4(var_ptr, var_r);
StoreStep4(m_ptr, m_r);
StoreStep4(v_ptr, v_r);
gradient16_ptr += C64NUM;
var_ptr += C64NUM;
m_ptr += C64NUM;
v_ptr += C64NUM;
}
for (; c1 < c16; c1 += C16NUM) {
struct AVX_Data g_r, var_r, m_r, v_r;
g_r.data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr)));
var_r.data = _mm512_loadu_ps(var_ptr);
m_r.data = _mm512_loadu_ps(m_ptr);
v_r.data = _mm512_loadu_ps(v_ptr);
m_r.data = _mm512_mul_ps(m_r.data, beta1_r.data);
v_r.data = _mm512_mul_ps(v_r.data, beta2_r.data);
struct AVX_Data avx_r0;
avx_r0.data = _mm512_mul_ps(g_r.data, g_r.data);
m_r.data = _mm512_fmadd_ps(g_r.data, beta1_minus_r.data, m_r.data);
v_r.data = _mm512_fmadd_ps(avx_r0.data, beta2_minus_r.data, v_r.data);
avx_r0.data = _mm512_sqrt_ps(v_r.data);
avx_r0.data = _mm512_div_ps(m_r.data, _mm512_add_ps(avx_r0.data, epsilon_r.data));
avx_r0.data = _mm512_fmadd_ps(var_r.data, decay_r.data, avx_r0.data);
var_r.data = _mm512_fmadd_ps(avx_r0.data, lr_neg_r.data, var_r.data);
_mm512_storeu_ps(var_ptr, var_r.data);
_mm512_storeu_ps(m_ptr, m_r.data);
_mm512_storeu_ps(v_ptr, v_r.data);
gradient16_ptr += C16NUM;
var_ptr += C16NUM;
m_ptr += C16NUM;
v_ptr += C16NUM;
}
#endif
return c1;
}

View File

@ -26,6 +26,10 @@ int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2,
size_t start, size_t end, bool use_nesterov);
int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
const float *gradient, size_t start, size_t end, bool use_nesterov);
int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float *decay,
const float *gradient, size_t start, size_t end);
int FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float *decay,
const int16_t *gradient16, size_t start, size_t end);
#ifdef __cplusplus
}
#endif

View File

@ -34,6 +34,7 @@
#define C16NUM 16
#define C24NUM 24
#define C32NUM 32
#define C64NUM 64
#define TILE_NUM 8
#define MSMIN(x, y) ((x) < (y) ? (x) : (y))

View File

@ -533,8 +533,9 @@ class AdamWeightDecay(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype,
epsilon_dtype, decay, grad_dtype):
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype}
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
validator.check_scalar_or_tensor_types_same({"grad": grad_dtype}, [mstype.float16, mstype.float32], self.name)
args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype,
"decay": decay}