optimize cpu adam op

This commit is contained in:
wangyanling 2021-04-22 19:33:57 +08:00
parent 2955f9e84a
commit ee1b803416
11 changed files with 287 additions and 95 deletions

View File

@ -422,7 +422,7 @@ build_mindspore()
echo "start build mindspore project."
mkdir -pv "${BUILD_PATH}/mindspore"
cd "${BUILD_PATH}/mindspore"
CMAKE_ARGS="-DDEBUG_MODE=$DEBUG_MODE -DBUILD_PATH=$BUILD_PATH"
CMAKE_ARGS="-DDEBUG_MODE=$DEBUG_MODE -DBUILD_PATH=$BUILD_PATH -DX86_64_SIMD=${X86_64_SIMD}"
if [[ "X$ENABLE_COVERAGE" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_COVERAGE=ON"
fi

View File

@ -7,11 +7,14 @@ if(ENABLE_CPU)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/backend/kernel_compiler/cpu)
if("${X86_64_SIMD}" STREQUAL "sse")
add_compile_definitions(ENABLE_SSE)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -msse4.2")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.1 -msse4.2")
endif()
if("${X86_64_SIMD}" STREQUAL "avx")
add_compile_definitions(ENABLE_SSE)
add_compile_definitions(ENABLE_AVX)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx2")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -msse4.2 -mfma -mavx -mavx2")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.1 -msse4.2 -mfma -mavx -mavx2")
endif()
add_subdirectory(backend/kernel_compiler/cpu/nnacl)
endif()

View File

@ -13,11 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/adam_cpu_kernel.h"
#include <cmath>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "backend/kernel_compiler/cpu/adam_cpu_kernel.h"
#include "nnacl/errorcode.h"
#include "nnacl/fp32/adam_fp32.h"
#include "utils/ms_utils.h"
namespace mindspore {
@ -25,23 +26,31 @@ namespace kernel {
template <typename T>
void AdamCPUKernel::LaunchAdam(T *var, T *m, T *v, float lr, float beta1, float beta2, float epsilon, const T *gradient,
size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
m[i] += (gradient[i] - m[i]) * (1 - beta1);
v[i] += (gradient[i] * gradient[i] - v[i]) * (1 - beta2);
if (use_nesterov) {
var[i] -= lr * (m[i] * beta1 + (1 - beta1) * gradient[i]) / (std::sqrt(v[i]) + epsilon);
} else {
var[i] -= lr * m[i] / (std::sqrt(v[i]) + epsilon);
std::function<void(size_t, size_t)> task;
if (dtype_ == kNumberTypeFloat32) {
task = [&](size_t start, size_t end) {
AdamFp32(var, m, v, lr, beta1, beta2, epsilon, gradient, start, end, use_nesterov_);
};
} else {
task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
m[i] += (gradient[i] - m[i]) * (1 - beta1);
v[i] += (gradient[i] * gradient[i] - v[i]) * (1 - beta2);
if (use_nesterov_) {
var[i] -= lr * (m[i] * beta1 + (1 - beta1) * gradient[i]) / (std::sqrt(v[i]) + epsilon);
} else {
var[i] -= lr * m[i] / (std::sqrt(v[i]) + epsilon);
}
}
}
};
};
}
CPUKernelUtils::ParallelFor(task, size);
}
void AdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
if (input_num != 10) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but Adam needs 10 inputs.";
}
@ -49,7 +58,7 @@ void AdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
if (output_num != 3) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but Adam needs 3 outputs.";
}
use_nesterov = AnfAlgo::GetNodeAttr<bool>(kernel_node, "use_nesterov");
use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "use_nesterov");
}
bool AdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
@ -83,7 +92,6 @@ bool AdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
MS_LOG(EXCEPTION) << "The beta1_power can't be set 1.";
}
float new_lr = lr * std::sqrt(1.0 - beta2_power) / (1 - beta1_power);
// multithreading
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
LaunchAdam<float>(var, m, v, new_lr, beta1, beta2, epsilon, gradient, lens);

