Adam op performance optimization
This commit is contained in:
parent
b9ec533f95
commit
be34ccd29f
|
@ -13,145 +13,109 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/cpu/adam_weight_decay_cpu_kernel.h"
|
||||
|
||||
#include <cmath>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "backend/kernel_compiler/cpu/adam_weight_decay_cpu_kernel.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/fp32/adam_fp32.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "common/thread_pool.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t kSizeFloat16 = sizeof(float16);
|
||||
constexpr size_t kSizeFloat32 = sizeof(float);
|
||||
constexpr size_t kScalarIndex = 0;
|
||||
constexpr size_t kAdamWeightDecayInputSize = 9;
|
||||
constexpr size_t kAdamWeightDecayOutputSize = 3;
|
||||
|
||||
template <typename T, typename S>
|
||||
void AdamWeightDecayCPUKernel::LaunchFusedAdam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &) {
|
||||
auto var = reinterpret_cast<T *>(inputs[VAR]->addr);
|
||||
auto m = reinterpret_cast<T *>(inputs[M]->addr);
|
||||
auto v = reinterpret_cast<T *>(inputs[V]->addr);
|
||||
auto lr = reinterpret_cast<T *>(inputs[LR]->addr)[kScalarIndex];
|
||||
auto beta1 = reinterpret_cast<T *>(inputs[BETA1]->addr)[kScalarIndex];
|
||||
auto beta2 = reinterpret_cast<T *>(inputs[BETA2]->addr)[kScalarIndex];
|
||||
auto epsilon = reinterpret_cast<T *>(inputs[EPSILON]->addr)[kScalarIndex];
|
||||
auto decay = reinterpret_cast<T *>(inputs[DECAY]->addr)[kScalarIndex];
|
||||
auto gradient16 = reinterpret_cast<S *>(inputs[GRAD]->addr);
|
||||
const auto beta1_minus = 1 - beta1;
|
||||
const auto beta2_minus = 1 - beta2;
|
||||
template <typename T>
|
||||
void AdamWeightDecayCPUKernel::LaunchAdamWeightDecay(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &) {
|
||||
T *var = reinterpret_cast<T *>(inputs[VAR]->addr);
|
||||
T *m = reinterpret_cast<T *>(inputs[M]->addr);
|
||||
T *v = reinterpret_cast<T *>(inputs[V]->addr);
|
||||
T lr = static_cast<T>(reinterpret_cast<float *>(inputs[LR]->addr)[kScalarIndex]);
|
||||
T beta1 = static_cast<T>(reinterpret_cast<float *>(inputs[BETA1]->addr)[kScalarIndex]);
|
||||
T beta2 = static_cast<T>(reinterpret_cast<float *>(inputs[BETA2]->addr)[kScalarIndex]);
|
||||
T epsilon = static_cast<T>(reinterpret_cast<float *>(inputs[EPSILON]->addr)[kScalarIndex]);
|
||||
T decay = static_cast<T>(reinterpret_cast<float *>(inputs[DECAY]->addr)[kScalarIndex]);
|
||||
T *gradient = reinterpret_cast<T *>(inputs[GRAD]->addr);
|
||||
const T one = static_cast<T>(1.0);
|
||||
const T beta1_minus = one - beta1;
|
||||
const T beta2_minus = one - beta2;
|
||||
|
||||
// multithreading
|
||||
size_t lens = inputs[VAR]->size > 0 ? static_cast<size_t>(inputs[VAR]->size / sizeof(float)) : 1;
|
||||
size_t lens = inputs[VAR]->size > 0 ? static_cast<size_t>(inputs[VAR]->size / sizeof(T)) : 1;
|
||||
std::function<void(size_t, size_t)> task;
|
||||
|
||||
task = [&](size_t start, size_t end) {
|
||||
size_t i =
|
||||
FusedAdamFp32(var, m, v, lr, beta1, beta2, epsilon, decay, reinterpret_cast<int16_t *>(gradient16), start, end);
|
||||
// remaining
|
||||
for (; i < end; i++) {
|
||||
auto temp = static_cast<float>(gradient16[i]);
|
||||
m[i] += (temp - m[i]) * beta1_minus;
|
||||
v[i] += (temp * temp - v[i]) * beta2_minus;
|
||||
T update = m[i] / (std::sqrt(v[i]) + epsilon);
|
||||
update += decay * var[i];
|
||||
for (size_t i = start; i < end; i++) {
|
||||
m[i] += (gradient[i] - m[i]) * beta1_minus;
|
||||
v[i] += (gradient[i] * gradient[i] - v[i]) * beta2_minus;
|
||||
T sqrt_v = static_cast<T>(std::sqrt(static_cast<double>(v[i])));
|
||||
auto update = m[i] / (sqrt_v + epsilon) + decay * var[i];
|
||||
var[i] -= lr * update;
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, lens);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AdamWeightDecayCPUKernel::LaunchAdamWeightDecay(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &) {
|
||||
auto var = reinterpret_cast<T *>(inputs[VAR]->addr);
|
||||
auto m = reinterpret_cast<T *>(inputs[M]->addr);
|
||||
auto v = reinterpret_cast<T *>(inputs[V]->addr);
|
||||
auto lr = reinterpret_cast<T *>(inputs[LR]->addr)[kScalarIndex];
|
||||
auto beta1 = reinterpret_cast<T *>(inputs[BETA1]->addr)[kScalarIndex];
|
||||
auto beta2 = reinterpret_cast<T *>(inputs[BETA2]->addr)[kScalarIndex];
|
||||
auto epsilon = reinterpret_cast<T *>(inputs[EPSILON]->addr)[kScalarIndex];
|
||||
auto decay = reinterpret_cast<T *>(inputs[DECAY]->addr)[kScalarIndex];
|
||||
auto gradient = reinterpret_cast<T *>(inputs[GRAD]->addr);
|
||||
const auto beta1_minus = 1 - beta1;
|
||||
const auto beta2_minus = 1 - beta2;
|
||||
void AdamWeightDecayCPUKernel::LaunchAdamWeightDecayNnacl(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &) {
|
||||
auto var = reinterpret_cast<float *>(inputs[VAR]->addr);
|
||||
auto m = reinterpret_cast<float *>(inputs[M]->addr);
|
||||
auto v = reinterpret_cast<float *>(inputs[V]->addr);
|
||||
auto lr = reinterpret_cast<float *>(inputs[LR]->addr)[kScalarIndex];
|
||||
auto beta1 = reinterpret_cast<float *>(inputs[BETA1]->addr)[kScalarIndex];
|
||||
auto beta2 = reinterpret_cast<float *>(inputs[BETA2]->addr)[kScalarIndex];
|
||||
auto epsilon = reinterpret_cast<float *>(inputs[EPSILON]->addr)[kScalarIndex];
|
||||
auto decay = reinterpret_cast<float *>(inputs[DECAY]->addr)[kScalarIndex];
|
||||
auto gradient = reinterpret_cast<float *>(inputs[GRAD]->addr);
|
||||
|
||||
// multithreading
|
||||
size_t lens = inputs[VAR]->size > 0 ? static_cast<size_t>(inputs[VAR]->size / sizeof(float)) : 1;
|
||||
std::function<void(size_t, size_t)> task;
|
||||
task = [&](size_t start, size_t end) {
|
||||
size_t i = AdamWeightDecayFp32(var, m, v, lr, beta1, beta2, epsilon, decay, gradient, start, end);
|
||||
// remaining
|
||||
for (; i < end; i++) {
|
||||
m[i] += (gradient[i] - m[i]) * beta1_minus;
|
||||
v[i] += (gradient[i] * gradient[i] - v[i]) * beta2_minus;
|
||||
T update = m[i] / (std::sqrt(v[i]) + epsilon);
|
||||
update += decay * var[i];
|
||||
var[i] -= lr * update;
|
||||
int ret = AdamWeightDecayFp32(var, m, v, lr, beta1, beta2, epsilon, decay, gradient, start, end);
|
||||
if (ret != NNACL_OK) {
|
||||
MS_LOG(EXCEPTION) << "AdamWeightDecayFp32 failed.";
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, lens);
|
||||
ParallelLaunchAutoSearch(task, lens, this, ¶llel_search_info_);
|
||||
}
|
||||
|
||||
void AdamWeightDecayCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::vector<size_t> var_shape = AnfAlgo::GetInputDeviceShape(kernel_node, VAR);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, VAR);
|
||||
gradient_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, GRAD);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != kAdamWeightDecayInputSize) {
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
if (input_num != kAdamWeightDecayInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but AdamWeightDecay needs 9 inputs.";
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != kAdamWeightDecayOutputSize) {
|
||||
if (output_num != kAdamWeightDecayOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but AdamWeightDecay needs 3 outputs.";
|
||||
}
|
||||
elem_num_ = 1;
|
||||
for (size_t i : var_shape) {
|
||||
elem_num_ *= i;
|
||||
}
|
||||
if (elem_num_ < 1) {
|
||||
MS_LOG(EXCEPTION) << "Invalid parameter shape";
|
||||
}
|
||||
if (dtype_ != kNumberTypeFloat32) {
|
||||
MS_LOG(EXCEPTION) << "The dtype of parameter must be float32!";
|
||||
}
|
||||
if (gradient_dtype_ != kNumberTypeFloat32 && gradient_dtype_ != kNumberTypeFloat16) {
|
||||
MS_LOG(EXCEPTION) << "The dtype of gradient must be float32 or float16!";
|
||||
}
|
||||
}
|
||||
|
||||
void AdamWeightDecayCPUKernel::CheckParam(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (inputs.size() != kAdamWeightDecayInputSize) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << inputs.size() << ", but AdamWeightDecay needs 9 inputs.";
|
||||
}
|
||||
if (outputs.size() != kAdamWeightDecayOutputSize) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but AdamWeightDecay needs 3 outputs.";
|
||||
}
|
||||
size_t elem1_size = elem_num_ * kSizeFloat32;
|
||||
size_t elem2_size = gradient_dtype_ == kNumberTypeFloat16 ? elem_num_ * kSizeFloat16 : elem1_size;
|
||||
if (inputs[VAR]->size != elem1_size || inputs[M]->size != elem1_size || inputs[V]->size != elem1_size ||
|
||||
inputs[GRAD]->size != elem2_size) {
|
||||
MS_LOG(EXCEPTION) << "Error input data size!";
|
||||
}
|
||||
if (inputs[LR]->size != kSizeFloat32 || inputs[BETA1]->size != kSizeFloat32 || inputs[BETA2]->size != kSizeFloat32 ||
|
||||
inputs[EPSILON]->size != kSizeFloat32 || inputs[DECAY]->size != kSizeFloat32) {
|
||||
MS_LOG(EXCEPTION) << "The attribute beta, lr, epsilon and weight decay must be float!";
|
||||
}
|
||||
}
|
||||
|
||||
bool AdamWeightDecayCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CheckParam(inputs, outputs);
|
||||
if (gradient_dtype_ == kNumberTypeFloat16) {
|
||||
LaunchFusedAdam<float, float16>(inputs, outputs);
|
||||
} else if (gradient_dtype_ == kNumberTypeFloat32) {
|
||||
LaunchAdamWeightDecay<float>(inputs, outputs);
|
||||
if (inputs.size() != kAdamWeightDecayInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << inputs.size() << ", but AdamWeightDecay needs 9 inputs.";
|
||||
}
|
||||
if (outputs.size() != kAdamWeightDecayOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but AdamWeightDecay needs 3 outputs.";
|
||||
}
|
||||
if (inputs[VAR]->size != inputs[M]->size || inputs[VAR]->size != inputs[V]->size ||
|
||||
inputs[VAR]->size != inputs[GRAD]->size) {
|
||||
MS_LOG(EXCEPTION) << "Error input data size!";
|
||||
}
|
||||
if (inputs[LR]->size != kSizeFloat32 || inputs[BETA1]->size != kSizeFloat32 || inputs[BETA2]->size != kSizeFloat32 ||
|
||||
inputs[EPSILON]->size != kSizeFloat32 || inputs[DECAY]->size != kSizeFloat32) {
|
||||
MS_LOG(EXCEPTION) << "The attribute beta, lr, epsilon and weight decay must be float!";
|
||||
}
|
||||
|
||||
if (dtype_ == kNumberTypeFloat32) {
|
||||
LaunchAdamWeightDecayNnacl(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat16) {
|
||||
LaunchAdamWeightDecay<float16>(inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "AdamWeightDecay not support " << dtype_;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -23,6 +23,11 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t kSizeFloat32 = sizeof(float);
|
||||
constexpr size_t kScalarIndex = 0;
|
||||
constexpr size_t kAdamWeightDecayInputNum = 9;
|
||||
constexpr size_t kAdamWeightDecayOutputNum = 3;
|
||||
|
||||
class AdamWeightDecayCPUKernel : public CPUKernel {
|
||||
public:
|
||||
AdamWeightDecayCPUKernel() = default;
|
||||
|
@ -32,48 +37,15 @@ class AdamWeightDecayCPUKernel : public CPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
void CheckParam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
template <typename T, typename S>
|
||||
void LaunchFusedAdam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
template <typename T>
|
||||
void LaunchAdamWeightDecay(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
size_t elem_num_{0};
|
||||
void LaunchAdamWeightDecayNnacl(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
TypeId gradient_dtype_{kTypeUnknown};
|
||||
enum input_list_ { VAR, M, V, LR, BETA1, BETA2, EPSILON, DECAY, GRAD };
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(AdamWeightDecay,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
AdamWeightDecayCPUKernel)
|
||||
|
||||
MS_REG_CPU_KERNEL(AdamWeightDecay,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
AdamWeightDecayCPUKernel)
|
||||
MS_REG_CPU_KERNEL(AdamWeightDecay, KernelAttr(), AdamWeightDecayCPUKernel)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -140,11 +140,16 @@ void CPUKernelUtils::ParallelForAutoSearch(const CTask &task, size_t count, Para
|
|||
|
||||
ActorThreadPool *GetActorMgrInnerThreadPool() {
|
||||
auto actor_manager = ActorMgr::GetActorMgrRef();
|
||||
MS_EXCEPTION_IF_NULL(actor_manager);
|
||||
auto thread_pool = actor_manager->GetActorThreadPool();
|
||||
// Init thread_pool if env is windows or ascend, in case that it won't be init in graph_scheduler.
|
||||
if (thread_pool == nullptr) {
|
||||
const size_t kMaxThreadNum = 23;
|
||||
size_t max_thread_num = std::thread::hardware_concurrency() - 1;
|
||||
#if ENABLE_D || ENABLE_GPU
|
||||
const size_t kDeviceNum = 8;
|
||||
max_thread_num /= kDeviceNum;
|
||||
#endif
|
||||
if (max_thread_num < 1) {
|
||||
max_thread_num = 1;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,156 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <cmath>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "backend/kernel_compiler/cpu/fused_cast_adam_weight_decay_cpu_kernel.h"
|
||||
#include "nnacl/fp32/adam_fp32.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void FusedCastAdamWeightDecayCPUKernel::LaunchFusedCastAdamFp32(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &) {
|
||||
auto var = reinterpret_cast<float *>(inputs[VAR]->addr);
|
||||
auto m = reinterpret_cast<float *>(inputs[M]->addr);
|
||||
auto v = reinterpret_cast<float *>(inputs[V]->addr);
|
||||
auto lr = reinterpret_cast<float *>(inputs[LR]->addr)[kScalarIndex];
|
||||
auto beta1 = reinterpret_cast<float *>(inputs[BETA1]->addr)[kScalarIndex];
|
||||
auto beta2 = reinterpret_cast<float *>(inputs[BETA2]->addr)[kScalarIndex];
|
||||
auto epsilon = reinterpret_cast<float *>(inputs[EPSILON]->addr)[kScalarIndex];
|
||||
auto decay = reinterpret_cast<float *>(inputs[DECAY]->addr)[kScalarIndex];
|
||||
auto gradient16 = reinterpret_cast<float16 *>(inputs[GRAD]->addr);
|
||||
const auto beta1_minus = 1 - beta1;
|
||||
const auto beta2_minus = 1 - beta2;
|
||||
|
||||
// multithreading
|
||||
size_t lens = inputs[VAR]->size > 0 ? static_cast<size_t>(inputs[VAR]->size / kSizeFloat32) : 1;
|
||||
std::function<void(size_t, size_t)> task;
|
||||
|
||||
task = [&](size_t start, size_t end) {
|
||||
size_t i = FusedCastAdamFp32(var, m, v, lr, beta1, beta2, epsilon, decay, reinterpret_cast<int16_t *>(gradient16),
|
||||
start, end);
|
||||
// remaining
|
||||
for (; i < end; i++) {
|
||||
auto temp = static_cast<float>(gradient16[i]);
|
||||
m[i] += (temp - m[i]) * beta1_minus;
|
||||
v[i] += (temp * temp - v[i]) * beta2_minus;
|
||||
auto update = m[i] / (std::sqrt(v[i]) + epsilon);
|
||||
update += decay * var[i];
|
||||
var[i] -= lr * update;
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, lens, this, ¶llel_search_info_);
|
||||
}
|
||||
|
||||
void FusedCastAdamWeightDecayCPUKernel::LaunchFusedCastAdamFp16(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &) {
|
||||
auto var16 = reinterpret_cast<float16 *>(inputs[VAR]->addr);
|
||||
auto m = reinterpret_cast<float *>(inputs[M]->addr);
|
||||
auto v = reinterpret_cast<float *>(inputs[V]->addr);
|
||||
auto lr = reinterpret_cast<float *>(inputs[LR]->addr)[kScalarIndex];
|
||||
auto beta1 = reinterpret_cast<float *>(inputs[BETA1]->addr)[kScalarIndex];
|
||||
auto beta2 = reinterpret_cast<float *>(inputs[BETA2]->addr)[kScalarIndex];
|
||||
auto epsilon = reinterpret_cast<float *>(inputs[EPSILON]->addr)[kScalarIndex];
|
||||
auto decay = reinterpret_cast<float *>(inputs[DECAY]->addr)[kScalarIndex];
|
||||
auto gradient16 = reinterpret_cast<float16 *>(inputs[GRAD]->addr);
|
||||
const auto beta1_minus = 1 - beta1;
|
||||
const auto beta2_minus = 1 - beta2;
|
||||
|
||||
// multithreading
|
||||
size_t lens = inputs[VAR]->size > 0 ? static_cast<size_t>(inputs[VAR]->size / kSizeFloat16) : 1;
|
||||
std::function<void(size_t, size_t)> task;
|
||||
|
||||
task = [&](size_t start, size_t end) {
|
||||
size_t i = FusedCastAdamFp16(reinterpret_cast<int16_t *>(var16), m, v, lr, beta1, beta2, epsilon, decay,
|
||||
reinterpret_cast<int16_t *>(gradient16), start, end);
|
||||
// remaining
|
||||
for (; i < end; i++) {
|
||||
auto temp_var = static_cast<float>(var16[i]);
|
||||
auto temp_grad = static_cast<float>(gradient16[i]);
|
||||
m[i] += (temp_grad - m[i]) * beta1_minus;
|
||||
v[i] += (temp_grad * temp_grad - v[i]) * beta2_minus;
|
||||
auto update = m[i] / (std::sqrt(v[i]) + epsilon);
|
||||
update += decay * temp_var;
|
||||
temp_var -= lr * update;
|
||||
var16[i] = static_cast<float16>(temp_var);
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, lens, this, ¶llel_search_info_);
|
||||
}
|
||||
|
||||
void FusedCastAdamWeightDecayCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::vector<size_t> var_shape = AnfAlgo::GetInputDeviceShape(kernel_node, VAR);
|
||||
var_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, VAR);
|
||||
gradient_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, GRAD);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != kFusedCastAdamWeightDecayInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but AdamWeightDecay needs 9 inputs.";
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != kFusedCastAdamWeightDecayOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but AdamWeightDecay needs 3 outputs.";
|
||||
}
|
||||
elem_num_ = 1;
|
||||
for (size_t i : var_shape) {
|
||||
elem_num_ *= i;
|
||||
}
|
||||
if (elem_num_ < 1) {
|
||||
MS_LOG(EXCEPTION) << "Invalid parameter shape";
|
||||
}
|
||||
if (gradient_dtype_ != kNumberTypeFloat16) {
|
||||
MS_LOG(EXCEPTION) << "The dtype of gradient must be float16!";
|
||||
}
|
||||
if (var_dtype_ != kNumberTypeFloat32 && var_dtype_ != kNumberTypeFloat16) {
|
||||
MS_LOG(EXCEPTION) << "The dtype of parameter must be float32 or float16!";
|
||||
}
|
||||
}
|
||||
|
||||
void FusedCastAdamWeightDecayCPUKernel::CheckParam(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (inputs.size() != kFusedCastAdamWeightDecayInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << inputs.size() << ", but AdamWeightDecay needs 9 inputs.";
|
||||
}
|
||||
if (outputs.size() != kFusedCastAdamWeightDecayOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but AdamWeightDecay needs 3 outputs.";
|
||||
}
|
||||
size_t elem_size_fp32 = elem_num_ * kSizeFloat32;
|
||||
size_t elem_size_fp16 = elem_num_ * kSizeFloat16;
|
||||
size_t var_size = var_dtype_ == kNumberTypeFloat16 ? elem_size_fp16 : elem_size_fp32;
|
||||
if (inputs[VAR]->size != var_size || inputs[M]->size != elem_size_fp32 || inputs[V]->size != elem_size_fp32 ||
|
||||
inputs[GRAD]->size != elem_size_fp16) {
|
||||
MS_LOG(EXCEPTION) << "Error input data size!";
|
||||
}
|
||||
if (inputs[LR]->size != kSizeFloat32 || inputs[BETA1]->size != kSizeFloat32 || inputs[BETA2]->size != kSizeFloat32 ||
|
||||
inputs[EPSILON]->size != kSizeFloat32 || inputs[DECAY]->size != kSizeFloat32) {
|
||||
MS_LOG(EXCEPTION) << "The attribute beta, lr, epsilon and weight decay must be float!";
|
||||
}
|
||||
}
|
||||
|
||||
bool FusedCastAdamWeightDecayCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CheckParam(inputs, outputs);
|
||||
if (var_dtype_ == kNumberTypeFloat16) {
|
||||
LaunchFusedCastAdamFp16(inputs, outputs);
|
||||
} else {
|
||||
LaunchFusedCastAdamFp32(inputs, outputs);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,85 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_CAST_ADAM_WEIGHT_DECAY_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_CAST_ADAM_WEIGHT_DECAY_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t kSizeFloat32 = sizeof(float);
|
||||
constexpr size_t kSizeFloat16 = sizeof(float16);
|
||||
constexpr size_t kScalarIndex = 0;
|
||||
constexpr size_t kFusedCastAdamWeightDecayInputNum = 9;
|
||||
constexpr size_t kFusedCastAdamWeightDecayOutputNum = 3;
|
||||
|
||||
class FusedCastAdamWeightDecayCPUKernel : public CPUKernel {
|
||||
public:
|
||||
FusedCastAdamWeightDecayCPUKernel() = default;
|
||||
~FusedCastAdamWeightDecayCPUKernel() override = default;
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
void CheckParam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
void LaunchFusedCastAdamFp32(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
void LaunchFusedCastAdamFp16(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
size_t elem_num_{0};
|
||||
TypeId var_dtype_{kTypeUnknown};
|
||||
TypeId gradient_dtype_{kTypeUnknown};
|
||||
enum input_list_ { VAR, M, V, LR, BETA1, BETA2, EPSILON, DECAY, GRAD };
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(FusedCastAdamWeightDecay,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedCastAdamWeightDecayCPUKernel)
|
||||
|
||||
MS_REG_CPU_KERNEL(FusedCastAdamWeightDecay,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedCastAdamWeightDecayCPUKernel)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_CAST_ADAM_WEIGHT_DECAY_CPU_KERNEL_H_
|
|
@ -105,7 +105,7 @@ if(ENABLE_CPU)
|
|||
elseif("${X86_64_SIMD}" STREQUAL "avx")
|
||||
target_compile_definitions(nnacl_mid PRIVATE ENABLE_SSE ENABLE_AVX)
|
||||
elseif("${X86_64_SIMD}" STREQUAL "avx512")
|
||||
target_compile_definitions(nnacl_mid PRIVATE ENABLE_SSE ENABLE_AVX ENABLE_AVX512)
|
||||
target_compile_definitions(nnacl_mid PRIVATE ENABLE_AVX512)
|
||||
target_compile_options(nnacl_mid PRIVATE -mavx512f)
|
||||
endif()
|
||||
target_compile_options(nnacl_mid PRIVATE -fPIC)
|
||||
|
|
|
@ -152,102 +152,44 @@ int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
size_t AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
|
||||
float decay, const float *gradient, size_t start, size_t end) {
|
||||
int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
||||
const float *gradient, size_t start, size_t end) {
|
||||
size_t c1 = start;
|
||||
const float beta1_minus = 1 - beta1;
|
||||
const float beta2_minus = 1 - beta2;
|
||||
#ifdef ENABLE_AVX512
|
||||
struct AVX_Data beta1_r, beta2_r, beta1_minus_r, beta2_minus_r, lr_neg_r, epsilon_r, decay_r;
|
||||
beta1_r.data = _mm512_set1_ps(beta1);
|
||||
beta2_r.data = _mm512_set1_ps(beta2);
|
||||
beta1_minus_r.data = _mm512_set1_ps(1.0f - beta1);
|
||||
beta2_minus_r.data = _mm512_set1_ps(1.0f - beta2);
|
||||
lr_neg_r.data = _mm512_set1_ps(-lr);
|
||||
epsilon_r.data = _mm512_set1_ps(epsilon);
|
||||
decay_r.data = _mm512_set1_ps(decay);
|
||||
__m512 beta1_r = _mm512_set1_ps(beta1);
|
||||
__m512 beta2_r = _mm512_set1_ps(beta2);
|
||||
__m512 beta1_minus_r = _mm512_set1_ps(beta1_minus);
|
||||
__m512 beta2_minus_r = _mm512_set1_ps(beta2_minus);
|
||||
__m512 lr_neg_r = _mm512_set1_ps(-lr);
|
||||
__m512 epsilon_r = _mm512_set1_ps(epsilon);
|
||||
__m512 decay_r = _mm512_set1_ps(decay);
|
||||
size_t c16 = ((end - start) / C16NUM) * C16NUM + start;
|
||||
size_t c64 = ((end - start) / C64NUM) * C64NUM + start;
|
||||
|
||||
const float *gradient_ptr = gradient + start;
|
||||
float *var_ptr = var + start;
|
||||
float *m_ptr = m + start;
|
||||
float *v_ptr = v + start;
|
||||
|
||||
for (; c1 < c64; c1 += C64NUM) {
|
||||
struct AVX_Data g_r[kUnrollSize], var_r[kUnrollSize], m_r[kUnrollSize], v_r[kUnrollSize];
|
||||
LoadStep4(g_r, gradient_ptr, kUnrollSize);
|
||||
LoadStep4(var_r, var_ptr, kUnrollSize);
|
||||
LoadStep4(m_r, m_ptr, kUnrollSize);
|
||||
LoadStep4(v_r, v_ptr, kUnrollSize);
|
||||
|
||||
m_r[0].data = _mm512_mul_ps(m_r[0].data, beta1_r.data);
|
||||
m_r[1].data = _mm512_mul_ps(m_r[1].data, beta1_r.data);
|
||||
m_r[2].data = _mm512_mul_ps(m_r[2].data, beta1_r.data);
|
||||
m_r[3].data = _mm512_mul_ps(m_r[3].data, beta1_r.data);
|
||||
m_r[0].data = _mm512_fmadd_ps(g_r[0].data, beta1_minus_r.data, m_r[0].data);
|
||||
m_r[1].data = _mm512_fmadd_ps(g_r[1].data, beta1_minus_r.data, m_r[1].data);
|
||||
m_r[2].data = _mm512_fmadd_ps(g_r[2].data, beta1_minus_r.data, m_r[2].data);
|
||||
m_r[3].data = _mm512_fmadd_ps(g_r[3].data, beta1_minus_r.data, m_r[3].data);
|
||||
|
||||
v_r[0].data = _mm512_mul_ps(v_r[0].data, beta2_r.data);
|
||||
v_r[1].data = _mm512_mul_ps(v_r[1].data, beta2_r.data);
|
||||
v_r[2].data = _mm512_mul_ps(v_r[2].data, beta2_r.data);
|
||||
v_r[3].data = _mm512_mul_ps(v_r[3].data, beta2_r.data);
|
||||
g_r[0].data = _mm512_mul_ps(g_r[0].data, g_r[0].data);
|
||||
g_r[1].data = _mm512_mul_ps(g_r[1].data, g_r[1].data);
|
||||
g_r[2].data = _mm512_mul_ps(g_r[2].data, g_r[2].data);
|
||||
g_r[3].data = _mm512_mul_ps(g_r[3].data, g_r[3].data);
|
||||
v_r[0].data = _mm512_fmadd_ps(g_r[0].data, beta2_minus_r.data, v_r[0].data);
|
||||
v_r[1].data = _mm512_fmadd_ps(g_r[1].data, beta2_minus_r.data, v_r[1].data);
|
||||
v_r[2].data = _mm512_fmadd_ps(g_r[2].data, beta2_minus_r.data, v_r[2].data);
|
||||
v_r[3].data = _mm512_fmadd_ps(g_r[3].data, beta2_minus_r.data, v_r[3].data);
|
||||
|
||||
g_r[0].data = _mm512_sqrt_ps(v_r[0].data);
|
||||
g_r[1].data = _mm512_sqrt_ps(v_r[1].data);
|
||||
g_r[2].data = _mm512_sqrt_ps(v_r[2].data);
|
||||
g_r[3].data = _mm512_sqrt_ps(v_r[3].data);
|
||||
g_r[0].data = _mm512_div_ps(m_r[0].data, _mm512_add_ps(g_r[0].data, epsilon_r.data));
|
||||
g_r[1].data = _mm512_div_ps(m_r[1].data, _mm512_add_ps(g_r[1].data, epsilon_r.data));
|
||||
g_r[2].data = _mm512_div_ps(m_r[2].data, _mm512_add_ps(g_r[2].data, epsilon_r.data));
|
||||
g_r[3].data = _mm512_div_ps(m_r[3].data, _mm512_add_ps(g_r[3].data, epsilon_r.data));
|
||||
g_r[0].data = _mm512_fmadd_ps(var_r[0].data, decay_r.data, g_r[0].data);
|
||||
g_r[1].data = _mm512_fmadd_ps(var_r[1].data, decay_r.data, g_r[1].data);
|
||||
g_r[2].data = _mm512_fmadd_ps(var_r[2].data, decay_r.data, g_r[2].data);
|
||||
g_r[3].data = _mm512_fmadd_ps(var_r[3].data, decay_r.data, g_r[3].data);
|
||||
|
||||
var_r[0].data = _mm512_fmadd_ps(g_r[0].data, lr_neg_r.data, var_r[0].data);
|
||||
var_r[1].data = _mm512_fmadd_ps(g_r[1].data, lr_neg_r.data, var_r[1].data);
|
||||
var_r[2].data = _mm512_fmadd_ps(g_r[2].data, lr_neg_r.data, var_r[2].data);
|
||||
var_r[3].data = _mm512_fmadd_ps(g_r[3].data, lr_neg_r.data, var_r[3].data);
|
||||
|
||||
StoreStep4(var_ptr, var_r, kUnrollSize);
|
||||
StoreStep4(m_ptr, m_r, kUnrollSize);
|
||||
StoreStep4(v_ptr, v_r, kUnrollSize);
|
||||
|
||||
gradient_ptr += C64NUM;
|
||||
var_ptr += C64NUM;
|
||||
m_ptr += C64NUM;
|
||||
v_ptr += C64NUM;
|
||||
}
|
||||
for (; c1 < c16; c1 += C16NUM) {
|
||||
struct AVX_Data g_r, var_r, m_r, v_r;
|
||||
g_r.data = _mm512_loadu_ps(gradient_ptr);
|
||||
var_r.data = _mm512_loadu_ps(var_ptr);
|
||||
m_r.data = _mm512_loadu_ps(m_ptr);
|
||||
v_r.data = _mm512_loadu_ps(v_ptr);
|
||||
__m512 g_r = _mm512_loadu_ps(gradient_ptr);
|
||||
__m512 var_r = _mm512_loadu_ps(var_ptr);
|
||||
__m512 m_r = _mm512_loadu_ps(m_ptr);
|
||||
__m512 v_r = _mm512_loadu_ps(v_ptr);
|
||||
|
||||
m_r.data = _mm512_mul_ps(m_r.data, beta1_r.data);
|
||||
v_r.data = _mm512_mul_ps(v_r.data, beta2_r.data);
|
||||
struct AVX_Data avx_r0;
|
||||
avx_r0.data = _mm512_mul_ps(g_r.data, g_r.data);
|
||||
m_r.data = _mm512_fmadd_ps(g_r.data, beta1_minus_r.data, m_r.data);
|
||||
v_r.data = _mm512_fmadd_ps(avx_r0.data, beta2_minus_r.data, v_r.data);
|
||||
avx_r0.data = _mm512_sqrt_ps(v_r.data);
|
||||
avx_r0.data = _mm512_div_ps(m_r.data, _mm512_add_ps(avx_r0.data, epsilon_r.data));
|
||||
avx_r0.data = _mm512_fmadd_ps(var_r.data, decay_r.data, avx_r0.data);
|
||||
var_r.data = _mm512_fmadd_ps(avx_r0.data, lr_neg_r.data, var_r.data);
|
||||
_mm512_storeu_ps(var_ptr, var_r.data);
|
||||
_mm512_storeu_ps(m_ptr, m_r.data);
|
||||
_mm512_storeu_ps(v_ptr, v_r.data);
|
||||
m_r = _mm512_mul_ps(m_r, beta1_r);
|
||||
v_r = _mm512_mul_ps(v_r, beta2_r);
|
||||
__m512 avx_r0 = _mm512_mul_ps(g_r, g_r);
|
||||
m_r = _mm512_fmadd_ps(g_r, beta1_minus_r, m_r);
|
||||
v_r = _mm512_fmadd_ps(avx_r0, beta2_minus_r, v_r);
|
||||
avx_r0 = _mm512_sqrt_ps(v_r);
|
||||
avx_r0 = _mm512_div_ps(m_r, _mm512_add_ps(avx_r0, epsilon_r));
|
||||
avx_r0 = _mm512_fmadd_ps(var_r, decay_r, avx_r0);
|
||||
var_r = _mm512_fmadd_ps(avx_r0, lr_neg_r, var_r);
|
||||
_mm512_storeu_ps(var_ptr, var_r);
|
||||
_mm512_storeu_ps(m_ptr, m_r);
|
||||
_mm512_storeu_ps(v_ptr, v_r);
|
||||
|
||||
gradient_ptr += C16NUM;
|
||||
var_ptr += C16NUM;
|
||||
|
@ -255,109 +197,51 @@ size_t AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1
|
|||
v_ptr += C16NUM;
|
||||
}
|
||||
#endif
|
||||
return c1;
|
||||
// remaining
|
||||
for (; c1 < end; c1++) {
|
||||
m[c1] += (gradient[c1] - m[c1]) * beta1_minus;
|
||||
v[c1] += (gradient[c1] * gradient[c1] - v[c1]) * beta2_minus;
|
||||
var[c1] -= lr * (m[c1] / (sqrt(v[c1]) + epsilon) + decay * var[c1]);
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
size_t FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
||||
const int16_t *gradient16, size_t start, size_t end) {
|
||||
size_t FusedCastAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
||||
const int16_t *gradient16, size_t start, size_t end) {
|
||||
size_t c1 = start;
|
||||
#ifdef ENABLE_AVX512
|
||||
struct AVX_Data beta1_r, beta2_r, beta1_minus_r, beta2_minus_r, lr_neg_r, epsilon_r, decay_r;
|
||||
beta1_r.data = _mm512_set1_ps(beta1);
|
||||
beta2_r.data = _mm512_set1_ps(beta2);
|
||||
beta1_minus_r.data = _mm512_set1_ps(1.0f - beta1);
|
||||
beta2_minus_r.data = _mm512_set1_ps(1.0f - beta2);
|
||||
lr_neg_r.data = _mm512_set1_ps(-lr);
|
||||
epsilon_r.data = _mm512_set1_ps(epsilon);
|
||||
decay_r.data = _mm512_set1_ps(decay);
|
||||
__m512 beta1_r = _mm512_set1_ps(beta1);
|
||||
__m512 beta2_r = _mm512_set1_ps(beta2);
|
||||
__m512 beta1_minus_r = _mm512_set1_ps(1.0f - beta1);
|
||||
__m512 beta2_minus_r = _mm512_set1_ps(1.0f - beta2);
|
||||
__m512 lr_neg_r = _mm512_set1_ps(-lr);
|
||||
__m512 epsilon_r = _mm512_set1_ps(epsilon);
|
||||
__m512 decay_r = _mm512_set1_ps(decay);
|
||||
size_t c16 = ((end - start) / C16NUM) * C16NUM + start;
|
||||
size_t c64 = ((end - start) / C64NUM) * C64NUM + start;
|
||||
|
||||
const int16_t *gradient16_ptr = gradient16 + start;
|
||||
float *var_ptr = var + start;
|
||||
float *m_ptr = m + start;
|
||||
float *v_ptr = v + start;
|
||||
|
||||
for (; c1 < c64; c1 += C64NUM) {
|
||||
struct AVX_Data g_r[kUnrollSize], var_r[kUnrollSize], m_r[kUnrollSize], v_r[kUnrollSize];
|
||||
g_r[0].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr)));
|
||||
g_r[1].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr + C16NUM)));
|
||||
g_r[2].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr + C16NUM * 2)));
|
||||
g_r[3].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr + C16NUM * 3)));
|
||||
|
||||
LoadStep4(var_r, var_ptr, kUnrollSize);
|
||||
LoadStep4(m_r, m_ptr, kUnrollSize);
|
||||
LoadStep4(v_r, v_ptr, kUnrollSize);
|
||||
|
||||
m_r[0].data = _mm512_mul_ps(m_r[0].data, beta1_r.data);
|
||||
m_r[1].data = _mm512_mul_ps(m_r[1].data, beta1_r.data);
|
||||
m_r[2].data = _mm512_mul_ps(m_r[2].data, beta1_r.data);
|
||||
m_r[3].data = _mm512_mul_ps(m_r[3].data, beta1_r.data);
|
||||
m_r[0].data = _mm512_fmadd_ps(g_r[0].data, beta1_minus_r.data, m_r[0].data);
|
||||
m_r[1].data = _mm512_fmadd_ps(g_r[1].data, beta1_minus_r.data, m_r[1].data);
|
||||
m_r[2].data = _mm512_fmadd_ps(g_r[2].data, beta1_minus_r.data, m_r[2].data);
|
||||
m_r[3].data = _mm512_fmadd_ps(g_r[3].data, beta1_minus_r.data, m_r[3].data);
|
||||
|
||||
v_r[0].data = _mm512_mul_ps(v_r[0].data, beta2_r.data);
|
||||
v_r[1].data = _mm512_mul_ps(v_r[1].data, beta2_r.data);
|
||||
v_r[2].data = _mm512_mul_ps(v_r[2].data, beta2_r.data);
|
||||
v_r[3].data = _mm512_mul_ps(v_r[3].data, beta2_r.data);
|
||||
g_r[0].data = _mm512_mul_ps(g_r[0].data, g_r[0].data);
|
||||
g_r[1].data = _mm512_mul_ps(g_r[1].data, g_r[1].data);
|
||||
g_r[2].data = _mm512_mul_ps(g_r[2].data, g_r[2].data);
|
||||
g_r[3].data = _mm512_mul_ps(g_r[3].data, g_r[3].data);
|
||||
v_r[0].data = _mm512_fmadd_ps(g_r[0].data, beta2_minus_r.data, v_r[0].data);
|
||||
v_r[1].data = _mm512_fmadd_ps(g_r[1].data, beta2_minus_r.data, v_r[1].data);
|
||||
v_r[2].data = _mm512_fmadd_ps(g_r[2].data, beta2_minus_r.data, v_r[2].data);
|
||||
v_r[3].data = _mm512_fmadd_ps(g_r[3].data, beta2_minus_r.data, v_r[3].data);
|
||||
|
||||
g_r[0].data = _mm512_sqrt_ps(v_r[0].data);
|
||||
g_r[1].data = _mm512_sqrt_ps(v_r[1].data);
|
||||
g_r[2].data = _mm512_sqrt_ps(v_r[2].data);
|
||||
g_r[3].data = _mm512_sqrt_ps(v_r[3].data);
|
||||
g_r[0].data = _mm512_div_ps(m_r[0].data, _mm512_add_ps(g_r[0].data, epsilon_r.data));
|
||||
g_r[1].data = _mm512_div_ps(m_r[1].data, _mm512_add_ps(g_r[1].data, epsilon_r.data));
|
||||
g_r[2].data = _mm512_div_ps(m_r[2].data, _mm512_add_ps(g_r[2].data, epsilon_r.data));
|
||||
g_r[3].data = _mm512_div_ps(m_r[3].data, _mm512_add_ps(g_r[3].data, epsilon_r.data));
|
||||
g_r[0].data = _mm512_fmadd_ps(var_r[0].data, decay_r.data, g_r[0].data);
|
||||
g_r[1].data = _mm512_fmadd_ps(var_r[1].data, decay_r.data, g_r[1].data);
|
||||
g_r[2].data = _mm512_fmadd_ps(var_r[2].data, decay_r.data, g_r[2].data);
|
||||
g_r[3].data = _mm512_fmadd_ps(var_r[3].data, decay_r.data, g_r[3].data);
|
||||
|
||||
var_r[0].data = _mm512_fmadd_ps(g_r[0].data, lr_neg_r.data, var_r[0].data);
|
||||
var_r[1].data = _mm512_fmadd_ps(g_r[1].data, lr_neg_r.data, var_r[1].data);
|
||||
var_r[2].data = _mm512_fmadd_ps(g_r[2].data, lr_neg_r.data, var_r[2].data);
|
||||
var_r[3].data = _mm512_fmadd_ps(g_r[3].data, lr_neg_r.data, var_r[3].data);
|
||||
|
||||
StoreStep4(var_ptr, var_r, kUnrollSize);
|
||||
StoreStep4(m_ptr, m_r, kUnrollSize);
|
||||
StoreStep4(v_ptr, v_r, kUnrollSize);
|
||||
|
||||
gradient16_ptr += C64NUM;
|
||||
var_ptr += C64NUM;
|
||||
m_ptr += C64NUM;
|
||||
v_ptr += C64NUM;
|
||||
}
|
||||
for (; c1 < c16; c1 += C16NUM) {
|
||||
struct AVX_Data g_r, var_r, m_r, v_r;
|
||||
g_r.data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr)));
|
||||
var_r.data = _mm512_loadu_ps(var_ptr);
|
||||
m_r.data = _mm512_loadu_ps(m_ptr);
|
||||
v_r.data = _mm512_loadu_ps(v_ptr);
|
||||
__m512 g_r = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr)));
|
||||
__m512 var_r = _mm512_loadu_ps(var_ptr);
|
||||
__m512 m_r = _mm512_loadu_ps(m_ptr);
|
||||
__m512 v_r = _mm512_loadu_ps(v_ptr);
|
||||
|
||||
m_r.data = _mm512_mul_ps(m_r.data, beta1_r.data);
|
||||
v_r.data = _mm512_mul_ps(v_r.data, beta2_r.data);
|
||||
struct AVX_Data avx_r0;
|
||||
avx_r0.data = _mm512_mul_ps(g_r.data, g_r.data);
|
||||
m_r.data = _mm512_fmadd_ps(g_r.data, beta1_minus_r.data, m_r.data);
|
||||
v_r.data = _mm512_fmadd_ps(avx_r0.data, beta2_minus_r.data, v_r.data);
|
||||
avx_r0.data = _mm512_sqrt_ps(v_r.data);
|
||||
avx_r0.data = _mm512_div_ps(m_r.data, _mm512_add_ps(avx_r0.data, epsilon_r.data));
|
||||
avx_r0.data = _mm512_fmadd_ps(var_r.data, decay_r.data, avx_r0.data);
|
||||
var_r.data = _mm512_fmadd_ps(avx_r0.data, lr_neg_r.data, var_r.data);
|
||||
_mm512_storeu_ps(var_ptr, var_r.data);
|
||||
_mm512_storeu_ps(m_ptr, m_r.data);
|
||||
_mm512_storeu_ps(v_ptr, v_r.data);
|
||||
m_r = _mm512_mul_ps(m_r, beta1_r);
|
||||
v_r = _mm512_mul_ps(v_r, beta2_r);
|
||||
__m512 avx_r0 = _mm512_mul_ps(g_r, g_r);
|
||||
m_r = _mm512_fmadd_ps(g_r, beta1_minus_r, m_r);
|
||||
v_r = _mm512_fmadd_ps(avx_r0, beta2_minus_r, v_r);
|
||||
avx_r0 = _mm512_sqrt_ps(v_r);
|
||||
avx_r0 = _mm512_div_ps(m_r, _mm512_add_ps(avx_r0, epsilon_r));
|
||||
avx_r0 = _mm512_fmadd_ps(var_r, decay_r, avx_r0);
|
||||
var_r = _mm512_fmadd_ps(avx_r0, lr_neg_r, var_r);
|
||||
_mm512_storeu_ps(var_ptr, var_r);
|
||||
_mm512_storeu_ps(m_ptr, m_r);
|
||||
_mm512_storeu_ps(v_ptr, v_r);
|
||||
|
||||
gradient16_ptr += C16NUM;
|
||||
var_ptr += C16NUM;
|
||||
|
@ -367,3 +251,49 @@ size_t FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, floa
|
|||
#endif
|
||||
return c1;
|
||||
}
|
||||
|
||||
size_t FusedCastAdamFp16(int16_t *var16, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
|
||||
float decay, const int16_t *gradient16, size_t start, size_t end) {
|
||||
size_t c1 = start;
|
||||
#ifdef ENABLE_AVX512
|
||||
__m512 beta1_r = _mm512_set1_ps(beta1);
|
||||
__m512 beta2_r = _mm512_set1_ps(beta2);
|
||||
__m512 beta1_minus_r = _mm512_set1_ps(1.0f - beta1);
|
||||
__m512 beta2_minus_r = _mm512_set1_ps(1.0f - beta2);
|
||||
__m512 lr_neg_r = _mm512_set1_ps(-lr);
|
||||
__m512 epsilon_r = _mm512_set1_ps(epsilon);
|
||||
__m512 decay_r = _mm512_set1_ps(decay);
|
||||
size_t c16 = ((end - start) / C16NUM) * C16NUM + start;
|
||||
|
||||
const int16_t *gradient16_ptr = gradient16 + start;
|
||||
int16_t *var16_ptr = var16 + start;
|
||||
float *m_ptr = m + start;
|
||||
float *v_ptr = v + start;
|
||||
|
||||
for (; c1 < c16; c1 += C16NUM) {
|
||||
__m512 g_r = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr)));
|
||||
__m512 var_r = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(var16_ptr)));
|
||||
__m512 m_r = _mm512_loadu_ps(m_ptr);
|
||||
__m512 v_r = _mm512_loadu_ps(v_ptr);
|
||||
|
||||
m_r = _mm512_mul_ps(m_r, beta1_r);
|
||||
v_r = _mm512_mul_ps(v_r, beta2_r);
|
||||
__m512 avx_r0 = _mm512_mul_ps(g_r, g_r);
|
||||
m_r = _mm512_fmadd_ps(g_r, beta1_minus_r, m_r);
|
||||
v_r = _mm512_fmadd_ps(avx_r0, beta2_minus_r, v_r);
|
||||
avx_r0 = _mm512_sqrt_ps(v_r);
|
||||
avx_r0 = _mm512_div_ps(m_r, _mm512_add_ps(avx_r0, epsilon_r));
|
||||
avx_r0 = _mm512_fmadd_ps(var_r, decay_r, avx_r0);
|
||||
var_r = _mm512_fmadd_ps(avx_r0, lr_neg_r, var_r);
|
||||
_mm512_storeu_ps(m_ptr, m_r);
|
||||
_mm512_storeu_ps(v_ptr, v_r);
|
||||
_mm256_storeu_si256((__m256i *)var16_ptr, _mm512_cvtps_ph(var_r, 0));
|
||||
|
||||
gradient16_ptr += C16NUM;
|
||||
var16_ptr += C16NUM;
|
||||
m_ptr += C16NUM;
|
||||
v_ptr += C16NUM;
|
||||
}
|
||||
#endif
|
||||
return c1;
|
||||
}
|
||||
|
|
|
@ -26,44 +26,9 @@
|
|||
#include <x86intrin.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
#if defined(ENABLE_AVX) || defined(ENABLE_AVX512)
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
#ifdef ENABLE_AVX512
|
||||
const size_t kUnrollSize = 4;
|
||||
struct AVX_Data {
|
||||
__m512 data;
|
||||
};
|
||||
|
||||
static inline int LoadStep4(struct AVX_Data *inp0, const float *inp1, const size_t arrLen) {
|
||||
if (arrLen != kUnrollSize) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
if (inp0 == NULL || inp1 == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
}
|
||||
inp0[0].data = _mm512_loadu_ps(inp1);
|
||||
inp0[1].data = _mm512_loadu_ps(inp1 + C16NUM);
|
||||
inp0[2].data = _mm512_loadu_ps(inp1 + C16NUM * 2);
|
||||
inp0[3].data = _mm512_loadu_ps(inp1 + C16NUM * 3);
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
static inline int StoreStep4(float *inp0, struct AVX_Data *inp1, const size_t arrLen) {
|
||||
if (arrLen != kUnrollSize) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
if (inp0 == NULL || inp1 == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
}
|
||||
_mm512_storeu_ps(inp0, inp1[0].data);
|
||||
_mm512_storeu_ps(inp0 + C16NUM, inp1[1].data);
|
||||
_mm512_storeu_ps(inp0 + C16NUM * 2, inp1[2].data);
|
||||
_mm512_storeu_ps(inp0 + C16NUM * 3, inp1[3].data);
|
||||
return NNACL_OK;
|
||||
}
|
||||
#endif
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
@ -71,10 +36,12 @@ int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2,
|
|||
size_t start, size_t end, bool use_nesterov);
|
||||
int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
|
||||
const float *gradient, size_t start, size_t end, bool use_nesterov);
|
||||
size_t AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
|
||||
float decay, const float *gradient, size_t start, size_t end);
|
||||
size_t FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
||||
const int16_t *gradient16, size_t start, size_t end);
|
||||
int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
||||
const float *gradient, size_t start, size_t end);
|
||||
size_t FusedCastAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
||||
const int16_t *gradient16, size_t start, size_t end);
|
||||
size_t FusedCastAdamFp16(int16_t *var16, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
|
||||
float decay, const int16_t *gradient16, size_t start, size_t end);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -205,6 +205,7 @@ constexpr auto kStridedReadOpName = "StridedRead";
|
|||
constexpr auto kStridedWriteOpName = "StridedWrite";
|
||||
constexpr auto kFusedAdamWeightDecayName = "FusedAdamWeightDecay";
|
||||
constexpr auto kAdamWeightDecayName = "AdamWeightDecay";
|
||||
constexpr auto kFusedCastAdamWeightDecayName = "FusedCastAdamWeightDecay";
|
||||
constexpr auto kFusedAdamName = "FusedAdam";
|
||||
constexpr auto kFusedSparseAdamName = "FusedSparseAdam";
|
||||
constexpr auto kFusedMatMulBiasAddName = "FusedMatMulBiasAdd";
|
||||
|
@ -599,6 +600,7 @@ const std::set<std::string> kOptOperatorSet = {kMomentumOpName,
|
|||
kAdamApplyOneWithDecayAssignOpName,
|
||||
kFusedAdamWeightDecayName,
|
||||
kAdamWeightDecayName,
|
||||
kFusedCastAdamWeightDecayName,
|
||||
kFusedAdamName,
|
||||
kFusedSparseAdamName,
|
||||
kFusedMulApplyMomentumOpName,
|
||||
|
|
|
@ -22,6 +22,7 @@ from .div import _div_cpu
|
|||
from .concat import _concat_cpu
|
||||
from .split import _split_cpu
|
||||
from .adam import _adam_cpu
|
||||
from .adam_weight_decay import _adam_weight_decay_cpu
|
||||
from .arg_max import _arg_max_cpu
|
||||
from .arg_min_with_value import _arg_min_with_value_cpu
|
||||
from .arg_max_with_value import _arg_max_with_value_cpu
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""AdamWeightDecay op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
adam_weight_decay_op_info = CpuRegOp("AdamWeightDecay") \
|
||||
.input(0, "var", "required") \
|
||||
.input(1, "m", "required") \
|
||||
.input(2, "v", "required") \
|
||||
.input(3, "lr", "required") \
|
||||
.input(4, "beta1", "required") \
|
||||
.input(5, "beta2", "required") \
|
||||
.input(6, "epsilon", "required") \
|
||||
.input(7, "decay", "required") \
|
||||
.input(8, "gradient", "required") \
|
||||
.output(0, "output0", "required") \
|
||||
.output(1, "output1", "required") \
|
||||
.output(2, "output2", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(adam_weight_decay_op_info)
|
||||
def _adam_weight_decay_cpu():
|
||||
"""AdamWeightDecay cpu register"""
|
||||
return
|
|
@ -44,7 +44,7 @@ from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSumm
|
|||
from .control_ops import GeSwitch, Merge
|
||||
from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign,
|
||||
MakeRefKey,
|
||||
FusedWeightScaleApplyMomentum, AdamWeightDecay)
|
||||
FusedWeightScaleApplyMomentum, FusedCastAdamWeightDecay)
|
||||
|
||||
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul,
|
||||
BitwiseAnd, BitwiseOr,
|
||||
|
@ -65,8 +65,8 @@ from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, U
|
|||
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
|
||||
LogUniformCandidateSampler)
|
||||
|
||||
from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum,
|
||||
BatchNorm, BiasAdd, Conv2D, Conv3D, Conv2DTranspose, Conv3DTranspose,
|
||||
from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam,
|
||||
ApplyMomentum, BatchNorm, BiasAdd, Conv2D, Conv3D, Conv2DTranspose, Conv3DTranspose,
|
||||
DepthwiseConv2dNative,
|
||||
DropoutDoMask, Dropout, Dropout2D, Dropout3D, DropoutGenMask, Flatten,
|
||||
InstanceNorm, BNTrainingReduce, BNTrainingUpdate,
|
||||
|
|
|
@ -453,9 +453,11 @@ class FusedWeightScaleApplyMomentum(PrimitiveWithInfer):
|
|||
return v_dtype
|
||||
|
||||
|
||||
class AdamWeightDecay(PrimitiveWithInfer):
|
||||
class FusedCastAdamWeightDecay(PrimitiveWithInfer):
|
||||
r"""
|
||||
Updates gradients by the Adaptive Moment Estimation (AdamWeightDecay) algorithm with weight decay.
|
||||
Updates gradients by the Adaptive Moment Estimation (AdamWeightDecay) algorithm with weight decay. This operator
|
||||
updates parameters of float16 using gradients of float16 when parameters are initialized with dtype of float16 and
|
||||
compute with dtype of float16.
|
||||
|
||||
The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
|
||||
|
||||
|
@ -479,15 +481,14 @@ class AdamWeightDecay(PrimitiveWithInfer):
|
|||
If false, the result is unpredictable. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **var** (Tensor) - Weights to be updated.
|
||||
- **m** (Tensor) - The 1st moment vector in the updating formula, has the same type as `var`.
|
||||
- **v** (Tensor) - the 2nd moment vector in the updating formula.
|
||||
Mean square gradients with the same type as `var`.
|
||||
- **lr** (float) - :math:`l` in the updating formula.
|
||||
- **var** (Tensor) - Weights to be updated with the type float16 or float32.
|
||||
- **m** (Tensor) - The 1st moment vector in the updating formula with the type float32.
|
||||
- **v** (Tensor) - the 2nd moment vector in the updating formula with the type float32.
|
||||
- **lr** (float) - :math:`lr` in the updating formula.
|
||||
- **beta1** (float) - The exponential decay rate for the 1st moment estimations.
|
||||
- **beta2** (float) - The exponential decay rate for the 2nd moment estimations.
|
||||
- **epsilon** (float) - Term added to the denominator to improve numerical stability.
|
||||
- **gradient** (Tensor) - Gradient, has the same type as `var`.
|
||||
- **gradient** (Tensor) - Gradient, has the type float16.
|
||||
|
||||
Outputs:
|
||||
Tuple of 3 Tensor, the updated parameters.
|
||||
|
@ -497,27 +498,10 @@ class AdamWeightDecay(PrimitiveWithInfer):
|
|||
- **v** (Tensor) - The same shape and data type as `v`.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore import Tensor, Parameter, ops
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super(Net, self).__init__()
|
||||
... self.adam_weight_decay = ops.AdamWeightDecay()
|
||||
... self.var = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="var")
|
||||
... self.m = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="m")
|
||||
... self.v = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="v")
|
||||
... def construct(self, lr, beta1, beta2, epsilon, decay, grad):
|
||||
... out = self.adam_weight_decay(self.var, self.m, self.v, lr, beta1, beta2,
|
||||
... epsilon, decay, grad)
|
||||
... return out
|
||||
>>> np.random.seed(0)
|
||||
>>> net = Net()
|
||||
>>> gradient = Tensor(np.random.rand(2, 2).astype(np.float32))
|
||||
>>> output = net(0.9, 0.9, 0.999, 1e-8, 1e-5, gradient)
|
||||
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -534,11 +518,11 @@ class AdamWeightDecay(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype,
|
||||
epsilon_dtype, decay, grad_dtype):
|
||||
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype}
|
||||
args = {"m": m_dtype, "v": v_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
|
||||
validator.check_scalar_or_tensor_types_same({"grad": grad_dtype}, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_scalar_or_tensor_types_same({"var": var_dtype}, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_scalar_or_tensor_types_same({"grad": grad_dtype}, [mstype.float16], self.name)
|
||||
|
||||
args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype,
|
||||
"decay": decay}
|
||||
args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype, "decay": decay}
|
||||
validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True)
|
||||
return var_dtype, m_dtype, v_dtype
|
||||
|
|
|
@ -4465,6 +4465,100 @@ class Adam(PrimitiveWithInfer):
|
|||
return var_dtype, m_dtype, v_dtype
|
||||
|
||||
|
||||
class AdamWeightDecay(PrimitiveWithInfer):
|
||||
r"""
|
||||
Updates gradients by the Adaptive Moment Estimation (AdamWeightDecay) algorithm with weight decay.
|
||||
|
||||
The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
|
||||
|
||||
The updating formulas are as follows,
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
m = \beta_1 * m + (1 - \beta_1) * g \\
|
||||
v = \beta_2 * v + (1 - \beta_2) * g * g \\
|
||||
w = w - lr * \frac{m}{\sqrt{v} + \epsilon}
|
||||
\end{array}
|
||||
|
||||
:math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents
|
||||
`gradient`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
|
||||
:math:`\lr` represents `learning_rate`, :math:`w` represents `var`, :math:`\epsilon` represents
|
||||
`epsilon`.
|
||||
|
||||
Args:
|
||||
use_locking (bool): Whether to enable a lock to protect variable tensors from being updated.
|
||||
If true, updates of the var, m, and v tensors will be protected by a lock.
|
||||
If false, the result is unpredictable. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **var** (Tensor) - Weights to be updated. The shape is :math:`(N, *)` where :math:`*` means,
|
||||
any number of additional dimensions. The data type can be float16 or float32.
|
||||
- **m** (Tensor) - The 1st moment vector in the updating formula,
|
||||
the shape and data type value should be the same as `var`.
|
||||
- **v** (Tensor) - the 2nd moment vector in the updating formula,
|
||||
the shape and data type value should be the same as `var`. Mean square gradients with the same type as `var`.
|
||||
- **lr** (float) - :math:`l` in the updating formula. The paper suggested value is :math:`10^{-8}`,
|
||||
the data type value should be the same as `var`.
|
||||
- **beta1** (float) - The exponential decay rate for the 1st moment estimations,
|
||||
the data type value should be the same as `var`. The paper suggested value is :math:`0.9`
|
||||
- **beta2** (float) - The exponential decay rate for the 2nd moment estimations,
|
||||
the data type value should be the same as `var`. The paper suggested value is :math:`0.999`
|
||||
- **epsilon** (float) - Term added to the denominator to improve numerical stability.
|
||||
- **gradient** (Tensor) - Gradient, has the same shape and data type as `var`.
|
||||
Outputs:
|
||||
Tuple of 3 Tensor, the updated parameters.
|
||||
|
||||
- **var** (Tensor) - The same shape and data type as `var`.
|
||||
- **m** (Tensor) - The same shape and data type as `m`.
|
||||
- **v** (Tensor) - The same shape and data type as `v`.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore import Tensor, Parameter, ops
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super(Net, self).__init__()
|
||||
... self.adam_weight_decay = ops.AdamWeightDecay()
|
||||
... self.var = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="var")
|
||||
... self.m = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="m")
|
||||
... self.v = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="v")
|
||||
... def construct(self, lr, beta1, beta2, epsilon, decay, grad):
|
||||
... out = self.adam_weight_decay(self.var, self.m, self.v, lr, beta1, beta2,
|
||||
... epsilon, decay, grad)
|
||||
... return out
|
||||
>>> np.random.seed(0)
|
||||
>>> net = Net()
|
||||
>>> gradient = Tensor(np.random.rand(2, 2).astype(np.float32))
|
||||
>>> output = net(0.9, 0.9, 0.999, 1e-8, 1e-5, gradient)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False):
|
||||
"""Initialize AdamWeightDecay."""
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
|
||||
def infer_shape(self, var_shape, m_shape, v_shape, lr_shape, beta1_shape, beta2_shape,
|
||||
epsilon_shape, decay_shape, grad_shape):
|
||||
validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
|
||||
validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
|
||||
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
|
||||
return var_shape, m_shape, v_shape
|
||||
|
||||
def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype,
|
||||
epsilon_dtype, decay, grad_dtype):
|
||||
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
|
||||
|
||||
args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype, "decay": decay}
|
||||
validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True)
|
||||
return var_dtype, m_dtype, v_dtype
|
||||
|
||||
|
||||
class AdamNoUpdateParam(PrimitiveWithInfer):
|
||||
r"""
|
||||
Updates gradients by Adaptive Moment Estimation (Adam) algorithm. This operator do not update the parameter, but
|
||||
|
|
|
@ -19,7 +19,9 @@ from mindspore.common import dtype as mstype
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
|
@ -28,6 +30,11 @@ _adam_opt = C.MultitypeFuncGraph("adam_opt")
|
|||
_scaler_one = Tensor(1, mstype.int32)
|
||||
_scaler_ten = Tensor(10, mstype.float32)
|
||||
|
||||
op_assign = P.Assign()
|
||||
op_assign.add_prim_attr("primitive_target", "CPU")
|
||||
op_cast = P.Cast()
|
||||
op_cast.add_prim_attr("primitive_target", "CPU")
|
||||
|
||||
|
||||
@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Bool", "Bool")
|
||||
|
@ -127,16 +134,23 @@ class AdamWeightDecayOp(Optimizer):
|
|||
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
||||
"""
|
||||
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
|
||||
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0,
|
||||
param_init_type=mstype.float32):
|
||||
super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay)
|
||||
_check_param_value(beta1, beta2, eps, self.cls_name)
|
||||
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
|
||||
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
|
||||
self.eps = Tensor(np.array([eps]).astype(np.float32))
|
||||
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
|
||||
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
|
||||
self.enable_init_fp16 = (param_init_type == mstype.float16)
|
||||
if self.enable_init_fp16:
|
||||
self.moments1 = self.clone_param32(prefix="adam_m", init='zeros')
|
||||
self.moments2 = self.clone_param32(prefix="adam_v", init='zeros')
|
||||
self.opt = P.FusedCastAdamWeightDecay()
|
||||
else:
|
||||
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
|
||||
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
|
||||
self.opt = P.AdamWeightDecay()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.AdamWeightDecay()
|
||||
self.opt.add_prim_attr("primitive_target", "CPU")
|
||||
|
||||
def construct(self, gradients):
|
||||
|
@ -158,3 +172,27 @@ class AdamWeightDecayOp(Optimizer):
|
|||
if self.use_parallel:
|
||||
self.broadcast_params(optim_result)
|
||||
return optim_result
|
||||
|
||||
def clone_param32(self, prefix, init=None):
|
||||
"""
|
||||
Clone the parameters in ParameterTuple element-wisely to generate a new ParameterTuple with float32 data type.
|
||||
Inputs:
|
||||
prefix (str): The prefix name of the parameters.
|
||||
init (Union[Tensor, str, numbers.Number]): Initialize the shape and dtype of the parameters.
|
||||
The definition of `init` is the same as in `Parameter` API. If `init` is 'same', the
|
||||
parameters in the new parameter tuple are the same as those in the original parameter tuple.
|
||||
Default: 'same'.
|
||||
Returns:
|
||||
Tuple, the new Parameter tuple.
|
||||
"""
|
||||
new = []
|
||||
for old_param in self.parameters:
|
||||
param_init = init
|
||||
if init is None:
|
||||
param_init = old_param.init
|
||||
new_state = old_param.clone()
|
||||
new_state.set_dtype(mstype.float32)
|
||||
new_state.set_data(initializer(param_init, shape=old_param.shape, dtype=mstype.float32))
|
||||
new_state.name = prefix + '.' + new_state.name
|
||||
new.append(new_state)
|
||||
return ParameterTuple(new)
|
||||
|
|
|
@ -934,7 +934,8 @@ class PanguAlpha_Model(nn.Cell):
|
|||
self.is_pipeline = (config.stage_num > 1)
|
||||
if self.is_pipeline:
|
||||
self.top_query_embedding_table = Parameter(initializer(TruncatedNormal(0.02),
|
||||
[config.seq_length, config.embedding_size]),
|
||||
[config.seq_length, config.embedding_size],
|
||||
config.param_init_type),
|
||||
name='embedding_table', parallel_optimizer=False)
|
||||
self.top_query_embedding = EmbeddingLookup()
|
||||
for i in range(config.num_layers):
|
||||
|
@ -1084,7 +1085,8 @@ class PanguAlpha(nn.Cell):
|
|||
raise ValueError(f"{embedding_path} file not exits, "
|
||||
f"please check whether word_embedding file exist.")
|
||||
else:
|
||||
self.embedding_table = Parameter(initializer(Normal(0.02), [self.vocab_size, self.embedding_size]),
|
||||
self.embedding_table = Parameter(initializer(Normal(0.02), [self.vocab_size, self.embedding_size],
|
||||
config.param_init_type),
|
||||
name="embedding_table", parallel_optimizer=False)
|
||||
|
||||
def construct(self, input_ids, input_mask, input_position, attention_mask,
|
||||
|
|
|
@ -42,7 +42,8 @@ class PANGUALPHAConfig:
|
|||
micro_size=32,
|
||||
load_ckpt_path=None,
|
||||
use_top_query_attention=True,
|
||||
param_init_type=mstype.float32):
|
||||
param_init_type=mstype.float32,
|
||||
enable_offload=False):
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.vocab_size = vocab_size
|
||||
|
@ -67,6 +68,7 @@ class PANGUALPHAConfig:
|
|||
self.load_ckpt_path = load_ckpt_path
|
||||
self.use_top_query_attention = use_top_query_attention
|
||||
self.param_init_type = param_init_type
|
||||
self.enable_offload = enable_offload
|
||||
|
||||
def __str__(self):
|
||||
info = "[PANGUALPHAConfig]" + '===' * 10 + '\n'
|
||||
|
|
|
@ -83,9 +83,12 @@ def _get_square_sum(grad, value):
|
|||
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
|
||||
|
||||
|
||||
@apply_global_norm.register("Tensor", "Tensor", "Tensor")
|
||||
def _apply_global_norm(clip_norm, global_norm, grad):
|
||||
grad = grad * clip_norm / global_norm
|
||||
@apply_global_norm.register("Bool", "Tensor", "Tensor", "Tensor")
|
||||
def _apply_global_norm(enable_grad_fp16, clip_norm, global_norm, grad):
|
||||
if enable_grad_fp16:
|
||||
grad = P.Cast()(grad * clip_norm / global_norm, mstype.float16)
|
||||
else:
|
||||
grad = grad * clip_norm / global_norm
|
||||
return grad
|
||||
|
||||
|
||||
|
@ -188,13 +191,17 @@ class ClipByGlobalNorm(nn.Cell):
|
|||
self.global_norm = GlobalNorm(params, config)
|
||||
self.clip_norm = Tensor([clip_norm], mstype.float32)
|
||||
self.hyper_map = C.HyperMap()
|
||||
if config.param_init_type == mstype.float16 and config.enable_offload:
|
||||
self.enable_grad_fp16 = True
|
||||
else:
|
||||
self.enable_grad_fp16 = False
|
||||
|
||||
def construct(self, grads):
|
||||
"""Clip grads by global norm construct"""
|
||||
global_norm_value = self.global_norm(grads)
|
||||
cond = P.GreaterEqual()(global_norm_value, self.clip_norm)
|
||||
global_norm = F.select(cond, global_norm_value, self.clip_norm)
|
||||
grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads)
|
||||
grads = self.hyper_map(F.partial(apply_global_norm, self.enable_grad_fp16, self.clip_norm, global_norm), grads)
|
||||
return grads, global_norm_value
|
||||
|
||||
|
||||
|
|
|
@ -40,7 +40,6 @@ from src.callbacks import EvalCallBack, LossCallBack
|
|||
from src.metrics import PPLMetric
|
||||
|
||||
|
||||
|
||||
project_root = os.path.abspath(
|
||||
os.path.dirname(os.path.realpath(__file__)) + os.path.sep + "..")
|
||||
print('project_root:', project_root)
|
||||
|
@ -110,7 +109,7 @@ def run_train(args_opt):
|
|||
stage_num=args_opt.stage_num, micro_size=args_opt.micro_size,
|
||||
eod_reset=bool(args_opt.eod_reset), load_ckpt_path=args_opt.load_ckpt_path,
|
||||
param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16,
|
||||
word_emb_dp=bool(args_opt.word_emb_dp))
|
||||
word_emb_dp=bool(args_opt.word_emb_dp), enable_offload=bool(args_opt.opt_offload))
|
||||
print("===config is: ", config, flush=True)
|
||||
# Define network
|
||||
pangu_alpha = PanguAlpha(config)
|
||||
|
@ -126,7 +125,8 @@ def run_train(args_opt):
|
|||
if args_opt.optimizer == "lamb":
|
||||
optimizer = nn.Lamb(group_params, learning_rate=lr)
|
||||
elif args_opt.opt_offload:
|
||||
optimizer = AdamWeightDecayOp(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95)
|
||||
optimizer = AdamWeightDecayOp(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95,
|
||||
param_init_type=config.param_init_type)
|
||||
else:
|
||||
optimizer = FP32StateAdamWeightDecay(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95)
|
||||
# Initial scaling sens
|
||||
|
@ -219,7 +219,7 @@ def run_train_pipeline(args_opt):
|
|||
use_past=False,
|
||||
stage_num=args_opt.stage_num,
|
||||
micro_size=args_opt.micro_size,
|
||||
word_emb_dp=bool(args_opt.word_emb_dp))
|
||||
word_emb_dp=bool(args_opt.word_emb_dp), enable_offload=bool(args_opt.opt_offload))
|
||||
print("===config is: ", config, flush=True)
|
||||
pangu_alpha = PanguAlpha(config)
|
||||
loss = CrossEntropyLoss(config)
|
||||
|
@ -233,7 +233,8 @@ def run_train_pipeline(args_opt):
|
|||
if args_opt.optimizer == "lamb":
|
||||
optimizer = nn.Lamb(group_params, learning_rate=lr)
|
||||
elif args_opt.opt_offload:
|
||||
optimizer = AdamWeightDecayOp(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95)
|
||||
optimizer = AdamWeightDecayOp(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95,
|
||||
param_init_type=config.param_init_type)
|
||||
else:
|
||||
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8)
|
||||
|
||||
|
|
Loading…
Reference in New Issue