forked from mindspore-Ecosystem/mindspore
fixed lamb cudnnReduceTensor error when shape is [1]
This commit is contained in:
parent
27d391ae66
commit
b8130e69aa
|
@ -281,31 +281,41 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
|
|||
MS_LOG(EXCEPTION) << "var_float or grad_float or g_hat_var is null";
|
||||
}
|
||||
|
||||
float *reduce_workspace_addr = GetPossiblyNullDeviceAddress<float>(workspaces, kReduceWorkspaceIndex);
|
||||
float *w_norm_ptr = GetDeviceAddress<float>(workspaces, kWNormIndex);
|
||||
float *g_norm_ptr = GetDeviceAddress<float>(workspaces, kGNormIndex);
|
||||
float *g_hat_norm_ptr = GetDeviceAddress<float>(workspaces, kGHatNormIndex);
|
||||
float *w_norm_ptr = nullptr;
|
||||
float *g_norm_ptr = nullptr;
|
||||
float *g_hat_norm_ptr = nullptr;
|
||||
|
||||
// Calc sum of square
|
||||
constexpr float alpha = 1;
|
||||
constexpr float beta = 0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr,
|
||||
workspace_size_list_[kReduceWorkspaceIndex], &alpha, input_descriptor_, var_float, &beta,
|
||||
output_descriptor_, w_norm_ptr),
|
||||
"For " + kernel_name_ + " cudnnReduceTensor for 'var_float' failed");
|
||||
if (is_all_match_) {
|
||||
w_norm_ptr = var_float;
|
||||
g_norm_ptr = grad_float;
|
||||
g_hat_norm_ptr = g_hat_var;
|
||||
} else {
|
||||
float *reduce_workspace_addr = GetPossiblyNullDeviceAddress<float>(workspaces, kReduceWorkspaceIndex);
|
||||
w_norm_ptr = GetDeviceAddress<float>(workspaces, kWNormIndex);
|
||||
g_norm_ptr = GetDeviceAddress<float>(workspaces, kGNormIndex);
|
||||
g_hat_norm_ptr = GetDeviceAddress<float>(workspaces, kGHatNormIndex);
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr,
|
||||
workspace_size_list_[kReduceWorkspaceIndex], &alpha, input_descriptor_, grad_float, &beta,
|
||||
output_descriptor_, g_norm_ptr),
|
||||
"For " + kernel_name_ + " cudnnReduceTensor for 'grad_float' failed");
|
||||
// Calc sum of square
|
||||
constexpr float alpha = 1;
|
||||
constexpr float beta = 0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr,
|
||||
workspace_size_list_[kReduceWorkspaceIndex], &alpha, input_descriptor_, var_float, &beta,
|
||||
output_descriptor_, w_norm_ptr),
|
||||
"For " + kernel_name_ + " cudnnReduceTensor for 'var_float' failed");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr,
|
||||
workspace_size_list_[kReduceWorkspaceIndex], &alpha, input_descriptor_, g_hat_var, &beta,
|
||||
output_descriptor_, g_hat_norm_ptr),
|
||||
"For " + kernel_name_ + " cudnnReduceTensor for 'g_hat_var' failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr,
|
||||
workspace_size_list_[kReduceWorkspaceIndex], &alpha, input_descriptor_, grad_float, &beta,
|
||||
output_descriptor_, g_norm_ptr),
|
||||
"For " + kernel_name_ + " cudnnReduceTensor for 'grad_float' failed");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr,
|
||||
workspace_size_list_[kReduceWorkspaceIndex], &alpha, input_descriptor_, g_hat_var, &beta,
|
||||
output_descriptor_, g_hat_norm_ptr),
|
||||
"For " + kernel_name_ + " cudnnReduceTensor for 'g_hat_var' failed");
|
||||
}
|
||||
|
||||
float w_norm = 0;
|
||||
float g_norm = 0;
|
||||
|
@ -331,9 +341,14 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
|
|||
|
||||
void InitShapeInfo(const ShapeVector &input_shape, const ShapeVector &output_shape) {
|
||||
// Determine which dimension will be reduced.
|
||||
is_all_match_ = false;
|
||||
ShapeVector reduce_output_shape = output_shape;
|
||||
std::fill(reduce_output_shape.begin(), reduce_output_shape.end(), 1);
|
||||
|
||||
if (input_shape == reduce_output_shape) {
|
||||
is_all_match_ = true;
|
||||
}
|
||||
|
||||
// Infer input and output descriptor.
|
||||
InferInAndOutDesc(input_shape, reduce_output_shape);
|
||||
}
|
||||
|
@ -383,6 +398,7 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
|
|||
size_t trust_ratio_size_{0};
|
||||
size_t reduce_output_size_{0};
|
||||
bool is_null_input_{false};
|
||||
bool is_all_match_{false};
|
||||
|
||||
cudnnHandle_t cudnn_handle_{nullptr};
|
||||
cudnnDataType_t data_type_{CUDNN_DATA_FLOAT};
|
||||
|
|
Loading…
Reference in New Issue