tuple operator: sequence_lt and sequence_le

This commit is contained in:
hujiahui8 2023-02-24 09:54:04 +08:00
parent b0b03870d6
commit 8d9f2ace6e
14 changed files with 826 additions and 1 deletions

View File

@ -0,0 +1,143 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/sequence/sequence_less_cpu_kernel.h"
#include <algorithm>
#include <utility>
#include <complex>
#include <unordered_map>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "utils/ms_utils.h"
#include "include/common/thread_pool.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr int kInputsNum = 2;
constexpr int kOutputsNum = 1;
constexpr auto kTupleLe = "tuple_le";
constexpr auto kTupleLt = "tuple_lt";
} // namespace
template <typename T, typename S>
bool LessImpl(const T *in_x, const S *in_y, const size_t in_x_size, const size_t in_y_size,
const bool is_less_equal = true) {
size_t max_size = std::max(in_x_size, in_x_size);
for (size_t i = 0; i < max_size; ++i) {
if (i >= in_x_size) {
return true;
}
if (i >= in_y_size) {
return false;
}
if (static_cast<double>(in_x[i]) < static_cast<double>(in_y[i])) {
return true;
} else if (static_cast<double>(in_x[i]) > static_cast<double>(in_y[i])) {
return false;
}
}
return is_less_equal;
}
template <typename T, typename S>
void LtImpl(const T *in_x, const S *in_y, bool *out, const size_t in_x_size, const size_t in_y_size) {
*out = LessImpl(in_x, in_y, in_x_size, in_y_size, false);
}
template <typename T, typename S>
void LeImpl(const T *in_x, const S *in_y, bool *out, const size_t in_x_size, const size_t in_y_size) {
*out = LessImpl(in_x, in_y, in_x_size, in_y_size, true);
}
bool SequenceLessCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
return MatchKernelFunc(base_operator, inputs, outputs);
}
int SequenceLessCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost);
if (ret != 0) {
return ret;
}
CHECK_KERNEL_INPUTS_NUM(input_shapes_.size(), kInputsNum, kernel_name_);
if (input_shapes_[0].empty() || input_shapes_[1].empty()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the x and y shape can't be 0, but got " << input_shapes_;
}
x_size_ = input_shapes_[0][0];
y_size_ = input_shapes_[1][0];
return KRET_OK;
}
template <typename T, typename S>
bool SequenceLessCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
using InequalityImplFunc = std::function<void(const T *, const S *, bool *, const bool, const bool)>;
std::unordered_map<std::string, InequalityImplFunc> func_map = {{kTupleLt, LtImpl<T, S>}, {kTupleLe, LeImpl<T, S>}};
auto iter = func_map.find(kernel_name_);
if (iter == func_map.end()) {
MS_EXCEPTION(TypeError) << "For '" << kernel_name_ << "' don't support. Only support [Le, Lt]";
}
InequalityImplFunc compute_func = iter->second;
const auto x_addr = GetDeviceAddress<T>(inputs, 0);
const auto y_addr = GetDeviceAddress<S>(inputs, 1);
bool *output_addr = GetDeviceAddress<bool>(outputs, 0);
compute_func(x_addr, y_addr, output_addr, x_size_, y_size_);
return true;
}
#define ADD_KERNEL(x_dtype, y_dtype, x_type, y_type) \
{ \
KernelAttr() \
.AddInputAttr(kObjectTypeTuple, kNumberType##x_dtype) \
.AddInputAttr(kObjectTypeTuple, kNumberType##y_dtype) \
.AddOutputAttr(kObjectTypeNumber, kNumberTypeBool), \
&SequenceLessCpuKernelMod::LaunchKernel<x_type, y_type> \
}
const std::vector<std::pair<KernelAttr, SequenceLessCpuKernelMod::KernelRunFunc>>
&SequenceLessCpuKernelMod::GetFuncList() const {
static const std::vector<std::pair<KernelAttr, SequenceLessCpuKernelMod::KernelRunFunc>> func_list = {
ADD_KERNEL(Float32, Float32, float, float), ADD_KERNEL(Float32, Float64, float, double),
ADD_KERNEL(Float32, Int32, float, int), ADD_KERNEL(Float32, Int64, float, int64_t),
ADD_KERNEL(Float32, Bool, float, bool), ADD_KERNEL(Float64, Float32, double, float),
ADD_KERNEL(Float64, Bool, double, bool), ADD_KERNEL(Float64, Float64, double, double),
ADD_KERNEL(Float64, Int32, double, int), ADD_KERNEL(Float64, Int64, double, int64_t),
ADD_KERNEL(Int32, Float32, int, float), ADD_KERNEL(Int32, Float64, int, double),
ADD_KERNEL(Int32, Int32, int, int), ADD_KERNEL(Int32, Int64, int, int64_t),
ADD_KERNEL(Int32, Bool, int, bool), ADD_KERNEL(Int64, Float32, int64_t, float),
ADD_KERNEL(Int64, Bool, int64_t, bool), ADD_KERNEL(Int64, Float64, int64_t, double),
ADD_KERNEL(Int64, Int32, int64_t, int), ADD_KERNEL(Int64, Int64, int64_t, int64_t),
ADD_KERNEL(Bool, Int32, bool, int), ADD_KERNEL(Bool, Int64, bool, int64_t),
ADD_KERNEL(Bool, Bool, bool, bool), ADD_KERNEL(Bool, Float64, bool, double),
ADD_KERNEL(Bool, Float32, bool, float)};
return func_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, tuple_le, SequenceLessCpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, tuple_lt, SequenceLessCpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, list_le, SequenceLessCpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, list_lt, SequenceLessCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,63 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SEQUENCE_LESS_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SEQUENCE_LESS_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include <utility>
#include <map>
#include <string>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class SequenceLessCpuKernelMod : public NativeCpuKernelMod,
public MatchKernelHelper<SequenceLessCpuKernelMod, AddressPtr> {
public:
SequenceLessCpuKernelMod() = default;
~SequenceLessCpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
MS_EXCEPTION_IF_NULL(kernel_func_);
return kernel_func_(this, inputs, workspace, outputs);
}
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
template <typename T, typename S>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
private:
size_t x_size_ = 0;
size_t y_size_ = 0;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SEQUENCE_LESS_CPU_KERNEL_H_

View File

@ -87,6 +87,10 @@ constexpr auto kScalarBool = "ScalarBool";
constexpr auto kBoolNot = "bool_not";
constexpr auto kScalarBitwiseAnd = "bit_and";
constexpr auto kScalarBitwiseOr = "bit_or";
constexpr auto kTupleLt = "tuple_lt";
constexpr auto kListLt = "list_lt";
constexpr auto kTupleLe = "tuple_le";
constexpr auto kListLe = "list_le";
constexpr auto kExp = "Exp";
constexpr auto kEqual = "Equal";
constexpr auto kNotEqual = "NotEqual";
@ -1753,6 +1757,10 @@ GVAR_DEF(PrimitivePtr, kPrimTupleGreaterThan, std::make_shared<Primitive>("tuple
GVAR_DEF(PrimitivePtr, kPrimListGreaterThan, std::make_shared<Primitive>("list_greater_than"));
GVAR_DEF(PrimitivePtr, kPrimTupleGreaterEqual, std::make_shared<Primitive>("tuple_greater_equal"));
GVAR_DEF(PrimitivePtr, kPrimListGreaterEqual, std::make_shared<Primitive>("list_greater_equal"));
GVAR_DEF(PrimitivePtr, kPrimTupleLessThan, std::make_shared<Primitive>(kTupleLt));
GVAR_DEF(PrimitivePtr, kPrimListLessThan, std::make_shared<Primitive>(kListLt));
GVAR_DEF(PrimitivePtr, kPrimTupleLessEqual, std::make_shared<Primitive>(kTupleLe));
GVAR_DEF(PrimitivePtr, kPrimListLessEqual, std::make_shared<Primitive>(kListLe));
GVAR_DEF(PrimitivePtr, kPrimMakeRange, std::make_shared<Primitive>("make_range"));
GVAR_DEF(PrimitivePtr, kPrimStopGradient, std::make_shared<Primitive>("StopGradient"));
GVAR_DEF(PrimitivePtr, kPrimDictLen, std::make_shared<Primitive>("dict_len"));

View File

@ -0,0 +1,37 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_LIST_LE_H_
#define MINDSPORE_CORE_OPS_LIST_LE_H_
#include "ops/base_operator.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
namespace ops {
/// \brief list less equal operation.
class MIND_API list_le : public BaseOperator {
public:
MIND_API_BASE_MEMBER(list_le);
/// \brief Constructor.
list_le() : BaseOperator(prim::kListLe) {}
/// \brief Init function.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_LIST_LE_H_

View File

@ -0,0 +1,37 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_LIST_LT_H_
#define MINDSPORE_CORE_OPS_LIST_LT_H_
#include "ops/base_operator.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
namespace ops {
/// \brief list less than operation.
class MIND_API list_lt : public BaseOperator {
public:
MIND_API_BASE_MEMBER(list_lt);
/// \brief Constructor.
list_lt() : BaseOperator(prim::kListLt) {}
/// \brief Init function.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_LIST_LT_H_

View File

@ -0,0 +1,138 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include <string>
#include <algorithm>
#include "ops/tuple_le.h"
#include "ops/tuple_lt.h"
#include "ops/list_le.h"
#include "ops/list_lt.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "include/common/utils/utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
AbstractBasePtr LessImpl(const AbstractBasePtrList &seqx_elements, const AbstractBasePtrList &seqy_elements,
const std::string &prim_name, const bool is_less_equal = true) {
size_t x_size = seqx_elements.size();
size_t y_size = seqy_elements.size();
size_t max_size = std::max(x_size, y_size);
for (size_t i = 0; i < max_size; ++i) {
if (i >= x_size) {
return std::make_shared<abstract::AbstractScalar>(true);
}
if (i >= y_size) {
return std::make_shared<abstract::AbstractScalar>(false);
}
auto x_element = seqx_elements[i];
auto y_element = seqy_elements[i];
if (x_element->BuildType()->type_id() == kObjectTypeTensorType ||
y_element->BuildType()->type_id() == kObjectTypeTensorType) {
MS_EXCEPTION(TypeError) << "For primitive tupel_equal, the input element must be scalar, but got "
<< x_element->ToString() << " and " << y_element->ToString();
}
if (x_element->BuildValue() == kAnyValue || y_element->BuildValue() == kAnyValue) {
return std::make_shared<abstract::AbstractScalar>(kAnyValue, kBool);
}
auto x = GetScalarValue<double>(prim_name, x_element->BuildValue());
auto y = GetScalarValue<double>(prim_name, y_element->BuildValue());
if (x > y) {
return std::make_shared<abstract::AbstractScalar>(false);
} else if (x < y) {
return std::make_shared<abstract::AbstractScalar>(true);
}
}
return std::make_shared<abstract::AbstractScalar>(is_less_equal);
}
AbstractBasePtr SequenceLessInferInner(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
const bool is_less_equal = true) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
constexpr size_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x_abs = input_args[0];
auto y_abs = input_args[1];
if (!x_abs->isa<abstract::AbstractSequence>() || !y_abs->isa<abstract::AbstractSequence>()) {
MS_EXCEPTION(TypeError) << "For primitive '" << prim_name << "', the input must be a list or tuple, "
<< "but got: " << x_abs->ToString() << " and " << y_abs->ToString();
}
auto seqx_abs = x_abs->cast<abstract::AbstractSequencePtr>();
auto seqy_abs = y_abs->cast<abstract::AbstractSequencePtr>();
if (seqx_abs->dynamic_len() || seqy_abs->dynamic_len()) {
return std::make_shared<abstract::AbstractScalar>(kAnyValue, kBool);
}
const auto &seqx_elements = seqx_abs->elements();
const auto &seqy_elements = seqy_abs->elements();
return LessImpl(seqx_elements, seqy_elements, prim_name, is_less_equal);
}
class SequenceLessThanInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
return SequenceLessInferInner(primitive, input_args, false)->BuildShape();
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
return SequenceLessInferInner(prim, input_args, false)->BuildType();
}
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
return SequenceLessInferInner(primitive, input_args, false);
}
};
class SequenceLessEqualInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
return SequenceLessInferInner(primitive, input_args)->BuildShape();
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
return SequenceLessInferInner(prim, input_args)->BuildType();
}
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
return SequenceLessInferInner(primitive, input_args);
}
};
MIND_API_OPERATOR_IMPL(tuple_le, BaseOperator);
MIND_API_OPERATOR_IMPL(tuple_lt, BaseOperator);
MIND_API_OPERATOR_IMPL(list_le, BaseOperator);
MIND_API_OPERATOR_IMPL(list_lt, BaseOperator);
REGISTER_PRIMITIVE_OP_INFER_IMPL(tuple_le, prim::kPrimTupleLessEqual, SequenceLessEqualInfer, false);
REGISTER_PRIMITIVE_OP_INFER_IMPL(list_le, prim::kPrimListLessEqual, SequenceLessEqualInfer, false);
REGISTER_PRIMITIVE_OP_INFER_IMPL(tuple_lt, prim::kPrimTupleLessThan, SequenceLessThanInfer, false);
REGISTER_PRIMITIVE_OP_INFER_IMPL(list_lt, prim::kPrimListLessThan, SequenceLessThanInfer, false);
} // namespace ops
} // namespace mindspore

