dtype
This commit is contained in:
parent
65ce2587b0
commit
f1d3852931
|
@ -72,6 +72,12 @@ bool KLDivLossGradGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inpu
|
|||
|
||||
std::vector<std::pair<KernelAttr, KLDivLossGradGpuKernelMod::KLDivLossLaunchFunc>>
|
||||
KLDivLossGradGpuKernelMod::func_list_ = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&KLDivLossGradGpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
|
|
@ -88,3 +88,27 @@ def test_kl_div_loss_grad():
|
|||
-0.03094601, -0.14319494]
|
||||
|
||||
assert np.allclose(dx[0].asnumpy(), dx1_expect)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_kl_div_loss_grad_float64():
|
||||
"""
|
||||
Feature: Test KLDivLossGrad.
|
||||
Description: Test KLDivLossGrad op with float inputs.
|
||||
Expectation: The result match to expect.
|
||||
"""
|
||||
np.random.seed(42)
|
||||
prediction = np.random.rand(20).astype(np.float64)
|
||||
target = np.random.rand(20).astype(np.float64)
|
||||
sens = np.random.rand(20).astype(np.float64)
|
||||
grad = Grad(Net())
|
||||
dx = grad(Tensor(prediction), Tensor(target), Tensor(sens))
|
||||
|
||||
dx1_expect = [-0.07466945, -0.06907414, -0.01004642, -0.3331403, -0.11802178, -0.52019656,
|
||||
-0.06224053, -0.2674369, -0.32387912, -0.00858657, -0.58906615, -0.13217884,
|
||||
-0.06111591, -0.8490888, -0.57735133, -0.7452407, -0.02695603, -0.01914206,
|
||||
-0.03094601, -0.14319494]
|
||||
|
||||
assert np.allclose(dx[0].asnumpy(), dx1_expect)
|
||||
|
|
Loading…
Reference in New Issue