Adam op performance optimization

This commit is contained in:
zhaosida 2021-08-27 14:58:07 +08:00
parent b9ec533f95
commit be34ccd29f
19 changed files with 656 additions and 400 deletions

View File

@ -13,145 +13,109 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/adam_weight_decay_cpu_kernel.h"
#include <cmath>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "backend/kernel_compiler/cpu/adam_weight_decay_cpu_kernel.h"
#include "nnacl/errorcode.h"
#include "nnacl/fp32/adam_fp32.h"
#include "utils/ms_utils.h"
#include "common/thread_pool.h"
namespace mindspore {
namespace kernel {
constexpr size_t kSizeFloat16 = sizeof(float16);
constexpr size_t kSizeFloat32 = sizeof(float);
constexpr size_t kScalarIndex = 0;
constexpr size_t kAdamWeightDecayInputSize = 9;
constexpr size_t kAdamWeightDecayOutputSize = 3;
template <typename T, typename S>
void AdamWeightDecayCPUKernel::LaunchFusedAdam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &) {
auto var = reinterpret_cast<T *>(inputs[VAR]->addr);
auto m = reinterpret_cast<T *>(inputs[M]->addr);
auto v = reinterpret_cast<T *>(inputs[V]->addr);
auto lr = reinterpret_cast<T *>(inputs[LR]->addr)[kScalarIndex];
auto beta1 = reinterpret_cast<T *>(inputs[BETA1]->addr)[kScalarIndex];
auto beta2 = reinterpret_cast<T *>(inputs[BETA2]->addr)[kScalarIndex];
auto epsilon = reinterpret_cast<T *>(inputs[EPSILON]->addr)[kScalarIndex];
auto decay = reinterpret_cast<T *>(inputs[DECAY]->addr)[kScalarIndex];
auto gradient16 = reinterpret_cast<S *>(inputs[GRAD]->addr);
const auto beta1_minus = 1 - beta1;
const auto beta2_minus = 1 - beta2;
template <typename T>
void AdamWeightDecayCPUKernel::LaunchAdamWeightDecay(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &) {
T *var = reinterpret_cast<T *>(inputs[VAR]->addr);
T *m = reinterpret_cast<T *>(inputs[M]->addr);
T *v = reinterpret_cast<T *>(inputs[V]->addr);
T lr = static_cast<T>(reinterpret_cast<float *>(inputs[LR]->addr)[kScalarIndex]);
T beta1 = static_cast<T>(reinterpret_cast<float *>(inputs[BETA1]->addr)[kScalarIndex]);
T beta2 = static_cast<T>(reinterpret_cast<float *>(inputs[BETA2]->addr)[kScalarIndex]);
T epsilon = static_cast<T>(reinterpret_cast<float *>(inputs[EPSILON]->addr)[kScalarIndex]);
T decay = static_cast<T>(reinterpret_cast<float *>(inputs[DECAY]->addr)[kScalarIndex]);
T *gradient = reinterpret_cast<T *>(inputs[GRAD]->addr);
const T one = static_cast<T>(1.0);
const T beta1_minus = one - beta1;
const T beta2_minus = one - beta2;
// multithreading
size_t lens = inputs[VAR]->size > 0 ? static_cast<size_t>(inputs[VAR]->size / sizeof(float)) : 1;
size_t lens = inputs[VAR]->size > 0 ? static_cast<size_t>(inputs[VAR]->size / sizeof(T)) : 1;
std::function<void(size_t, size_t)> task;
task = [&](size_t start, size_t end) {
size_t i =
FusedAdamFp32(var, m, v, lr, beta1, beta2, epsilon, decay, reinterpret_cast<int16_t *>(gradient16), start, end);
// remaining
for (; i < end; i++) {
auto temp = static_cast<float>(gradient16[i]);
m[i] += (temp - m[i]) * beta1_minus;
v[i] += (temp * temp - v[i]) * beta2_minus;
T update = m[i] / (std::sqrt(v[i]) + epsilon);
update += decay * var[i];
for (size_t i = start; i < end; i++) {
m[i] += (gradient[i] - m[i]) * beta1_minus;
v[i] += (gradient[i] * gradient[i] - v[i]) * beta2_minus;
T sqrt_v = static_cast<T>(std::sqrt(static_cast<double>(v[i])));
auto update = m[i] / (sqrt_v + epsilon) + decay * var[i];
var[i] -= lr * update;
}
};
CPUKernelUtils::ParallelFor(task, lens);
}
template <typename T>
void AdamWeightDecayCPUKernel::LaunchAdamWeightDecay(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &) {
auto var = reinterpret_cast<T *>(inputs[VAR]->addr);
auto m = reinterpret_cast<T *>(inputs[M]->addr);
auto v = reinterpret_cast<T *>(inputs[V]->addr);
auto lr = reinterpret_cast<T *>(inputs[LR]->addr)[kScalarIndex];
auto beta1 = reinterpret_cast<T *>(inputs[BETA1]->addr)[kScalarIndex];
auto beta2 = reinterpret_cast<T *>(inputs[BETA2]->addr)[kScalarIndex];
auto epsilon = reinterpret_cast<T *>(inputs[EPSILON]->addr)[kScalarIndex];
auto decay = reinterpret_cast<T *>(inputs[DECAY]->addr)[kScalarIndex];
auto gradient = reinterpret_cast<T *>(inputs[GRAD]->addr);
const auto beta1_minus = 1 - beta1;
const auto beta2_minus = 1 - beta2;
void AdamWeightDecayCPUKernel::LaunchAdamWeightDecayNnacl(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &) {
auto var = reinterpret_cast<float *>(inputs[VAR]->addr);
auto m = reinterpret_cast<float *>(inputs[M]->addr);
auto v = reinterpret_cast<float *>(inputs[V]->addr);
auto lr = reinterpret_cast<float *>(inputs[LR]->addr)[kScalarIndex];
auto beta1 = reinterpret_cast<float *>(inputs[BETA1]->addr)[kScalarIndex];
auto beta2 = reinterpret_cast<float *>(inputs[BETA2]->addr)[kScalarIndex];
auto epsilon = reinterpret_cast<float *>(inputs[EPSILON]->addr)[kScalarIndex];
auto decay = reinterpret_cast<float *>(inputs[DECAY]->addr)[kScalarIndex];
auto gradient = reinterpret_cast<float *>(inputs[GRAD]->addr);
// multithreading
size_t lens = inputs[VAR]->size > 0 ? static_cast<size_t>(inputs[VAR]->size / sizeof(float)) : 1;
std::function<void(size_t, size_t)> task;
task = [&](size_t start, size_t end) {
size_t i = AdamWeightDecayFp32(var, m, v, lr, beta1, beta2, epsilon, decay, gradient, start, end);
// remaining
for (; i < end; i++) {
m[i] += (gradient[i] - m[i]) * beta1_minus;
v[i] += (gradient[i] * gradient[i] - v[i]) * beta2_minus;
T update = m[i] / (std::sqrt(v[i]) + epsilon);
update += decay * var[i];
var[i] -= lr * update;
int ret = AdamWeightDecayFp32(var, m, v, lr, beta1, beta2, epsilon, decay, gradient, start, end);
if (ret != NNACL_OK) {
MS_LOG(EXCEPTION) << "AdamWeightDecayFp32 failed.";
}
};
CPUKernelUtils::ParallelFor(task, lens);
ParallelLaunchAutoSearch(task, lens, this, &parallel_search_info_);
}
void AdamWeightDecayCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> var_shape = AnfAlgo::GetInputDeviceShape(kernel_node, VAR);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, VAR);
gradient_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, GRAD);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != kAdamWeightDecayInputSize) {
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
if (input_num != kAdamWeightDecayInputNum) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but AdamWeightDecay needs 9 inputs.";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != kAdamWeightDecayOutputSize) {
if (output_num != kAdamWeightDecayOutputNum) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but AdamWeightDecay needs 3 outputs.";
}
elem_num_ = 1;
for (size_t i : var_shape) {
elem_num_ *= i;
}
if (elem_num_ < 1) {
MS_LOG(EXCEPTION) << "Invalid parameter shape";
}
if (dtype_ != kNumberTypeFloat32) {
MS_LOG(EXCEPTION) << "The dtype of parameter must be float32!";
}
if (gradient_dtype_ != kNumberTypeFloat32 && gradient_dtype_ != kNumberTypeFloat16) {
MS_LOG(EXCEPTION) << "The dtype of gradient must be float32 or float16!";
}
}
void AdamWeightDecayCPUKernel::CheckParam(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.size() != kAdamWeightDecayInputSize) {
MS_LOG(EXCEPTION) << "Input number is " << inputs.size() << ", but AdamWeightDecay needs 9 inputs.";
}
if (outputs.size() != kAdamWeightDecayOutputSize) {
MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but AdamWeightDecay needs 3 outputs.";
}
size_t elem1_size = elem_num_ * kSizeFloat32;
size_t elem2_size = gradient_dtype_ == kNumberTypeFloat16 ? elem_num_ * kSizeFloat16 : elem1_size;
if (inputs[VAR]->size != elem1_size || inputs[M]->size != elem1_size || inputs[V]->size != elem1_size ||
inputs[GRAD]->size != elem2_size) {
MS_LOG(EXCEPTION) << "Error input data size!";
}
if (inputs[LR]->size != kSizeFloat32 || inputs[BETA1]->size != kSizeFloat32 || inputs[BETA2]->size != kSizeFloat32 ||
inputs[EPSILON]->size != kSizeFloat32 || inputs[DECAY]->size != kSizeFloat32) {
MS_LOG(EXCEPTION) << "The attribute beta, lr, epsilon and weight decay must be float!";
}
}
bool AdamWeightDecayCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CheckParam(inputs, outputs);
if (gradient_dtype_ == kNumberTypeFloat16) {
LaunchFusedAdam<float, float16>(inputs, outputs);
} else if (gradient_dtype_ == kNumberTypeFloat32) {
LaunchAdamWeightDecay<float>(inputs, outputs);
if (inputs.size() != kAdamWeightDecayInputNum) {
MS_LOG(EXCEPTION) << "Input number is " << inputs.size() << ", but AdamWeightDecay needs 9 inputs.";
}
if (outputs.size() != kAdamWeightDecayOutputNum) {
MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but AdamWeightDecay needs 3 outputs.";
}
if (inputs[VAR]->size != inputs[M]->size || inputs[VAR]->size != inputs[V]->size ||
inputs[VAR]->size != inputs[GRAD]->size) {
MS_LOG(EXCEPTION) << "Error input data size!";
}
if (inputs[LR]->size != kSizeFloat32 || inputs[BETA1]->size != kSizeFloat32 || inputs[BETA2]->size != kSizeFloat32 ||
inputs[EPSILON]->size != kSizeFloat32 || inputs[DECAY]->size != kSizeFloat32) {
MS_LOG(EXCEPTION) << "The attribute beta, lr, epsilon and weight decay must be float!";
}
if (dtype_ == kNumberTypeFloat32) {
LaunchAdamWeightDecayNnacl(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat16) {
LaunchAdamWeightDecay<float16>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "AdamWeightDecay not support " << dtype_;
}
return true;
}

View File

@ -23,6 +23,11 @@
namespace mindspore {
namespace kernel {
constexpr size_t kSizeFloat32 = sizeof(float);
constexpr size_t kScalarIndex = 0;
constexpr size_t kAdamWeightDecayInputNum = 9;
constexpr size_t kAdamWeightDecayOutputNum = 3;
class AdamWeightDecayCPUKernel : public CPUKernel {
public:
AdamWeightDecayCPUKernel() = default;
@ -32,48 +37,15 @@ class AdamWeightDecayCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override;
private:
void CheckParam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T, typename S>
void LaunchFusedAdam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T>
void LaunchAdamWeightDecay(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
size_t elem_num_{0};
void LaunchAdamWeightDecayNnacl(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
TypeId dtype_{kTypeUnknown};
TypeId gradient_dtype_{kTypeUnknown};
enum input_list_ { VAR, M, V, LR, BETA1, BETA2, EPSILON, DECAY, GRAD };
};
MS_REG_CPU_KERNEL(AdamWeightDecay,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
AdamWeightDecayCPUKernel)
MS_REG_CPU_KERNEL(AdamWeightDecay,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
AdamWeightDecayCPUKernel)
MS_REG_CPU_KERNEL(AdamWeightDecay, KernelAttr(), AdamWeightDecayCPUKernel)
} // namespace kernel
} // namespace mindspore

View File

@ -140,11 +140,16 @@ void CPUKernelUtils::ParallelForAutoSearch(const CTask &task, size_t count, Para
ActorThreadPool *GetActorMgrInnerThreadPool() {
auto actor_manager = ActorMgr::GetActorMgrRef();
MS_EXCEPTION_IF_NULL(actor_manager);
auto thread_pool = actor_manager->GetActorThreadPool();
// Init thread_pool if env is windows or ascend, in case that it won't be init in graph_scheduler.
if (thread_pool == nullptr) {
const size_t kMaxThreadNum = 23;
size_t max_thread_num = std::thread::hardware_concurrency() - 1;
#if ENABLE_D || ENABLE_GPU
const size_t kDeviceNum = 8;
max_thread_num /= kDeviceNum;
#endif
if (max_thread_num < 1) {
max_thread_num = 1;
}

View File

@ -0,0 +1,156 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cmath>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "backend/kernel_compiler/cpu/fused_cast_adam_weight_decay_cpu_kernel.h"
#include "nnacl/fp32/adam_fp32.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace kernel {
void FusedCastAdamWeightDecayCPUKernel::LaunchFusedCastAdamFp32(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &) {
auto var = reinterpret_cast<float *>(inputs[VAR]->addr);
auto m = reinterpret_cast<float *>(inputs[M]->addr);
auto v = reinterpret_cast<float *>(inputs[V]->addr);
auto lr = reinterpret_cast<float *>(inputs[LR]->addr)[kScalarIndex];
auto beta1 = reinterpret_cast<float *>(inputs[BETA1]->addr)[kScalarIndex];
auto beta2 = reinterpret_cast<float *>(inputs[BETA2]->addr)[kScalarIndex];
auto epsilon = reinterpret_cast<float *>(inputs[EPSILON]->addr)[kScalarIndex];
auto decay = reinterpret_cast<float *>(inputs[DECAY]->addr)[kScalarIndex];
auto gradient16 = reinterpret_cast<float16 *>(inputs[GRAD]->addr);
const auto beta1_minus = 1 - beta1;
const auto beta2_minus = 1 - beta2;
// multithreading
size_t lens = inputs[VAR]->size > 0 ? static_cast<size_t>(inputs[VAR]->size / kSizeFloat32) : 1;
std::function<void(size_t, size_t)> task;
task = [&](size_t start, size_t end) {
size_t i = FusedCastAdamFp32(var, m, v, lr, beta1, beta2, epsilon, decay, reinterpret_cast<int16_t *>(gradient16),
start, end);
// remaining
for (; i < end; i++) {
auto temp = static_cast<float>(gradient16[i]);
m[i] += (temp - m[i]) * beta1_minus;
v[i] += (temp * temp - v[i]) * beta2_minus;
auto update = m[i] / (std::sqrt(v[i]) + epsilon);
update += decay * var[i];
var[i] -= lr * update;
}
};
ParallelLaunchAutoSearch(task, lens, this, &parallel_search_info_);
}
void FusedCastAdamWeightDecayCPUKernel::LaunchFusedCastAdamFp16(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &) {
auto var16 = reinterpret_cast<float16 *>(inputs[VAR]->addr);
auto m = reinterpret_cast<float *>(inputs[M]->addr);
auto v = reinterpret_cast<float *>(inputs[V]->addr);
auto lr = reinterpret_cast<float *>(inputs[LR]->addr)[kScalarIndex];
auto beta1 = reinterpret_cast<float *>(inputs[BETA1]->addr)[kScalarIndex];
auto beta2 = reinterpret_cast<float *>(inputs[BETA2]->addr)[kScalarIndex];
auto epsilon = reinterpret_cast<float *>(inputs[EPSILON]->addr)[kScalarIndex];
auto decay = reinterpret_cast<float *>(inputs[DECAY]->addr)[kScalarIndex];
auto gradient16 = reinterpret_cast<float16 *>(inputs[GRAD]->addr);
const auto beta1_minus = 1 - beta1;
const auto beta2_minus = 1 - beta2;
// multithreading
size_t lens = inputs[VAR]->size > 0 ? static_cast<size_t>(inputs[VAR]->size / kSizeFloat16) : 1;
std::function<void(size_t, size_t)> task;
task = [&](size_t start, size_t end) {
size_t i = FusedCastAdamFp16(reinterpret_cast<int16_t *>(var16), m, v, lr, beta1, beta2, epsilon, decay,
reinterpret_cast<int16_t *>(gradient16), start, end);
// remaining
for (; i < end; i++) {
auto temp_var = static_cast<float>(var16[i]);
auto temp_grad = static_cast<float>(gradient16[i]);
m[i] += (temp_grad - m[i]) * beta1_minus;
v[i] += (temp_grad * temp_grad - v[i]) * beta2_minus;
auto update = m[i] / (std::sqrt(v[i]) + epsilon);
update += decay * temp_var;
temp_var -= lr * update;
var16[i] = static_cast<float16>(temp_var);
}
};
ParallelLaunchAutoSearch(task, lens, this, &parallel_search_info_);
}
void FusedCastAdamWeightDecayCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> var_shape = AnfAlgo::GetInputDeviceShape(kernel_node, VAR);
var_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, VAR);
gradient_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, GRAD);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != kFusedCastAdamWeightDecayInputNum) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but AdamWeightDecay needs 9 inputs.";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != kFusedCastAdamWeightDecayOutputNum) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but AdamWeightDecay needs 3 outputs.";
}
elem_num_ = 1;
for (size_t i : var_shape) {
elem_num_ *= i;
}
if (elem_num_ < 1) {
MS_LOG(EXCEPTION) << "Invalid parameter shape";
}
if (gradient_dtype_ != kNumberTypeFloat16) {
MS_LOG(EXCEPTION) << "The dtype of gradient must be float16!";
}
if (var_dtype_ != kNumberTypeFloat32 && var_dtype_ != kNumberTypeFloat16) {
MS_LOG(EXCEPTION) << "The dtype of parameter must be float32 or float16!";
}
}
void FusedCastAdamWeightDecayCPUKernel::CheckParam(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.size() != kFusedCastAdamWeightDecayInputNum) {
MS_LOG(EXCEPTION) << "Input number is " << inputs.size() << ", but AdamWeightDecay needs 9 inputs.";
}
if (outputs.size() != kFusedCastAdamWeightDecayOutputNum) {
MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but AdamWeightDecay needs 3 outputs.";
}
size_t elem_size_fp32 = elem_num_ * kSizeFloat32;
size_t elem_size_fp16 = elem_num_ * kSizeFloat16;
size_t var_size = var_dtype_ == kNumberTypeFloat16 ? elem_size_fp16 : elem_size_fp32;
if (inputs[VAR]->size != var_size || inputs[M]->size != elem_size_fp32 || inputs[V]->size != elem_size_fp32 ||
inputs[GRAD]->size != elem_size_fp16) {
MS_LOG(EXCEPTION) << "Error input data size!";
}
if (inputs[LR]->size != kSizeFloat32 || inputs[BETA1]->size != kSizeFloat32 || inputs[BETA2]->size != kSizeFloat32 ||
inputs[EPSILON]->size != kSizeFloat32 || inputs[DECAY]->size != kSizeFloat32) {
MS_LOG(EXCEPTION) << "The attribute beta, lr, epsilon and weight decay must be float!";
}
}
bool FusedCastAdamWeightDecayCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CheckParam(inputs, outputs);
if (var_dtype_ == kNumberTypeFloat16) {
LaunchFusedCastAdamFp16(inputs, outputs);
} else {
LaunchFusedCastAdamFp32(inputs, outputs);
}
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,85 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_CAST_ADAM_WEIGHT_DECAY_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_CAST_ADAM_WEIGHT_DECAY_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
constexpr size_t kSizeFloat32 = sizeof(float);
constexpr size_t kSizeFloat16 = sizeof(float16);
constexpr size_t kScalarIndex = 0;
constexpr size_t kFusedCastAdamWeightDecayInputNum = 9;
constexpr size_t kFusedCastAdamWeightDecayOutputNum = 3;
class FusedCastAdamWeightDecayCPUKernel : public CPUKernel {
public:
FusedCastAdamWeightDecayCPUKernel() = default;
~FusedCastAdamWeightDecayCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
private:
void CheckParam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
void LaunchFusedCastAdamFp32(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
void LaunchFusedCastAdamFp16(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
size_t elem_num_{0};
TypeId var_dtype_{kTypeUnknown};
TypeId gradient_dtype_{kTypeUnknown};
enum input_list_ { VAR, M, V, LR, BETA1, BETA2, EPSILON, DECAY, GRAD };
};
MS_REG_CPU_KERNEL(FusedCastAdamWeightDecay,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
FusedCastAdamWeightDecayCPUKernel)
MS_REG_CPU_KERNEL(FusedCastAdamWeightDecay,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
FusedCastAdamWeightDecayCPUKernel)
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_CAST_ADAM_WEIGHT_DECAY_CPU_KERNEL_H_

View File

@ -105,7 +105,7 @@ if(ENABLE_CPU)
elseif("${X86_64_SIMD}" STREQUAL "avx")
target_compile_definitions(nnacl_mid PRIVATE ENABLE_SSE ENABLE_AVX)
elseif("${X86_64_SIMD}" STREQUAL "avx512")
target_compile_definitions(nnacl_mid PRIVATE ENABLE_SSE ENABLE_AVX ENABLE_AVX512)
target_compile_definitions(nnacl_mid PRIVATE ENABLE_AVX512)
target_compile_options(nnacl_mid PRIVATE -mavx512f)
endif()
target_compile_options(nnacl_mid PRIVATE -fPIC)

View File

@ -152,102 +152,44 @@ int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float
return NNACL_OK;
}
size_t AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
float decay, const float *gradient, size_t start, size_t end) {
int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
const float *gradient, size_t start, size_t end) {
size_t c1 = start;
const float beta1_minus = 1 - beta1;
const float beta2_minus = 1 - beta2;
#ifdef ENABLE_AVX512
struct AVX_Data beta1_r, beta2_r, beta1_minus_r, beta2_minus_r, lr_neg_r, epsilon_r, decay_r;
beta1_r.data = _mm512_set1_ps(beta1);
beta2_r.data = _mm512_set1_ps(beta2);
beta1_minus_r.data = _mm512_set1_ps(1.0f - beta1);
beta2_minus_r.data = _mm512_set1_ps(1.0f - beta2);
lr_neg_r.data = _mm512_set1_ps(-lr);
epsilon_r.data = _mm512_set1_ps(epsilon);
decay_r.data = _mm512_set1_ps(decay);
__m512 beta1_r = _mm512_set1_ps(beta1);
__m512 beta2_r = _mm512_set1_ps(beta2);
__m512 beta1_minus_r = _mm512_set1_ps(beta1_minus);
__m512 beta2_minus_r = _mm512_set1_ps(beta2_minus);
__m512 lr_neg_r = _mm512_set1_ps(-lr);
__m512 epsilon_r = _mm512_set1_ps(epsilon);
__m512 decay_r = _mm512_set1_ps(decay);
size_t c16 = ((end - start) / C16NUM) * C16NUM + start;
size_t c64 = ((end - start) / C64NUM) * C64NUM + start;
const float *gradient_ptr = gradient + start;
float *var_ptr = var + start;
float *m_ptr = m + start;
float *v_ptr = v + start;
for (; c1 < c64; c1 += C64NUM) {
struct AVX_Data g_r[kUnrollSize], var_r[kUnrollSize], m_r[kUnrollSize], v_r[kUnrollSize];
LoadStep4(g_r, gradient_ptr, kUnrollSize);
LoadStep4(var_r, var_ptr, kUnrollSize);
LoadStep4(m_r, m_ptr, kUnrollSize);
LoadStep4(v_r, v_ptr, kUnrollSize);
m_r[0].data = _mm512_mul_ps(m_r[0].data, beta1_r.data);
m_r[1].data = _mm512_mul_ps(m_r[1].data, beta1_r.data);
m_r[2].data = _mm512_mul_ps(m_r[2].data, beta1_r.data);
m_r[3].data = _mm512_mul_ps(m_r[3].data, beta1_r.data);
m_r[0].data = _mm512_fmadd_ps(g_r[0].data, beta1_minus_r.data, m_r[0].data);
m_r[1].data = _mm512_fmadd_ps(g_r[1].data, beta1_minus_r.data, m_r[1].data);
m_r[2].data = _mm512_fmadd_ps(g_r[2].data, beta1_minus_r.data, m_r[2].data);
m_r[3].data = _mm512_fmadd_ps(g_r[3].data, beta1_minus_r.data, m_r[3].data);
v_r[0].data = _mm512_mul_ps(v_r[0].data, beta2_r.data);
v_r[1].data = _mm512_mul_ps(v_r[1].data, beta2_r.data);
v_r[2].data = _mm512_mul_ps(v_r[2].data, beta2_r.data);
v_r[3].data = _mm512_mul_ps(v_r[3].data, beta2_r.data);
g_r[0].data = _mm512_mul_ps(g_r[0].data, g_r[0].data);
g_r[1].data = _mm512_mul_ps(g_r[1].data, g_r[1].data);
g_r[2].data = _mm512_mul_ps(g_r[2].data, g_r[2].data);
g_r[3].data = _mm512_mul_ps(g_r[3].data, g_r[3].data);
v_r[0].data = _mm512_fmadd_ps(g_r[0].data, beta2_minus_r.data, v_r[0].data);
v_r[1].data = _mm512_fmadd_ps(g_r[1].data, beta2_minus_r.data, v_r[1].data);
v_r[2].data = _mm512_fmadd_ps(g_r[2].data, beta2_minus_r.data, v_r[2].data);
v_r[3].data = _mm512_fmadd_ps(g_r[3].data, beta2_minus_r.data, v_r[3].data);
g_r[0].data = _mm512_sqrt_ps(v_r[0].data);
g_r[1].data = _mm512_sqrt_ps(v_r[1].data);
g_r[2].data = _mm512_sqrt_ps(v_r[2].data);
g_r[3].data = _mm512_sqrt_ps(v_r[3].data);
g_r[0].data = _mm512_div_ps(m_r[0].data, _mm512_add_ps(g_r[0].data, epsilon_r.data));
g_r[1].data = _mm512_div_ps(m_r[1].data, _mm512_add_ps(g_r[1].data, epsilon_r.data));
g_r[2].data = _mm512_div_ps(m_r[2].data, _mm512_add_ps(g_r[2].data, epsilon_r.data));
g_r[3].data = _mm512_div_ps(m_r[3].data, _mm512_add_ps(g_r[3].data, epsilon_r.data));
g_r[0].data = _mm512_fmadd_ps(var_r[0].data, decay_r.data, g_r[0].data);
g_r[1].data = _mm512_fmadd_ps(var_r[1].data, decay_r.data, g_r[1].data);
g_r[2].data = _mm512_fmadd_ps(var_r[2].data, decay_r.data, g_r[2].data);
g_r[3].data = _mm512_fmadd_ps(var_r[3].data, decay_r.data, g_r[3].data);
var_r[0].data = _mm512_fmadd_ps(g_r[0].data, lr_neg_r.data, var_r[0].data);
var_r[1].data = _mm512_fmadd_ps(g_r[1].data, lr_neg_r.data, var_r[1].data);
var_r[2].data = _mm512_fmadd_ps(g_r[2].data, lr_neg_r.data, var_r[2].data);
var_r[3].data = _mm512_fmadd_ps(g_r[3].data, lr_neg_r.data, var_r[3].data);
StoreStep4(var_ptr, var_r, kUnrollSize);
StoreStep4(m_ptr, m_r, kUnrollSize);
StoreStep4(v_ptr, v_r, kUnrollSize);
gradient_ptr += C64NUM;
var_ptr += C64NUM;
m_ptr += C64NUM;
v_ptr += C64NUM;
}
for (; c1 < c16; c1 += C16NUM) {
struct AVX_Data g_r, var_r, m_r, v_r;
g_r.data = _mm512_loadu_ps(gradient_ptr);
var_r.data = _mm512_loadu_ps(var_ptr);
m_r.data = _mm512_loadu_ps(m_ptr);
v_r.data = _mm512_loadu_ps(v_ptr);
__m512 g_r = _mm512_loadu_ps(gradient_ptr);
__m512 var_r = _mm512_loadu_ps(var_ptr);
__m512 m_r = _mm512_loadu_ps(m_ptr);
__m512 v_r = _mm512_loadu_ps(v_ptr);
m_r.data = _mm512_mul_ps(m_r.data, beta1_r.data);
v_r.data = _mm512_mul_ps(v_r.data, beta2_r.data);
struct AVX_Data avx_r0;
avx_r0.data = _mm512_mul_ps(g_r.data, g_r.data);
m_r.data = _mm512_fmadd_ps(g_r.data, beta1_minus_r.data, m_r.data);
v_r.data = _mm512_fmadd_ps(avx_r0.data, beta2_minus_r.data, v_r.data);
avx_r0.data = _mm512_sqrt_ps(v_r.data);
avx_r0.data = _mm512_div_ps(m_r.data, _mm512_add_ps(avx_r0.data, epsilon_r.data));
avx_r0.data = _mm512_fmadd_ps(var_r.data, decay_r.data, avx_r0.data);
var_r.data = _mm512_fmadd_ps(avx_r0.data, lr_neg_r.data, var_r.data);
_mm512_storeu_ps(var_ptr, var_r.data);
_mm512_storeu_ps(m_ptr, m_r.data);
_mm512_storeu_ps(v_ptr, v_r.data);
m_r = _mm512_mul_ps(m_r, beta1_r);
v_r = _mm512_mul_ps(v_r, beta2_r);
__m512 avx_r0 = _mm512_mul_ps(g_r, g_r);
m_r = _mm512_fmadd_ps(g_r, beta1_minus_r, m_r);
v_r = _mm512_fmadd_ps(avx_r0, beta2_minus_r, v_r);
avx_r0 = _mm512_sqrt_ps(v_r);
avx_r0 = _mm512_div_ps(m_r, _mm512_add_ps(avx_r0, epsilon_r));
avx_r0 = _mm512_fmadd_ps(var_r, decay_r, avx_r0);
var_r = _mm512_fmadd_ps(avx_r0, lr_neg_r, var_r);
_mm512_storeu_ps(var_ptr, var_r);
_mm512_storeu_ps(m_ptr, m_r);
_mm512_storeu_ps(v_ptr, v_r);
gradient_ptr += C16NUM;
var_ptr += C16NUM;
@ -255,109 +197,51 @@ size_t AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1
v_ptr += C16NUM;
}
#endif
return c1;
// remaining
for (; c1 < end; c1++) {
m[c1] += (gradient[c1] - m[c1]) * beta1_minus;
v[c1] += (gradient[c1] * gradient[c1] - v[c1]) * beta2_minus;
var[c1] -= lr * (m[c1] / (sqrt(v[c1]) + epsilon) + decay * var[c1]);
}
return NNACL_OK;
}
size_t FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
const int16_t *gradient16, size_t start, size_t end) {
size_t FusedCastAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
const int16_t *gradient16, size_t start, size_t end) {
size_t c1 = start;
#ifdef ENABLE_AVX512
struct AVX_Data beta1_r, beta2_r, beta1_minus_r, beta2_minus_r, lr_neg_r, epsilon_r, decay_r;
beta1_r.data = _mm512_set1_ps(beta1);
beta2_r.data = _mm512_set1_ps(beta2);
beta1_minus_r.data = _mm512_set1_ps(1.0f - beta1);
beta2_minus_r.data = _mm512_set1_ps(1.0f - beta2);
lr_neg_r.data = _mm512_set1_ps(-lr);
epsilon_r.data = _mm512_set1_ps(epsilon);
decay_r.data = _mm512_set1_ps(decay);
__m512 beta1_r = _mm512_set1_ps(beta1);
__m512 beta2_r = _mm512_set1_ps(beta2);
__m512 beta1_minus_r = _mm512_set1_ps(1.0f - beta1);
__m512 beta2_minus_r = _mm512_set1_ps(1.0f - beta2);
__m512 lr_neg_r = _mm512_set1_ps(-lr);
__m512 epsilon_r = _mm512_set1_ps(epsilon);
__m512 decay_r = _mm512_set1_ps(decay);
size_t c16 = ((end - start) / C16NUM) * C16NUM + start;
size_t c64 = ((end - start) / C64NUM) * C64NUM + start;
const int16_t *gradient16_ptr = gradient16 + start;
float *var_ptr = var + start;
float *m_ptr = m + start;
float *v_ptr = v + start;
for (; c1 < c64; c1 += C64NUM) {
struct AVX_Data g_r[kUnrollSize], var_r[kUnrollSize], m_r[kUnrollSize], v_r[kUnrollSize];
g_r[0].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr)));
g_r[1].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr + C16NUM)));
g_r[2].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr + C16NUM * 2)));
g_r[3].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr + C16NUM * 3)));
LoadStep4(var_r, var_ptr, kUnrollSize);
LoadStep4(m_r, m_ptr, kUnrollSize);
LoadStep4(v_r, v_ptr, kUnrollSize);
m_r[0].data = _mm512_mul_ps(m_r[0].data, beta1_r.data);
m_r[1].data = _mm512_mul_ps(m_r[1].data, beta1_r.data);
m_r[2].data = _mm512_mul_ps(m_r[2].data, beta1_r.data);
m_r[3].data = _mm512_mul_ps(m_r[3].data, beta1_r.data);
m_r[0].data = _mm512_fmadd_ps(g_r[0].data, beta1_minus_r.data, m_r[0].data);
m_r[1].data = _mm512_fmadd_ps(g_r[1].data, beta1_minus_r.data, m_r[1].data);
m_r[2].data = _mm512_fmadd_ps(g_r[2].data, beta1_minus_r.data, m_r[2].data);
m_r[3].data = _mm512_fmadd_ps(g_r[3].data, beta1_minus_r.data, m_r[3].data);
v_r[0].data = _mm512_mul_ps(v_r[0].data, beta2_r.data);
v_r[1].data = _mm512_mul_ps(v_r[1].data, beta2_r.data);
v_r[2].data = _mm512_mul_ps(v_r[2].data, beta2_r.data);
v_r[3].data = _mm512_mul_ps(v_r[3].data, beta2_r.data);
g_r[0].data = _mm512_mul_ps(g_r[0].data, g_r[0].data);
g_r[1].data = _mm512_mul_ps(g_r[1].data, g_r[1].data);
g_r[2].data = _mm512_mul_ps(g_r[2].data, g_r[2].data);
g_r[3].data = _mm512_mul_ps(g_r[3].data, g_r[3].data);
v_r[0].data = _mm512_fmadd_ps(g_r[0].data, beta2_minus_r.data, v_r[0].data);
v_r[1].data = _mm512_fmadd_ps(g_r[1].data, beta2_minus_r.data, v_r[1].data);
v_r[2].data = _mm512_fmadd_ps(g_r[2].data, beta2_minus_r.data, v_r[2].data);
v_r[3].data = _mm512_fmadd_ps(g_r[3].data, beta2_minus_r.data, v_r[3].data);
g_r[0].data = _mm512_sqrt_ps(v_r[0].data);
g_r[1].data = _mm512_sqrt_ps(v_r[1].data);
g_r[2].data = _mm512_sqrt_ps(v_r[2].data);
g_r[3].data = _mm512_sqrt_ps(v_r[3].data);
g_r[0].data = _mm512_div_ps(m_r[0].data, _mm512_add_ps(g_r[0].data, epsilon_r.data));
g_r[1].data = _mm512_div_ps(m_r[1].data, _mm512_add_ps(g_r[1].data, epsilon_r.data));
g_r[2].data = _mm512_div_ps(m_r[2].data, _mm512_add_ps(g_r[2].data, epsilon_r.data));
g_r[3].data = _mm512_div_ps(m_r[3].data, _mm512_add_ps(g_r[3].data, epsilon_r.data));
g_r[0].data = _mm512_fmadd_ps(var_r[0].data, decay_r.data, g_r[0].data);
g_r[1].data = _mm512_fmadd_ps(var_r[1].data, decay_r.data, g_r[1].data);
g_r[2].data = _mm512_fmadd_ps(var_r[2].data, decay_r.data, g_r[2].data);
g_r[3].data = _mm512_fmadd_ps(var_r[3].data, decay_r.data, g_r[3].data);
var_r[0].data = _mm512_fmadd_ps(g_r[0].data, lr_neg_r.data, var_r[0].data);
var_r[1].data = _mm512_fmadd_ps(g_r[1].data, lr_neg_r.data, var_r[1].data);
var_r[2].data = _mm512_fmadd_ps(g_r[2].data, lr_neg_r.data, var_r[2].data);
var_r[3].data = _mm512_fmadd_ps(g_r[3].data, lr_neg_r.data, var_r[3].data);
StoreStep4(var_ptr, var_r, kUnrollSize);
StoreStep4(m_ptr, m_r, kUnrollSize);
StoreStep4(v_ptr, v_r, kUnrollSize);
gradient16_ptr += C64NUM;
var_ptr += C64NUM;
m_ptr += C64NUM;
v_ptr += C64NUM;
}
for (; c1 < c16; c1 += C16NUM) {
struct AVX_Data g_r, var_r, m_r, v_r;
g_r.data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr)));
var_r.data = _mm512_loadu_ps(var_ptr);
m_r.data = _mm512_loadu_ps(m_ptr);
v_r.data = _mm512_loadu_ps(v_ptr);
__m512 g_r = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr)));
__m512 var_r = _mm512_loadu_ps(var_ptr);
__m512 m_r = _mm512_loadu_ps(m_ptr);
__m512 v_r = _mm512_loadu_ps(v_ptr);
m_r.data = _mm512_mul_ps(m_r.data, beta1_r.data);
v_r.data = _mm512_mul_ps(v_r.data, beta2_r.data);
struct AVX_Data avx_r0;
avx_r0.data = _mm512_mul_ps(g_r.data, g_r.data);
m_r.data = _mm512_fmadd_ps(g_r.data, beta1_minus_r.data, m_r.data);
v_r.data = _mm512_fmadd_ps(avx_r0.data, beta2_minus_r.data, v_r.data);
avx_r0.data = _mm512_sqrt_ps(v_r.data);
avx_r0.data = _mm512_div_ps(m_r.data, _mm512_add_ps(avx_r0.data, epsilon_r.data));
avx_r0.data = _mm512_fmadd_ps(var_r.data, decay_r.data, avx_r0.data);
var_r.data = _mm512_fmadd_ps(avx_r0.data, lr_neg_r.data, var_r.data);
_mm512_storeu_ps(var_ptr, var_r.data);
_mm512_storeu_ps(m_ptr, m_r.data);
_mm512_storeu_ps(v_ptr, v_r.data);
m_r = _mm512_mul_ps(m_r, beta1_r);
v_r = _mm512_mul_ps(v_r, beta2_r);
__m512 avx_r0 = _mm512_mul_ps(g_r, g_r);
m_r = _mm512_fmadd_ps(g_r, beta1_minus_r, m_r);
v_r = _mm512_fmadd_ps(avx_r0, beta2_minus_r, v_r);
avx_r0 = _mm512_sqrt_ps(v_r);
avx_r0 = _mm512_div_ps(m_r, _mm512_add_ps(avx_r0, epsilon_r));
avx_r0 = _mm512_fmadd_ps(var_r, decay_r, avx_r0);
var_r = _mm512_fmadd_ps(avx_r0, lr_neg_r, var_r);
_mm512_storeu_ps(var_ptr, var_r);
_mm512_storeu_ps(m_ptr, m_r);
_mm512_storeu_ps(v_ptr, v_r);
gradient16_ptr += C16NUM;
var_ptr += C16NUM;
@ -367,3 +251,49 @@ size_t FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, floa
#endif
return c1;
}
size_t FusedCastAdamFp16(int16_t *var16, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
float decay, const int16_t *gradient16, size_t start, size_t end) {
size_t c1 = start;
#ifdef ENABLE_AVX512
__m512 beta1_r = _mm512_set1_ps(beta1);
__m512 beta2_r = _mm512_set1_ps(beta2);
__m512 beta1_minus_r = _mm512_set1_ps(1.0f - beta1);
__m512 beta2_minus_r = _mm512_set1_ps(1.0f - beta2);
__m512 lr_neg_r = _mm512_set1_ps(-lr);
__m512 epsilon_r = _mm512_set1_ps(epsilon);
__m512 decay_r = _mm512_set1_ps(decay);
size_t c16 = ((end - start) / C16NUM) * C16NUM + start;
const int16_t *gradient16_ptr = gradient16 + start;
int16_t *var16_ptr = var16 + start;
float *m_ptr = m + start;
float *v_ptr = v + start;
for (; c1 < c16; c1 += C16NUM) {
__m512 g_r = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr)));
__m512 var_r = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(var16_ptr)));
__m512 m_r = _mm512_loadu_ps(m_ptr);
__m512 v_r = _mm512_loadu_ps(v_ptr);
m_r = _mm512_mul_ps(m_r, beta1_r);
v_r = _mm512_mul_ps(v_r, beta2_r);
__m512 avx_r0 = _mm512_mul_ps(g_r, g_r);
m_r = _mm512_fmadd_ps(g_r, beta1_minus_r, m_r);
v_r = _mm512_fmadd_ps(avx_r0, beta2_minus_r, v_r);
avx_r0 = _mm512_sqrt_ps(v_r);
avx_r0 = _mm512_div_ps(m_r, _mm512_add_ps(avx_r0, epsilon_r));
avx_r0 = _mm512_fmadd_ps(var_r, decay_r, avx_r0);
var_r = _mm512_fmadd_ps(avx_r0, lr_neg_r, var_r);
_mm512_storeu_ps(m_ptr, m_r);
_mm512_storeu_ps(v_ptr, v_r);
_mm256_storeu_si256((__m256i *)var16_ptr, _mm512_cvtps_ph(var_r, 0));
gradient16_ptr += C16NUM;
var16_ptr += C16NUM;
m_ptr += C16NUM;
v_ptr += C16NUM;
}
#endif
return c1;
}

