tuple equal ops

This commit is contained in:
huoxinyou 2023-02-23 16:03:18 +08:00
parent da82dab114
commit c47bc41af7
12 changed files with 549 additions and 35 deletions

View File

@ -68,18 +68,6 @@ struct SlideInfo {
int64_t stop; int64_t stop;
}; };
template <typename T>
AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
// Inputs: two tuples or two lists.
const size_t args_num = 2;
CheckArgsSize(op_name, args_spec_list, args_num);
auto input_x = CheckArg<T>(op_name, args_spec_list, 0);
auto input_y = CheckArg<T>(op_name, args_spec_list, 1);
ValuePtr x_value = input_x->BuildValue();
ValuePtr y_value = input_y->BuildValue();
return std::make_shared<AbstractScalar>(*x_value == *y_value);
}
void ComputeReduceIndex(const std::vector<int64_t> &reverse_x, const std::vector<int64_t> &reverse_y, void ComputeReduceIndex(const std::vector<int64_t> &reverse_x, const std::vector<int64_t> &reverse_y,
std::vector<int64_t> *grad_x_reduce_idx, std::vector<int64_t> *grad_y_reduce_idy) { std::vector<int64_t> *grad_x_reduce_idx, std::vector<int64_t> *grad_y_reduce_idy) {
MS_EXCEPTION_IF_NULL(grad_x_reduce_idx); MS_EXCEPTION_IF_NULL(grad_x_reduce_idx);
@ -814,16 +802,6 @@ AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const Primitive
return args_spec_list[0]->Clone(); return args_spec_list[0]->Clone();
} }
AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferImplTupleOrListEqual<AbstractTuple>(primitive->name(), args_spec_list);
}
AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferImplTupleOrListEqual<AbstractList>(primitive->name(), args_spec_list);
}
AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
return InferTupleOrListOrDictLen<AbstractDictionary>(primitive->name(), args_spec_list); return InferTupleOrListOrDictLen<AbstractDictionary>(primitive->name(), args_spec_list);
@ -1089,10 +1067,8 @@ REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringGetItem, prim::kPrimStringGetItem, Infe
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed, nullptr); REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv, nullptr); REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array, nullptr); REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual, nullptr);
// List // List
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce, nullptr); REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual, nullptr);
// Dict // Dict
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr); REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr);
// Slice // Slice
@ -1135,13 +1111,9 @@ void RegPrimitiveFrontEval() {
InferImplTupleDiv, nullptr); InferImplTupleDiv, nullptr);
abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimTupleToArray, abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimTupleToArray,
InferImplTuple2Array, nullptr); InferImplTuple2Array, nullptr);
abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimTupleEqual,
InferImplTupleEqual, nullptr);
// List // List
abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimListReduce, abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimListReduce,
InferImplListReduce, nullptr); InferImplListReduce, nullptr);
abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimListEqual,
InferImplListEqual, nullptr);
// Dict // Dict
abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimDictLen, abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimDictLen,
InferImplDictLen, nullptr); InferImplDictLen, nullptr);

View File

@ -36,13 +36,9 @@ AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
// List // List
AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive, AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
// Dict // Dict
AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);

View File

