diff --git a/docs/api/api_python/ops/mindspore.ops.func_select.rst b/docs/api/api_python/ops/mindspore.ops.func_select.rst index 1eb0ab28841..d7f2149c28a 100644 --- a/docs/api/api_python/ops/mindspore.ops.func_select.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_select.rst @@ -1,31 +1,31 @@ mindspore.ops.select ==================== -.. py:function:: mindspore.ops.select(cond, x, y) +.. py:function:: mindspore.ops.select(condition, input, other) - 根据条件判断Tensor中的元素的值,来决定输出中的相应元素是从 `x` (如果元素值为True)还是从 `y` (如果元素值为False)中选择。 + 根据条件判断Tensor中的元素的值,来决定输出中的相应元素是从 `input` (如果元素值为True)还是从 `other` (如果元素值为False)中选择。 该算法可以被定义为: .. math:: out_i = \begin{cases} - x_i, & \text{if } cond_i \\ - y_i, & \text{otherwise} + input_i, & \text{if } condition_i \\ + other_i, & \text{otherwise} \end{cases} 参数: - - **cond** (Tensor[bool]) - 条件Tensor,决定选择哪一个元素,shape是 :math:`(x_1, x_2, ..., x_N, ..., x_R)`。 - - **x** (Union[Tensor, int, float]) - 第一个被选择的Tensor或者数字。 - 如果x是一个Tensor,那么shape是或者可以被广播为 :math:`(x_1, x_2, ..., x_N, ..., x_R)`。 - 如果x是int或者float,那么将会被转化为int32或者float32类型,并且被广播为与y相同的shape。x和y中至少要有一个Tensor。 - - **y** (Union[Tensor, int, float]) - 第二个被选择的Tensor或者数字。 - 如果y是一个Tensor,那么shape是或者可以被广播为 :math:`(x_1, x_2, ..., x_N, ..., x_R)`。 - 如果y是int或者float,那么将会被转化为int32或者float32类型,并且被广播为与x相同的shape。x和y中至少要有一个Tensor。 + - **condition** (Tensor[bool]) - 条件Tensor,决定选择哪一个元素,shape是 :math:`(x_1, x_2, ..., x_N, ..., x_R)`。 + - **input** (Union[Tensor, int, float]) - 第一个被选择的Tensor或者数字。 + 如果input是一个Tensor,那么shape是或者可以被广播为 :math:`(x_1, x_2, ..., x_N, ..., x_R)`。 + 如果input是int或者float,那么将会被转化为int32或者float32类型,并且被广播为与y相同的shape。x和y中至少要有一个Tensor。 + - **other** (Union[Tensor, int, float]) - 第二个被选择的Tensor或者数字。 + 如果other是一个Tensor,那么shape是或者可以被广播为 :math:`(x_1, x_2, ..., x_N, ..., x_R)`。 + 如果other是int或者float,那么将会被转化为int32或者float32类型,并且被广播为与x相同的shape。x和y中至少要有一个Tensor。 返回: - Tensor,与 `cond` 的shape相同。 + Tensor,与 `condition` 的shape相同。 异常: - - **TypeError** - `x` 和 `y` 不是Tensor、int或者float。 + - **TypeError** - `input` 和 `other` 不是Tensor、int或者float。 - **ValueError** - 输入的shape不能被广播。 diff --git a/docs/api/api_python/ops/mindspore.ops.func_where.rst b/docs/api/api_python/ops/mindspore.ops.func_where.rst index 4f4c6b9f99d..e72c94acd12 100644 --- a/docs/api/api_python/ops/mindspore.ops.func_where.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_where.rst @@ -1,23 +1,23 @@ mindspore.ops.where ==================== -.. py:function:: mindspore.ops.where(condition, x, y) +.. py:function:: mindspore.ops.where(condition, input, other) - 返回一个Tensor,Tensor的元素从 `x` 或 `y` 中根据 `condition` 选择。 + 返回一个Tensor,Tensor的元素从 `input` 或 `other` 中根据 `condition` 选择。 .. math:: - output_i = \begin{cases} x_i,\quad &if\ condition_i \\ y_i,\quad &otherwise \end{cases} + output_i = \begin{cases} input_i,\quad &if\ condition_i \\ other_i,\quad &otherwise \end{cases} 参数: - - **condition** (Tensor[bool]) - 如果是 ``True`` ,选取 `x` 中的元素,否则选取 `y` 中的元素。 - - **x** (Union[Tensor, Scalar]) - 在 `condition` 为 ``True`` 的索引处选择的值。 - - **y** (Union[Tensor, Scalar]) - 当 `condition` 为 ``False`` 的索引处选择的值。 + - **condition** (Tensor[bool]) - 如果是 ``True`` ,选取 `input` 中的元素,否则选取 `other` 中的元素。 + - **input** (Union[Tensor, Scalar]) - 在 `condition` 为 ``True`` 的索引处选择的值。 + - **other** (Union[Tensor, Scalar]) - 当 `condition` 为 ``False`` 的索引处选择的值。 返回: - Tensor,其中的元素从 `x` 和 `y` 中选取。 + Tensor,其中的元素从 `input` 和 `other` 中选取。 异常: - **TypeError** - 如果 `condition` 不是Tensor。 - - **TypeError** - 如果 `x` 和 `y` 都是常量。 - - **ValueError** - `condition` 、 `x` 和 `y` 不能互相广播。 + - **TypeError** - 如果 `input` 和 `other` 都是常量。 + - **ValueError** - `condition` 、 `input` 和 `other` 不能互相广播。 diff --git a/mindspore/ccsrc/frontend/expander/bprop/grad_ops/grad_array_ops.cc b/mindspore/ccsrc/frontend/expander/bprop/grad_ops/grad_array_ops.cc index 397451b3eba..58d7ca52f1c 100644 --- a/mindspore/ccsrc/frontend/expander/bprop/grad_ops/grad_array_ops.cc +++ b/mindspore/ccsrc/frontend/expander/bprop/grad_ops/grad_array_ops.cc @@ -925,7 +925,10 @@ REG_BPROP_BUILDER("Select").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) { auto dout = ib->GetInput(kIndex4); auto dx = x->need_compute_grad_out() ? ib->Select(cond, dout, ib->ZerosLike(x)) : ib->OutZeros(x); auto dy = x->need_compute_grad_out() ? ib->Select(cond, ib->ZerosLike(y), dout) : ib->OutZeros(y); - return {ib->OutZeros(cond), dx, dy}; + auto bc_x = BinopGradCommon(ib, cond, x, dout, dx); + auto bc_y = BinopGradCommon(ib, cond, y, dout, dy); + auto ret = BinopGradCommon(ib, x, y, bc_x[kIndex1], bc_y[kIndex1]); + return {ib->OutZeros(cond), ret[kIndex0], ret[kIndex1]}; }); REG_BPROP_BUILDER("OnesLike").SetUnusedInputs({i0, i1, i2}).SetBody(ReturnZeros); diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/broadcast_for_select.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/broadcast_for_select.cc new file mode 100644 index 00000000000..a694870b693 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/broadcast_for_select.cc @@ -0,0 +1,127 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/ascend/optimizer/ge/broadcast_for_select.h" +#include +#include +#include +#include "mindspore/core/ops/array_ops.h" +#include "include/common/utils/anfalgo.h" +#include "include/backend/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +ShapeVector GetSelectInputShape(const AnfNodePtr &input) { + MS_EXCEPTION_IF_NULL(input); + auto input_base_shape = input->Shape(); + MS_EXCEPTION_IF_NULL(input_base_shape); + auto input_shape = input_base_shape->cast(); + MS_EXCEPTION_IF_NULL(input_shape); + return input_shape->shape(); +} + +ShapeVector CalcBroadcastShape(AnfNodePtr cond, AnfNodePtr x, AnfNodePtr y) { + auto cond_shape = GetSelectInputShape(cond); + auto x_shape = GetSelectInputShape(x); + auto y_shape = GetSelectInputShape(y); + auto cond_size = cond_shape.size(); + auto x_size = x_shape.size(); + auto y_size = y_shape.size(); + ShapeVector broadcast_shape = + cond_size > x_size ? cond_size > y_size ? cond_shape : y_shape : x_size > y_size ? x_shape : y_shape; + auto n = broadcast_shape.size(); + for (size_t i = n; i > 0; --i) { + auto cond_i = cond_size < i ? 1 : cond_shape[cond_size - i]; + auto x_i = x_size < i ? 1 : x_shape[x_size - i]; + auto y_i = y_size < i ? 1 : y_shape[y_size - i]; + auto broadcost_i = std::max(cond_i, std::max(x_i, y_i)); + if (cond_i != broadcost_i && cond_i != 1) { + MS_EXCEPTION(ValueError) << "For select, condition input can not broadcast at index " << i; + } + if (x_i != broadcost_i && x_i != 1) { + MS_EXCEPTION(ValueError) << "For select, x input can not broadcast at index " << i; + } + if (y_i != broadcost_i && y_i != 1) { + MS_EXCEPTION(ValueError) << "For select, y input can not broadcast at index " << i; + } + broadcast_shape[n - i] = broadcost_i; + } + return broadcast_shape; +} + +CNodePtr AddBroadCastToNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, + const std::vector &broad_shape) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(input_node); + auto input_type = common::AnfAlgo::GetOutputInferDataType(input_node, 0); + auto shape_node = opt::CreateValueNodeWithKernelInfo(func_graph, MakeValue(broad_shape)); + + std::vector broadcastto_inputs = { + NewValueNode(std::make_shared(prim::kPrimBroadcastTo->name())), input_node, shape_node}; + CNodePtr broadcastto_node = NewCNode(broadcastto_inputs, func_graph); + MS_EXCEPTION_IF_NULL(broadcastto_node); + broadcastto_node->set_scope(input_node->scope()); + broadcastto_node->set_abstract(input_node->abstract()); + common::AnfAlgo::SetOutputInferTypeAndShape({input_type}, {broad_shape}, broadcastto_node.get()); + return broadcastto_node; +} + +CNodePtr AddSelectNode(const FuncGraphPtr &func_graph, const CNodePtr &cond_node, const CNodePtr &x_node, + const CNodePtr &y_node, const CNodePtr &select_node, const std::vector &broad_shape) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(cond_node); + MS_EXCEPTION_IF_NULL(x_node); + MS_EXCEPTION_IF_NULL(y_node); + MS_EXCEPTION_IF_NULL(select_node); + auto input_type = common::AnfAlgo::GetOutputInferDataType(select_node, 0); + + std::vector select_inputs = {NewValueNode(std::make_shared(prim::kPrimSelect->name())), + cond_node, x_node, y_node}; + CNodePtr out_node = NewCNode(select_inputs, func_graph); + MS_EXCEPTION_IF_NULL(out_node); + out_node->set_scope(select_node->scope()); + out_node->set_abstract(select_node->abstract()); + common::AnfAlgo::SetOutputInferTypeAndShape({input_type}, {broad_shape}, out_node.get()); + return out_node; +} +} // namespace + +const BaseRef BroadCastForSelect::DefinePattern() const { + VarPtr inputs = std::make_shared(); + return VectorRef({prim::kPrimSelect, inputs}); +} + +const AnfNodePtr BroadCastForSelect::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + // Select(...) ===> inputs -> CalcBroadcastShape -> BroadCastTo -> Select(...) + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto select_node = node->cast(); + MS_EXCEPTION_IF_NULL(select_node); + // get broadcast shape + auto cond = select_node->input(kIndex1); + auto x = select_node->input(kIndex2); + auto y = select_node->input(kIndex3); + auto output_shape = CalcBroadcastShape(cond, x, y); + // do BroadCast + auto new_cond = AddBroadCastToNode(graph, cond, output_shape); + auto new_x = AddBroadCastToNode(graph, x, output_shape); + auto new_y = AddBroadCastToNode(graph, y, output_shape); + auto out_node = AddSelectNode(graph, new_cond, new_x, new_y, select_node, output_shape); + return out_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/broadcast_for_select.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/broadcast_for_select.h new file mode 100644 index 00000000000..ef562562c7f --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/broadcast_for_select.h @@ -0,0 +1,37 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_GE_BROADCAST_FOR_SELECT_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_GE_BROADCAST_FOR_SELECT_H_ + +#include +#include +#include +#include +#include "include/backend/optimizer/optimizer.h" +#include "ops/auto_generate/gen_ops_primitive.h" + +namespace mindspore { +namespace opt { +class BroadCastForSelect : public PatternProcessPass { + public: + explicit BroadCastForSelect(bool multi_graph = true) : PatternProcessPass("broadcast_for_select", multi_graph) {} + ~BroadCastForSelect() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_GE_BROADCAST_FOR_SELECT_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ge_backend_optimization.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/ge_backend_optimization.cc index d794df17cf9..a9c882e734e 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ge_backend_optimization.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ge_backend_optimization.cc @@ -63,6 +63,7 @@ #include "plugin/device/ascend/optimizer/enhancer/eliminate_maketuple_getitem.h" #include "plugin/device/ascend/optimizer/ge/convert_pad_v3_paddings.h" #include "plugin/device/ascend/optimizer/ir_fusion/shape_reshape_fusion.h" +#include "plugin/device/ascend/optimizer/ge/broadcast_for_select.h" namespace mindspore { namespace opt { @@ -98,6 +99,7 @@ void GEBackendOptimization(const KernelGraphPtr &kernel_graph) { opt_ge_pm->AddPass(std::make_shared(true, true)); opt_ge_pm->AddPass(std::make_shared("unfold_nested_output")); opt_ge_pm->AddPass(std::make_shared("unfold_nested_maketuple")); + opt_ge_pm->AddPass(std::make_shared()); optimizer->AddPassManager(opt_ge_pm); (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); diff --git a/mindspore/core/ops/ops_def/doc/select_doc.yaml b/mindspore/core/ops/ops_def/doc/select_doc.yaml index 3f0de7d4835..a23e687b4f3 100644 --- a/mindspore/core/ops/ops_def/doc/select_doc.yaml +++ b/mindspore/core/ops/ops_def/doc/select_doc.yaml @@ -1,30 +1,30 @@ select: description: | The conditional tensor determines whether the corresponding element in the output must be - selected from `x` (if True) or `y` (if False) based on the value of each + selected from `input` (if True) or `other` (if False) based on the value of each element. It can be defined as: .. math:: out_i = \begin{cases} - x_i, & \text{if } cond_i \\ - y_i, & \text{otherwise} + input_i, & \text{if } condition_i \\ + other_i, & \text{otherwise} \end{cases} Inputs: - - **cond** (Tensor[bool]): The condition tensor, decides which element is chosen. + - **condition** (Tensor[bool]): The condition tensor, decides which element is chosen. The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`. - - **x** (Tensor): The first Tensor to be selected. + - **input** (Tensor): The first Tensor to be selected. The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`. - - **y** (Tensor): The second Tensor to be selected. + - **other** (Tensor): The second Tensor to be selected. The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`. Outputs: - Tensor, has the same shape as `cond`. + Tensor, has the same shape as `condition`. Raises: - TypeError: If x or y is not a Tensor. + TypeError: If input or other is not a Tensor. ValueError: The shape of inputs are different. Supported Platforms: diff --git a/mindspore/core/ops/ops_def/select_op.yaml b/mindspore/core/ops/ops_def/select_op.yaml index 976169269cc..4322fc41ed5 100644 --- a/mindspore/core/ops/ops_def/select_op.yaml +++ b/mindspore/core/ops/ops_def/select_op.yaml @@ -1,14 +1,18 @@ -#operator select +#operator select/where select: args: - cond: + condition: dtype: tensor - x: + input: dtype: tensor - y: + type_cast: number + other: dtype: tensor + type_cast: number + args_signature: + dtype_group: (condition), (input, other) returns: output: dtype: tensor - function: - disable: True + dispatch: + enable: True diff --git a/mindspore/core/ops/ops_func_impl/select.cc b/mindspore/core/ops/ops_func_impl/select.cc index cbe7a16129b..3fb375763dd 100644 --- a/mindspore/core/ops/ops_func_impl/select.cc +++ b/mindspore/core/ops/ops_func_impl/select.cc @@ -46,18 +46,6 @@ namespace ops { using float_complex = std::complex; using double_complex = std::complex; -void SelectInferShapeCheck(const std::vector &x_shape, const std::vector &y_shape, - const std::vector &cond_shape, size_t shape_size) { - for (size_t i = 0; i < shape_size; i++) { - if ((x_shape[i] > 0 && cond_shape[i] > 0 && x_shape[i] != cond_shape[i]) || - (x_shape[i] > 0 && y_shape[i] > 0 && x_shape[i] != y_shape[i])) { - MS_EXCEPTION(ValueError) - << "For 'Select', the shape of 'condition', 'x' and 'y' must be the same. But got 'condition' shape: " - << cond_shape << ", 'x' shape: " << x_shape << ", 'y' shape: " << y_shape << "."; - } - } -} - abstract::BaseShapePtr SelectFuncImpl::InferShape(const PrimitivePtr &prim, const std::vector &input_args) const { auto cond_shape = input_args[kSelectCondIndex]->GetShape()->GetShapeVector(); @@ -66,16 +54,9 @@ abstract::BaseShapePtr SelectFuncImpl::InferShape(const PrimitivePtr &prim, if (IsDynamicRank(cond_shape) || IsDynamicRank(x_shape) || IsDynamicRank(y_shape)) { return std::make_shared(ShapeVector{abstract::TensorShape::kShapeRankAny}); } - auto cond_shape_size = cond_shape.size(); - auto x_shape_size = x_shape.size(); - auto y_shape_size = y_shape.size(); - if (cond_shape_size != x_shape_size || y_shape_size != x_shape_size) { - MS_EXCEPTION(ValueError) - << "For 'Select', the shape of 'condition', 'x' and 'y' must be the same. But got 'condition' shape: " - << cond_shape << ", 'x' shape: " << x_shape << ", 'y' shape: " << y_shape << "."; - } - SelectInferShapeCheck(x_shape, y_shape, cond_shape, x_shape_size); - return input_args[kSelectCondIndex]->GetShape()->Clone(); + auto broadcast_output_size = CalBroadCastShape(x_shape, y_shape, prim->name(), "input", "other"); + auto output_size = CalBroadCastShape(cond_shape, broadcast_output_size, prim->name(), "condition", "input"); + return std::make_shared(output_size); } TypePtr SelectFuncImpl::InferType(const PrimitivePtr &prim, const std::vector &input_args) const { @@ -94,11 +75,6 @@ TypePtr SelectFuncImpl::InferType(const PrimitivePtr &prim, const std::vectorToString() - << " and y_type: " << y_type->ToString() << "."; - } return x_type->Clone(); } } // namespace ops diff --git a/mindspore/python/mindspore/mint/__init__.py b/mindspore/python/mindspore/mint/__init__.py index 87996adfd86..37084af94cc 100644 --- a/mindspore/python/mindspore/mint/__init__.py +++ b/mindspore/python/mindspore/mint/__init__.py @@ -18,8 +18,9 @@ from mindspore.ops.extend import * from mindspore.ops.extend import array_func, math_func, nn_func from mindspore.mint.nn.functional import * from mindspore.mint.nn import functional +from mindspore.ops import where -__all__ = [] +__all__ = ['where'] __all__.extend(array_func.__all__) __all__.extend(math_func.__all__) __all__.extend(nn_func.__all__) diff --git a/mindspore/python/mindspore/ops/function/array_func.py b/mindspore/python/mindspore/ops/function/array_func.py index 07895edd669..b150527258b 100644 --- a/mindspore/python/mindspore/ops/function/array_func.py +++ b/mindspore/python/mindspore/ops/function/array_func.py @@ -61,7 +61,7 @@ from mindspore.ops._utils.utils import ms_arrange from mindspore.ops.auto_generate import cat, range, scatter_nd, deepcopy, masked_fill, diagonal, expand_dims, \ nonzero, flip, transpose, unsorted_segment_sum, diag, gather, gather_d, gather_nd, reshape, broadcast_to, \ - strided_slice, ones, zeros, max_, min_ + strided_slice, ones, zeros, max_, min_, select from mindspore.ops.operations.manually_defined import tile, rank, scalar_cast arg_max_with_value_ = ArgMaxWithValue() @@ -390,25 +390,25 @@ def hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype return out -def where(condition, x, y): +def where(condition, input, other): r""" - Selects elements from `x` or `y` based on `condition` and returns a tensor. + Selects elements from `input` or `other` based on `condition` and returns a tensor. .. math:: - output_i = \begin{cases} x_i,\quad &if\ condition_i \\ y_i,\quad &otherwise \end{cases} + output_i = \begin{cases} input_i,\quad &if\ condition_i \\ other_i,\quad &otherwise \end{cases} Args: - condition (Tensor[bool]): If True, yield `x`, otherwise yield `y`. - x (Union[Tensor, Scalar]): When `condition` is True, values to select from. - y (Union[Tensor, Scalar]): When `condition` is False, values to select from. + condition (Tensor[bool]): If True, yield `input`, otherwise yield `other`. + input (Union[Tensor, Scalar]): When `condition` is True, values to select from. + other (Union[Tensor, Scalar]): When `condition` is False, values to select from. Returns: - Tensor, elements are selected from `x` and `y`. + Tensor, elements are selected from `input` and `other`. Raises: TypeError: If `condition` is not a Tensor. - TypeError: If both `x` and `y` are scalars. - ValueError: If `condition`, `x` and `y` can not broadcast to each other. + TypeError: If both `input` and `other` are scalars. + ValueError: If `condition`, `input` and `other` can not broadcast to each other. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -425,25 +425,7 @@ def where(condition, x, y): [[0. 1.] [2. 1.]] """ - if not isinstance(condition, Tensor): - raise TypeError(f"For 'where', 'condition' must be a Tensor, but got {type(condition)}.") - if isinstance(x, (int, float)): - if not isinstance(y, Tensor): - raise TypeError( - f"For 'where', at least one of 'x' and 'y' should be Tensor, but got x:{type(x)}, y:{type(y)}." - ) - x = cast_(x, y.dtype) - elif isinstance(y, (int, float)): - if not isinstance(x, Tensor): - raise TypeError( - f"For 'where', at least one of 'x' and 'y' should be Tensor, but got x:{type(x)}, y:{type(y)}." - ) - y = cast_(y, x.dtype) - output_shape = _calc_broadcast_shape(x.shape, y.shape, condition.shape) - condition = broadcast_to(condition, output_shape) - x = broadcast_to(x, output_shape) - y = broadcast_to(y, output_shape) - return tensor_select_(condition, x, y) + return tensor_select_(condition, input, other) def reverse(x, axis): @@ -1435,171 +1417,6 @@ def flatten(input, order='C', *, start_dim=1, end_dim=-1): return reshape_(input, new_shape) -@constexpr -def _check_select_type_match(scalar, tensor_type, scalar_name, tensor_name): - if isinstance(scalar, int) and tensor_type != mstype.int32: - raise TypeError(f"For functional operator[select], the input[{scalar_name}] is int, " - f"then the input[{tensor_name}] must be a Tensor of int32.") - if isinstance(scalar, float) and tensor_type != mstype.float32: - raise TypeError(f"For functional operator[select], the input[{scalar_name}] is float, " - f"then the input[{tensor_name}] must be a Tensor of float32.") - - -@_primexpr -def _check_select_shape_match(input_shape, cond_shape, tensor_name): - if input_shape != cond_shape: - raise ValueError(f"For functional operator[select], the cond shape must be same as {tensor_name} shape.") - - -@constexpr -def _check_select_type(is_cond_tensor, is_x_scalar, is_y_scalar, is_x_tensor, is_y_tensor): - if not is_cond_tensor: - raise TypeError(f"For functional operator[select], the input[cond] must be a Tensor.") - if is_x_scalar and not is_y_tensor: - raise TypeError(f"For functional operator[select], the input[x] is int or float, " - f"then the input[y] must be a Tensor.") - if is_y_scalar and not is_x_tensor: - raise TypeError(f"For functional operator[select], the input[y] is int or float, " - f"then the input[x] must be a Tensor.") - - -@constexpr -def _check_select_shape_same(cond_shape, x_shape, y_shape): - """Check if input of select has same shape.""" - return cond_shape == x_shape and x_shape == y_shape and cond_shape == y_shape - - -@constexpr -def get_max_value(x, y, z): - """Get the maximum value of x, y and z.""" - if x >= y and x >= z: - return x - if y >= x and y >= z: - return y - return z - - -@constexpr -def _calc_broadcast_shape(cond_shape, x_shape, y_shape): - """Calculate broadcast shape for select""" - converted_shape = [] - cond_reverse = cond_shape[::-1] - x_reverse = x_shape[::-1] - y_reverse = y_shape[::-1] - max_len = get_max_value(len(cond_reverse), len(x_reverse), len(y_reverse)) - i = 0 - while i < max_len: - cond_element = 1 if i >= len(cond_reverse) else cond_reverse[i] - x_element = 1 if i >= len(x_reverse) else x_reverse[i] - y_element = 1 if i >= len(y_reverse) else y_reverse[i] - broadcast_element = get_max_value(cond_element, x_element, y_element) - if cond_element not in (1, broadcast_element): - raise ValueError(f"For select, condition input can not broadcast at index {i}") - if x_element not in (1, broadcast_element): - raise ValueError(f"For select, x input can not broadcast at index {i}") - if y_element not in (1, broadcast_element): - raise ValueError(f"For select, y input can not broadcast at index {i}") - converted_shape.append(broadcast_element) - i = i + 1 - converted_shape.reverse() - return tuple(converted_shape) - - -def select(cond, x, y): - r""" - The conditional tensor determines whether the corresponding element in the output must be - selected from `x` (if true) or `y` (if false) based on the value of each element. - - It can be defined as: - - .. math:: - out_i = \begin{cases} - x_i, & \text{if } cond_i \\ - y_i, & \text{otherwise} - \end{cases} - - Args: - cond (Tensor[bool]): The condition tensor, decides which element is chosen. - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`. - x (Union[Tensor, int, float]): The first Tensor or number to be selected. - If x is a Tensor, the shape is or can be broadcadt to :math:`(x_1, x_2, ..., x_N, ..., x_R)`. - If x is an int or a float, it will be cast to the type of int32 or float32, - and broadcast to the same shape as y. One of x and y must be a Tensor. - y (Union[Tensor, int, float]): The second Tensor or number to be selected. - If y is a Tensor, The shape is or can be broadcadt to :math:`(x_1, x_2, ..., x_N, ..., x_R)`. - If y is an int or a float, it will be cast to the type of int32 or float32, - and broadcast to the same shape as x. One of x and y must be a Tensor. - - Returns: - Tensor, has the same shape as `cond`. - - Raises: - TypeError: If `x` or `y` is not a Tensor, int or float. - ValueError: The shapes of inputs can not be broadcast. - - Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` - - Examples: - >>> import mindspore - >>> from mindspore import Tensor, ops - >>> # 1) Both inputs are Tensor - >>> - >>> cond = Tensor([True, False]) - >>> x = Tensor([2,3], mindspore.float32) - >>> y = Tensor([1,2], mindspore.float32) - >>> output = ops.select(cond, x, y) - >>> print(output) - [2. 2.] - >>> # 2) y is a float - >>> cond = Tensor([True, False]) - >>> x = Tensor([2,3], mindspore.float32) - >>> y = 2.0 - >>> output = ops.select(cond, x, y) - >>> print(output) - [2. 2.] - """ - is_x_scalar = isinstance(x, (int, float)) - is_y_scalar = isinstance(y, (int, float)) - is_x_tensor = isinstance(x, Tensor) - is_y_tensor = isinstance(y, Tensor) - is_cond_tensor = isinstance(cond, Tensor) - _check_select_type(is_cond_tensor, is_x_scalar, is_y_scalar, is_x_tensor, is_y_tensor) - input_x = x - input_y = y - if is_x_scalar: - _check_select_shape_match(y.shape, cond.shape, "y") - _check_select_type_match(x, y.dtype, "x", "y") - input_x = zeros_like_(y) + x - if isinstance(x, int): - input_x = cast_(input_x, mstype.int32) - else: - input_x = cast_(input_x, mstype.float32) - - if is_y_scalar: - _check_select_shape_match(x.shape, cond.shape, "x") - _check_select_type_match(y, x.dtype, "y", "x") - input_y = zeros_like_(x) + y - if isinstance(y, int): - input_y = cast_(input_y, mstype.int32) - else: - input_y = cast_(input_y, mstype.float32) - - if is_x_tensor and is_y_tensor and is_cond_tensor: - x_shape = ops.shape(x) - y_shape = ops.shape(y) - cond_shape = ops.shape(cond) - all_constant = ops.isconstant(cond_shape) and ops.isconstant(x_shape) and ops.isconstant(y_shape) - if all_constant and not _check_select_shape_same(cond_shape, x_shape, y_shape): - broadcast_shape = _calc_broadcast_shape(cond_shape, x_shape, y_shape) - new_cond = ops.broadcast_to(cond, broadcast_shape) - new_x = ops.broadcast_to(x, broadcast_shape) - new_y = ops.broadcast_to(y, broadcast_shape) - return tensor_select_(new_cond, new_x, new_y) - - return tensor_select_(cond, input_x, input_y) - - def slice(input_x, begin, size): r""" Slices a tensor in the specified shape. diff --git a/mindspore/python/mindspore/ops_generate/aclnn_config.yaml b/mindspore/python/mindspore/ops_generate/aclnn_config.yaml index 892d31a795d..55529d79d39 100644 --- a/mindspore/python/mindspore/ops_generate/aclnn_config.yaml +++ b/mindspore/python/mindspore/ops_generate/aclnn_config.yaml @@ -36,3 +36,4 @@ GroupNormGrad: 'aclnnGroupNormBackward' NotEqual: 'aclnnNeTensor' ClampScalar: 'aclnnClamp' OneHotExt: 'aclnnOneHot' +Select: 'aclnnSWhere' diff --git a/tests/st/ops/test_ops_select.py b/tests/st/ops/test_ops_select.py new file mode 100644 index 00000000000..d04b5e65510 --- /dev/null +++ b/tests/st/ops/test_ops_select.py @@ -0,0 +1,310 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + +"""test select""" +import numpy as np +import pytest +import os +import mindspore.common.dtype as mstype + +from mindspore.ops import select +from mindspore import ops, Tensor, jit, JitConfig, context +from tests.st.ops.dynamic_shape.test_op_utils import TEST_OP +from tests.st.utils import test_utils + + +def generate_random_input(shape, dtype): + return Tensor(np.random.randn(*shape).astype(dtype)) + + +def generate_expect_forward_output(condition, x, y): + return np.where(condition, x, y) + + +def generate_expect_backward_output(condition): + return np.zeros(np.shape(condition), dtype=np.bool_),\ + np.where(condition, 1, 0), np.where(condition, 0, 1) + + +@test_utils.run_with_cell +def select_forward_func(condition, x, y): + return select(condition, x, y) + + +@test_utils.run_with_cell +def select_backward_func(condition, x, y): + return ops.grad(select_forward_func, (0, 1, 2))(condition, x, y) + + +@test_utils.run_with_cell +def select_vmap_func(condition, x, y, in_axes=0): + return ops.vmap(select_forward_func, in_axes, out_axes=0)(condition, x, y) + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_select_float32(mode): + """ + Feature: Test functional select operator. Support x or y is a float32 Tensor. + Description: Operator select's inputs `x` and `y` are Tensor with float32 type. + Expectation: Assert result. + """ + context.set_context(mode=mode) + cond = np.array([[True, False], [True, False]]).astype(np.bool) + x = np.array([[1.2, 1], [1, 0]]).astype(np.float32) + y = np.array([[1, 2], [3, 4.0]]).astype(np.float32) + output = select_forward_func(Tensor(cond), Tensor(x), Tensor(y)) + print(output.asnumpy()) + expect = [[1.2, 2], [1, 4.0]] + error = np.ones(shape=[2, 2]) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_select_float16(mode): + """ + Feature: Test functional select operator. Support x or y is a float16 Tensor. + Description: Operator select's inputs `x` and `y` are Tensor with float16 type. + Expectation: Assert result. + """ + context.set_context(mode=mode) + cond = np.array([[True, False], [True, False]]).astype(np.bool) + x = np.array([[1.2, 1], [1, 0]]).astype(np.float16) + y = np.array([[1, 2], [3, 4.0]]).astype(np.float16) + output = select_forward_func(Tensor(cond), Tensor(x), Tensor(y)) + print(output.asnumpy()) + expect = [[1.2, 2], [1, 4.0]] + error = np.ones(shape=[2, 2]) * 1.0e-3 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_select_int32(mode): + """ + Feature: Test functional select operator. Support x or y is a int32 Tensor. + Description: Operator select's inputs `x` and `y` are Tensor with int32 type. + Expectation: Assert result. + """ + context.set_context(mode=mode) + cond = np.array([[True, False], [True, False]]).astype(np.bool) + x = np.array([[12, 1], [1, 0]]).astype(np.int32) + y = np.array([[1, 2], [3, 4]]).astype(np.int32) + output = select_forward_func(Tensor(cond), Tensor(x), Tensor(y)) + print(output.asnumpy()) + expect = [[12, 2], [1, 4]] + error = np.ones(shape=[2, 2]) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_functional_select_scalar(mode): + """ + Feature: Test functional select operator. Support x or y is a int/float. + Description: Operator select's input `x` is a Tensor with int32 type, input `y` is a int. + Expectation: Assert result. + """ + context.set_context(mode=mode) + cond = np.array([[True, False], [True, False]]).astype(np.bool) + x = np.array([[12, 1], [1, 0]]).astype(np.int32) + y = 2 + output = select_forward_func(Tensor(cond), Tensor(x), y) + print(output.asnumpy()) + expect = [[12, 2], [1, 2]] + error = np.ones(shape=[2, 2]) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_functional_select_broadcast(mode): + """ + Feature: Test functional select operator support broadcast input. + Description: Operator select's support broadcast input. + Expectation: Assert result. + """ + context.set_context(mode=mode) + cond = Tensor(np.random.rand(1, 65, 54, 12, 5, 2), dtype=mstype.bool_) + x = Tensor(np.random.rand(5, 5, 65, 1, 12, 5, 2).astype(np.float32)) + y = Tensor(np.random.rand(65, 54, 1, 5, 2).astype(np.float32)) + ret = select_forward_func(cond, x, y) + assert ret.shape == (5, 5, 65, 54, 12, 5, 2) + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_cpu_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.parametrize('mode', ['pynative', 'KBK', 'GE']) +def test_select_ext_static_shape(mode): + """ + Feature: Test select with static shape in graph and pynative mode. + Description: call ops.select with valid input and index. + Expectation: return the correct value. + """ + x = generate_random_input((2, 3, 4, 5), np.float32) + y = generate_random_input((2, 3, 4, 5), np.float32) + cond = x > 0 + + if mode == 'pynative': + ms_out = select_forward_func(cond, x, y) + elif mode == 'KBK': + ms_out = (jit(select_forward_func, jit_config=JitConfig(jit_level="O0")))(cond, x, y) + else: + ms_out = (jit(select_forward_func, jit_config=JitConfig(jit_level="O2")))(cond, x, y) + + expect = generate_expect_forward_output(cond.asnumpy(), x.asnumpy(), y.asnumpy()) + assert np.allclose(ms_out.asnumpy(), expect, rtol=1e-4) + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.parametrize('jit_level', ["O0", "O2"]) +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_cpu_training +@pytest.mark.platform_x86_gpu_training +def test_select_ext_dynamic_shape(jit_level): + """ + Feature: Test select with dynamic shape in graph mode. + Description: call ops.select with valid input and index. + Expectation: return the correct value. + """ + x1 = generate_random_input((2, 3, 4, 5), np.float32) + y1 = generate_random_input((2, 3, 4, 5), np.float32) + cond1 = x1 > 0 + + x2 = generate_random_input((6, 7, 8), np.float32) + y2 = generate_random_input((6, 7, 8), np.float32) + cond2 = x2 > 0 + TEST_OP(select_forward_func, [[cond1, x1, y1], [cond2, x2, y2]], grad=True, jit_level=jit_level) + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_cpu_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.parametrize('graph_level', ["0", "1"]) +def test_select_vmap(graph_level): + """ + Feature: Test select with vmap. + Description: call ops.select with valid input and index. + Expectation: return the correct value. + """ + def _foreach_run(condition, x, y, batch): + out = [] + for i in range(condition.shape[batch]): + if batch == -1: + cond_inner = condition[..., i] + x_inner = x[..., i] + y_inner = y[..., i] + else: + cond_inner = condition[i, ...] + x_inner = x[i, ...] + y_inner = y[i, ...] + out.append(select_forward_func(cond_inner, x_inner, y_inner)) + out = ops.Stack()(out) + return out + + os.environ['GRAPH_OP_RUN'] = graph_level + x = generate_random_input((2, 3, 4, 5), np.float32) + y = generate_random_input((2, 3, 4, 5), np.float32) + cond = x > 0 + + batch_axis = -1 + output = select_vmap_func(cond, x, y, batch_axis) + expect = _foreach_run(cond, x, y, batch_axis) + assert np.allclose(output.asnumpy(), expect.asnumpy(), rtol=1e-4) + + batch_axis = 0 + output = select_vmap_func(cond, x, y, batch_axis) + expect = _foreach_run(cond, x, y, batch_axis) + assert np.allclose(output.asnumpy(), expect.asnumpy(), rtol=1e-4) + + del os.environ['GRAPH_OP_RUN'] + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_cpu_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.parametrize("mode", ['pynative', 'GE', 'KBK']) +def test_select_ext_grad(mode): + """ + Feature: Test select with backward. + Description: call ops.select with valid input and index. + Expectation: return the correct value. + """ + x = generate_random_input((2, 3, 4, 5), np.float32) + y = generate_random_input((2, 3, 4, 5), np.float32) + cond = x > 0 + + if mode == 'pynative': + ms_cond, ms_x, ms_y = select_backward_func(cond, x, y) + elif mode == 'KBK': + ms_cond, ms_x, ms_y = (jit(select_backward_func, jit_config=JitConfig(jit_level="O0")))(cond, x, y) + else: + ms_cond, ms_x, ms_y = (jit(select_backward_func, jit_config=JitConfig(jit_level="O2")))(cond, x, y) + expect_cond, expect_x, expect_y = generate_expect_backward_output(cond.asnumpy()) + assert np.allclose(ms_cond.asnumpy(), expect_cond, rtol=1e-4) + assert np.allclose(ms_x.asnumpy(), expect_x, rtol=1e-4) + assert np.allclose(ms_y.asnumpy(), expect_y, rtol=1e-4) diff --git a/tests/st/ops/test_ops_where.py b/tests/st/ops/test_ops_where.py index 5d60033b1c9..706114330a3 100644 --- a/tests/st/ops/test_ops_where.py +++ b/tests/st/ops/test_ops_where.py @@ -1,15 +1,57 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + +"""test where""" import numpy as np import pytest +import os import mindspore.common.dtype as mstype -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore import Tensor -from mindspore import context + +from mindspore.ops import where +from mindspore import ops, Tensor, jit, JitConfig, context +from tests.st.ops.dynamic_shape.test_op_utils import TEST_OP +from tests.st.utils import test_utils -class Net(nn.Cell): - def construct(self, condition, x, y): - return ops.where(condition, x, y) +def generate_random_input(shape, dtype): + return Tensor(np.random.randn(*shape).astype(dtype)) + + +def generate_expect_forward_output(condition, x, y): + return np.where(condition, x, y) + + +def generate_expect_backward_output(condition): + return np.zeros(np.shape(condition), dtype=np.bool_),\ + np.where(condition, 1, 0), np.where(condition, 0, 1) + + +@test_utils.run_with_cell +def where_forward_func(condition, x, y): + return where(condition, x, y) + + +@test_utils.run_with_cell +def where_backward_func(condition, x, y): + return ops.grad(where_forward_func, (0, 1, 2))(condition, x, y) + + +@test_utils.run_with_cell +def where_vmap_func(condition, x, y, in_axes=0): + return ops.vmap(where_forward_func, in_axes, out_axes=0)(condition, x, y) @pytest.mark.level2 @@ -27,10 +69,135 @@ def test_ops_where(mode): Expectation: success """ context.set_context(mode=mode) - net = Net() x = Tensor(np.arange(4).reshape((2, 2)), mstype.float32) y = Tensor(np.ones((2, 2)), mstype.float32) condition = x < 3 - output = net(condition, x, y) + output = where_forward_func(condition, x, y) expected = np.array([[0, 1], [2, 1]], dtype=np.float32) assert np.allclose(output.asnumpy(), expected) + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_cpu_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.parametrize('mode', ['pynative', 'KBK', 'GE']) +def test_where_ext_static_shape(mode): + """ + Feature: Test where with static shape in graph and pynative mode. + Description: call ops.where with valid input and index. + Expectation: return the correct value. + """ + x = generate_random_input((2, 3, 4, 5), np.float32) + y = generate_random_input((2, 3, 4, 5), np.float32) + cond = x > 0 + + if mode == 'pynative': + ms_out = where_forward_func(cond, x, y) + elif mode == 'KBK': + ms_out = (jit(where_forward_func, jit_config=JitConfig(jit_level="O0")))(cond, x, y) + else: + ms_out = (jit(where_forward_func, jit_config=JitConfig(jit_level="O2")))(cond, x, y) + + expect = generate_expect_forward_output(cond.asnumpy(), x.asnumpy(), y.asnumpy()) + assert np.allclose(ms_out.asnumpy(), expect, rtol=1e-4) + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.parametrize('jit_level', ["O0", "O2"]) +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_cpu_training +@pytest.mark.platform_x86_gpu_training +def test_where_ext_dynamic_shape(jit_level): + """ + Feature: Test where with dynamic shape in graph mode. + Description: call ops.where with valid input and index. + Expectation: return the correct value. + """ + x1 = generate_random_input((2, 3, 4, 5), np.float32) + y1 = generate_random_input((2, 3, 4, 5), np.float32) + cond1 = x1 > 0 + + x2 = generate_random_input((6, 7, 8), np.float32) + y2 = generate_random_input((6, 7, 8), np.float32) + cond2 = x2 > 0 + TEST_OP(where_forward_func, [[cond1, x1, y1], [cond2, x2, y2]], grad=True, jit_level=jit_level) + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_cpu_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.parametrize('graph_level', ["0", "1"]) +def test_where_vmap(graph_level): + """ + Feature: Test where with vmap. + Description: call ops.where with valid input and index. + Expectation: return the correct value. + """ + def _foreach_run(condition, x, y, batch): + out = [] + for i in range(condition.shape[batch]): + if batch == -1: + cond_inner = condition[..., i] + x_inner = x[..., i] + y_inner = y[..., i] + else: + cond_inner = condition[i, ...] + x_inner = x[i, ...] + y_inner = y[i, ...] + out.append(where_forward_func(cond_inner, x_inner, y_inner)) + out = ops.Stack()(out) + return out + + os.environ['GRAPH_OP_RUN'] = graph_level + x = generate_random_input((2, 3, 4, 5), np.float32) + y = generate_random_input((2, 3, 4, 5), np.float32) + cond = x > 0 + + batch_axis = -1 + output = where_vmap_func(cond, x, y, batch_axis) + expect = _foreach_run(cond, x, y, batch_axis) + assert np.allclose(output.asnumpy(), expect.asnumpy(), rtol=1e-4) + + batch_axis = 0 + output = where_vmap_func(cond, x, y, batch_axis) + expect = _foreach_run(cond, x, y, batch_axis) + assert np.allclose(output.asnumpy(), expect.asnumpy(), rtol=1e-4) + + del os.environ['GRAPH_OP_RUN'] + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_cpu_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.parametrize("mode", ['pynative', 'GE', 'KBK']) +def test_where_ext_grad(mode): + """ + Feature: Test where with backward. + Description: call ops.where with valid input and index. + Expectation: return the correct value. + """ + x = generate_random_input((2, 3, 4, 5), np.float32) + y = generate_random_input((2, 3, 4, 5), np.float32) + cond = x > 0 + + if mode == 'pynative': + ms_cond, ms_x, ms_y = where_backward_func(cond, x, y) + elif mode == 'KBK': + ms_cond, ms_x, ms_y = (jit(where_backward_func, jit_config=JitConfig(jit_level="O0")))(cond, x, y) + else: + ms_cond, ms_x, ms_y = (jit(where_backward_func, jit_config=JitConfig(jit_level="O2")))(cond, x, y) + expect_cond, expect_x, expect_y = generate_expect_backward_output(cond.asnumpy()) + assert np.allclose(ms_cond.asnumpy(), expect_cond, rtol=1e-4) + assert np.allclose(ms_x.asnumpy(), expect_x, rtol=1e-4) + assert np.allclose(ms_y.asnumpy(), expect_y, rtol=1e-4) diff --git a/tests/st/profiler/test_ascend_profiler.py b/tests/st/profiler/test_ascend_profiler.py index caf5d8c9d1e..7df58775088 100644 --- a/tests/st/profiler/test_ascend_profiler.py +++ b/tests/st/profiler/test_ascend_profiler.py @@ -175,7 +175,7 @@ def test_collect_custom_aicpu(): profiler.analyse() aicpu_intermediate_file_list = glob.glob(f"{tmpdir}/profiler/aicpu_intermediate_*.csv") assert len(aicpu_intermediate_file_list) == 1 - s1 = {'Select', 'Xlogy', 'Cast'} + s1 = {'Cast', 'BroadcastTo', 'Select', 'Xlogy'} s2 = set() with open(aicpu_intermediate_file_list[0], 'r') as fr: reader = csv.DictReader(fr) diff --git a/tests/ut/cpp/ops/test_ops_select.cc b/tests/ut/cpp/ops/test_ops_select.cc new file mode 100644 index 00000000000..8a853dece9e --- /dev/null +++ b/tests/ut/cpp/ops/test_ops_select.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "common/common_test.h" +#include "ops/ops_func_impl/select.h" +#include "ops/test_ops.h" +#include "ops/test_ops_cmp_utils.h" +#include "ops/test_value_utils.h" + +namespace mindspore { +namespace ops { +OP_FUNC_IMPL_TEST_DECLARE(Select, MultiInputOpParams); + +OP_FUNC_IMPL_TEST_CASES( + Select, + testing::Values( + MultiInputOpParams{{{2, 3}, {2, 3}, {2, 3}}, {kBool, kFloat32, kFloat32}, {{2, 3}}, {kFloat32}, {}}, + MultiInputOpParams{{{2, 3}, {2, 3}, {2, 3}}, {kBool, kFloat32, kInt32}, {{2, 3}}, {kFloat32}, {}}, + MultiInputOpParams{{{-1, 3}, {2, 3}, {2, 3}}, {kBool, kFloat32, kFloat32}, {{2, 3}}, {kFloat32}, {}}, + MultiInputOpParams{{{2, -1}, {2, 3}, {2, 3}}, {kBool, kFloat32, kFloat32}, {{2, 3}}, {kFloat32}, {}}, + MultiInputOpParams{{{2, -1}, {2, -1}, {2, -1}}, {kBool, kFloat32, kFloat32}, {{2, -1}}, {kFloat32}, {}}, + MultiInputOpParams{{{-1, -1}, {-1, -1}, {2, -1}}, {kBool, kFloat32, kFloat32}, {{2, -1}}, {kFloat32}, {}}, + MultiInputOpParams{{{-1, -1}, {-1, -1}, {-1, -1}}, {kBool, kFloat32, kFloat32}, {{-1, -1}}, {kFloat32}, {}}, + MultiInputOpParams{{{4, 5, 8}, {1, 5, 8}, {4, 1, 8}}, {kBool, kFloat32, kFloat32}, {{4, 5, 8}}, {kFloat32}, {}}, + MultiInputOpParams{{{1, 65, 54, 12, 5, 2}, {5, 5, 65, 1, 12, 5, 2}, {65, 54, 1, 5, 2}}, + {kBool, kFloat32, kFloat32}, + {{5, 5, 65, 54, 12, 5, 2}}, + {kFloat32}, + {}})); +} // namespace ops +} // namespace mindspore