debug_sync_batch_norm

This commit is contained in:
zong-shuai 2022-04-16 16:39:05 +08:00
parent cc134689a1
commit 6b7464cd60
6 changed files with 0 additions and 852 deletions

View File

@ -1,126 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/nccl/sync_batch_norm_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_THREE(SyncBatchNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
SyncBatchNormGpuKernel, float, float, float)
MS_REG_GPU_KERNEL_THREE(SyncBatchNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
SyncBatchNormGpuKernel, half, float, float)
MS_REG_GPU_KERNEL_THREE(SyncBatchNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
SyncBatchNormGpuKernel, float, half, float)
MS_REG_GPU_KERNEL_THREE(SyncBatchNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
SyncBatchNormGpuKernel, half, half, float)
MS_REG_GPU_KERNEL_THREE(SyncBatchNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
SyncBatchNormGpuKernel, float, float, half)
MS_REG_GPU_KERNEL_THREE(SyncBatchNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
SyncBatchNormGpuKernel, half, float, half)
MS_REG_GPU_KERNEL_THREE(SyncBatchNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
SyncBatchNormGpuKernel, float, half, half)
MS_REG_GPU_KERNEL_THREE(SyncBatchNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
SyncBatchNormGpuKernel, half, half, half)
} // namespace kernel
} // namespace mindspore

View File