@ -0,0 +1,110 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/sequence/sequence_equal_cpu_kernel.h"
#include <algorithm>
#include <utility>
#include <complex>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "utils/ms_utils.h"
#include "include/common/thread_pool.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr int kInputsNum = 2;
constexpr int kOutputsNum = 1;
} // namespace
bool SequenceEqualCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
return MatchKernelFunc(base_operator, inputs, outputs);
}
int SequenceEqualCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
is_inputs_type_diff_ = false;
int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost);
if (ret != 0) {
return ret;
}
CHECK_KERNEL_INPUTS_NUM(input_shapes_.size(), kInputsNum, kernel_name_);
if (input_shapes_[0].empty() || input_shapes_[1].empty()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the x and y shape can't be 0, but got " << input_shapes_;
}
x_size_ = input_shapes_[0][0];
y_size_ = input_shapes_[1][0];
if (inputs[0]->GetDtype() != inputs[1]->GetDtype()) {
is_inputs_type_diff_ = true;
}
return KRET_OK;
}
template <typename T, typename S>
bool SequenceEqualCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
const auto x_addr = GetDeviceAddress<T>(inputs, 0);
const auto y_addr = GetDeviceAddress<S>(inputs, 1);
bool *output_addr = GetDeviceAddress<bool>(outputs, 0);
if (x_size_ != y_size_ || is_inputs_type_diff_) {
*output_addr = false;
return true;
}
for (size_t i = 0; i < x_size_; ++i) {
if (static_cast<double>(x_addr[i]) != static_cast<double>(y_addr[i])) {
*output_addr = false;
return true;
}
}
*output_addr = true;
return true;
}
#define ADD_KERNEL(x_dtype, y_dtype, x_type, y_type) \
{ \
KernelAttr() \
.AddInputAttr(kObjectTypeTuple, kNumberType##x_dtype) \
.AddInputAttr(kObjectTypeTuple, kNumberType##y_dtype) \
.AddOutputAttr(kObjectTypeNumber, kNumberTypeBool), \
&SequenceEqualCpuKernelMod::LaunchKernel<x_type, y_type> \
}
const std::vector<std::pair<KernelAttr, SequenceEqualCpuKernelMod::KernelRunFunc>>
&SequenceEqualCpuKernelMod::GetFuncList() const {
static const std::vector<std::pair<KernelAttr, SequenceEqualCpuKernelMod::KernelRunFunc>> func_list = {
ADD_KERNEL(Float32, Float32, float, float), ADD_KERNEL(Float32, Float64, float, double),
ADD_KERNEL(Float32, Int32, float, int), ADD_KERNEL(Float32, Int64, float, int64_t),
ADD_KERNEL(Float32, Bool, float, bool), ADD_KERNEL(Float64, Float32, double, float),
ADD_KERNEL(Float64, Bool, double, bool), ADD_KERNEL(Float64, Float64, double, double),
ADD_KERNEL(Float64, Int32, double, int), ADD_KERNEL(Float64, Int64, double, int64_t),
ADD_KERNEL(Int32, Float32, int, float), ADD_KERNEL(Int32, Float64, int, double),
ADD_KERNEL(Int32, Int32, int, int), ADD_KERNEL(Int32, Int64, int, int64_t),
ADD_KERNEL(Int32, Bool, int, bool), ADD_KERNEL(Int64, Float32, int64_t, float),
ADD_KERNEL(Int64, Bool, int64_t, bool), ADD_KERNEL(Int64, Float64, int64_t, double),
ADD_KERNEL(Int64, Int32, int64_t, int), ADD_KERNEL(Int64, Int64, int64_t, int64_t),
ADD_KERNEL(Bool, Int32, bool, int), ADD_KERNEL(Bool, Int64, bool, int64_t),
ADD_KERNEL(Bool, Bool, bool, bool), ADD_KERNEL(Bool, Float64, bool, double),
ADD_KERNEL(Bool, Float32, bool, float)};
return func_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, tuple_equal, SequenceEqualCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,64 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SEQUENCE_EQUAL_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SEQUENCE_EQUAL_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include <utility>
#include <map>
#include <string>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class SequenceEqualCpuKernelMod : public NativeCpuKernelMod,
public MatchKernelHelper<SequenceEqualCpuKernelMod, AddressPtr> {
public:
SequenceEqualCpuKernelMod() = default;
~SequenceEqualCpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
MS_EXCEPTION_IF_NULL(kernel_func_);
return kernel_func_(this, inputs, workspace, outputs);
}
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
template <typename T, typename S>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
private:
size_t x_size_ = 0;
size_t y_size_ = 0;
bool is_inputs_type_diff_ = false;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SEQUENCE_EQUAL_CPU_KERNEL_H_

View File

@ -85,6 +85,8 @@ constexpr auto kScalarLe = "scalar_le";
constexpr auto kScalarGe = "scalar_ge"; constexpr auto kScalarGe = "scalar_ge";
constexpr auto kScalarBool = "ScalarBool"; constexpr auto kScalarBool = "ScalarBool";
constexpr auto kBoolNot = "bool_not"; constexpr auto kBoolNot = "bool_not";
constexpr auto kTupleEqual = "tuple_equal";
constexpr auto kListEqual = "list_equal";
constexpr auto kScalarBitwiseAnd = "bit_and"; constexpr auto kScalarBitwiseAnd = "bit_and";
constexpr auto kScalarBitwiseOr = "bit_or"; constexpr auto kScalarBitwiseOr = "bit_or";
constexpr auto kTupleLt = "tuple_lt"; constexpr auto kTupleLt = "tuple_lt";
@ -1752,8 +1754,8 @@ GVAR_DEF(PrimitivePtr, kPrimReducedShape, std::make_shared<Primitive>("reduced_s
GVAR_DEF(PrimitivePtr, kPrimTupleDiv, std::make_shared<Primitive>("tuple_div")); GVAR_DEF(PrimitivePtr, kPrimTupleDiv, std::make_shared<Primitive>("tuple_div"));
GVAR_DEF(PrimitivePtr, kPrimTupleToArray, std::make_shared<Primitive>("tuple_to_array")); GVAR_DEF(PrimitivePtr, kPrimTupleToArray, std::make_shared<Primitive>("tuple_to_array"));
GVAR_DEF(PrimitivePtr, kPrimShapeMul, std::make_shared<Primitive>("shape_mul")); GVAR_DEF(PrimitivePtr, kPrimShapeMul, std::make_shared<Primitive>("shape_mul"));
GVAR_DEF(PrimitivePtr, kPrimTupleEqual, std::make_shared<Primitive>("tuple_equal")); GVAR_DEF(PrimitivePtr, kPrimTupleEqual, std::make_shared<Primitive>(kTupleEqual));
GVAR_DEF(PrimitivePtr, kPrimListEqual, std::make_shared<Primitive>("list_equal")); GVAR_DEF(PrimitivePtr, kPrimListEqual, std::make_shared<Primitive>(kListEqual));
GVAR_DEF(PrimitivePtr, kPrimTupleGreaterThan, std::make_shared<Primitive>("tuple_greater_than")); GVAR_DEF(PrimitivePtr, kPrimTupleGreaterThan, std::make_shared<Primitive>("tuple_greater_than"));
GVAR_DEF(PrimitivePtr, kPrimListGreaterThan, std::make_shared<Primitive>("list_greater_than")); GVAR_DEF(PrimitivePtr, kPrimListGreaterThan, std::make_shared<Primitive>("list_greater_than"));
GVAR_DEF(PrimitivePtr, kPrimTupleGreaterEqual, std::make_shared<Primitive>("tuple_greater_equal")); GVAR_DEF(PrimitivePtr, kPrimTupleGreaterEqual, std::make_shared<Primitive>("tuple_greater_equal"));

View File

@ -0,0 +1,36 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_LIST_EQUAL_H_
#define MINDSPORE_CORE_OPS_LIST_EQUAL_H_
#include "ops/base_operator.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
namespace ops {
/// \brief bool_not op is used to calculate the input true or false.
class MIND_API list_equal : public BaseOperator {
public:
MIND_API_BASE_MEMBER(list_equal);
/// \brief Constructor.
list_equal() : BaseOperator(prim::kListEqual) {}
/// \brief Init.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_LIST_EQUAL_H_

View File

@ -0,0 +1,77 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include <string>
#include <algorithm>
#include "ops/tuple_equal.h"
#include "ops/list_equal.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "include/common/utils/utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
AbstractBasePtr SequenceEqualInferInner(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
constexpr size_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
constexpr size_t x_index = 0;
constexpr size_t y_index = 1;
auto x_abs = input_args[x_index];
auto y_abs = input_args[y_index];
if (!x_abs->isa<abstract::AbstractSequence>() && !y_abs->isa<abstract::AbstractSequence>()) {
MS_EXCEPTION(TypeError) << "For primitive '" << prim_name << "', the input must be a list or tuple, "
<< "but got: " << x_abs->ToString() << " and " << y_abs->ToString();
}
auto seqx_abs = x_abs->cast<abstract::AbstractSequencePtr>();
auto seqy_abs = y_abs->cast<abstract::AbstractSequencePtr>();
if (seqx_abs->dynamic_len() || seqy_abs->dynamic_len() || seqx_abs->BuildValue() == kAnyValue ||
seqy_abs->BuildValue() == kAnyValue) {
return std::make_shared<abstract::AbstractScalar>(kAnyValue, kBool);
}
return std::make_shared<abstract::AbstractScalar>(*seqx_abs->BuildValue() == *seqy_abs->BuildValue());
}
class SequenceEqualInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
return SequenceEqualInferInner(primitive, input_args)->BuildShape();
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
return SequenceEqualInferInner(prim, input_args)->BuildType();
}
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
return SequenceEqualInferInner(primitive, input_args);
}
};
MIND_API_OPERATOR_IMPL(tuple_equal, BaseOperator);
MIND_API_OPERATOR_IMPL(list_equal, BaseOperator);
REGISTER_PRIMITIVE_OP_INFER_IMPL(tuple_equal, prim::kPrimTupleEqual, SequenceEqualInfer, false);
REGISTER_PRIMITIVE_OP_INFER_IMPL(list_equal, prim::kPrimListEqual, SequenceEqualInfer, false);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_TUPLE_EQUAL_H_
#define MINDSPORE_CORE_OPS_TUPLE_EQUAL_H_
#include "ops/base_operator.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
namespace ops {
/// \brief bool_not op is used to calculate the input true or false.
class MIND_API tuple_equal : public BaseOperator {
public:
MIND_API_BASE_MEMBER(tuple_equal);
/// \brief Constructor.
tuple_equal() : BaseOperator(prim::kTupleEqual) {}
/// \brief Init.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_TUPLE_EQUAL_H_

View File

@ -86,6 +86,17 @@ def get_bprop_index(self):
return bprop return bprop
@bprop_getters.register("tuple_equal")
@bprop_getters.register("list_equal")
def get_bprop_seq_equal(self):
"""Generate bprop for tuple_equal and list_equal"""
def bprop(x, y, out, dout):
return (zeros_like(x), zeros_like(y))
return bprop
@bprop_getters.register("tuple_setitem") @bprop_getters.register("tuple_setitem")
@bprop_getters.register("list_setitem") @bprop_getters.register("list_setitem")
def get_bprop_setitem(self): def get_bprop_setitem(self):

View File

@ -47,6 +47,9 @@ class ScalarDiv(Primitive):
def __init__(self): def __init__(self):
"""Initialize ScalarDiv""" """Initialize ScalarDiv"""
def __call__(self, x, y):
return x / y
class ScalarFloordiv(Primitive): class ScalarFloordiv(Primitive):
r""" r"""
@ -79,6 +82,9 @@ class ScalarFloordiv(Primitive):
"""Initialize ScalarFloordiv""" """Initialize ScalarFloordiv"""
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
def __call__(self, x, y):
return x // y
class ScalarAdd(Primitive): class ScalarAdd(Primitive):
r""" r"""
@ -105,6 +111,9 @@ class ScalarAdd(Primitive):
def __init__(self): def __init__(self):
"""Initialize ScalarAdd""" """Initialize ScalarAdd"""
def __call__(self, x, y):
return x + y
class ScalarSub(Primitive): class ScalarSub(Primitive):
r""" r"""
@ -131,6 +140,9 @@ class ScalarSub(Primitive):
def __init__(self): def __init__(self):
"""Initialize ScalarSub""" """Initialize ScalarSub"""
def __call__(self, x, y):
return x - y
class ScalarMul(Primitive): class ScalarMul(Primitive):
r""" r"""
@ -157,6 +169,9 @@ class ScalarMul(Primitive):
def __init__(self): def __init__(self):
"""Initialize ScalarMul""" """Initialize ScalarMul"""
def __call__(self, x, y):
return x * y
class scalar_eq(Primitive): class scalar_eq(Primitive):
r""" r"""
@ -181,7 +196,10 @@ class scalar_eq(Primitive):
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""Initialize ScalarMul""" """Initialize ScalarEq"""
def __call__(self, x, y):
return x == y
class scalar_gt(Primitive): class scalar_gt(Primitive):
@ -209,6 +227,9 @@ class scalar_gt(Primitive):
def __init__(self): def __init__(self):
"""Initialize scalar_gt""" """Initialize scalar_gt"""
def __call__(self, x, y):
return x > y
class scalar_lt(Primitive): class scalar_lt(Primitive):
r""" r"""
@ -235,6 +256,9 @@ class scalar_lt(Primitive):
def __init__(self): def __init__(self):
"""Initialize scalar_lt""" """Initialize scalar_lt"""
def __call__(self, x, y):
return x < y
class scalar_ge(Primitive): class scalar_ge(Primitive):
r""" r"""
@ -261,6 +285,9 @@ class scalar_ge(Primitive):
def __init__(self): def __init__(self):
"""Initialize scalar_ge""" """Initialize scalar_ge"""
def __call__(self, x, y):
return x >= y
class scalar_le(Primitive): class scalar_le(Primitive):
r""" r"""
@ -287,6 +314,9 @@ class scalar_le(Primitive):
def __init__(self): def __init__(self):
"""Initialize scalar_le""" """Initialize scalar_le"""
def __call__(self, x, y):
return x <= y
class ScalarMod(Primitive): class ScalarMod(Primitive):
r""" r"""
@ -317,6 +347,9 @@ class ScalarMod(Primitive):
def __init__(self): def __init__(self):
"""Initialize ScalarMod""" """Initialize ScalarMod"""
def __call__(self, x, y):
return x % y
class ScalarBool(Primitive): class ScalarBool(Primitive):
r""" r"""
@ -367,6 +400,9 @@ class bool_not(Primitive):
def __init__(self): def __init__(self):
"""Initialize bool_not""" """Initialize bool_not"""
def __call__(self, x):
return not x
class bit_and(Primitive): class bit_and(Primitive):
r""" r"""

View File

@ -454,6 +454,67 @@ class make_range(Primitive):
"""Initialize make_range""" """Initialize make_range"""
self.init_prim_io_names(inputs=['start', 'limit', 'delta'], outputs=['output_data']) self.init_prim_io_names(inputs=['start', 'limit', 'delta'], outputs=['output_data'])
def __call__(self, start, limit, delta):
return range(start, limit, delta)
class tuple_equal(Primitive):
r"""
Support sequence equal operation 'equal(target)'.
.. note::
This it is only for internal used.
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
Inputs:
- **x** (Union[Tuple]) - The tuple.
- **y** (Union[Tuple]) - The tuple.
Outputs:
Bool.
Raises:
TypeError: The 'x' is not tuple.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self):
"""Initialize tuple_equal"""
def __call__(self, x, y):
return x == y
class list_equal(Primitive):
r"""
Support sequence equal operation 'equal(target)'.
.. note::
This it is only for internal used.
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
Inputs:
- **x** (Union[List]) - The list.
- **y** (Union[List]) - The list.
Outputs:
Bool.
Raises:
TypeError: The 'x' is not list.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self):
"""Initialize list_equal"""
def __call__(self, x, y):
return x == y
class sequence_len(Primitive): class sequence_len(Primitive):
r""" r"""
@ -480,6 +541,9 @@ class sequence_len(Primitive):
"""Initialize sequence_len""" """Initialize sequence_len"""
self.init_prim_io_names(inputs=['sequence'], outputs=['output_data']) self.init_prim_io_names(inputs=['sequence'], outputs=['output_data'])
def __call__(self, x):
return len(x)
class SequenceMax(Primitive): class SequenceMax(Primitive):
r""" r"""

View File

@ -0,0 +1,110 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import mindspore.nn as nn
from mindspore import context
from mindspore.ops.operations import _sequence_ops as _seq
from mindspore.common import mutable
from mindspore.ops.composite import GradOperation
from sequence_help import context_prepare
context.set_context(mode=context.GRAPH_MODE)
context_prepare()
class NetTupleEqual(nn.Cell):
def __init__(self):
super().__init__()
self.seq_equal = _seq.tuple_equal()
def construct(self, x, y):
return self.seq_equal(x, y)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_seq_dyn_equal():
"""
Feature: test sequence equal op
Description: equal operation on tuple type
Expectation: the behavior is matched to python style
"""
x = mutable((1, 2, 3, 4, 5, 6), True)
y = mutable((1, 2, 3, 2, 6), True)
expect = False
net = NetTupleEqual()
res = net(x, y)
assert res == expect
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_seq_dyn_equal1():
"""
Feature: test sequence equal op
Description: equal operation on tuple type
Expectation: the behavior is matched to python style
"""
x = mutable((1, 2, 3, 4, 5, 6), True)
y = (1, 2, 3, 4, 5, 6)
expect = True
net = NetTupleEqual()
res = net(x, y)
assert res == expect
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_seq_equal():
"""
Feature: test sequence equal op
Description: equal operation on tuple type
Expectation: the behavior is matched to python style
"""
x = (1, 2, 3, 4, 5)
y = (True, 2, 3, 4, 5)
expect = False
net = NetTupleEqual()
res = net(x, y)
assert res == expect
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_seq_equal_grad():
"""
Feature: test sequence equal grad op
Description: equal operation on tuple type
Expectation: the behavior is matched to python style
"""
net_ms = NetTupleEqual()
x = mutable((1, 2, 3, 4, 5, 6), True)
y = mutable((1, 2, 3, 4, 5, 6), True)
dout = True
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
print("grad out1 = ", grad_func(x, y, dout))