forked from mindspore-Ecosystem/mindspore
!8088 add Adam CPU operator
Merge pull request !8088 from zhaoting/adam
This commit is contained in:
commit
491efd8aa4
|
@ -0,0 +1,108 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/cpu/adam_cpu_kernel.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <thread>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
template <typename T>
|
||||
void AdamCPUKernel::LaunchAdam(T *var, T *m, T *v, float lr, float beta1, float beta2, float epsilon, const T *gradient,
|
||||
size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
m[i] += (gradient[i] - m[i]) * (1 - beta1);
|
||||
v[i] += (gradient[i] * gradient[i] - v[i]) * (1 - beta2);
|
||||
if (use_nesterov) {
|
||||
var[i] -= lr * (m[i] * beta1 + (1 - beta1) * gradient[i]) / (std::sqrt(v[i]) + epsilon);
|
||||
} else {
|
||||
var[i] -= lr * m[i] / (std::sqrt(v[i]) + epsilon);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 10) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but Adam needs 10 inputs.";
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 3) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but Adam needs 3 outputs.";
|
||||
}
|
||||
use_nesterov = AnfAlgo::GetNodeAttr<bool>(kernel_node, "use_nesterov");
|
||||
}
|
||||
|
||||
bool AdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (inputs.size() != 10) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << inputs.size() << ", but Adam needs 10 inputs.";
|
||||
}
|
||||
if (outputs.size() != 3) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but Adam needs 3 outputs.";
|
||||
}
|
||||
if (inputs[0]->size != inputs[1]->size || inputs[0]->size != inputs[2]->size || inputs[0]->size != inputs[9]->size) {
|
||||
MS_LOG(EXCEPTION) << "Error input data size!";
|
||||
}
|
||||
size_t f_size = sizeof(float);
|
||||
if (inputs[3]->size != f_size || inputs[4]->size != f_size || inputs[5]->size != f_size ||
|
||||
inputs[6]->size != f_size || inputs[7]->size != f_size || inputs[8]->size != f_size) {
|
||||
MS_LOG(EXCEPTION) << "The attribute beta_power, beta, lr and epsilon must be float!";
|
||||
}
|
||||
auto var = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
auto m = reinterpret_cast<float *>(inputs[1]->addr);
|
||||
auto v = reinterpret_cast<float *>(inputs[2]->addr);
|
||||
float beta1_power = reinterpret_cast<float *>(inputs[3]->addr)[0];
|
||||
float beta2_power = reinterpret_cast<float *>(inputs[4]->addr)[0];
|
||||
float lr = reinterpret_cast<float *>(inputs[5]->addr)[0];
|
||||
float beta1 = reinterpret_cast<float *>(inputs[6]->addr)[0];
|
||||
float beta2 = reinterpret_cast<float *>(inputs[7]->addr)[0];
|
||||
float epsilon = reinterpret_cast<float *>(inputs[8]->addr)[0];
|
||||
auto gradient = reinterpret_cast<float *>(inputs[9]->addr);
|
||||
if (beta1_power == 1) {
|
||||
MS_LOG(EXCEPTION) << "The beta1_power can't be set 1.";
|
||||
}
|
||||
float new_lr = lr * std::sqrt(1.0 - beta2_power) / (1 - beta1_power);
|
||||
|
||||
// multithreading
|
||||
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
|
||||
auto max_thread_num = std::thread::hardware_concurrency();
|
||||
size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num;
|
||||
MS_LOG(INFO) << "lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num;
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(thread_num);
|
||||
size_t start = 0;
|
||||
size_t once_compute_size = (lens + thread_num - 1) / thread_num;
|
||||
while (start < lens) {
|
||||
size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size);
|
||||
threads.emplace_back(std::thread(&AdamCPUKernel::LaunchAdam<float>, this, var, m, v, new_lr, beta1, beta2, epsilon,
|
||||
gradient, start, end));
|
||||
start += once_compute_size;
|
||||
}
|
||||
for (size_t i = 0; i < threads.size(); ++i) {
|
||||
threads[i].join();
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADAM_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADAM_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class AdamCPUKernel : public CPUKernel {
|
||||
public:
|
||||
AdamCPUKernel() = default;
|
||||
~AdamCPUKernel() override = default;
|
||||
template <typename T>
|
||||
void LaunchAdam(T *var, T *m, T *v, float lr, float beta1, float beta2, float epsilon, const T *gradient,
|
||||
size_t start, size_t end);
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
bool use_nesterov{false};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Adam,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
AdamCPUKernel)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADAM_CPU_KERNEL_H_
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import Dense
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
||||
class NetAdam(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetAdam, self).__init__()
|
||||
self.batch_size = 1
|
||||
self.reshape = P.Reshape()
|
||||
weight = Tensor(np.ones([10, 16]).astype(np.float32) * 0.01)
|
||||
self.fc1 = Dense(16, 10, weight_init=weight)
|
||||
|
||||
def construct(self, input_x):
|
||||
output = self.reshape(input_x, (self.batch_size, -1))
|
||||
output = self.fc1(output)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_adam():
|
||||
epoch = 3
|
||||
net = NetAdam()
|
||||
optimizer = Adam(filter(lambda x: x.requires_grad,
|
||||
net.get_parameters()), learning_rate=0.01)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
net_with_criterion = WithLossCell(net, criterion)
|
||||
train_network = TrainOneStepCell(
|
||||
net_with_criterion, optimizer)
|
||||
train_network.set_train()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
losses1 = []
|
||||
for _ in range(epoch):
|
||||
data = Tensor(np.arange(0, 16).reshape(
|
||||
1, 1, 4, 4).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.array([0]).astype(np.int32))
|
||||
loss = train_network(data, label)
|
||||
losses1.append(loss.asnumpy())
|
||||
assert losses1[0] > losses1[1]
|
||||
assert losses1[1] > losses1[2]
|
Loading…
Reference in New Issue