@ -1,232 +0,0 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GPU_KERNEL_H_
#include <dlfcn.h>
#include <stdint.h>
#include <string>
#include <vector>
#include "plugin/device/gpu/kernel/nccl/nccl_gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/kernel_constants.h"
#include "include/common/utils/utils.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sync_batch_norm_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T, typename S, typename G>
class SyncBatchNormGpuKernel : public NcclGpuKernelMod {
public:
SyncBatchNormGpuKernel() { ResetResource(); }
~SyncBatchNormGpuKernel() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
T *x = GetDeviceAddress<T>(inputs, 0);
S *scale = GetDeviceAddress<S>(inputs, 1);
S *bias = GetDeviceAddress<S>(inputs, 2);
G *running_mean_input = GetDeviceAddress<G>(inputs, 3);
G *running_variance_input = GetDeviceAddress<G>(inputs, 4);
float *means_local = GetDeviceAddress<float>(workspace, 0); // per device
float *invstds_local = GetDeviceAddress<float>(workspace, 1);
int *counts_local = GetDeviceAddress<int>(workspace, 2);
int *counts_global = GetDeviceAddress<int>(workspace, 3); // gathered values from all devices
float *means_global = GetDeviceAddress<float>(workspace, 4);
float *invstds_global = GetDeviceAddress<float>(workspace, 5);
T *y = GetDeviceAddress<T>(outputs, 0);
S *output_scale = GetDeviceAddress<S>(outputs, 1);
S *output_bias = GetDeviceAddress<S>(outputs, 2);
T *output_running_mean = GetDeviceAddress<T>(outputs, 3);
T *output_running_variance = GetDeviceAddress<T>(outputs, 4);
// aggregate means and invstd on each device locally
CalSyncBatchNormPre(N_, C_, H_, W_, x, counts_local, means_local, invstds_local, epsilon_,
reinterpret_cast<cudaStream_t>(stream_ptr));
// gather values from all devices together
LaunchAllGather(means_local, means_global, stream_ptr);
LaunchAllGather(invstds_local, invstds_global, stream_ptr);
LaunchAllGather(counts_local, counts_global, stream_ptr);
// reducing gathered values on each device and deal with running means and variance
CalSyncBatchNormGather(N_, C_, H_, W_, counts_global, means_global, invstds_global, counts_local, means_local,
invstds_local, output_running_mean, output_running_variance, running_mean_input,
running_variance_input, epsilon_, momentum_, group_rank_, group_size_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalSyncBatchNormPost(N_, C_, H_, W_, x, y, means_local, invstds_local, scale, bias, output_scale, output_bias,
epsilon_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
auto prim = common::AnfAlgo::GetCNodePrimitive(kernel_node);
MS_EXCEPTION_IF_NULL(prim);
auto root_rank = prim->GetAttr(kAttrRootRank);
kernel_node_ = kernel_node;
if (root_rank) {
root_ = static_cast<int>(GetValue<int64_t>(root_rank));
}
nccl_data_type_ = nccl_dtype(AnfAlgo::GetInputDeviceDataType(kernel_node, 0));
group_name_ = GetAttr<std::string>(kernel_node, kAttrGroup);
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 5) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs should be 5, but got " << input_num;
}
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 5) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs should be 5, but got " << output_num;
}
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name, "input");
if (is_null_input_) {
InitSizeLists();
return true;
}
auto input_shape_dims = input_shape.size();
if (input_shape_dims != 4 && input_shape_dims != 2) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of input only should be 2 or 4, but got "
<< input_shape_dims;
}
input_size_ = 1;
for (auto dim : input_shape) {
input_size_ *= dim;
}
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
momentum_ = GetAttr<float>(kernel_node, "momentum");
output_size_ = input_size_;
output_size_ = output_size_ * sizeof(T);
input_size_ = input_size_ * sizeof(T);
param_count_ = input_shape[1]; // C is number of features
param_size_S_ = param_count_ * sizeof(S); // will be second/third template
param_size_G_input_ = param_count_ * sizeof(G);
param_size_G_output_ = param_count_ * sizeof(T);
workspace_size_ = param_count_; // specific size computed in InitSizeLists()
N_ = input_shape[0];
C_ = input_shape[1];
if (input_shape_dims == 2) {
// NC -> N,C,1,1 transform input dims
H_ = 1;
W_ = 1;
} else {
H_ = input_shape[2];
W_ = input_shape[3];
}
// MULTI DEVICE SPECIFICS
group_name_ = GetAttr<std::string>(kernel_node, kAttrGroup);
MS_LOG(INFO) << common::AnfAlgo::GetCNodeName(kernel_node) << " for group " << group_name_;
auto comm_stream_attr = prim->GetAttr("stream_id");
if (comm_stream_attr) {
comm_stream_ = reinterpret_cast<cudaStream_t>(GetValue<uintptr_t>(comm_stream_attr));
MS_EXCEPTION_IF_NULL(comm_stream_);
}
SelectCollectiveHandle();
// Get group size
group_size_ = device::gpu::CollectiveInitializer::instance().GetGroupSize(group_name_);
// // Get device rank ID in group
group_rank_ = device::gpu::CollectiveInitializer::instance().local_rank_id();
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
input_size_ = 0;
output_size_ = 0;
workspace_size_ = 0;
momentum_ = 0;
epsilon_ = 10e-5;
param_size_S_ = 0;
param_size_G_input_ = 0;
param_size_G_output_ = 0;
param_count_ = 0;
N_ = 0;
C_ = 0;
H_ = 0;
W_ = 0;
root_ = 0;
collective_handle_ = nullptr;
comm_stream_ = nullptr;
nccl_reduce_type_ = ncclSum;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
is_null_input_ = false;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_); // input x
input_size_list_.push_back(param_size_S_); // scale
input_size_list_.push_back(param_size_S_); // bias
input_size_list_.push_back(param_size_G_input_); // running mean
input_size_list_.push_back(param_size_G_input_); // running variance
output_size_list_.push_back(output_size_); // output
output_size_list_.push_back(param_size_S_); // save scale
output_size_list_.push_back(param_size_S_); // reserve space
output_size_list_.push_back(param_size_G_output_); // save mean
output_size_list_.push_back(param_size_G_output_); // save variance
// local mean/variance data - per device
workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // mean_local
workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // invstd_local
workspace_size_list_.push_back(workspace_size_ * sizeof(int)); // count_local
// global mean/variance data - for all devices
workspace_size_list_.push_back(workspace_size_ * sizeof(int) * group_size_); // gathered mean
workspace_size_list_.push_back(workspace_size_ * sizeof(float) * group_size_); // gathered invstd
workspace_size_list_.push_back(workspace_size_ * sizeof(float) * group_size_); // gathered count
}
private:
// GetTypeID functions return the correct typeID for input template
// Allow for a single templated LaunchAllGather function
mindspore::TypeId GetTypeID(float *input) { return kNumberTypeFloat32; }
mindspore::TypeId GetTypeID(int *input) { return kNumberTypeInt32; }
template <typename gather_type>
void LaunchAllGather(gather_type *input_addr, gather_type *output_addr, void *stream_ptr) {
cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast<cudaStream_t>(stream_ptr);
(void)AllGather(input_addr, output_addr, C_, nccl_dtype(GetTypeID(input_addr)), stream, group_name_);
}
size_t input_size_;
size_t output_size_;
size_t workspace_size_;
float momentum_;
float epsilon_;
size_t param_size_S_;
size_t param_size_G_input_;
size_t param_size_G_output_;
size_t param_count_;
size_t N_;
size_t C_;
size_t H_;
size_t W_;
size_t group_size_;
size_t group_rank_;
ncclRedOp_t nccl_reduce_type_;
// NCCL
string group_name_;
int root_;
cudaStream_t comm_stream_;
bool is_null_input_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GPU_KERNEL_H_

