aclnn: selece/where

Remove functional API

InferShape support broadcast

Add Pass for Select

Bprop recover shapes from broadcast

Add ut & st

Revise doc
This commit is contained in:
Dairenjie1 2024-03-21 17:28:26 +08:00 committed by hedongdong
parent f462ec19d1
commit 9ad846649e
16 changed files with 758 additions and 269 deletions

View File

@ -1,31 +1,31 @@
mindspore.ops.select
====================
.. py:function:: mindspore.ops.select(cond, x, y)
.. py:function:: mindspore.ops.select(condition, input, other)
根据条件判断Tensor中的元素的值来决定输出中的相应元素是从 `x` 如果元素值为True还是从 `y` 如果元素值为False中选择。
根据条件判断Tensor中的元素的值来决定输出中的相应元素是从 `input` 如果元素值为True还是从 `other` 如果元素值为False中选择。
该算法可以被定义为:
.. math::
out_i = \begin{cases}
x_i, & \text{if } cond_i \\
y_i, & \text{otherwise}
input_i, & \text{if } condition_i \\
other_i, & \text{otherwise}
\end{cases}
参数:
- **cond** (Tensor[bool]) - 条件Tensor决定选择哪一个元素shape是 :math:`(x_1, x_2, ..., x_N, ..., x_R)`
- **x** (Union[Tensor, int, float]) - 第一个被选择的Tensor或者数字。
如果x是一个Tensor那么shape是或者可以被广播为 :math:`(x_1, x_2, ..., x_N, ..., x_R)`
如果x是int或者float那么将会被转化为int32或者float32类型并且被广播为与y相同的shape。x和y中至少要有一个Tensor。
- **y** (Union[Tensor, int, float]) - 第二个被选择的Tensor或者数字。
如果y是一个Tensor那么shape是或者可以被广播为 :math:`(x_1, x_2, ..., x_N, ..., x_R)`
如果y是int或者float那么将会被转化为int32或者float32类型并且被广播为与x相同的shape。x和y中至少要有一个Tensor。
- **condition** (Tensor[bool]) - 条件Tensor决定选择哪一个元素shape是 :math:`(x_1, x_2, ..., x_N, ..., x_R)`
- **input** (Union[Tensor, int, float]) - 第一个被选择的Tensor或者数字。
如果input是一个Tensor那么shape是或者可以被广播为 :math:`(x_1, x_2, ..., x_N, ..., x_R)`
如果input是int或者float那么将会被转化为int32或者float32类型并且被广播为与y相同的shape。x和y中至少要有一个Tensor。
- **other** (Union[Tensor, int, float]) - 第二个被选择的Tensor或者数字。
如果other是一个Tensor那么shape是或者可以被广播为 :math:`(x_1, x_2, ..., x_N, ..., x_R)`
如果other是int或者float那么将会被转化为int32或者float32类型并且被广播为与x相同的shape。x和y中至少要有一个Tensor。
返回:
Tensor`cond` 的shape相同。
Tensor`condition` 的shape相同。
异常:
- **TypeError** - `x``y` 不是Tensor、int或者float。
- **TypeError** - `input``other` 不是Tensor、int或者float。
- **ValueError** - 输入的shape不能被广播。

View File

@ -1,23 +1,23 @@
mindspore.ops.where
====================
.. py:function:: mindspore.ops.where(condition, x, y)
.. py:function:: mindspore.ops.where(condition, input, other)
返回一个TensorTensor的元素从 `x``y` 中根据 `condition` 选择。
返回一个TensorTensor的元素从 `input``other` 中根据 `condition` 选择。
.. math::
output_i = \begin{cases} x_i,\quad &if\ condition_i \\ y_i,\quad &otherwise \end{cases}
output_i = \begin{cases} input_i,\quad &if\ condition_i \\ other_i,\quad &otherwise \end{cases}
参数:
- **condition** (Tensor[bool]) - 如果是 ``True`` ,选取 `x` 中的元素,否则选取 `y` 中的元素。
- **x** (Union[Tensor, Scalar]) - 在 `condition```True`` 的索引处选择的值。
- **y** (Union[Tensor, Scalar]) - 当 `condition```False`` 的索引处选择的值。
- **condition** (Tensor[bool]) - 如果是 ``True`` ,选取 `input` 中的元素,否则选取 `other` 中的元素。
- **input** (Union[Tensor, Scalar]) - 在 `condition```True`` 的索引处选择的值。
- **other** (Union[Tensor, Scalar]) - 当 `condition```False`` 的索引处选择的值。
返回:
Tensor其中的元素从 `x``y` 中选取。
Tensor其中的元素从 `input``other` 中选取。
异常:
- **TypeError** - 如果 `condition` 不是Tensor。
- **TypeError** - 如果 `x``y` 都是常量。
- **ValueError** - `condition``x``y` 不能互相广播。
- **TypeError** - 如果 `input``other` 都是常量。
- **ValueError** - `condition``input``other` 不能互相广播。

View File

@ -925,7 +925,10 @@ REG_BPROP_BUILDER("Select").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
auto dout = ib->GetInput(kIndex4);
auto dx = x->need_compute_grad_out() ? ib->Select(cond, dout, ib->ZerosLike(x)) : ib->OutZeros(x);
auto dy = x->need_compute_grad_out() ? ib->Select(cond, ib->ZerosLike(y), dout) : ib->OutZeros(y);
return {ib->OutZeros(cond), dx, dy};
auto bc_x = BinopGradCommon(ib, cond, x, dout, dx);
auto bc_y = BinopGradCommon(ib, cond, y, dout, dy);
auto ret = BinopGradCommon(ib, x, y, bc_x[kIndex1], bc_y[kIndex1]);
return {ib->OutZeros(cond), ret[kIndex0], ret[kIndex1]};
});
REG_BPROP_BUILDER("OnesLike").SetUnusedInputs({i0, i1, i2}).SetBody(ReturnZeros);