View File

@ -36,7 +36,8 @@ class AdamCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override;
private:
bool use_nesterov{false};
bool use_nesterov_{false};
TypeId dtype_{kTypeUnknown};
};
MS_REG_CPU_KERNEL(Adam,

View File

@ -13,59 +13,45 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/adam_delta_cpu_kernel.h"
#include <thread>
#include <vector>
#include <string>
#include <memory>
#include "backend/kernel_compiler/common_utils.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "common/thread_pool.h"
#include "backend/kernel_compiler/cpu/adam_delta_cpu_kernel.h"
#include "nnacl/errorcode.h"
#include "nnacl/fp32/adam_fp32.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace kernel {
constexpr size_t kAdamDeltaInputSize = 9;
namespace {
struct ComputeParam {
float *delta_{nullptr};
float *m_{nullptr};
float *v_{nullptr};
float *grad_{nullptr};
float beta1_{0};
float beta2_{0};
float epsilon_{0};
float lr_{0};
bool use_nesterov_{0};
};
void ComputeWeightDelta(const std::shared_ptr<ComputeParam> &input_params, size_t start, size_t end) {
MS_EXCEPTION_IF_NULL(input_params);
MS_EXCEPTION_IF_NULL(input_params->delta_);
MS_EXCEPTION_IF_NULL(input_params->m_);
MS_EXCEPTION_IF_NULL(input_params->v_);
MS_EXCEPTION_IF_NULL(input_params->grad_);
auto delta = input_params->delta_;
auto m = input_params->m_;
auto v = input_params->v_;
auto lr = input_params->lr_;
auto beta1 = input_params->beta1_;
auto beta2 = input_params->beta2_;
auto epsilon = input_params->epsilon_;
auto use_nesterov = input_params->use_nesterov_;
auto grad = input_params->grad_;
for (size_t i = start; i < end; ++i) {
m[i] *= beta1;
v[i] *= beta2;
m[i] += (1 - beta1) * grad[i];
v[i] += (1 - beta2) * grad[i] * grad[i];
if (use_nesterov) {
delta[i] = -lr * (m[i] * beta1 + (1 - beta1) * grad[i]) / (std::sqrt(v[i]) + epsilon);
} else {
delta[i] = -lr * m[i] / (std::sqrt(v[i]) + epsilon);
}
template <typename T>
void AdamDeltaCPUKernel::LaunchAdamDelta(T *delta, T *m, T *v, float lr, float beta1, float beta2, float epsilon,
const T *gradient, size_t size) {
std::function<void(size_t, size_t)> task;
if (dtype_ == kNumberTypeFloat32) {
task = [&](size_t start, size_t end) {
AdamDeltaFp32(delta, m, v, lr, beta1, beta2, epsilon, gradient, start, end, use_nesterov_);
};
} else {
task = [&](size_t start, size_t end) {
for (size_t c1 = start; c1 < end; ++c1) {
m[c1] *= beta1;
m[c1] += (1 - beta1) * gradient[c1];
v[c1] *= beta2;
v[c1] += (1 - beta2) * gradient[c1] * gradient[c1];
if (use_nesterov_) {
delta[c1] = -lr * (m[c1] * beta1 + (1 - beta1) * gradient[c1]) / (std::sqrt(v[c1]) + epsilon);
} else {
delta[c1] = -lr * m[c1] / (std::sqrt(v[c1]) + epsilon);
}
}
};
}
CPUKernelUtils::ParallelFor(task, size);
}
} // namespace
void AdamDeltaCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
@ -73,6 +59,7 @@ void AdamDeltaCPUKernel::InitKernel(const CNodePtr &kernel_node) {
std::vector<size_t> m_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> v_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
std::vector<size_t> grad_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 8);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
if (!IsSameShape(delta_shape, m_shape)) {
MS_LOG(EXCEPTION) << "Delta and m should have the same shape";
}
@ -134,42 +121,15 @@ bool AdamDeltaCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
auto epsilon = reinterpret_cast<float *>(inputs[7]->addr)[0];
auto grad = reinterpret_cast<float *>(inputs[8]->addr);
auto delta = reinterpret_cast<float *>(outputs[0]->addr);
lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power);
size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
if (elem_num_ < thread_num) {
thread_num = elem_num_;
}
std::vector<common::Task> tasks;
std::vector<std::shared_ptr<ComputeParam>> thread_params;
tasks.reserve(thread_num);
MS_EXCEPTION_IF_NULL(m);
MS_EXCEPTION_IF_NULL(v);
MS_EXCEPTION_IF_NULL(grad);
MS_EXCEPTION_IF_NULL(delta);
size_t end = 0;
size_t offset = elem_num_ / thread_num;
size_t left = elem_num_ % thread_num;
for (size_t i = 0; i < thread_num; ++i) {
auto params = std::make_shared<ComputeParam>();
params->delta_ = delta;
params->m_ = m;
params->v_ = v;
params->grad_ = grad;
params->beta1_ = beta1;
params->beta2_ = beta2;
params->use_nesterov_ = use_nesterov_;
params->lr_ = lr;
params->epsilon_ = epsilon;
size_t start = end;
end = start + offset;
if (i < left) {
end += 1;
}
auto task = [&params, start, end]() {
ComputeWeightDelta(params, start, end);
return common::SUCCESS;
};
tasks.emplace_back(task);
thread_params.emplace_back(params);
}
common::ThreadPool::GetInstance().SyncRun(tasks);
lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power);
// multithreading
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
LaunchAdamDelta<float>(delta, m, v, lr, beta1, beta2, epsilon, grad, lens);
return true;
}
} // namespace kernel

