Support arg min with value
This commit is contained in:
parent
e6767ccd0f
commit
bbd9a04b0f
|
@ -265,6 +265,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"erf", std::string("erf")}, // P.Erf()
|
||||
{"erfc", std::string("erfc")}, // P.Erfc()
|
||||
{"standard_laplace", std::string("standard_laplace")}, // P.StandardLaplace()
|
||||
{"arg_min_with_value", std::string("arg_min_with_value")}, // P.ArgMinWithValue
|
||||
}},
|
||||
{kObjectTypeRowTensorType,
|
||||
{
|
||||
|
|
|
@ -20,50 +20,95 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "mindspore/core/ops/argmax_with_value.h"
|
||||
#include "mindspore/core/ops/argmin_with_value.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/general_reduction_impl.cuh"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t kInputNum = 1;
|
||||
constexpr size_t kOutputNum = 2;
|
||||
|
||||
template <typename T, typename S>
|
||||
class ArgMaxAndMinWithValueGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
||||
class ArgMaxAndMinWithValueGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
ArgMaxAndMinWithValueGpuKernelMod() { ResetResource(); }
|
||||
~ArgMaxAndMinWithValueGpuKernelMod() override = default;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override {
|
||||
if (!InitSize(base_operator, inputs, outputs)) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(stream_ptr);
|
||||
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output = GetDeviceAddress<T>(outputs, 1);
|
||||
S *index = GetDeviceAddress<S>(outputs, 0);
|
||||
CalGeneralReduction(small_, input, bound_, outerSize_, innerSize_, index, output,
|
||||
CalGeneralReduction(small_, input, bound_, outer_size_, inner_size_, index, output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
std::string kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
small_ = (kernel_name == "ArgMinWithValue") ? true : false;
|
||||
auto shape = Convert2SizeTClipNeg(common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0));
|
||||
auto output_shape = Convert2SizeTClipNeg(common::AnfAlgo::GetOutputInferShape(kernel_node, 1));
|
||||
is_null_input_ =
|
||||
CHECK_SHAPE_NULL(shape, kernel_name, "input") || CHECK_SHAPE_NULL(output_shape, kernel_name, "output");
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
return true;
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
if (kernel_name_ != "ArgMaxWithValue" && kernel_name_ != "ArgMinWithValue") {
|
||||
MS_EXCEPTION(ArgumentError) << "The kernel must be either ArgMaxWithValue or ArgMinWithValue.";
|
||||
}
|
||||
|
||||
// Check inputs and outputs size.
|
||||
if (inputs.size() != kInputNum) {
|
||||
MS_EXCEPTION(ArgumentError)
|
||||
<< "For kernel mod[ArgMaxAndMinWithValueGpuKernelMod], the size of input should be 1, but got "
|
||||
<< inputs.size();
|
||||
}
|
||||
if (outputs.size() != kOutputNum) {
|
||||
MS_EXCEPTION(ArgumentError)
|
||||
<< "For kernel mod[ArgMaxAndMinWithValueGpuKernelMod], the size of output should be 2, but got "
|
||||
<< outputs.size();
|
||||
}
|
||||
|
||||
if (kernel_name_ == "ArgMinWithValue") {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::ArgMinWithValue>(base_operator);
|
||||
MS_EXCEPTION_IF_NULL(kernel_ptr);
|
||||
axis_ = kernel_ptr->axis();
|
||||
} else {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::ArgMaxWithValue>(base_operator);
|
||||
MS_EXCEPTION_IF_NULL(kernel_ptr);
|
||||
axis_ = kernel_ptr->axis();
|
||||
}
|
||||
small_ = (kernel_name_ == "ArgMinWithValue") ? true : false;
|
||||
return InitSize(base_operator, inputs, outputs);
|
||||
}
|
||||
|
||||
bool InitSize(const BaseOperatorPtr &, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(inputs[0]);
|
||||
auto shape = Convert2SizeTClipNeg(inputs[0]->GetShapeVector());
|
||||
MS_EXCEPTION_IF_NULL(outputs[0]);
|
||||
auto output_shape = Convert2SizeTClipNeg(outputs[0]->GetShapeVector());
|
||||
int64_t dims = SizeToLong(shape.size());
|
||||
int64_t axis = GetAttr<int64_t>(kernel_node, "axis");
|
||||
if (axis < -dims || axis >= dims) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the 'axis' must be in the range [-" << dims << "," << dims
|
||||
<< "), but got " << axis;
|
||||
if (axis_ < -dims || axis_ >= dims) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << dims << ","
|
||||
<< dims << "), but got " << axis_;
|
||||
}
|
||||
if (axis < 0) {
|
||||
axis += dims;
|
||||
if (axis_ < 0) {
|
||||
axis_ += dims;
|
||||
}
|
||||
input_size_ = sizeof(T);
|
||||
for (auto x : shape) {
|
||||
|
@ -73,50 +118,50 @@ class ArgMaxAndMinWithValueGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
for (auto x : output_shape) {
|
||||
output_size_ *= x;
|
||||
}
|
||||
bound_ = static_cast<S>(shape[axis]);
|
||||
if (shape[axis] != static_cast<size_t>(bound_)) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the value of shape[axis] must be "
|
||||
<< static_cast<size_t>(bound_) << ", but got " << shape[axis];
|
||||
bound_ = static_cast<S>(shape[axis_]);
|
||||
if (static_cast<S>(shape[axis_]) != bound_) {
|
||||
MS_EXCEPTION(ArgumentError) << "For '" << kernel_name_ << "', the value of shape[axis] must be "
|
||||
<< static_cast<size_t>(bound_) << ", but got " << shape[axis_];
|
||||
}
|
||||
outerSize_ = 1;
|
||||
for (int64_t i = axis - 1; i >= 0; i--) {
|
||||
outerSize_ *= shape[i];
|
||||
outer_size_ = 1;
|
||||
for (int64_t i = axis_ - 1; i >= 0; i--) {
|
||||
outer_size_ *= shape[i];
|
||||
}
|
||||
innerSize_ = 1;
|
||||
for (int64_t i = axis + 1; i < dims; i++) {
|
||||
innerSize_ *= shape[i];
|
||||
inner_size_ = 1;
|
||||
for (int64_t i = axis_ + 1; i < dims; i++) {
|
||||
inner_size_ *= shape[i];
|
||||
}
|
||||
InitSizeLists();
|
||||
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(output_size_);
|
||||
output_size_list_.push_back(output_size_ / sizeof(S) * sizeof(T));
|
||||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
void ResetResource() noexcept {
|
||||
kernel_name_ = "";
|
||||
axis_ = 0;
|
||||
input_size_ = 0;
|
||||
output_size_ = 0;
|
||||
bound_ = 0;
|
||||
outerSize_ = 0;
|
||||
innerSize_ = 0;
|
||||
is_null_input_ = false;
|
||||
outer_size_ = 0;
|
||||
inner_size_ = 0;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(output_size_);
|
||||
output_size_list_.push_back(output_size_ / sizeof(S) * sizeof(T));
|
||||
}
|
||||
|
||||
private:
|
||||
std::string kernel_name_;
|
||||
bool small_ = false;
|
||||
int64_t axis_;
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
S bound_;
|
||||
size_t outerSize_;
|
||||
size_t innerSize_;
|
||||
bool is_null_input_;
|
||||
size_t outer_size_;
|
||||
size_t inner_size_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -25,6 +25,18 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
int64_t ArgMaxWithValue::axis() const {
|
||||
auto value_ptr = GetAttr("axis");
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
bool ArgMaxWithValue::keep_dims() const {
|
||||
auto value_ptr = GetAttr("keep_dims");
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
namespace {
|
||||
abstract::TupleShapePtr ArgMaxWithValueInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
|
|
@ -29,6 +29,9 @@ class MIND_API ArgMaxWithValue : public BaseOperator {
|
|||
public:
|
||||
MIND_API_BASE_MEMBER(ArgMaxWithValue);
|
||||
ArgMaxWithValue() : BaseOperator(kNameArgMaxWithValue) { InitIOName({"input_x"}, {"index", "output_x"}); }
|
||||
|
||||
int64_t axis() const;
|
||||
bool keep_dims() const;
|
||||
};
|
||||
abstract::AbstractBasePtr ArgMaxWithValueInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/argmin_with_value.h"
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
int64_t ArgMinWithValue::axis() const {
|
||||
auto value_ptr = GetAttr("axis");
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
bool ArgMinWithValue::keep_dims() const {
|
||||
auto value_ptr = GetAttr("keep_dims");
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
namespace {
|
||||
abstract::TupleShapePtr ArgMinWithValueInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape_ptr = input_args[0]->BuildShape();
|
||||
auto x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr);
|
||||
auto x_shape = x_shape_map[kShape];
|
||||
auto axis_value = primitive->GetAttr("axis");
|
||||
MS_EXCEPTION_IF_NULL(axis_value);
|
||||
auto axis = GetValue<int64_t>(axis_value);
|
||||
auto keep_dims_value = primitive->GetAttr("keep_dims");
|
||||
MS_EXCEPTION_IF_NULL(keep_dims_value);
|
||||
auto keep_dims = GetValue<bool>(keep_dims_value);
|
||||
auto x_rank = SizeToLong(x_shape.size());
|
||||
if (x_rank == 0) {
|
||||
if (axis != -1 && axis != 0) {
|
||||
MS_EXCEPTION(ValueError) << "For ArgMinWithValue with 0d input tensor, axis must be one of 0 or -1, but got"
|
||||
<< axis << ".";
|
||||
}
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{x_shape_ptr, x_shape_ptr});
|
||||
}
|
||||
if (axis < 0) {
|
||||
axis += x_rank;
|
||||
}
|
||||
if (axis < 0 || axis >= x_rank) {
|
||||
MS_EXCEPTION(ValueError) << "For ArgMinWithValue, axis must be in range [-x_rank, x_rank), but got" << axis << ".";
|
||||
}
|
||||
(void)primitive->AddAttr("dimension", MakeValue(axis));
|
||||
// Calculate all the shapes.
|
||||
auto cal_shape = [axis, keep_dims](ShapeVector &shape, const ShapeVector &x_shape) -> void {
|
||||
(void)shape.insert(shape.end(), x_shape.begin(), x_shape.end());
|
||||
if (keep_dims) {
|
||||
shape[axis] = 1;
|
||||
} else {
|
||||
(void)shape.erase(shape.begin() + axis);
|
||||
}
|
||||
};
|
||||
ShapeVector output_shape;
|
||||
cal_shape(output_shape, x_shape);
|
||||
auto index_and_value_shape = std::make_shared<abstract::Shape>(output_shape);
|
||||
return std::make_shared<abstract::TupleShape>(
|
||||
std::vector<abstract::BaseShapePtr>{index_and_value_shape, index_and_value_shape});
|
||||
}
|
||||
|
||||
TuplePtr ArgMinWithValueInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
TypePtr input_x_type = input_args[0]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_x_type, valid_types, prim->name());
|
||||
auto index_type = std::make_shared<TensorType>(kInt32);
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{index_type, input_x_type});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr ArgMinWithValueInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto shapes = ArgMinWithValueInferShape(primitive, input_args);
|
||||
auto types = ArgMinWithValueInferType(primitive, input_args);
|
||||
return abstract::MakeAbstract(shapes, types);
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(ArgMinWithValue, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ArgMinWithValue, prim::kPrimArgMinWithValue, ArgMinWithValueInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_ARGMIN_WITH_VALUE_H_
|
||||
#define MINDSPORE_CORE_OPS_ARGMIN_WITH_VALUE_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameArgMinWithValue = "ArgMinWithValue";
|
||||
class MIND_API ArgMinWithValue : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ArgMinWithValue);
|
||||
ArgMinWithValue() : BaseOperator(kNameArgMinWithValue) { InitIOName({"input_x"}, {"index", "output_x"}); }
|
||||
|
||||
int64_t axis() const;
|
||||
bool keep_dims() const;
|
||||
};
|
||||
abstract::AbstractBasePtr ArgMinWithValueInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_ARGMIN_WITH_VALUE_H_
|
|
@ -414,6 +414,7 @@ GVAR_DEF(PrimitivePtr, kPrimPad, std::make_shared<Primitive>("Pad"));
|
|||
GVAR_DEF(PrimitivePtr, kPrimPadding, std::make_shared<Primitive>(kPadding));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMirrorPad, std::make_shared<Primitive>(kMirrorPad));
|
||||
GVAR_DEF(PrimitivePtr, kPrimArgMaxWithValue, std::make_shared<Primitive>("ArgMaxWithValue"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimArgMinWithValue, std::make_shared<Primitive>("ArgMinWithValue"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimUnique, std::make_shared<Primitive>("Unique"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimUniqueWithPad, std::make_shared<Primitive>("UniqueWithPad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimUniqueGrad, std::make_shared<Primitive>("UniqueGrad"));
|
||||
|
@ -1139,7 +1140,6 @@ GVAR_DEF(PrimitivePtr, kPrimSpaceToDepth, std::make_shared<Primitive>("SpaceToDe
|
|||
GVAR_DEF(PrimitivePtr, kPrimPadFusion, std::make_shared<Primitive>("PadFusion"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimPowFusion, std::make_shared<Primitive>("PowFusion"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimResize, std::make_shared<Primitive>("Resize"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimArgMinWithValue, std::make_shared<Primitive>("ArgMinWithValue"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimIf, std::make_shared<Primitive>("If"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimAvgPoolFusion, std::make_shared<Primitive>("AvgPoolFusion"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMaxPoolFusion, std::make_shared<Primitive>("MaxPoolFusion"));
|
||||
|
|
|
@ -495,6 +495,13 @@ def argmax(x, axis=None):
|
|||
return P.Argmax(axis)(x)
|
||||
|
||||
|
||||
def arg_min_with_value(input_x, axis=0, keep_dims=False):
|
||||
"""
|
||||
Returns the minimum value with corresponding index.
|
||||
"""
|
||||
return F.arg_min_with_value(input_x, axis, keep_dims)
|
||||
|
||||
|
||||
def argmin(x, axis=None):
|
||||
"""
|
||||
Returns the indices of the minimum values along an axis.
|
||||
|
|
|
@ -1999,6 +1999,54 @@ class Tensor(Tensor_):
|
|||
# P.Argmin is currently not supported
|
||||
return tensor_operator_registry.get('argmax')(axis)(tensor_operator_registry.get('__neg__')(a))
|
||||
|
||||
def arg_min_with_value(self, axis=0, keep_dims=False):
|
||||
"""
|
||||
Returns the minimum value with corresponding index.
|
||||
|
||||
Note:
|
||||
In auto_parallel and semi_auto_parallel mode, the first output index can not be used.
|
||||
|
||||
.. warning::
|
||||
- If there are multiple minimum values, the index of the first minimum value is used.
|
||||
- The value range of "axis" is [-dims, dims - 1]. "dims" is the dimension length of "input_x".
|
||||
|
||||
Args:
|
||||
axis (int): The dimension to reduce. Default: 0.
|
||||
keep_dims (bool): Whether to reduce dimension, if true the output will keep the same dimension as the input,
|
||||
the output will reduce dimension if false. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input tensor, can be any dimension. Set the shape of input tensor as
|
||||
:math:`(x_1, x_2, ..., x_N)` .
|
||||
|
||||
Outputs:
|
||||
tuple (Tensor), tuple of 2 tensors, containing the corresponding index and the minimum value of the input
|
||||
tensor.
|
||||
|
||||
- index (Tensor) - The index for the minimum value of the input tensor. If `keep_dims` is true, the shape of
|
||||
output tensors is :math:`(x_1, x_2, ..., x_{axis-1}, 1, x_{axis+1}, ..., x_N)`. Otherwise, the shape is
|
||||
:math:`(x_1, x_2, ..., x_{axis-1}, x_{axis+1}, ..., x_N)` .
|
||||
- output_x (Tensor) - The minimum value of input tensor, with the same shape as index.
|
||||
|
||||
Raises:
|
||||
TypeError: If `keep_dims` is not a bool.
|
||||
TypeError: If `axis` is not an int.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([0.0, 0.4, 0.6, 0.7, 0.1]), mindspore.float32)
|
||||
>>> output = ops.arg_min_with_value(input_x)
|
||||
>>> print(output)
|
||||
(Tensor(shape=[], dtype=Int32, value= 0), Tensor(shape=[], dtype=Float32, value= 0))
|
||||
>>> output = ops.arg_min_with_value(input_x, keep_dims=True)
|
||||
>>> print(output)
|
||||
(Tensor(shape=[1], dtype=Int32, value= [0]), Tensor(shape=[1], dtype=Float32, value= [ 0.00000000e+00]))
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('arg_min_with_value')(self, axis, keep_dims)
|
||||
|
||||
def cumsum(self, axis=None, dtype=None):
|
||||
"""
|
||||
Return the cumulative sum of the elements along a given axis.
|
||||
|
|
|
@ -102,6 +102,7 @@ from .array_func import (
|
|||
split,
|
||||
index_fill,
|
||||
max,
|
||||
min,
|
||||
)
|
||||
from .parameter_func import (
|
||||
assign,
|
||||
|
|
|
@ -3631,6 +3631,69 @@ def max(input_x, axis=0, keep_dims=False):
|
|||
return argmax_with_value_op(input_x)
|
||||
|
||||
|
||||
def min(input_x, axis=0, keep_dims=False):
|
||||
"""
|
||||
Calculates the minimum value with corresponding index, and returns indices and values.
|
||||
|
||||
Calculates the minimum value along with the given axis for the input tensor. It returns the minimum values and
|
||||
indices.
|
||||
|
||||
Note:
|
||||
In auto_parallel and semi_auto_parallel mode, the first output index can not be used.
|
||||
|
||||
.. warning::
|
||||
- If there are multiple minimum values, the index of the first minimum value is used.
|
||||
- The value range of "axis" is [-dims, dims - 1]. "dims" is the dimension length of "input_x".
|
||||
|
||||
Also see: class: `mindspore.ops.ArgMinWithValue`.
|
||||
|
||||
Args:
|
||||
input_x (Tensor) - The input tensor, can be any dimension. Set the shape of input tensor as
|
||||
:math:`(x_1, x_2, ..., x_N)` . And the data type only support mindspore.float16 or float32.
|
||||
axis (int): The dimension to reduce. Default: 0.
|
||||
keep_dims (bool): Whether to reduce dimension, if true the output will keep the same dimension as the input,
|
||||
the output will reduce dimension if false. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input tensor, can be any dimension. Set the shape of input tensor as
|
||||
:math:`(x_1, x_2, ..., x_N)` .
|
||||
|
||||
Outputs:
|
||||
tuple (Tensor), tuple of 2 tensors, containing the corresponding index and the minimum value of the input
|
||||
tensor.
|
||||
|
||||
- index (Tensor) - The index for the minimum value of the input tensor. If `keep_dims` is true, the shape of
|
||||
output tensors is :math:`(x_1, x_2, ..., x_{axis-1}, 1, x_{axis+1}, ..., x_N)`. Otherwise, the shape is
|
||||
:math:`(x_1, x_2, ..., x_{axis-1}, x_{axis+1}, ..., x_N)` .
|
||||
- output_x (Tensor) - The minimum value of input tensor, with the same shape as index.
|
||||
|
||||
Raises:
|
||||
TypeError: If `keep_dims` is not a bool.
|
||||
TypeError: If `axis` is not an int.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([0.0, 0.4, 0.6, 0.7, 0.1]), mindspore.float32)
|
||||
>>> index, output = ops.max(input_x)
|
||||
>>> print(index, output)
|
||||
3 0.7
|
||||
>>> index, output = ops.max(input_x, keep_dims=True)
|
||||
>>> print(index, output)
|
||||
[3] [0.7]
|
||||
|
||||
>>> input_x = Tensor(np.array([0.0, 0.4, 0.6, 0.7, 0.1]), mindspore.float32)
|
||||
>>> output = ops.min(input_x)
|
||||
>>> print(output)
|
||||
0 0.0
|
||||
>>> output = ops.min(input_x, keep_dims=True)
|
||||
>>> print(output)
|
||||
[0] [0.0]
|
||||
"""
|
||||
argmin_with_value_op = P.ArgMinWithValue(axis, keep_dims)
|
||||
return argmin_with_value_op(input_x)
|
||||
|
||||
__all__ = [
|
||||
'unique',
|
||||
'unique_with_pad',
|
||||
|
@ -3708,5 +3771,6 @@ __all__ = [
|
|||
'split',
|
||||
"index_fill",
|
||||
'max',
|
||||
'min',
|
||||
]
|
||||
__all__.sort()
|
||||
|
|
|
@ -1084,5 +1084,6 @@ tensor_operator_registry.register('norm', norm)
|
|||
tensor_operator_registry.register('renorm', renorm)
|
||||
tensor_operator_registry.register('adaptive_max_pool2d', AdaptiveMaxPool2D)
|
||||
tensor_operator_registry.register('coalesce', coalesce)
|
||||
tensor_operator_registry.register('arg_min_with_value', min)
|
||||
__all__ = [name for name in dir() if name[0] != "_"]
|
||||
__all__.remove('Primitive')
|
||||
|
|
|
@ -2220,6 +2220,8 @@ class ArgMinWithValue(PrimitiveWithInfer):
|
|||
- If there are multiple minimum values, the index of the first minimum value is used.
|
||||
- The value range of "axis" is [-dims, dims - 1]. "dims" is the dimension length of "input_x".
|
||||
|
||||
Also see: func: `mindspore.ops.arg_min_with_value`.
|
||||
|
||||
Args:
|
||||
axis (int): The dimension to reduce. Default: 0.
|
||||
keep_dims (bool): Whether to reduce dimension, if true the output will keep the same dimension as the input,
|
||||
|
|
|
@ -16,10 +16,12 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
|
||||
class NetArgminWithValue(nn.Cell):
|
||||
|
@ -101,6 +103,15 @@ def argminwithvalue_3d(data_type, shape_x):
|
|||
assert (output[1].asnumpy() == expect2).all()
|
||||
|
||||
|
||||
def argminwithvalue_tensor(context_mode):
|
||||
context.set_context(mode=context_mode, device_target="GPU")
|
||||
x = Tensor(np.array([[1., 20., 5.],
|
||||
[67., 8., 9.],
|
||||
[130., 24., 15.],
|
||||
[0.3, -0.4, -15.]]).astype(np.float32))
|
||||
return x.arg_min_with_value(axis=-1)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -144,3 +155,72 @@ def test_argminwithvalue_3d_big_float32():
|
|||
argminwithvalue_3d(np.float32, shape_x)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
argminwithvalue_3d(np.float32, shape_x)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_argminwithvalue_functional():
|
||||
"""
|
||||
Feature: support min op functional.
|
||||
Description: test the op using functional.
|
||||
Expectation: expect correct result.
|
||||
"""
|
||||
context.set_context(device_target="GPU")
|
||||
x = Tensor(np.array([[1., 20., 5.],
|
||||
[67., 8., 9.],
|
||||
[130., 24., 15.],
|
||||
[0.3, -0.4, -15.]]).astype(np.float32))
|
||||
expect_index = np.array([3, 3, 3]).astype(np.int32)
|
||||
expect_output = np.array([0.3, -0.4, -15.]).astype(np.float32)
|
||||
index, output = F.min(x, axis=0)
|
||||
|
||||
assert (index.asnumpy() == expect_index).all()
|
||||
assert (output.asnumpy() == expect_output).all()
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_argminwithvalue_tensor():
|
||||
"""
|
||||
Feature: support tensor's arg_min_with_value op.
|
||||
Description: test the op using tensor.
|
||||
Expectation: expect correct result.
|
||||
"""
|
||||
expect_index = np.array([0, 1, 2, 2]).astype(np.int32)
|
||||
expect_output = np.array([1., 8., 15., -15.]).astype(np.float32)
|
||||
|
||||
index, output = argminwithvalue_tensor(context.GRAPH_MODE)
|
||||
assert (index.asnumpy() == expect_index).all()
|
||||
assert (output.asnumpy() == expect_output).all()
|
||||
|
||||
index, output = argminwithvalue_tensor(context.PYNATIVE_MODE)
|
||||
assert (index.asnumpy() == expect_index).all()
|
||||
assert (output.asnumpy() == expect_output).all()
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_argminwithvalue_dynamic_shape():
|
||||
"""
|
||||
Feature: support arg_min_with_value op with dynamic shape.
|
||||
Description: test the op with dynamic shape
|
||||
Expectation: expect correct result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = Tensor(np.array([[1., 20., 5.],
|
||||
[67., 8., 9.],
|
||||
[130., 24., 15.],
|
||||
[0.3, -0.4, -15.]]).astype(np.float32))
|
||||
expect_index = np.array([0, 1, 2, 2]).astype(np.int32)
|
||||
expect_output = np.array([1., 8., 15., -15.]).astype(np.float32)
|
||||
|
||||
argmin_net = NetArgminWithValue()
|
||||
input_dynamic = Tensor(shape=[4, None], dtype=mindspore.float32)
|
||||
argmin_net.set_inputs(input_dynamic)
|
||||
output = argmin_net(x)
|
||||
|
||||
assert (output[1][0].asnumpy() == expect_index).all()
|
||||
assert (output[1][1].asnumpy() == expect_output).all()
|
||||
|
|
Loading…
Reference in New Issue