diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nccl/sync_batch_norm_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nccl/sync_batch_norm_gpu_kernel.cc deleted file mode 100644 index ca88da1b4fd..00000000000 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nccl/sync_batch_norm_gpu_kernel.cc +++ /dev/null @@ -1,126 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/nccl/sync_batch_norm_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_THREE(SyncBatchNorm, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SyncBatchNormGpuKernel, float, float, float) -MS_REG_GPU_KERNEL_THREE(SyncBatchNorm, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - SyncBatchNormGpuKernel, half, float, float) -MS_REG_GPU_KERNEL_THREE(SyncBatchNorm, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SyncBatchNormGpuKernel, float, half, float) -MS_REG_GPU_KERNEL_THREE(SyncBatchNorm, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - SyncBatchNormGpuKernel, half, half, float) -MS_REG_GPU_KERNEL_THREE(SyncBatchNorm, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SyncBatchNormGpuKernel, float, float, half) -MS_REG_GPU_KERNEL_THREE(SyncBatchNorm, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - SyncBatchNormGpuKernel, half, float, half) -MS_REG_GPU_KERNEL_THREE(SyncBatchNorm, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SyncBatchNormGpuKernel, float, half, half) -MS_REG_GPU_KERNEL_THREE(SyncBatchNorm, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - SyncBatchNormGpuKernel, half, half, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nccl/sync_batch_norm_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nccl/sync_batch_norm_gpu_kernel.h deleted file mode 100644 index 60b0f8e416c..00000000000 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nccl/sync_batch_norm_gpu_kernel.h +++ /dev/null @@ -1,232 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include "plugin/device/gpu/kernel/nccl/nccl_gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/kernel_constants.h" -#include "include/common/utils/utils.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sync_batch_norm_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class SyncBatchNormGpuKernel : public NcclGpuKernelMod { - public: - SyncBatchNormGpuKernel() { ResetResource(); } - ~SyncBatchNormGpuKernel() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *x = GetDeviceAddress(inputs, 0); - S *scale = GetDeviceAddress(inputs, 1); - S *bias = GetDeviceAddress(inputs, 2); - G *running_mean_input = GetDeviceAddress(inputs, 3); - G *running_variance_input = GetDeviceAddress(inputs, 4); - - float *means_local = GetDeviceAddress(workspace, 0); // per device - float *invstds_local = GetDeviceAddress(workspace, 1); - int *counts_local = GetDeviceAddress(workspace, 2); - int *counts_global = GetDeviceAddress(workspace, 3); // gathered values from all devices - float *means_global = GetDeviceAddress(workspace, 4); - float *invstds_global = GetDeviceAddress(workspace, 5); - - T *y = GetDeviceAddress(outputs, 0); - S *output_scale = GetDeviceAddress(outputs, 1); - S *output_bias = GetDeviceAddress(outputs, 2); - T *output_running_mean = GetDeviceAddress(outputs, 3); - T *output_running_variance = GetDeviceAddress(outputs, 4); - - // aggregate means and invstd on each device locally - CalSyncBatchNormPre(N_, C_, H_, W_, x, counts_local, means_local, invstds_local, epsilon_, - reinterpret_cast(stream_ptr)); - // gather values from all devices together - LaunchAllGather(means_local, means_global, stream_ptr); - LaunchAllGather(invstds_local, invstds_global, stream_ptr); - LaunchAllGather(counts_local, counts_global, stream_ptr); - // reducing gathered values on each device and deal with running means and variance - CalSyncBatchNormGather(N_, C_, H_, W_, counts_global, means_global, invstds_global, counts_local, means_local, - invstds_local, output_running_mean, output_running_variance, running_mean_input, - running_variance_input, epsilon_, momentum_, group_rank_, group_size_, - reinterpret_cast(stream_ptr)); - CalSyncBatchNormPost(N_, C_, H_, W_, x, y, means_local, invstds_local, scale, bias, output_scale, output_bias, - epsilon_, reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node); - auto prim = common::AnfAlgo::GetCNodePrimitive(kernel_node); - MS_EXCEPTION_IF_NULL(prim); - auto root_rank = prim->GetAttr(kAttrRootRank); - kernel_node_ = kernel_node; - if (root_rank) { - root_ = static_cast(GetValue(root_rank)); - } - nccl_data_type_ = nccl_dtype(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)); - group_name_ = GetAttr(kernel_node, kAttrGroup); - size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 5) { - MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs should be 5, but got " << input_num; - } - size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 5) { - MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs should be 5, but got " << output_num; - } - auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name, "input"); - if (is_null_input_) { - InitSizeLists(); - return true; - } - auto input_shape_dims = input_shape.size(); - if (input_shape_dims != 4 && input_shape_dims != 2) { - MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of input only should be 2 or 4, but got " - << input_shape_dims; - } - input_size_ = 1; - for (auto dim : input_shape) { - input_size_ *= dim; - } - epsilon_ = GetAttr(kernel_node, "epsilon"); - momentum_ = GetAttr(kernel_node, "momentum"); - output_size_ = input_size_; - output_size_ = output_size_ * sizeof(T); - input_size_ = input_size_ * sizeof(T); - param_count_ = input_shape[1]; // C is number of features - param_size_S_ = param_count_ * sizeof(S); // will be second/third template - param_size_G_input_ = param_count_ * sizeof(G); - param_size_G_output_ = param_count_ * sizeof(T); - workspace_size_ = param_count_; // specific size computed in InitSizeLists() - N_ = input_shape[0]; - C_ = input_shape[1]; - if (input_shape_dims == 2) { - // NC -> N,C,1,1 transform input dims - H_ = 1; - W_ = 1; - } else { - H_ = input_shape[2]; - W_ = input_shape[3]; - } - // MULTI DEVICE SPECIFICS - group_name_ = GetAttr(kernel_node, kAttrGroup); - MS_LOG(INFO) << common::AnfAlgo::GetCNodeName(kernel_node) << " for group " << group_name_; - auto comm_stream_attr = prim->GetAttr("stream_id"); - if (comm_stream_attr) { - comm_stream_ = reinterpret_cast(GetValue(comm_stream_attr)); - MS_EXCEPTION_IF_NULL(comm_stream_); - } - SelectCollectiveHandle(); - // Get group size - group_size_ = device::gpu::CollectiveInitializer::instance().GetGroupSize(group_name_); - // // Get device rank ID in group - group_rank_ = device::gpu::CollectiveInitializer::instance().local_rank_id(); - InitSizeLists(); - return true; - } - - void ResetResource() noexcept override { - input_size_ = 0; - output_size_ = 0; - workspace_size_ = 0; - momentum_ = 0; - epsilon_ = 10e-5; - param_size_S_ = 0; - param_size_G_input_ = 0; - param_size_G_output_ = 0; - param_count_ = 0; - N_ = 0; - C_ = 0; - H_ = 0; - W_ = 0; - root_ = 0; - collective_handle_ = nullptr; - comm_stream_ = nullptr; - nccl_reduce_type_ = ncclSum; - input_size_list_.clear(); - output_size_list_.clear(); - workspace_size_list_.clear(); - is_null_input_ = false; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); // input x - input_size_list_.push_back(param_size_S_); // scale - input_size_list_.push_back(param_size_S_); // bias - input_size_list_.push_back(param_size_G_input_); // running mean - input_size_list_.push_back(param_size_G_input_); // running variance - output_size_list_.push_back(output_size_); // output - output_size_list_.push_back(param_size_S_); // save scale - output_size_list_.push_back(param_size_S_); // reserve space - output_size_list_.push_back(param_size_G_output_); // save mean - output_size_list_.push_back(param_size_G_output_); // save variance - // local mean/variance data - per device - workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // mean_local - workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // invstd_local - workspace_size_list_.push_back(workspace_size_ * sizeof(int)); // count_local - // global mean/variance data - for all devices - workspace_size_list_.push_back(workspace_size_ * sizeof(int) * group_size_); // gathered mean - workspace_size_list_.push_back(workspace_size_ * sizeof(float) * group_size_); // gathered invstd - workspace_size_list_.push_back(workspace_size_ * sizeof(float) * group_size_); // gathered count - } - - private: - // GetTypeID functions return the correct typeID for input template - // Allow for a single templated LaunchAllGather function - mindspore::TypeId GetTypeID(float *input) { return kNumberTypeFloat32; } - mindspore::TypeId GetTypeID(int *input) { return kNumberTypeInt32; } - template - void LaunchAllGather(gather_type *input_addr, gather_type *output_addr, void *stream_ptr) { - cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast(stream_ptr); - (void)AllGather(input_addr, output_addr, C_, nccl_dtype(GetTypeID(input_addr)), stream, group_name_); - } - - size_t input_size_; - size_t output_size_; - size_t workspace_size_; - float momentum_; - float epsilon_; - size_t param_size_S_; - size_t param_size_G_input_; - size_t param_size_G_output_; - size_t param_count_; - size_t N_; - size_t C_; - size_t H_; - size_t W_; - size_t group_size_; - size_t group_rank_; - ncclRedOp_t nccl_reduce_type_; - - // NCCL - string group_name_; - int root_; - cudaStream_t comm_stream_; - bool is_null_input_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nccl/sync_batch_norm_grad_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nccl/sync_batch_norm_grad_gpu_kernel.cc deleted file mode 100644 index 0b2f155c414..00000000000 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nccl/sync_batch_norm_grad_gpu_kernel.cc +++ /dev/null @@ -1,110 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/nccl/sync_batch_norm_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SyncBatchNormGradGpuKernel, float, float, float) -MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SyncBatchNormGradGpuKernel, half, float, float) -MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - SyncBatchNormGradGpuKernel, float, half, float) -MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - SyncBatchNormGradGpuKernel, half, half, float) -MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SyncBatchNormGradGpuKernel, float, float, half) -MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SyncBatchNormGradGpuKernel, half, float, half) -MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - SyncBatchNormGradGpuKernel, float, half, half) -MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - SyncBatchNormGradGpuKernel, half, half, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nccl/sync_batch_norm_grad_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nccl/sync_batch_norm_grad_gpu_kernel.h deleted file mode 100644 index c6a0b0a0fb6..00000000000 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nccl/sync_batch_norm_grad_gpu_kernel.h +++ /dev/null @@ -1,200 +0,0 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GRAD_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include "plugin/device/gpu/kernel/nccl/nccl_gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/kernel_constants.h" -#include "include/common/utils/utils.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sync_batch_norm_grad_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class SyncBatchNormGradGpuKernel : public NcclGpuKernelMod { - public: - SyncBatchNormGradGpuKernel() { ResetResource(); } - ~SyncBatchNormGradGpuKernel() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *dy = GetDeviceAddress(inputs, 0); - T *x_input = GetDeviceAddress(inputs, 1); - S *scale = GetDeviceAddress(inputs, 2); - G *saved_mean = GetDeviceAddress(inputs, 3); - G *saved_variance = GetDeviceAddress(inputs, 4); - float *dy_sum_local = GetDeviceAddress(workspace, 0); - float *dot_p_local = GetDeviceAddress(workspace, 1); - float *dy_sum_red = GetDeviceAddress(workspace, 2); - float *dot_p_red = GetDeviceAddress(workspace, 3); - T *dx = GetDeviceAddress(outputs, 0); - S *dscale = GetDeviceAddress(outputs, 1); - S *dbias = GetDeviceAddress(outputs, 2); - // aggregate interim values on each device locally - CalSyncBatchNormGradPre(N_, C_, H_, W_, x_input, dy, saved_mean, saved_variance, dy_sum_local, dot_p_local, - reinterpret_cast(stream_ptr)); - // reduce values across devices - LaunchAllReduce(dy_sum_local, dy_sum_red, stream_ptr); - LaunchAllReduce(dot_p_local, dot_p_red, stream_ptr); - // Aggregate and compute output - CalSyncBatchNormGradPost(N_, C_, H_, W_, x_input, dy, dx, saved_mean, saved_variance, dy_sum_red, dot_p_red, scale, - dscale, dbias, epsilon_, reinterpret_cast(stream_ptr)); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node); - auto prim = common::AnfAlgo::GetCNodePrimitive(kernel_node); - MS_EXCEPTION_IF_NULL(prim); - auto root_rank = prim->GetAttr(kAttrRootRank); - kernel_node_ = kernel_node; - if (root_rank) { - root_ = static_cast(GetValue(root_rank)); - } - nccl_data_type_ = nccl_dtype(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)); - group_name_ = GetAttr(kernel_node, kAttrGroup); - size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 5) { - MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs should be 5, but got " << input_num; - } - size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 3) { - MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs should be 3, but got " << output_num; - } - auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name, "input"); - if (is_null_input_) { - InitSizeLists(); - return true; - } - auto input_shape_dims = input_shape.size(); - if (input_shape_dims != 4 && input_shape_dims != 2) { - MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of input only should be 2 or 4, but got " - << input_shape_dims; - } - input_size_ = 1; - for (auto dim : input_shape) { - input_size_ *= dim; - } - output_size_ = input_size_; - output_size_ = output_size_ * sizeof(T); - input_size_ = input_size_ * sizeof(T); - param_count_ = input_shape[1]; - param_size_S_ = param_count_ * sizeof(S); - param_size_G_ = param_count_ * sizeof(G); - N_ = input_shape[0]; - C_ = input_shape[1]; - if (input_shape_dims == 2) { // N,C,1,1 transform input - H_ = 1; - W_ = 1; - } else { - H_ = input_shape[2]; - W_ = input_shape[3]; - } - workspace_size_ = C_; - epsilon_ = GetAttr(kernel_node, "epsilon"); - // MULTIDEVICE SPECIFICS - group_name_ = GetAttr(kernel_node, kAttrGroup); - MS_LOG(INFO) << common::AnfAlgo::GetCNodeName(kernel_node) << " for group " << group_name_; - auto comm_stream_attr = prim->GetAttr("stream_id"); - if (comm_stream_attr) { - comm_stream_ = reinterpret_cast(GetValue(comm_stream_attr)); - MS_EXCEPTION_IF_NULL(comm_stream_); - } - SelectCollectiveHandle(); - // Get group size - device_count_ = device::gpu::CollectiveInitializer::instance().GetGroupSize(group_name_); - InitSizeLists(); - return true; - } - - void ResetResource() noexcept override { - input_size_ = 0; - output_size_ = 0; - workspace_size_ = 0; - epsilon_ = 10e-5; // default - param_size_S_ = 0; - param_size_G_ = 0; - param_count_ = 0; - N_ = 0; - C_ = 0; - H_ = 0; - W_ = 0; - root_ = 0; - collective_handle_ = nullptr; - comm_stream_ = nullptr; - nccl_reduce_type_ = ncclSum; - input_size_list_.clear(); - output_size_list_.clear(); - workspace_size_list_.clear(); - is_null_input_ = false; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); // dy - input_size_list_.push_back(input_size_); // x - input_size_list_.push_back(param_size_S_); // scale - input_size_list_.push_back(param_size_G_); // saved_mean - input_size_list_.push_back(param_size_G_); // saved_variance - output_size_list_.push_back(output_size_); // dx - output_size_list_.push_back(param_size_S_); // dscale - output_size_list_.push_back(param_size_S_); // dbias - workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // sum_dy - workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // sum_dy_xmu - workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // reduced sum_dy - workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // reduced sum_dy_xmu - } - - private: - template - void LaunchAllReduce(reduce_type *input_addr, reduce_type *output_addr, void *stream_ptr) { - cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast(stream_ptr); - (void)AllReduce(input_addr, output_addr, C_, nccl_dtype(kNumberTypeFloat32), nccl_reduce_type_, stream, - group_name_); - } - - size_t input_size_; - size_t output_size_; - size_t workspace_size_; - float epsilon_; - size_t param_size_S_; - size_t param_size_G_; - size_t param_count_; - size_t N_; - size_t C_; - size_t H_; - size_t W_; - size_t device_count_; - ncclRedOp_t nccl_reduce_type_; - - // NCCL - string group_name_; - int root_; - cudaStream_t comm_stream_; - bool is_null_input_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GRAD_GPU_KERNEL_H_ diff --git a/tests/st/nccl/test_nccl_sync_batch_norm_op.py b/tests/st/nccl/test_nccl_sync_batch_norm_op.py deleted file mode 100644 index 65d47246e43..00000000000 --- a/tests/st/nccl/test_nccl_sync_batch_norm_op.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.communication.management import init -from mindspore.ops import composite as C - -# define target and input values here -x_fwd_input = np.array([[ - [[1, 3, 3, 5], [2, 4, 6, 8], [3, 6, 7, 7], [4, 3, 8, 2]], - [[5, 7, 6, 3], [3, 5, 6, 7], [9, 4, 2, 5], [7, 5, 8, 1]]]]).astype(np.float32) -expect_output_fwd = np.array([[[[-0.6059, 0.3118, 0.3118, 1.2294], - [-0.1471, 0.7706, 1.6882, 2.6059], - [0.3118, 1.6882, 2.1471, 2.1471], - [0.7706, 0.3118, 2.6059, -0.1471]], - [[0.9119, 1.8518, 1.3819, -0.0281], - [-0.0281, 0.9119, 1.3819, 1.8518], - [2.7918, 0.4419, -0.4981, 0.9119], - [1.8518, 0.9119, 2.3218, -0.9680]]]]).astype(np.float32) -grad_back = np.array([[[[1, 2, 7, 1], [4, 2, 1, 3], [1, 6, 5, 2], [2, 4, 3, 2]], - [[9, 4, 3, 5], [1, 3, 7, 6], [5, 7, 9, 9], [1, 4, 6, 8]]]]).astype(np.float32) -expect_output_back = np.array([[[[-0.69126546, -0.32903028, 1.9651246, -0.88445705], - [0.6369296, -0.37732816, -0.93275493, -0.11168876], - [-0.7878612, 1.3614, 0.8542711, -0.52222186], - [-0.37732816, 0.5886317, -0.11168876, -0.28073236]], - [[1.6447213, -0.38968924, -1.0174079, -0.55067265], - [-2.4305856, -1.1751484, 0.86250514, 0.5502673], - [0.39576983, 0.5470243, 1.1715001, 1.6447213], - [-1.7996241, -0.7051701, 0.7080077, 0.5437813]]]]).astype(np.float32) - -class Net(nn.Cell): - def __init__(self, c): - super(Net, self).__init__() - self.num_features = c - self.eps = 1e-5 - self.momentum = 1 - self.mode = True - self.affine = True - self.sync_bn_op = nn.SyncBatchNorm(num_features=self.num_features, - eps=self.eps, - momentum=self.momentum, - affine=self.affine, - gamma_init='ones', - beta_init='ones', - moving_mean_init='ones', - moving_var_init='ones', - use_batch_statistics=True, - process_groups=None) - def construct(self, input_data): - return self.sync_bn_op(input_data) - -class Grad(nn.Cell): - def __init__(self, network): - super(Grad, self).__init__() - self.grad = C.GradOperation(get_all=True, sens_param=True) - self.network = network - - def construct(self, input_data, sens): - gout = self.grad(self.network)(input_data, sens) - return gout - -def test_sync_batch_norm_forward_fp32_graph(): - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - init() - x = x_fwd_input.copy().astype(np.float32) - expect_output = expect_output_fwd.copy().astype(np.float32) - overall_shape = x.shape - error = np.ones(shape=overall_shape) * 1.0e-4 - net = Net(2) - net.set_train() - output = net(Tensor(x)) - diff = output.asnumpy() - expect_output - assert np.all(diff < error) - assert np.all(-diff < error) - -def test_sync_batch_norm_forward_fp16_pynative(): - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - init() - x = x_fwd_input.copy().astype(np.float16) - expect_output = expect_output_fwd.copy().astype(np.float16) - overall_shape = x.shape - error = np.ones(shape=overall_shape) * 1.0e-3 - net = Net(2) - net.set_train() - output = net(Tensor(x)) - diff = output.asnumpy() - expect_output - assert np.all(diff < error) - assert np.all(-diff < error) - -def test_sync_batch_norm_backwards_fp32_graph(): - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - init() - x = x_fwd_input.copy().astype(np.float32) - expect_output = expect_output_back.copy().astype(np.float32) - grad = grad_back.copy().astype(np.float32) - overall_shape = x.shape - error = np.ones(shape=overall_shape) * 1.0e-5 - fwd_net = Net(2) - fwd_net.set_train() - bn_grad = Grad(fwd_net) - output = bn_grad(Tensor(x), Tensor(grad)) - diff = output[0].asnumpy() - expect_output - assert np.all(diff < error) - assert np.all(-diff < error) - -def test_sync_batch_norm_backwards_fp16_pynative(): - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - init() - x = x_fwd_input.copy().astype(np.float16) - expect_output = expect_output_back.copy().astype(np.float16) - grad = grad_back.copy().astype(np.float16) - overall_shape = x.shape - error = np.ones(shape=overall_shape) * 1.0e-3 - fwd_net = Net(2) - fwd_net.set_train() - bn_grad = Grad(fwd_net) - output = bn_grad(Tensor(x), Tensor(grad)) - diff = output[0].asnumpy() - expect_output - assert np.all(diff < error) - assert np.all(-diff < error) diff --git a/tests/st/nccl/test_nccl_sync_batch_norm_op_all.py b/tests/st/nccl/test_nccl_sync_batch_norm_op_all.py deleted file mode 100644 index d52d7b55b28..00000000000 --- a/tests/st/nccl/test_nccl_sync_batch_norm_op_all.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - - -import os -import pytest - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_single -def test_nccl_sync_batch_norm_1(): - cmd_str = "mpirun -n 4 pytest -s test_nccl_sync_batch_norm_op.py::test_sync_batch_norm_forward_fp32_graph" - return_code = os.system(cmd_str) - assert return_code == 0 - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_single -def test_nccl_sync_batch_norm_2(): - cmd_str = "mpirun -n 4 pytest -s test_nccl_sync_batch_norm_op.py::test_sync_batch_norm_forward_fp16_pynative" - return_code = os.system(cmd_str) - assert return_code == 0 - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_single -def test_nccl_sync_batch_norm_3(): - cmd_str = "mpirun -n 1 pytest -s test_nccl_sync_batch_norm_op.py::test_sync_batch_norm_backwards_fp32_graph" - return_code = os.system(cmd_str) - assert return_code == 0 - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_single -def test_nccl_sync_batch_norm_4(): - cmd_str = "mpirun -n 1 pytest -s test_nccl_sync_batch_norm_op.py::test_sync_batch_norm_backwards_fp16_pynative" - return_code = os.system(cmd_str) - assert return_code == 0