View File

@ -32,8 +32,12 @@ class AdamDeltaCPUKernel : public CPUKernel {
protected:
void CheckParams(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) const;
template <typename T>
void LaunchAdamDelta(T *delta, T *m, T *v, float lr, float beta1, float beta2, float epsilon, const T *gradient,
size_t size);
bool use_nesterov_{false};
size_t elem_num_{0};
TypeId dtype_{kTypeUnknown};
};
MS_REG_CPU_KERNEL(AdamNoUpdateParam,

View File

@ -0,0 +1,182 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef ENABLE_SSE
#include <x86intrin.h>
#endif
#ifdef ENABLE_AVX
#include <immintrin.h>
#endif
#include <math.h>
#include "nnacl/fp32/exp_fp32.h"
#include "nnacl/fp32/adam_fp32.h"
#include "nnacl/op_base.h"
int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, const float *gradient,
size_t start, size_t end, bool use_nesterov) {
size_t c1 = start;
#ifdef ENABLE_AVX
float coeff1 = 1 - beta1;
float coeff2 = 1 - beta2;
const float *m_ptr = m;
const float *v_ptr = v;
float *var_ptr = var;
const float *gradient_ptr = gradient;
size_t c8 = ((end - start) / C8NUM) * C8NUM;
__m256 avx_r0, avx_r1, avx_r2, avx_r3, avx_r4, avx_r5, avx_r6, gradient_r;
for (; c1 < c8; c1 += C8NUM) {
avx_r0 = _mm256_set1_ps(coeff1);
gradient_r = _mm256_loadu_ps(gradient_ptr);
avx_r2 = _mm256_loadu_ps(m_ptr);
avx_r3 = _mm256_sub_ps(gradient_r, avx_r2);
avx_r4 = _mm256_mul_ps(avx_r3, avx_r0);
avx_r3 = _mm256_add_ps(avx_r4, avx_r2); // m[i]~m[i+8]
avx_r2 = _mm256_mul_ps(gradient_r, gradient_r);
avx_r4 = _mm256_loadu_ps(v_ptr);
avx_r5 = _mm256_sub_ps(avx_r2, avx_r4);
avx_r1 = _mm256_set1_ps(coeff2);
avx_r2 = _mm256_mul_ps(avx_r5, avx_r1);
avx_r5 = _mm256_add_ps(avx_r4, avx_r2); // v[i]~v[i+8]
if (use_nesterov) {
avx_r1 = _mm256_set1_ps(beta1);
avx_r2 = _mm256_mul_ps(avx_r3, avx_r1);
avx_r4 = _mm256_mul_ps(gradient_r, avx_r0);
avx_r6 = _mm256_add_ps(avx_r2, avx_r4);
avx_r0 = _mm256_set1_ps(lr);
avx_r2 = _mm256_mul_ps(avx_r6, avx_r0);
avx_r0 = _mm256_set1_ps(epsilon);
avx_r1 = _mm256_sqrt_ps(avx_r5);
avx_r4 = _mm256_add_ps(avx_r0, avx_r1);
avx_r0 = _mm256_div_ps(avx_r2, avx_r2);
avx_r1 = _mm256_loadu_ps(var_ptr);
avx_r2 = _mm256_sub_ps(avx_r1, avx_r0);
_mm256_storeu_ps(var_ptr, avx_r2);
} else {
avx_r0 = _mm256_set1_ps(lr);
avx_r1 = _mm256_mul_ps(avx_r3, avx_r0);
avx_r0 = _mm256_set1_ps(epsilon);
avx_r2 = _mm256_sqrt_ps(avx_r5);
avx_r4 = _mm256_add_ps(avx_r0, avx_r2);
avx_r0 = _mm256_div_ps(avx_r1, avx_r4);
avx_r1 = _mm256_loadu_ps(var_ptr);
avx_r3 = _mm256_sub_ps(avx_r1, avx_r0);
_mm256_storeu_ps(var_ptr, avx_r3);
}
m_ptr += C8NUM;
v_ptr += C8NUM;
var_ptr += C8NUM;
gradient_ptr += C8NUM;
}
#endif
// remaining
for (; c1 < end; c1++) {
m[c1] += (gradient[c1] - m[c1]) * (1 - beta1);
v[c1] += (gradient[c1] * gradient[c1] - v[c1]) * (1 - beta2);
if (use_nesterov) {
var[c1] -= lr * (m[c1] * beta1 + (1 - beta1) * gradient[c1]) / (sqrt(v[c1]) + epsilon);
} else {
var[c1] -= lr * m[c1] / (sqrt(v[c1]) + epsilon);
}
}
return NNACL_OK;
}
int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
const float *gradient, size_t start, size_t end, bool use_nesterov) {
size_t c1 = 0;
#ifdef ENABLE_AVX
float coeff1 = 1 - beta1;
float coeff2 = 1 - beta2;
float *m_ptr = m;
float *v_ptr = v;
float *delta_ptr = delta;
const float *gradient_ptr = gradient;
size_t c8 = ((end - start) / C8NUM) * C8NUM;
__m256 gradient_r0, m_r1, v_r2, beta1_r3, beta2_r4, var_r5, var_r6, var_r7;
for (; c1 < c8 + start; c1 += C8NUM) {
gradient_r0 = _mm256_loadu_ps(gradient_ptr); // static
beta1_r3 = _mm256_set1_ps(beta1); // static
var_r5 = _mm256_loadu_ps(m_ptr);
var_r6 = _mm256_mul_ps(beta1_r3, var_r5); // m[i] = m[i] * beta1
var_r7 = _mm256_set1_ps(coeff1);
var_r5 = _mm256_mul_ps(var_r7, gradient_r0); //
m_r1 = _mm256_add_ps(var_r6, var_r5);
_mm256_storeu_ps(m_ptr, m_r1);
beta2_r4 = _mm256_set1_ps(beta2); // static
var_r5 = _mm256_loadu_ps(v_ptr);
var_r6 = _mm256_mul_ps(beta2_r4, var_r5); // v[i] * beta2
var_r7 = _mm256_set1_ps(coeff2);
var_r5 = _mm256_mul_ps(var_r7, gradient_r0);
var_r7 = _mm256_mul_ps(var_r5, gradient_r0);
v_r2 = _mm256_add_ps(var_r7, var_r6);
_mm256_storeu_ps(v_ptr, v_r2);
if (use_nesterov) {
var_r5 = _mm256_mul_ps(beta1_r3, m_r1);
var_r6 = _mm256_set1_ps(coeff1);
var_r7 = _mm256_mul_ps(gradient_r0, var_r6);
var_r6 = _mm256_add_ps(var_r5, var_r7); // m[i] * beta1 + (1 - beta1) * grad[i]
var_r5 = _mm256_set1_ps(lr);
var_r7 = _mm256_mul_ps(var_r6, var_r5);
var_r5 = _mm256_set1_ps(epsilon);
var_r6 = _mm256_sqrt_ps(v_r2);
v_r2 = _mm256_add_ps(var_r5, var_r6);
var_r5 = _mm256_div_ps(var_r7, v_r2);
var_r6 = _mm256_set1_ps(0.f);
var_r7 = _mm256_sub_ps(var_r6, var_r5);
_mm256_storeu_ps(delta_ptr, var_r7);
} else {
var_r5 = _mm256_set1_ps(lr);
var_r6 = _mm256_mul_ps(var_r5, m_r1);
var_r7 = _mm256_set1_ps(epsilon);
var_r5 = _mm256_sqrt_ps(v_r2);
v_r2 = _mm256_add_ps(var_r5, var_r7);
var_r5 = _mm256_div_ps(var_r6, v_r2);
var_r6 = _mm256_set1_ps(0.f);
var_r7 = _mm256_sub_ps(var_r6, var_r5);
_mm256_storeu_ps(delta_ptr, var_r7);
}
}
#endif
// remaining
for (; c1 < end; ++c1) {
m[c1] *= beta1;
m[c1] += (1 - beta1) * gradient[c1];
v[c1] *= beta2;
v[c1] += (1 - beta2) * gradient[c1] * gradient[c1];
if (use_nesterov) {
delta[c1] = -lr * (m[c1] * beta1 + (1 - beta1) * gradient[c1]) / (sqrt(v[c1]) + epsilon);
} else {
delta[c1] = -lr * m[c1] / (sqrt(v[c1]) + epsilon);
}
}
return NNACL_OK;
}

