add tupletotensor test case & modify some scalar ops name

This commit is contained in:
huoxinyou 2023-02-06 16:38:55 +08:00
parent fedc32491c
commit 5519cfd9f0
36 changed files with 287 additions and 150 deletions

View File

@ -35,7 +35,7 @@ PrimToFunction::PrimToFunction()
{kScalarLt, kPrimTypeNumTwoArgs}, {"scalar_ne", kPrimTypeNumTwoArgs},
{kScalarMod, kPrimTypeNumTwoArgs}, {kScalarMul, kPrimTypeNumTwoArgs},
{kScalarPow, kPrimTypeNumTwoArgs}, {kScalarSub, kPrimTypeNumTwoArgs},
{kScalarFloorDiv, kPrimTypeNumTwoArgs}, {kScalarBitwiseAnd, kPrimTypeNumTwoArgs},
{kScalarFloordiv, kPrimTypeNumTwoArgs}, {kScalarBitwiseAnd, kPrimTypeNumTwoArgs},
{kScalarBitwiseOr, kPrimTypeNumTwoArgs}, {"bit_xor", kPrimTypeNumTwoArgs},
{"bit_left_shift", kPrimTypeNumTwoArgs}, {"bit_right_shift", kPrimTypeNumTwoArgs},
{kStringNot, kPrimTypeStrOneArg}, {kStringConcat, kPrimTypeStrTwoArgs},

View File