View File

@ -26,44 +26,9 @@
#include <x86intrin.h>
#endif
#endif
#ifdef ENABLE_AVX
#if defined(ENABLE_AVX) || defined(ENABLE_AVX512)
#include <immintrin.h>
#endif
#ifdef ENABLE_AVX512
const size_t kUnrollSize = 4;
struct AVX_Data {
__m512 data;
};
static inline int LoadStep4(struct AVX_Data *inp0, const float *inp1, const size_t arrLen) {
if (arrLen != kUnrollSize) {
return NNACL_ERR;
}
if (inp0 == NULL || inp1 == NULL) {
return NNACL_NULL_PTR;
}
inp0[0].data = _mm512_loadu_ps(inp1);
inp0[1].data = _mm512_loadu_ps(inp1 + C16NUM);
inp0[2].data = _mm512_loadu_ps(inp1 + C16NUM * 2);
inp0[3].data = _mm512_loadu_ps(inp1 + C16NUM * 3);
return NNACL_OK;
}
static inline int StoreStep4(float *inp0, struct AVX_Data *inp1, const size_t arrLen) {
if (arrLen != kUnrollSize) {
return NNACL_ERR;
}
if (inp0 == NULL || inp1 == NULL) {
return NNACL_NULL_PTR;
}
_mm512_storeu_ps(inp0, inp1[0].data);
_mm512_storeu_ps(inp0 + C16NUM, inp1[1].data);
_mm512_storeu_ps(inp0 + C16NUM * 2, inp1[2].data);
_mm512_storeu_ps(inp0 + C16NUM * 3, inp1[3].data);
return NNACL_OK;
}
#endif
#ifdef __cplusplus
extern "C" {
#endif
@ -71,10 +36,12 @@ int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2,
size_t start, size_t end, bool use_nesterov);
int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
const float *gradient, size_t start, size_t end, bool use_nesterov);
size_t AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
float decay, const float *gradient, size_t start, size_t end);
size_t FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
const int16_t *gradient16, size_t start, size_t end);
int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
const float *gradient, size_t start, size_t end);
size_t FusedCastAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
const int16_t *gradient16, size_t start, size_t end);
size_t FusedCastAdamFp16(int16_t *var16, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
float decay, const int16_t *gradient16, size_t start, size_t end);
#ifdef __cplusplus
}
#endif