View File

@ -0,0 +1,127 @@
/**
* Copyright 2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/optimizer/ge/broadcast_for_select.h"
#include <vector>
#include <memory>
#include <algorithm>
#include "mindspore/core/ops/array_ops.h"
#include "include/common/utils/anfalgo.h"
#include "include/backend/anf_runtime_algorithm.h"
namespace mindspore {
namespace opt {
namespace {
ShapeVector GetSelectInputShape(const AnfNodePtr &input) {
MS_EXCEPTION_IF_NULL(input);
auto input_base_shape = input->Shape();
MS_EXCEPTION_IF_NULL(input_base_shape);
auto input_shape = input_base_shape->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(input_shape);
return input_shape->shape();
}
ShapeVector CalcBroadcastShape(AnfNodePtr cond, AnfNodePtr x, AnfNodePtr y) {
auto cond_shape = GetSelectInputShape(cond);
auto x_shape = GetSelectInputShape(x);
auto y_shape = GetSelectInputShape(y);
auto cond_size = cond_shape.size();
auto x_size = x_shape.size();
auto y_size = y_shape.size();
ShapeVector broadcast_shape =
cond_size > x_size ? cond_size > y_size ? cond_shape : y_shape : x_size > y_size ? x_shape : y_shape;
auto n = broadcast_shape.size();
for (size_t i = n; i > 0; --i) {
auto cond_i = cond_size < i ? 1 : cond_shape[cond_size - i];
auto x_i = x_size < i ? 1 : x_shape[x_size - i];
auto y_i = y_size < i ? 1 : y_shape[y_size - i];
auto broadcost_i = std::max(cond_i, std::max(x_i, y_i));
if (cond_i != broadcost_i && cond_i != 1) {
MS_EXCEPTION(ValueError) << "For select, condition input can not broadcast at index " << i;
}
if (x_i != broadcost_i && x_i != 1) {
MS_EXCEPTION(ValueError) << "For select, x input can not broadcast at index " << i;
}
if (y_i != broadcost_i && y_i != 1) {
MS_EXCEPTION(ValueError) << "For select, y input can not broadcast at index " << i;
}
broadcast_shape[n - i] = broadcost_i;
}
return broadcast_shape;
}
CNodePtr AddBroadCastToNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
const std::vector<int64_t> &broad_shape) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(input_node);
auto input_type = common::AnfAlgo::GetOutputInferDataType(input_node, 0);
auto shape_node = opt::CreateValueNodeWithKernelInfo(func_graph, MakeValue(broad_shape));
std::vector<AnfNodePtr> broadcastto_inputs = {
NewValueNode(std::make_shared<Primitive>(prim::kPrimBroadcastTo->name())), input_node, shape_node};
CNodePtr broadcastto_node = NewCNode(broadcastto_inputs, func_graph);
MS_EXCEPTION_IF_NULL(broadcastto_node);
broadcastto_node->set_scope(input_node->scope());
broadcastto_node->set_abstract(input_node->abstract());
common::AnfAlgo::SetOutputInferTypeAndShape({input_type}, {broad_shape}, broadcastto_node.get());
return broadcastto_node;
}
CNodePtr AddSelectNode(const FuncGraphPtr &func_graph, const CNodePtr &cond_node, const CNodePtr &x_node,
const CNodePtr &y_node, const CNodePtr &select_node, const std::vector<int64_t> &broad_shape) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cond_node);
MS_EXCEPTION_IF_NULL(x_node);
MS_EXCEPTION_IF_NULL(y_node);
MS_EXCEPTION_IF_NULL(select_node);
auto input_type = common::AnfAlgo::GetOutputInferDataType(select_node, 0);
std::vector<AnfNodePtr> select_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSelect->name())),
cond_node, x_node, y_node};
CNodePtr out_node = NewCNode(select_inputs, func_graph);
MS_EXCEPTION_IF_NULL(out_node);
out_node->set_scope(select_node->scope());
out_node->set_abstract(select_node->abstract());
common::AnfAlgo::SetOutputInferTypeAndShape({input_type}, {broad_shape}, out_node.get());
return out_node;
}
} // namespace
const BaseRef BroadCastForSelect::DefinePattern() const {
VarPtr inputs = std::make_shared<SeqVar>();
return VectorRef({prim::kPrimSelect, inputs});
}
const AnfNodePtr BroadCastForSelect::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
// Select(...) ===> inputs -> CalcBroadcastShape -> BroadCastTo -> Select(...)
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto select_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(select_node);
// get broadcast shape
auto cond = select_node->input(kIndex1);
auto x = select_node->input(kIndex2);
auto y = select_node->input(kIndex3);
auto output_shape = CalcBroadcastShape(cond, x, y);
// do BroadCast
auto new_cond = AddBroadCastToNode(graph, cond, output_shape);
auto new_x = AddBroadCastToNode(graph, x, output_shape);
auto new_y = AddBroadCastToNode(graph, y, output_shape);
auto out_node = AddSelectNode(graph, new_cond, new_x, new_y, select_node, output_shape);
return out_node;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_GE_BROADCAST_FOR_SELECT_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_GE_BROADCAST_FOR_SELECT_H_
#include <utility>
#include <string>
#include <vector>
#include <map>
#include "include/backend/optimizer/optimizer.h"
#include "ops/auto_generate/gen_ops_primitive.h"
namespace mindspore {
namespace opt {
class BroadCastForSelect : public PatternProcessPass {
public:
explicit BroadCastForSelect(bool multi_graph = true) : PatternProcessPass("broadcast_for_select", multi_graph) {}
~BroadCastForSelect() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_GE_BROADCAST_FOR_SELECT_H_

View File

@ -63,6 +63,7 @@
#include "plugin/device/ascend/optimizer/enhancer/eliminate_maketuple_getitem.h"
#include "plugin/device/ascend/optimizer/ge/convert_pad_v3_paddings.h"
#include "plugin/device/ascend/optimizer/ir_fusion/shape_reshape_fusion.h"
#include "plugin/device/ascend/optimizer/ge/broadcast_for_select.h"
namespace mindspore {
namespace opt {
@ -98,6 +99,7 @@ void GEBackendOptimization(const KernelGraphPtr &kernel_graph) {
opt_ge_pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>(true, true));
opt_ge_pm->AddPass(std::make_shared<opt::UnfoldNestedOutput>("unfold_nested_output"));
opt_ge_pm->AddPass(std::make_shared<opt::UnfoldMaketuple>("unfold_nested_maketuple"));
opt_ge_pm->AddPass(std::make_shared<opt::BroadCastForSelect>());
optimizer->AddPassManager(opt_ge_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();

View File

@ -1,30 +1,30 @@
select:
description: |
The conditional tensor determines whether the corresponding element in the output must be
selected from `x` (if True) or `y` (if False) based on the value of each
selected from `input` (if True) or `other` (if False) based on the value of each
element.
It can be defined as:
.. math::
out_i = \begin{cases}
x_i, & \text{if } cond_i \\
y_i, & \text{otherwise}
input_i, & \text{if } condition_i \\
other_i, & \text{otherwise}
\end{cases}
Inputs:
- **cond** (Tensor[bool]): The condition tensor, decides which element is chosen.
- **condition** (Tensor[bool]): The condition tensor, decides which element is chosen.
The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
- **x** (Tensor): The first Tensor to be selected.
- **input** (Tensor): The first Tensor to be selected.
The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
- **y** (Tensor): The second Tensor to be selected.
- **other** (Tensor): The second Tensor to be selected.
The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
Outputs:
Tensor, has the same shape as `cond`.
Tensor, has the same shape as `condition`.
Raises:
TypeError: If x or y is not a Tensor.
TypeError: If input or other is not a Tensor.
ValueError: The shape of inputs are different.
Supported Platforms:

View File

@ -1,14 +1,18 @@
#operator select
#operator select/where
select:
args:
cond:
condition:
dtype: tensor
x:
input:
dtype: tensor
y:
type_cast: number
other:
dtype: tensor
type_cast: number
args_signature:
dtype_group: (condition), (input, other)
returns:
output:
dtype: tensor
function:
disable: True
dispatch:
enable: True

View File

@ -46,18 +46,6 @@ namespace ops {
using float_complex = std::complex<float>;
using double_complex = std::complex<double>;
void SelectInferShapeCheck(const std::vector<int64_t> &x_shape, const std::vector<int64_t> &y_shape,
const std::vector<int64_t> &cond_shape, size_t shape_size) {
for (size_t i = 0; i < shape_size; i++) {
if ((x_shape[i] > 0 && cond_shape[i] > 0 && x_shape[i] != cond_shape[i]) ||
(x_shape[i] > 0 && y_shape[i] > 0 && x_shape[i] != y_shape[i])) {
MS_EXCEPTION(ValueError)
<< "For 'Select', the shape of 'condition', 'x' and 'y' must be the same. But got 'condition' shape: "
<< cond_shape << ", 'x' shape: " << x_shape << ", 'y' shape: " << y_shape << ".";
}
}
}
abstract::BaseShapePtr SelectFuncImpl::InferShape(const PrimitivePtr &prim,
const std::vector<AbstractBasePtr> &input_args) const {
auto cond_shape = input_args[kSelectCondIndex]->GetShape()->GetShapeVector();
@ -66,16 +54,9 @@ abstract::BaseShapePtr SelectFuncImpl::InferShape(const PrimitivePtr &prim,
if (IsDynamicRank(cond_shape) || IsDynamicRank(x_shape) || IsDynamicRank(y_shape)) {
return std::make_shared<abstract::TensorShape>(ShapeVector{abstract::TensorShape::kShapeRankAny});
}
auto cond_shape_size = cond_shape.size();
auto x_shape_size = x_shape.size();
auto y_shape_size = y_shape.size();
if (cond_shape_size != x_shape_size || y_shape_size != x_shape_size) {
MS_EXCEPTION(ValueError)
<< "For 'Select', the shape of 'condition', 'x' and 'y' must be the same. But got 'condition' shape: "
<< cond_shape << ", 'x' shape: " << x_shape << ", 'y' shape: " << y_shape << ".";
}
SelectInferShapeCheck(x_shape, y_shape, cond_shape, x_shape_size);
return input_args[kSelectCondIndex]->GetShape()->Clone();
auto broadcast_output_size = CalBroadCastShape(x_shape, y_shape, prim->name(), "input", "other");
auto output_size = CalBroadCastShape(cond_shape, broadcast_output_size, prim->name(), "condition", "input");
return std::make_shared<abstract::TensorShape>(output_size);
}
TypePtr SelectFuncImpl::InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const {
@ -94,11 +75,6 @@ TypePtr SelectFuncImpl::InferType(const PrimitivePtr &prim, const std::vector<Ab
(void)CheckAndConvertUtils::CheckTensorTypeValid("y_type", y_type, common_valid_types_with_complex_and_bool,
prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("cond", cond_type, {kBool}, prim_name);
if (*x_type != *y_type) {
MS_EXCEPTION(TypeError) << "For '" << prim_name
<< "', the x_type and y_type must be the same, but got x_type: " << x_type->ToString()
<< " and y_type: " << y_type->ToString() << ".";
}
return x_type->Clone();
}
} // namespace ops

View File

@ -18,8 +18,9 @@ from mindspore.ops.extend import *
from mindspore.ops.extend import array_func, math_func, nn_func
from mindspore.mint.nn.functional import *
from mindspore.mint.nn import functional
from mindspore.ops import where
__all__ = []
__all__ = ['where']
__all__.extend(array_func.__all__)
__all__.extend(math_func.__all__)
__all__.extend(nn_func.__all__)

View File

@ -61,7 +61,7 @@ from mindspore.ops._utils.utils import ms_arrange
from mindspore.ops.auto_generate import cat, range, scatter_nd, deepcopy, masked_fill, diagonal, expand_dims, \
nonzero, flip, transpose, unsorted_segment_sum, diag, gather, gather_d, gather_nd, reshape, broadcast_to, \
strided_slice, ones, zeros, max_, min_
strided_slice, ones, zeros, max_, min_, select
from mindspore.ops.operations.manually_defined import tile, rank, scalar_cast
arg_max_with_value_ = ArgMaxWithValue()
@ -390,25 +390,25 @@ def hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype
return out
def where(condition, x, y):
def where(condition, input, other):
r"""
Selects elements from `x` or `y` based on `condition` and returns a tensor.
Selects elements from `input` or `other` based on `condition` and returns a tensor.
.. math::
output_i = \begin{cases} x_i,\quad &if\ condition_i \\ y_i,\quad &otherwise \end{cases}
output_i = \begin{cases} input_i,\quad &if\ condition_i \\ other_i,\quad &otherwise \end{cases}
Args:
condition (Tensor[bool]): If True, yield `x`, otherwise yield `y`.
x (Union[Tensor, Scalar]): When `condition` is True, values to select from.
y (Union[Tensor, Scalar]): When `condition` is False, values to select from.
condition (Tensor[bool]): If True, yield `input`, otherwise yield `other`.
input (Union[Tensor, Scalar]): When `condition` is True, values to select from.
other (Union[Tensor, Scalar]): When `condition` is False, values to select from.
Returns:
Tensor, elements are selected from `x` and `y`.
Tensor, elements are selected from `input` and `other`.
Raises:
TypeError: If `condition` is not a Tensor.
TypeError: If both `x` and `y` are scalars.
ValueError: If `condition`, `x` and `y` can not broadcast to each other.
TypeError: If both `input` and `other` are scalars.
ValueError: If `condition`, `input` and `other` can not broadcast to each other.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -425,25 +425,7 @@ def where(condition, x, y):
[[0. 1.]
[2. 1.]]
"""
if not isinstance(condition, Tensor):
raise TypeError(f"For 'where', 'condition' must be a Tensor, but got {type(condition)}.")
if isinstance(x, (int, float)):
if not isinstance(y, Tensor):
raise TypeError(
f"For 'where', at least one of 'x' and 'y' should be Tensor, but got x:{type(x)}, y:{type(y)}."
)
x = cast_(x, y.dtype)
elif isinstance(y, (int, float)):
if not isinstance(x, Tensor):
raise TypeError(
f"For 'where', at least one of 'x' and 'y' should be Tensor, but got x:{type(x)}, y:{type(y)}."
)
y = cast_(y, x.dtype)
output_shape = _calc_broadcast_shape(x.shape, y.shape, condition.shape)
condition = broadcast_to(condition, output_shape)
x = broadcast_to(x, output_shape)
y = broadcast_to(y, output_shape)
return tensor_select_(condition, x, y)
return tensor_select_(condition, input, other)
def reverse(x, axis):
@ -1435,171 +1417,6 @@ def flatten(input, order='C', *, start_dim=1, end_dim=-1):
return reshape_(input, new_shape)
@constexpr
def _check_select_type_match(scalar, tensor_type, scalar_name, tensor_name):
if isinstance(scalar, int) and tensor_type != mstype.int32:
raise TypeError(f"For functional operator[select], the input[{scalar_name}] is int, "
f"then the input[{tensor_name}] must be a Tensor of int32.")
if isinstance(scalar, float) and tensor_type != mstype.float32:
raise TypeError(f"For functional operator[select], the input[{scalar_name}] is float, "
f"then the input[{tensor_name}] must be a Tensor of float32.")
@_primexpr
def _check_select_shape_match(input_shape, cond_shape, tensor_name):
if input_shape != cond_shape:
raise ValueError(f"For functional operator[select], the cond shape must be same as {tensor_name} shape.")
@constexpr
def _check_select_type(is_cond_tensor, is_x_scalar, is_y_scalar, is_x_tensor, is_y_tensor):
if not is_cond_tensor:
raise TypeError(f"For functional operator[select], the input[cond] must be a Tensor.")
if is_x_scalar and not is_y_tensor:
raise TypeError(f"For functional operator[select], the input[x] is int or float, "
f"then the input[y] must be a Tensor.")
if is_y_scalar and not is_x_tensor:
raise TypeError(f"For functional operator[select], the input[y] is int or float, "
f"then the input[x] must be a Tensor.")
@constexpr
def _check_select_shape_same(cond_shape, x_shape, y_shape):
"""Check if input of select has same shape."""
return cond_shape == x_shape and x_shape == y_shape and cond_shape == y_shape
@constexpr
def get_max_value(x, y, z):
"""Get the maximum value of x, y and z."""
if x >= y and x >= z:
return x
if y >= x and y >= z:
return y
return z
@constexpr
def _calc_broadcast_shape(cond_shape, x_shape, y_shape):
"""Calculate broadcast shape for select"""
converted_shape = []
cond_reverse = cond_shape[::-1]
x_reverse = x_shape[::-1]
y_reverse = y_shape[::-1]
max_len = get_max_value(len(cond_reverse), len(x_reverse), len(y_reverse))
i = 0
while i < max_len:
cond_element = 1 if i >= len(cond_reverse) else cond_reverse[i]
x_element = 1 if i >= len(x_reverse) else x_reverse[i]
y_element = 1 if i >= len(y_reverse) else y_reverse[i]
broadcast_element = get_max_value(cond_element, x_element, y_element)
if cond_element not in (1, broadcast_element):
raise ValueError(f"For select, condition input can not broadcast at index {i}")
if x_element not in (1, broadcast_element):
raise ValueError(f"For select, x input can not broadcast at index {i}")
if y_element not in (1, broadcast_element):
raise ValueError(f"For select, y input can not broadcast at index {i}")
converted_shape.append(broadcast_element)
i = i + 1
converted_shape.reverse()
return tuple(converted_shape)
def select(cond, x, y):
r"""
The conditional tensor determines whether the corresponding element in the output must be
selected from `x` (if true) or `y` (if false) based on the value of each element.
It can be defined as:
.. math::
out_i = \begin{cases}
x_i, & \text{if } cond_i \\
y_i, & \text{otherwise}
\end{cases}
Args:
cond (Tensor[bool]): The condition tensor, decides which element is chosen.
The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
x (Union[Tensor, int, float]): The first Tensor or number to be selected.
If x is a Tensor, the shape is or can be broadcadt to :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
If x is an int or a float, it will be cast to the type of int32 or float32,
and broadcast to the same shape as y. One of x and y must be a Tensor.
y (Union[Tensor, int, float]): The second Tensor or number to be selected.
If y is a Tensor, The shape is or can be broadcadt to :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
If y is an int or a float, it will be cast to the type of int32 or float32,
and broadcast to the same shape as x. One of x and y must be a Tensor.
Returns:
Tensor, has the same shape as `cond`.
Raises:
TypeError: If `x` or `y` is not a Tensor, int or float.
ValueError: The shapes of inputs can not be broadcast.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore
>>> from mindspore import Tensor, ops
>>> # 1) Both inputs are Tensor
>>>
>>> cond = Tensor([True, False])
>>> x = Tensor([2,3], mindspore.float32)
>>> y = Tensor([1,2], mindspore.float32)
>>> output = ops.select(cond, x, y)
>>> print(output)
[2. 2.]
>>> # 2) y is a float
>>> cond = Tensor([True, False])
>>> x = Tensor([2,3], mindspore.float32)
>>> y = 2.0
>>> output = ops.select(cond, x, y)
>>> print(output)
[2. 2.]
"""
is_x_scalar = isinstance(x, (int, float))
is_y_scalar = isinstance(y, (int, float))
is_x_tensor = isinstance(x, Tensor)
is_y_tensor = isinstance(y, Tensor)
is_cond_tensor = isinstance(cond, Tensor)
_check_select_type(is_cond_tensor, is_x_scalar, is_y_scalar, is_x_tensor, is_y_tensor)
input_x = x
input_y = y
if is_x_scalar:
_check_select_shape_match(y.shape, cond.shape, "y")
_check_select_type_match(x, y.dtype, "x", "y")
input_x = zeros_like_(y) + x
if isinstance(x, int):
input_x = cast_(input_x, mstype.int32)
else:
input_x = cast_(input_x, mstype.float32)
if is_y_scalar:
_check_select_shape_match(x.shape, cond.shape, "x")
_check_select_type_match(y, x.dtype, "y", "x")
input_y = zeros_like_(x) + y
if isinstance(y, int):
input_y = cast_(input_y, mstype.int32)
else:
input_y = cast_(input_y, mstype.float32)
if is_x_tensor and is_y_tensor and is_cond_tensor:
x_shape = ops.shape(x)
y_shape = ops.shape(y)
cond_shape = ops.shape(cond)
all_constant = ops.isconstant(cond_shape) and ops.isconstant(x_shape) and ops.isconstant(y_shape)
if all_constant and not _check_select_shape_same(cond_shape, x_shape, y_shape):
broadcast_shape = _calc_broadcast_shape(cond_shape, x_shape, y_shape)
new_cond = ops.broadcast_to(cond, broadcast_shape)
new_x = ops.broadcast_to(x, broadcast_shape)
new_y = ops.broadcast_to(y, broadcast_shape)
return tensor_select_(new_cond, new_x, new_y)
return tensor_select_(cond, input_x, input_y)
def slice(input_x, begin, size):
r"""
Slices a tensor in the specified shape.

View File

@ -36,3 +36,4 @@ GroupNormGrad: 'aclnnGroupNormBackward'
NotEqual: 'aclnnNeTensor'
ClampScalar: 'aclnnClamp'
OneHotExt: 'aclnnOneHot'
Select: 'aclnnSWhere'

View File

@ -0,0 +1,310 @@
# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test select"""
import numpy as np
import pytest
import os
import mindspore.common.dtype as mstype
from mindspore.ops import select
from mindspore import ops, Tensor, jit, JitConfig, context
from tests.st.ops.dynamic_shape.test_op_utils import TEST_OP
from tests.st.utils import test_utils
def generate_random_input(shape, dtype):
return Tensor(np.random.randn(*shape).astype(dtype))
def generate_expect_forward_output(condition, x, y):
return np.where(condition, x, y)
def generate_expect_backward_output(condition):
return np.zeros(np.shape(condition), dtype=np.bool_),\
np.where(condition, 1, 0), np.where(condition, 0, 1)
@test_utils.run_with_cell
def select_forward_func(condition, x, y):
return select(condition, x, y)
@test_utils.run_with_cell
def select_backward_func(condition, x, y):
return ops.grad(select_forward_func, (0, 1, 2))(condition, x, y)
@test_utils.run_with_cell
def select_vmap_func(condition, x, y, in_axes=0):
return ops.vmap(select_forward_func, in_axes, out_axes=0)(condition, x, y)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_select_float32(mode):
"""
Feature: Test functional select operator. Support x or y is a float32 Tensor.
Description: Operator select's inputs `x` and `y` are Tensor with float32 type.
Expectation: Assert result.
"""
context.set_context(mode=mode)
cond = np.array([[True, False], [True, False]]).astype(np.bool)
x = np.array([[1.2, 1], [1, 0]]).astype(np.float32)
y = np.array([[1, 2], [3, 4.0]]).astype(np.float32)
output = select_forward_func(Tensor(cond), Tensor(x), Tensor(y))
print(output.asnumpy())
expect = [[1.2, 2], [1, 4.0]]
error = np.ones(shape=[2, 2]) * 1.0e-6
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_select_float16(mode):
"""
Feature: Test functional select operator. Support x or y is a float16 Tensor.
Description: Operator select's inputs `x` and `y` are Tensor with float16 type.
Expectation: Assert result.
"""
context.set_context(mode=mode)
cond = np.array([[True, False], [True, False]]).astype(np.bool)
x = np.array([[1.2, 1], [1, 0]]).astype(np.float16)
y = np.array([[1, 2], [3, 4.0]]).astype(np.float16)
output = select_forward_func(Tensor(cond), Tensor(x), Tensor(y))
print(output.asnumpy())
expect = [[1.2, 2], [1, 4.0]]
error = np.ones(shape=[2, 2]) * 1.0e-3
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_select_int32(mode):
"""
Feature: Test functional select operator. Support x or y is a int32 Tensor.
Description: Operator select's inputs `x` and `y` are Tensor with int32 type.
Expectation: Assert result.
"""
context.set_context(mode=mode)
cond = np.array([[True, False], [True, False]]).astype(np.bool)
x = np.array([[12, 1], [1, 0]]).astype(np.int32)
y = np.array([[1, 2], [3, 4]]).astype(np.int32)
output = select_forward_func(Tensor(cond), Tensor(x), Tensor(y))
print(output.asnumpy())
expect = [[12, 2], [1, 4]]
error = np.ones(shape=[2, 2]) * 1.0e-6
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_functional_select_scalar(mode):
"""
Feature: Test functional select operator. Support x or y is a int/float.
Description: Operator select's input `x` is a Tensor with int32 type, input `y` is a int.
Expectation: Assert result.
"""
context.set_context(mode=mode)
cond = np.array([[True, False], [True, False]]).astype(np.bool)
x = np.array([[12, 1], [1, 0]]).astype(np.int32)
y = 2
output = select_forward_func(Tensor(cond), Tensor(x), y)
print(output.asnumpy())
expect = [[12, 2], [1, 2]]
error = np.ones(shape=[2, 2]) * 1.0e-6
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_functional_select_broadcast(mode):
"""
Feature: Test functional select operator support broadcast input.
Description: Operator select's support broadcast input.
Expectation: Assert result.
"""
context.set_context(mode=mode)
cond = Tensor(np.random.rand(1, 65, 54, 12, 5, 2), dtype=mstype.bool_)
x = Tensor(np.random.rand(5, 5, 65, 1, 12, 5, 2).astype(np.float32))
y = Tensor(np.random.rand(65, 54, 1, 5, 2).astype(np.float32))
ret = select_forward_func(cond, x, y)
assert ret.shape == (5, 5, 65, 54, 12, 5, 2)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.parametrize('mode', ['pynative', 'KBK', 'GE'])
def test_select_ext_static_shape(mode):
"""
Feature: Test select with static shape in graph and pynative mode.
Description: call ops.select with valid input and index.
Expectation: return the correct value.
"""
x = generate_random_input((2, 3, 4, 5), np.float32)
y = generate_random_input((2, 3, 4, 5), np.float32)
cond = x > 0
if mode == 'pynative':
ms_out = select_forward_func(cond, x, y)
elif mode == 'KBK':
ms_out = (jit(select_forward_func, jit_config=JitConfig(jit_level="O0")))(cond, x, y)
else:
ms_out = (jit(select_forward_func, jit_config=JitConfig(jit_level="O2")))(cond, x, y)
expect = generate_expect_forward_output(cond.asnumpy(), x.asnumpy(), y.asnumpy())
assert np.allclose(ms_out.asnumpy(), expect, rtol=1e-4)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.parametrize('jit_level', ["O0", "O2"])
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_gpu_training
def test_select_ext_dynamic_shape(jit_level):
"""
Feature: Test select with dynamic shape in graph mode.
Description: call ops.select with valid input and index.
Expectation: return the correct value.
"""
x1 = generate_random_input((2, 3, 4, 5), np.float32)
y1 = generate_random_input((2, 3, 4, 5), np.float32)
cond1 = x1 > 0
x2 = generate_random_input((6, 7, 8), np.float32)
y2 = generate_random_input((6, 7, 8), np.float32)
cond2 = x2 > 0
TEST_OP(select_forward_func, [[cond1, x1, y1], [cond2, x2, y2]], grad=True, jit_level=jit_level)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.parametrize('graph_level', ["0", "1"])
def test_select_vmap(graph_level):
"""
Feature: Test select with vmap.
Description: call ops.select with valid input and index.
Expectation: return the correct value.
"""
def _foreach_run(condition, x, y, batch):
out = []
for i in range(condition.shape[batch]):
if batch == -1:
cond_inner = condition[..., i]
x_inner = x[..., i]
y_inner = y[..., i]
else:
cond_inner = condition[i, ...]
x_inner = x[i, ...]
y_inner = y[i, ...]
out.append(select_forward_func(cond_inner, x_inner, y_inner))
out = ops.Stack()(out)
return out
os.environ['GRAPH_OP_RUN'] = graph_level
x = generate_random_input((2, 3, 4, 5), np.float32)
y = generate_random_input((2, 3, 4, 5), np.float32)
cond = x > 0
batch_axis = -1
output = select_vmap_func(cond, x, y, batch_axis)
expect = _foreach_run(cond, x, y, batch_axis)
assert np.allclose(output.asnumpy(), expect.asnumpy(), rtol=1e-4)
batch_axis = 0
output = select_vmap_func(cond, x, y, batch_axis)
expect = _foreach_run(cond, x, y, batch_axis)
assert np.allclose(output.asnumpy(), expect.asnumpy(), rtol=1e-4)
del os.environ['GRAPH_OP_RUN']
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.parametrize("mode", ['pynative', 'GE', 'KBK'])
def test_select_ext_grad(mode):
"""
Feature: Test select with backward.
Description: call ops.select with valid input and index.
Expectation: return the correct value.
"""
x = generate_random_input((2, 3, 4, 5), np.float32)
y = generate_random_input((2, 3, 4, 5), np.float32)
cond = x > 0
if mode == 'pynative':
ms_cond, ms_x, ms_y = select_backward_func(cond, x, y)
elif mode == 'KBK':
ms_cond, ms_x, ms_y = (jit(select_backward_func, jit_config=JitConfig(jit_level="O0")))(cond, x, y)
else:
ms_cond, ms_x, ms_y = (jit(select_backward_func, jit_config=JitConfig(jit_level="O2")))(cond, x, y)
expect_cond, expect_x, expect_y = generate_expect_backward_output(cond.asnumpy())
assert np.allclose(ms_cond.asnumpy(), expect_cond, rtol=1e-4)
assert np.allclose(ms_x.asnumpy(), expect_x, rtol=1e-4)
assert np.allclose(ms_y.asnumpy(), expect_y, rtol=1e-4)

View File

@ -1,15 +1,57 @@
# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test where"""
import numpy as np
import pytest
import os
import mindspore.common.dtype as mstype
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore import context
from mindspore.ops import where
from mindspore import ops, Tensor, jit, JitConfig, context
from tests.st.ops.dynamic_shape.test_op_utils import TEST_OP
from tests.st.utils import test_utils
class Net(nn.Cell):
def construct(self, condition, x, y):
return ops.where(condition, x, y)
def generate_random_input(shape, dtype):
return Tensor(np.random.randn(*shape).astype(dtype))
def generate_expect_forward_output(condition, x, y):
return np.where(condition, x, y)
def generate_expect_backward_output(condition):
return np.zeros(np.shape(condition), dtype=np.bool_),\
np.where(condition, 1, 0), np.where(condition, 0, 1)
@test_utils.run_with_cell
def where_forward_func(condition, x, y):
return where(condition, x, y)
@test_utils.run_with_cell
def where_backward_func(condition, x, y):
return ops.grad(where_forward_func, (0, 1, 2))(condition, x, y)
@test_utils.run_with_cell
def where_vmap_func(condition, x, y, in_axes=0):
return ops.vmap(where_forward_func, in_axes, out_axes=0)(condition, x, y)
@pytest.mark.level2
@ -27,10 +69,135 @@ def test_ops_where(mode):
Expectation: success
"""
context.set_context(mode=mode)
net = Net()
x = Tensor(np.arange(4).reshape((2, 2)), mstype.float32)
y = Tensor(np.ones((2, 2)), mstype.float32)
condition = x < 3
output = net(condition, x, y)
output = where_forward_func(condition, x, y)
expected = np.array([[0, 1], [2, 1]], dtype=np.float32)
assert np.allclose(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.parametrize('mode', ['pynative', 'KBK', 'GE'])
def test_where_ext_static_shape(mode):
"""
Feature: Test where with static shape in graph and pynative mode.
Description: call ops.where with valid input and index.
Expectation: return the correct value.
"""
x = generate_random_input((2, 3, 4, 5), np.float32)
y = generate_random_input((2, 3, 4, 5), np.float32)
cond = x > 0
if mode == 'pynative':
ms_out = where_forward_func(cond, x, y)
elif mode == 'KBK':
ms_out = (jit(where_forward_func, jit_config=JitConfig(jit_level="O0")))(cond, x, y)
else:
ms_out = (jit(where_forward_func, jit_config=JitConfig(jit_level="O2")))(cond, x, y)
expect = generate_expect_forward_output(cond.asnumpy(), x.asnumpy(), y.asnumpy())
assert np.allclose(ms_out.asnumpy(), expect, rtol=1e-4)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.parametrize('jit_level', ["O0", "O2"])
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_gpu_training
def test_where_ext_dynamic_shape(jit_level):
"""
Feature: Test where with dynamic shape in graph mode.
Description: call ops.where with valid input and index.
Expectation: return the correct value.
"""
x1 = generate_random_input((2, 3, 4, 5), np.float32)
y1 = generate_random_input((2, 3, 4, 5), np.float32)
cond1 = x1 > 0
x2 = generate_random_input((6, 7, 8), np.float32)
y2 = generate_random_input((6, 7, 8), np.float32)
cond2 = x2 > 0
TEST_OP(where_forward_func, [[cond1, x1, y1], [cond2, x2, y2]], grad=True, jit_level=jit_level)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.parametrize('graph_level', ["0", "1"])
def test_where_vmap(graph_level):
"""
Feature: Test where with vmap.
Description: call ops.where with valid input and index.
Expectation: return the correct value.
"""
def _foreach_run(condition, x, y, batch):
out = []
for i in range(condition.shape[batch]):
if batch == -1:
cond_inner = condition[..., i]
x_inner = x[..., i]
y_inner = y[..., i]
else:
cond_inner = condition[i, ...]
x_inner = x[i, ...]
y_inner = y[i, ...]
out.append(where_forward_func(cond_inner, x_inner, y_inner))
out = ops.Stack()(out)
return out
os.environ['GRAPH_OP_RUN'] = graph_level
x = generate_random_input((2, 3, 4, 5), np.float32)
y = generate_random_input((2, 3, 4, 5), np.float32)
cond = x > 0
batch_axis = -1
output = where_vmap_func(cond, x, y, batch_axis)
expect = _foreach_run(cond, x, y, batch_axis)
assert np.allclose(output.asnumpy(), expect.asnumpy(), rtol=1e-4)
batch_axis = 0
output = where_vmap_func(cond, x, y, batch_axis)
expect = _foreach_run(cond, x, y, batch_axis)
assert np.allclose(output.asnumpy(), expect.asnumpy(), rtol=1e-4)
del os.environ['GRAPH_OP_RUN']
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.parametrize("mode", ['pynative', 'GE', 'KBK'])
def test_where_ext_grad(mode):
"""
Feature: Test where with backward.
Description: call ops.where with valid input and index.
Expectation: return the correct value.
"""
x = generate_random_input((2, 3, 4, 5), np.float32)
y = generate_random_input((2, 3, 4, 5), np.float32)
cond = x > 0
if mode == 'pynative':
ms_cond, ms_x, ms_y = where_backward_func(cond, x, y)
elif mode == 'KBK':
ms_cond, ms_x, ms_y = (jit(where_backward_func, jit_config=JitConfig(jit_level="O0")))(cond, x, y)
else:
ms_cond, ms_x, ms_y = (jit(where_backward_func, jit_config=JitConfig(jit_level="O2")))(cond, x, y)
expect_cond, expect_x, expect_y = generate_expect_backward_output(cond.asnumpy())
assert np.allclose(ms_cond.asnumpy(), expect_cond, rtol=1e-4)
assert np.allclose(ms_x.asnumpy(), expect_x, rtol=1e-4)
assert np.allclose(ms_y.asnumpy(), expect_y, rtol=1e-4)

View File

@ -175,7 +175,7 @@ def test_collect_custom_aicpu():
profiler.analyse()
aicpu_intermediate_file_list = glob.glob(f"{tmpdir}/profiler/aicpu_intermediate_*.csv")
assert len(aicpu_intermediate_file_list) == 1
s1 = {'Select', 'Xlogy', 'Cast'}
s1 = {'Cast', 'BroadcastTo', 'Select', 'Xlogy'}
s2 = set()
with open(aicpu_intermediate_file_list[0], 'r') as fr:
reader = csv.DictReader(fr)

View File

@ -0,0 +1,44 @@
/**
* Copyright 2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "common/common_test.h"
#include "ops/ops_func_impl/select.h"
#include "ops/test_ops.h"
#include "ops/test_ops_cmp_utils.h"
#include "ops/test_value_utils.h"
namespace mindspore {
namespace ops {
OP_FUNC_IMPL_TEST_DECLARE(Select, MultiInputOpParams);
OP_FUNC_IMPL_TEST_CASES(
Select,
testing::Values(
MultiInputOpParams{{{2, 3}, {2, 3}, {2, 3}}, {kBool, kFloat32, kFloat32}, {{2, 3}}, {kFloat32}, {}},
MultiInputOpParams{{{2, 3}, {2, 3}, {2, 3}}, {kBool, kFloat32, kInt32}, {{2, 3}}, {kFloat32}, {}},
MultiInputOpParams{{{-1, 3}, {2, 3}, {2, 3}}, {kBool, kFloat32, kFloat32}, {{2, 3}}, {kFloat32}, {}},
MultiInputOpParams{{{2, -1}, {2, 3}, {2, 3}}, {kBool, kFloat32, kFloat32}, {{2, 3}}, {kFloat32}, {}},
MultiInputOpParams{{{2, -1}, {2, -1}, {2, -1}}, {kBool, kFloat32, kFloat32}, {{2, -1}}, {kFloat32}, {}},
MultiInputOpParams{{{-1, -1}, {-1, -1}, {2, -1}}, {kBool, kFloat32, kFloat32}, {{2, -1}}, {kFloat32}, {}},
MultiInputOpParams{{{-1, -1}, {-1, -1}, {-1, -1}}, {kBool, kFloat32, kFloat32}, {{-1, -1}}, {kFloat32}, {}},
MultiInputOpParams{{{4, 5, 8}, {1, 5, 8}, {4, 1, 8}}, {kBool, kFloat32, kFloat32}, {{4, 5, 8}}, {kFloat32}, {}},
MultiInputOpParams{{{1, 65, 54, 12, 5, 2}, {5, 5, 65, 1, 12, 5, 2}, {65, 54, 1, 5, 2}},
{kBool, kFloat32, kFloat32},
{{5, 5, 65, 54, 12, 5, 2}},
{kFloat32},
{}}));
} // namespace ops
} // namespace mindspore