Gpu Adam Fusion

This commit is contained in:
wilfChen 2020-06-23 20:40:45 +08:00
parent 0478b7d191
commit 034d2ea2aa
14 changed files with 693 additions and 10 deletions

View File

@ -0,0 +1,50 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "adam_weight_decay_impl.cuh"
#include "device/gpu/cuda_common.h"
template <typename T>
__global__ void AdamWeightDecayKernel(const int element_num_, const bool need_decay, const float *beta1,
const float *one_sub_beta1, const float *beta2, const float *one_sub_beta2,
const float *epsilon, const float *lr, const float *weight_decay, T *m, T *v,
T *param, T *gradient) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < element_num_; i += blockDim.x * gridDim.x) {
float next_m = beta1[0] * m[i] + one_sub_beta1[0] * gradient[i];
float next_v = beta2[0] * v[i] + one_sub_beta2[0] * gradient[i] * gradient[i];
float update = next_m / (sqrt(next_v) + epsilon[0]);
if (need_decay && weight_decay != nullptr) {
update += weight_decay[0] * param[i];
}
param[i] -= lr[0] * update;
m[i] = next_m;
v[i] = next_v;
}
}
template <typename T>
void AdamWeightDecay(const int &element_num_, const bool &need_decay, const float *beta1, const float *one_sub_beta1,
const float *beta2, const float *one_sub_beta2, const float *epsilon, const float *lr,
const float *weight_decay, T *m, T *v, T *param, T *gradient, cudaStream_t stream) {
AdamWeightDecayKernel<<<GET_BLOCKS(element_num_), GET_THREADS, 0, stream>>>(
element_num_, need_decay, beta1, one_sub_beta1, beta2, one_sub_beta2, epsilon, lr, weight_decay, m, v, param,
gradient);
}
template void AdamWeightDecay(const int &element_num_, const bool &need_decay, const float *beta1,
const float *one_sub_beta1, const float *beta2, const float *one_sub_beta2,
const float *epsilon, const float *lr, const float *weight_decay, float *m, float *v,
float *param, float *gradient, cudaStream_t stream);

View File