View File

@ -96,6 +96,6 @@ class SequenceMulInfer : public abstract::OpInferBase {
std::set<int64_t> GetValueDependArgIndices() const override { return {1}; }
};
REGISTER_PRIMITIVE_OP_INFER_IMPL(SequenceMul, prim::kPrimSequenceMul, SequenceMulInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(SequenceMul, prim::kPrimSequenceMul, SequenceMulInfer, false);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_TUPLE_LE_H_
#define MINDSPORE_CORE_OPS_TUPLE_LE_H_
#include "ops/base_operator.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
namespace ops {
/// \brief tuple less equal operation.
class MIND_API tuple_le : public BaseOperator {
public:
MIND_API_BASE_MEMBER(tuple_le);
/// \brief Constructor.
tuple_le() : BaseOperator(prim::kTupleLt) {}
/// \brief Init function.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_TUPLE_LE_H_

View File

@ -0,0 +1,37 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_TUPLE_LT_H_
#define MINDSPORE_CORE_OPS_TUPLE_LT_H_
#include "ops/base_operator.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
namespace ops {
/// \brief Sequence less than operation.
class MIND_API tuple_lt : public BaseOperator {
public:
MIND_API_BASE_MEMBER(tuple_lt);
/// \brief Constructor.
tuple_lt() : BaseOperator(prim::kTupleLt) {}
/// \brief Init function.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_TUPLE_LT_H_

View File

@ -100,6 +100,19 @@ def get_bprop_setitem(self):
return bprop
@bprop_getters.register("tuple_le")
@bprop_getters.register("tuple_lt")
@bprop_getters.register("list_le")
@bprop_getters.register("list_lt")
def get_bprop_less(self):
"""Generate bprop for SequenceLessThan and SequenceLessEqual"""
def bprop(x, y, out, dout):
return zeros_like(x), zeros_like(y)
return bprop
@bprop_getters.register(seq.SequenceMul)
def get_bprop_mul(self):
"""Generate bprop for SequenceMul"""

View File

@ -19,6 +19,7 @@ from __future__ import division
from mindspore.ops.composite import base
from mindspore.ops import functional as F
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.operations import _sequence_ops as _seq
# less_equal is a metagraph object which will determine if two objects are less_equal according to input type
# using ".register" decorator
@ -70,3 +71,33 @@ def _less_equal_tensor(x, y):
Tensor, return value by operator P.LessEqual.
"""
return F.tensor_le(x, y)
@less_equal.register("Tuple", "Tuple")
def _less_equal_tuple(x, y):
"""
Determine whether x is less than or equal to y.
Args:
x(Tuple): Tuple.
y(Tuple): Tuple.
Returns:
bool, if x <= y return true in python logic, x > y return false.
"""
return _seq.tuple_le()(x, y)
@less_equal.register("List", "List")
def _less_equal_list(x, y):
"""
Determine whether x is less than or equal to y.
Args:
x(List): List.
y(List): List.
Returns:
bool, if x <= y return true in python logic, x > y return false.
"""
return _seq.list_le()(x, y)

View File

@ -20,6 +20,7 @@ from __future__ import division
from mindspore.ops.composite import base
from mindspore.ops import functional as F
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.operations import _sequence_ops as _seq
# less is a metafuncgraph object which will determine if two objects are less according to input type
# using ".register" decorator
@ -71,3 +72,33 @@ def _less_tensor(x, y):
Tensor, return value of x and y by operation P.Less()
"""
return F.tensor_lt(x, y)
@less.register("Tuple", "Tuple")
def _less_tuple(x, y):
"""
Determine whether x is less than to y.
Args:
x(Tuple): Tuple.
y(Tuple): Tuple.
Returns:
bool, if x < y return true in python logic, x >= y return false.
"""
return _seq.tuple_lt()(x, y)
@less.register("List", "List")
def _less_list(x, y):
"""
Determine whether x is less than to y.
Args:
x(List): List.
y(List): List.
Returns:
bool, if x < y return true in python logic, x >= y return false.
"""
return _seq.list_lt()(x, y)

View File

@ -643,3 +643,115 @@ class list_greater_equal(Primitive):
"""Initialize list_greater_equal"""
self.init_prim_io_names(
inputs=['input_0', 'input_1'], outputs=['output_data'])
class tuple_lt(Primitive):
r"""
Support tuple less_than operation 'less_than(target)'.
.. note::
This it is only for internal used.
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
Inputs:
- **input_0** (Union[Tuple]) - The first sequence.
- **input_1** (Union[Tuple]) - The second sequence, dtype and shape should be same as 'input_0'.
Outputs:
A bool value to indicate whether every element in 'input_0' is less than element in 'input_1' correspondingly.
Raises:
TypeError: The 'input_0' or 'input_1' is not tuple.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self):
"""Initialize tuple_lt"""
self.init_prim_io_names(
inputs=['input_0', 'input_1'], outputs=['output_data'])
class list_lt(Primitive):
r"""
Support list less_than operation 'less_than(target)'.
.. note::
This it is only for internal used.
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
Inputs:
- **input_0** (Union[List]) - The first sequence.
- **input_1** (Union[List]) - The second sequence, dtype and shape should be same as 'input_0'.
Outputs:
A bool value to indicate whether every element in 'input_0' is less than element in 'input_1' correspondingly.
Raises:
TypeError: The 'input_0' or 'input_1' is not list.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self):
"""Initialize list_lt"""
self.init_prim_io_names(
inputs=['input_0', 'input_1'], outputs=['output_data'])
class tuple_le(Primitive):
r"""
Support tuple less_equal operation 'less_equal(target)'.
.. note::
This it is only for internal used.
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
Inputs:
- **input_0** (Union[Tuple]) - The first sequence.
- **input_1** (Union[Tuple]) - The second sequence, dtype and shape should be same as 'input_0'.
Outputs:
A bool value to indicate whether every element in 'input_0' is less equal element in 'input_1' correspondingly.
Raises:
TypeError: The 'input_0' or 'input_1' is not tuple.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self):
"""Initialize tuple_le"""
self.init_prim_io_names(
inputs=['input_0', 'input_1'], outputs=['output_data'])
class list_le(Primitive):
r"""
Support list less equal operation 'less_equal(target)'.
.. note::
This it is only for internal used.
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
Inputs:
- **input_0** (Union[List]) - The first sequence.
- **input_1** (Union[List]) - The second sequence, dtype and shape should be same as 'input_0'.
Outputs:
A bool value to indicate whether every element in 'input_0' is less equal element in 'input_1' correspondingly.
Raises:
TypeError: The 'input_0' or 'input_1' is not list.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self):
"""Initialize list_le"""
self.init_prim_io_names(
inputs=['input_0', 'input_1'], outputs=['output_data'])

View File

@ -0,0 +1,138 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import mindspore.nn as nn
from mindspore import context
from mindspore.ops.operations import _sequence_ops as _seq
from mindspore.common import mutable
from mindspore.ops.composite import GradOperation
from sequence_help import context_prepare
context.set_context(mode=context.GRAPH_MODE)
context_prepare()
class NetTupleLt(nn.Cell):
def __init__(self):
super().__init__()
self.seq_lt = _seq.tuple_lt()
def construct(self, x, y):
return self.seq_lt(x, y)
class NetTupleLe(nn.Cell):
def __init__(self):
super().__init__()
self.seq_le = _seq.tuple_le()
def construct(self, x, y):
return self.seq_le(x, y)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_seq_dyn_le():
"""
Feature: test sequence getitem op
Description: setitem operation on tuple type
Expectation: the behavior is matched to python style
"""
x = mutable((1, 2, 3, 4, 5, 6), True)
y = mutable((1, 2, 3, 2, 6), True)
expect = False
net = NetTupleLe()
res = net(x, y)
assert res == expect
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_seq_dyn_lt():
"""
Feature: test sequence getitem op
Description: setitem operation on tuple type
Expectation: the behavior is matched to python style
"""
x = mutable((1, 2, 3, 4, 5, 6), True)
y = (1, 2, 3, 4, 5, 6)
expect = False
net = NetTupleLt()
res = net(x, y)
assert res == expect
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_seq_le():
"""
Feature: test sequence getitem op
Description: setitem operation on tuple type
Expectation: the behavior is matched to python style
"""
x = (1, 2, 3, 4, 5)
y = (True, 2, 3, 4, 5)
expect = True
net = NetTupleLe()
res = net(x, y)
assert res == expect
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_seq_lt():
"""
Feature: test sequence getitem op
Description: setitem operation on tuple type
Expectation: the behavior is matched to python style
"""
x = (1, 2, 3, 4, 5, 6)
y = (True, 2, 3, 4, 5)
expect = False
net = NetTupleLt()
res = net(x, y)
assert res == expect
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_seq_getitem_grad():
"""
Feature: test sequence getitem grad op
Description: setitem operation on tuple type
Expectation: the behavior is matched to python style
"""
net_ms = NetTupleLe()
x = mutable((2, 3, 4, 5, 6), True)
y = mutable((1, 2, 3, 4, 5, 6), True)
dout = True
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
print("grad out1 = ", grad_func(x, y, dout))