View File

@ -205,6 +205,7 @@ constexpr auto kStridedReadOpName = "StridedRead";
constexpr auto kStridedWriteOpName = "StridedWrite";
constexpr auto kFusedAdamWeightDecayName = "FusedAdamWeightDecay";
constexpr auto kAdamWeightDecayName = "AdamWeightDecay";
constexpr auto kFusedCastAdamWeightDecayName = "FusedCastAdamWeightDecay";
constexpr auto kFusedAdamName = "FusedAdam";
constexpr auto kFusedSparseAdamName = "FusedSparseAdam";
constexpr auto kFusedMatMulBiasAddName = "FusedMatMulBiasAdd";
@ -599,6 +600,7 @@ const std::set<std::string> kOptOperatorSet = {kMomentumOpName,
kAdamApplyOneWithDecayAssignOpName,
kFusedAdamWeightDecayName,
kAdamWeightDecayName,
kFusedCastAdamWeightDecayName,
kFusedAdamName,
kFusedSparseAdamName,
kFusedMulApplyMomentumOpName,

View File

@ -22,6 +22,7 @@ from .div import _div_cpu
from .concat import _concat_cpu
from .split import _split_cpu
from .adam import _adam_cpu
from .adam_weight_decay import _adam_weight_decay_cpu
from .arg_max import _arg_max_cpu
from .arg_min_with_value import _arg_min_with_value_cpu
from .arg_max_with_value import _arg_max_with_value_cpu

View File

@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""AdamWeightDecay op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
adam_weight_decay_op_info = CpuRegOp("AdamWeightDecay") \
.input(0, "var", "required") \
.input(1, "m", "required") \
.input(2, "v", "required") \
.input(3, "lr", "required") \
.input(4, "beta1", "required") \
.input(5, "beta2", "required") \
.input(6, "epsilon", "required") \
.input(7, "decay", "required") \
.input(8, "gradient", "required") \
.output(0, "output0", "required") \
.output(1, "output1", "required") \
.output(2, "output2", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.get_op_info()
@op_info_register(adam_weight_decay_op_info)
def _adam_weight_decay_cpu():
"""AdamWeightDecay cpu register"""
return

View File

@ -44,7 +44,7 @@ from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSumm
from .control_ops import GeSwitch, Merge
from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign,
MakeRefKey,
FusedWeightScaleApplyMomentum, AdamWeightDecay)
FusedWeightScaleApplyMomentum, FusedCastAdamWeightDecay)
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul,
BitwiseAnd, BitwiseOr,
@ -65,8 +65,8 @@ from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, U
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
LogUniformCandidateSampler)
from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum,
BatchNorm, BiasAdd, Conv2D, Conv3D, Conv2DTranspose, Conv3DTranspose,
from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam,
ApplyMomentum, BatchNorm, BiasAdd, Conv2D, Conv3D, Conv2DTranspose, Conv3DTranspose,
DepthwiseConv2dNative,
DropoutDoMask, Dropout, Dropout2D, Dropout3D, DropoutGenMask, Flatten,
InstanceNorm, BNTrainingReduce, BNTrainingUpdate,