@ -0,0 +1,24 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ADAM_WEIGHT_DECAY_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ADAM_WEIGHT_DECAY_H_
template <typename T>
void AdamWeightDecay(const int &element_num_, const bool &need_decay, const float *beta1, const float *one_sub_beta1,
const float *beta2, const float *one_sub_beta2, const float *epsilon, const float *lr,
const float *weight_decay, T *m, T *v, T *param, T *gradient, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ADAM_WEIGHT_DECAY_H_

View File

@ -0,0 +1,52 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/nn/fused_adam_weight_decay.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(FusedAdamWeightDecay,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
FusedAdamWeightDecayGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(FusedAdam,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
FusedAdamWeightDecayGpuKernel, float)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,103 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/kernel_constants.h"
#include "kernel/gpu/cuda_impl/adam_weight_decay_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class FusedAdamWeightDecayGpuKernel : public GpuKernel {
public:
FusedAdamWeightDecayGpuKernel() : element_nums_(0), weight_decay_(false) {}
~FusedAdamWeightDecayGpuKernel() override = default;
bool Init(const CNodePtr &kernel_node) override {
auto node_name = AnfAlgo::GetCNodeName(kernel_node);
if (node_name == "AdamWeighDecay") {
weight_decay_ = true;
}
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 7);
element_nums_ = 1;
for (auto i : shape) {
element_nums_ *= i;
}
InitSizeLists();
return true;
}
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
float *beta1 = GetDeviceAddress<float>(inputs, 0);
float *one_sub_beta1 = GetDeviceAddress<float>(inputs, 1);
float *beta2 = GetDeviceAddress<float>(inputs, 2);
float *one_sub_beta2 = GetDeviceAddress<float>(inputs, 3);
float *epsilon = GetDeviceAddress<float>(inputs, 4);
float *lr = GetDeviceAddress<float>(inputs, 5);
T *param = GetDeviceAddress<T>(inputs, 6);
T *m = GetDeviceAddress<T>(inputs, 7);
T *v = GetDeviceAddress<T>(inputs, 8);
T *gradient = GetDeviceAddress<T>(inputs, 9);
float *weight_decay = nullptr;
if (weight_decay_) {
weight_decay = GetDeviceAddress<float>(inputs, 10);
}
AdamWeightDecay(element_nums_, true, beta1, one_sub_beta1, beta2, one_sub_beta2, epsilon, lr, weight_decay, m, v,
param, gradient, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
protected:
void InitResource() override{};
void InitSizeLists() override {
input_size_list_.push_back(sizeof(float));
input_size_list_.push_back(sizeof(float));
input_size_list_.push_back(sizeof(float));
input_size_list_.push_back(sizeof(float));
input_size_list_.push_back(element_nums_ * sizeof(T));
input_size_list_.push_back(sizeof(float));
input_size_list_.push_back(sizeof(float));
input_size_list_.push_back(element_nums_ * sizeof(T));
if (weight_decay_) {
input_size_list_.push_back(sizeof(float));
}
output_size_list_.push_back(element_nums_ * sizeof(T));
}
private:
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int element_nums_;
bool weight_decay_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_

View File

@ -182,9 +182,11 @@ extern const PrimitivePtr kPrimReduceMin;
extern const PrimitivePtr kPrimNeg;
extern const PrimitivePtr kPrimSub;
extern const PrimitivePtr kPrimMul;
extern const PrimitivePtr kPrimRealDiv;
extern const PrimitivePtr kPrimMinimum;
extern const PrimitivePtr kPrimMaximum;
extern const PrimitivePtr kPrimSquare;
extern const PrimitivePtr kPrimSqrt;
extern const PrimitivePtr kPrimEqual;
extern const PrimitivePtr kPrimLess;
extern const PrimitivePtr kPrimLessEqual;

View File

@ -0,0 +1,112 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/gpu/adam_fusion.h"
#include <memory>
#include <vector>
#include <string>
#include "session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
#include "pre_activate/common/helper.h"
namespace mindspore {
namespace opt {
namespace {
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
std::vector<std::string> inputs_format;
std::vector<std::string> outputs_format;
std::vector<TypeId> inputs_type;
std::vector<TypeId> outputs_type;
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) {
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
inputs_format.push_back(kOpFormat_DEFAULT);
}
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) {
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
outputs_format.push_back(kOpFormat_DEFAULT);
}
builder.SetInputsDeviceType(inputs_type);
builder.SetInputsFormat(inputs_format);
builder.SetOutputsDeviceType(outputs_type);
builder.SetOutputsFormat(outputs_format);
return builder.Build();
}
} // namespace
const BaseRef AdamFusion::DefinePattern() const {
VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}),
VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});
VectorRef next_v =
VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}),
VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})});
VectorRef update = VectorRef(
{prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})});
VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update});
VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr});
VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})});
VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})});
VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})});
return depend3;
}
const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
auto beta1_input = utils::cast<AnfNodePtr>((*equiv)[beta1_]);
auto one_sub_beta1_input = utils::cast<AnfNodePtr>((*equiv)[one_sub_beta1_]);
auto beta2_input = utils::cast<AnfNodePtr>((*equiv)[beta2_]);
auto one_sub_beta2_input = utils::cast<AnfNodePtr>((*equiv)[one_sub_beta2_]);
auto eps_input = utils::cast<AnfNodePtr>((*equiv)[eps_]);
auto lr_input = utils::cast<AnfNodePtr>((*equiv)[lr_]);
auto param_input = utils::cast<AnfNodePtr>((*equiv)[param_]);
auto m_input = utils::cast<AnfNodePtr>((*equiv)[m_]);
auto v_input = utils::cast<AnfNodePtr>((*equiv)[v_]);
auto gradient_input = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
MS_EXCEPTION_IF_NULL(beta1_input);
MS_EXCEPTION_IF_NULL(one_sub_beta1_input);
MS_EXCEPTION_IF_NULL(beta2_input);
MS_EXCEPTION_IF_NULL(one_sub_beta2_input);
MS_EXCEPTION_IF_NULL(eps_input);
MS_EXCEPTION_IF_NULL(lr_input);
MS_EXCEPTION_IF_NULL(param_input);
MS_EXCEPTION_IF_NULL(m_input);
MS_EXCEPTION_IF_NULL(v_input);
MS_EXCEPTION_IF_NULL(gradient_input);
auto prim = std::make_shared<Primitive>(kFusedAdamName);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {
NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input,
eps_input, lr_input, param_input, m_input, v_input,
gradient_input};
auto adam = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(adam);
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam.get());
adam->set_scope(node->scope());
auto build_info = GenerateKernelBuildInfo(adam);
AnfAlgo::SetSelectKernelBuildInfo(build_info, adam.get());
return adam;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,56 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_
#include <memory>
#include "pre_activate/common/optimizer.h"
namespace mindspore {
namespace opt {
class AdamFusion : public PatternProcessPass {
public:
explicit AdamFusion(bool multigraph = true) : PatternProcessPass("adam_fusion", multigraph) {
beta1_ = std::make_shared<Var>();
one_sub_beta1_ = std::make_shared<Var>();
beta2_ = std::make_shared<Var>();
one_sub_beta2_ = std::make_shared<Var>();
eps_ = std::make_shared<Var>();
lr_ = std::make_shared<Var>();
param_ = std::make_shared<Var>();
m_ = std::make_shared<Var>();
v_ = std::make_shared<Var>();
gradient_ = std::make_shared<Var>();
}
~AdamFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
VarPtr beta1_;
VarPtr one_sub_beta1_;
VarPtr beta2_;
VarPtr one_sub_beta2_;
VarPtr eps_;
VarPtr lr_;
VarPtr param_;
VarPtr m_;
VarPtr v_;
VarPtr gradient_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_

View File

@ -0,0 +1,117 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/gpu/adam_weight_decay_fusion.h"
#include <memory>
#include <vector>
#include <string>
#include "session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
#include "pre_activate/common/helper.h"
namespace mindspore {
namespace opt {
namespace {
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
std::vector<std::string> inputs_format;
std::vector<std::string> outputs_format;
std::vector<TypeId> inputs_type;
std::vector<TypeId> outputs_type;
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) {
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
inputs_format.push_back(kOpFormat_DEFAULT);
}
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) {
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
outputs_format.push_back(kOpFormat_DEFAULT);
}
builder.SetInputsDeviceType(inputs_type);
builder.SetInputsFormat(inputs_format);
builder.SetOutputsDeviceType(outputs_type);
builder.SetOutputsFormat(outputs_format);
return builder.Build();
}
} // namespace
const BaseRef AdamWeightDecayFusion::DefinePattern() const {
VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}),
VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});
VectorRef next_v =
VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}),
VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})});
VectorRef update = VectorRef(
{prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})});
VectorRef new_update = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, weight_decay_, param_}), update});
VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update});
VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr});
VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})});
VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})});
VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})});
return depend3;
}
const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
auto beta1_input = utils::cast<AnfNodePtr>((*equiv)[beta1_]);
auto one_sub_beta1_input = utils::cast<AnfNodePtr>((*equiv)[one_sub_beta1_]);
auto beta2_input = utils::cast<AnfNodePtr>((*equiv)[beta2_]);
auto one_sub_beta2_input = utils::cast<AnfNodePtr>((*equiv)[one_sub_beta2_]);
auto eps_input = utils::cast<AnfNodePtr>((*equiv)[eps_]);
auto lr_input = utils::cast<AnfNodePtr>((*equiv)[lr_]);
auto weight_decay_input = utils::cast<AnfNodePtr>((*equiv)[weight_decay_]);
auto param_input = utils::cast<AnfNodePtr>((*equiv)[param_]);
auto m_input = utils::cast<AnfNodePtr>((*equiv)[m_]);
auto v_input = utils::cast<AnfNodePtr>((*equiv)[v_]);
auto gradient_input = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
MS_EXCEPTION_IF_NULL(beta1_input);
MS_EXCEPTION_IF_NULL(one_sub_beta1_input);
MS_EXCEPTION_IF_NULL(beta2_input);
MS_EXCEPTION_IF_NULL(one_sub_beta2_input);
MS_EXCEPTION_IF_NULL(eps_input);
MS_EXCEPTION_IF_NULL(lr_input);
MS_EXCEPTION_IF_NULL(weight_decay_input);
MS_EXCEPTION_IF_NULL(param_input);
MS_EXCEPTION_IF_NULL(m_input);
MS_EXCEPTION_IF_NULL(v_input);
MS_EXCEPTION_IF_NULL(gradient_input);
auto prim = std::make_shared<Primitive>(kFusedAdamWeightDecayName);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {
NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input,
eps_input, lr_input, param_input, m_input, v_input,
gradient_input, weight_decay_input};
auto adam_weight_decay = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(adam_weight_decay);
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam_weight_decay.get());
adam_weight_decay->set_scope(node->scope());
auto build_info = GenerateKernelBuildInfo(adam_weight_decay);
AnfAlgo::SetSelectKernelBuildInfo(build_info, adam_weight_decay.get());
return adam_weight_decay;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,58 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_
#include <memory>
#include "pre_activate/common/optimizer.h"
namespace mindspore {
namespace opt {
class AdamWeightDecayFusion : public PatternProcessPass {
public:
explicit AdamWeightDecayFusion(bool multigraph = true) : PatternProcessPass("adam_weight_decay_fusion", multigraph) {
beta1_ = std::make_shared<Var>();
one_sub_beta1_ = std::make_shared<Var>();
beta2_ = std::make_shared<Var>();
one_sub_beta2_ = std::make_shared<Var>();
eps_ = std::make_shared<Var>();
lr_ = std::make_shared<Var>();
weight_decay_ = std::make_shared<Var>();
param_ = std::make_shared<Var>();
m_ = std::make_shared<Var>();
v_ = std::make_shared<Var>();
gradient_ = std::make_shared<Var>();
}
~AdamWeightDecayFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
VarPtr beta1_;
VarPtr one_sub_beta1_;
VarPtr beta2_;
VarPtr one_sub_beta2_;
VarPtr eps_;
VarPtr lr_;
VarPtr weight_decay_;
VarPtr param_;
VarPtr m_;
VarPtr v_;
VarPtr gradient_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_

View File

@ -23,6 +23,8 @@
#include "pre_activate/common/helper.h"
#include "pre_activate/pass/communication_op_fusion.h"
#include "pre_activate/pass/getitem_tuple.h"
#include "pre_activate/gpu/adam_weight_decay_fusion.h"
#include "pre_activate/gpu/adam_fusion.h"
#include "device/kernel_runtime_manager.h"
#include "predict/predict.h"
#include "common/utils.h"
@ -53,6 +55,16 @@ void GPUSession::StartKernelRT() const {
void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>());
pm->AddPass(std::make_shared<opt::AdamFusion>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
}
void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AllReduceFusion>());
@ -151,14 +163,16 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
auto graph_id = graph_sum_;
auto graph = ConstructKernelGraph(lst, outputs);
MS_EXCEPTION_IF_NULL(graph);
// Optimize
Optimize(graph);
// Select kernel build info
SelectKernel(graph);
// Convert kernel Graph to model
predictmodel::StepConvertGraph(graph);
// Start gpu kernel runtime
StartKernelRT();
// AllReduce Optimize
Optimize(graph);
// HardwareOptimize
HardwareOptimize(graph);
// Assign CUDA streams
AssignStream(graph);
// Hide NoOp from execution graph

View File

@ -51,6 +51,8 @@ class GPUSession : public SessionBasic {
void Optimize(const std::shared_ptr<KernelGraph> &kernel_graph);
void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph);
void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;

View File

@ -161,6 +161,8 @@ constexpr auto kNMSWithMaskOpName = "NMSWithMask";
constexpr auto kSoftmaxGradExtOpName = "SoftmaxGradExt";
constexpr auto kStridedReadOpName = "StridedRead";
constexpr auto kStridedWriteOpName = "StridedWrite";
constexpr auto kFusedAdamWeightDecayName = "FusedAdamWeightDecay";
constexpr auto kFusedAdamName = "FusedAdam";
// attr key name
constexpr auto kAttrInputNames = "input_names";

View File

@ -0,0 +1,87 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True)
class Net(nn.Cell):
def __init__(self, decay_flag=True):
super(Net, self).__init__()
self.decay_flag = decay_flag
self.op_mul = P.Mul()
self.op_square = P.Square()
self.op_sqrt = P.Sqrt()
self.op_cast = P.Cast()
self.op_reshape = P.Reshape()
self.op_shape = P.Shape()
self.param = Parameter(Tensor(np.array([0.1, 0.3, 0.5]).astype(np.float32)), name='param')
self.m = Parameter(Tensor(np.array([0.1, 0.3, 0.5]).astype(np.float32)), name='m')
self.v = Parameter(Tensor(np.array([0.1, 0.3, 0.5]).astype(np.float32)), name='v')
@ms_function
def construct(self, beta1, beta2, gradient, eps, weight_decay_tensor, lr):
param_fp32 = self.op_cast(self.param, mstype.float32)
m_fp32 = self.op_cast(self.m, mstype.float32)
v_fp32 = self.op_cast(self.v, mstype.float32)
gradient_fp32 = self.op_cast(gradient, mstype.float32)
next_m = self.op_mul(beta1, m_fp32) + \
self.op_mul(self.op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)
next_v = self.op_mul(beta2, v_fp32) + self.op_mul(self.op_cast(F.tuple_to_array((1.0,)), mstype.float32) - \
beta2, self.op_square(gradient_fp32))
update = next_m / (eps + self.op_sqrt(next_v))
if self.decay_flag:
update = self.op_mul(weight_decay_tensor, param_fp32) + update
update_with_lr = self.op_mul(lr, update)
next_param = param_fp32 - self.op_reshape(update_with_lr, self.op_shape(param_fp32))
next_v = F.depend(next_v, F.assign(self.param, next_param))
next_v = F.depend(next_v, F.assign(self.m, next_m))
next_v = F.depend(next_v, F.assign(self.v, next_v))
return next_v
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test():
beta1 = Tensor(np.array([0.9]).astype(np.float32))
beta2 = Tensor(np.array([0.999]).astype(np.float32))
lr = Tensor(np.array([0.001]).astype(np.float32))
eps = Tensor(np.array([1e-6]).astype(np.float32))
weight_decay_tensor = Tensor(np.array([0.001]).astype(np.float32))
gradient = Tensor(np.array([0.01, 0.03, 0.05]).astype(np.float32))
opt = Net(True)
_ = opt(beta1, beta2, gradient, eps, weight_decay_tensor, lr)
param_expect = np.array([0.09971199, 0.29950103, 0.4993557]).astype(np.float32)
m_expect = np.array([0.091, 0.273, 0.45499998]).astype(np.float32)
v_expect = np.array([0.0999001, 0.29970092, 0.4995025]).astype(np.float32)
assert np.allclose(opt.param.data.asnumpy(), param_expect)
assert np.allclose(opt.m.data.asnumpy(), m_expect)
assert np.allclose(opt.v.data.asnumpy(), v_expect)

View File

@ -119,6 +119,10 @@ def test_4d_transpose_ab():
[[5612, 5810, 6008, 6206]]]]
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_4D_fp16():
input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float16)
input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float16)
@ -126,13 +130,13 @@ def test_4D_fp16():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = BatchMatMulNet()
output = net(input_x, input_y)
expect = [[[[20, 23, 26, 29]],
[[200, 212, 224, 236]],
[[596, 617, 638, 659]],
[[1208, 1238, 1268, 1298]]],
expect = np.array([[[[20, 23, 26, 29]],
[[200, 212, 224, 236]],
[[596, 617, 638, 659]],
[[1208, 1238, 1268, 1298]]],
[[[2036, 2075, 2114, 2153]],
[[3080, 3128, 3176, 3224]],
[[4340, 4397, 4454, 4511]],
[[5816, 5882, 5948, 6014]]]]
[[[2036, 2076, 2114, 2152]],
[[3080, 3128, 3176, 3224]],
[[4340, 4396, 4456, 4510]],
[[5816, 5880, 5948, 6016]]]]).astype(np.float16)
assert (output.asnumpy() == expect).all()