dtype
This commit is contained in:
parent
65ce2587b0
commit
f1d3852931
|
@ -72,6 +72,12 @@ bool KLDivLossGradGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inpu
|
||||||
|
|
||||||
std::vector<std::pair<KernelAttr, KLDivLossGradGpuKernelMod::KLDivLossLaunchFunc>>
|
std::vector<std::pair<KernelAttr, KLDivLossGradGpuKernelMod::KLDivLossLaunchFunc>>
|
||||||
KLDivLossGradGpuKernelMod::func_list_ = {
|
KLDivLossGradGpuKernelMod::func_list_ = {
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat64)
|
||||||
|
.AddInputAttr(kNumberTypeFloat64)
|
||||||
|
.AddInputAttr(kNumberTypeFloat64)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
&KLDivLossGradGpuKernelMod::LaunchKernel<double>},
|
||||||
{KernelAttr()
|
{KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
|
|
@ -88,3 +88,27 @@ def test_kl_div_loss_grad():
|
||||||
-0.03094601, -0.14319494]
|
-0.03094601, -0.14319494]
|
||||||
|
|
||||||
assert np.allclose(dx[0].asnumpy(), dx1_expect)
|
assert np.allclose(dx[0].asnumpy(), dx1_expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_kl_div_loss_grad_float64():
|
||||||
|
"""
|
||||||
|
Feature: Test KLDivLossGrad.
|
||||||
|
Description: Test KLDivLossGrad op with float inputs.
|
||||||
|
Expectation: The result match to expect.
|
||||||
|
"""
|
||||||
|
np.random.seed(42)
|
||||||
|
prediction = np.random.rand(20).astype(np.float64)
|
||||||
|
target = np.random.rand(20).astype(np.float64)
|
||||||
|
sens = np.random.rand(20).astype(np.float64)
|
||||||
|
grad = Grad(Net())
|
||||||
|
dx = grad(Tensor(prediction), Tensor(target), Tensor(sens))
|
||||||
|
|
||||||
|
dx1_expect = [-0.07466945, -0.06907414, -0.01004642, -0.3331403, -0.11802178, -0.52019656,
|
||||||
|
-0.06224053, -0.2674369, -0.32387912, -0.00858657, -0.58906615, -0.13217884,
|
||||||
|
-0.06111591, -0.8490888, -0.57735133, -0.7452407, -0.02695603, -0.01914206,
|
||||||
|
-0.03094601, -0.14319494]
|
||||||
|
|
||||||
|
assert np.allclose(dx[0].asnumpy(), dx1_expect)
|
||||||
|
|
Loading…
Reference in New Issue