diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sgd_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sgd_cpu_kernel.cc new file mode 100644 index 00000000000..8e1b4d66f45 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sgd_cpu_kernel.cc @@ -0,0 +1,88 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/sgd_cpu_kernel.h" + +#include +#include + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kInputSize = 6; +constexpr size_t kOutputSize = 1; +} // namespace +template +void SGDCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + dampening_ = AnfAlgo::GetNodeAttr(kernel_node, "dampening"); + weight_decay_ = AnfAlgo::GetNodeAttr(kernel_node, "weight_decay"); + nesterov_ = AnfAlgo::GetNodeAttr(kernel_node, "nesterov"); +} + +template +void SGDCPUKernel::CheckParam(const std::vector &inputs, const std::vector &outputs) { + // inputs: params, grad, lr, accum, momentum, stat + if (inputs.size() != kInputSize) { + MS_LOG(EXCEPTION) << "Input number is " << inputs.size() << ", but SGD needs 6 inputs."; + } + + // output: param + if (outputs.size() != kOutputSize) { + MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but SGD needs 1 outputs."; + } +} + +template +bool SGDCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, + const std::vector &outputs) { + CheckParam(inputs, outputs); + + auto param = reinterpret_cast(inputs[0]->addr); + auto grad = reinterpret_cast(inputs[1]->addr); + auto lr = reinterpret_cast(inputs[2]->addr); + auto accum = reinterpret_cast(inputs[3]->addr); + auto momentum = reinterpret_cast(inputs[4]->addr); + auto stat = reinterpret_cast(inputs[5]->addr); + size_t elem_num = inputs[0]->size / sizeof(float); + + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + T grad_new = grad[i]; + if (weight_decay_ > 0) { + grad_new += param[i] * static_cast(weight_decay_); + } + if (momentum[0] > static_cast(0)) { + if (stat[i] > static_cast(0)) { + accum[i] = grad_new; + stat[i] = static_cast(0); + } else { + accum[i] = accum[i] * momentum[0] + static_cast(1.0 - dampening_) * grad_new; + } + if (nesterov_) { + grad_new += accum[i] * momentum[0]; + } else { + grad_new = accum[i]; + } + } + param[i] -= lr[0] * grad_new; + } + }; + CPUKernelUtils::ParallelFor(task, elem_num); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sgd_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sgd_cpu_kernel.h new file mode 100644 index 00000000000..93f25d1b657 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sgd_cpu_kernel.h @@ -0,0 +1,67 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SGD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SGD_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class SGDCPUKernel : public CPUKernel { + public: + SGDCPUKernel() = default; + ~SGDCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector & /*workspace*/, + const std::vector &outputs) override; + + private: + static void CheckParam(const std::vector &inputs, const std::vector &outputs); + float dampening_; + float weight_decay_; + bool nesterov_{true}; +}; + +MS_REG_CPU_KERNEL_T(SGD, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SGDCPUKernel, float); + +MS_REG_CPU_KERNEL_T(SGD, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + SGDCPUKernel, float16); +} // namespace kernel +} // namespace mindspore +#endif diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 9c94a17aeb5..a8b7cac0961 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2704,7 +2704,7 @@ class SGD(PrimitiveWithCheck): float16 nor float32. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> sgd = ops.SGD() diff --git a/tests/st/ops/cpu/test_sgd_op.py b/tests/st/ops/cpu/test_sgd_op.py new file mode 100644 index 00000000000..37eb6e42f24 --- /dev/null +++ b/tests/st/ops/cpu/test_sgd_op.py @@ -0,0 +1,72 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.nn import Dense +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.nn.optim import SGD +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + +class NetSGD(nn.Cell): + def __init__(self): + super(NetSGD, self).__init__() + self.batch_size = 1 + self.reshape = P.Reshape() + weight = Tensor(np.ones([10, 16]).astype(np.float32) * 0.01) + self.fc1 = Dense(16, 10, weight_init=weight) + + def construct(self, input_x): + output = self.reshape(input_x, (self.batch_size, -1)) + output = self.fc1(output) + return output + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_SGD(): + epoch = 3 + net = NetSGD() + learning_rate = 0.1 + momentum = 0.9 + dampening = 0.0 + weight_decay = 0.0 + nesterov = True + loss_scale = 1.0 + + optimizer = SGD(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum, dampening, + weight_decay, nesterov, loss_scale) + criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + net_with_criterion = WithLossCell(net, criterion) + train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer + train_network.set_train() + losses = [] + for _ in range(epoch): + data = Tensor(np.arange(0, 16).reshape(1, 1, 4, 4).astype(np.float32) * 0.01) + label = Tensor(np.array([0]).astype(np.int32)) + loss = train_network(data, label) + losses.append(loss.asnumpy()) + + last_loss = 100.0 + for loss in losses: + assert last_loss > loss + last_loss = loss