!48195 fix_apply_proximal_grad_descent_st_master
Merge pull request !48195 from yide12/apply_proximal_gradient_descent_master
This commit is contained in:
commit
57669ee207
|
@ -66,7 +66,7 @@ __global__ void CalApplyProximalGradientDescentKernel(const size_t input_element
|
|||
pos += gridDim.x * blockDim.x) {
|
||||
auto prox_v = var[pos];
|
||||
prox_v -= delta[pos] * alpha[0];
|
||||
output[pos] = SgnFunc(prox_v) * MaxFunc(AbsFunc(prox_v) - alpha[0] * l1[0], static_cast<T>(0.0)) /
|
||||
var[pos] = SgnFunc(prox_v) * MaxFunc(AbsFunc(prox_v) - alpha[0] * l1[0], static_cast<T>(0.0)) /
|
||||
(static_cast<T>(1) + l2[0] * alpha[0]);
|
||||
}
|
||||
} else {
|
||||
|
@ -74,7 +74,7 @@ __global__ void CalApplyProximalGradientDescentKernel(const size_t input_element
|
|||
pos += gridDim.x * blockDim.x) {
|
||||
auto prox_v = var[pos];
|
||||
prox_v -= delta[pos] * alpha[0];
|
||||
output[pos] = prox_v / (static_cast<T>(1) + l2[0] * alpha[0]);
|
||||
var[pos] = prox_v / (static_cast<T>(1) + l2[0] * alpha[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -131,7 +131,8 @@ std::vector<std::pair<KernelAttr, ApplyProximalGradientDescentGpuKernelMod::Kern
|
|||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutInRef(0, 0),
|
||||
&ApplyProximalGradientDescentGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
|
@ -139,7 +140,8 @@ std::vector<std::pair<KernelAttr, ApplyProximalGradientDescentGpuKernelMod::Kern
|
|||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutInRef(0, 0),
|
||||
&ApplyProximalGradientDescentGpuKernelMod::LaunchKernel<half>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
|
@ -147,7 +149,8 @@ std::vector<std::pair<KernelAttr, ApplyProximalGradientDescentGpuKernelMod::Kern
|
|||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
.AddOutputAttr(kNumberTypeFloat64)
|
||||
.AddOutInRef(0, 0),
|
||||
&ApplyProximalGradientDescentGpuKernelMod::LaunchKernel<double>}};
|
||||
|
||||
std::vector<KernelAttr> ApplyProximalGradientDescentGpuKernelMod::GetOpSupport() {
|
||||
|
|
|
@ -482,7 +482,7 @@ class ApplyProximalGradientDescentNet(nn.Cell):
|
|||
return self.var
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
|
Loading…
Reference in New Issue