View File

@ -453,9 +453,11 @@ class FusedWeightScaleApplyMomentum(PrimitiveWithInfer):
return v_dtype
class AdamWeightDecay(PrimitiveWithInfer):
class FusedCastAdamWeightDecay(PrimitiveWithInfer):
r"""
Updates gradients by the Adaptive Moment Estimation (AdamWeightDecay) algorithm with weight decay.
Updates gradients by the Adaptive Moment Estimation (AdamWeightDecay) algorithm with weight decay. This operator
updates parameters of float16 using gradients of float16 when parameters are initialized with dtype of float16 and
compute with dtype of float16.
The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
@ -479,15 +481,14 @@ class AdamWeightDecay(PrimitiveWithInfer):
If false, the result is unpredictable. Default: False.
Inputs:
- **var** (Tensor) - Weights to be updated.
- **m** (Tensor) - The 1st moment vector in the updating formula, has the same type as `var`.
- **v** (Tensor) - the 2nd moment vector in the updating formula.
Mean square gradients with the same type as `var`.
- **lr** (float) - :math:`l` in the updating formula.
- **var** (Tensor) - Weights to be updated with the type float16 or float32.
- **m** (Tensor) - The 1st moment vector in the updating formula with the type float32.
- **v** (Tensor) - the 2nd moment vector in the updating formula with the type float32.
- **lr** (float) - :math:`lr` in the updating formula.
- **beta1** (float) - The exponential decay rate for the 1st moment estimations.
- **beta2** (float) - The exponential decay rate for the 2nd moment estimations.
- **epsilon** (float) - Term added to the denominator to improve numerical stability.
- **gradient** (Tensor) - Gradient, has the same type as `var`.
- **gradient** (Tensor) - Gradient, has the type float16.
Outputs:
Tuple of 3 Tensor, the updated parameters.
@ -497,27 +498,10 @@ class AdamWeightDecay(PrimitiveWithInfer):
- **v** (Tensor) - The same shape and data type as `v`.
Supported Platforms:
``GPU`` ``CPU``
``CPU``
Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter, ops
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.adam_weight_decay = ops.AdamWeightDecay()
... self.var = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="var")
... self.m = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="m")
... self.v = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="v")
... def construct(self, lr, beta1, beta2, epsilon, decay, grad):
... out = self.adam_weight_decay(self.var, self.m, self.v, lr, beta1, beta2,
... epsilon, decay, grad)
... return out
>>> np.random.seed(0)
>>> net = Net()
>>> gradient = Tensor(np.random.rand(2, 2).astype(np.float32))
>>> output = net(0.9, 0.9, 0.999, 1e-8, 1e-5, gradient)
"""
@prim_attr_register
@ -534,11 +518,11 @@ class AdamWeightDecay(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype,
epsilon_dtype, decay, grad_dtype):
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype}
args = {"m": m_dtype, "v": v_dtype}
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
validator.check_scalar_or_tensor_types_same({"grad": grad_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_scalar_or_tensor_types_same({"var": var_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_scalar_or_tensor_types_same({"grad": grad_dtype}, [mstype.float16], self.name)
args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype,
"decay": decay}
args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype, "decay": decay}
validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True)
return var_dtype, m_dtype, v_dtype

View File

@ -4465,6 +4465,100 @@ class Adam(PrimitiveWithInfer):
return var_dtype, m_dtype, v_dtype
class AdamWeightDecay(PrimitiveWithInfer):
r"""
Updates gradients by the Adaptive Moment Estimation (AdamWeightDecay) algorithm with weight decay.
The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
The updating formulas are as follows,
.. math::
\begin{array}{ll} \\
m = \beta_1 * m + (1 - \beta_1) * g \\
v = \beta_2 * v + (1 - \beta_2) * g * g \\
w = w - lr * \frac{m}{\sqrt{v} + \epsilon}
\end{array}
:math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents
`gradient`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
:math:`\lr` represents `learning_rate`, :math:`w` represents `var`, :math:`\epsilon` represents
`epsilon`.
Args:
use_locking (bool): Whether to enable a lock to protect variable tensors from being updated.
If true, updates of the var, m, and v tensors will be protected by a lock.
If false, the result is unpredictable. Default: False.
Inputs:
- **var** (Tensor) - Weights to be updated. The shape is :math:`(N, *)` where :math:`*` means,
any number of additional dimensions. The data type can be float16 or float32.
- **m** (Tensor) - The 1st moment vector in the updating formula,
the shape and data type value should be the same as `var`.
- **v** (Tensor) - the 2nd moment vector in the updating formula,
the shape and data type value should be the same as `var`. Mean square gradients with the same type as `var`.
- **lr** (float) - :math:`l` in the updating formula. The paper suggested value is :math:`10^{-8}`,
the data type value should be the same as `var`.
- **beta1** (float) - The exponential decay rate for the 1st moment estimations,
the data type value should be the same as `var`. The paper suggested value is :math:`0.9`
- **beta2** (float) - The exponential decay rate for the 2nd moment estimations,
the data type value should be the same as `var`. The paper suggested value is :math:`0.999`
- **epsilon** (float) - Term added to the denominator to improve numerical stability.
- **gradient** (Tensor) - Gradient, has the same shape and data type as `var`.
Outputs:
Tuple of 3 Tensor, the updated parameters.
- **var** (Tensor) - The same shape and data type as `var`.
- **m** (Tensor) - The same shape and data type as `m`.
- **v** (Tensor) - The same shape and data type as `v`.
Supported Platforms:
``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter, ops
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.adam_weight_decay = ops.AdamWeightDecay()
... self.var = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="var")
... self.m = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="m")
... self.v = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="v")
... def construct(self, lr, beta1, beta2, epsilon, decay, grad):
... out = self.adam_weight_decay(self.var, self.m, self.v, lr, beta1, beta2,
... epsilon, decay, grad)
... return out
>>> np.random.seed(0)
>>> net = Net()
>>> gradient = Tensor(np.random.rand(2, 2).astype(np.float32))
>>> output = net(0.9, 0.9, 0.999, 1e-8, 1e-5, gradient)
"""
@prim_attr_register
def __init__(self, use_locking=False):
"""Initialize AdamWeightDecay."""
self.add_prim_attr('side_effect_mem', True)
validator.check_value_type("use_locking", use_locking, [bool], self.name)
def infer_shape(self, var_shape, m_shape, v_shape, lr_shape, beta1_shape, beta2_shape,
epsilon_shape, decay_shape, grad_shape):
validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
return var_shape, m_shape, v_shape
def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype,
epsilon_dtype, decay, grad_dtype):
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype, "decay": decay}
validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True)
return var_dtype, m_dtype, v_dtype
class AdamNoUpdateParam(PrimitiveWithInfer):
r"""
Updates gradients by Adaptive Moment Estimation (Adam) algorithm. This operator do not update the parameter, but

View File

@ -19,7 +19,9 @@ from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import ParameterTuple
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.nn.optim.optimizer import Optimizer
@ -28,6 +30,11 @@ _adam_opt = C.MultitypeFuncGraph("adam_opt")
_scaler_one = Tensor(1, mstype.int32)
_scaler_ten = Tensor(10, mstype.float32)
op_assign = P.Assign()
op_assign.add_prim_attr("primitive_target", "CPU")
op_cast = P.Cast()
op_cast.add_prim_attr("primitive_target", "CPU")
@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool", "Bool")
@ -127,16 +134,23 @@ class AdamWeightDecayOp(Optimizer):
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim)
"""
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0,
param_init_type=mstype.float32):
super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay)
_check_param_value(beta1, beta2, eps, self.cls_name)
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
self.eps = Tensor(np.array([eps]).astype(np.float32))
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
self.enable_init_fp16 = (param_init_type == mstype.float16)
if self.enable_init_fp16:
self.moments1 = self.clone_param32(prefix="adam_m", init='zeros')
self.moments2 = self.clone_param32(prefix="adam_v", init='zeros')
self.opt = P.FusedCastAdamWeightDecay()
else:
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
self.opt = P.AdamWeightDecay()
self.hyper_map = C.HyperMap()
self.opt = P.AdamWeightDecay()
self.opt.add_prim_attr("primitive_target", "CPU")
def construct(self, gradients):
@ -158,3 +172,27 @@ class AdamWeightDecayOp(Optimizer):
if self.use_parallel:
self.broadcast_params(optim_result)
return optim_result
def clone_param32(self, prefix, init=None):
"""
Clone the parameters in ParameterTuple element-wisely to generate a new ParameterTuple with float32 data type.
Inputs:
prefix (str): The prefix name of the parameters.
init (Union[Tensor, str, numbers.Number]): Initialize the shape and dtype of the parameters.
The definition of `init` is the same as in `Parameter` API. If `init` is 'same', the
parameters in the new parameter tuple are the same as those in the original parameter tuple.
Default: 'same'.
Returns:
Tuple, the new Parameter tuple.
"""
new = []
for old_param in self.parameters:
param_init = init
if init is None:
param_init = old_param.init
new_state = old_param.clone()
new_state.set_dtype(mstype.float32)
new_state.set_data(initializer(param_init, shape=old_param.shape, dtype=mstype.float32))
new_state.name = prefix + '.' + new_state.name
new.append(new_state)
return ParameterTuple(new)