View File

@ -1,110 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/nccl/sync_batch_norm_grad_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
SyncBatchNormGradGpuKernel, float, float, float)
MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
SyncBatchNormGradGpuKernel, half, float, float)
MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
SyncBatchNormGradGpuKernel, float, half, float)
MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
SyncBatchNormGradGpuKernel, half, half, float)
MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
SyncBatchNormGradGpuKernel, float, float, half)
MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
SyncBatchNormGradGpuKernel, half, float, half)
MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
SyncBatchNormGradGpuKernel, float, half, half)
MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
SyncBatchNormGradGpuKernel, half, half, half)
} // namespace kernel
} // namespace mindspore

View File

@ -1,200 +0,0 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GRAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GRAD_GPU_KERNEL_H_
#include <dlfcn.h>
#include <stdint.h>
#include <string>
#include <vector>
#include "plugin/device/gpu/kernel/nccl/nccl_gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/kernel_constants.h"
#include "include/common/utils/utils.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sync_batch_norm_grad_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T, typename S, typename G>
class SyncBatchNormGradGpuKernel : public NcclGpuKernelMod {
public:
SyncBatchNormGradGpuKernel() { ResetResource(); }
~SyncBatchNormGradGpuKernel() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
T *dy = GetDeviceAddress<T>(inputs, 0);
T *x_input = GetDeviceAddress<T>(inputs, 1);
S *scale = GetDeviceAddress<S>(inputs, 2);
G *saved_mean = GetDeviceAddress<G>(inputs, 3);
G *saved_variance = GetDeviceAddress<G>(inputs, 4);
float *dy_sum_local = GetDeviceAddress<float>(workspace, 0);
float *dot_p_local = GetDeviceAddress<float>(workspace, 1);
float *dy_sum_red = GetDeviceAddress<float>(workspace, 2);
float *dot_p_red = GetDeviceAddress<float>(workspace, 3);
T *dx = GetDeviceAddress<T>(outputs, 0);
S *dscale = GetDeviceAddress<S>(outputs, 1);
S *dbias = GetDeviceAddress<S>(outputs, 2);
// aggregate interim values on each device locally
CalSyncBatchNormGradPre(N_, C_, H_, W_, x_input, dy, saved_mean, saved_variance, dy_sum_local, dot_p_local,
reinterpret_cast<cudaStream_t>(stream_ptr));
// reduce values across devices
LaunchAllReduce(dy_sum_local, dy_sum_red, stream_ptr);
LaunchAllReduce(dot_p_local, dot_p_red, stream_ptr);
// Aggregate and compute output
CalSyncBatchNormGradPost(N_, C_, H_, W_, x_input, dy, dx, saved_mean, saved_variance, dy_sum_red, dot_p_red, scale,
dscale, dbias, epsilon_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
auto prim = common::AnfAlgo::GetCNodePrimitive(kernel_node);
MS_EXCEPTION_IF_NULL(prim);
auto root_rank = prim->GetAttr(kAttrRootRank);
kernel_node_ = kernel_node;
if (root_rank) {
root_ = static_cast<int>(GetValue<int64_t>(root_rank));
}
nccl_data_type_ = nccl_dtype(AnfAlgo::GetInputDeviceDataType(kernel_node, 0));
group_name_ = GetAttr<std::string>(kernel_node, kAttrGroup);
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 5) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs should be 5, but got " << input_num;
}
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 3) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs should be 3, but got " << output_num;
}
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name, "input");
if (is_null_input_) {
InitSizeLists();
return true;
}
auto input_shape_dims = input_shape.size();
if (input_shape_dims != 4 && input_shape_dims != 2) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of input only should be 2 or 4, but got "
<< input_shape_dims;
}
input_size_ = 1;
for (auto dim : input_shape) {
input_size_ *= dim;
}
output_size_ = input_size_;
output_size_ = output_size_ * sizeof(T);
input_size_ = input_size_ * sizeof(T);
param_count_ = input_shape[1];
param_size_S_ = param_count_ * sizeof(S);
param_size_G_ = param_count_ * sizeof(G);
N_ = input_shape[0];
C_ = input_shape[1];
if (input_shape_dims == 2) { // N,C,1,1 transform input
H_ = 1;
W_ = 1;
} else {
H_ = input_shape[2];
W_ = input_shape[3];
}
workspace_size_ = C_;
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
// MULTIDEVICE SPECIFICS
group_name_ = GetAttr<std::string>(kernel_node, kAttrGroup);
MS_LOG(INFO) << common::AnfAlgo::GetCNodeName(kernel_node) << " for group " << group_name_;
auto comm_stream_attr = prim->GetAttr("stream_id");
if (comm_stream_attr) {
comm_stream_ = reinterpret_cast<cudaStream_t>(GetValue<uintptr_t>(comm_stream_attr));
MS_EXCEPTION_IF_NULL(comm_stream_);
}
SelectCollectiveHandle();
// Get group size
device_count_ = device::gpu::CollectiveInitializer::instance().GetGroupSize(group_name_);
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
input_size_ = 0;
output_size_ = 0;
workspace_size_ = 0;
epsilon_ = 10e-5; // default
param_size_S_ = 0;
param_size_G_ = 0;
param_count_ = 0;
N_ = 0;
C_ = 0;
H_ = 0;
W_ = 0;
root_ = 0;
collective_handle_ = nullptr;
comm_stream_ = nullptr;
nccl_reduce_type_ = ncclSum;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
is_null_input_ = false;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_); // dy
input_size_list_.push_back(input_size_); // x
input_size_list_.push_back(param_size_S_); // scale
input_size_list_.push_back(param_size_G_); // saved_mean
input_size_list_.push_back(param_size_G_); // saved_variance
output_size_list_.push_back(output_size_); // dx
output_size_list_.push_back(param_size_S_); // dscale
output_size_list_.push_back(param_size_S_); // dbias
workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // sum_dy
workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // sum_dy_xmu
workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // reduced sum_dy
workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // reduced sum_dy_xmu
}
private:
template <typename reduce_type>
void LaunchAllReduce(reduce_type *input_addr, reduce_type *output_addr, void *stream_ptr) {
cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast<cudaStream_t>(stream_ptr);
(void)AllReduce(input_addr, output_addr, C_, nccl_dtype(kNumberTypeFloat32), nccl_reduce_type_, stream,
group_name_);
}
size_t input_size_;
size_t output_size_;
size_t workspace_size_;
float epsilon_;
size_t param_size_S_;
size_t param_size_G_;
size_t param_count_;
size_t N_;
size_t C_;
size_t H_;
size_t W_;
size_t device_count_;
ncclRedOp_t nccl_reduce_type_;
// NCCL
string group_name_;
int root_;
cudaStream_t comm_stream_;
bool is_null_input_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GRAD_GPU_KERNEL_H_

