zsd AdamWeightDecay CPU op avx512
This commit is contained in:
parent
20a5e30481
commit
9a42eda0c8
|
@ -66,18 +66,6 @@ if(ENABLE_CPU)
|
|||
if(PLATFORM_ARM64)
|
||||
add_compile_definitions(ENABLE_ARM)
|
||||
endif()
|
||||
|
||||
if("${ARM_SIMD}" STREQUAL "neon")
|
||||
set(CPU_SIMD_SRC "${CMAKE_CURRENT_SOURCE_DIR}/cpu/adam_weight_decay_cpu_kernel.cc")
|
||||
add_compile_definitions(ENABLE_NEON)
|
||||
set_property(SOURCE ${CPU_SIMD_SRC} PROPERTY COMPILE_OPTIONS -O3 -ffast-math)
|
||||
endif()
|
||||
|
||||
if("${X86_64_SIMD}" STREQUAL "avx")
|
||||
set(CPU_SIMD_SRC "${CMAKE_CURRENT_SOURCE_DIR}/cpu/adam_weight_decay_cpu_kernel.cc")
|
||||
add_compile_definitions(ENABLE_AVX512)
|
||||
set_property(SOURCE ${CPU_SIMD_SRC} PROPERTY COMPILE_OPTIONS -O3 -fopenmp -mavx512f -ffast-math)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(NOT ENABLE_CPU OR WIN32)
|
||||
|
|
|
@ -18,75 +18,94 @@
|
|||
#include <cmath>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "nnacl/fp32/adam_fp32.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "common/thread_pool.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
void AdamWeightDecayCPUKernel::LaunchAdamWeightDecay(T *var, T *m, T *v, float lr, float beta1, float beta2,
|
||||
float epsilon, T *decay, const T *gradient, size_t size) {
|
||||
constexpr size_t kSizeFloat16 = sizeof(float16);
|
||||
constexpr size_t kSizeFloat32 = sizeof(float);
|
||||
constexpr size_t kAdamWeightDecayInputSize = 9;
|
||||
constexpr size_t kAdamWeightDecayOutputSize = 3;
|
||||
|
||||
void AdamWeightDecayCPUKernel::ParallelForAdam(const CTask &task, size_t count) {
|
||||
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
|
||||
const float block_size = 128.0;
|
||||
const float align_size = 16.0;
|
||||
size_t thread_num = count < block_size * max_thread_num ? std::ceil(count / block_size) : max_thread_num;
|
||||
std::vector<common::Task> tasks;
|
||||
size_t start = 0;
|
||||
size_t once_compute_size = align_size * std::ceil(count / (align_size * thread_num));
|
||||
while (start < count) {
|
||||
size_t end = (start + once_compute_size) > count ? count : (start + once_compute_size);
|
||||
auto block = [&, start, end]() {
|
||||
task(start, end);
|
||||
return common::SUCCESS;
|
||||
};
|
||||
tasks.emplace_back(block);
|
||||
start += once_compute_size;
|
||||
}
|
||||
common::ThreadPool::GetInstance().SyncRun(tasks);
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void AdamWeightDecayCPUKernel::LaunchFusedAdam(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto var = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto m = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto v = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
auto lr = reinterpret_cast<T *>(inputs[3]->addr)[0];
|
||||
auto beta1 = reinterpret_cast<T *>(inputs[4]->addr)[0];
|
||||
auto beta2 = reinterpret_cast<T *>(inputs[5]->addr)[0];
|
||||
auto epsilon = reinterpret_cast<T *>(inputs[6]->addr)[0];
|
||||
auto decay = reinterpret_cast<T *>(inputs[7]->addr);
|
||||
auto gradient16 = reinterpret_cast<S *>(inputs[8]->addr);
|
||||
|
||||
float beta1_minus = 1 - beta1;
|
||||
float beta2_minus = 1 - beta2;
|
||||
#if defined(ENABLE_AVX512)
|
||||
MS_FLOAT32X16 beta1_16 = MS_MOV512_F32(beta1);
|
||||
MS_FLOAT32X16 beta2_16 = MS_MOV512_F32(beta2);
|
||||
MS_FLOAT32X16 beta1_minus_16 = MS_MOV512_F32(beta1_minus);
|
||||
MS_FLOAT32X16 beta2_minus_16 = MS_MOV512_F32(beta2_minus);
|
||||
MS_FLOAT32X16 lr_neg_16 = MS_MOV512_F32(-lr);
|
||||
MS_FLOAT32X16 epsilon_16 = MS_MOV512_F32(epsilon);
|
||||
MS_FLOAT32X16 decay_16 = MS_MOV512_F32(*decay);
|
||||
#endif
|
||||
#if defined(ENABLE_NEON)
|
||||
MS_FLOAT32X4 epsilon_4 = MS_MOVQ_F32(epsilon);
|
||||
float lr_neg = -lr;
|
||||
#endif
|
||||
// multithreading
|
||||
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
|
||||
std::function<void(size_t, size_t)> task;
|
||||
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
size_t i = start;
|
||||
#if defined(ENABLE_AVX512)
|
||||
if (end >= MS_AVX512_WIDTH) {
|
||||
for (; i <= end - MS_AVX512_WIDTH; i += MS_AVX512_WIDTH) {
|
||||
MS_FLOAT32X16 var_16 = MS_LD512_F32(var + i);
|
||||
MS_FLOAT32X16 m_16 = MS_LD512_F32(m + i);
|
||||
MS_FLOAT32X16 v_16 = MS_LD512_F32(v + i);
|
||||
MS_FLOAT32X16 g_16 = MS_LD512_F32(gradient + i);
|
||||
m_16 = MS_MUL512_F32(m_16, beta1_16);
|
||||
m_16 = MS_FMA512_F32(g_16, beta1_minus_16, m_16);
|
||||
v_16 = MS_MUL512_F32(v_16, beta2_16);
|
||||
v_16 = MS_MUL512_F32(g_16, g_16);
|
||||
v_16 = MS_FMA512_F32(g_16, beta2_minus_16, v_16);
|
||||
g_16 = MS_SQRT512_F32(v_16);
|
||||
g_16 = MS_DIV512_F32(m_16, MS_ADD512_F32(g_16, epsilon_16));
|
||||
g_16 = MS_FMA512_F32(var_16, decay_16, g_16);
|
||||
var_16 = MS_FMA512_F32(g_16, lr_neg_16, var_16);
|
||||
MS_ST512_F32(var + i, var_16);
|
||||
MS_ST512_F32(m + i, m_16);
|
||||
MS_ST512_F32(v + i, v_16);
|
||||
}
|
||||
task = [&](size_t start, size_t end) {
|
||||
size_t i =
|
||||
FusedAdamFp32(var, m, v, lr, beta1, beta2, epsilon, decay, reinterpret_cast<int16_t *>(gradient16), start, end);
|
||||
// remaining
|
||||
for (; i < end; i++) {
|
||||
auto temp = static_cast<float>(gradient16[i]);
|
||||
m[i] += (temp - m[i]) * beta1_minus;
|
||||
v[i] += (temp * temp - v[i]) * beta2_minus;
|
||||
T update = m[i] / (std::sqrt(v[i]) + epsilon);
|
||||
update += *decay * var[i];
|
||||
var[i] -= lr * update;
|
||||
}
|
||||
#endif
|
||||
#if defined(ENABLE_NEON)
|
||||
if (end >= MS_NEON_WIDTH) {
|
||||
for (; i <= end - MS_NEON_WIDTH; i += MS_NEON_WIDTH) {
|
||||
MS_FLOAT32X4 var_4 = MS_LDQ_F32(var + i);
|
||||
MS_FLOAT32X4 m_4 = MS_LDQ_F32(m + i);
|
||||
MS_FLOAT32X4 v_4 = MS_LDQ_F32(v + i);
|
||||
MS_FLOAT32X4 g_4 = MS_LDQ_F32(gradient + i);
|
||||
m_4 = MS_MULQ_N_F32(m_4, beta1);
|
||||
m_4 = MS_MLAQ_N_F32(m_4, g_4, beta1_minus);
|
||||
v_4 = MS_MULQ_N_F32(v_4, beta2);
|
||||
g_4 = MS_MULQ_F32(g_4, g_4);
|
||||
v_4 = MS_MLAQ_N_F32(v_4, g_4, beta2_minus);
|
||||
g_4 = MS_SQRT_F32(v_4);
|
||||
g_4 = MS_DIVQ_F32(m_4, MS_ADDQ_F32(g_4, epsilon_4));
|
||||
g_4 = MS_MLAQ_N_F32(g_4, var_4, *decay);
|
||||
var_4 = MS_MLAQ_N_F32(var_4, g_4, lr_neg);
|
||||
MS_STQ_F32(var + i, var_4);
|
||||
MS_STQ_F32(m + i, m_4);
|
||||
MS_STQ_F32(v + i, v_4);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
};
|
||||
ParallelForAdam(task, lens);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AdamWeightDecayCPUKernel::LaunchAdamWeightDecay(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto var = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto m = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto v = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
auto lr = reinterpret_cast<T *>(inputs[3]->addr)[0];
|
||||
auto beta1 = reinterpret_cast<T *>(inputs[4]->addr)[0];
|
||||
auto beta2 = reinterpret_cast<T *>(inputs[5]->addr)[0];
|
||||
auto epsilon = reinterpret_cast<T *>(inputs[6]->addr)[0];
|
||||
auto decay = reinterpret_cast<T *>(inputs[7]->addr);
|
||||
auto gradient = reinterpret_cast<T *>(inputs[8]->addr);
|
||||
auto beta1_minus = 1 - beta1;
|
||||
auto beta2_minus = 1 - beta2;
|
||||
|
||||
// multithreading
|
||||
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
|
||||
std::function<void(size_t, size_t)> task;
|
||||
|
||||
task = [&](size_t start, size_t end) {
|
||||
size_t i = AdamWeightDecayFp32(var, m, v, lr, beta1, beta2, epsilon, decay, gradient, start, end);
|
||||
// remaining
|
||||
for (; i < end; i++) {
|
||||
m[i] += (gradient[i] - m[i]) * beta1_minus;
|
||||
v[i] += (gradient[i] * gradient[i] - v[i]) * beta2_minus;
|
||||
|
@ -95,51 +114,66 @@ void AdamWeightDecayCPUKernel::LaunchAdamWeightDecay(T *var, T *m, T *v, float l
|
|||
var[i] -= lr * update;
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, size);
|
||||
ParallelForAdam(task, lens);
|
||||
}
|
||||
|
||||
void AdamWeightDecayCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::vector<size_t> var_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
gradient_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 8);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 9) {
|
||||
if (input_num != kAdamWeightDecayInputSize) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but AdamWeightDecay needs 9 inputs.";
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 3) {
|
||||
if (output_num != kAdamWeightDecayOutputSize) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but AdamWeightDecay needs 3 outputs.";
|
||||
}
|
||||
elem_num_ = 1;
|
||||
for (size_t i : var_shape) {
|
||||
elem_num_ *= i;
|
||||
}
|
||||
if (elem_num_ < 1) {
|
||||
MS_LOG(EXCEPTION) << "Invalid parameter shape";
|
||||
}
|
||||
if (dtype_ != kNumberTypeFloat32) {
|
||||
MS_LOG(EXCEPTION) << "The dtype of parameter must be float32!";
|
||||
}
|
||||
if (gradient_dtype_ != kNumberTypeFloat32 && gradient_dtype_ != kNumberTypeFloat16) {
|
||||
MS_LOG(EXCEPTION) << "The dtype of gradient must be float32 or float16!";
|
||||
}
|
||||
}
|
||||
|
||||
void AdamWeightDecayCPUKernel::CheckParam(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (inputs.size() != kAdamWeightDecayInputSize) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << inputs.size() << ", but AdamWeightDecay needs 9 inputs.";
|
||||
}
|
||||
if (outputs.size() != kAdamWeightDecayOutputSize) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but AdamWeightDecay needs 3 outputs.";
|
||||
}
|
||||
size_t elem1_size = elem_num_ * kSizeFloat32;
|
||||
size_t elem2_size = gradient_dtype_ == kNumberTypeFloat16 ? elem_num_ * kSizeFloat16 : elem1_size;
|
||||
if (inputs[0]->size != elem1_size || inputs[1]->size != elem1_size || inputs[2]->size != elem1_size ||
|
||||
inputs[8]->size != elem2_size) {
|
||||
MS_LOG(EXCEPTION) << "Error input data size!";
|
||||
}
|
||||
if (inputs[3]->size != kSizeFloat32 || inputs[4]->size != kSizeFloat32 || inputs[5]->size != kSizeFloat32 ||
|
||||
inputs[6]->size != kSizeFloat32 || inputs[7]->size != kSizeFloat32) {
|
||||
MS_LOG(EXCEPTION) << "The attribute beta, lr, epsilon and weight decay must be float!";
|
||||
}
|
||||
}
|
||||
|
||||
bool AdamWeightDecayCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (inputs.size() != 9) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << inputs.size() << ", but AdamWeightDecay needs 9 inputs.";
|
||||
CheckParam(inputs, outputs);
|
||||
if (gradient_dtype_ == kNumberTypeFloat16) {
|
||||
LaunchFusedAdam<float, float16>(inputs, outputs);
|
||||
} else if (gradient_dtype_ == kNumberTypeFloat32) {
|
||||
LaunchAdamWeightDecay<float>(inputs, outputs);
|
||||
}
|
||||
if (outputs.size() != 3) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but AdamWeightDecay needs 3 outputs.";
|
||||
}
|
||||
if (inputs[0]->size != inputs[1]->size || inputs[0]->size != inputs[2]->size || inputs[0]->size != inputs[8]->size) {
|
||||
MS_LOG(EXCEPTION) << "Error input data size!";
|
||||
}
|
||||
size_t f_size = sizeof(float);
|
||||
if (inputs[3]->size != f_size || inputs[4]->size != f_size || inputs[5]->size != f_size ||
|
||||
inputs[6]->size != f_size || inputs[7]->size != f_size) {
|
||||
MS_LOG(EXCEPTION) << "The attribute beta, lr and epsilon must be float!";
|
||||
}
|
||||
auto var = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
auto m = reinterpret_cast<float *>(inputs[1]->addr);
|
||||
auto v = reinterpret_cast<float *>(inputs[2]->addr);
|
||||
float lr = reinterpret_cast<float *>(inputs[3]->addr)[0];
|
||||
float beta1 = reinterpret_cast<float *>(inputs[4]->addr)[0];
|
||||
float beta2 = reinterpret_cast<float *>(inputs[5]->addr)[0];
|
||||
float epsilon = reinterpret_cast<float *>(inputs[6]->addr)[0];
|
||||
auto decay = reinterpret_cast<float *>(inputs[7]->addr);
|
||||
auto gradient = reinterpret_cast<float *>(inputs[8]->addr);
|
||||
|
||||
// multithreading
|
||||
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
|
||||
LaunchAdamWeightDecay<float>(var, m, v, lr, beta1, beta2, epsilon, decay, gradient, lens);
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
|
|
|
@ -21,59 +21,26 @@
|
|||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_AVX512)
|
||||
#include <x86intrin.h>
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
#define MS_FLOAT32X4 float32x4_t
|
||||
#define MS_LDQ_F32 vld1q_f32
|
||||
#define MS_MOVQ_F32 vmovq_n_f32
|
||||
#define MS_STQ_F32 vst1q_f32
|
||||
#define MS_ADDQ_F32(src1, src2) vaddq_f32(src1, src2)
|
||||
#define MS_MULQ_F32(src1, src2) vmulq_f32(src1, src2)
|
||||
#define MS_MULQ_N_F32(src1, src2) vmulq_n_f32(src1, src2)
|
||||
#define MS_DIVQ_F32(src1, src2) vdivq_f32(src1, src2)
|
||||
#define MS_MLAQ_F32(src1, src2, src3) vmlaq_f32(src1, src2, src3)
|
||||
#define MS_MLAQ_N_F32(src1, src2, src3) vmlaq_n_f32(src1, src2, src3)
|
||||
#define MS_SQRT_F32(src) vsqrtq_f32(src)
|
||||
#define MS_CAST_F32_F16(src) vreinterpretq_f32_f16(src)
|
||||
#define MS_NEON_WIDTH 4
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_AVX512)
|
||||
#define MS_FLOAT32X16 __m512
|
||||
#define MS_LD512_F32 _mm512_loadu_ps
|
||||
#define MS_ST512_F32 _mm512_storeu_ps
|
||||
#define MS_MOV512_F32 _mm512_set1_ps
|
||||
#define MS_ADD512_F32(src1, src2) _mm512_add_ps(src1, src2)
|
||||
#define MS_MUL512_F32(src1, src2) _mm512_mul_ps(src1, src2)
|
||||
#define MS_DIV512_F32(src1, src2) _mm512_div_ps(src1, src2)
|
||||
#define MS_FMA512_F32(src1, src2, src3) _mm512_fmadd_ps(src1, src2, src3)
|
||||
#define MS_SQRT512_F32(src) _mm512_sqrt_ps(src)
|
||||
#define MS_CAST512_F32_S32(src) _mm512_castsi512_ps(src)
|
||||
#define MS_AVX512_WIDTH 16
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class AdamWeightDecayCPUKernel : public CPUKernel {
|
||||
public:
|
||||
AdamWeightDecayCPUKernel() = default;
|
||||
~AdamWeightDecayCPUKernel() override = default;
|
||||
template <typename T>
|
||||
void LaunchAdamWeightDecay(T *var, T *m, T *v, float lr, float beta1, float beta2, float epsilon, T *decay,
|
||||
const T *gradient, size_t size);
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
void ParallelForAdam(const CTask &task, size_t count);
|
||||
void CheckParam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
template <typename T, typename S>
|
||||
void LaunchFusedAdam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
template <typename T>
|
||||
void LaunchAdamWeightDecay(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
size_t elem_num_{0};
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
TypeId gradient_dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(AdamWeightDecay,
|
||||
|
@ -91,6 +58,22 @@ MS_REG_CPU_KERNEL(AdamWeightDecay,
|
|||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
AdamWeightDecayCPUKernel)
|
||||
|
||||
MS_REG_CPU_KERNEL(AdamWeightDecay,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
AdamWeightDecayCPUKernel)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -72,7 +72,8 @@ if(ENABLE_CPU)
|
|||
elseif("${X86_64_SIMD}" STREQUAL "sse")
|
||||
target_compile_definitions(nnacl_mid PRIVATE ENABLE_SSE)
|
||||
elseif("${X86_64_SIMD}" STREQUAL "avx")
|
||||
target_compile_definitions(nnacl_mid PRIVATE ENABLE_SSE ENABLE_AVX)
|
||||
target_compile_definitions(nnacl_mid PRIVATE ENABLE_SSE ENABLE_AVX ENABLE_AVX512)
|
||||
target_compile_options(nnacl_mid PRIVATE -mavx512f)
|
||||
endif()
|
||||
target_compile_options(nnacl_mid PRIVATE -fPIC)
|
||||
if(NOT CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||
|
|
|
@ -25,6 +25,25 @@
|
|||
#include "nnacl/fp32/exp_fp32.h"
|
||||
#include "nnacl/fp32/adam_fp32.h"
|
||||
|
||||
#ifdef ENABLE_AVX512
|
||||
struct AVX_Data {
|
||||
__m512 data;
|
||||
};
|
||||
|
||||
inline void LoadStep4(struct AVX_Data *inp0, const float *inp1) {
|
||||
inp0[0].data = _mm512_loadu_ps(inp1);
|
||||
inp0[1].data = _mm512_loadu_ps(inp1 + C16NUM);
|
||||
inp0[2].data = _mm512_loadu_ps(inp1 + C16NUM * 2);
|
||||
inp0[3].data = _mm512_loadu_ps(inp1 + C16NUM * 3);
|
||||
}
|
||||
|
||||
inline void StoreStep4(float *inp0, struct AVX_Data *inp1) {
|
||||
_mm512_storeu_ps(inp0, inp1[0].data);
|
||||
_mm512_storeu_ps(inp0 + C16NUM, inp1[1].data);
|
||||
_mm512_storeu_ps(inp0 + C16NUM * 2, inp1[2].data);
|
||||
_mm512_storeu_ps(inp0 + C16NUM * 3, inp1[3].data);
|
||||
}
|
||||
#endif
|
||||
int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, const float *gradient,
|
||||
size_t start, size_t end, bool use_nesterov) {
|
||||
size_t c1 = start;
|
||||
|
@ -159,3 +178,223 @@ int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float
|
|||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float *decay,
|
||||
const float *gradient, size_t start, size_t end) {
|
||||
size_t c1 = start;
|
||||
#ifdef ENABLE_AVX512
|
||||
float beta1_minus = 1 - beta1;
|
||||
float beta2_minus = 1 - beta2;
|
||||
struct AVX_Data beta1_r, beta2_r, beta1_minus_r, beta2_minus_r, lr_neg_r, epsilon_r, decay_r;
|
||||
beta1_r.data = _mm512_set1_ps(beta1);
|
||||
beta2_r.data = _mm512_set1_ps(beta2);
|
||||
beta1_minus_r.data = _mm512_set1_ps(beta1_minus);
|
||||
beta2_minus_r.data = _mm512_set1_ps(beta2_minus);
|
||||
lr_neg_r.data = _mm512_set1_ps(-lr);
|
||||
epsilon_r.data = _mm512_set1_ps(epsilon);
|
||||
decay_r.data = _mm512_set1_ps(*decay);
|
||||
size_t c16 = ((end - start) / C16NUM) * C16NUM + start;
|
||||
size_t c64 = ((end - start) / C64NUM) * C64NUM + start;
|
||||
|
||||
const float *gradient_ptr = gradient + start;
|
||||
float *var_ptr = var + start;
|
||||
float *m_ptr = m + start;
|
||||
float *v_ptr = v + start;
|
||||
|
||||
for (; c1 < c64; c1 += C64NUM) {
|
||||
struct AVX_Data g_r[4], var_r[4], m_r[4], v_r[4];
|
||||
LoadStep4(g_r, gradient_ptr);
|
||||
LoadStep4(var_r, var_ptr);
|
||||
LoadStep4(m_r, m_ptr);
|
||||
LoadStep4(v_r, v_ptr);
|
||||
|
||||
m_r[0].data = _mm512_mul_ps(m_r[0].data, beta1_r.data);
|
||||
m_r[1].data = _mm512_mul_ps(m_r[1].data, beta1_r.data);
|
||||
m_r[2].data = _mm512_mul_ps(m_r[2].data, beta1_r.data);
|
||||
m_r[3].data = _mm512_mul_ps(m_r[3].data, beta1_r.data);
|
||||
m_r[0].data = _mm512_fmadd_ps(g_r[0].data, beta1_minus_r.data, m_r[0].data);
|
||||
m_r[1].data = _mm512_fmadd_ps(g_r[1].data, beta1_minus_r.data, m_r[1].data);
|
||||
m_r[2].data = _mm512_fmadd_ps(g_r[2].data, beta1_minus_r.data, m_r[2].data);
|
||||
m_r[3].data = _mm512_fmadd_ps(g_r[3].data, beta1_minus_r.data, m_r[3].data);
|
||||
|
||||
v_r[0].data = _mm512_mul_ps(v_r[0].data, beta2_r.data);
|
||||
v_r[1].data = _mm512_mul_ps(v_r[1].data, beta2_r.data);
|
||||
v_r[2].data = _mm512_mul_ps(v_r[2].data, beta2_r.data);
|
||||
v_r[3].data = _mm512_mul_ps(v_r[3].data, beta2_r.data);
|
||||
g_r[0].data = _mm512_mul_ps(g_r[0].data, g_r[0].data);
|
||||
g_r[1].data = _mm512_mul_ps(g_r[1].data, g_r[1].data);
|
||||
g_r[2].data = _mm512_mul_ps(g_r[2].data, g_r[2].data);
|
||||
g_r[3].data = _mm512_mul_ps(g_r[3].data, g_r[3].data);
|
||||
v_r[0].data = _mm512_fmadd_ps(g_r[0].data, beta2_minus_r.data, v_r[0].data);
|
||||
v_r[1].data = _mm512_fmadd_ps(g_r[1].data, beta2_minus_r.data, v_r[1].data);
|
||||
v_r[2].data = _mm512_fmadd_ps(g_r[2].data, beta2_minus_r.data, v_r[2].data);
|
||||
v_r[3].data = _mm512_fmadd_ps(g_r[3].data, beta2_minus_r.data, v_r[3].data);
|
||||
|
||||
g_r[0].data = _mm512_sqrt_ps(v_r[0].data);
|
||||
g_r[1].data = _mm512_sqrt_ps(v_r[1].data);
|
||||
g_r[2].data = _mm512_sqrt_ps(v_r[2].data);
|
||||
g_r[3].data = _mm512_sqrt_ps(v_r[3].data);
|
||||
g_r[0].data = _mm512_div_ps(m_r[0].data, _mm512_add_ps(g_r[0].data, epsilon_r.data));
|
||||
g_r[1].data = _mm512_div_ps(m_r[1].data, _mm512_add_ps(g_r[1].data, epsilon_r.data));
|
||||
g_r[2].data = _mm512_div_ps(m_r[2].data, _mm512_add_ps(g_r[2].data, epsilon_r.data));
|
||||
g_r[3].data = _mm512_div_ps(m_r[3].data, _mm512_add_ps(g_r[3].data, epsilon_r.data));
|
||||
g_r[0].data = _mm512_fmadd_ps(var_r[0].data, decay_r.data, g_r[0].data);
|
||||
g_r[1].data = _mm512_fmadd_ps(var_r[1].data, decay_r.data, g_r[1].data);
|
||||
g_r[2].data = _mm512_fmadd_ps(var_r[2].data, decay_r.data, g_r[2].data);
|
||||
g_r[3].data = _mm512_fmadd_ps(var_r[3].data, decay_r.data, g_r[3].data);
|
||||
|
||||
var_r[0].data = _mm512_fmadd_ps(g_r[0].data, lr_neg_r.data, var_r[0].data);
|
||||
var_r[1].data = _mm512_fmadd_ps(g_r[1].data, lr_neg_r.data, var_r[1].data);
|
||||
var_r[2].data = _mm512_fmadd_ps(g_r[2].data, lr_neg_r.data, var_r[2].data);
|
||||
var_r[3].data = _mm512_fmadd_ps(g_r[3].data, lr_neg_r.data, var_r[3].data);
|
||||
|
||||
StoreStep4(var_ptr, var_r);
|
||||
StoreStep4(m_ptr, m_r);
|
||||
StoreStep4(v_ptr, v_r);
|
||||
|
||||
gradient_ptr += C64NUM;
|
||||
var_ptr += C64NUM;
|
||||
m_ptr += C64NUM;
|
||||
v_ptr += C64NUM;
|
||||
}
|
||||
for (; c1 < c16; c1 += C16NUM) {
|
||||
struct AVX_Data g_r, var_r, m_r, v_r;
|
||||
g_r.data = _mm512_loadu_ps(gradient_ptr);
|
||||
var_r.data = _mm512_loadu_ps(var_ptr);
|
||||
m_r.data = _mm512_loadu_ps(m_ptr);
|
||||
v_r.data = _mm512_loadu_ps(v_ptr);
|
||||
|
||||
m_r.data = _mm512_mul_ps(m_r.data, beta1_r.data);
|
||||
v_r.data = _mm512_mul_ps(v_r.data, beta2_r.data);
|
||||
struct AVX_Data avx_r0;
|
||||
avx_r0.data = _mm512_mul_ps(g_r.data, g_r.data);
|
||||
m_r.data = _mm512_fmadd_ps(g_r.data, beta1_minus_r.data, m_r.data);
|
||||
v_r.data = _mm512_fmadd_ps(avx_r0.data, beta2_minus_r.data, v_r.data);
|
||||
avx_r0.data = _mm512_sqrt_ps(v_r.data);
|
||||
avx_r0.data = _mm512_div_ps(m_r.data, _mm512_add_ps(avx_r0.data, epsilon_r.data));
|
||||
avx_r0.data = _mm512_fmadd_ps(var_r.data, decay_r.data, avx_r0.data);
|
||||
var_r.data = _mm512_fmadd_ps(avx_r0.data, lr_neg_r.data, var_r.data);
|
||||
_mm512_storeu_ps(var_ptr, var_r.data);
|
||||
_mm512_storeu_ps(m_ptr, m_r.data);
|
||||
_mm512_storeu_ps(v_ptr, v_r.data);
|
||||
|
||||
gradient_ptr += C16NUM;
|
||||
var_ptr += C16NUM;
|
||||
m_ptr += C16NUM;
|
||||
v_ptr += C16NUM;
|
||||
}
|
||||
#endif
|
||||
return c1;
|
||||
}
|
||||
|
||||
int FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float *decay,
|
||||
const int16_t *gradient16, size_t start, size_t end) {
|
||||
size_t c1 = start;
|
||||
#ifdef ENABLE_AVX512
|
||||
float beta1_minus = 1 - beta1;
|
||||
float beta2_minus = 1 - beta2;
|
||||
struct AVX_Data beta1_r, beta2_r, beta1_minus_r, beta2_minus_r, lr_neg_r, epsilon_r, decay_r;
|
||||
beta1_r.data = _mm512_set1_ps(beta1);
|
||||
beta2_r.data = _mm512_set1_ps(beta2);
|
||||
beta1_minus_r.data = _mm512_set1_ps(beta1_minus);
|
||||
beta2_minus_r.data = _mm512_set1_ps(beta2_minus);
|
||||
lr_neg_r.data = _mm512_set1_ps(-lr);
|
||||
epsilon_r.data = _mm512_set1_ps(epsilon);
|
||||
decay_r.data = _mm512_set1_ps(*decay);
|
||||
size_t c16 = ((end - start) / C16NUM) * C16NUM + start;
|
||||
size_t c64 = ((end - start) / C64NUM) * C64NUM + start;
|
||||
|
||||
const int16_t *gradient16_ptr = gradient16 + start;
|
||||
float *var_ptr = var + start;
|
||||
float *m_ptr = m + start;
|
||||
float *v_ptr = v + start;
|
||||
|
||||
for (; c1 < c64; c1 += C64NUM) {
|
||||
struct AVX_Data g_r[4], var_r[4], m_r[4], v_r[4];
|
||||
g_r[0].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr)));
|
||||
g_r[1].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr + C16NUM)));
|
||||
g_r[2].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr + C16NUM * 2)));
|
||||
g_r[3].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr + C16NUM * 3)));
|
||||
|
||||
LoadStep4(var_r, var_ptr);
|
||||
LoadStep4(m_r, m_ptr);
|
||||
LoadStep4(v_r, v_ptr);
|
||||
|
||||
m_r[0].data = _mm512_mul_ps(m_r[0].data, beta1_r.data);
|
||||
m_r[1].data = _mm512_mul_ps(m_r[1].data, beta1_r.data);
|
||||
m_r[2].data = _mm512_mul_ps(m_r[2].data, beta1_r.data);
|
||||
m_r[3].data = _mm512_mul_ps(m_r[3].data, beta1_r.data);
|
||||
m_r[0].data = _mm512_fmadd_ps(g_r[0].data, beta1_minus_r.data, m_r[0].data);
|
||||
m_r[1].data = _mm512_fmadd_ps(g_r[1].data, beta1_minus_r.data, m_r[1].data);
|
||||
m_r[2].data = _mm512_fmadd_ps(g_r[2].data, beta1_minus_r.data, m_r[2].data);
|
||||
m_r[3].data = _mm512_fmadd_ps(g_r[3].data, beta1_minus_r.data, m_r[3].data);
|
||||
|
||||
v_r[0].data = _mm512_mul_ps(v_r[0].data, beta2_r.data);
|
||||
v_r[1].data = _mm512_mul_ps(v_r[1].data, beta2_r.data);
|
||||
v_r[2].data = _mm512_mul_ps(v_r[2].data, beta2_r.data);
|
||||
v_r[3].data = _mm512_mul_ps(v_r[3].data, beta2_r.data);
|
||||
g_r[0].data = _mm512_mul_ps(g_r[0].data, g_r[0].data);
|
||||
g_r[1].data = _mm512_mul_ps(g_r[1].data, g_r[1].data);
|
||||
g_r[2].data = _mm512_mul_ps(g_r[2].data, g_r[2].data);
|
||||
g_r[3].data = _mm512_mul_ps(g_r[3].data, g_r[3].data);
|
||||
v_r[0].data = _mm512_fmadd_ps(g_r[0].data, beta2_minus_r.data, v_r[0].data);
|
||||
v_r[1].data = _mm512_fmadd_ps(g_r[1].data, beta2_minus_r.data, v_r[1].data);
|
||||
v_r[2].data = _mm512_fmadd_ps(g_r[2].data, beta2_minus_r.data, v_r[2].data);
|
||||
v_r[3].data = _mm512_fmadd_ps(g_r[3].data, beta2_minus_r.data, v_r[3].data);
|
||||
|
||||
g_r[0].data = _mm512_sqrt_ps(v_r[0].data);
|
||||
g_r[1].data = _mm512_sqrt_ps(v_r[1].data);
|
||||
g_r[2].data = _mm512_sqrt_ps(v_r[2].data);
|
||||
g_r[3].data = _mm512_sqrt_ps(v_r[3].data);
|
||||
g_r[0].data = _mm512_div_ps(m_r[0].data, _mm512_add_ps(g_r[0].data, epsilon_r.data));
|
||||
g_r[1].data = _mm512_div_ps(m_r[1].data, _mm512_add_ps(g_r[1].data, epsilon_r.data));
|
||||
g_r[2].data = _mm512_div_ps(m_r[2].data, _mm512_add_ps(g_r[2].data, epsilon_r.data));
|
||||
g_r[3].data = _mm512_div_ps(m_r[3].data, _mm512_add_ps(g_r[3].data, epsilon_r.data));
|
||||
g_r[0].data = _mm512_fmadd_ps(var_r[0].data, decay_r.data, g_r[0].data);
|
||||
g_r[1].data = _mm512_fmadd_ps(var_r[1].data, decay_r.data, g_r[1].data);
|
||||
g_r[2].data = _mm512_fmadd_ps(var_r[2].data, decay_r.data, g_r[2].data);
|
||||
g_r[3].data = _mm512_fmadd_ps(var_r[3].data, decay_r.data, g_r[3].data);
|
||||
|
||||
var_r[0].data = _mm512_fmadd_ps(g_r[0].data, lr_neg_r.data, var_r[0].data);
|
||||
var_r[1].data = _mm512_fmadd_ps(g_r[1].data, lr_neg_r.data, var_r[1].data);
|
||||
var_r[2].data = _mm512_fmadd_ps(g_r[2].data, lr_neg_r.data, var_r[2].data);
|
||||
var_r[3].data = _mm512_fmadd_ps(g_r[3].data, lr_neg_r.data, var_r[3].data);
|
||||
|
||||
StoreStep4(var_ptr, var_r);
|
||||
StoreStep4(m_ptr, m_r);
|
||||
StoreStep4(v_ptr, v_r);
|
||||
|
||||
gradient16_ptr += C64NUM;
|
||||
var_ptr += C64NUM;
|
||||
m_ptr += C64NUM;
|
||||
v_ptr += C64NUM;
|
||||
}
|
||||
for (; c1 < c16; c1 += C16NUM) {
|
||||
struct AVX_Data g_r, var_r, m_r, v_r;
|
||||
g_r.data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr)));
|
||||
var_r.data = _mm512_loadu_ps(var_ptr);
|
||||
m_r.data = _mm512_loadu_ps(m_ptr);
|
||||
v_r.data = _mm512_loadu_ps(v_ptr);
|
||||
|
||||
m_r.data = _mm512_mul_ps(m_r.data, beta1_r.data);
|
||||
v_r.data = _mm512_mul_ps(v_r.data, beta2_r.data);
|
||||
struct AVX_Data avx_r0;
|
||||
avx_r0.data = _mm512_mul_ps(g_r.data, g_r.data);
|
||||
m_r.data = _mm512_fmadd_ps(g_r.data, beta1_minus_r.data, m_r.data);
|
||||
v_r.data = _mm512_fmadd_ps(avx_r0.data, beta2_minus_r.data, v_r.data);
|
||||
avx_r0.data = _mm512_sqrt_ps(v_r.data);
|
||||
avx_r0.data = _mm512_div_ps(m_r.data, _mm512_add_ps(avx_r0.data, epsilon_r.data));
|
||||
avx_r0.data = _mm512_fmadd_ps(var_r.data, decay_r.data, avx_r0.data);
|
||||
var_r.data = _mm512_fmadd_ps(avx_r0.data, lr_neg_r.data, var_r.data);
|
||||
_mm512_storeu_ps(var_ptr, var_r.data);
|
||||
_mm512_storeu_ps(m_ptr, m_r.data);
|
||||
_mm512_storeu_ps(v_ptr, v_r.data);
|
||||
|
||||
gradient16_ptr += C16NUM;
|
||||
var_ptr += C16NUM;
|
||||
m_ptr += C16NUM;
|
||||
v_ptr += C16NUM;
|
||||
}
|
||||
#endif
|
||||
return c1;
|
||||
}
|
||||
|
|
|
@ -26,6 +26,10 @@ int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2,
|
|||
size_t start, size_t end, bool use_nesterov);
|
||||
int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
|
||||
const float *gradient, size_t start, size_t end, bool use_nesterov);
|
||||
int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float *decay,
|
||||
const float *gradient, size_t start, size_t end);
|
||||
int FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float *decay,
|
||||
const int16_t *gradient16, size_t start, size_t end);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
#define C16NUM 16
|
||||
#define C24NUM 24
|
||||
#define C32NUM 32
|
||||
#define C64NUM 64
|
||||
#define TILE_NUM 8
|
||||
|
||||
#define MSMIN(x, y) ((x) < (y) ? (x) : (y))
|
||||
|
|
|
@ -533,8 +533,9 @@ class AdamWeightDecay(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype,
|
||||
epsilon_dtype, decay, grad_dtype):
|
||||
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
|
||||
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
|
||||
validator.check_scalar_or_tensor_types_same({"grad": grad_dtype}, [mstype.float16, mstype.float32], self.name)
|
||||
|
||||
args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype,
|
||||
"decay": decay}
|
||||
|
|
Loading…
Reference in New Issue