View File

@ -934,7 +934,8 @@ class PanguAlpha_Model(nn.Cell):
self.is_pipeline = (config.stage_num > 1)
if self.is_pipeline:
self.top_query_embedding_table = Parameter(initializer(TruncatedNormal(0.02),
[config.seq_length, config.embedding_size]),
[config.seq_length, config.embedding_size],
config.param_init_type),
name='embedding_table', parallel_optimizer=False)
self.top_query_embedding = EmbeddingLookup()
for i in range(config.num_layers):
@ -1084,7 +1085,8 @@ class PanguAlpha(nn.Cell):
raise ValueError(f"{embedding_path} file not exits, "
f"please check whether word_embedding file exist.")
else:
self.embedding_table = Parameter(initializer(Normal(0.02), [self.vocab_size, self.embedding_size]),
self.embedding_table = Parameter(initializer(Normal(0.02), [self.vocab_size, self.embedding_size],
config.param_init_type),
name="embedding_table", parallel_optimizer=False)
def construct(self, input_ids, input_mask, input_position, attention_mask,

View File

@ -42,7 +42,8 @@ class PANGUALPHAConfig:
micro_size=32,
load_ckpt_path=None,
use_top_query_attention=True,
param_init_type=mstype.float32):
param_init_type=mstype.float32,
enable_offload=False):
self.batch_size = batch_size
self.seq_length = seq_length
self.vocab_size = vocab_size
@ -67,6 +68,7 @@ class PANGUALPHAConfig:
self.load_ckpt_path = load_ckpt_path
self.use_top_query_attention = use_top_query_attention
self.param_init_type = param_init_type
self.enable_offload = enable_offload
def __str__(self):
info = "[PANGUALPHAConfig]" + '===' * 10 + '\n'

