forked from mindspore-Ecosystem/mindspore
Gpu Adam Fusion
This commit is contained in:
parent
0478b7d191
commit
034d2ea2aa
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "adam_weight_decay_impl.cuh"
|
||||
#include "device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void AdamWeightDecayKernel(const int element_num_, const bool need_decay, const float *beta1,
|
||||
const float *one_sub_beta1, const float *beta2, const float *one_sub_beta2,
|
||||
const float *epsilon, const float *lr, const float *weight_decay, T *m, T *v,
|
||||
T *param, T *gradient) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < element_num_; i += blockDim.x * gridDim.x) {
|
||||
float next_m = beta1[0] * m[i] + one_sub_beta1[0] * gradient[i];
|
||||
float next_v = beta2[0] * v[i] + one_sub_beta2[0] * gradient[i] * gradient[i];
|
||||
float update = next_m / (sqrt(next_v) + epsilon[0]);
|
||||
if (need_decay && weight_decay != nullptr) {
|
||||
update += weight_decay[0] * param[i];
|
||||
}
|
||||
param[i] -= lr[0] * update;
|
||||
m[i] = next_m;
|
||||
v[i] = next_v;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AdamWeightDecay(const int &element_num_, const bool &need_decay, const float *beta1, const float *one_sub_beta1,
|
||||
const float *beta2, const float *one_sub_beta2, const float *epsilon, const float *lr,
|
||||
const float *weight_decay, T *m, T *v, T *param, T *gradient, cudaStream_t stream) {
|
||||
AdamWeightDecayKernel<<<GET_BLOCKS(element_num_), GET_THREADS, 0, stream>>>(
|
||||
element_num_, need_decay, beta1, one_sub_beta1, beta2, one_sub_beta2, epsilon, lr, weight_decay, m, v, param,
|
||||
gradient);
|
||||
}
|
||||
|
||||
template void AdamWeightDecay(const int &element_num_, const bool &need_decay, const float *beta1,
|
||||
const float *one_sub_beta1, const float *beta2, const float *one_sub_beta2,
|
||||
const float *epsilon, const float *lr, const float *weight_decay, float *m, float *v,
|
||||
float *param, float *gradient, cudaStream_t stream);
|
|
@ -0,0 +1,24 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ADAM_WEIGHT_DECAY_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ADAM_WEIGHT_DECAY_H_
|
||||
template <typename T>
|
||||
void AdamWeightDecay(const int &element_num_, const bool &need_decay, const float *beta1, const float *one_sub_beta1,
|
||||
const float *beta2, const float *one_sub_beta2, const float *epsilon, const float *lr,
|
||||
const float *weight_decay, T *m, T *v, T *param, T *gradient, cudaStream_t stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ADAM_WEIGHT_DECAY_H_
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel/gpu/nn/fused_adam_weight_decay.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(FusedAdamWeightDecay,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedAdamWeightDecayGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(FusedAdam,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedAdamWeightDecayGpuKernel, float)
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,103 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "kernel/gpu/gpu_kernel.h"
|
||||
#include "kernel/gpu/gpu_kernel_factory.h"
|
||||
#include "kernel/gpu/kernel_constants.h"
|
||||
#include "kernel/gpu/cuda_impl/adam_weight_decay_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class FusedAdamWeightDecayGpuKernel : public GpuKernel {
|
||||
public:
|
||||
FusedAdamWeightDecayGpuKernel() : element_nums_(0), weight_decay_(false) {}
|
||||
~FusedAdamWeightDecayGpuKernel() override = default;
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
auto node_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (node_name == "AdamWeighDecay") {
|
||||
weight_decay_ = true;
|
||||
}
|
||||
|
||||
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 7);
|
||||
element_nums_ = 1;
|
||||
for (auto i : shape) {
|
||||
element_nums_ *= i;
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
float *beta1 = GetDeviceAddress<float>(inputs, 0);
|
||||
float *one_sub_beta1 = GetDeviceAddress<float>(inputs, 1);
|
||||
float *beta2 = GetDeviceAddress<float>(inputs, 2);
|
||||
float *one_sub_beta2 = GetDeviceAddress<float>(inputs, 3);
|
||||
float *epsilon = GetDeviceAddress<float>(inputs, 4);
|
||||
float *lr = GetDeviceAddress<float>(inputs, 5);
|
||||
T *param = GetDeviceAddress<T>(inputs, 6);
|
||||
T *m = GetDeviceAddress<T>(inputs, 7);
|
||||
T *v = GetDeviceAddress<T>(inputs, 8);
|
||||
T *gradient = GetDeviceAddress<T>(inputs, 9);
|
||||
float *weight_decay = nullptr;
|
||||
if (weight_decay_) {
|
||||
weight_decay = GetDeviceAddress<float>(inputs, 10);
|
||||
}
|
||||
AdamWeightDecay(element_nums_, true, beta1, one_sub_beta1, beta2, one_sub_beta2, epsilon, lr, weight_decay, m, v,
|
||||
param, gradient, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitResource() override{};
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(sizeof(float));
|
||||
input_size_list_.push_back(sizeof(float));
|
||||
input_size_list_.push_back(sizeof(float));
|
||||
input_size_list_.push_back(sizeof(float));
|
||||
input_size_list_.push_back(element_nums_ * sizeof(T));
|
||||
input_size_list_.push_back(sizeof(float));
|
||||
input_size_list_.push_back(sizeof(float));
|
||||
input_size_list_.push_back(element_nums_ * sizeof(T));
|
||||
if (weight_decay_) {
|
||||
input_size_list_.push_back(sizeof(float));
|
||||
}
|
||||
output_size_list_.push_back(element_nums_ * sizeof(T));
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
||||
int element_nums_;
|
||||
bool weight_decay_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_
|
|
@ -182,9 +182,11 @@ extern const PrimitivePtr kPrimReduceMin;
|
|||
extern const PrimitivePtr kPrimNeg;
|
||||
extern const PrimitivePtr kPrimSub;
|
||||
extern const PrimitivePtr kPrimMul;
|
||||
extern const PrimitivePtr kPrimRealDiv;
|
||||
extern const PrimitivePtr kPrimMinimum;
|
||||
extern const PrimitivePtr kPrimMaximum;
|
||||
extern const PrimitivePtr kPrimSquare;
|
||||
extern const PrimitivePtr kPrimSqrt;
|
||||
extern const PrimitivePtr kPrimEqual;
|
||||
extern const PrimitivePtr kPrimLess;
|
||||
extern const PrimitivePtr kPrimLessEqual;
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pre_activate/gpu/adam_fusion.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "pre_activate/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
||||
std::vector<std::string> inputs_format;
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> inputs_type;
|
||||
std::vector<TypeId> outputs_type;
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) {
|
||||
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
|
||||
inputs_format.push_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) {
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
|
||||
outputs_format.push_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
builder.SetInputsDeviceType(inputs_type);
|
||||
builder.SetInputsFormat(inputs_format);
|
||||
builder.SetOutputsDeviceType(outputs_type);
|
||||
builder.SetOutputsFormat(outputs_format);
|
||||
return builder.Build();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef AdamFusion::DefinePattern() const {
|
||||
VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}),
|
||||
VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});
|
||||
VectorRef next_v =
|
||||
VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}),
|
||||
VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})});
|
||||
VectorRef update = VectorRef(
|
||||
{prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})});
|
||||
VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update});
|
||||
VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr});
|
||||
VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})});
|
||||
VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})});
|
||||
VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})});
|
||||
return depend3;
|
||||
}
|
||||
|
||||
const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
auto beta1_input = utils::cast<AnfNodePtr>((*equiv)[beta1_]);
|
||||
auto one_sub_beta1_input = utils::cast<AnfNodePtr>((*equiv)[one_sub_beta1_]);
|
||||
auto beta2_input = utils::cast<AnfNodePtr>((*equiv)[beta2_]);
|
||||
auto one_sub_beta2_input = utils::cast<AnfNodePtr>((*equiv)[one_sub_beta2_]);
|
||||
auto eps_input = utils::cast<AnfNodePtr>((*equiv)[eps_]);
|
||||
auto lr_input = utils::cast<AnfNodePtr>((*equiv)[lr_]);
|
||||
auto param_input = utils::cast<AnfNodePtr>((*equiv)[param_]);
|
||||
auto m_input = utils::cast<AnfNodePtr>((*equiv)[m_]);
|
||||
auto v_input = utils::cast<AnfNodePtr>((*equiv)[v_]);
|
||||
auto gradient_input = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
|
||||
MS_EXCEPTION_IF_NULL(beta1_input);
|
||||
MS_EXCEPTION_IF_NULL(one_sub_beta1_input);
|
||||
MS_EXCEPTION_IF_NULL(beta2_input);
|
||||
MS_EXCEPTION_IF_NULL(one_sub_beta2_input);
|
||||
MS_EXCEPTION_IF_NULL(eps_input);
|
||||
MS_EXCEPTION_IF_NULL(lr_input);
|
||||
MS_EXCEPTION_IF_NULL(param_input);
|
||||
MS_EXCEPTION_IF_NULL(m_input);
|
||||
MS_EXCEPTION_IF_NULL(v_input);
|
||||
MS_EXCEPTION_IF_NULL(gradient_input);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kFusedAdamName);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {
|
||||
NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input,
|
||||
eps_input, lr_input, param_input, m_input, v_input,
|
||||
gradient_input};
|
||||
auto adam = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(adam);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam.get());
|
||||
adam->set_scope(node->scope());
|
||||
|
||||
auto build_info = GenerateKernelBuildInfo(adam);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(build_info, adam.get());
|
||||
return adam;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_
|
||||
|
||||
#include <memory>
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class AdamFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit AdamFusion(bool multigraph = true) : PatternProcessPass("adam_fusion", multigraph) {
|
||||
beta1_ = std::make_shared<Var>();
|
||||
one_sub_beta1_ = std::make_shared<Var>();
|
||||
beta2_ = std::make_shared<Var>();
|
||||
one_sub_beta2_ = std::make_shared<Var>();
|
||||
eps_ = std::make_shared<Var>();
|
||||
lr_ = std::make_shared<Var>();
|
||||
param_ = std::make_shared<Var>();
|
||||
m_ = std::make_shared<Var>();
|
||||
v_ = std::make_shared<Var>();
|
||||
gradient_ = std::make_shared<Var>();
|
||||
}
|
||||
~AdamFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
VarPtr beta1_;
|
||||
VarPtr one_sub_beta1_;
|
||||
VarPtr beta2_;
|
||||
VarPtr one_sub_beta2_;
|
||||
VarPtr eps_;
|
||||
VarPtr lr_;
|
||||
VarPtr param_;
|
||||
VarPtr m_;
|
||||
VarPtr v_;
|
||||
VarPtr gradient_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_
|
|
@ -0,0 +1,117 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pre_activate/gpu/adam_weight_decay_fusion.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "pre_activate/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
||||
std::vector<std::string> inputs_format;
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> inputs_type;
|
||||
std::vector<TypeId> outputs_type;
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) {
|
||||
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
|
||||
inputs_format.push_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) {
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
|
||||
outputs_format.push_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
builder.SetInputsDeviceType(inputs_type);
|
||||
builder.SetInputsFormat(inputs_format);
|
||||
builder.SetOutputsDeviceType(outputs_type);
|
||||
builder.SetOutputsFormat(outputs_format);
|
||||
return builder.Build();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef AdamWeightDecayFusion::DefinePattern() const {
|
||||
VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}),
|
||||
VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});
|
||||
VectorRef next_v =
|
||||
VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}),
|
||||
VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})});
|
||||
VectorRef update = VectorRef(
|
||||
{prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})});
|
||||
VectorRef new_update = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, weight_decay_, param_}), update});
|
||||
|
||||
VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update});
|
||||
VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr});
|
||||
VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})});
|
||||
VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})});
|
||||
VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})});
|
||||
return depend3;
|
||||
}
|
||||
|
||||
const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
auto beta1_input = utils::cast<AnfNodePtr>((*equiv)[beta1_]);
|
||||
auto one_sub_beta1_input = utils::cast<AnfNodePtr>((*equiv)[one_sub_beta1_]);
|
||||
auto beta2_input = utils::cast<AnfNodePtr>((*equiv)[beta2_]);
|
||||
auto one_sub_beta2_input = utils::cast<AnfNodePtr>((*equiv)[one_sub_beta2_]);
|
||||
auto eps_input = utils::cast<AnfNodePtr>((*equiv)[eps_]);
|
||||
auto lr_input = utils::cast<AnfNodePtr>((*equiv)[lr_]);
|
||||
auto weight_decay_input = utils::cast<AnfNodePtr>((*equiv)[weight_decay_]);
|
||||
auto param_input = utils::cast<AnfNodePtr>((*equiv)[param_]);
|
||||
auto m_input = utils::cast<AnfNodePtr>((*equiv)[m_]);
|
||||
auto v_input = utils::cast<AnfNodePtr>((*equiv)[v_]);
|
||||
auto gradient_input = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
|
||||
MS_EXCEPTION_IF_NULL(beta1_input);
|
||||
MS_EXCEPTION_IF_NULL(one_sub_beta1_input);
|
||||
MS_EXCEPTION_IF_NULL(beta2_input);
|
||||
MS_EXCEPTION_IF_NULL(one_sub_beta2_input);
|
||||
MS_EXCEPTION_IF_NULL(eps_input);
|
||||
MS_EXCEPTION_IF_NULL(lr_input);
|
||||
MS_EXCEPTION_IF_NULL(weight_decay_input);
|
||||
MS_EXCEPTION_IF_NULL(param_input);
|
||||
MS_EXCEPTION_IF_NULL(m_input);
|
||||
MS_EXCEPTION_IF_NULL(v_input);
|
||||
MS_EXCEPTION_IF_NULL(gradient_input);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kFusedAdamWeightDecayName);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {
|
||||
NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input,
|
||||
eps_input, lr_input, param_input, m_input, v_input,
|
||||
gradient_input, weight_decay_input};
|
||||
auto adam_weight_decay = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(adam_weight_decay);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam_weight_decay.get());
|
||||
adam_weight_decay->set_scope(node->scope());
|
||||
|
||||
auto build_info = GenerateKernelBuildInfo(adam_weight_decay);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(build_info, adam_weight_decay.get());
|
||||
return adam_weight_decay;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_
|
||||
|
||||
#include <memory>
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class AdamWeightDecayFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit AdamWeightDecayFusion(bool multigraph = true) : PatternProcessPass("adam_weight_decay_fusion", multigraph) {
|
||||
beta1_ = std::make_shared<Var>();
|
||||
one_sub_beta1_ = std::make_shared<Var>();
|
||||
beta2_ = std::make_shared<Var>();
|
||||
one_sub_beta2_ = std::make_shared<Var>();
|
||||
eps_ = std::make_shared<Var>();
|
||||
lr_ = std::make_shared<Var>();
|
||||
weight_decay_ = std::make_shared<Var>();
|
||||
param_ = std::make_shared<Var>();
|
||||
m_ = std::make_shared<Var>();
|
||||
v_ = std::make_shared<Var>();
|
||||
gradient_ = std::make_shared<Var>();
|
||||
}
|
||||
~AdamWeightDecayFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
VarPtr beta1_;
|
||||
VarPtr one_sub_beta1_;
|
||||
VarPtr beta2_;
|
||||
VarPtr one_sub_beta2_;
|
||||
VarPtr eps_;
|
||||
VarPtr lr_;
|
||||
VarPtr weight_decay_;
|
||||
VarPtr param_;
|
||||
VarPtr m_;
|
||||
VarPtr v_;
|
||||
VarPtr gradient_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_
|
|
@ -23,6 +23,8 @@
|
|||
#include "pre_activate/common/helper.h"
|
||||
#include "pre_activate/pass/communication_op_fusion.h"
|
||||
#include "pre_activate/pass/getitem_tuple.h"
|
||||
#include "pre_activate/gpu/adam_weight_decay_fusion.h"
|
||||
#include "pre_activate/gpu/adam_fusion.h"
|
||||
#include "device/kernel_runtime_manager.h"
|
||||
#include "predict/predict.h"
|
||||
#include "common/utils.h"
|
||||
|
@ -53,6 +55,16 @@ void GPUSession::StartKernelRT() const {
|
|||
|
||||
void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>());
|
||||
pm->AddPass(std::make_shared<opt::AdamFusion>());
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
}
|
||||
|
||||
void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AllReduceFusion>());
|
||||
|
@ -151,14 +163,16 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
|
|||
auto graph_id = graph_sum_;
|
||||
auto graph = ConstructKernelGraph(lst, outputs);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// Optimize
|
||||
Optimize(graph);
|
||||
// Select kernel build info
|
||||
SelectKernel(graph);
|
||||
// Convert kernel Graph to model
|
||||
predictmodel::StepConvertGraph(graph);
|
||||
// Start gpu kernel runtime
|
||||
StartKernelRT();
|
||||
// AllReduce Optimize
|
||||
Optimize(graph);
|
||||
// HardwareOptimize
|
||||
HardwareOptimize(graph);
|
||||
// Assign CUDA streams
|
||||
AssignStream(graph);
|
||||
// Hide NoOp from execution graph
|
||||
|
|
|
@ -51,6 +51,8 @@ class GPUSession : public SessionBasic {
|
|||
|
||||
void Optimize(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
|
||||
void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
|
||||
void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
|
||||
void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
|
|
|
@ -161,6 +161,8 @@ constexpr auto kNMSWithMaskOpName = "NMSWithMask";
|
|||
constexpr auto kSoftmaxGradExtOpName = "SoftmaxGradExt";
|
||||
constexpr auto kStridedReadOpName = "StridedRead";
|
||||
constexpr auto kStridedWriteOpName = "StridedWrite";
|
||||
constexpr auto kFusedAdamWeightDecayName = "FusedAdamWeightDecay";
|
||||
constexpr auto kFusedAdamName = "FusedAdam";
|
||||
|
||||
// attr key name
|
||||
constexpr auto kAttrInputNames = "input_names";
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, decay_flag=True):
|
||||
super(Net, self).__init__()
|
||||
self.decay_flag = decay_flag
|
||||
self.op_mul = P.Mul()
|
||||
self.op_square = P.Square()
|
||||
self.op_sqrt = P.Sqrt()
|
||||
self.op_cast = P.Cast()
|
||||
self.op_reshape = P.Reshape()
|
||||
self.op_shape = P.Shape()
|
||||
self.param = Parameter(Tensor(np.array([0.1, 0.3, 0.5]).astype(np.float32)), name='param')
|
||||
self.m = Parameter(Tensor(np.array([0.1, 0.3, 0.5]).astype(np.float32)), name='m')
|
||||
self.v = Parameter(Tensor(np.array([0.1, 0.3, 0.5]).astype(np.float32)), name='v')
|
||||
|
||||
@ms_function
|
||||
def construct(self, beta1, beta2, gradient, eps, weight_decay_tensor, lr):
|
||||
param_fp32 = self.op_cast(self.param, mstype.float32)
|
||||
m_fp32 = self.op_cast(self.m, mstype.float32)
|
||||
v_fp32 = self.op_cast(self.v, mstype.float32)
|
||||
gradient_fp32 = self.op_cast(gradient, mstype.float32)
|
||||
|
||||
next_m = self.op_mul(beta1, m_fp32) + \
|
||||
self.op_mul(self.op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)
|
||||
next_v = self.op_mul(beta2, v_fp32) + self.op_mul(self.op_cast(F.tuple_to_array((1.0,)), mstype.float32) - \
|
||||
beta2, self.op_square(gradient_fp32))
|
||||
update = next_m / (eps + self.op_sqrt(next_v))
|
||||
if self.decay_flag:
|
||||
update = self.op_mul(weight_decay_tensor, param_fp32) + update
|
||||
update_with_lr = self.op_mul(lr, update)
|
||||
next_param = param_fp32 - self.op_reshape(update_with_lr, self.op_shape(param_fp32))
|
||||
|
||||
next_v = F.depend(next_v, F.assign(self.param, next_param))
|
||||
next_v = F.depend(next_v, F.assign(self.m, next_m))
|
||||
next_v = F.depend(next_v, F.assign(self.v, next_v))
|
||||
return next_v
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test():
|
||||
beta1 = Tensor(np.array([0.9]).astype(np.float32))
|
||||
beta2 = Tensor(np.array([0.999]).astype(np.float32))
|
||||
lr = Tensor(np.array([0.001]).astype(np.float32))
|
||||
eps = Tensor(np.array([1e-6]).astype(np.float32))
|
||||
weight_decay_tensor = Tensor(np.array([0.001]).astype(np.float32))
|
||||
|
||||
gradient = Tensor(np.array([0.01, 0.03, 0.05]).astype(np.float32))
|
||||
opt = Net(True)
|
||||
_ = opt(beta1, beta2, gradient, eps, weight_decay_tensor, lr)
|
||||
|
||||
param_expect = np.array([0.09971199, 0.29950103, 0.4993557]).astype(np.float32)
|
||||
m_expect = np.array([0.091, 0.273, 0.45499998]).astype(np.float32)
|
||||
v_expect = np.array([0.0999001, 0.29970092, 0.4995025]).astype(np.float32)
|
||||
assert np.allclose(opt.param.data.asnumpy(), param_expect)
|
||||
assert np.allclose(opt.m.data.asnumpy(), m_expect)
|
||||
assert np.allclose(opt.v.data.asnumpy(), v_expect)
|
|
@ -119,6 +119,10 @@ def test_4d_transpose_ab():
|
|||
[[5612, 5810, 6008, 6206]]]]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_4D_fp16():
|
||||
input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float16)
|
||||
input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float16)
|
||||
|
@ -126,13 +130,13 @@ def test_4D_fp16():
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = BatchMatMulNet()
|
||||
output = net(input_x, input_y)
|
||||
expect = [[[[20, 23, 26, 29]],
|
||||
[[200, 212, 224, 236]],
|
||||
[[596, 617, 638, 659]],
|
||||
[[1208, 1238, 1268, 1298]]],
|
||||
expect = np.array([[[[20, 23, 26, 29]],
|
||||
[[200, 212, 224, 236]],
|
||||
[[596, 617, 638, 659]],
|
||||
[[1208, 1238, 1268, 1298]]],
|
||||
|
||||
[[[2036, 2075, 2114, 2153]],
|
||||
[[3080, 3128, 3176, 3224]],
|
||||
[[4340, 4397, 4454, 4511]],
|
||||
[[5816, 5882, 5948, 6014]]]]
|
||||
[[[2036, 2076, 2114, 2152]],
|
||||
[[3080, 3128, 3176, 3224]],
|
||||
[[4340, 4396, 4456, 4510]],
|
||||
[[5816, 5880, 5948, 6016]]]]).astype(np.float16)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
|
Loading…
Reference in New Issue