forked from mindspore-Ecosystem/mindspore
!37716 [MS][ASCED]add smooth_l1_loss ai core
Merge pull request !37716 from mengyuanli/smooth_l1_loss_ascend
This commit is contained in:
commit
641642b9e2
|
@ -26,7 +26,7 @@ mindspore.nn.SmoothL1Loss
|
|||
其中,:math:`{\beta}` 代表阈值 `beta` 。
|
||||
|
||||
.. note::
|
||||
- 在Ascend上, 目前不支持将 `reduction` 设定成'sum'或'mean'。
|
||||
- 在Ascend上,目前不支持 `logits` 的数据类型是float64。
|
||||
- SmoothL1Loss可以看成 :class:`mindspore.nn.L1Loss` 的修改版本,也可以看成 :class:`mindspore.nn.L1Loss` 和 :class:`mindspore.ops.L2Loss` 的组合。
|
||||
- :class:`mindspore.nn.L1Loss` 计算两个输入Tensor之间的绝对误差,而 :class:`mindspore.ops.L2Loss` 计算两个输入Tensor之间的平方误差。
|
||||
- :class:`mindspore.ops.L2Loss` 通常更快收敛,但对离群值的鲁棒性较差。该损失函数具有较好的鲁棒性。
|
||||
|
@ -36,7 +36,7 @@ mindspore.nn.SmoothL1Loss
|
|||
- **reduction** (str) - 缩减输出的方法。默认值:'none'。其他选项:'mean'和'sum'。
|
||||
|
||||
输入:
|
||||
- **logits** (Tensor) - 预测值,任意维度Tensor。数据类型为float16、float32或float64。
|
||||
- **logits** (Tensor) - 预测值,任意维度Tensor。数据类型为float16或float32, CPU和GPU后端还支持float64。
|
||||
- **labels** (Tensor) - 目标值,数据类型和shape与 `logits` 相同的Tensor。
|
||||
|
||||
输出:
|
||||
|
@ -48,5 +48,6 @@ mindspore.nn.SmoothL1Loss
|
|||
- **TypeError** - `logits` 或 `labels` 不是Tensor。
|
||||
- **TypeError** - `logits` 或 `labels` 的数据类型不是float16,float32和float64中的任一者。
|
||||
- **TypeError** - `logits` 的数据类型与 `labels` 不同。
|
||||
- **ValueError** - `beta` 小于或等于0。
|
||||
- **ValueError** - `beta` 小于0。
|
||||
- **ValueError** - `logits` 的shape与 `labels` 不同。
|
||||
- **TypeError** - Ascend后端不支持数据类型是float64的 `logits` 输入。
|
||||
|
|
|
@ -28,10 +28,10 @@ mindspore.ops.smooth_l1_loss
|
|||
其中, :math:`\beta` 代表阈值 `beta` 。 :math:`N` 为batch size。
|
||||
|
||||
.. note::
|
||||
在Ascend上,目前不支持将 `reduction` 设定成'sum'或'mean'。
|
||||
在Ascend上,目前不支持 `logits` 的数据类型是float64。
|
||||
|
||||
参数:
|
||||
- **logits** (Tensor) - shape: :math:`(N, *)` ,其中 :math:`*` 表示任意数量的附加维度。数据类型支持float16、float32或float64。
|
||||
- **logits** (Tensor) - shape: :math:`(N, *)` ,其中 :math:`*` 表示任意数量的附加维度。数据类型为float16或float32, CPU和GPU后端还支持float64。
|
||||
- **labels** (Tensor) - shape: :math:`(N, *)` ,与 `logits` 的shape和数据类型相同。
|
||||
- **beta** (float) - 控制损失函数在L1Loss和L2Loss间变换的阈值。默认值:1.0。
|
||||
- **reduction** (str) - 缩减输出的方法。默认值:'none'。 其他选项:'mean'和'sum'。
|
||||
|
@ -43,5 +43,6 @@ mindspore.ops.smooth_l1_loss
|
|||
- **TypeError** - `beta` 不是float类型。
|
||||
- **ValueError** - `reduction` 不是'none','mean'和'sum'中的任一者。
|
||||
- **TypeError** - `logits` 或 `labels` 的数据类型不是float16,float32和float64中的任一者。
|
||||
- **ValueError** - `beta` 小于或等于0。
|
||||
- **ValueError** - `beta` 小于0。
|
||||
- **ValueError** - `logits` 与 `labels` 的shape不同。
|
||||
- **TypeError** - Ascend后端不支持数据类型是float64的 `logits` 输入。
|
||||
|
|
|
@ -481,7 +481,7 @@ class SmoothL1Loss(LossBase):
|
|||
\end{cases}
|
||||
|
||||
.. note::
|
||||
For Ascend platform, the 'reduction' is not support set to 'sum' or 'mean'.
|
||||
For Ascend platform, the float64 data type of `logits` is not support now.
|
||||
SmoothL1Loss can be regarded as modified version of L1Loss or a combination of L1Loss and L2Loss.
|
||||
L1Loss computes the element-wise absolute difference between two input tensors while L2Loss computes the
|
||||
squared difference between two input tensors. L2Loss often leads to faster convergence but it is less
|
||||
|
@ -510,6 +510,7 @@ class SmoothL1Loss(LossBase):
|
|||
TypeError: If dtype of `logits` is not the same as `labels`.
|
||||
ValueError: If `beta` is less than or equal to 0.
|
||||
ValueError: If shape of `logits` is not the same as `labels`.
|
||||
ValueError: The float64 data type of `logits` is support on Ascend platform.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
@ -526,17 +527,11 @@ class SmoothL1Loss(LossBase):
|
|||
def __init__(self, beta=1.0, reduction='none'):
|
||||
"""Initialize SmoothL1Loss."""
|
||||
super(SmoothL1Loss, self).__init__(reduction)
|
||||
target = context.get_context("device_target")
|
||||
if reduction != 'none' and target.lower() == "ascend":
|
||||
raise ValueError(f"Currently Ascend device_target only support `reduction`='none', "
|
||||
f"but got {reduction}")
|
||||
self.beta = beta
|
||||
self.reduction = reduction
|
||||
self.smooth_l1_loss = P.SmoothL1Loss(self.beta, self.reduction)
|
||||
|
||||
def construct(self, logits, labels):
|
||||
_check_is_tensor('logits', logits, self.cls_name)
|
||||
_check_is_tensor('labels', labels, self.cls_name)
|
||||
return self.smooth_l1_loss(logits, labels)
|
||||
|
||||
|
||||
|
|
|
@ -19,11 +19,12 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
|||
smooth_l1_loss_op_info = TBERegOp("SmoothL1Loss") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("smooth_l1_loss.so") \
|
||||
.binfile_name("smooth_l1_loss_v2.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("smooth_l1_loss") \
|
||||
.kernel_name("smooth_l1_loss_v2") \
|
||||
.partial_flag(True) \
|
||||
.attr("beta", "required", "float", "all") \
|
||||
.attr("beta", "optional", "float", "all") \
|
||||
.attr("reduction", "optional", "str", "all") \
|
||||
.input(0, "predict", False, "required", "all") \
|
||||
.input(1, "label", False, "required", "all") \
|
||||
.output(0, "loss", False, "required", "all") \
|
||||
|
|
|
@ -19,12 +19,13 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
|||
smooth_l1_loss_op_info = TBERegOp("SmoothL1Loss") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("smooth_l1_loss.so") \
|
||||
.binfile_name("smooth_l1_loss_v2.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("smooth_l1_loss") \
|
||||
.kernel_name("smooth_l1_loss_v2") \
|
||||
.partial_flag(True) \
|
||||
.dynamic_shape(True) \
|
||||
.attr("beta", "required", "float", "all") \
|
||||
.attr("beta", "optional", "float", "all") \
|
||||
.attr("reduction", "optional", "str", "all") \
|
||||
.input(0, "predict", False, "required", "all") \
|
||||
.input(1, "label", False, "required", "all") \
|
||||
.output(0, "loss", False, "required", "all") \
|
||||
|
|
|
@ -19,11 +19,12 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
|||
smooth_l1_loss_grad_op_info = TBERegOp("SmoothL1LossGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("smooth_l1_loss_grad.so") \
|
||||
.binfile_name("smooth_l1_loss_grad_v2.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("smooth_l1_loss_grad") \
|
||||
.kernel_name("smooth_l1_loss_grad_v2") \
|
||||
.partial_flag(True) \
|
||||
.attr("beta", "required", "float", "all") \
|
||||
.attr("beta", "optional", "float", "all") \
|
||||
.attr("reduction", "optional", "str", "all") \
|
||||
.input(0, "predict", False, "required", "all") \
|
||||
.input(1, "label", False, "required", "all") \
|
||||
.input(2, "dout", False, "required", "all") \
|
||||
|
|
|
@ -19,12 +19,13 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
|||
smooth_l1_loss_grad_op_info = TBERegOp("SmoothL1LossGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("smooth_l1_loss_grad.so") \
|
||||
.binfile_name("smooth_l1_loss_grad_v2.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("smooth_l1_loss_grad") \
|
||||
.kernel_name("smooth_l1_loss_grad_v2") \
|
||||
.partial_flag(True) \
|
||||
.dynamic_shape(True) \
|
||||
.attr("beta", "required", "float", "all") \
|
||||
.attr("beta", "optional", "float", "all") \
|
||||
.attr("reduction", "optional", "str", "all") \
|
||||
.input(0, "predict", False, "required", "all") \
|
||||
.input(1, "label", False, "required", "all") \
|
||||
.input(2, "dout", False, "required", "all") \
|
||||
|
|
|
@ -1548,7 +1548,7 @@ def smooth_l1_loss(logits, labels, beta=1.0, reduction='none'):
|
|||
Its default value is 1.0. :math:`N` is the batch size.
|
||||
|
||||
Note:
|
||||
For Ascend platform, the 'reduction' is not support set to 'sum' or 'mean' for now.
|
||||
For Ascend platform, the float64 data type of `logits` is not support now.
|
||||
|
||||
Args:
|
||||
logits (Tensor): Tensor of shape :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
||||
|
@ -1567,6 +1567,7 @@ def smooth_l1_loss(logits, labels, beta=1.0, reduction='none'):
|
|||
TypeError: If dtype of `logits` or `labels` is neither float16 nor float32.
|
||||
ValueError: If `beta` is less than or equal to 0.
|
||||
ValueError: If shape of `logits` is not the same as `labels`.
|
||||
TypeError: The float64 data type of `logits` is support on Ascend platform.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
|
|
@ -2957,10 +2957,6 @@ class SmoothL1Loss(Primitive):
|
|||
validator.check_string(
|
||||
reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
|
||||
self.init_prim_io_names(inputs=['prediction', 'target'], outputs=['output'])
|
||||
target = context.get_context("device_target")
|
||||
if reduction != 'none' and target.lower() == "ascend":
|
||||
raise ValueError(f"Currently Ascend device_target only support `reduction`='none', "
|
||||
f"but got {reduction}")
|
||||
|
||||
|
||||
class MultiMarginLoss(Primitive):
|
||||
|
|
|
@ -14,31 +14,62 @@
|
|||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, sigma=1.0):
|
||||
super(Net, self).__init__()
|
||||
self.SmoothL1Loss = P.SmoothL1Loss(sigma)
|
||||
def smoothl1loss(beta, reduction):
|
||||
np.random.seed(42)
|
||||
prediction = np.random.randn(20).astype(np.float32)
|
||||
target = np.random.randn(20).astype(np.float32)
|
||||
|
||||
def construct(self, pred, gt):
|
||||
return self.SmoothL1Loss(pred, gt)
|
||||
net = nn.SmoothL1Loss(beta, reduction)
|
||||
return net(Tensor(prediction), Tensor(target))
|
||||
|
||||
|
||||
def test_net():
|
||||
pred = np.random.randn(2, 4).astype(np.float32)
|
||||
gt = np.random.randn(2, 4).astype(np.float32)
|
||||
smooth_l1_loss = Net()
|
||||
loss = smooth_l1_loss(Tensor(pred), Tensor(gt))
|
||||
print("------------- input ---------------")
|
||||
print("predict:\n", pred)
|
||||
print("grount truth:\n", gt)
|
||||
print("------------- output ---------------")
|
||||
print("loss:\n", loss.asnumpy())
|
||||
def verify_forward(reduction, loss, expect):
|
||||
if reduction == 'none':
|
||||
np.testing.assert_array_almost_equal(loss, expect)
|
||||
elif reduction == "sum":
|
||||
expect_sum = np.sum(expect)
|
||||
np.testing.assert_array_almost_equal(loss, expect_sum, decimal=5)
|
||||
elif reduction == "mean":
|
||||
expect_mean = np.mean(expect)
|
||||
np.testing.assert_array_almost_equal(loss, expect_mean)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize("reduction", ['none', 'mean', 'sum'])
|
||||
def test_smoothl1loss(reduction):
|
||||
"""
|
||||
Feature: SmoothL1Loss cpu kernel.
|
||||
Description: test the rightness of SmoothL1Loss cpu kernel.
|
||||
Expectation: the output is same as expect.
|
||||
"""
|
||||
|
||||
beta = 1.0
|
||||
loss = smoothl1loss(beta, reduction)
|
||||
expect = np.array([0.46941718, 0.00382918, 0.16829303, 2.447778, 0.04812113, 0.05953304,
|
||||
2.2302065, 0.07672881, 0.00860204, 0.34798968, 0.00956192, 1.818008,
|
||||
0.03262977, 0.36599946, 2.047463, 0.2168481, 0.7216947, 1.7739174,
|
||||
0.08826803, 1.109165])
|
||||
|
||||
verify_forward(reduction, loss.asnumpy(), expect)
|
||||
|
||||
beta = 1 / 9
|
||||
loss = smoothl1loss(beta, reduction)
|
||||
expect = np.array([0.9133791, 0.03446258, 0.5246048, 2.8922224, 0.2546738, 0.289504,
|
||||
2.674651, 0.33618113, 0.07560876, 0.7786982, 0.08273339, 2.2624524,
|
||||
0.19990394, 0.8000138, 2.4919074, 0.6030006, 1.1661391, 2.2183619,
|
||||
0.3646064, 1.5536094])
|
||||
|
||||
verify_forward(reduction, loss.asnumpy(), expect)
|
||||
|
|
|
@ -14,44 +14,118 @@
|
|||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.composite import GradOperation
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, sigma=1.0):
|
||||
super(Net, self).__init__()
|
||||
self.SmoothL1Loss = P.SmoothL1Loss(sigma)
|
||||
|
||||
def construct(self, pred, gt):
|
||||
return self.SmoothL1Loss(pred, gt)
|
||||
|
||||
|
||||
class Grad(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(Grad, self).__init__()
|
||||
self.grad = GradOperation(get_all=True, sens_param=True)
|
||||
self.network = network
|
||||
|
||||
def construct(self, pred, gt, dout):
|
||||
return self.grad(self.network)(pred, gt, dout)
|
||||
def construct(self, x1, x2, sens):
|
||||
gout = self.grad(self.network)(x1, x2, sens)
|
||||
return gout
|
||||
|
||||
|
||||
def test_net():
|
||||
pred = np.random.randn(2, 4).astype(np.float32)
|
||||
gt = np.random.randn(2, 4).astype(np.float32)
|
||||
dout = np.random.randn(2, 4).astype(np.float32)
|
||||
smooth_l1_loss_grad = Grad(Net())
|
||||
output = smooth_l1_loss_grad(Tensor(pred), Tensor(gt), Tensor(dout))
|
||||
print("------------- input ---------------")
|
||||
print("predict:\n", pred)
|
||||
print("grount truth:\n", gt)
|
||||
print("dout:\n", dout)
|
||||
print("------------- output ---------------")
|
||||
print("predict grad:\n", output[0].asnumpy())
|
||||
def smoothl1loss_grad(beta):
|
||||
np.random.seed(42)
|
||||
prediction = np.random.randn(20).astype(np.float32)
|
||||
target = np.random.randn(20).astype(np.float32)
|
||||
sens = np.random.randn(20).astype(np.float32)
|
||||
|
||||
net = nn.SmoothL1Loss(beta)
|
||||
grad = Grad(net)
|
||||
return grad(Tensor(prediction), Tensor(target), Tensor(sens))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_smoothl1loss_grad_no_reduce():
|
||||
"""
|
||||
Feature: SmoothL1LossGrad cpu kernel.
|
||||
Description: test the rightness of SmoothL1LossGrad cpu kernel.
|
||||
Expectation: the output is same as expect.
|
||||
"""
|
||||
|
||||
epsilon = 1e-6
|
||||
|
||||
beta = 1.0
|
||||
dx = smoothl1loss_grad(beta)
|
||||
dx1_expect = np.array([-0.71552587, 0.01499678, -0.06709455, -0.30110368, -0.45868093,
|
||||
0.24838912, -0.46063876, 0.41411355, 0.04507046, -1.4708229,
|
||||
0.04481723, 0.38508227, -0.17292616, -0.52333146, -1.0309995,
|
||||
0.61330026, 0.83921754, -0.3092124, 0.1391843, -0.9755451], dtype=np.float32)
|
||||
|
||||
dx2_expect = -dx1_expect
|
||||
|
||||
diff1 = np.absolute(dx[0].asnumpy() - dx1_expect)
|
||||
diff2 = np.absolute(dx[1].asnumpy() - dx2_expect)
|
||||
assert(diff1 < epsilon).all()
|
||||
assert(diff2 < epsilon).all()
|
||||
|
||||
beta = 1 / 9
|
||||
dx = smoothl1loss_grad(beta)
|
||||
dx1_expect = np.array([-0.73846656, 0.13497104, -0.11564828, -0.30110368, -1.478522,
|
||||
0.7198442, -0.46063876, 1.0571222, 0.3436183, -1.7630402,
|
||||
0.32408398, 0.38508227, -0.676922, -0.6116763, -1.0309995,
|
||||
0.93128014, 0.83921754, -0.3092124, 0.33126342, -0.9755451], dtype=np.float32)
|
||||
|
||||
dx2_expect = -dx1_expect
|
||||
|
||||
diff1 = np.absolute(dx[0].asnumpy() - np.array(dx1_expect))
|
||||
diff2 = np.absolute(dx[1].asnumpy() - np.array(dx2_expect))
|
||||
assert(diff1 < epsilon).all()
|
||||
assert(diff2 < epsilon).all()
|
||||
|
||||
|
||||
def smoothl1loss_grad_2(beta, reduction):
|
||||
prediction = np.array([1, 2, 3, 4, 5, 6], dtype=np.float32)
|
||||
target = np.array([100, 2, 7, 32, 34, 1], dtype=np.float32)
|
||||
sens = np.array([9], dtype=np.float32)
|
||||
|
||||
net = nn.SmoothL1Loss(beta, reduction)
|
||||
grad = Grad(net)
|
||||
return grad(Tensor(prediction), Tensor(target), Tensor(sens))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize("reduction", ['mean', 'sum'])
|
||||
def test_smoothl1loss_grad_sum(reduction):
|
||||
"""
|
||||
Feature: SmoothL1LossGrad cpu kernel, reduction = sum.
|
||||
Description: test the rightness of SmoothL1LossGrad cpu kernel.
|
||||
Expectation: the output is same as expect.
|
||||
"""
|
||||
|
||||
beta = 1.0
|
||||
dx = smoothl1loss_grad_2(beta, reduction)
|
||||
|
||||
sum_dx1_expect = np.array([-9, 0, -9, -9, -9, 9], dtype=np.float32)
|
||||
sum_dx2_expect = -sum_dx1_expect
|
||||
|
||||
mean_dx1_expect = np.array(
|
||||
[-1.5, 0, -1.5, -1.5, -1.5, 1.5], dtype=np.float32)
|
||||
mean_dx2_expect = -mean_dx1_expect
|
||||
|
||||
print("dx[0].asnumpy()", dx[0].asnumpy())
|
||||
print("dx[1].asnumpy()", dx[1].asnumpy())
|
||||
|
||||
if reduction == 'sum':
|
||||
np.testing.assert_array_almost_equal(dx[0].asnumpy(), sum_dx1_expect)
|
||||
np.testing.assert_array_almost_equal(dx[1].asnumpy(), sum_dx2_expect)
|
||||
if reduction == 'mean':
|
||||
np.testing.assert_array_almost_equal(dx[0].asnumpy(), mean_dx1_expect)
|
||||
np.testing.assert_array_almost_equal(dx[1].asnumpy(), mean_dx2_expect)
|
||||
|
|
Loading…
Reference in New Issue