!48961 fixed lamb cudnnReduceTensor error when shape is [1]

Merge pull request !48961 from wanghenchang/master_0211
This commit is contained in:
i-robot 2023-02-17 02:46:18 +00:00 committed by Gitee
commit 7cf3697fa6
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 38 additions and 22 deletions

View File

@ -281,31 +281,41 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
MS_LOG(EXCEPTION) << "var_float or grad_float or g_hat_var is null";
}
float *reduce_workspace_addr = GetPossiblyNullDeviceAddress<float>(workspaces, kReduceWorkspaceIndex);
float *w_norm_ptr = GetDeviceAddress<float>(workspaces, kWNormIndex);
float *g_norm_ptr = GetDeviceAddress<float>(workspaces, kGNormIndex);
float *g_hat_norm_ptr = GetDeviceAddress<float>(workspaces, kGHatNormIndex);
float *w_norm_ptr = nullptr;
float *g_norm_ptr = nullptr;
float *g_hat_norm_ptr = nullptr;
// Calc sum of square
constexpr float alpha = 1;
constexpr float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr,
workspace_size_list_[kReduceWorkspaceIndex], &alpha, input_descriptor_, var_float, &beta,
output_descriptor_, w_norm_ptr),
"For " + kernel_name_ + " cudnnReduceTensor for 'var_float' failed");
if (is_all_match_) {
w_norm_ptr = var_float;
g_norm_ptr = grad_float;
g_hat_norm_ptr = g_hat_var;
} else {
float *reduce_workspace_addr = GetPossiblyNullDeviceAddress<float>(workspaces, kReduceWorkspaceIndex);
w_norm_ptr = GetDeviceAddress<float>(workspaces, kWNormIndex);
g_norm_ptr = GetDeviceAddress<float>(workspaces, kGNormIndex);
g_hat_norm_ptr = GetDeviceAddress<float>(workspaces, kGHatNormIndex);
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr,
workspace_size_list_[kReduceWorkspaceIndex], &alpha, input_descriptor_, grad_float, &beta,
output_descriptor_, g_norm_ptr),
"For " + kernel_name_ + " cudnnReduceTensor for 'grad_float' failed");
// Calc sum of square
constexpr float alpha = 1;
constexpr float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr,
workspace_size_list_[kReduceWorkspaceIndex], &alpha, input_descriptor_, var_float, &beta,
output_descriptor_, w_norm_ptr),
"For " + kernel_name_ + " cudnnReduceTensor for 'var_float' failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr,
workspace_size_list_[kReduceWorkspaceIndex], &alpha, input_descriptor_, g_hat_var, &beta,
output_descriptor_, g_hat_norm_ptr),
"For " + kernel_name_ + " cudnnReduceTensor for 'g_hat_var' failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr,
workspace_size_list_[kReduceWorkspaceIndex], &alpha, input_descriptor_, grad_float, &beta,
output_descriptor_, g_norm_ptr),
"For " + kernel_name_ + " cudnnReduceTensor for 'grad_float' failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr,
workspace_size_list_[kReduceWorkspaceIndex], &alpha, input_descriptor_, g_hat_var, &beta,
output_descriptor_, g_hat_norm_ptr),
"For " + kernel_name_ + " cudnnReduceTensor for 'g_hat_var' failed");
}
float w_norm = 0;
float g_norm = 0;
@ -331,9 +341,14 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
void InitShapeInfo(const ShapeVector &input_shape, const ShapeVector &output_shape) {
// Determine which dimension will be reduced.
is_all_match_ = false;
ShapeVector reduce_output_shape = output_shape;
std::fill(reduce_output_shape.begin(), reduce_output_shape.end(), 1);
if (input_shape == reduce_output_shape) {
is_all_match_ = true;
}
// Infer input and output descriptor.
InferInAndOutDesc(input_shape, reduce_output_shape);
}
@ -383,6 +398,7 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
size_t trust_ratio_size_{0};
size_t reduce_output_size_{0};
bool is_null_input_{false};
bool is_all_match_{false};
cudnnHandle_t cudnn_handle_{nullptr};
cudnnDataType_t data_type_{CUDNN_DATA_FLOAT};