!48195 fix_apply_proximal_grad_descent_st_master

Merge pull request !48195 from yide12/apply_proximal_gradient_descent_master
This commit is contained in:
i-robot 2023-02-01 10:04:26 +00:00 committed by Gitee
commit 57669ee207
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 9 additions and 6 deletions

View File

@ -66,7 +66,7 @@ __global__ void CalApplyProximalGradientDescentKernel(const size_t input_element
pos += gridDim.x * blockDim.x) {
auto prox_v = var[pos];
prox_v -= delta[pos] * alpha[0];
output[pos] = SgnFunc(prox_v) * MaxFunc(AbsFunc(prox_v) - alpha[0] * l1[0], static_cast<T>(0.0)) /
var[pos] = SgnFunc(prox_v) * MaxFunc(AbsFunc(prox_v) - alpha[0] * l1[0], static_cast<T>(0.0)) /
(static_cast<T>(1) + l2[0] * alpha[0]);
}
} else {
@ -74,7 +74,7 @@ __global__ void CalApplyProximalGradientDescentKernel(const size_t input_element
pos += gridDim.x * blockDim.x) {
auto prox_v = var[pos];
prox_v -= delta[pos] * alpha[0];
output[pos] = prox_v / (static_cast<T>(1) + l2[0] * alpha[0]);
var[pos] = prox_v / (static_cast<T>(1) + l2[0] * alpha[0]);
}
}
}

View File

@ -131,7 +131,8 @@ std::vector<std::pair<KernelAttr, ApplyProximalGradientDescentGpuKernelMod::Kern
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
.AddOutputAttr(kNumberTypeFloat32)
.AddOutInRef(0, 0),
&ApplyProximalGradientDescentGpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
@ -139,7 +140,8 @@ std::vector<std::pair<KernelAttr, ApplyProximalGradientDescentGpuKernelMod::Kern
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
.AddOutputAttr(kNumberTypeFloat16)
.AddOutInRef(0, 0),
&ApplyProximalGradientDescentGpuKernelMod::LaunchKernel<half>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
@ -147,7 +149,8 @@ std::vector<std::pair<KernelAttr, ApplyProximalGradientDescentGpuKernelMod::Kern
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
.AddOutputAttr(kNumberTypeFloat64)
.AddOutInRef(0, 0),
&ApplyProximalGradientDescentGpuKernelMod::LaunchKernel<double>}};
std::vector<KernelAttr> ApplyProximalGradientDescentGpuKernelMod::GetOpSupport() {

View File

@ -482,7 +482,7 @@ class ApplyProximalGradientDescentNet(nn.Cell):
return self.var
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training