aclnn: selece/where
Remove functional API InferShape support broadcast Add Pass for Select Bprop recover shapes from broadcast Add ut & st Revise doc
This commit is contained in:
parent
f462ec19d1
commit
9ad846649e
|
@ -1,31 +1,31 @@
|
|||
mindspore.ops.select
|
||||
====================
|
||||
|
||||
.. py:function:: mindspore.ops.select(cond, x, y)
|
||||
.. py:function:: mindspore.ops.select(condition, input, other)
|
||||
|
||||
根据条件判断Tensor中的元素的值,来决定输出中的相应元素是从 `x` (如果元素值为True)还是从 `y` (如果元素值为False)中选择。
|
||||
根据条件判断Tensor中的元素的值,来决定输出中的相应元素是从 `input` (如果元素值为True)还是从 `other` (如果元素值为False)中选择。
|
||||
|
||||
该算法可以被定义为:
|
||||
|
||||
.. math::
|
||||
|
||||
out_i = \begin{cases}
|
||||
x_i, & \text{if } cond_i \\
|
||||
y_i, & \text{otherwise}
|
||||
input_i, & \text{if } condition_i \\
|
||||
other_i, & \text{otherwise}
|
||||
\end{cases}
|
||||
|
||||
参数:
|
||||
- **cond** (Tensor[bool]) - 条件Tensor,决定选择哪一个元素,shape是 :math:`(x_1, x_2, ..., x_N, ..., x_R)`。
|
||||
- **x** (Union[Tensor, int, float]) - 第一个被选择的Tensor或者数字。
|
||||
如果x是一个Tensor,那么shape是或者可以被广播为 :math:`(x_1, x_2, ..., x_N, ..., x_R)`。
|
||||
如果x是int或者float,那么将会被转化为int32或者float32类型,并且被广播为与y相同的shape。x和y中至少要有一个Tensor。
|
||||
- **y** (Union[Tensor, int, float]) - 第二个被选择的Tensor或者数字。
|
||||
如果y是一个Tensor,那么shape是或者可以被广播为 :math:`(x_1, x_2, ..., x_N, ..., x_R)`。
|
||||
如果y是int或者float,那么将会被转化为int32或者float32类型,并且被广播为与x相同的shape。x和y中至少要有一个Tensor。
|
||||
- **condition** (Tensor[bool]) - 条件Tensor,决定选择哪一个元素,shape是 :math:`(x_1, x_2, ..., x_N, ..., x_R)`。
|
||||
- **input** (Union[Tensor, int, float]) - 第一个被选择的Tensor或者数字。
|
||||
如果input是一个Tensor,那么shape是或者可以被广播为 :math:`(x_1, x_2, ..., x_N, ..., x_R)`。
|
||||
如果input是int或者float,那么将会被转化为int32或者float32类型,并且被广播为与y相同的shape。x和y中至少要有一个Tensor。
|
||||
- **other** (Union[Tensor, int, float]) - 第二个被选择的Tensor或者数字。
|
||||
如果other是一个Tensor,那么shape是或者可以被广播为 :math:`(x_1, x_2, ..., x_N, ..., x_R)`。
|
||||
如果other是int或者float,那么将会被转化为int32或者float32类型,并且被广播为与x相同的shape。x和y中至少要有一个Tensor。
|
||||
|
||||
返回:
|
||||
Tensor,与 `cond` 的shape相同。
|
||||
Tensor,与 `condition` 的shape相同。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `x` 和 `y` 不是Tensor、int或者float。
|
||||
- **TypeError** - `input` 和 `other` 不是Tensor、int或者float。
|
||||
- **ValueError** - 输入的shape不能被广播。
|
||||
|
|
|
@ -1,23 +1,23 @@
|
|||
mindspore.ops.where
|
||||
====================
|
||||
|
||||
.. py:function:: mindspore.ops.where(condition, x, y)
|
||||
.. py:function:: mindspore.ops.where(condition, input, other)
|
||||
|
||||
返回一个Tensor,Tensor的元素从 `x` 或 `y` 中根据 `condition` 选择。
|
||||
返回一个Tensor,Tensor的元素从 `input` 或 `other` 中根据 `condition` 选择。
|
||||
|
||||
.. math::
|
||||
|
||||
output_i = \begin{cases} x_i,\quad &if\ condition_i \\ y_i,\quad &otherwise \end{cases}
|
||||
output_i = \begin{cases} input_i,\quad &if\ condition_i \\ other_i,\quad &otherwise \end{cases}
|
||||
|
||||
参数:
|
||||
- **condition** (Tensor[bool]) - 如果是 ``True`` ,选取 `x` 中的元素,否则选取 `y` 中的元素。
|
||||
- **x** (Union[Tensor, Scalar]) - 在 `condition` 为 ``True`` 的索引处选择的值。
|
||||
- **y** (Union[Tensor, Scalar]) - 当 `condition` 为 ``False`` 的索引处选择的值。
|
||||
- **condition** (Tensor[bool]) - 如果是 ``True`` ,选取 `input` 中的元素,否则选取 `other` 中的元素。
|
||||
- **input** (Union[Tensor, Scalar]) - 在 `condition` 为 ``True`` 的索引处选择的值。
|
||||
- **other** (Union[Tensor, Scalar]) - 当 `condition` 为 ``False`` 的索引处选择的值。
|
||||
|
||||
返回:
|
||||
Tensor,其中的元素从 `x` 和 `y` 中选取。
|
||||
Tensor,其中的元素从 `input` 和 `other` 中选取。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果 `condition` 不是Tensor。
|
||||
- **TypeError** - 如果 `x` 和 `y` 都是常量。
|
||||
- **ValueError** - `condition` 、 `x` 和 `y` 不能互相广播。
|
||||
- **TypeError** - 如果 `input` 和 `other` 都是常量。
|
||||
- **ValueError** - `condition` 、 `input` 和 `other` 不能互相广播。
|
||||
|
|
|
@ -925,7 +925,10 @@ REG_BPROP_BUILDER("Select").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
|||
auto dout = ib->GetInput(kIndex4);
|
||||
auto dx = x->need_compute_grad_out() ? ib->Select(cond, dout, ib->ZerosLike(x)) : ib->OutZeros(x);
|
||||
auto dy = x->need_compute_grad_out() ? ib->Select(cond, ib->ZerosLike(y), dout) : ib->OutZeros(y);
|
||||
return {ib->OutZeros(cond), dx, dy};
|
||||
auto bc_x = BinopGradCommon(ib, cond, x, dout, dx);
|
||||
auto bc_y = BinopGradCommon(ib, cond, y, dout, dy);
|
||||
auto ret = BinopGradCommon(ib, x, y, bc_x[kIndex1], bc_y[kIndex1]);
|
||||
return {ib->OutZeros(cond), ret[kIndex0], ret[kIndex1]};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("OnesLike").SetUnusedInputs({i0, i1, i2}).SetBody(ReturnZeros);
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/ascend/optimizer/ge/broadcast_for_select.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "mindspore/core/ops/array_ops.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "include/backend/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
ShapeVector GetSelectInputShape(const AnfNodePtr &input) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
auto input_base_shape = input->Shape();
|
||||
MS_EXCEPTION_IF_NULL(input_base_shape);
|
||||
auto input_shape = input_base_shape->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_shape);
|
||||
return input_shape->shape();
|
||||
}
|
||||
|
||||
ShapeVector CalcBroadcastShape(AnfNodePtr cond, AnfNodePtr x, AnfNodePtr y) {
|
||||
auto cond_shape = GetSelectInputShape(cond);
|
||||
auto x_shape = GetSelectInputShape(x);
|
||||
auto y_shape = GetSelectInputShape(y);
|
||||
auto cond_size = cond_shape.size();
|
||||
auto x_size = x_shape.size();
|
||||
auto y_size = y_shape.size();
|
||||
ShapeVector broadcast_shape =
|
||||
cond_size > x_size ? cond_size > y_size ? cond_shape : y_shape : x_size > y_size ? x_shape : y_shape;
|
||||
auto n = broadcast_shape.size();
|
||||
for (size_t i = n; i > 0; --i) {
|
||||
auto cond_i = cond_size < i ? 1 : cond_shape[cond_size - i];
|
||||
auto x_i = x_size < i ? 1 : x_shape[x_size - i];
|
||||
auto y_i = y_size < i ? 1 : y_shape[y_size - i];
|
||||
auto broadcost_i = std::max(cond_i, std::max(x_i, y_i));
|
||||
if (cond_i != broadcost_i && cond_i != 1) {
|
||||
MS_EXCEPTION(ValueError) << "For select, condition input can not broadcast at index " << i;
|
||||
}
|
||||
if (x_i != broadcost_i && x_i != 1) {
|
||||
MS_EXCEPTION(ValueError) << "For select, x input can not broadcast at index " << i;
|
||||
}
|
||||
if (y_i != broadcost_i && y_i != 1) {
|
||||
MS_EXCEPTION(ValueError) << "For select, y input can not broadcast at index " << i;
|
||||
}
|
||||
broadcast_shape[n - i] = broadcost_i;
|
||||
}
|
||||
return broadcast_shape;
|
||||
}
|
||||
|
||||
CNodePtr AddBroadCastToNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
|
||||
const std::vector<int64_t> &broad_shape) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
auto input_type = common::AnfAlgo::GetOutputInferDataType(input_node, 0);
|
||||
auto shape_node = opt::CreateValueNodeWithKernelInfo(func_graph, MakeValue(broad_shape));
|
||||
|
||||
std::vector<AnfNodePtr> broadcastto_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimBroadcastTo->name())), input_node, shape_node};
|
||||
CNodePtr broadcastto_node = NewCNode(broadcastto_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(broadcastto_node);
|
||||
broadcastto_node->set_scope(input_node->scope());
|
||||
broadcastto_node->set_abstract(input_node->abstract());
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({input_type}, {broad_shape}, broadcastto_node.get());
|
||||
return broadcastto_node;
|
||||
}
|
||||
|
||||
CNodePtr AddSelectNode(const FuncGraphPtr &func_graph, const CNodePtr &cond_node, const CNodePtr &x_node,
|
||||
const CNodePtr &y_node, const CNodePtr &select_node, const std::vector<int64_t> &broad_shape) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(cond_node);
|
||||
MS_EXCEPTION_IF_NULL(x_node);
|
||||
MS_EXCEPTION_IF_NULL(y_node);
|
||||
MS_EXCEPTION_IF_NULL(select_node);
|
||||
auto input_type = common::AnfAlgo::GetOutputInferDataType(select_node, 0);
|
||||
|
||||
std::vector<AnfNodePtr> select_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSelect->name())),
|
||||
cond_node, x_node, y_node};
|
||||
CNodePtr out_node = NewCNode(select_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(out_node);
|
||||
out_node->set_scope(select_node->scope());
|
||||
out_node->set_abstract(select_node->abstract());
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({input_type}, {broad_shape}, out_node.get());
|
||||
return out_node;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef BroadCastForSelect::DefinePattern() const {
|
||||
VarPtr inputs = std::make_shared<SeqVar>();
|
||||
return VectorRef({prim::kPrimSelect, inputs});
|
||||
}
|
||||
|
||||
const AnfNodePtr BroadCastForSelect::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
// Select(...) ===> inputs -> CalcBroadcastShape -> BroadCastTo -> Select(...)
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto select_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(select_node);
|
||||
// get broadcast shape
|
||||
auto cond = select_node->input(kIndex1);
|
||||
auto x = select_node->input(kIndex2);
|
||||
auto y = select_node->input(kIndex3);
|
||||
auto output_shape = CalcBroadcastShape(cond, x, y);
|
||||
// do BroadCast
|
||||
auto new_cond = AddBroadCastToNode(graph, cond, output_shape);
|
||||
auto new_x = AddBroadCastToNode(graph, x, output_shape);
|
||||
auto new_y = AddBroadCastToNode(graph, y, output_shape);
|
||||
auto out_node = AddSelectNode(graph, new_cond, new_x, new_y, select_node, output_shape);
|
||||
return out_node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_GE_BROADCAST_FOR_SELECT_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_GE_BROADCAST_FOR_SELECT_H_
|
||||
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "include/backend/optimizer/optimizer.h"
|
||||
#include "ops/auto_generate/gen_ops_primitive.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class BroadCastForSelect : public PatternProcessPass {
|
||||
public:
|
||||
explicit BroadCastForSelect(bool multi_graph = true) : PatternProcessPass("broadcast_for_select", multi_graph) {}
|
||||
~BroadCastForSelect() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_GE_BROADCAST_FOR_SELECT_H_
|
|
@ -63,6 +63,7 @@
|
|||
#include "plugin/device/ascend/optimizer/enhancer/eliminate_maketuple_getitem.h"
|
||||
#include "plugin/device/ascend/optimizer/ge/convert_pad_v3_paddings.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fusion/shape_reshape_fusion.h"
|
||||
#include "plugin/device/ascend/optimizer/ge/broadcast_for_select.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -98,6 +99,7 @@ void GEBackendOptimization(const KernelGraphPtr &kernel_graph) {
|
|||
opt_ge_pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>(true, true));
|
||||
opt_ge_pm->AddPass(std::make_shared<opt::UnfoldNestedOutput>("unfold_nested_output"));
|
||||
opt_ge_pm->AddPass(std::make_shared<opt::UnfoldMaketuple>("unfold_nested_maketuple"));
|
||||
opt_ge_pm->AddPass(std::make_shared<opt::BroadCastForSelect>());
|
||||
optimizer->AddPassManager(opt_ge_pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
|
|
|
@ -1,30 +1,30 @@
|
|||
select:
|
||||
description: |
|
||||
The conditional tensor determines whether the corresponding element in the output must be
|
||||
selected from `x` (if True) or `y` (if False) based on the value of each
|
||||
selected from `input` (if True) or `other` (if False) based on the value of each
|
||||
element.
|
||||
|
||||
It can be defined as:
|
||||
|
||||
.. math::
|
||||
out_i = \begin{cases}
|
||||
x_i, & \text{if } cond_i \\
|
||||
y_i, & \text{otherwise}
|
||||
input_i, & \text{if } condition_i \\
|
||||
other_i, & \text{otherwise}
|
||||
\end{cases}
|
||||
|
||||
Inputs:
|
||||
- **cond** (Tensor[bool]): The condition tensor, decides which element is chosen.
|
||||
- **condition** (Tensor[bool]): The condition tensor, decides which element is chosen.
|
||||
The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
|
||||
- **x** (Tensor): The first Tensor to be selected.
|
||||
- **input** (Tensor): The first Tensor to be selected.
|
||||
The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
|
||||
- **y** (Tensor): The second Tensor to be selected.
|
||||
- **other** (Tensor): The second Tensor to be selected.
|
||||
The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape as `cond`.
|
||||
Tensor, has the same shape as `condition`.
|
||||
|
||||
Raises:
|
||||
TypeError: If x or y is not a Tensor.
|
||||
TypeError: If input or other is not a Tensor.
|
||||
ValueError: The shape of inputs are different.
|
||||
|
||||
Supported Platforms:
|
||||
|
|
|
@ -1,14 +1,18 @@
|
|||
#operator select
|
||||
#operator select/where
|
||||
select:
|
||||
args:
|
||||
cond:
|
||||
condition:
|
||||
dtype: tensor
|
||||
x:
|
||||
input:
|
||||
dtype: tensor
|
||||
y:
|
||||
type_cast: number
|
||||
other:
|
||||
dtype: tensor
|
||||
type_cast: number
|
||||
args_signature:
|
||||
dtype_group: (condition), (input, other)
|
||||
returns:
|
||||
output:
|
||||
dtype: tensor
|
||||
function:
|
||||
disable: True
|
||||
dispatch:
|
||||
enable: True
|
||||
|
|
|
@ -46,18 +46,6 @@ namespace ops {
|
|||
using float_complex = std::complex<float>;
|
||||
using double_complex = std::complex<double>;
|
||||
|
||||
void SelectInferShapeCheck(const std::vector<int64_t> &x_shape, const std::vector<int64_t> &y_shape,
|
||||
const std::vector<int64_t> &cond_shape, size_t shape_size) {
|
||||
for (size_t i = 0; i < shape_size; i++) {
|
||||
if ((x_shape[i] > 0 && cond_shape[i] > 0 && x_shape[i] != cond_shape[i]) ||
|
||||
(x_shape[i] > 0 && y_shape[i] > 0 && x_shape[i] != y_shape[i])) {
|
||||
MS_EXCEPTION(ValueError)
|
||||
<< "For 'Select', the shape of 'condition', 'x' and 'y' must be the same. But got 'condition' shape: "
|
||||
<< cond_shape << ", 'x' shape: " << x_shape << ", 'y' shape: " << y_shape << ".";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
abstract::BaseShapePtr SelectFuncImpl::InferShape(const PrimitivePtr &prim,
|
||||
const std::vector<AbstractBasePtr> &input_args) const {
|
||||
auto cond_shape = input_args[kSelectCondIndex]->GetShape()->GetShapeVector();
|
||||
|
@ -66,16 +54,9 @@ abstract::BaseShapePtr SelectFuncImpl::InferShape(const PrimitivePtr &prim,
|
|||
if (IsDynamicRank(cond_shape) || IsDynamicRank(x_shape) || IsDynamicRank(y_shape)) {
|
||||
return std::make_shared<abstract::TensorShape>(ShapeVector{abstract::TensorShape::kShapeRankAny});
|
||||
}
|
||||
auto cond_shape_size = cond_shape.size();
|
||||
auto x_shape_size = x_shape.size();
|
||||
auto y_shape_size = y_shape.size();
|
||||
if (cond_shape_size != x_shape_size || y_shape_size != x_shape_size) {
|
||||
MS_EXCEPTION(ValueError)
|
||||
<< "For 'Select', the shape of 'condition', 'x' and 'y' must be the same. But got 'condition' shape: "
|
||||
<< cond_shape << ", 'x' shape: " << x_shape << ", 'y' shape: " << y_shape << ".";
|
||||
}
|
||||
SelectInferShapeCheck(x_shape, y_shape, cond_shape, x_shape_size);
|
||||
return input_args[kSelectCondIndex]->GetShape()->Clone();
|
||||
auto broadcast_output_size = CalBroadCastShape(x_shape, y_shape, prim->name(), "input", "other");
|
||||
auto output_size = CalBroadCastShape(cond_shape, broadcast_output_size, prim->name(), "condition", "input");
|
||||
return std::make_shared<abstract::TensorShape>(output_size);
|
||||
}
|
||||
|
||||
TypePtr SelectFuncImpl::InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const {
|
||||
|
@ -94,11 +75,6 @@ TypePtr SelectFuncImpl::InferType(const PrimitivePtr &prim, const std::vector<Ab
|
|||
(void)CheckAndConvertUtils::CheckTensorTypeValid("y_type", y_type, common_valid_types_with_complex_and_bool,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("cond", cond_type, {kBool}, prim_name);
|
||||
if (*x_type != *y_type) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name
|
||||
<< "', the x_type and y_type must be the same, but got x_type: " << x_type->ToString()
|
||||
<< " and y_type: " << y_type->ToString() << ".";
|
||||
}
|
||||
return x_type->Clone();
|
||||
}
|
||||
} // namespace ops
|
||||
|
|
|
@ -18,8 +18,9 @@ from mindspore.ops.extend import *
|
|||
from mindspore.ops.extend import array_func, math_func, nn_func
|
||||
from mindspore.mint.nn.functional import *
|
||||
from mindspore.mint.nn import functional
|
||||
from mindspore.ops import where
|
||||
|
||||
__all__ = []
|
||||
__all__ = ['where']
|
||||
__all__.extend(array_func.__all__)
|
||||
__all__.extend(math_func.__all__)
|
||||
__all__.extend(nn_func.__all__)
|
||||
|
|
|
@ -61,7 +61,7 @@ from mindspore.ops._utils.utils import ms_arrange
|
|||
|
||||
from mindspore.ops.auto_generate import cat, range, scatter_nd, deepcopy, masked_fill, diagonal, expand_dims, \
|
||||
nonzero, flip, transpose, unsorted_segment_sum, diag, gather, gather_d, gather_nd, reshape, broadcast_to, \
|
||||
strided_slice, ones, zeros, max_, min_
|
||||
strided_slice, ones, zeros, max_, min_, select
|
||||
from mindspore.ops.operations.manually_defined import tile, rank, scalar_cast
|
||||
|
||||
arg_max_with_value_ = ArgMaxWithValue()
|
||||
|
@ -390,25 +390,25 @@ def hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype
|
|||
return out
|
||||
|
||||
|
||||
def where(condition, x, y):
|
||||
def where(condition, input, other):
|
||||
r"""
|
||||
Selects elements from `x` or `y` based on `condition` and returns a tensor.
|
||||
Selects elements from `input` or `other` based on `condition` and returns a tensor.
|
||||
|
||||
.. math::
|
||||
output_i = \begin{cases} x_i,\quad &if\ condition_i \\ y_i,\quad &otherwise \end{cases}
|
||||
output_i = \begin{cases} input_i,\quad &if\ condition_i \\ other_i,\quad &otherwise \end{cases}
|
||||
|
||||
Args:
|
||||
condition (Tensor[bool]): If True, yield `x`, otherwise yield `y`.
|
||||
x (Union[Tensor, Scalar]): When `condition` is True, values to select from.
|
||||
y (Union[Tensor, Scalar]): When `condition` is False, values to select from.
|
||||
condition (Tensor[bool]): If True, yield `input`, otherwise yield `other`.
|
||||
input (Union[Tensor, Scalar]): When `condition` is True, values to select from.
|
||||
other (Union[Tensor, Scalar]): When `condition` is False, values to select from.
|
||||
|
||||
Returns:
|
||||
Tensor, elements are selected from `x` and `y`.
|
||||
Tensor, elements are selected from `input` and `other`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `condition` is not a Tensor.
|
||||
TypeError: If both `x` and `y` are scalars.
|
||||
ValueError: If `condition`, `x` and `y` can not broadcast to each other.
|
||||
TypeError: If both `input` and `other` are scalars.
|
||||
ValueError: If `condition`, `input` and `other` can not broadcast to each other.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
@ -425,25 +425,7 @@ def where(condition, x, y):
|
|||
[[0. 1.]
|
||||
[2. 1.]]
|
||||
"""
|
||||
if not isinstance(condition, Tensor):
|
||||
raise TypeError(f"For 'where', 'condition' must be a Tensor, but got {type(condition)}.")
|
||||
if isinstance(x, (int, float)):
|
||||
if not isinstance(y, Tensor):
|
||||
raise TypeError(
|
||||
f"For 'where', at least one of 'x' and 'y' should be Tensor, but got x:{type(x)}, y:{type(y)}."
|
||||
)
|
||||
x = cast_(x, y.dtype)
|
||||
elif isinstance(y, (int, float)):
|
||||
if not isinstance(x, Tensor):
|
||||
raise TypeError(
|
||||
f"For 'where', at least one of 'x' and 'y' should be Tensor, but got x:{type(x)}, y:{type(y)}."
|
||||
)
|
||||
y = cast_(y, x.dtype)
|
||||
output_shape = _calc_broadcast_shape(x.shape, y.shape, condition.shape)
|
||||
condition = broadcast_to(condition, output_shape)
|
||||
x = broadcast_to(x, output_shape)
|
||||
y = broadcast_to(y, output_shape)
|
||||
return tensor_select_(condition, x, y)
|
||||
return tensor_select_(condition, input, other)
|
||||
|
||||
|
||||
def reverse(x, axis):
|
||||
|
@ -1435,171 +1417,6 @@ def flatten(input, order='C', *, start_dim=1, end_dim=-1):
|
|||
return reshape_(input, new_shape)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_select_type_match(scalar, tensor_type, scalar_name, tensor_name):
|
||||
if isinstance(scalar, int) and tensor_type != mstype.int32:
|
||||
raise TypeError(f"For functional operator[select], the input[{scalar_name}] is int, "
|
||||
f"then the input[{tensor_name}] must be a Tensor of int32.")
|
||||
if isinstance(scalar, float) and tensor_type != mstype.float32:
|
||||
raise TypeError(f"For functional operator[select], the input[{scalar_name}] is float, "
|
||||
f"then the input[{tensor_name}] must be a Tensor of float32.")
|
||||
|
||||
|
||||
@_primexpr
|
||||
def _check_select_shape_match(input_shape, cond_shape, tensor_name):
|
||||
if input_shape != cond_shape:
|
||||
raise ValueError(f"For functional operator[select], the cond shape must be same as {tensor_name} shape.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_select_type(is_cond_tensor, is_x_scalar, is_y_scalar, is_x_tensor, is_y_tensor):
|
||||
if not is_cond_tensor:
|
||||
raise TypeError(f"For functional operator[select], the input[cond] must be a Tensor.")
|
||||
if is_x_scalar and not is_y_tensor:
|
||||
raise TypeError(f"For functional operator[select], the input[x] is int or float, "
|
||||
f"then the input[y] must be a Tensor.")
|
||||
if is_y_scalar and not is_x_tensor:
|
||||
raise TypeError(f"For functional operator[select], the input[y] is int or float, "
|
||||
f"then the input[x] must be a Tensor.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_select_shape_same(cond_shape, x_shape, y_shape):
|
||||
"""Check if input of select has same shape."""
|
||||
return cond_shape == x_shape and x_shape == y_shape and cond_shape == y_shape
|
||||
|
||||
|
||||
@constexpr
|
||||
def get_max_value(x, y, z):
|
||||
"""Get the maximum value of x, y and z."""
|
||||
if x >= y and x >= z:
|
||||
return x
|
||||
if y >= x and y >= z:
|
||||
return y
|
||||
return z
|
||||
|
||||
|
||||
@constexpr
|
||||
def _calc_broadcast_shape(cond_shape, x_shape, y_shape):
|
||||
"""Calculate broadcast shape for select"""
|
||||
converted_shape = []
|
||||
cond_reverse = cond_shape[::-1]
|
||||
x_reverse = x_shape[::-1]
|
||||
y_reverse = y_shape[::-1]
|
||||
max_len = get_max_value(len(cond_reverse), len(x_reverse), len(y_reverse))
|
||||
i = 0
|
||||
while i < max_len:
|
||||
cond_element = 1 if i >= len(cond_reverse) else cond_reverse[i]
|
||||
x_element = 1 if i >= len(x_reverse) else x_reverse[i]
|
||||
y_element = 1 if i >= len(y_reverse) else y_reverse[i]
|
||||
broadcast_element = get_max_value(cond_element, x_element, y_element)
|
||||
if cond_element not in (1, broadcast_element):
|
||||
raise ValueError(f"For select, condition input can not broadcast at index {i}")
|
||||
if x_element not in (1, broadcast_element):
|
||||
raise ValueError(f"For select, x input can not broadcast at index {i}")
|
||||
if y_element not in (1, broadcast_element):
|
||||
raise ValueError(f"For select, y input can not broadcast at index {i}")
|
||||
converted_shape.append(broadcast_element)
|
||||
i = i + 1
|
||||
converted_shape.reverse()
|
||||
return tuple(converted_shape)
|
||||
|
||||
|
||||
def select(cond, x, y):
|
||||
r"""
|
||||
The conditional tensor determines whether the corresponding element in the output must be
|
||||
selected from `x` (if true) or `y` (if false) based on the value of each element.
|
||||
|
||||
It can be defined as:
|
||||
|
||||
.. math::
|
||||
out_i = \begin{cases}
|
||||
x_i, & \text{if } cond_i \\
|
||||
y_i, & \text{otherwise}
|
||||
\end{cases}
|
||||
|
||||
Args:
|
||||
cond (Tensor[bool]): The condition tensor, decides which element is chosen.
|
||||
The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
|
||||
x (Union[Tensor, int, float]): The first Tensor or number to be selected.
|
||||
If x is a Tensor, the shape is or can be broadcadt to :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
|
||||
If x is an int or a float, it will be cast to the type of int32 or float32,
|
||||
and broadcast to the same shape as y. One of x and y must be a Tensor.
|
||||
y (Union[Tensor, int, float]): The second Tensor or number to be selected.
|
||||
If y is a Tensor, The shape is or can be broadcadt to :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
|
||||
If y is an int or a float, it will be cast to the type of int32 or float32,
|
||||
and broadcast to the same shape as x. One of x and y must be a Tensor.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same shape as `cond`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` or `y` is not a Tensor, int or float.
|
||||
ValueError: The shapes of inputs can not be broadcast.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
>>> from mindspore import Tensor, ops
|
||||
>>> # 1) Both inputs are Tensor
|
||||
>>>
|
||||
>>> cond = Tensor([True, False])
|
||||
>>> x = Tensor([2,3], mindspore.float32)
|
||||
>>> y = Tensor([1,2], mindspore.float32)
|
||||
>>> output = ops.select(cond, x, y)
|
||||
>>> print(output)
|
||||
[2. 2.]
|
||||
>>> # 2) y is a float
|
||||
>>> cond = Tensor([True, False])
|
||||
>>> x = Tensor([2,3], mindspore.float32)
|
||||
>>> y = 2.0
|
||||
>>> output = ops.select(cond, x, y)
|
||||
>>> print(output)
|
||||
[2. 2.]
|
||||
"""
|
||||
is_x_scalar = isinstance(x, (int, float))
|
||||
is_y_scalar = isinstance(y, (int, float))
|
||||
is_x_tensor = isinstance(x, Tensor)
|
||||
is_y_tensor = isinstance(y, Tensor)
|
||||
is_cond_tensor = isinstance(cond, Tensor)
|
||||
_check_select_type(is_cond_tensor, is_x_scalar, is_y_scalar, is_x_tensor, is_y_tensor)
|
||||
input_x = x
|
||||
input_y = y
|
||||
if is_x_scalar:
|
||||
_check_select_shape_match(y.shape, cond.shape, "y")
|
||||
_check_select_type_match(x, y.dtype, "x", "y")
|
||||
input_x = zeros_like_(y) + x
|
||||
if isinstance(x, int):
|
||||
input_x = cast_(input_x, mstype.int32)
|
||||
else:
|
||||
input_x = cast_(input_x, mstype.float32)
|
||||
|
||||
if is_y_scalar:
|
||||
_check_select_shape_match(x.shape, cond.shape, "x")
|
||||
_check_select_type_match(y, x.dtype, "y", "x")
|
||||
input_y = zeros_like_(x) + y
|
||||
if isinstance(y, int):
|
||||
input_y = cast_(input_y, mstype.int32)
|
||||
else:
|
||||
input_y = cast_(input_y, mstype.float32)
|
||||
|
||||
if is_x_tensor and is_y_tensor and is_cond_tensor:
|
||||
x_shape = ops.shape(x)
|
||||
y_shape = ops.shape(y)
|
||||
cond_shape = ops.shape(cond)
|
||||
all_constant = ops.isconstant(cond_shape) and ops.isconstant(x_shape) and ops.isconstant(y_shape)
|
||||
if all_constant and not _check_select_shape_same(cond_shape, x_shape, y_shape):
|
||||
broadcast_shape = _calc_broadcast_shape(cond_shape, x_shape, y_shape)
|
||||
new_cond = ops.broadcast_to(cond, broadcast_shape)
|
||||
new_x = ops.broadcast_to(x, broadcast_shape)
|
||||
new_y = ops.broadcast_to(y, broadcast_shape)
|
||||
return tensor_select_(new_cond, new_x, new_y)
|
||||
|
||||
return tensor_select_(cond, input_x, input_y)
|
||||
|
||||
|
||||
def slice(input_x, begin, size):
|
||||
r"""
|
||||
Slices a tensor in the specified shape.
|
||||
|
|
|
@ -36,3 +36,4 @@ GroupNormGrad: 'aclnnGroupNormBackward'
|
|||
NotEqual: 'aclnnNeTensor'
|
||||
ClampScalar: 'aclnnClamp'
|
||||
OneHotExt: 'aclnnOneHot'
|
||||
Select: 'aclnnSWhere'
|
||||
|
|
|
@ -0,0 +1,310 @@
|
|||
# Copyright 2024 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
"""test select"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
import os
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from mindspore.ops import select
|
||||
from mindspore import ops, Tensor, jit, JitConfig, context
|
||||
from tests.st.ops.dynamic_shape.test_op_utils import TEST_OP
|
||||
from tests.st.utils import test_utils
|
||||
|
||||
|
||||
def generate_random_input(shape, dtype):
|
||||
return Tensor(np.random.randn(*shape).astype(dtype))
|
||||
|
||||
|
||||
def generate_expect_forward_output(condition, x, y):
|
||||
return np.where(condition, x, y)
|
||||
|
||||
|
||||
def generate_expect_backward_output(condition):
|
||||
return np.zeros(np.shape(condition), dtype=np.bool_),\
|
||||
np.where(condition, 1, 0), np.where(condition, 0, 1)
|
||||
|
||||
|
||||
@test_utils.run_with_cell
|
||||
def select_forward_func(condition, x, y):
|
||||
return select(condition, x, y)
|
||||
|
||||
|
||||
@test_utils.run_with_cell
|
||||
def select_backward_func(condition, x, y):
|
||||
return ops.grad(select_forward_func, (0, 1, 2))(condition, x, y)
|
||||
|
||||
|
||||
@test_utils.run_with_cell
|
||||
def select_vmap_func(condition, x, y, in_axes=0):
|
||||
return ops.vmap(select_forward_func, in_axes, out_axes=0)(condition, x, y)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_select_float32(mode):
|
||||
"""
|
||||
Feature: Test functional select operator. Support x or y is a float32 Tensor.
|
||||
Description: Operator select's inputs `x` and `y` are Tensor with float32 type.
|
||||
Expectation: Assert result.
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
cond = np.array([[True, False], [True, False]]).astype(np.bool)
|
||||
x = np.array([[1.2, 1], [1, 0]]).astype(np.float32)
|
||||
y = np.array([[1, 2], [3, 4.0]]).astype(np.float32)
|
||||
output = select_forward_func(Tensor(cond), Tensor(x), Tensor(y))
|
||||
print(output.asnumpy())
|
||||
expect = [[1.2, 2], [1, 4.0]]
|
||||
error = np.ones(shape=[2, 2]) * 1.0e-6
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_select_float16(mode):
|
||||
"""
|
||||
Feature: Test functional select operator. Support x or y is a float16 Tensor.
|
||||
Description: Operator select's inputs `x` and `y` are Tensor with float16 type.
|
||||
Expectation: Assert result.
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
cond = np.array([[True, False], [True, False]]).astype(np.bool)
|
||||
x = np.array([[1.2, 1], [1, 0]]).astype(np.float16)
|
||||
y = np.array([[1, 2], [3, 4.0]]).astype(np.float16)
|
||||
output = select_forward_func(Tensor(cond), Tensor(x), Tensor(y))
|
||||
print(output.asnumpy())
|
||||
expect = [[1.2, 2], [1, 4.0]]
|
||||
error = np.ones(shape=[2, 2]) * 1.0e-3
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_select_int32(mode):
|
||||
"""
|
||||
Feature: Test functional select operator. Support x or y is a int32 Tensor.
|
||||
Description: Operator select's inputs `x` and `y` are Tensor with int32 type.
|
||||
Expectation: Assert result.
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
cond = np.array([[True, False], [True, False]]).astype(np.bool)
|
||||
x = np.array([[12, 1], [1, 0]]).astype(np.int32)
|
||||
y = np.array([[1, 2], [3, 4]]).astype(np.int32)
|
||||
output = select_forward_func(Tensor(cond), Tensor(x), Tensor(y))
|
||||
print(output.asnumpy())
|
||||
expect = [[12, 2], [1, 4]]
|
||||
error = np.ones(shape=[2, 2]) * 1.0e-6
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_functional_select_scalar(mode):
|
||||
"""
|
||||
Feature: Test functional select operator. Support x or y is a int/float.
|
||||
Description: Operator select's input `x` is a Tensor with int32 type, input `y` is a int.
|
||||
Expectation: Assert result.
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
cond = np.array([[True, False], [True, False]]).astype(np.bool)
|
||||
x = np.array([[12, 1], [1, 0]]).astype(np.int32)
|
||||
y = 2
|
||||
output = select_forward_func(Tensor(cond), Tensor(x), y)
|
||||
print(output.asnumpy())
|
||||
expect = [[12, 2], [1, 2]]
|
||||
error = np.ones(shape=[2, 2]) * 1.0e-6
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_functional_select_broadcast(mode):
|
||||
"""
|
||||
Feature: Test functional select operator support broadcast input.
|
||||
Description: Operator select's support broadcast input.
|
||||
Expectation: Assert result.
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
cond = Tensor(np.random.rand(1, 65, 54, 12, 5, 2), dtype=mstype.bool_)
|
||||
x = Tensor(np.random.rand(5, 5, 65, 1, 12, 5, 2).astype(np.float32))
|
||||
y = Tensor(np.random.rand(65, 54, 1, 5, 2).astype(np.float32))
|
||||
ret = select_forward_func(cond, x, y)
|
||||
assert ret.shape == (5, 5, 65, 54, 12, 5, 2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.parametrize('mode', ['pynative', 'KBK', 'GE'])
|
||||
def test_select_ext_static_shape(mode):
|
||||
"""
|
||||
Feature: Test select with static shape in graph and pynative mode.
|
||||
Description: call ops.select with valid input and index.
|
||||
Expectation: return the correct value.
|
||||
"""
|
||||
x = generate_random_input((2, 3, 4, 5), np.float32)
|
||||
y = generate_random_input((2, 3, 4, 5), np.float32)
|
||||
cond = x > 0
|
||||
|
||||
if mode == 'pynative':
|
||||
ms_out = select_forward_func(cond, x, y)
|
||||
elif mode == 'KBK':
|
||||
ms_out = (jit(select_forward_func, jit_config=JitConfig(jit_level="O0")))(cond, x, y)
|
||||
else:
|
||||
ms_out = (jit(select_forward_func, jit_config=JitConfig(jit_level="O2")))(cond, x, y)
|
||||
|
||||
expect = generate_expect_forward_output(cond.asnumpy(), x.asnumpy(), y.asnumpy())
|
||||
assert np.allclose(ms_out.asnumpy(), expect, rtol=1e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('jit_level', ["O0", "O2"])
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
def test_select_ext_dynamic_shape(jit_level):
|
||||
"""
|
||||
Feature: Test select with dynamic shape in graph mode.
|
||||
Description: call ops.select with valid input and index.
|
||||
Expectation: return the correct value.
|
||||
"""
|
||||
x1 = generate_random_input((2, 3, 4, 5), np.float32)
|
||||
y1 = generate_random_input((2, 3, 4, 5), np.float32)
|
||||
cond1 = x1 > 0
|
||||
|
||||
x2 = generate_random_input((6, 7, 8), np.float32)
|
||||
y2 = generate_random_input((6, 7, 8), np.float32)
|
||||
cond2 = x2 > 0
|
||||
TEST_OP(select_forward_func, [[cond1, x1, y1], [cond2, x2, y2]], grad=True, jit_level=jit_level)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.parametrize('graph_level', ["0", "1"])
|
||||
def test_select_vmap(graph_level):
|
||||
"""
|
||||
Feature: Test select with vmap.
|
||||
Description: call ops.select with valid input and index.
|
||||
Expectation: return the correct value.
|
||||
"""
|
||||
def _foreach_run(condition, x, y, batch):
|
||||
out = []
|
||||
for i in range(condition.shape[batch]):
|
||||
if batch == -1:
|
||||
cond_inner = condition[..., i]
|
||||
x_inner = x[..., i]
|
||||
y_inner = y[..., i]
|
||||
else:
|
||||
cond_inner = condition[i, ...]
|
||||
x_inner = x[i, ...]
|
||||
y_inner = y[i, ...]
|
||||
out.append(select_forward_func(cond_inner, x_inner, y_inner))
|
||||
out = ops.Stack()(out)
|
||||
return out
|
||||
|
||||
os.environ['GRAPH_OP_RUN'] = graph_level
|
||||
x = generate_random_input((2, 3, 4, 5), np.float32)
|
||||
y = generate_random_input((2, 3, 4, 5), np.float32)
|
||||
cond = x > 0
|
||||
|
||||
batch_axis = -1
|
||||
output = select_vmap_func(cond, x, y, batch_axis)
|
||||
expect = _foreach_run(cond, x, y, batch_axis)
|
||||
assert np.allclose(output.asnumpy(), expect.asnumpy(), rtol=1e-4)
|
||||
|
||||
batch_axis = 0
|
||||
output = select_vmap_func(cond, x, y, batch_axis)
|
||||
expect = _foreach_run(cond, x, y, batch_axis)
|
||||
assert np.allclose(output.asnumpy(), expect.asnumpy(), rtol=1e-4)
|
||||
|
||||
del os.environ['GRAPH_OP_RUN']
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.parametrize("mode", ['pynative', 'GE', 'KBK'])
|
||||
def test_select_ext_grad(mode):
|
||||
"""
|
||||
Feature: Test select with backward.
|
||||
Description: call ops.select with valid input and index.
|
||||
Expectation: return the correct value.
|
||||
"""
|
||||
x = generate_random_input((2, 3, 4, 5), np.float32)
|
||||
y = generate_random_input((2, 3, 4, 5), np.float32)
|
||||
cond = x > 0
|
||||
|
||||
if mode == 'pynative':
|
||||
ms_cond, ms_x, ms_y = select_backward_func(cond, x, y)
|
||||
elif mode == 'KBK':
|
||||
ms_cond, ms_x, ms_y = (jit(select_backward_func, jit_config=JitConfig(jit_level="O0")))(cond, x, y)
|
||||
else:
|
||||
ms_cond, ms_x, ms_y = (jit(select_backward_func, jit_config=JitConfig(jit_level="O2")))(cond, x, y)
|
||||
expect_cond, expect_x, expect_y = generate_expect_backward_output(cond.asnumpy())
|
||||
assert np.allclose(ms_cond.asnumpy(), expect_cond, rtol=1e-4)
|
||||
assert np.allclose(ms_x.asnumpy(), expect_x, rtol=1e-4)
|
||||
assert np.allclose(ms_y.asnumpy(), expect_y, rtol=1e-4)
|
|
@ -1,15 +1,57 @@
|
|||
# Copyright 2024 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
"""test where"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
import os
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
from mindspore.ops import where
|
||||
from mindspore import ops, Tensor, jit, JitConfig, context
|
||||
from tests.st.ops.dynamic_shape.test_op_utils import TEST_OP
|
||||
from tests.st.utils import test_utils
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, condition, x, y):
|
||||
return ops.where(condition, x, y)
|
||||
def generate_random_input(shape, dtype):
|
||||
return Tensor(np.random.randn(*shape).astype(dtype))
|
||||
|
||||
|
||||
def generate_expect_forward_output(condition, x, y):
|
||||
return np.where(condition, x, y)
|
||||
|
||||
|
||||
def generate_expect_backward_output(condition):
|
||||
return np.zeros(np.shape(condition), dtype=np.bool_),\
|
||||
np.where(condition, 1, 0), np.where(condition, 0, 1)
|
||||
|
||||
|
||||
@test_utils.run_with_cell
|
||||
def where_forward_func(condition, x, y):
|
||||
return where(condition, x, y)
|
||||
|
||||
|
||||
@test_utils.run_with_cell
|
||||
def where_backward_func(condition, x, y):
|
||||
return ops.grad(where_forward_func, (0, 1, 2))(condition, x, y)
|
||||
|
||||
|
||||
@test_utils.run_with_cell
|
||||
def where_vmap_func(condition, x, y, in_axes=0):
|
||||
return ops.vmap(where_forward_func, in_axes, out_axes=0)(condition, x, y)
|
||||
|
||||
|
||||
@pytest.mark.level2
|
||||
|
@ -27,10 +69,135 @@ def test_ops_where(mode):
|
|||
Expectation: success
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = Tensor(np.arange(4).reshape((2, 2)), mstype.float32)
|
||||
y = Tensor(np.ones((2, 2)), mstype.float32)
|
||||
condition = x < 3
|
||||
output = net(condition, x, y)
|
||||
output = where_forward_func(condition, x, y)
|
||||
expected = np.array([[0, 1], [2, 1]], dtype=np.float32)
|
||||
assert np.allclose(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.parametrize('mode', ['pynative', 'KBK', 'GE'])
|
||||
def test_where_ext_static_shape(mode):
|
||||
"""
|
||||
Feature: Test where with static shape in graph and pynative mode.
|
||||
Description: call ops.where with valid input and index.
|
||||
Expectation: return the correct value.
|
||||
"""
|
||||
x = generate_random_input((2, 3, 4, 5), np.float32)
|
||||
y = generate_random_input((2, 3, 4, 5), np.float32)
|
||||
cond = x > 0
|
||||
|
||||
if mode == 'pynative':
|
||||
ms_out = where_forward_func(cond, x, y)
|
||||
elif mode == 'KBK':
|
||||
ms_out = (jit(where_forward_func, jit_config=JitConfig(jit_level="O0")))(cond, x, y)
|
||||
else:
|
||||
ms_out = (jit(where_forward_func, jit_config=JitConfig(jit_level="O2")))(cond, x, y)
|
||||
|
||||
expect = generate_expect_forward_output(cond.asnumpy(), x.asnumpy(), y.asnumpy())
|
||||
assert np.allclose(ms_out.asnumpy(), expect, rtol=1e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('jit_level', ["O0", "O2"])
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
def test_where_ext_dynamic_shape(jit_level):
|
||||
"""
|
||||
Feature: Test where with dynamic shape in graph mode.
|
||||
Description: call ops.where with valid input and index.
|
||||
Expectation: return the correct value.
|
||||
"""
|
||||
x1 = generate_random_input((2, 3, 4, 5), np.float32)
|
||||
y1 = generate_random_input((2, 3, 4, 5), np.float32)
|
||||
cond1 = x1 > 0
|
||||
|
||||
x2 = generate_random_input((6, 7, 8), np.float32)
|
||||
y2 = generate_random_input((6, 7, 8), np.float32)
|
||||
cond2 = x2 > 0
|
||||
TEST_OP(where_forward_func, [[cond1, x1, y1], [cond2, x2, y2]], grad=True, jit_level=jit_level)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.parametrize('graph_level', ["0", "1"])
|
||||
def test_where_vmap(graph_level):
|
||||
"""
|
||||
Feature: Test where with vmap.
|
||||
Description: call ops.where with valid input and index.
|
||||
Expectation: return the correct value.
|
||||
"""
|
||||
def _foreach_run(condition, x, y, batch):
|
||||
out = []
|
||||
for i in range(condition.shape[batch]):
|
||||
if batch == -1:
|
||||
cond_inner = condition[..., i]
|
||||
x_inner = x[..., i]
|
||||
y_inner = y[..., i]
|
||||
else:
|
||||
cond_inner = condition[i, ...]
|
||||
x_inner = x[i, ...]
|
||||
y_inner = y[i, ...]
|
||||
out.append(where_forward_func(cond_inner, x_inner, y_inner))
|
||||
out = ops.Stack()(out)
|
||||
return out
|
||||
|
||||
os.environ['GRAPH_OP_RUN'] = graph_level
|
||||
x = generate_random_input((2, 3, 4, 5), np.float32)
|
||||
y = generate_random_input((2, 3, 4, 5), np.float32)
|
||||
cond = x > 0
|
||||
|
||||
batch_axis = -1
|
||||
output = where_vmap_func(cond, x, y, batch_axis)
|
||||
expect = _foreach_run(cond, x, y, batch_axis)
|
||||
assert np.allclose(output.asnumpy(), expect.asnumpy(), rtol=1e-4)
|
||||
|
||||
batch_axis = 0
|
||||
output = where_vmap_func(cond, x, y, batch_axis)
|
||||
expect = _foreach_run(cond, x, y, batch_axis)
|
||||
assert np.allclose(output.asnumpy(), expect.asnumpy(), rtol=1e-4)
|
||||
|
||||
del os.environ['GRAPH_OP_RUN']
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.parametrize("mode", ['pynative', 'GE', 'KBK'])
|
||||
def test_where_ext_grad(mode):
|
||||
"""
|
||||
Feature: Test where with backward.
|
||||
Description: call ops.where with valid input and index.
|
||||
Expectation: return the correct value.
|
||||
"""
|
||||
x = generate_random_input((2, 3, 4, 5), np.float32)
|
||||
y = generate_random_input((2, 3, 4, 5), np.float32)
|
||||
cond = x > 0
|
||||
|
||||
if mode == 'pynative':
|
||||
ms_cond, ms_x, ms_y = where_backward_func(cond, x, y)
|
||||
elif mode == 'KBK':
|
||||
ms_cond, ms_x, ms_y = (jit(where_backward_func, jit_config=JitConfig(jit_level="O0")))(cond, x, y)
|
||||
else:
|
||||
ms_cond, ms_x, ms_y = (jit(where_backward_func, jit_config=JitConfig(jit_level="O2")))(cond, x, y)
|
||||
expect_cond, expect_x, expect_y = generate_expect_backward_output(cond.asnumpy())
|
||||
assert np.allclose(ms_cond.asnumpy(), expect_cond, rtol=1e-4)
|
||||
assert np.allclose(ms_x.asnumpy(), expect_x, rtol=1e-4)
|
||||
assert np.allclose(ms_y.asnumpy(), expect_y, rtol=1e-4)
|
||||
|
|
|
@ -175,7 +175,7 @@ def test_collect_custom_aicpu():
|
|||
profiler.analyse()
|
||||
aicpu_intermediate_file_list = glob.glob(f"{tmpdir}/profiler/aicpu_intermediate_*.csv")
|
||||
assert len(aicpu_intermediate_file_list) == 1
|
||||
s1 = {'Select', 'Xlogy', 'Cast'}
|
||||
s1 = {'Cast', 'BroadcastTo', 'Select', 'Xlogy'}
|
||||
s2 = set()
|
||||
with open(aicpu_intermediate_file_list[0], 'r') as fr:
|
||||
reader = csv.DictReader(fr)
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <memory>
|
||||
#include "common/common_test.h"
|
||||
#include "ops/ops_func_impl/select.h"
|
||||
#include "ops/test_ops.h"
|
||||
#include "ops/test_ops_cmp_utils.h"
|
||||
#include "ops/test_value_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
OP_FUNC_IMPL_TEST_DECLARE(Select, MultiInputOpParams);
|
||||
|
||||
OP_FUNC_IMPL_TEST_CASES(
|
||||
Select,
|
||||
testing::Values(
|
||||
MultiInputOpParams{{{2, 3}, {2, 3}, {2, 3}}, {kBool, kFloat32, kFloat32}, {{2, 3}}, {kFloat32}, {}},
|
||||
MultiInputOpParams{{{2, 3}, {2, 3}, {2, 3}}, {kBool, kFloat32, kInt32}, {{2, 3}}, {kFloat32}, {}},
|
||||
MultiInputOpParams{{{-1, 3}, {2, 3}, {2, 3}}, {kBool, kFloat32, kFloat32}, {{2, 3}}, {kFloat32}, {}},
|
||||
MultiInputOpParams{{{2, -1}, {2, 3}, {2, 3}}, {kBool, kFloat32, kFloat32}, {{2, 3}}, {kFloat32}, {}},
|
||||
MultiInputOpParams{{{2, -1}, {2, -1}, {2, -1}}, {kBool, kFloat32, kFloat32}, {{2, -1}}, {kFloat32}, {}},
|
||||
MultiInputOpParams{{{-1, -1}, {-1, -1}, {2, -1}}, {kBool, kFloat32, kFloat32}, {{2, -1}}, {kFloat32}, {}},
|
||||
MultiInputOpParams{{{-1, -1}, {-1, -1}, {-1, -1}}, {kBool, kFloat32, kFloat32}, {{-1, -1}}, {kFloat32}, {}},
|
||||
MultiInputOpParams{{{4, 5, 8}, {1, 5, 8}, {4, 1, 8}}, {kBool, kFloat32, kFloat32}, {{4, 5, 8}}, {kFloat32}, {}},
|
||||
MultiInputOpParams{{{1, 65, 54, 12, 5, 2}, {5, 5, 65, 1, 12, 5, 2}, {65, 54, 1, 5, 2}},
|
||||
{kBool, kFloat32, kFloat32},
|
||||
{{5, 5, 65, 54, 12, 5, 2}},
|
||||
{kFloat32},
|
||||
{}}));
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue