optimize cpu adam op
This commit is contained in:
parent
2955f9e84a
commit
ee1b803416
2
build.sh
2
build.sh
|
@ -422,7 +422,7 @@ build_mindspore()
|
|||
echo "start build mindspore project."
|
||||
mkdir -pv "${BUILD_PATH}/mindspore"
|
||||
cd "${BUILD_PATH}/mindspore"
|
||||
CMAKE_ARGS="-DDEBUG_MODE=$DEBUG_MODE -DBUILD_PATH=$BUILD_PATH"
|
||||
CMAKE_ARGS="-DDEBUG_MODE=$DEBUG_MODE -DBUILD_PATH=$BUILD_PATH -DX86_64_SIMD=${X86_64_SIMD}"
|
||||
if [[ "X$ENABLE_COVERAGE" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_COVERAGE=ON"
|
||||
fi
|
||||
|
|
|
@ -7,11 +7,14 @@ if(ENABLE_CPU)
|
|||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/backend/kernel_compiler/cpu)
|
||||
if("${X86_64_SIMD}" STREQUAL "sse")
|
||||
add_compile_definitions(ENABLE_SSE)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -msse4.2")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.1 -msse4.2")
|
||||
endif()
|
||||
if("${X86_64_SIMD}" STREQUAL "avx")
|
||||
add_compile_definitions(ENABLE_SSE)
|
||||
add_compile_definitions(ENABLE_AVX)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx2")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -msse4.2 -mfma -mavx -mavx2")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.1 -msse4.2 -mfma -mavx -mavx2")
|
||||
endif()
|
||||
add_subdirectory(backend/kernel_compiler/cpu/nnacl)
|
||||
endif()
|
||||
|
|
|
@ -13,11 +13,12 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/cpu/adam_cpu_kernel.h"
|
||||
|
||||
#include <cmath>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "backend/kernel_compiler/cpu/adam_cpu_kernel.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/fp32/adam_fp32.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -25,23 +26,31 @@ namespace kernel {
|
|||
template <typename T>
|
||||
void AdamCPUKernel::LaunchAdam(T *var, T *m, T *v, float lr, float beta1, float beta2, float epsilon, const T *gradient,
|
||||
size_t size) {
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
std::function<void(size_t, size_t)> task;
|
||||
if (dtype_ == kNumberTypeFloat32) {
|
||||
task = [&](size_t start, size_t end) {
|
||||
AdamFp32(var, m, v, lr, beta1, beta2, epsilon, gradient, start, end, use_nesterov_);
|
||||
};
|
||||
} else {
|
||||
task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
m[i] += (gradient[i] - m[i]) * (1 - beta1);
|
||||
v[i] += (gradient[i] * gradient[i] - v[i]) * (1 - beta2);
|
||||
if (use_nesterov) {
|
||||
if (use_nesterov_) {
|
||||
var[i] -= lr * (m[i] * beta1 + (1 - beta1) * gradient[i]) / (std::sqrt(v[i]) + epsilon);
|
||||
} else {
|
||||
var[i] -= lr * m[i] / (std::sqrt(v[i]) + epsilon);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
CPUKernelUtils::ParallelFor(task, size);
|
||||
}
|
||||
|
||||
void AdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
if (input_num != 10) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but Adam needs 10 inputs.";
|
||||
}
|
||||
|
@ -49,7 +58,7 @@ void AdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
if (output_num != 3) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but Adam needs 3 outputs.";
|
||||
}
|
||||
use_nesterov = AnfAlgo::GetNodeAttr<bool>(kernel_node, "use_nesterov");
|
||||
use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "use_nesterov");
|
||||
}
|
||||
|
||||
bool AdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
|
@ -83,7 +92,6 @@ bool AdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
MS_LOG(EXCEPTION) << "The beta1_power can't be set 1.";
|
||||
}
|
||||
float new_lr = lr * std::sqrt(1.0 - beta2_power) / (1 - beta1_power);
|
||||
|
||||
// multithreading
|
||||
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
|
||||
LaunchAdam<float>(var, m, v, new_lr, beta1, beta2, epsilon, gradient, lens);
|
||||
|
|
|
@ -36,7 +36,8 @@ class AdamCPUKernel : public CPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
bool use_nesterov{false};
|
||||
bool use_nesterov_{false};
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Adam,
|
||||
|
|
|
@ -13,59 +13,45 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/cpu/adam_delta_cpu_kernel.h"
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "common/thread_pool.h"
|
||||
#include "backend/kernel_compiler/cpu/adam_delta_cpu_kernel.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/fp32/adam_fp32.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t kAdamDeltaInputSize = 9;
|
||||
namespace {
|
||||
struct ComputeParam {
|
||||
float *delta_{nullptr};
|
||||
float *m_{nullptr};
|
||||
float *v_{nullptr};
|
||||
float *grad_{nullptr};
|
||||
float beta1_{0};
|
||||
float beta2_{0};
|
||||
float epsilon_{0};
|
||||
float lr_{0};
|
||||
bool use_nesterov_{0};
|
||||
};
|
||||
|
||||
void ComputeWeightDelta(const std::shared_ptr<ComputeParam> &input_params, size_t start, size_t end) {
|
||||
MS_EXCEPTION_IF_NULL(input_params);
|
||||
MS_EXCEPTION_IF_NULL(input_params->delta_);
|
||||
MS_EXCEPTION_IF_NULL(input_params->m_);
|
||||
MS_EXCEPTION_IF_NULL(input_params->v_);
|
||||
MS_EXCEPTION_IF_NULL(input_params->grad_);
|
||||
auto delta = input_params->delta_;
|
||||
auto m = input_params->m_;
|
||||
auto v = input_params->v_;
|
||||
auto lr = input_params->lr_;
|
||||
auto beta1 = input_params->beta1_;
|
||||
auto beta2 = input_params->beta2_;
|
||||
auto epsilon = input_params->epsilon_;
|
||||
auto use_nesterov = input_params->use_nesterov_;
|
||||
auto grad = input_params->grad_;
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
m[i] *= beta1;
|
||||
v[i] *= beta2;
|
||||
m[i] += (1 - beta1) * grad[i];
|
||||
v[i] += (1 - beta2) * grad[i] * grad[i];
|
||||
if (use_nesterov) {
|
||||
delta[i] = -lr * (m[i] * beta1 + (1 - beta1) * grad[i]) / (std::sqrt(v[i]) + epsilon);
|
||||
template <typename T>
|
||||
void AdamDeltaCPUKernel::LaunchAdamDelta(T *delta, T *m, T *v, float lr, float beta1, float beta2, float epsilon,
|
||||
const T *gradient, size_t size) {
|
||||
std::function<void(size_t, size_t)> task;
|
||||
if (dtype_ == kNumberTypeFloat32) {
|
||||
task = [&](size_t start, size_t end) {
|
||||
AdamDeltaFp32(delta, m, v, lr, beta1, beta2, epsilon, gradient, start, end, use_nesterov_);
|
||||
};
|
||||
} else {
|
||||
delta[i] = -lr * m[i] / (std::sqrt(v[i]) + epsilon);
|
||||
task = [&](size_t start, size_t end) {
|
||||
for (size_t c1 = start; c1 < end; ++c1) {
|
||||
m[c1] *= beta1;
|
||||
m[c1] += (1 - beta1) * gradient[c1];
|
||||
v[c1] *= beta2;
|
||||
v[c1] += (1 - beta2) * gradient[c1] * gradient[c1];
|
||||
if (use_nesterov_) {
|
||||
delta[c1] = -lr * (m[c1] * beta1 + (1 - beta1) * gradient[c1]) / (std::sqrt(v[c1]) + epsilon);
|
||||
} else {
|
||||
delta[c1] = -lr * m[c1] / (std::sqrt(v[c1]) + epsilon);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
CPUKernelUtils::ParallelFor(task, size);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void AdamDeltaCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
|
@ -73,6 +59,7 @@ void AdamDeltaCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
std::vector<size_t> m_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
std::vector<size_t> v_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
std::vector<size_t> grad_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 8);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
if (!IsSameShape(delta_shape, m_shape)) {
|
||||
MS_LOG(EXCEPTION) << "Delta and m should have the same shape";
|
||||
}
|
||||
|
@ -134,42 +121,15 @@ bool AdamDeltaCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
auto epsilon = reinterpret_cast<float *>(inputs[7]->addr)[0];
|
||||
auto grad = reinterpret_cast<float *>(inputs[8]->addr);
|
||||
auto delta = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power);
|
||||
size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
|
||||
if (elem_num_ < thread_num) {
|
||||
thread_num = elem_num_;
|
||||
}
|
||||
std::vector<common::Task> tasks;
|
||||
std::vector<std::shared_ptr<ComputeParam>> thread_params;
|
||||
tasks.reserve(thread_num);
|
||||
MS_EXCEPTION_IF_NULL(m);
|
||||
MS_EXCEPTION_IF_NULL(v);
|
||||
MS_EXCEPTION_IF_NULL(grad);
|
||||
MS_EXCEPTION_IF_NULL(delta);
|
||||
|
||||
size_t end = 0;
|
||||
size_t offset = elem_num_ / thread_num;
|
||||
size_t left = elem_num_ % thread_num;
|
||||
for (size_t i = 0; i < thread_num; ++i) {
|
||||
auto params = std::make_shared<ComputeParam>();
|
||||
params->delta_ = delta;
|
||||
params->m_ = m;
|
||||
params->v_ = v;
|
||||
params->grad_ = grad;
|
||||
params->beta1_ = beta1;
|
||||
params->beta2_ = beta2;
|
||||
params->use_nesterov_ = use_nesterov_;
|
||||
params->lr_ = lr;
|
||||
params->epsilon_ = epsilon;
|
||||
size_t start = end;
|
||||
end = start + offset;
|
||||
if (i < left) {
|
||||
end += 1;
|
||||
}
|
||||
auto task = [¶ms, start, end]() {
|
||||
ComputeWeightDelta(params, start, end);
|
||||
return common::SUCCESS;
|
||||
};
|
||||
tasks.emplace_back(task);
|
||||
thread_params.emplace_back(params);
|
||||
}
|
||||
common::ThreadPool::GetInstance().SyncRun(tasks);
|
||||
lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power);
|
||||
// multithreading
|
||||
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
|
||||
LaunchAdamDelta<float>(delta, m, v, lr, beta1, beta2, epsilon, grad, lens);
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
|
|
|
@ -32,8 +32,12 @@ class AdamDeltaCPUKernel : public CPUKernel {
|
|||
protected:
|
||||
void CheckParams(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) const;
|
||||
template <typename T>
|
||||
void LaunchAdamDelta(T *delta, T *m, T *v, float lr, float beta1, float beta2, float epsilon, const T *gradient,
|
||||
size_t size);
|
||||
bool use_nesterov_{false};
|
||||
size_t elem_num_{0};
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(AdamNoUpdateParam,
|
||||
|
|
|
@ -0,0 +1,182 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifdef ENABLE_SSE
|
||||
#include <x86intrin.h>
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
|
||||
#include <math.h>
|
||||
#include "nnacl/fp32/exp_fp32.h"
|
||||
#include "nnacl/fp32/adam_fp32.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, const float *gradient,
|
||||
size_t start, size_t end, bool use_nesterov) {
|
||||
size_t c1 = start;
|
||||
#ifdef ENABLE_AVX
|
||||
float coeff1 = 1 - beta1;
|
||||
float coeff2 = 1 - beta2;
|
||||
const float *m_ptr = m;
|
||||
const float *v_ptr = v;
|
||||
float *var_ptr = var;
|
||||
const float *gradient_ptr = gradient;
|
||||
size_t c8 = ((end - start) / C8NUM) * C8NUM;
|
||||
__m256 avx_r0, avx_r1, avx_r2, avx_r3, avx_r4, avx_r5, avx_r6, gradient_r;
|
||||
|
||||
for (; c1 < c8; c1 += C8NUM) {
|
||||
avx_r0 = _mm256_set1_ps(coeff1);
|
||||
gradient_r = _mm256_loadu_ps(gradient_ptr);
|
||||
avx_r2 = _mm256_loadu_ps(m_ptr);
|
||||
avx_r3 = _mm256_sub_ps(gradient_r, avx_r2);
|
||||
avx_r4 = _mm256_mul_ps(avx_r3, avx_r0);
|
||||
avx_r3 = _mm256_add_ps(avx_r4, avx_r2); // m[i]~m[i+8]
|
||||
|
||||
avx_r2 = _mm256_mul_ps(gradient_r, gradient_r);
|
||||
avx_r4 = _mm256_loadu_ps(v_ptr);
|
||||
avx_r5 = _mm256_sub_ps(avx_r2, avx_r4);
|
||||
avx_r1 = _mm256_set1_ps(coeff2);
|
||||
avx_r2 = _mm256_mul_ps(avx_r5, avx_r1);
|
||||
avx_r5 = _mm256_add_ps(avx_r4, avx_r2); // v[i]~v[i+8]
|
||||
|
||||
if (use_nesterov) {
|
||||
avx_r1 = _mm256_set1_ps(beta1);
|
||||
avx_r2 = _mm256_mul_ps(avx_r3, avx_r1);
|
||||
avx_r4 = _mm256_mul_ps(gradient_r, avx_r0);
|
||||
avx_r6 = _mm256_add_ps(avx_r2, avx_r4);
|
||||
avx_r0 = _mm256_set1_ps(lr);
|
||||
avx_r2 = _mm256_mul_ps(avx_r6, avx_r0);
|
||||
|
||||
avx_r0 = _mm256_set1_ps(epsilon);
|
||||
avx_r1 = _mm256_sqrt_ps(avx_r5);
|
||||
avx_r4 = _mm256_add_ps(avx_r0, avx_r1);
|
||||
|
||||
avx_r0 = _mm256_div_ps(avx_r2, avx_r2);
|
||||
avx_r1 = _mm256_loadu_ps(var_ptr);
|
||||
avx_r2 = _mm256_sub_ps(avx_r1, avx_r0);
|
||||
_mm256_storeu_ps(var_ptr, avx_r2);
|
||||
} else {
|
||||
avx_r0 = _mm256_set1_ps(lr);
|
||||
avx_r1 = _mm256_mul_ps(avx_r3, avx_r0);
|
||||
|
||||
avx_r0 = _mm256_set1_ps(epsilon);
|
||||
avx_r2 = _mm256_sqrt_ps(avx_r5);
|
||||
avx_r4 = _mm256_add_ps(avx_r0, avx_r2);
|
||||
|
||||
avx_r0 = _mm256_div_ps(avx_r1, avx_r4);
|
||||
avx_r1 = _mm256_loadu_ps(var_ptr);
|
||||
avx_r3 = _mm256_sub_ps(avx_r1, avx_r0);
|
||||
_mm256_storeu_ps(var_ptr, avx_r3);
|
||||
}
|
||||
m_ptr += C8NUM;
|
||||
v_ptr += C8NUM;
|
||||
var_ptr += C8NUM;
|
||||
gradient_ptr += C8NUM;
|
||||
}
|
||||
#endif
|
||||
|
||||
// remaining
|
||||
for (; c1 < end; c1++) {
|
||||
m[c1] += (gradient[c1] - m[c1]) * (1 - beta1);
|
||||
v[c1] += (gradient[c1] * gradient[c1] - v[c1]) * (1 - beta2);
|
||||
if (use_nesterov) {
|
||||
var[c1] -= lr * (m[c1] * beta1 + (1 - beta1) * gradient[c1]) / (sqrt(v[c1]) + epsilon);
|
||||
} else {
|
||||
var[c1] -= lr * m[c1] / (sqrt(v[c1]) + epsilon);
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
|
||||
const float *gradient, size_t start, size_t end, bool use_nesterov) {
|
||||
size_t c1 = 0;
|
||||
#ifdef ENABLE_AVX
|
||||
float coeff1 = 1 - beta1;
|
||||
float coeff2 = 1 - beta2;
|
||||
float *m_ptr = m;
|
||||
float *v_ptr = v;
|
||||
float *delta_ptr = delta;
|
||||
const float *gradient_ptr = gradient;
|
||||
size_t c8 = ((end - start) / C8NUM) * C8NUM;
|
||||
|
||||
__m256 gradient_r0, m_r1, v_r2, beta1_r3, beta2_r4, var_r5, var_r6, var_r7;
|
||||
for (; c1 < c8 + start; c1 += C8NUM) {
|
||||
gradient_r0 = _mm256_loadu_ps(gradient_ptr); // static
|
||||
beta1_r3 = _mm256_set1_ps(beta1); // static
|
||||
var_r5 = _mm256_loadu_ps(m_ptr);
|
||||
var_r6 = _mm256_mul_ps(beta1_r3, var_r5); // m[i] = m[i] * beta1
|
||||
var_r7 = _mm256_set1_ps(coeff1);
|
||||
var_r5 = _mm256_mul_ps(var_r7, gradient_r0); //
|
||||
m_r1 = _mm256_add_ps(var_r6, var_r5);
|
||||
_mm256_storeu_ps(m_ptr, m_r1);
|
||||
|
||||
beta2_r4 = _mm256_set1_ps(beta2); // static
|
||||
var_r5 = _mm256_loadu_ps(v_ptr);
|
||||
var_r6 = _mm256_mul_ps(beta2_r4, var_r5); // v[i] * beta2
|
||||
var_r7 = _mm256_set1_ps(coeff2);
|
||||
var_r5 = _mm256_mul_ps(var_r7, gradient_r0);
|
||||
var_r7 = _mm256_mul_ps(var_r5, gradient_r0);
|
||||
v_r2 = _mm256_add_ps(var_r7, var_r6);
|
||||
_mm256_storeu_ps(v_ptr, v_r2);
|
||||
|
||||
if (use_nesterov) {
|
||||
var_r5 = _mm256_mul_ps(beta1_r3, m_r1);
|
||||
var_r6 = _mm256_set1_ps(coeff1);
|
||||
var_r7 = _mm256_mul_ps(gradient_r0, var_r6);
|
||||
var_r6 = _mm256_add_ps(var_r5, var_r7); // m[i] * beta1 + (1 - beta1) * grad[i]
|
||||
var_r5 = _mm256_set1_ps(lr);
|
||||
var_r7 = _mm256_mul_ps(var_r6, var_r5);
|
||||
|
||||
var_r5 = _mm256_set1_ps(epsilon);
|
||||
var_r6 = _mm256_sqrt_ps(v_r2);
|
||||
v_r2 = _mm256_add_ps(var_r5, var_r6);
|
||||
var_r5 = _mm256_div_ps(var_r7, v_r2);
|
||||
var_r6 = _mm256_set1_ps(0.f);
|
||||
var_r7 = _mm256_sub_ps(var_r6, var_r5);
|
||||
_mm256_storeu_ps(delta_ptr, var_r7);
|
||||
} else {
|
||||
var_r5 = _mm256_set1_ps(lr);
|
||||
var_r6 = _mm256_mul_ps(var_r5, m_r1);
|
||||
|
||||
var_r7 = _mm256_set1_ps(epsilon);
|
||||
var_r5 = _mm256_sqrt_ps(v_r2);
|
||||
v_r2 = _mm256_add_ps(var_r5, var_r7);
|
||||
|
||||
var_r5 = _mm256_div_ps(var_r6, v_r2);
|
||||
var_r6 = _mm256_set1_ps(0.f);
|
||||
var_r7 = _mm256_sub_ps(var_r6, var_r5);
|
||||
_mm256_storeu_ps(delta_ptr, var_r7);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// remaining
|
||||
for (; c1 < end; ++c1) {
|
||||
m[c1] *= beta1;
|
||||
m[c1] += (1 - beta1) * gradient[c1];
|
||||
v[c1] *= beta2;
|
||||
v[c1] += (1 - beta2) * gradient[c1] * gradient[c1];
|
||||
if (use_nesterov) {
|
||||
delta[c1] = -lr * (m[c1] * beta1 + (1 - beta1) * gradient[c1]) / (sqrt(v[c1]) + epsilon);
|
||||
} else {
|
||||
delta[c1] = -lr * m[c1] / (sqrt(v[c1]) + epsilon);
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_NNACL_ADAM_FP32_H
|
||||
#define MINDSPORE_NNACL_ADAM_FP32_H
|
||||
#include <math.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, const float *gradient,
|
||||
size_t start, size_t end, bool use_nesterov);
|
||||
int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
|
||||
const float *gradient, size_t start, size_t end, bool use_nesterov);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_NNACL_ADAM_FP32_H
|
|
@ -401,7 +401,7 @@ class GELU(Cell):
|
|||
TypeError: If dtype of `input_data` is neither float16 nor float32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
|
||||
|
|
|
@ -809,7 +809,7 @@ class LayerNorm(Cell):
|
|||
TypeError: If `epsilon` is not a float.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
|
||||
|
|
|
@ -24,6 +24,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
|||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/stub/runtime/)
|
||||
include_directories(${CMAKE_BINARY_DIR})
|
||||
include_directories(${CUDA_INCLUDE_DIRS})
|
||||
include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/backend/kernel_compiler/cpu)
|
||||
MESSAGE("check ut_test ${CMAKE_BINARY_DIR}")
|
||||
|
||||
link_directories(${MS_CCSRC_BUILD_PATH})
|
||||
|
@ -150,6 +151,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"../../../mindspore/ccsrc/transform/graph_ir/op_declare/*.cc"
|
||||
"../../../mindspore/ccsrc/ps/*.cc"
|
||||
"../../../mindspore/ccsrc/profiler/device/common/*.cc"
|
||||
"../../../mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/adam_fp32.c"
|
||||
)
|
||||
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST
|
||||
|
|
Loading…
Reference in New Issue