@ -781,7 +781,7 @@ class CleanAfterOptARewriter : public BaseRewriter {
}
// From:
// ListGetItem(list, key)
// list_getitem(list, key)
// To:
// TupleGetItem(list, key)
AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) {

View File

@ -178,6 +178,17 @@ REG_ASCEND_VM_OP_ADAPTATION_INFO(kBatchToSpaceOpName)
REG_ASCEND_VM_OP_ADAPTATION_INFO(kBatchToSpaceNDOpName).set_backend_op_name(kBatchToSpaceNDDOpName);
REG_ASCEND_VM_OP_ADAPTATION_INFO(kCastOpName).set_target_op_name(kCastOpName).set_input_attr_info(1, "int");
REG_ASCEND_VM_OP_ADAPTATION_INFO(kScalarToTensorOpName)
.set_target_op_name(kScalarToTensorOpName)
.set_input_attr_info(1, "int");
REG_ASCEND_VM_OP_ADAPTATION_INFO(kTupleToTensorOpName)
.set_target_op_name(kTupleToTensorOpName)
.set_input_attr_info(1, "int");
REG_ASCEND_VM_OP_ADAPTATION_INFO(kListToTensorOpName)
.set_target_op_name(kListToTensorOpName)
.set_input_attr_info(1, "int");
REG_ASCEND_VM_OP_ADAPTATION_INFO(kCentralizationOpName)
.set_target_op_name(kCentralizationOpName)

View File

@ -34,13 +34,13 @@ constexpr auto kScalarAdd = "ScalarAdd";
constexpr auto kScalarSub = "ScalarSub";
constexpr auto kScalarMul = "ScalarMul";
constexpr auto kScalarDiv = "ScalarDiv";
constexpr auto kScalarFloorDiv = "ScalarFloorDiv";
constexpr auto kScalarFloordiv = "ScalarFloordiv";
constexpr auto kScalarMod = "ScalarMod";
constexpr auto kScalarGt = "ScalarGreater";
constexpr auto kScalarGe = "ScalarGreaterEqual";
constexpr auto kScalarLt = "ScalarLess";
constexpr auto kScalarLe = "ScalarLessEqual";
constexpr auto kScalarEq = "ScalarEqual";
constexpr auto kScalarGt = "scalar_gt";
constexpr auto kScalarGe = "scalar_ge";
constexpr auto kScalarLt = "scalar_lt";
constexpr auto kScalarLe = "scalar_le";
constexpr auto kScalarEq = "scalar_eq";
constexpr size_t kInputNum = 2;
constexpr size_t kInputx = 0;
constexpr size_t kInputy = 1;
@ -194,9 +194,6 @@ bool ScalarArithmeticCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
if (kernel_name_ != kernel_type_) {
MS_LOG(EXCEPTION) << "Suppose to be " << kernel_type_ << " but got " << kernel_name_;
}
if (inputs.size() != kInputNum) {
MS_LOG(EXCEPTION) << "For kernel '" << kernel_type_ << "' input_num must be 2, but got " << inputs.size();
}
@ -245,7 +242,7 @@ bool ScalarArithmeticCpuKernelMod::LaunchKernel(const std::vector<KernelTensorPt
{kScalarLt, LtImpl<T, S, N>},
{kScalarGe, GeImpl<T, S, N>},
{kScalarLe, LeImpl<T, S, N>},
{kScalarFloorDiv, FloorDivImpl<T, S, N>}};
{kScalarFloordiv, FloorDivImpl<T, S, N>}};
auto iter = func_map.find(kernel_name_);
if (iter == func_map.end()) {
MS_EXCEPTION(TypeError) << "For '" << kernel_name_
@ -347,19 +344,19 @@ MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarMul,
[]() { return std::make_shared<ScalarArithmeticCpuKernelMod>(kScalarMul); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarDiv,
[]() { return std::make_shared<ScalarArithmeticCpuKernelMod>(kScalarDiv); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarFloorDiv,
[]() { return std::make_shared<ScalarArithmeticCpuKernelMod>(kScalarFloorDiv); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarFloordiv,
[]() { return std::make_shared<ScalarArithmeticCpuKernelMod>(kScalarFloordiv); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarMod,
[]() { return std::make_shared<ScalarArithmeticCpuKernelMod>(kScalarMod); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarEqual,
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, scalar_eq,
[]() { return std::make_shared<ScalarArithmeticCpuKernelMod>(kScalarEq); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarGreater,
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, scalar_gt,
[]() { return std::make_shared<ScalarArithmeticCpuKernelMod>(kScalarGt); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarGreaterEqual,
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, scalar_ge,
[]() { return std::make_shared<ScalarArithmeticCpuKernelMod>(kScalarGe); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarLess,
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, scalar_lt,
[]() { return std::make_shared<ScalarArithmeticCpuKernelMod>(kScalarLt); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarLessEqual,
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, scalar_le,
[]() { return std::make_shared<ScalarArithmeticCpuKernelMod>(kScalarLe); });
} // namespace kernel
} // namespace mindspore

View File

@ -26,8 +26,8 @@
namespace mindspore {
namespace kernel {
namespace {
constexpr auto kScalarBitwiseAnd = "ScalarBitwiseAnd";
constexpr auto kScalarBitwiseOr = "ScalarBitwiseOr";
constexpr auto kScalarBitwiseAnd = "bit_and";
constexpr auto kScalarBitwiseOr = "bit_or";
constexpr size_t kInputNum = 2;
constexpr size_t kInputx = 0;
constexpr size_t kInputy = 1;
@ -54,9 +54,6 @@ bool ScalarBitwiseCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
if (kernel_name_ != kernel_type_) {
MS_LOG(EXCEPTION) << "Suppose to be " << kernel_type_ << " but got " << kernel_name_;
}
if (inputs.size() != kInputNum) {
MS_LOG(EXCEPTION) << "For kernel '" << kernel_type_ << "' input_num must be 2, but got " << inputs.size();
}
@ -127,9 +124,9 @@ std::vector<KernelAttr> ScalarBitwiseCpuKernelMod::GetOpSupport() {
return support_list;
}
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarBitwiseAnd,
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, bit_and,
[]() { return std::make_shared<ScalarBitwiseCpuKernelMod>(kScalarBitwiseAnd); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarBitwiseOr,
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, bit_or,
[]() { return std::make_shared<ScalarBitwiseCpuKernelMod>(kScalarBitwiseOr); });
} // namespace kernel
} // namespace mindspore

View File

@ -151,7 +151,6 @@ int TransposeGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
std::vector<int64_t> perm;
if (TryGetIntValue(inputs, kAxisIndex1st, kernel_name_, &perm)) {
GetPermValue(perm);
get_dynamic_perm_value_ = true;
}
if (int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;

View File

@ -64,7 +64,6 @@ class TransposeGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelpe
size_t workspace_size_{0};
bool is_null_input_;
bool is_dynamic_perm_{false};
bool get_dynamic_perm_value_{false};
};
} // namespace kernel
} // namespace mindspore

View File

@ -25,6 +25,10 @@ namespace mindspore::opt {
RER_GPU_DYNAMIC_CONST_TO_ATTR(kCastOpName, 1);
RER_GPU_DYNAMIC_CONST_TO_ATTR(kFillOpName, 0);
RER_GPU_DYNAMIC_CONST_TO_ATTR(kScalarToTensorOpName, 1);
RER_GPU_DYNAMIC_CONST_TO_ATTR(kTupleToTensorOpName, 1);
RER_GPU_DYNAMIC_CONST_TO_ATTR(kListToTensorOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kCastOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kCOO2CSROpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kCSR2COOOpName, 1);
@ -35,6 +39,9 @@ RER_GPU_STATIC_CONST_TO_ATTR(kCSRMulOpName, 3);
RER_GPU_STATIC_CONST_TO_ATTR(kCSRMVOpName, 3);
RER_GPU_STATIC_CONST_TO_ATTR(kCSRReduceSumOpName, 3, 4);
RER_GPU_STATIC_CONST_TO_ATTR(kFillOpName, 0);
RER_GPU_STATIC_CONST_TO_ATTR(kScalarToTensorOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kTupleToTensorOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kListToTensorOpName, 1);
} // namespace mindspore::opt
#endif // MINDSPORE_CCSRC_PLUGIN_GPU_OPTIMIZER_REG_GPU_CONST_INPUT_TO_ATTR_H_

View File

@ -69,21 +69,21 @@ constexpr auto kScalarAdd = "ScalarAdd";
constexpr auto kScalarSub = "ScalarSub";
constexpr auto kScalarMul = "ScalarMul";
constexpr auto kScalarDiv = "ScalarDiv";
constexpr auto kScalarFloorDiv = "ScalarFloorDiv";
constexpr auto kScalarFloordiv = "ScalarFloordiv";
constexpr auto kScalarMod = "ScalarMod";
constexpr auto kScalarPow = "ScalarPow";
constexpr auto kScalarTrunc = "ScalarTrunc";
constexpr auto kScalarFloor = "ScalarFloor";
constexpr auto kScalarUadd = "ScalarUadd";
constexpr auto kScalarUsub = "ScalarUsub";
constexpr auto kScalarEq = "ScalarEqual";
constexpr auto kScalarLt = "ScalarLess";
constexpr auto kScalarGt = "ScalarGreater";
constexpr auto kScalarLe = "ScalarLessEqual";
constexpr auto kScalarGe = "ScalarGreaterEqual";
constexpr auto kScalarEq = "scalar_eq";
constexpr auto kScalarLt = "scalar_lt";
constexpr auto kScalarGt = "scalar_gt";
constexpr auto kScalarLe = "scalar_le";
constexpr auto kScalarGe = "scalar_ge";
constexpr auto kScalarBool = "ScalarBool";
constexpr auto kScalarBitwiseAnd = "ScalarBitwiseAnd";
constexpr auto kScalarBitwiseOr = "ScalarBitwiseOr";
constexpr auto kScalarBitwiseAnd = "bit_and";
constexpr auto kScalarBitwiseOr = "bit_or";
constexpr auto kExp = "Exp";
constexpr auto kEqual = "Equal";
constexpr auto kNotEqual = "NotEqual";
@ -181,7 +181,7 @@ constexpr auto kLogNormalReverse = "LogNormalReverse";
constexpr auto kUnstack = "Unstack";
constexpr auto kUnpack = "Unpack";
constexpr auto kTupleGetItem = "TupleGetItem";
constexpr auto kListGetItem = "ListGetItem";
constexpr auto kListGetItem = "list_getitem";
constexpr auto kSliceGetItem = "SliceGetItem";
constexpr auto kGeLU = "GeLU";
constexpr auto kUnravelIndex = "UnravelIndex";
@ -491,7 +491,7 @@ GVAR_DEF(PrimitivePtr, kPrimScalarAdd, std::make_shared<Primitive>(kScalarAdd));
GVAR_DEF(PrimitivePtr, kPrimScalarSub, std::make_shared<Primitive>(kScalarSub));
GVAR_DEF(PrimitivePtr, kPrimScalarMul, std::make_shared<Primitive>(kScalarMul));
GVAR_DEF(PrimitivePtr, kPrimScalarDiv, std::make_shared<Primitive>(kScalarDiv));
GVAR_DEF(PrimitivePtr, kPrimScalarFloorDiv, std::make_shared<Primitive>(kScalarFloorDiv));
GVAR_DEF(PrimitivePtr, kPrimScalarFloorDiv, std::make_shared<Primitive>(kScalarFloordiv));
GVAR_DEF(PrimitivePtr, kPrimScalarMod, std::make_shared<Primitive>(kScalarMod));
GVAR_DEF(PrimitivePtr, kPrimScalarPow, std::make_shared<Primitive>(kScalarPow));
GVAR_DEF(PrimitivePtr, kPrimScalarTrunc, std::make_shared<Primitive>(kScalarTrunc));

View File

@ -21,13 +21,13 @@
namespace mindspore {
namespace ops {
/// \brief ListGetItem op is added to the multi-output node to describe which output of the node, which is only used
/// \brief list_getitem op is added to the multi-output node to describe which output of the node, which is only used
/// in FuncGraph.
class MIND_API ListGetItem : public BaseOperator {
class MIND_API list_getitem : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ListGetItem);
MIND_API_BASE_MEMBER(list_getitem);
/// \brief Constructor.
ListGetItem() : BaseOperator(prim::kListGetItem) { InitIOName({"input", "index"}, {"output"}); }
list_getitem() : BaseOperator(prim::kListGetItem) { InitIOName({"input", "index"}, {"output"}); }
};
} // namespace ops
} // namespace mindspore

View File

@ -224,7 +224,7 @@ MathImplFunc ChooseFunc(const std::string &prim_name) {
{prim::kScalarLt, LtImpl<T>},
{prim::kScalarGe, GeImpl<T>},
{prim::kScalarLe, LeImpl<T>},
{prim::kScalarFloorDiv, FloorDivImpl<T>}};
{prim::kScalarFloordiv, FloorDivImpl<T>}};
auto iter = infer_value_func_map.find(prim_name);
if (iter == infer_value_func_map.end()) {
MS_EXCEPTION(TypeError) << "For '" << prim_name
@ -333,23 +333,23 @@ MIND_API_OPERATOR_IMPL(ScalarAdd, BaseOperator);
MIND_API_OPERATOR_IMPL(ScalarSub, BaseOperator);
MIND_API_OPERATOR_IMPL(ScalarMul, BaseOperator);
MIND_API_OPERATOR_IMPL(ScalarDiv, BaseOperator);
MIND_API_OPERATOR_IMPL(ScalarFloorDiv, BaseOperator);
MIND_API_OPERATOR_IMPL(ScalarFloordiv, BaseOperator);
MIND_API_OPERATOR_IMPL(ScalarMod, BaseOperator);
MIND_API_OPERATOR_IMPL(ScalarEqual, BaseOperator);
MIND_API_OPERATOR_IMPL(ScalarGreater, BaseOperator);
MIND_API_OPERATOR_IMPL(ScalarGreaterEqual, BaseOperator);
MIND_API_OPERATOR_IMPL(ScalarLess, BaseOperator);
MIND_API_OPERATOR_IMPL(ScalarLessEqual, BaseOperator);
MIND_API_OPERATOR_IMPL(scalar_eq, BaseOperator);
MIND_API_OPERATOR_IMPL(scalar_gt, BaseOperator);
MIND_API_OPERATOR_IMPL(scalar_ge, BaseOperator);
MIND_API_OPERATOR_IMPL(scalar_lt, BaseOperator);
MIND_API_OPERATOR_IMPL(scalar_le, BaseOperator);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarAdd, prim::kPrimScalarAdd, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarSub, prim::kPrimScalarSub, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarMul, prim::kPrimScalarMul, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarDiv, prim::kPrimScalarDiv, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarFloorDiv, prim::kPrimScalarFloorDiv, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarFloordiv, prim::kPrimScalarFloorDiv, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarMod, prim::kPrimScalarMod, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarEqual, prim::kPrimScalarEq, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarGreater, prim::kPrimScalarGt, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarGreaterEqual, prim::kPrimScalarGe, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarLess, prim::kPrimScalarLt, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarLessEqual, prim::kPrimScalarLe, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(scalar_eq, prim::kPrimScalarEq, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(scalar_gt, prim::kPrimScalarGt, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(scalar_ge, prim::kPrimScalarGe, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(scalar_lt, prim::kPrimScalarLt, ScalarArithmeticInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(scalar_le, prim::kPrimScalarLe, ScalarArithmeticInfer, true);
} // namespace ops
} // namespace mindspore

View File

@ -115,9 +115,9 @@ class ScalarBitwiseInfer : public abstract::OpInferBase {
return res;
}
};
MIND_API_OPERATOR_IMPL(ScalarBitwiseOr, BaseOperator);
MIND_API_OPERATOR_IMPL(ScalarBitwiseAnd, BaseOperator);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarBitwiseOr, prim::kPrimScalarBitwiseOr, ScalarBitwiseInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarBitwiseAnd, prim::kPrimScalarBitwiseAnd, ScalarBitwiseInfer, true);
MIND_API_OPERATOR_IMPL(bit_or, BaseOperator);
MIND_API_OPERATOR_IMPL(bit_and, BaseOperator);
REGISTER_PRIMITIVE_OP_INFER_IMPL(bit_or, prim::kPrimScalarBitwiseOr, ScalarBitwiseInfer, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(bit_and, prim::kPrimScalarBitwiseAnd, ScalarBitwiseInfer, true);
} // namespace ops
} // namespace mindspore

View File

@ -22,11 +22,11 @@
namespace mindspore {
namespace ops {
/// \brief
class MIND_API ScalarBitwiseAnd : public BaseOperator {
class MIND_API bit_and : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ScalarBitwiseAnd);
MIND_API_BASE_MEMBER(bit_and);
/// \brief Constructor.
ScalarBitwiseAnd() : BaseOperator(prim::kScalarBitwiseAnd) {}
bit_and() : BaseOperator(prim::kScalarBitwiseAnd) {}
/// \brief Init.
void Init() const {}
};

View File

@ -22,11 +22,11 @@
namespace mindspore {
namespace ops {
/// \brief
class MIND_API ScalarBitwiseOr : public BaseOperator {
class MIND_API bit_or : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ScalarBitwiseOr);
MIND_API_BASE_MEMBER(bit_or);
/// \brief Constructor.
ScalarBitwiseOr() : BaseOperator(prim::kScalarBitwiseOr) {}
bit_or() : BaseOperator(prim::kScalarBitwiseOr) {}
/// \brief Init.
void Init() const {}
};

View File

@ -21,12 +21,12 @@
namespace mindspore {
namespace ops {
/// \brief ScalarEqual op is used to judge equal between variable scalar.
class MIND_API ScalarEqual : public BaseOperator {
/// \brief scalar_eq op is used to judge equal between variable scalar.
class MIND_API scalar_eq : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ScalarEqual);
MIND_API_BASE_MEMBER(scalar_eq);
/// \brief Constructor.
ScalarEqual() : BaseOperator(prim::kScalarEq) {}
scalar_eq() : BaseOperator(prim::kScalarEq) {}
/// \brief Init.
void Init() const {}
};

View File

@ -22,11 +22,11 @@
namespace mindspore {
namespace ops {
/// \brief ScalarFloorDiv op is used to div between variable scalar.
class MIND_API ScalarFloorDiv : public BaseOperator {
class MIND_API ScalarFloordiv : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ScalarFloorDiv);
MIND_API_BASE_MEMBER(ScalarFloordiv);
/// \brief Constructor.
ScalarFloorDiv() : BaseOperator(prim::kScalarFloorDiv) {}
ScalarFloordiv() : BaseOperator(prim::kScalarFloordiv) {}
/// \brief Init.
void Init() const {}
};

View File

@ -21,12 +21,12 @@
namespace mindspore {
namespace ops {
/// \brief ScalarGreaterEqual op is used to judge greaterEqual between variable scalar.
class MIND_API ScalarGreaterEqual : public BaseOperator {
/// \brief scalar_ge op is used to judge greaterEqual between variable scalar.
class MIND_API scalar_ge : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ScalarGreaterEqual);
MIND_API_BASE_MEMBER(scalar_ge);
/// \brief Constructor.
ScalarGreaterEqual() : BaseOperator(prim::kScalarGe) {}
scalar_ge() : BaseOperator(prim::kScalarGe) {}
/// \brief Init.
void Init() const {}
};

View File

@ -21,12 +21,12 @@
namespace mindspore {
namespace ops {
/// \brief ScalarGreater op is used to judge greater between variable scalar.
class MIND_API ScalarGreater : public BaseOperator {
/// \brief scalar_gt op is used to judge greater between variable scalar.
class MIND_API scalar_gt : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ScalarGreater);
MIND_API_BASE_MEMBER(scalar_gt);
/// \brief Constructor.
ScalarGreater() : BaseOperator(prim::kScalarGt) {}
scalar_gt() : BaseOperator(prim::kScalarGt) {}
/// \brief Init.
void Init() const {}
};

View File

@ -21,12 +21,12 @@
namespace mindspore {
namespace ops {
/// \brief ScalarLessEqual op is used to judge lessEqual between variable scalar.
class MIND_API ScalarLessEqual : public BaseOperator {
/// \brief scalar_le op is used to judge lessEqual between variable scalar.
class MIND_API scalar_le : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ScalarLessEqual);
MIND_API_BASE_MEMBER(scalar_le);
/// \brief Constructor.
ScalarLessEqual() : BaseOperator(prim::kScalarLe) {}
scalar_le() : BaseOperator(prim::kScalarLe) {}
/// \brief Init.
void Init() const {}
};

View File

@ -21,12 +21,12 @@
namespace mindspore {
namespace ops {
/// \brief ScalarLess op is used to judge less between variable scalar.
class MIND_API ScalarLess : public BaseOperator {
/// \brief scalar_lt op is used to judge less between variable scalar.
class MIND_API scalar_lt : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ScalarLess);
MIND_API_BASE_MEMBER(scalar_lt);
/// \brief Constructor.
ScalarLess() : BaseOperator(prim::kScalarLt) {}
scalar_lt() : BaseOperator(prim::kScalarLt) {}
/// \brief Init.
void Init() const {}
};

View File

@ -102,9 +102,9 @@ class SequenceGetItemInfer : public abstract::OpInferBase {
};
MIND_API_OPERATOR_IMPL(TupleGetItem, BaseOperator);
MIND_API_OPERATOR_IMPL(RealTupleGetItem, BaseOperator);
MIND_API_OPERATOR_IMPL(ListGetItem, BaseOperator);
MIND_API_OPERATOR_IMPL(list_getitem, BaseOperator);
REGISTER_PRIMITIVE_OP_INFER_IMPL(TupleGetItem, prim::kPrimTupleGetItem, SequenceGetItemInfer, false);
REGISTER_PRIMITIVE_OP_INFER_IMPL(RealTupleGetItem, prim::kPrimRealTupleGetItem, SequenceGetItemInfer, false);
REGISTER_PRIMITIVE_OP_INFER_IMPL(ListGetItem, prim::kPrimListGetItem, SequenceGetItemInfer, false);
REGISTER_PRIMITIVE_OP_INFER_IMPL(list_getitem, prim::kPrimListGetItem, SequenceGetItemInfer, false);
} // namespace ops
} // namespace mindspore

View File

@ -23,7 +23,7 @@
namespace mindspore {
// clang-format off
#ifndef ENABLE_SECURITY
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "ListGetItem",
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "list_getitem",
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
"list_append", "list_map", "list_reduce", "tuple_reversed", "tile_shape", "tuple_div", "tuple_to_array",
"make_dict", "make_slice", "string_eq", "VirtualLoss", "Return", "env_getitem",
@ -33,7 +33,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
"InvertPermutation", "DropoutGenMask", "StatelessDropOutGenMask", "embed", "create_instance", "RefToEmbed",
"StopGradient", "UpdateState", "Load", "Switch", "Print", "call_instance"};
#else
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "ListGetItem",
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "list_getitem",
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
"list_append", "list_map", "list_reduce", "tuple_reversed", "tile_shape", "tuple_div", "tuple_to_array",
"make_dict", "make_slice", "string_eq", "VirtualLoss", "Return", "env_getitem",

View File

@ -65,7 +65,7 @@ def TupleGetItem(x, index):
return x[index]
def ScalarGreater(x, y):
def scalar_gt(x, y):
"""Implement `scalar_gt`."""
return x > y
@ -75,17 +75,17 @@ def scalar_ne(x, y):
return x != y
def ScalarEqual(x, y):
def scalar_eq(x, y):
"""Implement `scalar_eq`."""
return x == y
def ScalarLessEqual(x, y):
def scalar_le(x, y):
"""Implement `scalar_le`."""
return x <= y
def ScalarLess(x, y):
def scalar_lt(x, y):
"""Implement `scalar_lt`."""
return x < y
@ -117,7 +117,7 @@ def Switch(c, x, y):
return x if c else y
def ListGetItem(data, item):
def list_getitem(data, item):
"""Implement `list_getitem`."""
return data[item]

View File

@ -72,7 +72,7 @@ def bprop_tuple_getitem(data, idx, out, dout):
return F.tuple_setitem(zeros_like(data), idx, dout), zeros_like(idx)
@bprops.register("ListGetItem")
@bprops.register("list_getitem")
def bprop_list_getitem(data, idx, out, dout):
"""Backpropagator for primitive `list_getitem`."""
return F.list_setitem(zeros_like(data), idx, dout), zeros_like(idx)

View File

@ -64,7 +64,7 @@ def get_bprop_scalar_div(self):
return bprop
@bprop_getters.register(_scalar_ops.ScalarFloorDiv)
@bprop_getters.register(_scalar_ops.ScalarFloordiv)
def get_bprop_scalar_floordiv(self):
"""Grad definition for `ScalarFloorDiv` operation."""
@ -86,13 +86,13 @@ def get_bprop_scalar_mod(self):
return bprop
@bprop_getters.register(_scalar_ops.ScalarEqual)
@bprop_getters.register(_scalar_ops.ScalarLessEqual)
@bprop_getters.register(_scalar_ops.ScalarLess)
@bprop_getters.register(_scalar_ops.ScalarGreaterEqual)
@bprop_getters.register(_scalar_ops.ScalarGreater)
@bprop_getters.register(_scalar_ops.ScalarBitwiseAnd)
@bprop_getters.register(_scalar_ops.ScalarBitwiseOr)
@bprop_getters.register(_scalar_ops.scalar_eq)
@bprop_getters.register(_scalar_ops.scalar_le)
@bprop_getters.register(_scalar_ops.scalar_lt)
@bprop_getters.register(_scalar_ops.scalar_ge)
@bprop_getters.register(_scalar_ops.scalar_gt)
@bprop_getters.register(_scalar_ops.bit_and)
@bprop_getters.register(_scalar_ops.bit_or)
def get_bprop_scalar_logic(self):
"""Grad definition for `ScalarLogicOps` operation."""

View File

@ -50,7 +50,7 @@ def get_identity_vmap_rule(prim, axis_size):
return vmap_rule
@vmap_rules_getters.register("ListGetItem")
@vmap_rules_getters.register("list_getitem")
@vmap_rules_getters.register("TupleGetItem")
def get_seq_get_item_vmap_rule(prim, axis_size):
"""VmapRule for `list_getitem` or `TupleGetItem` operation."""

View File

@ -73,7 +73,7 @@ class _ListSlice(base.SequenceSliceGetItem_):
def __init__(self, name):
"""Initialize _TupleSlice."""
base.SequenceSliceGetItem_.__init__(self, name, "make_list", "ListGetItem")
base.SequenceSliceGetItem_.__init__(self, name, "make_list", "list_getitem")
def __call__(self, *args):
pass

View File

@ -61,16 +61,16 @@ scalar_mod = _scalar_ops.ScalarMod()
scalar_add = _scalar_ops.ScalarAdd()
scalar_mul = _scalar_ops.ScalarMul()
scalar_sub = _scalar_ops.ScalarSub()
scalar_gt = _scalar_ops.ScalarGreater()
scalar_ge = _scalar_ops.ScalarGreaterEqual()
scalar_le = _scalar_ops.ScalarLessEqual()
scalar_lt = _scalar_ops.ScalarLess()
scalar_eq = _scalar_ops.ScalarEqual()
scalar_floordiv = _scalar_ops.ScalarFloorDiv()
scalar_gt = _scalar_ops.scalar_gt()
scalar_ge = _scalar_ops.scalar_ge()
scalar_le = _scalar_ops.scalar_le()
scalar_lt = _scalar_ops.scalar_lt()
scalar_eq = _scalar_ops.scalar_eq()
scalar_floordiv = _scalar_ops.ScalarFloordiv()
tuple_setitem = Primitive('tuple_setitem')
tuple_getitem = Primitive(_constants.kTupleGetItem)
list_getitem = Primitive('ListGetItem')
list_getitem = Primitive('list_getitem')
list_setitem = Primitive('list_setitem')
dict_getitem = Primitive('dict_getitem')
dict_setitem = Primitive('dict_setitem')

View File

@ -21,7 +21,7 @@ import numpy as np
from mindspore.common import Tensor
from mindspore.ops import composite as C
from mindspore.ops.operations.array_ops import Cast
from mindspore.ops.operations._scalar_ops import ScalarBitwiseOr, ScalarBitwiseAnd
from mindspore.ops.operations._scalar_ops import bit_or, bit_and
from mindspore.ops import signature as sig
from mindspore.ops.operations.math_ops import _infer_shape_reduce
from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, _run_op
@ -36,8 +36,8 @@ from mindspore.common._register_for_adapter import ms_adapter_registry
# Bit operation
bit_and = ScalarBitwiseAnd()
bit_or = ScalarBitwiseOr()
bit_and = bit_and()
bit_or = bit_or()
bit_xor = Primitive("bit_xor")
bit_left_shift = Primitive("bit_left_shift")
bit_right_shift = Primitive("bit_right_shift")

View File

@ -48,7 +48,7 @@ class ScalarDiv(Primitive):
"""Initialize ScalarDiv"""
class ScalarFloorDiv(Primitive):
class ScalarFloordiv(Primitive):
r"""
Computes the quotient of dividing the first input scalar by the second input scalar element-wise.
@ -76,7 +76,7 @@ class ScalarFloorDiv(Primitive):
"""
@prim_attr_register
def __init__(self):
"""Initialize ScalarFloorDiv"""
"""Initialize ScalarFloordiv"""
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
@ -158,7 +158,7 @@ class ScalarMul(Primitive):
"""Initialize ScalarMul"""
class ScalarEqual(Primitive):
class scalar_eq(Primitive):
r"""
Computes the equivalence between two Scalars.
@ -184,7 +184,7 @@ class ScalarEqual(Primitive):
"""Initialize ScalarMul"""
class ScalarGreater(Primitive):
class scalar_gt(Primitive):
r"""
Compare the value of the input scalars :math:`x,y`, and the output result is a bool value.
@ -207,10 +207,10 @@ class ScalarGreater(Primitive):
"""
@prim_attr_register
def __init__(self):
"""Initialize ScalarGreater"""
"""Initialize scalar_gt"""
class ScalarLess(Primitive):
class scalar_lt(Primitive):
r"""
Computes the boolean value of :math:`x < y`.
@ -233,10 +233,10 @@ class ScalarLess(Primitive):
"""
@prim_attr_register
def __init__(self):
"""Initialize ScalarLess"""
"""Initialize scalar_lt"""
class ScalarGreaterEqual(Primitive):
class scalar_ge(Primitive):
r"""
Compare the value of the input scalars :math:`x,y`, and the output result is a bool value.
@ -259,10 +259,10 @@ class ScalarGreaterEqual(Primitive):
"""
@prim_attr_register
def __init__(self):
"""Initialize ScalarGreaterEqual"""
"""Initialize scalar_ge"""
class ScalarLessEqual(Primitive):
class scalar_le(Primitive):
r"""
Compare the value of the input scalars :math:`x,y`, and the output result is a bool value.
@ -285,7 +285,7 @@ class ScalarLessEqual(Primitive):
"""
@prim_attr_register
def __init__(self):
"""Initialize ScalarLessEqual"""
"""Initialize scalar_le"""
class ScalarMod(Primitive):
@ -343,7 +343,7 @@ class ScalarBool(Primitive):
"""Initialize ScalarBool"""
class ScalarBitwiseAnd(Primitive):
class bit_and(Primitive):
r"""
Returns bitwise `and` of two scalars.
@ -373,7 +373,7 @@ class ScalarBitwiseAnd(Primitive):
"""Initialize ScalarMod"""
class ScalarBitwiseOr(Primitive):
class bit_or(Primitive):
r"""
Returns bitwise `or` of two scalars.

View File

@ -20,8 +20,9 @@ from tuple_help import TupleFactory
context.set_context(mode=context.GRAPH_MODE)
@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scalar_add():
"""
@ -45,8 +46,9 @@ def test_scalar_add():
fact.grad_impl()
@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scalar_sub():
"""
@ -71,8 +73,9 @@ def test_scalar_sub():
@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scalar_mul():
"""
@ -98,6 +101,7 @@ def test_scalar_mul():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scalar_div():
"""
@ -121,8 +125,9 @@ def test_scalar_div():
fact.grad_impl()
@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scalar_mod():
"""
@ -146,8 +151,9 @@ def test_scalar_mod():
fact.grad_impl()
@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scalar_floordiv():
"""
@ -171,12 +177,13 @@ def test_scalar_floordiv():
fact.grad_impl()
@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scalar_eq():
"""
Feature: test ScalarEqual.
Feature: test scalar_eq.
Description: inputs is dynamic scalar.
Expectation: the result match with numpy result
"""
@ -196,12 +203,13 @@ def test_scalar_eq():
fact.grad_impl()
@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scalar_ge():
"""
Feature: test ScalarGreaterEqual.
Feature: test scalar_ge.
Description: inputs is dynamic scalar.
Expectation: the result match with numpy result
"""
@ -223,10 +231,11 @@ def test_scalar_ge():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scalar_gt():
"""
Feature: test ScalarGreater.
Feature: test scalar_gt.
Description: inputs is dynamic scalar.
Expectation: the result match with numpy result
"""
@ -248,10 +257,11 @@ def test_scalar_gt():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scalar_le():
"""
Feature: test ScalarLessEqual.
Feature: test scalar_le.
Description: inputs is dynamic scalar.
Expectation: the result match with numpy result
"""
@ -273,10 +283,11 @@ def test_scalar_le():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scalar_lt():
"""
Feature: test ScalarLess.
Feature: test scalar_lt.
Description: inputs is dynamic scalar.
Expectation: the result match with numpy result
"""

View File

@ -0,0 +1,60 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore.ops.operations import _sequence_ops as seq
import mindspore.ops as ops
from mindspore import context
from mindspore.common import mutable
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.tuple_to_tensor = seq.TupleToTensor()
self.scalar_to_tensor = ops.ScalarToTensor()
def construct(self, x, y):
return self.tuple_to_tensor(x, mstype.float32), self.scalar_to_tensor(y, mstype.int64)
def dyn_case():
x = mutable((1, 2, 3), True)
y = mutable(3)
expect_x = np.array([1, 2, 3], dtype=np.float32)
expect_y = np.array(3, dtype=np.int64)
net = Net()
res_x, res_y = net(x, y)
rtol = 1.e-4
atol = 1.e-4
assert np.allclose(res_x.asnumpy(), expect_x, rtol, atol, equal_nan=True)
assert np.allclose(res_y.asnumpy(), expect_y, rtol, atol, equal_nan=True)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_seq_to_tensor():
"""
Feature: test xxToTensor.
Description: inputs is dynamic sequence or scalar.
Expectation: the result match with numpy result
"""
dyn_case()

View File

@ -0,0 +1,56 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore.ops.operations import _sequence_ops as seq
from mindspore import context
from mindspore import Tensor
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def __init__(self, axis=0):
super().__init__()
self.tensor_to_tuple = seq.TensorToTuple()
self.tensor_to_scalar = seq.TensorToScalar()
def construct(self, x, y):
return self.tensor_to_tuple(x), self.tensor_to_scalar(y)
def dyn_case():
x = Tensor([1, 2, 3], mstype.int64)
y = Tensor(1, mstype.float32)
expect_x = (1, 2, 3)
expect_y = 1.0
net = Net()
res_x, res_y = net(x, y)
assert expect_x == res_x
assert expect_y == res_y
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_seq_to_tensor():
"""
Feature: test TensorToxx.
Description: inputs is dynamic sequence or scalar.
Expectation: the result match with numpy result
"""
dyn_case()

View File

@ -263,7 +263,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
/// Description: The second input is a scalar
/// Expectation: Throw type error
TEST_F(TestComposite, test_ListSlice_arg_one_number) {
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "ListGetItem");
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
FuncGraphPtr list_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 3);
AbstractBasePtrList eles;
@ -298,7 +298,7 @@ TEST_F(TestComposite, test_ListSlice_arg_one_number) {
/// Expectation: No Expectation
TEST_F(TestComposite, test_ListSlice_arg_slice) {
std::shared_ptr<py::scoped_interpreter> env = python_adapter::set_python_scoped();
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "ListGetItem");
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
AbstractBasePtrList eles;
@ -327,7 +327,7 @@ TEST_F(TestComposite, test_ListSlice_arg_slice) {
/// Description: Test List slice the step is none
/// Expectation: No Expectation
TEST_F(TestComposite, test_ListSlice_arg_slice_step_none) {
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "ListGetItem");
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
AbstractBasePtrList eles;
@ -356,7 +356,7 @@ TEST_F(TestComposite, test_ListSlice_arg_slice_step_none) {
/// Description: Test List slice the step is negative
/// Expectation: No Expectation
TEST_F(TestComposite, test_ListSlice_arg_slice_step_negative) {
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "ListGetItem");
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
AbstractBasePtrList eles;
@ -385,7 +385,7 @@ TEST_F(TestComposite, test_ListSlice_arg_slice_step_negative) {
/// Description: Test List slice the step is positive
/// Expectation: No Expectation
TEST_F(TestComposite, test_ListSlice_arg_slice_step_positive) {
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "ListGetItem");
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
AbstractBasePtrList eles;

View File

@ -189,7 +189,7 @@ TEST_F(TestOps, TupleGetItemTest) {
}
TEST_F(TestOps, ListGetItemTest) {
auto prim = std::make_shared<Primitive>("ListGetItem");
auto prim = std::make_shared<Primitive>("list_getitem");
ASSERT_EQ(prim->name(), kPrimListGetItem->name());
}

View File

@ -682,7 +682,7 @@ TEST_F(TestPrim, test_list_getitem) {
args_spec_list.push_back(abstract_v1);
args_spec_list.push_back(abstract_v2);
auto prim = std::make_shared<Primitive>("ListGetItem");
auto prim = std::make_shared<Primitive>("list_getitem");
FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();