optimize cpu adam op
This commit is contained in:
parent
2955f9e84a
commit
ee1b803416
2
build.sh
2
build.sh
|
@ -422,7 +422,7 @@ build_mindspore()
|
||||||
echo "start build mindspore project."
|
echo "start build mindspore project."
|
||||||
mkdir -pv "${BUILD_PATH}/mindspore"
|
mkdir -pv "${BUILD_PATH}/mindspore"
|
||||||
cd "${BUILD_PATH}/mindspore"
|
cd "${BUILD_PATH}/mindspore"
|
||||||
CMAKE_ARGS="-DDEBUG_MODE=$DEBUG_MODE -DBUILD_PATH=$BUILD_PATH"
|
CMAKE_ARGS="-DDEBUG_MODE=$DEBUG_MODE -DBUILD_PATH=$BUILD_PATH -DX86_64_SIMD=${X86_64_SIMD}"
|
||||||
if [[ "X$ENABLE_COVERAGE" = "Xon" ]]; then
|
if [[ "X$ENABLE_COVERAGE" = "Xon" ]]; then
|
||||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_COVERAGE=ON"
|
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_COVERAGE=ON"
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -7,11 +7,14 @@ if(ENABLE_CPU)
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/backend/kernel_compiler/cpu)
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/backend/kernel_compiler/cpu)
|
||||||
if("${X86_64_SIMD}" STREQUAL "sse")
|
if("${X86_64_SIMD}" STREQUAL "sse")
|
||||||
add_compile_definitions(ENABLE_SSE)
|
add_compile_definitions(ENABLE_SSE)
|
||||||
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -msse4.2")
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.1 -msse4.2")
|
||||||
endif()
|
endif()
|
||||||
if("${X86_64_SIMD}" STREQUAL "avx")
|
if("${X86_64_SIMD}" STREQUAL "avx")
|
||||||
add_compile_definitions(ENABLE_SSE)
|
add_compile_definitions(ENABLE_SSE)
|
||||||
add_compile_definitions(ENABLE_AVX)
|
add_compile_definitions(ENABLE_AVX)
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx2")
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -msse4.2 -mfma -mavx -mavx2")
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.1 -msse4.2 -mfma -mavx -mavx2")
|
||||||
endif()
|
endif()
|
||||||
add_subdirectory(backend/kernel_compiler/cpu/nnacl)
|
add_subdirectory(backend/kernel_compiler/cpu/nnacl)
|
||||||
endif()
|
endif()
|
||||||
|
|
|
@ -13,11 +13,12 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "backend/kernel_compiler/cpu/adam_cpu_kernel.h"
|
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||||
#include "runtime/device/cpu/cpu_device_address.h"
|
#include "runtime/device/cpu/cpu_device_address.h"
|
||||||
|
#include "backend/kernel_compiler/cpu/adam_cpu_kernel.h"
|
||||||
|
#include "nnacl/errorcode.h"
|
||||||
|
#include "nnacl/fp32/adam_fp32.h"
|
||||||
#include "utils/ms_utils.h"
|
#include "utils/ms_utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -25,23 +26,31 @@ namespace kernel {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void AdamCPUKernel::LaunchAdam(T *var, T *m, T *v, float lr, float beta1, float beta2, float epsilon, const T *gradient,
|
void AdamCPUKernel::LaunchAdam(T *var, T *m, T *v, float lr, float beta1, float beta2, float epsilon, const T *gradient,
|
||||||
size_t size) {
|
size_t size) {
|
||||||
auto task = [&](size_t start, size_t end) {
|
std::function<void(size_t, size_t)> task;
|
||||||
|
if (dtype_ == kNumberTypeFloat32) {
|
||||||
|
task = [&](size_t start, size_t end) {
|
||||||
|
AdamFp32(var, m, v, lr, beta1, beta2, epsilon, gradient, start, end, use_nesterov_);
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
task = [&](size_t start, size_t end) {
|
||||||
for (size_t i = start; i < end; i++) {
|
for (size_t i = start; i < end; i++) {
|
||||||
m[i] += (gradient[i] - m[i]) * (1 - beta1);
|
m[i] += (gradient[i] - m[i]) * (1 - beta1);
|
||||||
v[i] += (gradient[i] * gradient[i] - v[i]) * (1 - beta2);
|
v[i] += (gradient[i] * gradient[i] - v[i]) * (1 - beta2);
|
||||||
if (use_nesterov) {
|
if (use_nesterov_) {
|
||||||
var[i] -= lr * (m[i] * beta1 + (1 - beta1) * gradient[i]) / (std::sqrt(v[i]) + epsilon);
|
var[i] -= lr * (m[i] * beta1 + (1 - beta1) * gradient[i]) / (std::sqrt(v[i]) + epsilon);
|
||||||
} else {
|
} else {
|
||||||
var[i] -= lr * m[i] / (std::sqrt(v[i]) + epsilon);
|
var[i] -= lr * m[i] / (std::sqrt(v[i]) + epsilon);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
}
|
||||||
CPUKernelUtils::ParallelFor(task, size);
|
CPUKernelUtils::ParallelFor(task, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
void AdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
|
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||||
if (input_num != 10) {
|
if (input_num != 10) {
|
||||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but Adam needs 10 inputs.";
|
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but Adam needs 10 inputs.";
|
||||||
}
|
}
|
||||||
|
@ -49,7 +58,7 @@ void AdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
if (output_num != 3) {
|
if (output_num != 3) {
|
||||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but Adam needs 3 outputs.";
|
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but Adam needs 3 outputs.";
|
||||||
}
|
}
|
||||||
use_nesterov = AnfAlgo::GetNodeAttr<bool>(kernel_node, "use_nesterov");
|
use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "use_nesterov");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
bool AdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
|
@ -83,7 +92,6 @@ bool AdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
MS_LOG(EXCEPTION) << "The beta1_power can't be set 1.";
|
MS_LOG(EXCEPTION) << "The beta1_power can't be set 1.";
|
||||||
}
|
}
|
||||||
float new_lr = lr * std::sqrt(1.0 - beta2_power) / (1 - beta1_power);
|
float new_lr = lr * std::sqrt(1.0 - beta2_power) / (1 - beta1_power);
|
||||||
|
|
||||||
// multithreading
|
// multithreading
|
||||||
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
|
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
|
||||||
LaunchAdam<float>(var, m, v, new_lr, beta1, beta2, epsilon, gradient, lens);
|
LaunchAdam<float>(var, m, v, new_lr, beta1, beta2, epsilon, gradient, lens);
|
||||||
|
|
|
@ -36,7 +36,8 @@ class AdamCPUKernel : public CPUKernel {
|
||||||
const std::vector<AddressPtr> &outputs) override;
|
const std::vector<AddressPtr> &outputs) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool use_nesterov{false};
|
bool use_nesterov_{false};
|
||||||
|
TypeId dtype_{kTypeUnknown};
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(Adam,
|
MS_REG_CPU_KERNEL(Adam,
|
||||||
|
|
|
@ -13,59 +13,45 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "backend/kernel_compiler/cpu/adam_delta_cpu_kernel.h"
|
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "backend/kernel_compiler/common_utils.h"
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
#include "runtime/device/cpu/cpu_device_address.h"
|
#include "runtime/device/cpu/cpu_device_address.h"
|
||||||
#include "common/thread_pool.h"
|
#include "backend/kernel_compiler/cpu/adam_delta_cpu_kernel.h"
|
||||||
|
#include "nnacl/errorcode.h"
|
||||||
|
#include "nnacl/fp32/adam_fp32.h"
|
||||||
|
#include "utils/ms_utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
constexpr size_t kAdamDeltaInputSize = 9;
|
constexpr size_t kAdamDeltaInputSize = 9;
|
||||||
namespace {
|
template <typename T>
|
||||||
struct ComputeParam {
|
void AdamDeltaCPUKernel::LaunchAdamDelta(T *delta, T *m, T *v, float lr, float beta1, float beta2, float epsilon,
|
||||||
float *delta_{nullptr};
|
const T *gradient, size_t size) {
|
||||||
float *m_{nullptr};
|
std::function<void(size_t, size_t)> task;
|
||||||
float *v_{nullptr};
|
if (dtype_ == kNumberTypeFloat32) {
|
||||||
float *grad_{nullptr};
|
task = [&](size_t start, size_t end) {
|
||||||
float beta1_{0};
|
AdamDeltaFp32(delta, m, v, lr, beta1, beta2, epsilon, gradient, start, end, use_nesterov_);
|
||||||
float beta2_{0};
|
};
|
||||||
float epsilon_{0};
|
|
||||||
float lr_{0};
|
|
||||||
bool use_nesterov_{0};
|
|
||||||
};
|
|
||||||
|
|
||||||
void ComputeWeightDelta(const std::shared_ptr<ComputeParam> &input_params, size_t start, size_t end) {
|
|
||||||
MS_EXCEPTION_IF_NULL(input_params);
|
|
||||||
MS_EXCEPTION_IF_NULL(input_params->delta_);
|
|
||||||
MS_EXCEPTION_IF_NULL(input_params->m_);
|
|
||||||
MS_EXCEPTION_IF_NULL(input_params->v_);
|
|
||||||
MS_EXCEPTION_IF_NULL(input_params->grad_);
|
|
||||||
auto delta = input_params->delta_;
|
|
||||||
auto m = input_params->m_;
|
|
||||||
auto v = input_params->v_;
|
|
||||||
auto lr = input_params->lr_;
|
|
||||||
auto beta1 = input_params->beta1_;
|
|
||||||
auto beta2 = input_params->beta2_;
|
|
||||||
auto epsilon = input_params->epsilon_;
|
|
||||||
auto use_nesterov = input_params->use_nesterov_;
|
|
||||||
auto grad = input_params->grad_;
|
|
||||||
for (size_t i = start; i < end; ++i) {
|
|
||||||
m[i] *= beta1;
|
|
||||||
v[i] *= beta2;
|
|
||||||
m[i] += (1 - beta1) * grad[i];
|
|
||||||
v[i] += (1 - beta2) * grad[i] * grad[i];
|
|
||||||
if (use_nesterov) {
|
|
||||||
delta[i] = -lr * (m[i] * beta1 + (1 - beta1) * grad[i]) / (std::sqrt(v[i]) + epsilon);
|
|
||||||
} else {
|
} else {
|
||||||
delta[i] = -lr * m[i] / (std::sqrt(v[i]) + epsilon);
|
task = [&](size_t start, size_t end) {
|
||||||
|
for (size_t c1 = start; c1 < end; ++c1) {
|
||||||
|
m[c1] *= beta1;
|
||||||
|
m[c1] += (1 - beta1) * gradient[c1];
|
||||||
|
v[c1] *= beta2;
|
||||||
|
v[c1] += (1 - beta2) * gradient[c1] * gradient[c1];
|
||||||
|
if (use_nesterov_) {
|
||||||
|
delta[c1] = -lr * (m[c1] * beta1 + (1 - beta1) * gradient[c1]) / (std::sqrt(v[c1]) + epsilon);
|
||||||
|
} else {
|
||||||
|
delta[c1] = -lr * m[c1] / (std::sqrt(v[c1]) + epsilon);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
CPUKernelUtils::ParallelFor(task, size);
|
||||||
}
|
}
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void AdamDeltaCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
void AdamDeltaCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
|
@ -73,6 +59,7 @@ void AdamDeltaCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
std::vector<size_t> m_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
std::vector<size_t> m_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||||
std::vector<size_t> v_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
std::vector<size_t> v_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||||
std::vector<size_t> grad_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 8);
|
std::vector<size_t> grad_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 8);
|
||||||
|
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||||
if (!IsSameShape(delta_shape, m_shape)) {
|
if (!IsSameShape(delta_shape, m_shape)) {
|
||||||
MS_LOG(EXCEPTION) << "Delta and m should have the same shape";
|
MS_LOG(EXCEPTION) << "Delta and m should have the same shape";
|
||||||
}
|
}
|
||||||
|
@ -134,42 +121,15 @@ bool AdamDeltaCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
auto epsilon = reinterpret_cast<float *>(inputs[7]->addr)[0];
|
auto epsilon = reinterpret_cast<float *>(inputs[7]->addr)[0];
|
||||||
auto grad = reinterpret_cast<float *>(inputs[8]->addr);
|
auto grad = reinterpret_cast<float *>(inputs[8]->addr);
|
||||||
auto delta = reinterpret_cast<float *>(outputs[0]->addr);
|
auto delta = reinterpret_cast<float *>(outputs[0]->addr);
|
||||||
lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power);
|
MS_EXCEPTION_IF_NULL(m);
|
||||||
size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
|
MS_EXCEPTION_IF_NULL(v);
|
||||||
if (elem_num_ < thread_num) {
|
MS_EXCEPTION_IF_NULL(grad);
|
||||||
thread_num = elem_num_;
|
MS_EXCEPTION_IF_NULL(delta);
|
||||||
}
|
|
||||||
std::vector<common::Task> tasks;
|
|
||||||
std::vector<std::shared_ptr<ComputeParam>> thread_params;
|
|
||||||
tasks.reserve(thread_num);
|
|
||||||
|
|
||||||
size_t end = 0;
|
lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power);
|
||||||
size_t offset = elem_num_ / thread_num;
|
// multithreading
|
||||||
size_t left = elem_num_ % thread_num;
|
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
|
||||||
for (size_t i = 0; i < thread_num; ++i) {
|
LaunchAdamDelta<float>(delta, m, v, lr, beta1, beta2, epsilon, grad, lens);
|
||||||
auto params = std::make_shared<ComputeParam>();
|
|
||||||
params->delta_ = delta;
|
|
||||||
params->m_ = m;
|
|
||||||
params->v_ = v;
|
|
||||||
params->grad_ = grad;
|
|
||||||
params->beta1_ = beta1;
|
|
||||||
params->beta2_ = beta2;
|
|
||||||
params->use_nesterov_ = use_nesterov_;
|
|
||||||
params->lr_ = lr;
|
|
||||||
params->epsilon_ = epsilon;
|
|
||||||
size_t start = end;
|
|
||||||
end = start + offset;
|
|
||||||
if (i < left) {
|
|
||||||
end += 1;
|
|
||||||
}
|
|
||||||
auto task = [¶ms, start, end]() {
|
|
||||||
ComputeWeightDelta(params, start, end);
|
|
||||||
return common::SUCCESS;
|
|
||||||
};
|
|
||||||
tasks.emplace_back(task);
|
|
||||||
thread_params.emplace_back(params);
|
|
||||||
}
|
|
||||||
common::ThreadPool::GetInstance().SyncRun(tasks);
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
|
|
|
@ -32,8 +32,12 @@ class AdamDeltaCPUKernel : public CPUKernel {
|
||||||
protected:
|
protected:
|
||||||
void CheckParams(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
void CheckParams(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs) const;
|
const std::vector<AddressPtr> &outputs) const;
|
||||||
|
template <typename T>
|
||||||
|
void LaunchAdamDelta(T *delta, T *m, T *v, float lr, float beta1, float beta2, float epsilon, const T *gradient,
|
||||||
|
size_t size);
|
||||||
bool use_nesterov_{false};
|
bool use_nesterov_{false};
|
||||||
size_t elem_num_{0};
|
size_t elem_num_{0};
|
||||||
|
TypeId dtype_{kTypeUnknown};
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(AdamNoUpdateParam,
|
MS_REG_CPU_KERNEL(AdamNoUpdateParam,
|
||||||
|
|
|
@ -0,0 +1,182 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifdef ENABLE_SSE
|
||||||
|
#include <x86intrin.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef ENABLE_AVX
|
||||||
|
#include <immintrin.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <math.h>
|
||||||
|
#include "nnacl/fp32/exp_fp32.h"
|
||||||
|
#include "nnacl/fp32/adam_fp32.h"
|
||||||
|
#include "nnacl/op_base.h"
|
||||||
|
|
||||||
|
int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, const float *gradient,
|
||||||
|
size_t start, size_t end, bool use_nesterov) {
|
||||||
|
size_t c1 = start;
|
||||||
|
#ifdef ENABLE_AVX
|
||||||
|
float coeff1 = 1 - beta1;
|
||||||
|
float coeff2 = 1 - beta2;
|
||||||
|
const float *m_ptr = m;
|
||||||
|
const float *v_ptr = v;
|
||||||
|
float *var_ptr = var;
|
||||||
|
const float *gradient_ptr = gradient;
|
||||||
|
size_t c8 = ((end - start) / C8NUM) * C8NUM;
|
||||||
|
__m256 avx_r0, avx_r1, avx_r2, avx_r3, avx_r4, avx_r5, avx_r6, gradient_r;
|
||||||
|
|
||||||
|
for (; c1 < c8; c1 += C8NUM) {
|
||||||
|
avx_r0 = _mm256_set1_ps(coeff1);
|
||||||
|
gradient_r = _mm256_loadu_ps(gradient_ptr);
|
||||||
|
avx_r2 = _mm256_loadu_ps(m_ptr);
|
||||||
|
avx_r3 = _mm256_sub_ps(gradient_r, avx_r2);
|
||||||
|
avx_r4 = _mm256_mul_ps(avx_r3, avx_r0);
|
||||||
|
avx_r3 = _mm256_add_ps(avx_r4, avx_r2); // m[i]~m[i+8]
|
||||||
|
|
||||||
|
avx_r2 = _mm256_mul_ps(gradient_r, gradient_r);
|
||||||
|
avx_r4 = _mm256_loadu_ps(v_ptr);
|
||||||
|
avx_r5 = _mm256_sub_ps(avx_r2, avx_r4);
|
||||||
|
avx_r1 = _mm256_set1_ps(coeff2);
|
||||||
|
avx_r2 = _mm256_mul_ps(avx_r5, avx_r1);
|
||||||
|
avx_r5 = _mm256_add_ps(avx_r4, avx_r2); // v[i]~v[i+8]
|
||||||
|
|
||||||
|
if (use_nesterov) {
|
||||||
|
avx_r1 = _mm256_set1_ps(beta1);
|
||||||
|
avx_r2 = _mm256_mul_ps(avx_r3, avx_r1);
|
||||||
|
avx_r4 = _mm256_mul_ps(gradient_r, avx_r0);
|
||||||
|
avx_r6 = _mm256_add_ps(avx_r2, avx_r4);
|
||||||
|
avx_r0 = _mm256_set1_ps(lr);
|
||||||
|
avx_r2 = _mm256_mul_ps(avx_r6, avx_r0);
|
||||||
|
|
||||||
|
avx_r0 = _mm256_set1_ps(epsilon);
|
||||||
|
avx_r1 = _mm256_sqrt_ps(avx_r5);
|
||||||
|
avx_r4 = _mm256_add_ps(avx_r0, avx_r1);
|
||||||
|
|
||||||
|
avx_r0 = _mm256_div_ps(avx_r2, avx_r2);
|
||||||
|
avx_r1 = _mm256_loadu_ps(var_ptr);
|
||||||
|
avx_r2 = _mm256_sub_ps(avx_r1, avx_r0);
|
||||||
|
_mm256_storeu_ps(var_ptr, avx_r2);
|
||||||
|
} else {
|
||||||
|
avx_r0 = _mm256_set1_ps(lr);
|
||||||
|
avx_r1 = _mm256_mul_ps(avx_r3, avx_r0);
|
||||||
|
|
||||||
|
avx_r0 = _mm256_set1_ps(epsilon);
|
||||||
|
avx_r2 = _mm256_sqrt_ps(avx_r5);
|
||||||
|
avx_r4 = _mm256_add_ps(avx_r0, avx_r2);
|
||||||
|
|
||||||
|
avx_r0 = _mm256_div_ps(avx_r1, avx_r4);
|
||||||
|
avx_r1 = _mm256_loadu_ps(var_ptr);
|
||||||
|
avx_r3 = _mm256_sub_ps(avx_r1, avx_r0);
|
||||||
|
_mm256_storeu_ps(var_ptr, avx_r3);
|
||||||
|
}
|
||||||
|
m_ptr += C8NUM;
|
||||||
|
v_ptr += C8NUM;
|
||||||
|
var_ptr += C8NUM;
|
||||||
|
gradient_ptr += C8NUM;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// remaining
|
||||||
|
for (; c1 < end; c1++) {
|
||||||
|
m[c1] += (gradient[c1] - m[c1]) * (1 - beta1);
|
||||||
|
v[c1] += (gradient[c1] * gradient[c1] - v[c1]) * (1 - beta2);
|
||||||
|
if (use_nesterov) {
|
||||||
|
var[c1] -= lr * (m[c1] * beta1 + (1 - beta1) * gradient[c1]) / (sqrt(v[c1]) + epsilon);
|
||||||
|
} else {
|
||||||
|
var[c1] -= lr * m[c1] / (sqrt(v[c1]) + epsilon);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
|
||||||
|
const float *gradient, size_t start, size_t end, bool use_nesterov) {
|
||||||
|
size_t c1 = 0;
|
||||||
|
#ifdef ENABLE_AVX
|
||||||
|
float coeff1 = 1 - beta1;
|
||||||
|
float coeff2 = 1 - beta2;
|
||||||
|
float *m_ptr = m;
|
||||||
|
float *v_ptr = v;
|
||||||
|
float *delta_ptr = delta;
|
||||||
|
const float *gradient_ptr = gradient;
|
||||||
|
size_t c8 = ((end - start) / C8NUM) * C8NUM;
|
||||||
|
|
||||||
|
__m256 gradient_r0, m_r1, v_r2, beta1_r3, beta2_r4, var_r5, var_r6, var_r7;
|
||||||
|
for (; c1 < c8 + start; c1 += C8NUM) {
|
||||||
|
gradient_r0 = _mm256_loadu_ps(gradient_ptr); // static
|
||||||
|
beta1_r3 = _mm256_set1_ps(beta1); // static
|
||||||
|
var_r5 = _mm256_loadu_ps(m_ptr);
|
||||||
|
var_r6 = _mm256_mul_ps(beta1_r3, var_r5); // m[i] = m[i] * beta1
|
||||||
|
var_r7 = _mm256_set1_ps(coeff1);
|
||||||
|
var_r5 = _mm256_mul_ps(var_r7, gradient_r0); //
|
||||||
|
m_r1 = _mm256_add_ps(var_r6, var_r5);
|
||||||
|
_mm256_storeu_ps(m_ptr, m_r1);
|
||||||
|
|
||||||
|
beta2_r4 = _mm256_set1_ps(beta2); // static
|
||||||
|
var_r5 = _mm256_loadu_ps(v_ptr);
|
||||||
|
var_r6 = _mm256_mul_ps(beta2_r4, var_r5); // v[i] * beta2
|
||||||
|
var_r7 = _mm256_set1_ps(coeff2);
|
||||||
|
var_r5 = _mm256_mul_ps(var_r7, gradient_r0);
|
||||||
|
var_r7 = _mm256_mul_ps(var_r5, gradient_r0);
|
||||||
|
v_r2 = _mm256_add_ps(var_r7, var_r6);
|
||||||
|
_mm256_storeu_ps(v_ptr, v_r2);
|
||||||
|
|
||||||
|
if (use_nesterov) {
|
||||||
|
var_r5 = _mm256_mul_ps(beta1_r3, m_r1);
|
||||||
|
var_r6 = _mm256_set1_ps(coeff1);
|
||||||
|
var_r7 = _mm256_mul_ps(gradient_r0, var_r6);
|
||||||
|
var_r6 = _mm256_add_ps(var_r5, var_r7); // m[i] * beta1 + (1 - beta1) * grad[i]
|
||||||
|
var_r5 = _mm256_set1_ps(lr);
|
||||||
|
var_r7 = _mm256_mul_ps(var_r6, var_r5);
|
||||||
|
|
||||||
|
var_r5 = _mm256_set1_ps(epsilon);
|
||||||
|
var_r6 = _mm256_sqrt_ps(v_r2);
|
||||||
|
v_r2 = _mm256_add_ps(var_r5, var_r6);
|
||||||
|
var_r5 = _mm256_div_ps(var_r7, v_r2);
|
||||||
|
var_r6 = _mm256_set1_ps(0.f);
|
||||||
|
var_r7 = _mm256_sub_ps(var_r6, var_r5);
|
||||||
|
_mm256_storeu_ps(delta_ptr, var_r7);
|
||||||
|
} else {
|
||||||
|
var_r5 = _mm256_set1_ps(lr);
|
||||||
|
var_r6 = _mm256_mul_ps(var_r5, m_r1);
|
||||||
|
|
||||||
|
var_r7 = _mm256_set1_ps(epsilon);
|
||||||
|
var_r5 = _mm256_sqrt_ps(v_r2);
|
||||||
|
v_r2 = _mm256_add_ps(var_r5, var_r7);
|
||||||
|
|
||||||
|
var_r5 = _mm256_div_ps(var_r6, v_r2);
|
||||||
|
var_r6 = _mm256_set1_ps(0.f);
|
||||||
|
var_r7 = _mm256_sub_ps(var_r6, var_r5);
|
||||||
|
_mm256_storeu_ps(delta_ptr, var_r7);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// remaining
|
||||||
|
for (; c1 < end; ++c1) {
|
||||||
|
m[c1] *= beta1;
|
||||||
|
m[c1] += (1 - beta1) * gradient[c1];
|
||||||
|
v[c1] *= beta2;
|
||||||
|
v[c1] += (1 - beta2) * gradient[c1] * gradient[c1];
|
||||||
|
if (use_nesterov) {
|
||||||
|
delta[c1] = -lr * (m[c1] * beta1 + (1 - beta1) * gradient[c1]) / (sqrt(v[c1]) + epsilon);
|
||||||
|
} else {
|
||||||
|
delta[c1] = -lr * m[c1] / (sqrt(v[c1]) + epsilon);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_NNACL_ADAM_FP32_H
|
||||||
|
#define MINDSPORE_NNACL_ADAM_FP32_H
|
||||||
|
#include <math.h>
|
||||||
|
#include "nnacl/op_base.h"
|
||||||
|
#include "nnacl/errorcode.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, const float *gradient,
|
||||||
|
size_t start, size_t end, bool use_nesterov);
|
||||||
|
int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
|
||||||
|
const float *gradient, size_t start, size_t end, bool use_nesterov);
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif // MINDSPORE_NNACL_ADAM_FP32_H
|
|
@ -401,7 +401,7 @@ class GELU(Cell):
|
||||||
TypeError: If dtype of `input_data` is neither float16 nor float32.
|
TypeError: If dtype of `input_data` is neither float16 nor float32.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``GPU``
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
|
>>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
|
||||||
|
|
|
@ -809,7 +809,7 @@ class LayerNorm(Cell):
|
||||||
TypeError: If `epsilon` is not a float.
|
TypeError: If `epsilon` is not a float.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``GPU``
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
|
>>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
|
||||||
|
|
|
@ -24,6 +24,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/stub/runtime/)
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/stub/runtime/)
|
||||||
include_directories(${CMAKE_BINARY_DIR})
|
include_directories(${CMAKE_BINARY_DIR})
|
||||||
include_directories(${CUDA_INCLUDE_DIRS})
|
include_directories(${CUDA_INCLUDE_DIRS})
|
||||||
|
include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/backend/kernel_compiler/cpu)
|
||||||
MESSAGE("check ut_test ${CMAKE_BINARY_DIR}")
|
MESSAGE("check ut_test ${CMAKE_BINARY_DIR}")
|
||||||
|
|
||||||
link_directories(${MS_CCSRC_BUILD_PATH})
|
link_directories(${MS_CCSRC_BUILD_PATH})
|
||||||
|
@ -150,6 +151,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
"../../../mindspore/ccsrc/transform/graph_ir/op_declare/*.cc"
|
"../../../mindspore/ccsrc/transform/graph_ir/op_declare/*.cc"
|
||||||
"../../../mindspore/ccsrc/ps/*.cc"
|
"../../../mindspore/ccsrc/ps/*.cc"
|
||||||
"../../../mindspore/ccsrc/profiler/device/common/*.cc"
|
"../../../mindspore/ccsrc/profiler/device/common/*.cc"
|
||||||
|
"../../../mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/adam_fp32.c"
|
||||||
)
|
)
|
||||||
|
|
||||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST
|
list(REMOVE_ITEM MINDSPORE_SRC_LIST
|
||||||
|
|
Loading…
Reference in New Issue