From c740a1d2ea61e6faa0bfe8ac4df1eb4878db931f Mon Sep 17 00:00:00 2001 From: linqingke Date: Mon, 21 Mar 2022 16:38:20 +0800 Subject: [PATCH] add custom aicpu node st. --- .../ascend/kernel/aicpu/aicpu_kernel_load.cc | 2 +- .../test_random_choice_with_mask.py | 73 +++++++++++++++++++ 2 files changed, 74 insertions(+), 1 deletion(-) create mode 100644 tests/st/ops/ascend/test_aicpu_ops/test_random_choice_with_mask.py diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_kernel_load.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_kernel_load.cc index 02ed9bfa376..9dee2e70d91 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_kernel_load.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_kernel_load.cc @@ -105,7 +105,7 @@ bool AicpuOpKernelLoad::GetSoNeedLoadPath(const std::string &so_name, std::strin MS_LOG(ERROR) << "Current path [" << cust_kernel_so_path << "] is invalid."; return false; } - auto real_cust_kernel_so_path = cust_kernel_so_path.substr(0, pos) + "/lib/"; + auto real_cust_kernel_so_path = cust_kernel_so_path.substr(0, pos) + "/"; if (real_cust_kernel_so_path.size() > PATH_MAX) { MS_LOG(ERROR) << "Current path [" << real_cust_kernel_so_path << "] is too long."; diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_random_choice_with_mask.py b/tests/st/ops/ascend/test_aicpu_ops/test_random_choice_with_mask.py new file mode 100644 index 00000000000..3beef3e6dda --- /dev/null +++ b/tests/st/ops/ascend/test_aicpu_ops/test_random_choice_with_mask.py @@ -0,0 +1,73 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class RandomChoiceWithMaskNet(nn.Cell): + def __init__(self): + super(RandomChoiceWithMaskNet, self).__init__() + self.random_choice_with_mask = P.RandomChoiceWithMask(count=4, seed=1) + self.random_choice_with_mask.add_prim_attr("cust_aicpu", "mindspore_aicpu_kernels") + + def construct(self, x): + return self.random_choice_with_mask(x) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_random_choice_with_mask_graph(): + """ + Feature: Custom aicpu feature. + Description: Test random_choice_with_mask kernel in graph mode. + Expectation: No exception. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], + [0, 0, 0, 1]]).astype(np.bool)) + expect1 = (4, 2) + expect2 = (4,) + net = RandomChoiceWithMaskNet() + output1, output2 = net(input_tensor) + assert output1.shape == expect1 + assert output2.shape == expect2 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_random_choice_with_mask_pynative(): + """ + Feature: Custom aicpu feature. + Description: Test random_choice_with_mask kernel in pynative mode. + Expectation: No exception. + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], + [0, 0, 0, 1]]).astype(np.bool)) + expect1 = (4, 2) + expect2 = (4,) + net = RandomChoiceWithMaskNet() + output1, output2 = net(input_tensor) + assert output1.shape == expect1 + assert output2.shape == expect2