View File

@ -1,134 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.communication.management import init
from mindspore.ops import composite as C
# define target and input values here
x_fwd_input = np.array([[
[[1, 3, 3, 5], [2, 4, 6, 8], [3, 6, 7, 7], [4, 3, 8, 2]],
[[5, 7, 6, 3], [3, 5, 6, 7], [9, 4, 2, 5], [7, 5, 8, 1]]]]).astype(np.float32)
expect_output_fwd = np.array([[[[-0.6059, 0.3118, 0.3118, 1.2294],
[-0.1471, 0.7706, 1.6882, 2.6059],
[0.3118, 1.6882, 2.1471, 2.1471],
[0.7706, 0.3118, 2.6059, -0.1471]],
[[0.9119, 1.8518, 1.3819, -0.0281],
[-0.0281, 0.9119, 1.3819, 1.8518],
[2.7918, 0.4419, -0.4981, 0.9119],
[1.8518, 0.9119, 2.3218, -0.9680]]]]).astype(np.float32)
grad_back = np.array([[[[1, 2, 7, 1], [4, 2, 1, 3], [1, 6, 5, 2], [2, 4, 3, 2]],
[[9, 4, 3, 5], [1, 3, 7, 6], [5, 7, 9, 9], [1, 4, 6, 8]]]]).astype(np.float32)
expect_output_back = np.array([[[[-0.69126546, -0.32903028, 1.9651246, -0.88445705],
[0.6369296, -0.37732816, -0.93275493, -0.11168876],
[-0.7878612, 1.3614, 0.8542711, -0.52222186],
[-0.37732816, 0.5886317, -0.11168876, -0.28073236]],
[[1.6447213, -0.38968924, -1.0174079, -0.55067265],
[-2.4305856, -1.1751484, 0.86250514, 0.5502673],
[0.39576983, 0.5470243, 1.1715001, 1.6447213],
[-1.7996241, -0.7051701, 0.7080077, 0.5437813]]]]).astype(np.float32)
class Net(nn.Cell):
def __init__(self, c):
super(Net, self).__init__()
self.num_features = c
self.eps = 1e-5
self.momentum = 1
self.mode = True
self.affine = True
self.sync_bn_op = nn.SyncBatchNorm(num_features=self.num_features,
eps=self.eps,
momentum=self.momentum,
affine=self.affine,
gamma_init='ones',
beta_init='ones',
moving_mean_init='ones',
moving_var_init='ones',
use_batch_statistics=True,
process_groups=None)
def construct(self, input_data):
return self.sync_bn_op(input_data)
class Grad(nn.Cell):
def __init__(self, network):
super(Grad, self).__init__()
self.grad = C.GradOperation(get_all=True, sens_param=True)
self.network = network
def construct(self, input_data, sens):
gout = self.grad(self.network)(input_data, sens)
return gout
def test_sync_batch_norm_forward_fp32_graph():
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
init()
x = x_fwd_input.copy().astype(np.float32)
expect_output = expect_output_fwd.copy().astype(np.float32)
overall_shape = x.shape
error = np.ones(shape=overall_shape) * 1.0e-4
net = Net(2)
net.set_train()
output = net(Tensor(x))
diff = output.asnumpy() - expect_output
assert np.all(diff < error)
assert np.all(-diff < error)
def test_sync_batch_norm_forward_fp16_pynative():
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
init()
x = x_fwd_input.copy().astype(np.float16)
expect_output = expect_output_fwd.copy().astype(np.float16)
overall_shape = x.shape
error = np.ones(shape=overall_shape) * 1.0e-3
net = Net(2)
net.set_train()
output = net(Tensor(x))
diff = output.asnumpy() - expect_output
assert np.all(diff < error)
assert np.all(-diff < error)
def test_sync_batch_norm_backwards_fp32_graph():
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
init()
x = x_fwd_input.copy().astype(np.float32)
expect_output = expect_output_back.copy().astype(np.float32)
grad = grad_back.copy().astype(np.float32)
overall_shape = x.shape
error = np.ones(shape=overall_shape) * 1.0e-5
fwd_net = Net(2)
fwd_net.set_train()
bn_grad = Grad(fwd_net)
output = bn_grad(Tensor(x), Tensor(grad))
diff = output[0].asnumpy() - expect_output
assert np.all(diff < error)
assert np.all(-diff < error)
def test_sync_batch_norm_backwards_fp16_pynative():
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
init()
x = x_fwd_input.copy().astype(np.float16)
expect_output = expect_output_back.copy().astype(np.float16)
grad = grad_back.copy().astype(np.float16)
overall_shape = x.shape
error = np.ones(shape=overall_shape) * 1.0e-3
fwd_net = Net(2)
fwd_net.set_train()
bn_grad = Grad(fwd_net)
output = bn_grad(Tensor(x), Tensor(grad))
diff = output[0].asnumpy() - expect_output
assert np.all(diff < error)
assert np.all(-diff < error)

View File

@ -1,50 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import pytest
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_single
def test_nccl_sync_batch_norm_1():
cmd_str = "mpirun -n 4 pytest -s test_nccl_sync_batch_norm_op.py::test_sync_batch_norm_forward_fp32_graph"
return_code = os.system(cmd_str)
assert return_code == 0
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_single
def test_nccl_sync_batch_norm_2():
cmd_str = "mpirun -n 4 pytest -s test_nccl_sync_batch_norm_op.py::test_sync_batch_norm_forward_fp16_pynative"
return_code = os.system(cmd_str)
assert return_code == 0
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_single
def test_nccl_sync_batch_norm_3():
cmd_str = "mpirun -n 1 pytest -s test_nccl_sync_batch_norm_op.py::test_sync_batch_norm_backwards_fp32_graph"
return_code = os.system(cmd_str)
assert return_code == 0
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_single
def test_nccl_sync_batch_norm_4():
cmd_str = "mpirun -n 1 pytest -s test_nccl_sync_batch_norm_op.py::test_sync_batch_norm_backwards_fp16_pynative"
return_code = os.system(cmd_str)
assert return_code == 0