From 5c9791a802507a7a387810845c65a36cd2a1b551 Mon Sep 17 00:00:00 2001 From: liuxiao Date: Wed, 15 Apr 2020 17:50:34 +0800 Subject: [PATCH] Add Abs\AbsGrad\Sign\SmoothL1Loss\SmoothL1LossGrad and modify TopKV2->TopK for VM --- mindspore/ccsrc/kernel/tbe/tbe_adapter.cc | 1 - mindspore/ops/_op_impl/tbe/__init__.py | 7 ++- mindspore/ops/_op_impl/tbe/abs.py | 41 ++++++++++++++ mindspore/ops/_op_impl/tbe/abs_grad.py | 44 +++++++++++++++ mindspore/ops/_op_impl/tbe/sign.py | 41 ++++++++++++++ mindspore/ops/_op_impl/tbe/smooth_l1_loss.py | 44 +++++++++++++++ .../ops/_op_impl/tbe/smooth_l1_loss_grad.py | 45 +++++++++++++++ .../ops/_op_impl/tbe/{topkv2.py => top_k.py} | 14 ++--- mindspore/ops/op_info_register.py | 1 + .../test_tbe_ops/test_smooth_l1_loss.py | 42 ++++++++++++++ .../test_tbe_ops/test_smooth_l1_loss_grad.py | 55 +++++++++++++++++++ .../{test_topkv2.py => test_topk.py} | 2 +- 12 files changed, 327 insertions(+), 10 deletions(-) create mode 100644 mindspore/ops/_op_impl/tbe/abs.py create mode 100644 mindspore/ops/_op_impl/tbe/abs_grad.py create mode 100644 mindspore/ops/_op_impl/tbe/sign.py create mode 100644 mindspore/ops/_op_impl/tbe/smooth_l1_loss.py create mode 100644 mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py rename mindspore/ops/_op_impl/tbe/{topkv2.py => top_k.py} (86%) create mode 100644 tests/st/ops/davinci/test_tbe_ops/test_smooth_l1_loss.py create mode 100644 tests/st/ops/davinci/test_tbe_ops/test_smooth_l1_loss_grad.py rename tests/st/ops/davinci/test_tbe_ops/{test_topkv2.py => test_topk.py} (97%) diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index 9e4553e0578..3fda5547590 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -42,7 +42,6 @@ static std::map tbe_func_adapter_map = { {"depthwise_conv2d_native", "depthwise_conv2d"}, {"depthwise_conv2d_native_backprop_filter", "depthwise_conv2d_backprop_filter_d"}, {"depthwise_conv2d_native_backprop_input", "depthwise_conv2d_backprop_input_d"}, - {"top_kv2", "top_k"}, {"scatter_nd", "scatter_nd_d"}, {"tile", "tile_d"}, {"gather_v2", "gather_v2_d"}, diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 340cf9efe34..2cffc374912 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -14,6 +14,8 @@ # ============================================================================ """tbe ops""" +from .abs import _abs_tbe +from .abs_grad import _abs_grad_tbe from .adam_apply_one_with_decay import _adam_apply_one_with_decay_tbe from .add import _add_tbe from .add_n import _add_n_tbe @@ -49,7 +51,7 @@ from .sigmoid_cross_entropy_with_logits import _sigmoid_cross_entropy_with_logit from .sigmoid_cross_entropy_with_logits_grad import _sigmoid_cross_entropy_with_logits_grad_tbe from .tensor_add import _tensor_add_tbe from .trans_data import _trans_data_tbe -from .topkv2 import _topk_v2_tbe +from .top_k import _top_k_tbe from .matmul import _matmul_tbe from .sub import _sub_tbe from .reduce_mean_d import _reduce_mean_d_tbe @@ -107,6 +109,7 @@ from .minimum_grad import _minimum_grad_tbe from .maximum_grad import _maximum_grad_tbe from .concat import _concat_tbe from .slice import _slice_tbe +from .sign import _sign_tbe from .greater import _greater_tbe from .clip_by_norm_no_div_sum import _clip_by_norm_no_div_sum_tbe from .clip_by_value import _clip_by_value_tbe @@ -130,6 +133,8 @@ from .resize_nearest_neighbor_grad_d import _resize_nearest_neighbor_grad_d_tbe from .pad_d import _pad_d_tbe from .arg_max_with_value import _arg_max_with_value_tbe from .arg_min_with_value import _arg_min_with_value_tbe +from .smooth_l1_loss import _smooth_l1_loss_tbe +from .smooth_l1_loss_grad import _smooth_l1_loss_grad_tbe from .fused_mul_add import _fused_mul_add_tbe from .fused_mul_add_n import _fused_mul_add_n_tbe from .fused_mul_apply_momentum import _fused_mul_apply_momentum_tbe diff --git a/mindspore/ops/_op_impl/tbe/abs.py b/mindspore/ops/_op_impl/tbe/abs.py new file mode 100644 index 00000000000..30a75812bde --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/abs.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Abs op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +abs_op_info = TBERegOp("Abs") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("abs.so") \ + .compute_cost(10) \ + .kernel_name("abs") \ + .partial_flag(True) \ + .op_pattern("formatAgnostic") \ + .input(0, "x", None, "required", None) \ + .output(0, "y", True, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ + .get_op_info() + + +@op_info_register(abs_op_info) +def _abs_tbe(): + """Abs TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/abs_grad.py b/mindspore/ops/_op_impl/tbe/abs_grad.py new file mode 100644 index 00000000000..ba630f6570f --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/abs_grad.py @@ -0,0 +1,44 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""AbsGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +abs_grad_op_info = TBERegOp("AbsGrad") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("abs_grad.so") \ + .compute_cost(10) \ + .kernel_name("abs_grad") \ + .partial_flag(True) \ + .op_pattern("formatAgnostic") \ + .input(0, "y", None, "required", None) \ + .input(1, "dy", None, "required", None) \ + .output(0, "z", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \ + .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(abs_grad_op_info) +def _abs_grad_tbe(): + """AbsGrad TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/sign.py b/mindspore/ops/_op_impl/tbe/sign.py new file mode 100644 index 00000000000..823715aa9f5 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/sign.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Sign op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +sign_op_info = TBERegOp("Sign") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("sign.so") \ + .compute_cost(10) \ + .kernel_name("sign") \ + .partial_flag(True) \ + .op_pattern("formatAgnostic") \ + .input(0, "x", None, "required", None) \ + .output(0, "y", True, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ + .get_op_info() + + +@op_info_register(sign_op_info) +def _sign_tbe(): + """Sign TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/smooth_l1_loss.py b/mindspore/ops/_op_impl/tbe/smooth_l1_loss.py new file mode 100644 index 00000000000..3723b30c043 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/smooth_l1_loss.py @@ -0,0 +1,44 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SmoothL1Loss op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +smooth_l1_loss_op_info = TBERegOp("SmoothL1Loss") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("smooth_l1_loss.so") \ + .compute_cost(10) \ + .kernel_name("smooth_l1_loss") \ + .partial_flag(True) \ + .attr("sigma", "required", "float", "all") \ + .input(0, "predict", False, "required", "all") \ + .input(1, "label", False, "required", "all") \ + .output(0, "loss", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \ + .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ + .get_op_info() + + +@op_info_register(smooth_l1_loss_op_info) +def _smooth_l1_loss_tbe(): + """SmoothL1Loss TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py b/mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py new file mode 100644 index 00000000000..fa1ae1ec34d --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py @@ -0,0 +1,45 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SmoothL1LossGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +smooth_l1_loss_grad_op_info = TBERegOp("SmoothL1LossGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("smooth_l1_loss_grad.so") \ + .compute_cost(10) \ + .kernel_name("smooth_l1_loss_grad") \ + .partial_flag(True) \ + .attr("sigma", "required", "float", "all") \ + .input(0, "predict", False, "required", "all") \ + .input(1, "label", False, "required", "all") \ + .input(2, "dout", False, "required", "all") \ + .output(0, "loss", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \ + .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ + .get_op_info() + + +@op_info_register(smooth_l1_loss_grad_op_info) +def _smooth_l1_loss_grad_tbe(): + """SmoothL1LossGrad TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/topkv2.py b/mindspore/ops/_op_impl/tbe/top_k.py similarity index 86% rename from mindspore/ops/_op_impl/tbe/topkv2.py rename to mindspore/ops/_op_impl/tbe/top_k.py index a03871f8b74..92733bbf463 100644 --- a/mindspore/ops/_op_impl/tbe/topkv2.py +++ b/mindspore/ops/_op_impl/tbe/top_k.py @@ -13,15 +13,15 @@ # limitations under the License. # ============================================================================ -"""TopKV2 op""" +"""TopK op""" from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType -top_k_v2_op_info = TBERegOp("TopKV2") \ +top_k_op_info = TBERegOp("TopK") \ .fusion_type("OPAQUE") \ .async_flag(False) \ - .binfile_name("top_k_v2.so") \ + .binfile_name("top_k.so") \ .compute_cost(10) \ - .kernel_name("top_k_v2") \ + .kernel_name("top_k") \ .partial_flag(True) \ .attr("k", "required", "int", "all")\ .attr("sorted", "required", "bool", "all")\ @@ -33,7 +33,7 @@ top_k_v2_op_info = TBERegOp("TopKV2") \ .get_op_info() -@op_info_register(top_k_v2_op_info) -def _topk_v2_tbe(): - """TopKV2 TBE register""" +@op_info_register(top_k_op_info) +def _top_k_tbe(): + """TopK TBE register""" return diff --git a/mindspore/ops/op_info_register.py b/mindspore/ops/op_info_register.py index 28821b621e8..e4b0bfdbfed 100644 --- a/mindspore/ops/op_info_register.py +++ b/mindspore/ops/op_info_register.py @@ -599,3 +599,4 @@ class DataType: F32_NCHW = ("float32", "NCHW") F32_NHWC = ("float32", "NHWC") F32_HWCN = ("float32", "HWCN") + \ No newline at end of file diff --git a/tests/st/ops/davinci/test_tbe_ops/test_smooth_l1_loss.py b/tests/st/ops/davinci/test_tbe_ops/test_smooth_l1_loss.py new file mode 100644 index 00000000000..cc0c0e0fc27 --- /dev/null +++ b/tests/st/ops/davinci/test_tbe_ops/test_smooth_l1_loss.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import mindspore.nn as nn +import mindspore.context as context +from mindspore import Tensor +from mindspore.ops import operations as P +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self, sigma=1.0): + super(Net, self).__init__() + self.SmoothL1Loss = P.SmoothL1Loss(sigma) + + def construct(self, pred, gt): + return self.SmoothL1Loss(pred, gt) + + +def test_net(): + pred = np.random.randn(2, 4).astype(np.float32) + gt = np.random.randn(2, 4).astype(np.float32) + smooth_l1_loss = Net() + loss = smooth_l1_loss(Tensor(pred), Tensor(gt)) + print("------------- input ---------------") + print("predict:\n", pred) + print("grount truth:\n", gt) + print("------------- output ---------------") + print("loss:\n", loss.asnumpy()) diff --git a/tests/st/ops/davinci/test_tbe_ops/test_smooth_l1_loss_grad.py b/tests/st/ops/davinci/test_tbe_ops/test_smooth_l1_loss_grad.py new file mode 100644 index 00000000000..1ab9d998a16 --- /dev/null +++ b/tests/st/ops/davinci/test_tbe_ops/test_smooth_l1_loss_grad.py @@ -0,0 +1,55 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import mindspore.nn as nn +import mindspore.context as context +from mindspore.ops.composite import GradOperation +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self, sigma=1.0): + super(Net, self).__init__() + self.SmoothL1Loss = P.SmoothL1Loss(sigma) + + def construct(self, pred, gt): + return self.SmoothL1Loss(pred, gt) + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) + self.network = network + + def construct(self, pred, gt, dout): + return self.grad(self.network)(pred, gt, dout) + + +def test_net(): + pred = np.random.randn(2, 4).astype(np.float32) + gt = np.random.randn(2, 4).astype(np.float32) + dout = np.random.randn(2, 4).astype(np.float32) + smooth_l1_loss_grad = Grad(Net()) + output = smooth_l1_loss_grad(Tensor(pred), Tensor(gt), Tensor(dout)) + print("------------- input ---------------") + print("predict:\n", pred) + print("grount truth:\n", gt) + print("dout:\n", dout) + print("------------- output ---------------") + print("predict grad:\n", output[0].asnumpy()) diff --git a/tests/st/ops/davinci/test_tbe_ops/test_topkv2.py b/tests/st/ops/davinci/test_tbe_ops/test_topk.py similarity index 97% rename from tests/st/ops/davinci/test_tbe_ops/test_topkv2.py rename to tests/st/ops/davinci/test_tbe_ops/test_topk.py index a5058656372..275ef50038f 100644 --- a/tests/st/ops/davinci/test_tbe_ops/test_topkv2.py +++ b/tests/st/ops/davinci/test_tbe_ops/test_topk.py @@ -24,7 +24,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") class Net(nn.Cell): def __init__(self, k): super(Net, self).__init__() - self.topk = P.TopK() + self.topk = P.TopK(True) self.k = k def construct(self, x):