tuple equal ops
This commit is contained in:
parent
da82dab114
commit
c47bc41af7
|
@ -68,18 +68,6 @@ struct SlideInfo {
|
||||||
int64_t stop;
|
int64_t stop;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
|
||||||
// Inputs: two tuples or two lists.
|
|
||||||
const size_t args_num = 2;
|
|
||||||
CheckArgsSize(op_name, args_spec_list, args_num);
|
|
||||||
auto input_x = CheckArg<T>(op_name, args_spec_list, 0);
|
|
||||||
auto input_y = CheckArg<T>(op_name, args_spec_list, 1);
|
|
||||||
ValuePtr x_value = input_x->BuildValue();
|
|
||||||
ValuePtr y_value = input_y->BuildValue();
|
|
||||||
return std::make_shared<AbstractScalar>(*x_value == *y_value);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ComputeReduceIndex(const std::vector<int64_t> &reverse_x, const std::vector<int64_t> &reverse_y,
|
void ComputeReduceIndex(const std::vector<int64_t> &reverse_x, const std::vector<int64_t> &reverse_y,
|
||||||
std::vector<int64_t> *grad_x_reduce_idx, std::vector<int64_t> *grad_y_reduce_idy) {
|
std::vector<int64_t> *grad_x_reduce_idx, std::vector<int64_t> *grad_y_reduce_idy) {
|
||||||
MS_EXCEPTION_IF_NULL(grad_x_reduce_idx);
|
MS_EXCEPTION_IF_NULL(grad_x_reduce_idx);
|
||||||
|
@ -814,16 +802,6 @@ AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const Primitive
|
||||||
return args_spec_list[0]->Clone();
|
return args_spec_list[0]->Clone();
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
|
||||||
return InferImplTupleOrListEqual<AbstractTuple>(primitive->name(), args_spec_list);
|
|
||||||
}
|
|
||||||
|
|
||||||
AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
|
||||||
return InferImplTupleOrListEqual<AbstractList>(primitive->name(), args_spec_list);
|
|
||||||
}
|
|
||||||
|
|
||||||
AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
return InferTupleOrListOrDictLen<AbstractDictionary>(primitive->name(), args_spec_list);
|
return InferTupleOrListOrDictLen<AbstractDictionary>(primitive->name(), args_spec_list);
|
||||||
|
@ -1089,10 +1067,8 @@ REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringGetItem, prim::kPrimStringGetItem, Infe
|
||||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed, nullptr);
|
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed, nullptr);
|
||||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv, nullptr);
|
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv, nullptr);
|
||||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array, nullptr);
|
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array, nullptr);
|
||||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual, nullptr);
|
|
||||||
// List
|
// List
|
||||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce, nullptr);
|
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce, nullptr);
|
||||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual, nullptr);
|
|
||||||
// Dict
|
// Dict
|
||||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr);
|
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr);
|
||||||
// Slice
|
// Slice
|
||||||
|
@ -1135,13 +1111,9 @@ void RegPrimitiveFrontEval() {
|
||||||
InferImplTupleDiv, nullptr);
|
InferImplTupleDiv, nullptr);
|
||||||
abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimTupleToArray,
|
abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimTupleToArray,
|
||||||
InferImplTuple2Array, nullptr);
|
InferImplTuple2Array, nullptr);
|
||||||
abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimTupleEqual,
|
|
||||||
InferImplTupleEqual, nullptr);
|
|
||||||
// List
|
// List
|
||||||
abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimListReduce,
|
abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimListReduce,
|
||||||
InferImplListReduce, nullptr);
|
InferImplListReduce, nullptr);
|
||||||
abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimListEqual,
|
|
||||||
InferImplListEqual, nullptr);
|
|
||||||
// Dict
|
// Dict
|
||||||
abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimDictLen,
|
abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimDictLen,
|
||||||
InferImplDictLen, nullptr);
|
InferImplDictLen, nullptr);
|
||||||
|
|
|
@ -36,13 +36,9 @@ AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const AbstractBasePtrList &args_spec_list);
|
|
||||||
// List
|
// List
|
||||||
AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const AbstractBasePtrList &args_spec_list);
|
|
||||||
// Dict
|
// Dict
|
||||||
AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
|
|
@ -0,0 +1,110 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "plugin/device/cpu/kernel/sequence/sequence_equal_cpu_kernel.h"
|
||||||
|
#include <algorithm>
|
||||||
|
#include <utility>
|
||||||
|
#include <complex>
|
||||||
|
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||||
|
#include "utils/ms_utils.h"
|
||||||
|
#include "include/common/thread_pool.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
namespace {
|
||||||
|
constexpr int kInputsNum = 2;
|
||||||
|
constexpr int kOutputsNum = 1;
|
||||||
|
} // namespace
|
||||||
|
bool SequenceEqualCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(base_operator);
|
||||||
|
kernel_name_ = base_operator->name();
|
||||||
|
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
|
||||||
|
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
|
||||||
|
return MatchKernelFunc(base_operator, inputs, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
int SequenceEqualCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
|
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||||
|
is_inputs_type_diff_ = false;
|
||||||
|
int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost);
|
||||||
|
if (ret != 0) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
CHECK_KERNEL_INPUTS_NUM(input_shapes_.size(), kInputsNum, kernel_name_);
|
||||||
|
if (input_shapes_[0].empty() || input_shapes_[1].empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the x and y shape can't be 0, but got " << input_shapes_;
|
||||||
|
}
|
||||||
|
x_size_ = input_shapes_[0][0];
|
||||||
|
y_size_ = input_shapes_[1][0];
|
||||||
|
if (inputs[0]->GetDtype() != inputs[1]->GetDtype()) {
|
||||||
|
is_inputs_type_diff_ = true;
|
||||||
|
}
|
||||||
|
return KRET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename S>
|
||||||
|
bool SequenceEqualCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||||
|
const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs) {
|
||||||
|
const auto x_addr = GetDeviceAddress<T>(inputs, 0);
|
||||||
|
const auto y_addr = GetDeviceAddress<S>(inputs, 1);
|
||||||
|
bool *output_addr = GetDeviceAddress<bool>(outputs, 0);
|
||||||
|
if (x_size_ != y_size_ || is_inputs_type_diff_) {
|
||||||
|
*output_addr = false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < x_size_; ++i) {
|
||||||
|
if (static_cast<double>(x_addr[i]) != static_cast<double>(y_addr[i])) {
|
||||||
|
*output_addr = false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*output_addr = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define ADD_KERNEL(x_dtype, y_dtype, x_type, y_type) \
|
||||||
|
{ \
|
||||||
|
KernelAttr() \
|
||||||
|
.AddInputAttr(kObjectTypeTuple, kNumberType##x_dtype) \
|
||||||
|
.AddInputAttr(kObjectTypeTuple, kNumberType##y_dtype) \
|
||||||
|
.AddOutputAttr(kObjectTypeNumber, kNumberTypeBool), \
|
||||||
|
&SequenceEqualCpuKernelMod::LaunchKernel<x_type, y_type> \
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<std::pair<KernelAttr, SequenceEqualCpuKernelMod::KernelRunFunc>>
|
||||||
|
&SequenceEqualCpuKernelMod::GetFuncList() const {
|
||||||
|
static const std::vector<std::pair<KernelAttr, SequenceEqualCpuKernelMod::KernelRunFunc>> func_list = {
|
||||||
|
ADD_KERNEL(Float32, Float32, float, float), ADD_KERNEL(Float32, Float64, float, double),
|
||||||
|
ADD_KERNEL(Float32, Int32, float, int), ADD_KERNEL(Float32, Int64, float, int64_t),
|
||||||
|
ADD_KERNEL(Float32, Bool, float, bool), ADD_KERNEL(Float64, Float32, double, float),
|
||||||
|
ADD_KERNEL(Float64, Bool, double, bool), ADD_KERNEL(Float64, Float64, double, double),
|
||||||
|
ADD_KERNEL(Float64, Int32, double, int), ADD_KERNEL(Float64, Int64, double, int64_t),
|
||||||
|
ADD_KERNEL(Int32, Float32, int, float), ADD_KERNEL(Int32, Float64, int, double),
|
||||||
|
ADD_KERNEL(Int32, Int32, int, int), ADD_KERNEL(Int32, Int64, int, int64_t),
|
||||||
|
ADD_KERNEL(Int32, Bool, int, bool), ADD_KERNEL(Int64, Float32, int64_t, float),
|
||||||
|
ADD_KERNEL(Int64, Bool, int64_t, bool), ADD_KERNEL(Int64, Float64, int64_t, double),
|
||||||
|
ADD_KERNEL(Int64, Int32, int64_t, int), ADD_KERNEL(Int64, Int64, int64_t, int64_t),
|
||||||
|
ADD_KERNEL(Bool, Int32, bool, int), ADD_KERNEL(Bool, Int64, bool, int64_t),
|
||||||
|
ADD_KERNEL(Bool, Bool, bool, bool), ADD_KERNEL(Bool, Float64, bool, double),
|
||||||
|
ADD_KERNEL(Bool, Float32, bool, float)};
|
||||||
|
return func_list;
|
||||||
|
}
|
||||||
|
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, tuple_equal, SequenceEqualCpuKernelMod);
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,64 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SEQUENCE_EQUAL_CPU_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SEQUENCE_EQUAL_CPU_KERNEL_H_
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||||
|
#include "plugin/factory/ms_factory.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
class SequenceEqualCpuKernelMod : public NativeCpuKernelMod,
|
||||||
|
public MatchKernelHelper<SequenceEqualCpuKernelMod, AddressPtr> {
|
||||||
|
public:
|
||||||
|
SequenceEqualCpuKernelMod() = default;
|
||||||
|
~SequenceEqualCpuKernelMod() override = default;
|
||||||
|
|
||||||
|
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs) override;
|
||||||
|
|
||||||
|
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
|
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override;
|
||||||
|
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_func_);
|
||||||
|
return kernel_func_(this, inputs, workspace, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||||
|
|
||||||
|
template <typename T, typename S>
|
||||||
|
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs);
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t x_size_ = 0;
|
||||||
|
size_t y_size_ = 0;
|
||||||
|
bool is_inputs_type_diff_ = false;
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SEQUENCE_EQUAL_CPU_KERNEL_H_
|
|
@ -85,6 +85,8 @@ constexpr auto kScalarLe = "scalar_le";
|
||||||
constexpr auto kScalarGe = "scalar_ge";
|
constexpr auto kScalarGe = "scalar_ge";
|
||||||
constexpr auto kScalarBool = "ScalarBool";
|
constexpr auto kScalarBool = "ScalarBool";
|
||||||
constexpr auto kBoolNot = "bool_not";
|
constexpr auto kBoolNot = "bool_not";
|
||||||
|
constexpr auto kTupleEqual = "tuple_equal";
|
||||||
|
constexpr auto kListEqual = "list_equal";
|
||||||
constexpr auto kScalarBitwiseAnd = "bit_and";
|
constexpr auto kScalarBitwiseAnd = "bit_and";
|
||||||
constexpr auto kScalarBitwiseOr = "bit_or";
|
constexpr auto kScalarBitwiseOr = "bit_or";
|
||||||
constexpr auto kTupleLt = "tuple_lt";
|
constexpr auto kTupleLt = "tuple_lt";
|
||||||
|
@ -1752,8 +1754,8 @@ GVAR_DEF(PrimitivePtr, kPrimReducedShape, std::make_shared<Primitive>("reduced_s
|
||||||
GVAR_DEF(PrimitivePtr, kPrimTupleDiv, std::make_shared<Primitive>("tuple_div"));
|
GVAR_DEF(PrimitivePtr, kPrimTupleDiv, std::make_shared<Primitive>("tuple_div"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimTupleToArray, std::make_shared<Primitive>("tuple_to_array"));
|
GVAR_DEF(PrimitivePtr, kPrimTupleToArray, std::make_shared<Primitive>("tuple_to_array"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimShapeMul, std::make_shared<Primitive>("shape_mul"));
|
GVAR_DEF(PrimitivePtr, kPrimShapeMul, std::make_shared<Primitive>("shape_mul"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimTupleEqual, std::make_shared<Primitive>("tuple_equal"));
|
GVAR_DEF(PrimitivePtr, kPrimTupleEqual, std::make_shared<Primitive>(kTupleEqual));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimListEqual, std::make_shared<Primitive>("list_equal"));
|
GVAR_DEF(PrimitivePtr, kPrimListEqual, std::make_shared<Primitive>(kListEqual));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimTupleGreaterThan, std::make_shared<Primitive>("tuple_greater_than"));
|
GVAR_DEF(PrimitivePtr, kPrimTupleGreaterThan, std::make_shared<Primitive>("tuple_greater_than"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimListGreaterThan, std::make_shared<Primitive>("list_greater_than"));
|
GVAR_DEF(PrimitivePtr, kPrimListGreaterThan, std::make_shared<Primitive>("list_greater_than"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimTupleGreaterEqual, std::make_shared<Primitive>("tuple_greater_equal"));
|
GVAR_DEF(PrimitivePtr, kPrimTupleGreaterEqual, std::make_shared<Primitive>("tuple_greater_equal"));
|
||||||
|
|
|
@ -0,0 +1,36 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_LIST_EQUAL_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_LIST_EQUAL_H_
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
#include "mindspore/core/ops/core_ops.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
/// \brief bool_not op is used to calculate the input true or false.
|
||||||
|
class MIND_API list_equal : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(list_equal);
|
||||||
|
/// \brief Constructor.
|
||||||
|
list_equal() : BaseOperator(prim::kListEqual) {}
|
||||||
|
/// \brief Init.
|
||||||
|
void Init() const {}
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_LIST_EQUAL_H_
|
|
@ -0,0 +1,77 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "ops/tuple_equal.h"
|
||||||
|
#include "ops/list_equal.h"
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "include/common/utils/utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
AbstractBasePtr SequenceEqualInferInner(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
constexpr size_t input_num = 2;
|
||||||
|
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||||
|
for (const auto &item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
|
constexpr size_t x_index = 0;
|
||||||
|
constexpr size_t y_index = 1;
|
||||||
|
auto x_abs = input_args[x_index];
|
||||||
|
auto y_abs = input_args[y_index];
|
||||||
|
if (!x_abs->isa<abstract::AbstractSequence>() && !y_abs->isa<abstract::AbstractSequence>()) {
|
||||||
|
MS_EXCEPTION(TypeError) << "For primitive '" << prim_name << "', the input must be a list or tuple, "
|
||||||
|
<< "but got: " << x_abs->ToString() << " and " << y_abs->ToString();
|
||||||
|
}
|
||||||
|
auto seqx_abs = x_abs->cast<abstract::AbstractSequencePtr>();
|
||||||
|
auto seqy_abs = y_abs->cast<abstract::AbstractSequencePtr>();
|
||||||
|
if (seqx_abs->dynamic_len() || seqy_abs->dynamic_len() || seqx_abs->BuildValue() == kAnyValue ||
|
||||||
|
seqy_abs->BuildValue() == kAnyValue) {
|
||||||
|
return std::make_shared<abstract::AbstractScalar>(kAnyValue, kBool);
|
||||||
|
}
|
||||||
|
return std::make_shared<abstract::AbstractScalar>(*seqx_abs->BuildValue() == *seqy_abs->BuildValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
class SequenceEqualInfer : public abstract::OpInferBase {
|
||||||
|
public:
|
||||||
|
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||||
|
return SequenceEqualInferInner(primitive, input_args)->BuildShape();
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||||
|
return SequenceEqualInferInner(prim, input_args)->BuildType();
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||||
|
return SequenceEqualInferInner(primitive, input_args);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
MIND_API_OPERATOR_IMPL(tuple_equal, BaseOperator);
|
||||||
|
MIND_API_OPERATOR_IMPL(list_equal, BaseOperator);
|
||||||
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(tuple_equal, prim::kPrimTupleEqual, SequenceEqualInfer, false);
|
||||||
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(list_equal, prim::kPrimListEqual, SequenceEqualInfer, false);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_TUPLE_EQUAL_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_TUPLE_EQUAL_H_
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
#include "mindspore/core/ops/core_ops.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
/// \brief bool_not op is used to calculate the input true or false.
|
||||||
|
class MIND_API tuple_equal : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(tuple_equal);
|
||||||
|
/// \brief Constructor.
|
||||||
|
tuple_equal() : BaseOperator(prim::kTupleEqual) {}
|
||||||
|
/// \brief Init.
|
||||||
|
void Init() const {}
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_TUPLE_EQUAL_H_
|
|
@ -86,6 +86,17 @@ def get_bprop_index(self):
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register("tuple_equal")
|
||||||
|
@bprop_getters.register("list_equal")
|
||||||
|
def get_bprop_seq_equal(self):
|
||||||
|
"""Generate bprop for tuple_equal and list_equal"""
|
||||||
|
|
||||||
|
def bprop(x, y, out, dout):
|
||||||
|
return (zeros_like(x), zeros_like(y))
|
||||||
|
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register("tuple_setitem")
|
@bprop_getters.register("tuple_setitem")
|
||||||
@bprop_getters.register("list_setitem")
|
@bprop_getters.register("list_setitem")
|
||||||
def get_bprop_setitem(self):
|
def get_bprop_setitem(self):
|
||||||
|
|
|
@ -47,6 +47,9 @@ class ScalarDiv(Primitive):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize ScalarDiv"""
|
"""Initialize ScalarDiv"""
|
||||||
|
|
||||||
|
def __call__(self, x, y):
|
||||||
|
return x / y
|
||||||
|
|
||||||
|
|
||||||
class ScalarFloordiv(Primitive):
|
class ScalarFloordiv(Primitive):
|
||||||
r"""
|
r"""
|
||||||
|
@ -79,6 +82,9 @@ class ScalarFloordiv(Primitive):
|
||||||
"""Initialize ScalarFloordiv"""
|
"""Initialize ScalarFloordiv"""
|
||||||
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
||||||
|
|
||||||
|
def __call__(self, x, y):
|
||||||
|
return x // y
|
||||||
|
|
||||||
|
|
||||||
class ScalarAdd(Primitive):
|
class ScalarAdd(Primitive):
|
||||||
r"""
|
r"""
|
||||||
|
@ -105,6 +111,9 @@ class ScalarAdd(Primitive):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize ScalarAdd"""
|
"""Initialize ScalarAdd"""
|
||||||
|
|
||||||
|
def __call__(self, x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
|
||||||
class ScalarSub(Primitive):
|
class ScalarSub(Primitive):
|
||||||
r"""
|
r"""
|
||||||
|
@ -131,6 +140,9 @@ class ScalarSub(Primitive):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize ScalarSub"""
|
"""Initialize ScalarSub"""
|
||||||
|
|
||||||
|
def __call__(self, x, y):
|
||||||
|
return x - y
|
||||||
|
|
||||||
|
|
||||||
class ScalarMul(Primitive):
|
class ScalarMul(Primitive):
|
||||||
r"""
|
r"""
|
||||||
|
@ -157,6 +169,9 @@ class ScalarMul(Primitive):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize ScalarMul"""
|
"""Initialize ScalarMul"""
|
||||||
|
|
||||||
|
def __call__(self, x, y):
|
||||||
|
return x * y
|
||||||
|
|
||||||
|
|
||||||
class scalar_eq(Primitive):
|
class scalar_eq(Primitive):
|
||||||
r"""
|
r"""
|
||||||
|
@ -181,7 +196,10 @@ class scalar_eq(Primitive):
|
||||||
"""
|
"""
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize ScalarMul"""
|
"""Initialize ScalarEq"""
|
||||||
|
|
||||||
|
def __call__(self, x, y):
|
||||||
|
return x == y
|
||||||
|
|
||||||
|
|
||||||
class scalar_gt(Primitive):
|
class scalar_gt(Primitive):
|
||||||
|
@ -209,6 +227,9 @@ class scalar_gt(Primitive):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize scalar_gt"""
|
"""Initialize scalar_gt"""
|
||||||
|
|
||||||
|
def __call__(self, x, y):
|
||||||
|
return x > y
|
||||||
|
|
||||||
|
|
||||||
class scalar_lt(Primitive):
|
class scalar_lt(Primitive):
|
||||||
r"""
|
r"""
|
||||||
|
@ -235,6 +256,9 @@ class scalar_lt(Primitive):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize scalar_lt"""
|
"""Initialize scalar_lt"""
|
||||||
|
|
||||||
|
def __call__(self, x, y):
|
||||||
|
return x < y
|
||||||
|
|
||||||
|
|
||||||
class scalar_ge(Primitive):
|
class scalar_ge(Primitive):
|
||||||
r"""
|
r"""
|
||||||
|
@ -261,6 +285,9 @@ class scalar_ge(Primitive):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize scalar_ge"""
|
"""Initialize scalar_ge"""
|
||||||
|
|
||||||
|
def __call__(self, x, y):
|
||||||
|
return x >= y
|
||||||
|
|
||||||
|
|
||||||
class scalar_le(Primitive):
|
class scalar_le(Primitive):
|
||||||
r"""
|
r"""
|
||||||
|
@ -287,6 +314,9 @@ class scalar_le(Primitive):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize scalar_le"""
|
"""Initialize scalar_le"""
|
||||||
|
|
||||||
|
def __call__(self, x, y):
|
||||||
|
return x <= y
|
||||||
|
|
||||||
|
|
||||||
class ScalarMod(Primitive):
|
class ScalarMod(Primitive):
|
||||||
r"""
|
r"""
|
||||||
|
@ -317,6 +347,9 @@ class ScalarMod(Primitive):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize ScalarMod"""
|
"""Initialize ScalarMod"""
|
||||||
|
|
||||||
|
def __call__(self, x, y):
|
||||||
|
return x % y
|
||||||
|
|
||||||
|
|
||||||
class ScalarBool(Primitive):
|
class ScalarBool(Primitive):
|
||||||
r"""
|
r"""
|
||||||
|
@ -367,6 +400,9 @@ class bool_not(Primitive):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize bool_not"""
|
"""Initialize bool_not"""
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return not x
|
||||||
|
|
||||||
|
|
||||||
class bit_and(Primitive):
|
class bit_and(Primitive):
|
||||||
r"""
|
r"""
|
||||||
|
|
|
@ -454,6 +454,67 @@ class make_range(Primitive):
|
||||||
"""Initialize make_range"""
|
"""Initialize make_range"""
|
||||||
self.init_prim_io_names(inputs=['start', 'limit', 'delta'], outputs=['output_data'])
|
self.init_prim_io_names(inputs=['start', 'limit', 'delta'], outputs=['output_data'])
|
||||||
|
|
||||||
|
def __call__(self, start, limit, delta):
|
||||||
|
return range(start, limit, delta)
|
||||||
|
|
||||||
|
|
||||||
|
class tuple_equal(Primitive):
|
||||||
|
r"""
|
||||||
|
Support sequence equal operation 'equal(target)'.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
This it is only for internal used.
|
||||||
|
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Union[Tuple]) - The tuple.
|
||||||
|
- **y** (Union[Tuple]) - The tuple.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Bool.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: The 'x' is not tuple.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize tuple_equal"""
|
||||||
|
|
||||||
|
def __call__(self, x, y):
|
||||||
|
return x == y
|
||||||
|
|
||||||
|
|
||||||
|
class list_equal(Primitive):
|
||||||
|
r"""
|
||||||
|
Support sequence equal operation 'equal(target)'.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
This it is only for internal used.
|
||||||
|
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Union[List]) - The list.
|
||||||
|
- **y** (Union[List]) - The list.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Bool.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: The 'x' is not list.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize list_equal"""
|
||||||
|
|
||||||
|
def __call__(self, x, y):
|
||||||
|
return x == y
|
||||||
|
|
||||||
|
|
||||||
class sequence_len(Primitive):
|
class sequence_len(Primitive):
|
||||||
r"""
|
r"""
|
||||||
|
@ -480,6 +541,9 @@ class sequence_len(Primitive):
|
||||||
"""Initialize sequence_len"""
|
"""Initialize sequence_len"""
|
||||||
self.init_prim_io_names(inputs=['sequence'], outputs=['output_data'])
|
self.init_prim_io_names(inputs=['sequence'], outputs=['output_data'])
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return len(x)
|
||||||
|
|
||||||
|
|
||||||
class SequenceMax(Primitive):
|
class SequenceMax(Primitive):
|
||||||
r"""
|
r"""
|
||||||
|
|
|
@ -0,0 +1,110 @@
|
||||||
|
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.ops.operations import _sequence_ops as _seq
|
||||||
|
from mindspore.common import mutable
|
||||||
|
from mindspore.ops.composite import GradOperation
|
||||||
|
from sequence_help import context_prepare
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
context_prepare()
|
||||||
|
|
||||||
|
|
||||||
|
class NetTupleEqual(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.seq_equal = _seq.tuple_equal()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.seq_equal(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_seq_dyn_equal():
|
||||||
|
"""
|
||||||
|
Feature: test sequence equal op
|
||||||
|
Description: equal operation on tuple type
|
||||||
|
Expectation: the behavior is matched to python style
|
||||||
|
"""
|
||||||
|
x = mutable((1, 2, 3, 4, 5, 6), True)
|
||||||
|
y = mutable((1, 2, 3, 2, 6), True)
|
||||||
|
expect = False
|
||||||
|
net = NetTupleEqual()
|
||||||
|
res = net(x, y)
|
||||||
|
assert res == expect
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_seq_dyn_equal1():
|
||||||
|
"""
|
||||||
|
Feature: test sequence equal op
|
||||||
|
Description: equal operation on tuple type
|
||||||
|
Expectation: the behavior is matched to python style
|
||||||
|
"""
|
||||||
|
x = mutable((1, 2, 3, 4, 5, 6), True)
|
||||||
|
y = (1, 2, 3, 4, 5, 6)
|
||||||
|
expect = True
|
||||||
|
net = NetTupleEqual()
|
||||||
|
res = net(x, y)
|
||||||
|
assert res == expect
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_seq_equal():
|
||||||
|
"""
|
||||||
|
Feature: test sequence equal op
|
||||||
|
Description: equal operation on tuple type
|
||||||
|
Expectation: the behavior is matched to python style
|
||||||
|
"""
|
||||||
|
x = (1, 2, 3, 4, 5)
|
||||||
|
y = (True, 2, 3, 4, 5)
|
||||||
|
expect = False
|
||||||
|
net = NetTupleEqual()
|
||||||
|
res = net(x, y)
|
||||||
|
assert res == expect
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_seq_equal_grad():
|
||||||
|
"""
|
||||||
|
Feature: test sequence equal grad op
|
||||||
|
Description: equal operation on tuple type
|
||||||
|
Expectation: the behavior is matched to python style
|
||||||
|
"""
|
||||||
|
net_ms = NetTupleEqual()
|
||||||
|
x = mutable((1, 2, 3, 4, 5, 6), True)
|
||||||
|
y = mutable((1, 2, 3, 4, 5, 6), True)
|
||||||
|
dout = True
|
||||||
|
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
||||||
|
print("grad out1 = ", grad_func(x, y, dout))
|
Loading…
Reference in New Issue