View File

@ -83,9 +83,12 @@ def _get_square_sum(grad, value):
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
@apply_global_norm.register("Tensor", "Tensor", "Tensor")
def _apply_global_norm(clip_norm, global_norm, grad):
grad = grad * clip_norm / global_norm
@apply_global_norm.register("Bool", "Tensor", "Tensor", "Tensor")
def _apply_global_norm(enable_grad_fp16, clip_norm, global_norm, grad):
if enable_grad_fp16:
grad = P.Cast()(grad * clip_norm / global_norm, mstype.float16)
else:
grad = grad * clip_norm / global_norm
return grad
@ -188,13 +191,17 @@ class ClipByGlobalNorm(nn.Cell):
self.global_norm = GlobalNorm(params, config)
self.clip_norm = Tensor([clip_norm], mstype.float32)
self.hyper_map = C.HyperMap()
if config.param_init_type == mstype.float16 and config.enable_offload:
self.enable_grad_fp16 = True
else:
self.enable_grad_fp16 = False
def construct(self, grads):
"""Clip grads by global norm construct"""
global_norm_value = self.global_norm(grads)
cond = P.GreaterEqual()(global_norm_value, self.clip_norm)
global_norm = F.select(cond, global_norm_value, self.clip_norm)
grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads)
grads = self.hyper_map(F.partial(apply_global_norm, self.enable_grad_fp16, self.clip_norm, global_norm), grads)
return grads, global_norm_value

