!65973 broadcastto 接入aclnn

Merge pull request !65973 from 张栩浩/r2.3_broadcastto
This commit is contained in:
i-robot 2024-03-09 12:26:00 +00:00 committed by Gitee
commit b0b92092e8
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 249 additions and 6 deletions

View File

@ -10,3 +10,6 @@ broadcast_to:
returns:
output:
dtype: tensor
view: True
dispatch:
enable: True

View File

@ -114,7 +114,7 @@ TensorStorageInfoPtrList BroadCastToProcess(const PrimitivePtr &prim, const tens
return {new_storage_info};
}
TensorStorageInfoPtrList BroadCastToCalc(const PrimitivePtr &prim, const std::vector<ValuePtr> &inputs) {
TensorStorageInfoPtrList BroadcastToCalc(const PrimitivePtr &prim, const std::vector<ValuePtr> &inputs) {
if (CheckInputsNull(inputs, kBroadCastToInputsNum) || !inputs[0]->isa<tensor::Tensor>()) {
return {};
}
@ -125,5 +125,5 @@ TensorStorageInfoPtrList BroadCastToCalc(const PrimitivePtr &prim, const std::ve
return BroadCastToProcess(prim, input_tensor, input_x);
}
REG_VIEW_STRIDES_CALC_FUN(BroadcastTo, BroadCastToCalc);
REG_VIEW_STRIDES_CALC_FUN(BroadcastTo, BroadcastToCalc);
} // namespace mindspore::ops

View File

@ -22,7 +22,7 @@
namespace mindspore {
namespace ops {
TensorStorageInfoPtrList BroadCastToCalc(const PrimitivePtr &prim, const std::vector<ValuePtr> &inputs);
MS_CORE_API TensorStorageInfoPtrList BroadcastToCalc(const PrimitivePtr &prim, const std::vector<ValuePtr> &inputs);
TensorStorageInfoPtrList BroadCastToProcess(const PrimitivePtr &prim, const tensor::TensorPtr input_tensor,
const std::vector<int64_t> &input_x);
} // namespace ops

View File

@ -0,0 +1,240 @@
# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.context as context
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.nn import Cell
from mindspore.common.api import _pynative_executor
@pytest.mark.level0
@pytest.mark.parametrize('context_mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_broadcast(context_mode):
"""
Feature: pyboost function.
Description: test function broadcast_to forward.
Expectation: expect correct result.
"""
context.set_context(mode=context_mode)
shape = (4, 5, 2, 3, 4, 5, 6)
x_np = np.random.rand(2, 3, 1, 5, 1).astype(np.float32)
output = P.BroadcastTo(shape)(Tensor(x_np))
expect = np.broadcast_to(x_np, shape)
assert np.allclose(output.asnumpy(), expect)
shape = (3, 5, 7, 4, 5, 6)
x_np = np.arange(20).reshape((4, 5, 1)).astype(np.int32)
output = P.BroadcastTo(shape)(Tensor(x_np))
expect = np.broadcast_to(x_np, shape)
assert np.allclose(output.asnumpy(), expect)
shape = (8, 5, 7, 4, 5, 6)
x_np = np.arange(24).reshape((1, 4, 1, 6)).astype(np.bool)
output = P.BroadcastTo(shape)(Tensor(x_np))
expect = np.broadcast_to(x_np, shape)
assert np.allclose(output.asnumpy(), expect)
shape = (3, 4, 5, 2, 3, 4, 5, 7)
x_np = np.random.rand(2, 3, 1, 5, 1).astype(np.float16)
output = P.BroadcastTo(shape)(Tensor(x_np))
expect = np.broadcast_to(x_np, shape)
assert np.allclose(output.asnumpy(), expect)
shape = (3, 4, 5, 6)
x_np = np.random.rand(3, 1, 5, 1).astype(np.float32)
output = P.BroadcastTo(shape)(Tensor(x_np))
expect = np.broadcast_to(x_np, shape)
assert np.allclose(output.asnumpy(), expect)
x1_np = np.random.rand(3, 1, 5, 1).astype(np.float16)
output = P.BroadcastTo(shape)(Tensor(x1_np))
expect = np.broadcast_to(x1_np, shape)
assert np.allclose(output.asnumpy(), expect)
shape = (2, 3, 4, 5)
x1_np = np.random.rand(4, 5).astype(np.float32)
output = P.BroadcastTo(shape)(Tensor(x1_np))
expect = np.broadcast_to(x1_np, shape)
assert np.allclose(output.asnumpy(), expect)
def broadcast_to_dtype(dtype):
"""
Basic function to test data type of BroadcastTo.
"""
shape = (2, 3, 4, 5)
x1_np = np.random.rand(4, 5).astype(dtype)
output = P.BroadcastTo(shape)(Tensor(x1_np))
expect = np.broadcast_to(x1_np, shape)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.parametrize('context_mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_broadcast_to_dtype(context_mode):
"""
Feature: Test supported data types of BroadCastTo.
Description: all data types
Expectation: success.
"""
context.set_context(mode=context_mode)
types = [np.float16, np.float32, np.float64, np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64, np.complex64, np.complex128]
for dtype in types:
broadcast_to_dtype(dtype=dtype)
@pytest.mark.level1
@pytest.mark.parametrize('context_mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_broadcast_dyn_init(context_mode):
"""
Feature: pyboost function.
Description: Test running the op with -1's in the init shape to support varied inputs.
Expectation: expect correct result.
"""
context.set_context(mode=context_mode)
ms_shape = (-1, -1, 5, 6)
np_shape = (3, 4, 5, 6)
x_np = np.random.rand(3, 1, 5, 1).astype(np.float32)
output = P.BroadcastTo(ms_shape)(Tensor(x_np))
expect = np.broadcast_to(x_np, np_shape)
assert np.allclose(output.asnumpy(), expect)
x1_np = np.random.rand(3, 1, 5, 1).astype(np.float16)
output = P.BroadcastTo(ms_shape)(Tensor(x1_np))
expect = np.broadcast_to(x1_np, np_shape)
assert np.allclose(output.asnumpy(), expect)
ms_shape = (2, 3, -1, -1)
np_shape = (2, 3, 4, 5)
x1_np = np.random.rand(4, 5).astype(np.float32)
output = P.BroadcastTo(ms_shape)(Tensor(x1_np))
expect = np.broadcast_to(x1_np, np_shape)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.parametrize('context_mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_broadcast_dyn_invalid_init(context_mode):
"""
Feature: pyboost function.
Description: Test running the op with -1's in the init shape in incorrect positions.
Expectation: Expected to fail.
"""
context.set_context(mode=context_mode)
ms_shape = (2, -1, 4, 5)
x_np = np.random.rand(4, 5).astype(np.float32)
with pytest.raises(ValueError):
P.BroadcastTo(ms_shape)(Tensor(x_np))
_pynative_executor.sync()
ms_shape = (-1, 1, -1, -1)
x_np = np.random.rand(4, 5).astype(np.float32)
with pytest.raises(ValueError):
P.BroadcastTo(ms_shape)(Tensor(x_np))
_pynative_executor.sync()
class BroadcastToNet(Cell):
"""
Construct of dynamic input for BroadcastTo.
"""
def __init__(self, shape):
super().__init__()
self.broadcastto = P.BroadcastTo(shape)
def construct(self, input_x):
return self.broadcastto(input_x)
@pytest.mark.level1
@pytest.mark.parametrize('context_mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_broadcast_to_dynamic_shape(context_mode):
"""
Feature: Test dynamic shape of BroadcastTo operator
Description: dynamic input
Expectation: success.
"""
context.set_context(mode=context_mode)
shape = (2, 2, 3)
input_x_np = np.random.randn(2, 3).astype(np.float32)
input_x = Tensor(input_x_np)
input_dyn = Tensor(shape=[None, 3], dtype=input_x.dtype)
broadcast_to_net = BroadcastToNet(shape)
broadcast_to_net.set_inputs(input_dyn)
output = broadcast_to_net(input_x)
expect = np.broadcast_to(input_x_np, shape)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.parametrize('context_mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_broadcast_exception(context_mode):
"""
Feature: Test invalid input and target shape in of BroadcastTo.
Description: target shape is empty, but input shape is not empty.
Expectation: the result match with expected result.
"""
with pytest.raises(Exception) as info:
context.set_context(mode=context_mode)
shape = (0,)
x_np = np.random.randint(1, 4)
P.BroadcastTo(shape)(Tensor(x_np))
assert "ValueError: For 'BroadcastTo', each dimension pair, input_x shape and target shape must be equal or \
input dimension is 1 or target dimension is -1. But got input_x shape: [const vector][], target shape: \
[const vector][0]." in str(info.value)

View File

@ -37,7 +37,7 @@ TEST_F(TestViewBroadcastTo, func) {
auto input_tensor = std::make_shared<tensor::Tensor>(tensor_data, kInt64);
input_tensor->set_shape({1, 4});
auto storage_list = BroadCastToCalc(prim, std::vector<ValuePtr>({input_tensor, input_perm}));
auto storage_list = BroadcastToCalc(prim, std::vector<ValuePtr>({input_tensor, input_perm}));
std::vector<int64_t> expect_shape({2, 1, 4});
std::vector<int64_t> expect_strides({0, 4, 1});
size_t expect_size = 1;
@ -67,7 +67,7 @@ TEST_F(TestViewBroadcastTo, BroadDim) {
tensor_total_length * sizeof(int64_t));
std::vector<ValuePtr> inputs{input_tensor, input_perm};
auto storage_list = BroadCastToCalc(prim, inputs);
auto storage_list = BroadcastToCalc(prim, inputs);
std::vector<int64_t> expect_shape({2, 1, 2, 3});
std::vector<int64_t> expect_strides({0, 6, 3, 1});
size_t expect_size = 1;
@ -80,7 +80,7 @@ TEST_F(TestViewBroadcastTo, BroadDim) {
input_perm = MakeValue(perm_data);
inputs[kIndex1] = input_perm;
storage_list = BroadCastToCalc(prim, inputs);
storage_list = BroadcastToCalc(prim, inputs);
std::vector<int64_t> expect_shape_2({3, 2, 3});
std::vector<int64_t> expect_strides_2({0, 3, 1});
ASSERT_EQ(storage_list.size(), expect_size);