gpu add combine mom fusion

This commit is contained in:
VectorSL 2020-09-24 11:16:29 +08:00
parent 2bac83ba1b
commit 8dca80036a
9 changed files with 525 additions and 18 deletions

View File

@ -99,7 +99,7 @@ template <typename T, typename S>
__global__ void FusedMomentumScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation,
const T *learning_rate, const S *gradient, const T *momentum) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num); i += blockDim.x * gridDim.x) {
accumulation[i] = momentum[0] * accumulation[i] + static_cast<T>(gradient[i]);
accumulation[i] = momentum[0] * accumulation[i] + static_cast<T>(gradient[i]) * scale[0];
variable[i] -= learning_rate[0] * accumulation[i];
}
}
@ -113,6 +113,56 @@ void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accu
element_num, scale, variable, accumulation, learning_rate, gradient, momentum);
}
// CombineFusedScaleMomentum
template <typename T, typename S>
__global__ void CombineFusedMomentumScaleMomentum(const size_t num, const size_t *element_num,
T **scale, T **variable, T **accumulation,
T **learning_rate, S **gradient, T **momentum) {
for (size_t idx = 0; idx < num; idx++) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num[idx]); i += blockDim.x * gridDim.x) {
accumulation[idx][i] = momentum[idx][0] * accumulation[idx][i] + static_cast<T>(gradient[idx][i]) * scale[idx][0];
variable[idx][i] -= learning_rate[idx][0] * accumulation[idx][i];
}
}
}
template <typename T, typename S>
void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *elements, T **scale,
T **variable, T **accumulation, T **learning_rate, S **gradient,
T **momentum, cudaStream_t cuda_stream) {
size_t thread_per_block = 256;
size_t block_per_grid = (max + thread_per_block - 1) / thread_per_block;
CombineFusedMomentumScaleMomentum<<<block_per_grid, thread_per_block, 0, cuda_stream>>>(
num, elements, scale, variable, accumulation, learning_rate, gradient, momentum);
}
// end CombineFusedScaleMomentum
// CombineFusedWeightDecayScaleMomentum
template <typename T, typename S>
__global__ void CombineFusedMomentumWeightDecayScaleMomentum(const size_t num, const size_t *element_num,
T **weight_decay, T **scale, T **variable,
T **accumulation, T **learning_rate, S **gradient,
T **momentum) {
for (size_t idx = 0; idx < num; idx++) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num[idx]); i += blockDim.x * gridDim.x) {
T grad = (variable[idx][i] * weight_decay[idx][0] + static_cast<T>(gradient[idx][i])) * scale[idx][0];
accumulation[idx][i] = momentum[idx][0] * accumulation[idx][i] + grad;
variable[idx][i] -= learning_rate[idx][0] * accumulation[idx][i];
}
}
}
template <typename T, typename S>
void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *element_num,
T **weight_decay, T **scale, T **variable, T **accumulation,
T **learning_rate, S **gradient, T **momentum,
cudaStream_t cuda_stream) {
size_t thread_per_block = 256;
size_t block_per_grid = (max + thread_per_block - 1) / thread_per_block;
CombineFusedMomentumWeightDecayScaleMomentum<<<block_per_grid, thread_per_block, 0, cuda_stream>>>(
num, element_num, weight_decay, scale, variable, accumulation, learning_rate, gradient, momentum);
}
// end CombineFusedWeightDecayScaleMomentum
template void MomentumUpdateVariable<float, float, float>(const size_t size, float *variable, float *accumulation,
const float *learning_rate, const float *gradient,
const float *momentum, bool use_nesterov,
@ -142,3 +192,17 @@ template void FusedScaleMomentum(const size_t element_num, float *scale, float *
template void FusedScaleMomentum(const size_t element_num, float *scale, float *variable, float *accumulation,
const float *learning_rate, const half *gradient, const float *momentum,
cudaStream_t cuda_stream);
template void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *elements,
float **weight_decay, float **scale, float **variable,
float **accumulation, float **learning_rate, float **gradient,
float **momentum, cudaStream_t cuda_stream);
template void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *elements,
float **weight_decay, float **scale, float **variable,
float **accumulation, float **learning_rate, half **gradient,
float **momentum, cudaStream_t cuda_stream);
template void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *elements, float **scale,
float **variable, float **accumulation, float **learning_rate,
float **gradient, float **momentum, cudaStream_t cuda_stream);
template void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *elements, float **scale,
float **variable, float **accumulation, float **learning_rate,
half **gradient, float **momentum, cudaStream_t cuda_stream);

View File

