tuple operator: sequence_lt and sequence_le
This commit is contained in:
parent
b0b03870d6
commit
8d9f2ace6e
|
@ -0,0 +1,143 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/sequence/sequence_less_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <complex>
|
||||
#include <unordered_map>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "include/common/thread_pool.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr int kInputsNum = 2;
|
||||
constexpr int kOutputsNum = 1;
|
||||
constexpr auto kTupleLe = "tuple_le";
|
||||
constexpr auto kTupleLt = "tuple_lt";
|
||||
} // namespace
|
||||
|
||||
template <typename T, typename S>
|
||||
bool LessImpl(const T *in_x, const S *in_y, const size_t in_x_size, const size_t in_y_size,
|
||||
const bool is_less_equal = true) {
|
||||
size_t max_size = std::max(in_x_size, in_x_size);
|
||||
for (size_t i = 0; i < max_size; ++i) {
|
||||
if (i >= in_x_size) {
|
||||
return true;
|
||||
}
|
||||
if (i >= in_y_size) {
|
||||
return false;
|
||||
}
|
||||
if (static_cast<double>(in_x[i]) < static_cast<double>(in_y[i])) {
|
||||
return true;
|
||||
} else if (static_cast<double>(in_x[i]) > static_cast<double>(in_y[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return is_less_equal;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void LtImpl(const T *in_x, const S *in_y, bool *out, const size_t in_x_size, const size_t in_y_size) {
|
||||
*out = LessImpl(in_x, in_y, in_x_size, in_y_size, false);
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void LeImpl(const T *in_x, const S *in_y, bool *out, const size_t in_x_size, const size_t in_y_size) {
|
||||
*out = LessImpl(in_x, in_y, in_x_size, in_y_size, true);
|
||||
}
|
||||
|
||||
bool SequenceLessCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
|
||||
return MatchKernelFunc(base_operator, inputs, outputs);
|
||||
}
|
||||
|
||||
int SequenceLessCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
CHECK_KERNEL_INPUTS_NUM(input_shapes_.size(), kInputsNum, kernel_name_);
|
||||
if (input_shapes_[0].empty() || input_shapes_[1].empty()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the x and y shape can't be 0, but got " << input_shapes_;
|
||||
}
|
||||
x_size_ = input_shapes_[0][0];
|
||||
y_size_ = input_shapes_[1][0];
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
bool SequenceLessCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
using InequalityImplFunc = std::function<void(const T *, const S *, bool *, const bool, const bool)>;
|
||||
std::unordered_map<std::string, InequalityImplFunc> func_map = {{kTupleLt, LtImpl<T, S>}, {kTupleLe, LeImpl<T, S>}};
|
||||
auto iter = func_map.find(kernel_name_);
|
||||
if (iter == func_map.end()) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << kernel_name_ << "' don't support. Only support [Le, Lt]";
|
||||
}
|
||||
InequalityImplFunc compute_func = iter->second;
|
||||
|
||||
const auto x_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
const auto y_addr = GetDeviceAddress<S>(inputs, 1);
|
||||
bool *output_addr = GetDeviceAddress<bool>(outputs, 0);
|
||||
|
||||
compute_func(x_addr, y_addr, output_addr, x_size_, y_size_);
|
||||
return true;
|
||||
}
|
||||
|
||||
#define ADD_KERNEL(x_dtype, y_dtype, x_type, y_type) \
|
||||
{ \
|
||||
KernelAttr() \
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberType##x_dtype) \
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberType##y_dtype) \
|
||||
.AddOutputAttr(kObjectTypeNumber, kNumberTypeBool), \
|
||||
&SequenceLessCpuKernelMod::LaunchKernel<x_type, y_type> \
|
||||
}
|
||||
|
||||
const std::vector<std::pair<KernelAttr, SequenceLessCpuKernelMod::KernelRunFunc>>
|
||||
&SequenceLessCpuKernelMod::GetFuncList() const {
|
||||
static const std::vector<std::pair<KernelAttr, SequenceLessCpuKernelMod::KernelRunFunc>> func_list = {
|
||||
ADD_KERNEL(Float32, Float32, float, float), ADD_KERNEL(Float32, Float64, float, double),
|
||||
ADD_KERNEL(Float32, Int32, float, int), ADD_KERNEL(Float32, Int64, float, int64_t),
|
||||
ADD_KERNEL(Float32, Bool, float, bool), ADD_KERNEL(Float64, Float32, double, float),
|
||||
ADD_KERNEL(Float64, Bool, double, bool), ADD_KERNEL(Float64, Float64, double, double),
|
||||
ADD_KERNEL(Float64, Int32, double, int), ADD_KERNEL(Float64, Int64, double, int64_t),
|
||||
ADD_KERNEL(Int32, Float32, int, float), ADD_KERNEL(Int32, Float64, int, double),
|
||||
ADD_KERNEL(Int32, Int32, int, int), ADD_KERNEL(Int32, Int64, int, int64_t),
|
||||
ADD_KERNEL(Int32, Bool, int, bool), ADD_KERNEL(Int64, Float32, int64_t, float),
|
||||
ADD_KERNEL(Int64, Bool, int64_t, bool), ADD_KERNEL(Int64, Float64, int64_t, double),
|
||||
ADD_KERNEL(Int64, Int32, int64_t, int), ADD_KERNEL(Int64, Int64, int64_t, int64_t),
|
||||
ADD_KERNEL(Bool, Int32, bool, int), ADD_KERNEL(Bool, Int64, bool, int64_t),
|
||||
ADD_KERNEL(Bool, Bool, bool, bool), ADD_KERNEL(Bool, Float64, bool, double),
|
||||
ADD_KERNEL(Bool, Float32, bool, float)};
|
||||
return func_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, tuple_le, SequenceLessCpuKernelMod);
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, tuple_lt, SequenceLessCpuKernelMod);
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, list_le, SequenceLessCpuKernelMod);
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, list_lt, SequenceLessCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SEQUENCE_LESS_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SEQUENCE_LESS_CPU_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class SequenceLessCpuKernelMod : public NativeCpuKernelMod,
|
||||
public MatchKernelHelper<SequenceLessCpuKernelMod, AddressPtr> {
|
||||
public:
|
||||
SequenceLessCpuKernelMod() = default;
|
||||
~SequenceLessCpuKernelMod() override = default;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_func_);
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
template <typename T, typename S>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
|
||||
private:
|
||||
size_t x_size_ = 0;
|
||||
size_t y_size_ = 0;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SEQUENCE_LESS_CPU_KERNEL_H_
|
|
@ -87,6 +87,10 @@ constexpr auto kScalarBool = "ScalarBool";
|
|||
constexpr auto kBoolNot = "bool_not";
|
||||
constexpr auto kScalarBitwiseAnd = "bit_and";
|
||||
constexpr auto kScalarBitwiseOr = "bit_or";
|
||||
constexpr auto kTupleLt = "tuple_lt";
|
||||
constexpr auto kListLt = "list_lt";
|
||||
constexpr auto kTupleLe = "tuple_le";
|
||||
constexpr auto kListLe = "list_le";
|
||||
constexpr auto kExp = "Exp";
|
||||
constexpr auto kEqual = "Equal";
|
||||
constexpr auto kNotEqual = "NotEqual";
|
||||
|
@ -1753,6 +1757,10 @@ GVAR_DEF(PrimitivePtr, kPrimTupleGreaterThan, std::make_shared<Primitive>("tuple
|
|||
GVAR_DEF(PrimitivePtr, kPrimListGreaterThan, std::make_shared<Primitive>("list_greater_than"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTupleGreaterEqual, std::make_shared<Primitive>("tuple_greater_equal"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimListGreaterEqual, std::make_shared<Primitive>("list_greater_equal"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTupleLessThan, std::make_shared<Primitive>(kTupleLt));
|
||||
GVAR_DEF(PrimitivePtr, kPrimListLessThan, std::make_shared<Primitive>(kListLt));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTupleLessEqual, std::make_shared<Primitive>(kTupleLe));
|
||||
GVAR_DEF(PrimitivePtr, kPrimListLessEqual, std::make_shared<Primitive>(kListLe));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMakeRange, std::make_shared<Primitive>("make_range"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimStopGradient, std::make_shared<Primitive>("StopGradient"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimDictLen, std::make_shared<Primitive>("dict_len"));
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_LIST_LE_H_
|
||||
#define MINDSPORE_CORE_OPS_LIST_LE_H_
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
/// \brief list less equal operation.
|
||||
class MIND_API list_le : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(list_le);
|
||||
/// \brief Constructor.
|
||||
list_le() : BaseOperator(prim::kListLe) {}
|
||||
/// \brief Init function.
|
||||
void Init() const {}
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_LIST_LE_H_
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_LIST_LT_H_
|
||||
#define MINDSPORE_CORE_OPS_LIST_LT_H_
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
/// \brief list less than operation.
|
||||
class MIND_API list_lt : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(list_lt);
|
||||
/// \brief Constructor.
|
||||
list_lt() : BaseOperator(prim::kListLt) {}
|
||||
/// \brief Init function.
|
||||
void Init() const {}
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_LIST_LT_H_
|
|
@ -0,0 +1,138 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
|
||||
#include "ops/tuple_le.h"
|
||||
#include "ops/tuple_lt.h"
|
||||
#include "ops/list_le.h"
|
||||
#include "ops/list_lt.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
AbstractBasePtr LessImpl(const AbstractBasePtrList &seqx_elements, const AbstractBasePtrList &seqy_elements,
|
||||
const std::string &prim_name, const bool is_less_equal = true) {
|
||||
size_t x_size = seqx_elements.size();
|
||||
size_t y_size = seqy_elements.size();
|
||||
size_t max_size = std::max(x_size, y_size);
|
||||
|
||||
for (size_t i = 0; i < max_size; ++i) {
|
||||
if (i >= x_size) {
|
||||
return std::make_shared<abstract::AbstractScalar>(true);
|
||||
}
|
||||
if (i >= y_size) {
|
||||
return std::make_shared<abstract::AbstractScalar>(false);
|
||||
}
|
||||
auto x_element = seqx_elements[i];
|
||||
auto y_element = seqy_elements[i];
|
||||
|
||||
if (x_element->BuildType()->type_id() == kObjectTypeTensorType ||
|
||||
y_element->BuildType()->type_id() == kObjectTypeTensorType) {
|
||||
MS_EXCEPTION(TypeError) << "For primitive tupel_equal, the input element must be scalar, but got "
|
||||
<< x_element->ToString() << " and " << y_element->ToString();
|
||||
}
|
||||
if (x_element->BuildValue() == kAnyValue || y_element->BuildValue() == kAnyValue) {
|
||||
return std::make_shared<abstract::AbstractScalar>(kAnyValue, kBool);
|
||||
}
|
||||
|
||||
auto x = GetScalarValue<double>(prim_name, x_element->BuildValue());
|
||||
auto y = GetScalarValue<double>(prim_name, y_element->BuildValue());
|
||||
if (x > y) {
|
||||
return std::make_shared<abstract::AbstractScalar>(false);
|
||||
} else if (x < y) {
|
||||
return std::make_shared<abstract::AbstractScalar>(true);
|
||||
}
|
||||
}
|
||||
return std::make_shared<abstract::AbstractScalar>(is_less_equal);
|
||||
}
|
||||
|
||||
AbstractBasePtr SequenceLessInferInner(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
|
||||
const bool is_less_equal = true) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
constexpr size_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
||||
auto x_abs = input_args[0];
|
||||
auto y_abs = input_args[1];
|
||||
if (!x_abs->isa<abstract::AbstractSequence>() || !y_abs->isa<abstract::AbstractSequence>()) {
|
||||
MS_EXCEPTION(TypeError) << "For primitive '" << prim_name << "', the input must be a list or tuple, "
|
||||
<< "but got: " << x_abs->ToString() << " and " << y_abs->ToString();
|
||||
}
|
||||
auto seqx_abs = x_abs->cast<abstract::AbstractSequencePtr>();
|
||||
auto seqy_abs = y_abs->cast<abstract::AbstractSequencePtr>();
|
||||
if (seqx_abs->dynamic_len() || seqy_abs->dynamic_len()) {
|
||||
return std::make_shared<abstract::AbstractScalar>(kAnyValue, kBool);
|
||||
}
|
||||
const auto &seqx_elements = seqx_abs->elements();
|
||||
const auto &seqy_elements = seqy_abs->elements();
|
||||
|
||||
return LessImpl(seqx_elements, seqy_elements, prim_name, is_less_equal);
|
||||
}
|
||||
|
||||
class SequenceLessThanInfer : public abstract::OpInferBase {
|
||||
public:
|
||||
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return SequenceLessInferInner(primitive, input_args, false)->BuildShape();
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return SequenceLessInferInner(prim, input_args, false)->BuildType();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return SequenceLessInferInner(primitive, input_args, false);
|
||||
}
|
||||
};
|
||||
|
||||
class SequenceLessEqualInfer : public abstract::OpInferBase {
|
||||
public:
|
||||
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return SequenceLessInferInner(primitive, input_args)->BuildShape();
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return SequenceLessInferInner(prim, input_args)->BuildType();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return SequenceLessInferInner(primitive, input_args);
|
||||
}
|
||||
};
|
||||
|
||||
MIND_API_OPERATOR_IMPL(tuple_le, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(tuple_lt, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(list_le, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(list_lt, BaseOperator);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(tuple_le, prim::kPrimTupleLessEqual, SequenceLessEqualInfer, false);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(list_le, prim::kPrimListLessEqual, SequenceLessEqualInfer, false);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(tuple_lt, prim::kPrimTupleLessThan, SequenceLessThanInfer, false);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(list_lt, prim::kPrimListLessThan, SequenceLessThanInfer, false);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -96,6 +96,6 @@ class SequenceMulInfer : public abstract::OpInferBase {
|
|||
|
||||
std::set<int64_t> GetValueDependArgIndices() const override { return {1}; }
|
||||
};
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(SequenceMul, prim::kPrimSequenceMul, SequenceMulInfer, true);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(SequenceMul, prim::kPrimSequenceMul, SequenceMulInfer, false);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_TUPLE_LE_H_
|
||||
#define MINDSPORE_CORE_OPS_TUPLE_LE_H_
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
/// \brief tuple less equal operation.
|
||||
class MIND_API tuple_le : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(tuple_le);
|
||||
/// \brief Constructor.
|
||||
tuple_le() : BaseOperator(prim::kTupleLt) {}
|
||||
/// \brief Init function.
|
||||
void Init() const {}
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_TUPLE_LE_H_
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_TUPLE_LT_H_
|
||||
#define MINDSPORE_CORE_OPS_TUPLE_LT_H_
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
/// \brief Sequence less than operation.
|
||||
class MIND_API tuple_lt : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(tuple_lt);
|
||||
/// \brief Constructor.
|
||||
tuple_lt() : BaseOperator(prim::kTupleLt) {}
|
||||
/// \brief Init function.
|
||||
void Init() const {}
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_TUPLE_LT_H_
|
|
@ -100,6 +100,19 @@ def get_bprop_setitem(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register("tuple_le")
|
||||
@bprop_getters.register("tuple_lt")
|
||||
@bprop_getters.register("list_le")
|
||||
@bprop_getters.register("list_lt")
|
||||
def get_bprop_less(self):
|
||||
"""Generate bprop for SequenceLessThan and SequenceLessEqual"""
|
||||
|
||||
def bprop(x, y, out, dout):
|
||||
return zeros_like(x), zeros_like(y)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(seq.SequenceMul)
|
||||
def get_bprop_mul(self):
|
||||
"""Generate bprop for SequenceMul"""
|
||||
|
|
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||
from mindspore.ops.composite import base
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.ops.operations import _sequence_ops as _seq
|
||||
|
||||
# less_equal is a metagraph object which will determine if two objects are less_equal according to input type
|
||||
# using ".register" decorator
|
||||
|
@ -70,3 +71,33 @@ def _less_equal_tensor(x, y):
|
|||
Tensor, return value by operator P.LessEqual.
|
||||
"""
|
||||
return F.tensor_le(x, y)
|
||||
|
||||
|
||||
@less_equal.register("Tuple", "Tuple")
|
||||
def _less_equal_tuple(x, y):
|
||||
"""
|
||||
Determine whether x is less than or equal to y.
|
||||
|
||||
Args:
|
||||
x(Tuple): Tuple.
|
||||
y(Tuple): Tuple.
|
||||
|
||||
Returns:
|
||||
bool, if x <= y return true in python logic, x > y return false.
|
||||
"""
|
||||
return _seq.tuple_le()(x, y)
|
||||
|
||||
|
||||
@less_equal.register("List", "List")
|
||||
def _less_equal_list(x, y):
|
||||
"""
|
||||
Determine whether x is less than or equal to y.
|
||||
|
||||
Args:
|
||||
x(List): List.
|
||||
y(List): List.
|
||||
|
||||
Returns:
|
||||
bool, if x <= y return true in python logic, x > y return false.
|
||||
"""
|
||||
return _seq.list_le()(x, y)
|
||||
|
|
|
@ -20,6 +20,7 @@ from __future__ import division
|
|||
from mindspore.ops.composite import base
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.ops.operations import _sequence_ops as _seq
|
||||
|
||||
# less is a metafuncgraph object which will determine if two objects are less according to input type
|
||||
# using ".register" decorator
|
||||
|
@ -71,3 +72,33 @@ def _less_tensor(x, y):
|
|||
Tensor, return value of x and y by operation P.Less()
|
||||
"""
|
||||
return F.tensor_lt(x, y)
|
||||
|
||||
|
||||
@less.register("Tuple", "Tuple")
|
||||
def _less_tuple(x, y):
|
||||
"""
|
||||
Determine whether x is less than to y.
|
||||
|
||||
Args:
|
||||
x(Tuple): Tuple.
|
||||
y(Tuple): Tuple.
|
||||
|
||||
Returns:
|
||||
bool, if x < y return true in python logic, x >= y return false.
|
||||
"""
|
||||
return _seq.tuple_lt()(x, y)
|
||||
|
||||
|
||||
@less.register("List", "List")
|
||||
def _less_list(x, y):
|
||||
"""
|
||||
Determine whether x is less than to y.
|
||||
|
||||
Args:
|
||||
x(List): List.
|
||||
y(List): List.
|
||||
|
||||
Returns:
|
||||
bool, if x < y return true in python logic, x >= y return false.
|
||||
"""
|
||||
return _seq.list_lt()(x, y)
|
||||
|
|
|
@ -643,3 +643,115 @@ class list_greater_equal(Primitive):
|
|||
"""Initialize list_greater_equal"""
|
||||
self.init_prim_io_names(
|
||||
inputs=['input_0', 'input_1'], outputs=['output_data'])
|
||||
|
||||
|
||||
class tuple_lt(Primitive):
|
||||
r"""
|
||||
Support tuple less_than operation 'less_than(target)'.
|
||||
|
||||
.. note::
|
||||
This it is only for internal used.
|
||||
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
||||
|
||||
Inputs:
|
||||
- **input_0** (Union[Tuple]) - The first sequence.
|
||||
- **input_1** (Union[Tuple]) - The second sequence, dtype and shape should be same as 'input_0'.
|
||||
|
||||
Outputs:
|
||||
A bool value to indicate whether every element in 'input_0' is less than element in 'input_1' correspondingly.
|
||||
|
||||
Raises:
|
||||
TypeError: The 'input_0' or 'input_1' is not tuple.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize tuple_lt"""
|
||||
self.init_prim_io_names(
|
||||
inputs=['input_0', 'input_1'], outputs=['output_data'])
|
||||
|
||||
|
||||
class list_lt(Primitive):
|
||||
r"""
|
||||
Support list less_than operation 'less_than(target)'.
|
||||
|
||||
.. note::
|
||||
This it is only for internal used.
|
||||
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
||||
|
||||
Inputs:
|
||||
- **input_0** (Union[List]) - The first sequence.
|
||||
- **input_1** (Union[List]) - The second sequence, dtype and shape should be same as 'input_0'.
|
||||
|
||||
Outputs:
|
||||
A bool value to indicate whether every element in 'input_0' is less than element in 'input_1' correspondingly.
|
||||
|
||||
Raises:
|
||||
TypeError: The 'input_0' or 'input_1' is not list.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize list_lt"""
|
||||
self.init_prim_io_names(
|
||||
inputs=['input_0', 'input_1'], outputs=['output_data'])
|
||||
|
||||
|
||||
class tuple_le(Primitive):
|
||||
r"""
|
||||
Support tuple less_equal operation 'less_equal(target)'.
|
||||
|
||||
.. note::
|
||||
This it is only for internal used.
|
||||
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
||||
|
||||
Inputs:
|
||||
- **input_0** (Union[Tuple]) - The first sequence.
|
||||
- **input_1** (Union[Tuple]) - The second sequence, dtype and shape should be same as 'input_0'.
|
||||
|
||||
Outputs:
|
||||
A bool value to indicate whether every element in 'input_0' is less equal element in 'input_1' correspondingly.
|
||||
|
||||
Raises:
|
||||
TypeError: The 'input_0' or 'input_1' is not tuple.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize tuple_le"""
|
||||
self.init_prim_io_names(
|
||||
inputs=['input_0', 'input_1'], outputs=['output_data'])
|
||||
|
||||
|
||||
class list_le(Primitive):
|
||||
r"""
|
||||
Support list less equal operation 'less_equal(target)'.
|
||||
|
||||
.. note::
|
||||
This it is only for internal used.
|
||||
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
||||
|
||||
Inputs:
|
||||
- **input_0** (Union[List]) - The first sequence.
|
||||
- **input_1** (Union[List]) - The second sequence, dtype and shape should be same as 'input_0'.
|
||||
|
||||
Outputs:
|
||||
A bool value to indicate whether every element in 'input_0' is less equal element in 'input_1' correspondingly.
|
||||
|
||||
Raises:
|
||||
TypeError: The 'input_0' or 'input_1' is not list.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize list_le"""
|
||||
self.init_prim_io_names(
|
||||
inputs=['input_0', 'input_1'], outputs=['output_data'])
|
||||
|
|
|
@ -0,0 +1,138 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.ops.operations import _sequence_ops as _seq
|
||||
from mindspore.common import mutable
|
||||
from mindspore.ops.composite import GradOperation
|
||||
from sequence_help import context_prepare
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context_prepare()
|
||||
|
||||
|
||||
class NetTupleLt(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.seq_lt = _seq.tuple_lt()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.seq_lt(x, y)
|
||||
|
||||
|
||||
class NetTupleLe(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.seq_le = _seq.tuple_le()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.seq_le(x, y)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_seq_dyn_le():
|
||||
"""
|
||||
Feature: test sequence getitem op
|
||||
Description: setitem operation on tuple type
|
||||
Expectation: the behavior is matched to python style
|
||||
"""
|
||||
x = mutable((1, 2, 3, 4, 5, 6), True)
|
||||
y = mutable((1, 2, 3, 2, 6), True)
|
||||
expect = False
|
||||
net = NetTupleLe()
|
||||
res = net(x, y)
|
||||
assert res == expect
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_seq_dyn_lt():
|
||||
"""
|
||||
Feature: test sequence getitem op
|
||||
Description: setitem operation on tuple type
|
||||
Expectation: the behavior is matched to python style
|
||||
"""
|
||||
x = mutable((1, 2, 3, 4, 5, 6), True)
|
||||
y = (1, 2, 3, 4, 5, 6)
|
||||
expect = False
|
||||
net = NetTupleLt()
|
||||
res = net(x, y)
|
||||
assert res == expect
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_seq_le():
|
||||
"""
|
||||
Feature: test sequence getitem op
|
||||
Description: setitem operation on tuple type
|
||||
Expectation: the behavior is matched to python style
|
||||
"""
|
||||
x = (1, 2, 3, 4, 5)
|
||||
y = (True, 2, 3, 4, 5)
|
||||
expect = True
|
||||
net = NetTupleLe()
|
||||
res = net(x, y)
|
||||
assert res == expect
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_seq_lt():
|
||||
"""
|
||||
Feature: test sequence getitem op
|
||||
Description: setitem operation on tuple type
|
||||
Expectation: the behavior is matched to python style
|
||||
"""
|
||||
x = (1, 2, 3, 4, 5, 6)
|
||||
y = (True, 2, 3, 4, 5)
|
||||
expect = False
|
||||
net = NetTupleLt()
|
||||
res = net(x, y)
|
||||
assert res == expect
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_seq_getitem_grad():
|
||||
"""
|
||||
Feature: test sequence getitem grad op
|
||||
Description: setitem operation on tuple type
|
||||
Expectation: the behavior is matched to python style
|
||||
"""
|
||||
net_ms = NetTupleLe()
|
||||
x = mutable((2, 3, 4, 5, 6), True)
|
||||
y = mutable((1, 2, 3, 4, 5, 6), True)
|
||||
dout = True
|
||||
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
||||
print("grad out1 = ", grad_func(x, y, dout))
|
Loading…
Reference in New Issue