add_unify_clip_grad_norm_op_on_gpu
This commit is contained in:
parent
db7d28f5c8
commit
7a9fd2d7df
|
@ -0,0 +1,84 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/clip_grad_norm_impl.cuh"
|
||||
|
||||
// The implement of ScalingGradOp
|
||||
template <typename T>
|
||||
__global__ void ScalingGradKernel(const size_t size, const T *x, const float *scaling_factor, float *scaling_out_addr) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
scaling_out_addr[i] = x[i] * (1.0 / scaling_factor[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ScalingGradKernel(const size_t size, const half *x, const float *scaling_factor,
|
||||
float *scaling_out_addr) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
scaling_out_addr[i] = __half2float(x[i]) * (1.0 / scaling_factor[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ScalingGradOp(const size_t size, const T *x, const float *scaling_factor, float *scaling_out_addr,
|
||||
cudaStream_t cuda_stream) {
|
||||
ScalingGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, x, scaling_factor, scaling_out_addr);
|
||||
}
|
||||
|
||||
template void ScalingGradOp<float>(const size_t size, const float *x, const float *scaling_factor,
|
||||
float *scaling_out_addr, cudaStream_t cuda_stream);
|
||||
|
||||
template void ScalingGradOp<half>(const size_t size, const half *x, const float *scaling_factor,
|
||||
float *scaling_out_addr, cudaStream_t cuda_stream);
|
||||
|
||||
// The implement of ClipGradNormOp
|
||||
template <typename T>
|
||||
__global__ void ClipGradNormKernel(const size_t size, const float *x, const T *clip_norm, const float *reduce_sum_value,
|
||||
float *output_addr) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
if (reduce_sum_value[0] > clip_norm[0]) {
|
||||
output_addr[i] = x[i] * clip_norm[0] / reduce_sum_value[0];
|
||||
} else {
|
||||
output_addr[i] = x[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ClipGradNormKernel(const size_t size, const float *x, const half *clip_norm,
|
||||
const float *reduce_sum_value, float *output_addr) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
const float clip_norm_float = __half2float(clip_norm[0]);
|
||||
if (reduce_sum_value[0] > clip_norm_float) {
|
||||
output_addr[i] = x[i] * clip_norm_float / reduce_sum_value[0];
|
||||
} else {
|
||||
output_addr[i] = x[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ClipGradNormOp(const size_t size, const float *x, const T *clip_norm, const float *reduce_sum_value,
|
||||
float *output_addr, cudaStream_t cuda_stream) {
|
||||
ClipGradNormKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, x, clip_norm, reduce_sum_value,
|
||||
output_addr);
|
||||
}
|
||||
|
||||
template void ClipGradNormOp<float>(const size_t size, const float *x, const float *clip_norm,
|
||||
const float *reduce_sum_value, float *output_addr, cudaStream_t cuda_stream);
|
||||
|
||||
template void ClipGradNormOp<half>(const size_t size, const float *x, const half *clip_norm,
|
||||
const float *reduce_sum_value, float *output_addr, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_CLIP_GRAD_NORM_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_CLIP_GRAD_NORM_IMPL_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void ScalingGradOp(const size_t size, const T *x, const float *scaling_factor, float *scaling_out_addr,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void ClipGradNormOp(const size_t size, const float *x, const T *clip_norm, const float *reduce_sum_value,
|
||||
float *output_addr, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_CLIP_GRAD_NORM_IMPL_H_
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/clip_grad_norm_gpu_kernel.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(ClipGradNorm,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
ClipGradNormGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(ClipGradNorm,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
ClipGradNormGpuKernel, half)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,331 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CLIP_GRAD_NORM_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CLIP_GRAD_NORM_GPU_KERNEL_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "utils/log_adapter.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/clip_grad_norm_impl.cuh"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
constexpr size_t kArgMaxDim = 7;
|
||||
|
||||
template <typename T>
|
||||
class ClipGradNormGpuKernel : public GpuKernel {
|
||||
public:
|
||||
ClipGradNormGpuKernel()
|
||||
: cudnn_handle_(nullptr),
|
||||
data_type_(CUDNN_DATA_FLOAT),
|
||||
nan_prop_(CUDNN_NOT_PROPAGATE_NAN),
|
||||
reduce_indices_(CUDNN_REDUCE_TENSOR_NO_INDICES),
|
||||
reduce_tensor_descriptor_(nullptr),
|
||||
input_descriptor_(nullptr),
|
||||
output_descriptor_(nullptr),
|
||||
all_match_(false),
|
||||
is_null_input_(false),
|
||||
x_size_(0),
|
||||
clip_norm_size_(0),
|
||||
scaling_factor_size_(0),
|
||||
output_size_(0),
|
||||
workspace_size_(0) {}
|
||||
|
||||
~ClipGradNormGpuKernel() override { DestroyResource(); }
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
// Get address
|
||||
constexpr size_t input_num_expected = 3;
|
||||
constexpr size_t workspace_num_expected = 3;
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(inputs.size() == input_num_expected, "Size not equal");
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(workspace.size() == workspace_num_expected, "Size not equal");
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(outputs.size() == 1, "Size not equal");
|
||||
constexpr size_t scaling_factor_index = 2;
|
||||
constexpr size_t reduce_out_index = 2;
|
||||
T *x_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *clip_norm_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
float *scaling_factor_addr = GetDeviceAddress<float>(inputs, scaling_factor_index);
|
||||
float *scaling_out_addr = GetDeviceAddress<float>(workspace, 0);
|
||||
float *reduce_workspace_addr = GetPossiblyNullDeviceAddress<float>(workspace, 1);
|
||||
float *reduce_out_addr = GetDeviceAddress<float>(workspace, reduce_out_index);
|
||||
float *output_addr = GetDeviceAddress<float>(outputs, 0);
|
||||
|
||||
// Run gradient tensor scaling.
|
||||
ScalingGradOp(x_size_ / sizeof(T), x_addr, scaling_factor_addr, scaling_out_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
// Run reduce sum operation(keep_dims=True) for gradient tensor.
|
||||
constexpr size_t alpha = 1;
|
||||
constexpr size_t beta = 0;
|
||||
if (all_match_) {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudaMemcpyAsync(reduce_out_addr, scaling_out_addr, workspace_size_list_[reduce_out_index],
|
||||
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync for 'ClipGradNormGpuKernel' failed");
|
||||
} else {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr,
|
||||
workspace_size_list_[1], &alpha, input_descriptor_, scaling_out_addr, &beta,
|
||||
output_descriptor_, reduce_out_addr),
|
||||
"cudnnReduceTensor for 'ClipGradNormGpuKernel' failed");
|
||||
}
|
||||
// Update gradient tensor by argument 'clip_norm'
|
||||
ClipGradNormOp(output_size_ / sizeof(float), scaling_out_addr, clip_norm_addr, reduce_out_addr, output_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(kernel_name == "ClipGradNorm", "Kernel name is not ClipGradNorm");
|
||||
kernel_node_ = kernel_node;
|
||||
// Init resource for cudnnreducetensor operation.
|
||||
InitResource();
|
||||
if (!CheckIONumber(kernel_node)) {
|
||||
return false;
|
||||
}
|
||||
// Check input and output shape
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
size_t input_dim = input_shape.size();
|
||||
if (!CheckValidShape(input_shape, output_shape, input_dim)) {
|
||||
return true;
|
||||
}
|
||||
// Init member variables.
|
||||
InitAxis(kernel_node, output_shape, SizeToInt(input_dim));
|
||||
clip_norm_size_ = sizeof(T);
|
||||
scaling_factor_size_ = sizeof(float);
|
||||
x_size_ = sizeof(T);
|
||||
output_size_ = sizeof(float);
|
||||
std::for_each(output_shape.begin(), output_shape.end(), [this](const size_t &v) {
|
||||
x_size_ *= v;
|
||||
output_size_ *= v;
|
||||
});
|
||||
InitShapeInfo(input_shape, output_shape);
|
||||
// Determine the reduce operation.
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetReduceTensorDescriptor(reduce_tensor_descriptor_, CUDNN_REDUCE_TENSOR_NORM2, CUDNN_DATA_FLOAT, nan_prop_,
|
||||
reduce_indices_, CUDNN_32BIT_INDICES),
|
||||
"cudnnSetReduceTensorDescriptor failed");
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitResource() override {
|
||||
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateReduceTensorDescriptor(&reduce_tensor_descriptor_),
|
||||
"cudnnCreateReduceTensorDescriptor failed.");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&input_descriptor_),
|
||||
"cudnnCreateTensorDescriptor failed.");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&output_descriptor_),
|
||||
"cudnnCreateTensorDescriptor failed.");
|
||||
}
|
||||
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.emplace_back(x_size_);
|
||||
input_size_list_.emplace_back(clip_norm_size_);
|
||||
input_size_list_.emplace_back(scaling_factor_size_);
|
||||
output_size_list_.emplace_back(output_size_);
|
||||
// Init workspace size for gradient tensor scaling calculate.
|
||||
workspace_size_list_.emplace_back(output_size_);
|
||||
// Init workspace size for gradient tensor reduce sum calculate.
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudnnGetReductionWorkspaceSize(cudnn_handle_, reduce_tensor_descriptor_,
|
||||
input_descriptor_, output_descriptor_, &workspace_size_),
|
||||
"cudnnGetReductionWorkspaceSize failed.");
|
||||
workspace_size_list_.emplace_back(workspace_size_);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(output_descriptor_, &workspace_size_),
|
||||
"cudnnGetTensorSizeInBytes failed.");
|
||||
workspace_size_list_.emplace_back(workspace_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
void DestroyResource() noexcept {
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyReduceTensorDescriptor(reduce_tensor_descriptor_),
|
||||
"cudnnDestroyReduceTensorDescriptor failed.");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(input_descriptor_),
|
||||
"cudnnDestroyTensorDescriptor failed.");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(output_descriptor_),
|
||||
"cudnnDestroyTensorDescriptor failed.");
|
||||
}
|
||||
|
||||
bool CheckIONumber(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
constexpr size_t input_num_expected = 3;
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != input_num_expected) {
|
||||
MS_LOG(ERROR) << "The input number of kernel node [" << kernel_node->DebugString() << "] should be "
|
||||
<< input_num_expected << ", but got " << input_num;
|
||||
return false;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "The output number of kernel node [" << kernel_node->DebugString() << "] should be " << 1
|
||||
<< ", but got " << output_num;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CheckValidShape(const std::vector<size_t> &input_shape, const std::vector<size_t> &output_shape,
|
||||
size_t input_dim) {
|
||||
is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(output_shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "For 'ClipGradNormGpuKernel', input or output is null.";
|
||||
InitSizeLists();
|
||||
return false;
|
||||
}
|
||||
if (input_shape.size() != output_shape.size()) {
|
||||
MS_LOG(EXCEPTION) << "The size of input shape: " << input_shape.size()
|
||||
<< " and the size of output shape: " << output_shape.size() << " are different.";
|
||||
}
|
||||
if (input_dim > kArgMaxDim) {
|
||||
MS_LOG(EXCEPTION) << "Broadcast operation is not supported when dim exceeds than " << kArgMaxDim;
|
||||
}
|
||||
CheckTensorSize({input_shape, output_shape});
|
||||
return true;
|
||||
}
|
||||
|
||||
void InitAxis(const CNodePtr &kernel_node, const std::vector<size_t> &output_shape, int input_dim) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (prim->GetAttr("axis")->isa<ValueTuple>() || prim->GetAttr("axis")->isa<ValueList>()) {
|
||||
std::vector<int64_t> attr_axis = GetAttr<std::vector<int64_t>>(kernel_node, "axis");
|
||||
if (!attr_axis.empty()) {
|
||||
std::vector<int> attr_axis_int;
|
||||
(void)std::transform(attr_axis.begin(), attr_axis.end(), std::back_inserter(attr_axis_int),
|
||||
[](const int64_t &v) { return LongToInt(v); });
|
||||
for (const auto &v : attr_axis_int) {
|
||||
v < 0 ? axis_.emplace_back(v + input_dim) : axis_.emplace_back(v);
|
||||
}
|
||||
std::sort(axis_.begin(), axis_.end());
|
||||
auto multiple_ops = std::unique(axis_.begin(), axis_.end());
|
||||
(void)axis_.erase(multiple_ops, axis_.end());
|
||||
}
|
||||
} else if (prim->GetAttr("axis")->isa<Int64Imm>()) {
|
||||
int axis = LongToInt(GetAttr<int64_t>(kernel_node, "axis"));
|
||||
axis < 0 ? axis_.emplace_back(axis + input_dim) : axis_.emplace_back(axis);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The attribute axis type is invalid.";
|
||||
}
|
||||
|
||||
bool exceed_bound =
|
||||
std::any_of(axis_.begin(), axis_.end(), [&input_dim](const int &v) { return v < 0 || v >= input_dim; });
|
||||
if (exceed_bound) {
|
||||
MS_LOG(EXCEPTION) << "For 'ClipGradNormGpuKernel', the value of axis should be in range of [-" << input_dim
|
||||
<< ", " << (input_dim - 1) << "].";
|
||||
}
|
||||
}
|
||||
|
||||
void InitShapeInfo(const std::vector<size_t> &input_shape, const std::vector<size_t> &output_shape) {
|
||||
// Determine which dimension will be reduced.
|
||||
std::vector<size_t> reduce_output_shape = output_shape;
|
||||
if (axis_.empty()) {
|
||||
std::fill(reduce_output_shape.begin(), reduce_output_shape.end(), 1);
|
||||
} else {
|
||||
std::for_each(axis_.begin(), axis_.end(), [&reduce_output_shape](const int &v) { reduce_output_shape[v] = 1; });
|
||||
}
|
||||
// Whether is all matched.
|
||||
all_match_ = true;
|
||||
input_shape_.resize(kArgMaxDim, 1);
|
||||
output_shape_.resize(kArgMaxDim, 1);
|
||||
reduce_output_shape_.resize(kArgMaxDim, 1);
|
||||
for (size_t i = 0; i < output_shape.size(); ++i) {
|
||||
input_shape_[i] = input_shape[i];
|
||||
output_shape_[i] = output_shape[i];
|
||||
reduce_output_shape_[i] = reduce_output_shape[i];
|
||||
if (input_shape_[i] != reduce_output_shape_[i]) {
|
||||
all_match_ = false;
|
||||
}
|
||||
}
|
||||
// Infer input and output descriptor.
|
||||
InferInAndOutDesc(input_shape, reduce_output_shape);
|
||||
}
|
||||
|
||||
void InferInAndOutDesc(const std::vector<size_t> &input_shape, const std::vector<size_t> &reduce_output_shape) {
|
||||
constexpr size_t split_dim = 4;
|
||||
constexpr size_t dim_idx_two = 2;
|
||||
constexpr size_t dim_idx_three = 3;
|
||||
if (input_shape.size() <= split_dim) {
|
||||
std::vector<size_t> new_input_shape;
|
||||
ShapeNdTo4d(input_shape, &new_input_shape);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(input_descriptor_, CUDNN_TENSOR_NCHW, data_type_, new_input_shape[0],
|
||||
new_input_shape[1], new_input_shape[dim_idx_two], new_input_shape[dim_idx_three]),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
} else {
|
||||
CudnnSetTensorNdDescriptor(input_shape, input_descriptor_, data_type_, kernel_node_);
|
||||
}
|
||||
if (reduce_output_shape.size() <= split_dim) {
|
||||
std::vector<size_t> new_reduce_output_shape;
|
||||
ShapeNdTo4d(reduce_output_shape, &new_reduce_output_shape);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(output_descriptor_, CUDNN_TENSOR_NCHW, data_type_, new_reduce_output_shape[0],
|
||||
new_reduce_output_shape[1], new_reduce_output_shape[dim_idx_two],
|
||||
new_reduce_output_shape[dim_idx_three]),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
} else {
|
||||
CudnnSetTensorNdDescriptor(reduce_output_shape, output_descriptor_, data_type_, kernel_node_);
|
||||
}
|
||||
}
|
||||
|
||||
cudnnHandle_t cudnn_handle_;
|
||||
cudnnDataType_t data_type_;
|
||||
cudnnNanPropagation_t nan_prop_;
|
||||
cudnnReduceTensorIndices_t reduce_indices_;
|
||||
cudnnReduceTensorDescriptor_t reduce_tensor_descriptor_;
|
||||
cudnnTensorDescriptor_t input_descriptor_;
|
||||
cudnnTensorDescriptor_t output_descriptor_;
|
||||
|
||||
bool all_match_{false};
|
||||
bool is_null_input_{false};
|
||||
size_t x_size_;
|
||||
size_t clip_norm_size_;
|
||||
size_t scaling_factor_size_;
|
||||
size_t output_size_;
|
||||
size_t workspace_size_;
|
||||
std::vector<int> axis_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
// Used for broadcast operation.
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
std::vector<size_t> reduce_output_shape_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CLIP_GRAD_NORM_GPU_KERNEL_H_
|
|
@ -1475,3 +1475,83 @@ class Cummin(Primitive):
|
|||
def __init__(self, axis):
|
||||
"""Initialize Cummin"""
|
||||
validator.check_value_type('axis', axis, [int], self.name)
|
||||
|
||||
|
||||
class ClipGradNorm(PrimitiveWithInfer):
|
||||
r"""
|
||||
Clips gradient tensor to a maximum :math:`L_2`-norm.
|
||||
|
||||
First, the input gradient tensor will be scaled as:
|
||||
|
||||
.. math::
|
||||
\text{output}(x) = x * \frac{1}{\text{scaling_factor}},
|
||||
|
||||
Then, the output gradient tensor remains the same if the :math:`L_2`-norm of the input gradient tensor
|
||||
is not greater than the argument clip_norm. Otherwise the input gradient tensor will be normalized as:
|
||||
|
||||
.. math::
|
||||
\text{output}(x) = \frac{x * \text{clip_norm}}{L_2(x)},
|
||||
|
||||
where :math:`L_2(x)` is the :math:`L_2`-norm of :math:`x`.
|
||||
|
||||
Args:
|
||||
axis (Union[None, int, tuple(int)]): Computes the L2-norm along the specific dimension.
|
||||
Default: None, all dimensions to calculate.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The gradient tensor of shape N-D. The type must be float32 or float16.
|
||||
- **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`. The type keeps same with 'x'
|
||||
- **scaling_factor** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`, whose type is float32.
|
||||
|
||||
Outputs:
|
||||
Tensor, clipped gradient tensor with the same shape as the `x`, whose type is float32.
|
||||
|
||||
Raises:
|
||||
TypeError: If `axis` is not one of None, int, and tuple(int).
|
||||
TypeError: If dtype of `x` is neither float16 nor float32.
|
||||
TypeError: If shape of 'clip_norm' is neither '()' nor '(1)'.
|
||||
TypeError: If shape of 'scaling_factor' is neither '()' nor '(1)'.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> clip_grad_norm = ops.ClipGradNorm()
|
||||
>>> x = Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32)
|
||||
>>> clip_norm = Tensor(np.array([100]).astype(np.float32))
|
||||
>>> scaling_factor = Tensor(np.array([2]).astype(np.float32))
|
||||
>>> output = clip_grad_norm(x, clip_norm, scaling_factor)
|
||||
>>> print(output.shape)
|
||||
(4, 16)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, axis=None):
|
||||
"""Initialize ClipGradNorm"""
|
||||
if axis is None:
|
||||
axis = ()
|
||||
axis = (axis,) if isinstance(axis, int) else axis
|
||||
validator.check_value_type('axis', axis, [tuple], self.name)
|
||||
if axis:
|
||||
raise TypeError(f"For '{self.name}', the length of 'axis' only supports 0 now, but got {len(axis)}, "
|
||||
f"It will be supported later!")
|
||||
self.add_prim_attr('axis', axis)
|
||||
self.init_attrs['axis'] = axis
|
||||
self.axis = axis
|
||||
|
||||
def infer_shape(self, x_shape, clip_norm_shape, scaling_factor_shape):
|
||||
x_dim = len(x_shape)
|
||||
for _, value in enumerate(self.axis):
|
||||
validator.check_int_range(value, -x_dim, x_dim, Rel.INC_LEFT, 'axis value', self.name)
|
||||
clip_norm_dim = len(clip_norm_shape)
|
||||
validator.check_int_range(clip_norm_dim, 0, 1, Rel.INC_BOTH, 'clip norm dim', self.name)
|
||||
scaling_factor_dim = len(scaling_factor_shape)
|
||||
validator.check_int_range(scaling_factor_dim, 0, 1, Rel.INC_BOTH, 'scaling factor dim', self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type, clip_norm_type, scaling_factor_type):
|
||||
validator.check_tensor_dtype_valid("x_type", x_type, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_tensor_dtype_valid("clip_norm_type", clip_norm_type,
|
||||
[mstype.float16, mstype.float32], self.name)
|
||||
validator.check_tensor_dtype_valid("scaling_factor_type", scaling_factor_type, [mstype.float32], self.name)
|
||||
return mstype.float32
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _inner_ops
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_clip_grad_norm():
|
||||
"""
|
||||
Feature: ClipGradNorm Operation function verification.
|
||||
Description: The calculation results of 'ops.ClipGradNorm' should be same with the 'nn.ClipByNorm'.
|
||||
Expectation: Normal output without assert wrong.
|
||||
"""
|
||||
clip_norm = Tensor(np.array([1.0]), ms.float16)
|
||||
scaling_factor = Tensor(np.array([65536]), ms.float32)
|
||||
|
||||
# test input arg with shape(32, 3, 224, 224)
|
||||
x = np.random.rand(32, 3, 224, 224) * 100 - 50
|
||||
x = Tensor(x, ms.float16)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
scaling_out = x * P.Reciprocal()(scaling_factor)
|
||||
clip_by_norm_out = nn.ClipByNorm()(scaling_out, clip_norm)
|
||||
clip_grad_norm_out = _inner_ops.ClipGradNorm()(x, clip_norm, scaling_factor)
|
||||
assert np.allclose(clip_by_norm_out.asnumpy(), clip_grad_norm_out.asnumpy(), 0.00000001, 0.00000001)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
scaling_out = x * P.Reciprocal()(scaling_factor)
|
||||
clip_by_norm_out = nn.ClipByNorm()(scaling_out, clip_norm)
|
||||
clip_grad_norm_out = _inner_ops.ClipGradNorm()(x, clip_norm, scaling_factor)
|
||||
assert np.allclose(clip_by_norm_out.asnumpy(), clip_grad_norm_out.asnumpy(), 0.00000001, 0.00000001)
|
||||
|
||||
# test input arg with shape(60, 224, 224)
|
||||
x = np.random.rand(60, 224, 224) * 100 - 50
|
||||
x = Tensor(x, ms.float16)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
scaling_out = x * P.Reciprocal()(scaling_factor)
|
||||
clip_by_norm_out = nn.ClipByNorm()(scaling_out, clip_norm)
|
||||
clip_grad_norm_out = _inner_ops.ClipGradNorm()(x, clip_norm, scaling_factor)
|
||||
assert np.allclose(clip_by_norm_out.asnumpy(), clip_grad_norm_out.asnumpy(), 0.00000001, 0.00000001)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
scaling_out = x * P.Reciprocal()(scaling_factor)
|
||||
clip_by_norm_out = nn.ClipByNorm()(scaling_out, clip_norm)
|
||||
clip_grad_norm_out = _inner_ops.ClipGradNorm()(x, clip_norm, scaling_factor)
|
||||
assert np.allclose(clip_by_norm_out.asnumpy(), clip_grad_norm_out.asnumpy(), 0.00000001, 0.00000001)
|
||||
|
||||
# test input arg with shape(21128, 60)
|
||||
x = np.random.rand(21128, 60) * 100 - 50
|
||||
x = Tensor(x, ms.float16)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
scaling_out = x * P.Reciprocal()(scaling_factor)
|
||||
clip_by_norm_out = nn.ClipByNorm()(scaling_out, clip_norm)
|
||||
clip_grad_norm_out = _inner_ops.ClipGradNorm()(x, clip_norm, scaling_factor)
|
||||
assert np.allclose(clip_by_norm_out.asnumpy(), clip_grad_norm_out.asnumpy(), 0.00000001, 0.00000001)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
scaling_out = x * P.Reciprocal()(scaling_factor)
|
||||
clip_by_norm_out = nn.ClipByNorm()(scaling_out, clip_norm)
|
||||
clip_grad_norm_out = _inner_ops.ClipGradNorm()(x, clip_norm, scaling_factor)
|
||||
assert np.allclose(clip_by_norm_out.asnumpy(), clip_grad_norm_out.asnumpy(), 0.00000001, 0.00000001)
|
||||
|
||||
# test input args with shape(60)
|
||||
x = np.random.rand(60) * 100 - 50
|
||||
x = Tensor(x, ms.float16)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
scaling_out = x * P.Reciprocal()(scaling_factor)
|
||||
clip_by_norm_out = nn.ClipByNorm()(scaling_out, clip_norm)
|
||||
clip_grad_norm_out = _inner_ops.ClipGradNorm()(x, clip_norm, scaling_factor)
|
||||
assert np.allclose(clip_by_norm_out.asnumpy(), clip_grad_norm_out.asnumpy(), 0.00000001, 0.00000001)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
scaling_out = x * P.Reciprocal()(scaling_factor)
|
||||
clip_by_norm_out = nn.ClipByNorm()(scaling_out, clip_norm)
|
||||
clip_grad_norm_out = _inner_ops.ClipGradNorm()(x, clip_norm, scaling_factor)
|
||||
assert np.allclose(clip_by_norm_out.asnumpy(), clip_grad_norm_out.asnumpy(), 0.00000001, 0.00000001)
|
Loading…
Reference in New Issue