@ -28,5 +28,12 @@ void FusedWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T
template <typename T, typename S>
void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation, const T *learning_rate,
const S *gradient, const T *momentum, cudaStream_t cuda_stream);
template <typename T, typename S>
void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *element, T **weight_decay,
T **scale, T **variable, T **accumulation, T **learning_rate, S **gradient,
T **momentum, cudaStream_t cuda_stream);
template <typename T, typename S>
void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *element, T **scale, T **variable,
T **accumulation, T **learning_rate, S **gradient, T **momentum,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_

View File

@ -40,22 +40,12 @@ void GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const Kernel
std::vector<std::pair<KernelAttr, GpuKernelCreater>> *iter_second,
size_t attr_index) {
if (kernel_info->GetInputNum() != iter_second->at(attr_index).first.GetInputSize()) {
if (iter_second->at(attr_index).first.GetAllSame()) {
auto dtype = iter_second->at(attr_index).first.GetInputAttr(0).first;
for (size_t attr = 1; attr < kernel_info->GetInputNum(); ++attr) {
(void)iter_second->at(attr_index).first.AddInputAttr(dtype);
}
} else {
if (!iter_second->at(attr_index).first.GetAllSame()) {
MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Input size is mismatching!";
}
}
if (kernel_info->GetOutputNum() != iter_second->at(attr_index).first.GetOutputSize()) {
if (iter_second->at(attr_index).first.GetAllSame()) {
auto dtype = iter_second->at(attr_index).first.GetOutputAttr(0).first;
for (size_t attr = 1; attr < kernel_info->GetOutputNum(); ++attr) {
(void)iter_second->at(attr_index).first.AddOutputAttr(dtype);
}
} else {
if (!iter_second->at(attr_index).first.GetAllSame()) {
MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Output size is mismatching!";
}
}
@ -99,6 +89,7 @@ std::pair<bool, size_t> GpuKernelFactory::GpuKernelAttrCheck(const std::string &
for (size_t attr_index = 0; attr_index < (iter->second).size(); ++attr_index) {
CheckIOParam(kernel_name, kernel_info, &(iter->second), attr_index);
bool flag = true;
auto attr_size = (&(iter->second))->at(attr_index).first.GetInputSize();
// data type matching check of all input parameters of kernel
for (size_t input_index = 0; input_index < kernel_info->GetInputNum(); input_index++) {
if (marjor_sm < RECOMMEND_SM && kernel_info->GetInputDeviceType(input_index) == kNumberTypeFloat16) {
@ -110,7 +101,7 @@ std::pair<bool, size_t> GpuKernelFactory::GpuKernelAttrCheck(const std::string &
<< ", but the current device's computing capacity is " << marjor_sm;
}
if (kernel_info->GetInputDeviceType(input_index) !=
(iter->second)[attr_index].first.GetInputAttr(input_index).first) {
(iter->second)[attr_index].first.GetInputAttr(input_index % attr_size).first) {
flag = false;
break;
}
@ -118,10 +109,11 @@ std::pair<bool, size_t> GpuKernelFactory::GpuKernelAttrCheck(const std::string &
if (!flag) {
continue;
}
attr_size = (&(iter->second))->at(attr_index).first.GetOutputSize();
// data type matching check of all output parameters of kernel
for (size_t output_index = 0; output_index < kernel_info->GetOutputNum(); output_index++) {
if (kernel_info->GetOutputDeviceType(output_index) !=
(iter->second)[attr_index].first.GetOutputAttr(output_index).first) {
(iter->second)[attr_index].first.GetOutputAttr(output_index % attr_size).first) {
flag = false;
break;
}

View File

@ -0,0 +1,70 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/combine_momentum_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(CombineMomentum,
KernelAttr()
.AddAllSameAttr(true)
.AddInputAttr(kNumberTypeFloat32) // scale
.AddInputAttr(kNumberTypeFloat32) // variable
.AddInputAttr(kNumberTypeFloat32) // accumulation
.AddInputAttr(kNumberTypeFloat32) // learning_rate
.AddInputAttr(kNumberTypeFloat32) // gradient
.AddInputAttr(kNumberTypeFloat32) // momentum
.AddOutputAttr(kNumberTypeFloat32),
CombineMomentumGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(CombineMomentum,
KernelAttr()
.AddAllSameAttr(true)
.AddInputAttr(kNumberTypeFloat32) // scale
.AddInputAttr(kNumberTypeFloat32) // variable
.AddInputAttr(kNumberTypeFloat32) // accumulation
.AddInputAttr(kNumberTypeFloat32) // variable
.AddInputAttr(kNumberTypeFloat32) // accumulation
.AddInputAttr(kNumberTypeFloat32) // learning_rate
.AddInputAttr(kNumberTypeFloat16) // gradient
.AddInputAttr(kNumberTypeFloat32) // momentum
.AddOutputAttr(kNumberTypeFloat32),
CombineMomentumGpuKernel, float, half)
MS_REG_GPU_KERNEL_TWO(CombineMomentumWeight,
KernelAttr()
.AddAllSameAttr(true)
.AddInputAttr(kNumberTypeFloat32) // weight decay
.AddInputAttr(kNumberTypeFloat32) // scale
.AddInputAttr(kNumberTypeFloat32) // variable
.AddInputAttr(kNumberTypeFloat32) // accumulation
.AddInputAttr(kNumberTypeFloat32) // learning_rate
.AddInputAttr(kNumberTypeFloat32) // gradient
.AddInputAttr(kNumberTypeFloat32) // momentum
.AddOutputAttr(kNumberTypeFloat32),
CombineMomentumGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(CombineMomentumWeight,
KernelAttr()
.AddAllSameAttr(true)
.AddInputAttr(kNumberTypeFloat32) // variable
.AddInputAttr(kNumberTypeFloat32) // accumulation
.AddInputAttr(kNumberTypeFloat32) // variable
.AddInputAttr(kNumberTypeFloat32) // accumulation
.AddInputAttr(kNumberTypeFloat32) // learning_rate
.AddInputAttr(kNumberTypeFloat16) // gradient
.AddInputAttr(kNumberTypeFloat32) // momentum
.AddOutputAttr(kNumberTypeFloat32),
CombineMomentumGpuKernel, float, half)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,207 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_SCALE_MOMENTUM_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_SCALE_MOMENTUM_GPU_KERNEL_H_
#include <vector>
#include <memory>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class CombineMomentumGpuKernel : public GpuKernel {
public:
CombineMomentumGpuKernel() : element_num_(1), num_(0), max_(0), input_num_(6) {}
~CombineMomentumGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &workspace, void *stream_ptr) override {
const cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr);
auto weight_decay = std::make_unique<T *[]>(input_num_ * num_);
auto scale = std::make_unique<T *[]>(input_num_ * num_);
auto variable = std::make_unique<T *[]>(input_num_ * num_);
auto accumulation = std::make_unique<T *[]>(input_num_ * num_);
auto learning_rate = std::make_unique<T *[]>(input_num_ * num_);
auto gradient = std::make_unique<S *[]>(input_num_ * num_);
auto momentum = std::make_unique<T *[]>(input_num_ * num_);
if (input_num_ == 6) {
LaunchCombineMom(inputs, workspace, stream, scale, variable, accumulation, learning_rate, gradient, momentum);
} else {
LaunchCombineMomWeightDecay(inputs, workspace, stream, weight_decay, scale, variable, accumulation, learning_rate,
gradient, momentum);
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
num_ = GetAttr<size_t>(kernel_node, "n");
elements_ = std::make_unique<size_t[]>(num_);
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name == "CombineMomentum") {
input_num_ = 6;
} else {
input_num_ = 7;
workspace_size_list_.push_back(sizeof(T *) * num_);
}
for (size_t i = 0; i < num_; i++) {
element_num_ = 1;
auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i * input_num_ + input_num_ - 4);
for (size_t j = 0; j < variable_shape.size(); j++) {
element_num_ *= variable_shape[j];
}
if (max_ < element_num_) {
max_ = element_num_;
}
elements_[i] = element_num_;
InitSizeLists();
}
workspace_size_list_.push_back(sizeof(T *) * num_);
workspace_size_list_.push_back(sizeof(T *) * num_);
workspace_size_list_.push_back(sizeof(T *) * num_);
workspace_size_list_.push_back(sizeof(T *) * num_);
workspace_size_list_.push_back(sizeof(S *) * num_);
workspace_size_list_.push_back(sizeof(T *) * num_);
workspace_size_list_.push_back(sizeof(size_t) * num_);
return true;
}
protected:
void InitSizeLists() override {
if (input_num_ == 7) {
input_size_list_.push_back(sizeof(T));
}
input_size_list_.push_back(sizeof(T));
input_size_list_.push_back(element_num_ * sizeof(T));
input_size_list_.push_back(element_num_ * sizeof(T));
input_size_list_.push_back(sizeof(T));
input_size_list_.push_back(element_num_ * sizeof(S));
input_size_list_.push_back(sizeof(T));
output_size_list_.push_back(element_num_ * sizeof(T));
}
private:
void LaunchCombineMom(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const cudaStream_t &stream, const std::unique_ptr<T *[]> &scale,
const std::unique_ptr<T *[]> &variable, const std::unique_ptr<T *[]> &accumulation,
const std::unique_ptr<T *[]> &learning_rate, const std::unique_ptr<S *[]> &gradient,
const std::unique_ptr<T *[]> &momentum) {
for (size_t i = 0; i < num_; i++) {
scale[i] = GetDeviceAddress<T>(inputs, i * input_num_);
variable[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 1);
accumulation[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 2);
learning_rate[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 3);
gradient[i] = GetDeviceAddress<S>(inputs, i * input_num_ + 4);
momentum[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 5);
}
T **scale_dev = GetDeviceAddress<T *>(workspace, 0);
T **variable_dev = GetDeviceAddress<T *>(workspace, 1);
T **accumulation_dev = GetDeviceAddress<T *>(workspace, 2);
T **learning_rate_dev = GetDeviceAddress<T *>(workspace, 3);
S **gradient_dev = GetDeviceAddress<S *>(workspace, 4);
T **momentum_dev = GetDeviceAddress<T *>(workspace, 5);
size_t *elements_dev = GetDeviceAddress<size_t>(workspace, 6);
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(scale_dev, scale.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), "cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(variable_dev, variable.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(accumulation_dev, accumulation.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(learning_rate_dev, learning_rate.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(gradient_dev, gradient.get(), sizeof(S *) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(momentum_dev, momentum.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(elements_dev, elements_.get(), sizeof(size_t) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CombineFusedScaleMomentum(max_, num_, elements_dev, scale_dev, variable_dev, accumulation_dev, learning_rate_dev,
gradient_dev, momentum_dev, stream);
}
void LaunchCombineMomWeightDecay(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const cudaStream_t &stream, const std::unique_ptr<T *[]> &weight_decay,
const std::unique_ptr<T *[]> &scale, const std::unique_ptr<T *[]> &variable,
const std::unique_ptr<T *[]> &accumulation,
const std::unique_ptr<T *[]> &learning_rate, const std::unique_ptr<S *[]> &gradient,
const std::unique_ptr<T *[]> &momentum) {
for (size_t i = 0; i < num_; i++) {
weight_decay[i] = GetDeviceAddress<T>(inputs, i * input_num_);
scale[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 1);
variable[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 2);
accumulation[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 3);
learning_rate[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 4);
gradient[i] = GetDeviceAddress<S>(inputs, i * input_num_ + 5);
momentum[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 6);
}
T **weight_decay_dev = GetDeviceAddress<T *>(workspace, 0);
T **scale_dev = GetDeviceAddress<T *>(workspace, 1);
T **variable_dev = GetDeviceAddress<T *>(workspace, 2);
T **accumulation_dev = GetDeviceAddress<T *>(workspace, 3);
T **learning_rate_dev = GetDeviceAddress<T *>(workspace, 4);
S **gradient_dev = GetDeviceAddress<S *>(workspace, 5);
T **momentum_dev = GetDeviceAddress<T *>(workspace, 6);
size_t *elements_dev = GetDeviceAddress<size_t>(workspace, 7);
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(weight_decay_dev, weight_decay.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(scale_dev, scale.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), "cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(variable_dev, variable.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(accumulation_dev, accumulation.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(learning_rate_dev, learning_rate.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(gradient_dev, gradient.get(), sizeof(S *) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(momentum_dev, momentum.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(elements_dev, elements_.get(), sizeof(size_t) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CombineFusedWeightDecayScaleMomentum(max_, num_, elements_dev, weight_decay_dev, scale_dev, variable_dev,
accumulation_dev, learning_rate_dev, gradient_dev, momentum_dev, stream);
}
size_t element_num_;
std::unique_ptr<size_t[]> elements_;
size_t num_;
size_t max_;
int input_num_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_SCALE_MOMENTUM_GPU_KERNEL_H_

View File

@ -52,7 +52,7 @@ class FusedScaleMomentumGpuKernel : public GpuKernel {
return false;
}
auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
for (size_t i = 0; i < variable_shape.size(); i++) {
element_num_ *= variable_shape[i];
}

View File

@ -53,7 +53,7 @@ class FusedWeightDecayScaleMomentumGpuKernel : public GpuKernel {
return false;
}
auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
for (size_t i = 0; i < variable_shape.size(); i++) {
element_num_ *= variable_shape[i];
}

View File

@ -0,0 +1,133 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/gpu/combine_momentum_fusion.h"
#include <memory>
#include <vector>
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr> &node_list) {
std::vector<std::string> inputs_device_format;
std::vector<std::string> outputs_device_format;
std::vector<TypeId> inputs_device_type;
std::vector<TypeId> outputs_device_type;
std::vector<std::vector<size_t>> outputs_shape;
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
for (size_t idx = 0; idx < node_list.size(); ++idx) {
auto cnode = utils::cast<CNodePtr>(node_list[idx]);
MS_EXCEPTION_IF_NULL(cnode);
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
inputs_device_format.push_back(kOpFormat_DEFAULT);
inputs_device_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
}
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) {
outputs_device_format.push_back(kOpFormat_DEFAULT);
outputs_device_type.push_back(AnfAlgo::GetOutputInferDataType(cnode, output_index));
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
}
}
builder.SetInputsFormat(inputs_device_format);
builder.SetOutputsFormat(outputs_device_format);
builder.SetInputsDeviceType(inputs_device_type);
builder.SetOutputsDeviceType(outputs_device_type);
return builder.Build();
}
bool GetDealList(const std::vector<AnfNodePtr> &node_list, std::vector<std::vector<AnfNodePtr>> *deal_list) {
std::vector<AnfNodePtr> momentum;
std::vector<AnfNodePtr> momentum_decay;
for (auto &momentum_node : node_list) {
if (momentum_node != nullptr && momentum_node->isa<CNode>()) {
if (AnfAlgo::GetCNodeName(momentum_node) == kFusedScaleApplyMomentum) {
momentum.push_back(momentum_node);
} else if (AnfAlgo::GetCNodeName(momentum_node) == kFusedWeightScaleApplyMomentum) {
momentum_decay.push_back(momentum_node);
}
}
}
if (momentum.size() <= 1 && momentum_decay.size() <= 1) {
return false;
}
if (momentum.size() > 1) {
deal_list->push_back(momentum);
}
if (momentum_decay.size() > 1) {
deal_list->push_back(momentum_decay);
}
return true;
}
bool CombineMomentumFusion::Run(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
// 1 get all the cast node
std::vector<std::vector<AnfNodePtr>> deal_list;
if (!GetDealList(node_list, &deal_list)) {
return false;
}
for (auto momentums : deal_list) {
// 2 create node momentum
std::vector<AnfNodePtr> inputs = {};
if (AnfAlgo::GetCNodeName(momentums[0]) == kFusedScaleApplyMomentum) {
auto prim = std::make_shared<Primitive>("CombineMomentum");
MS_EXCEPTION_IF_NULL(prim);
inputs.push_back(NewValueNode(prim));
} else {
auto prim = std::make_shared<Primitive>("CombineMomentumWeight");
MS_EXCEPTION_IF_NULL(prim);
inputs.push_back(NewValueNode(prim));
}
// set inputs for momentum
size_t input_num = AnfAlgo::GetInputTensorNum(momentums[0]);
for (auto mom : momentums) {
for (size_t i = 0; i < input_num; i++) {
inputs.push_back(AnfAlgo::GetInputNode(utils::cast<CNodePtr>(mom), i));
}
}
auto combine_mom = graph->NewCNode(inputs);
auto kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(kernel_info);
combine_mom->set_kernel_info(kernel_info);
AbstractBasePtrList abstract_list;
for (size_t idx = 0; idx < momentums.size(); ++idx) {
auto cnode = utils::cast<CNodePtr>(momentums[idx]);
MS_EXCEPTION_IF_NULL(cnode);
abstract_list.push_back(cnode->abstract());
}
auto kernel_build_info = GenerateKernelBuildInfo(momentums);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, combine_mom.get());
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
MS_EXCEPTION_IF_NULL(abstract_tuple);
combine_mom->set_abstract(abstract_tuple);
AnfAlgo::SetNodeAttr("n", MakeValue(momentums.size()), combine_mom);
// 3 replace all the cast by momentum
for (size_t idx = 0; idx < momentums.size(); ++idx) {
if (!manager->Replace(momentums[idx], combine_mom)) {
MS_LOG(EXCEPTION) << "manager replace node failed";
}
}
}
return true;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_COMBINE_MOMENTUM_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_COMBINE_MOMENTUM_FUSION_H_
#include <memory>
#include <string>
#include <vector>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class CombineMomentumFusion : public Pass {
public:
explicit CombineMomentumFusion(const std::string &name) : Pass("combine_momentum") {}
~CombineMomentumFusion() override = default;
bool Run(const FuncGraphPtr &graph) override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_COMBINE_MOMENTUM_FUSION_H_