View File

@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_ADAM_FP32_H
#define MINDSPORE_NNACL_ADAM_FP32_H
#include <math.h>
#include "nnacl/op_base.h"
#include "nnacl/errorcode.h"
#ifdef __cplusplus
extern "C" {
#endif
int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, const float *gradient,
size_t start, size_t end, bool use_nesterov);
int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
const float *gradient, size_t start, size_t end, bool use_nesterov);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_NNACL_ADAM_FP32_H

View File

@ -401,7 +401,7 @@ class GELU(Cell):
TypeError: If dtype of `input_data` is neither float16 nor float32.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)

View File

@ -809,7 +809,7 @@ class LayerNorm(Cell):
TypeError: If `epsilon` is not a float.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)

View File

@ -24,6 +24,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/stub/runtime/)
include_directories(${CMAKE_BINARY_DIR})
include_directories(${CUDA_INCLUDE_DIRS})
include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/backend/kernel_compiler/cpu)
MESSAGE("check ut_test ${CMAKE_BINARY_DIR}")
link_directories(${MS_CCSRC_BUILD_PATH})
@ -150,6 +151,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/transform/graph_ir/op_declare/*.cc"
"../../../mindspore/ccsrc/ps/*.cc"
"../../../mindspore/ccsrc/profiler/device/common/*.cc"
"../../../mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/adam_fp32.c"
)
list(REMOVE_ITEM MINDSPORE_SRC_LIST