View File

@ -40,7 +40,6 @@ from src.callbacks import EvalCallBack, LossCallBack
from src.metrics import PPLMetric
project_root = os.path.abspath(
os.path.dirname(os.path.realpath(__file__)) + os.path.sep + "..")
print('project_root:', project_root)
@ -110,7 +109,7 @@ def run_train(args_opt):
stage_num=args_opt.stage_num, micro_size=args_opt.micro_size,
eod_reset=bool(args_opt.eod_reset), load_ckpt_path=args_opt.load_ckpt_path,
param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16,
word_emb_dp=bool(args_opt.word_emb_dp))
word_emb_dp=bool(args_opt.word_emb_dp), enable_offload=bool(args_opt.opt_offload))
print("===config is: ", config, flush=True)
# Define network
pangu_alpha = PanguAlpha(config)
@ -126,7 +125,8 @@ def run_train(args_opt):
if args_opt.optimizer == "lamb":
optimizer = nn.Lamb(group_params, learning_rate=lr)
elif args_opt.opt_offload:
optimizer = AdamWeightDecayOp(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95)
optimizer = AdamWeightDecayOp(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95,
param_init_type=config.param_init_type)
else:
optimizer = FP32StateAdamWeightDecay(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95)
# Initial scaling sens
@ -219,7 +219,7 @@ def run_train_pipeline(args_opt):
use_past=False,
stage_num=args_opt.stage_num,
micro_size=args_opt.micro_size,
word_emb_dp=bool(args_opt.word_emb_dp))
word_emb_dp=bool(args_opt.word_emb_dp), enable_offload=bool(args_opt.opt_offload))
print("===config is: ", config, flush=True)
pangu_alpha = PanguAlpha(config)
loss = CrossEntropyLoss(config)
@ -233,7 +233,8 @@ def run_train_pipeline(args_opt):
if args_opt.optimizer == "lamb":
optimizer = nn.Lamb(group_params, learning_rate=lr)
elif args_opt.opt_offload:
optimizer = AdamWeightDecayOp(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95)
optimizer = AdamWeightDecayOp(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95,
param_init_type=config.param_init_type)
else:
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8)