forked from mindspore-Ecosystem/mindspore
add tupletotensor test case & modify some scalar ops name
This commit is contained in:
parent
fedc32491c
commit
5519cfd9f0
|
@ -35,7 +35,7 @@ PrimToFunction::PrimToFunction()
|
|||
{kScalarLt, kPrimTypeNumTwoArgs}, {"scalar_ne", kPrimTypeNumTwoArgs},
|
||||
{kScalarMod, kPrimTypeNumTwoArgs}, {kScalarMul, kPrimTypeNumTwoArgs},
|
||||
{kScalarPow, kPrimTypeNumTwoArgs}, {kScalarSub, kPrimTypeNumTwoArgs},
|
||||
{kScalarFloorDiv, kPrimTypeNumTwoArgs}, {kScalarBitwiseAnd, kPrimTypeNumTwoArgs},
|
||||
{kScalarFloordiv, kPrimTypeNumTwoArgs}, {kScalarBitwiseAnd, kPrimTypeNumTwoArgs},
|
||||
{kScalarBitwiseOr, kPrimTypeNumTwoArgs}, {"bit_xor", kPrimTypeNumTwoArgs},
|
||||
{"bit_left_shift", kPrimTypeNumTwoArgs}, {"bit_right_shift", kPrimTypeNumTwoArgs},
|
||||
{kStringNot, kPrimTypeStrOneArg}, {kStringConcat, kPrimTypeStrTwoArgs},
|
||||
|
|
|
@ -781,7 +781,7 @@ class CleanAfterOptARewriter : public BaseRewriter {
|
|||
}
|
||||
|
||||
// From:
|
||||
// ListGetItem(list, key)
|
||||
// list_getitem(list, key)
|
||||
// To:
|
||||
// TupleGetItem(list, key)
|
||||
AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) {
|
||||
|
|
|
@ -178,6 +178,17 @@ REG_ASCEND_VM_OP_ADAPTATION_INFO(kBatchToSpaceOpName)
|
|||
REG_ASCEND_VM_OP_ADAPTATION_INFO(kBatchToSpaceNDOpName).set_backend_op_name(kBatchToSpaceNDDOpName);
|
||||
|
||||
REG_ASCEND_VM_OP_ADAPTATION_INFO(kCastOpName).set_target_op_name(kCastOpName).set_input_attr_info(1, "int");
|
||||
REG_ASCEND_VM_OP_ADAPTATION_INFO(kScalarToTensorOpName)
|
||||
.set_target_op_name(kScalarToTensorOpName)
|
||||
.set_input_attr_info(1, "int");
|
||||
|
||||
REG_ASCEND_VM_OP_ADAPTATION_INFO(kTupleToTensorOpName)
|
||||
.set_target_op_name(kTupleToTensorOpName)
|
||||
.set_input_attr_info(1, "int");
|
||||
|
||||
REG_ASCEND_VM_OP_ADAPTATION_INFO(kListToTensorOpName)
|
||||
.set_target_op_name(kListToTensorOpName)
|
||||
.set_input_attr_info(1, "int");
|
||||
|
||||
REG_ASCEND_VM_OP_ADAPTATION_INFO(kCentralizationOpName)
|
||||
.set_target_op_name(kCentralizationOpName)
|
||||
|
|
|
@ -34,13 +34,13 @@ constexpr auto kScalarAdd = "ScalarAdd";
|
|||
constexpr auto kScalarSub = "ScalarSub";
|
||||
constexpr auto kScalarMul = "ScalarMul";
|
||||
constexpr auto kScalarDiv = "ScalarDiv";
|
||||
constexpr auto kScalarFloorDiv = "ScalarFloorDiv";
|
||||
constexpr auto kScalarFloordiv = "ScalarFloordiv";
|
||||
constexpr auto kScalarMod = "ScalarMod";
|
||||
constexpr auto kScalarGt = "ScalarGreater";
|
||||
constexpr auto kScalarGe = "ScalarGreaterEqual";
|
||||
constexpr auto kScalarLt = "ScalarLess";
|
||||
constexpr auto kScalarLe = "ScalarLessEqual";
|
||||
constexpr auto kScalarEq = "ScalarEqual";
|
||||
constexpr auto kScalarGt = "scalar_gt";
|
||||
constexpr auto kScalarGe = "scalar_ge";
|
||||
constexpr auto kScalarLt = "scalar_lt";
|
||||
constexpr auto kScalarLe = "scalar_le";
|
||||
constexpr auto kScalarEq = "scalar_eq";
|
||||
constexpr size_t kInputNum = 2;
|
||||
constexpr size_t kInputx = 0;
|
||||
constexpr size_t kInputy = 1;
|
||||
|
@ -194,9 +194,6 @@ bool ScalarArithmeticCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
|||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
if (kernel_name_ != kernel_type_) {
|
||||
MS_LOG(EXCEPTION) << "Suppose to be " << kernel_type_ << " but got " << kernel_name_;
|
||||
}
|
||||
if (inputs.size() != kInputNum) {
|
||||
MS_LOG(EXCEPTION) << "For kernel '" << kernel_type_ << "' input_num must be 2, but got " << inputs.size();
|
||||
}
|
||||
|
@ -245,7 +242,7 @@ bool ScalarArithmeticCpuKernelMod::LaunchKernel(const std::vector<KernelTensorPt
|
|||
{kScalarLt, LtImpl<T, S, N>},
|
||||
{kScalarGe, GeImpl<T, S, N>},
|
||||
{kScalarLe, LeImpl<T, S, N>},
|
||||
{kScalarFloorDiv, FloorDivImpl<T, S, N>}};
|
||||
{kScalarFloordiv, FloorDivImpl<T, S, N>}};
|
||||
auto iter = func_map.find(kernel_name_);
|
||||
if (iter == func_map.end()) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << kernel_name_
|
||||
|
@ -347,19 +344,19 @@ MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarMul,
|
|||
[]() { return std::make_shared<ScalarArithmeticCpuKernelMod>(kScalarMul); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarDiv,
|
||||
[]() { return std::make_shared<ScalarArithmeticCpuKernelMod>(kScalarDiv); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarFloorDiv,
|
||||
[]() { return std::make_shared<ScalarArithmeticCpuKernelMod>(kScalarFloorDiv); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarFloordiv,
|
||||
[]() { return std::make_shared<ScalarArithmeticCpuKernelMod>(kScalarFloordiv); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarMod,
|
||||
[]() { return std::make_shared<ScalarArithmeticCpuKernelMod>(kScalarMod); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarEqual,
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, scalar_eq,
|
||||
[]() { return std::make_shared<ScalarArithmeticCpuKernelMod>(kScalarEq); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarGreater,
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, scalar_gt,
|
||||
[]() { return std::make_shared<ScalarArithmeticCpuKernelMod>(kScalarGt); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarGreaterEqual,
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, scalar_ge,
|
||||
[]() { return std::make_shared<ScalarArithmeticCpuKernelMod>(kScalarGe); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarLess,
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, scalar_lt,
|
||||
[]() { return std::make_shared<ScalarArithmeticCpuKernelMod>(kScalarLt); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarLessEqual,
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, scalar_le,
|
||||
[]() { return std::make_shared<ScalarArithmeticCpuKernelMod>(kScalarLe); });
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,8 +26,8 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr auto kScalarBitwiseAnd = "ScalarBitwiseAnd";
|
||||
constexpr auto kScalarBitwiseOr = "ScalarBitwiseOr";
|
||||
constexpr auto kScalarBitwiseAnd = "bit_and";
|
||||
constexpr auto kScalarBitwiseOr = "bit_or";
|
||||
constexpr size_t kInputNum = 2;
|
||||
constexpr size_t kInputx = 0;
|
||||
constexpr size_t kInputy = 1;
|
||||
|
@ -54,9 +54,6 @@ bool ScalarBitwiseCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const
|
|||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
if (kernel_name_ != kernel_type_) {
|
||||
MS_LOG(EXCEPTION) << "Suppose to be " << kernel_type_ << " but got " << kernel_name_;
|
||||
}
|
||||
if (inputs.size() != kInputNum) {
|
||||
MS_LOG(EXCEPTION) << "For kernel '" << kernel_type_ << "' input_num must be 2, but got " << inputs.size();
|
||||
}
|
||||
|
@ -127,9 +124,9 @@ std::vector<KernelAttr> ScalarBitwiseCpuKernelMod::GetOpSupport() {
|
|||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarBitwiseAnd,
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, bit_and,
|
||||
[]() { return std::make_shared<ScalarBitwiseCpuKernelMod>(kScalarBitwiseAnd); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScalarBitwiseOr,
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, bit_or,
|
||||
[]() { return std::make_shared<ScalarBitwiseCpuKernelMod>(kScalarBitwiseOr); });
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -151,7 +151,6 @@ int TransposeGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
|
|||
std::vector<int64_t> perm;
|
||||
if (TryGetIntValue(inputs, kAxisIndex1st, kernel_name_, &perm)) {
|
||||
GetPermValue(perm);
|
||||
get_dynamic_perm_value_ = true;
|
||||
}
|
||||
if (int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||
return ret;
|
||||
|
|
|
@ -64,7 +64,6 @@ class TransposeGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelpe
|
|||
size_t workspace_size_{0};
|
||||
bool is_null_input_;
|
||||
bool is_dynamic_perm_{false};
|
||||
bool get_dynamic_perm_value_{false};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -25,6 +25,10 @@ namespace mindspore::opt {
|
|||
|
||||
RER_GPU_DYNAMIC_CONST_TO_ATTR(kCastOpName, 1);
|
||||
RER_GPU_DYNAMIC_CONST_TO_ATTR(kFillOpName, 0);
|
||||
RER_GPU_DYNAMIC_CONST_TO_ATTR(kScalarToTensorOpName, 1);
|
||||
RER_GPU_DYNAMIC_CONST_TO_ATTR(kTupleToTensorOpName, 1);
|
||||
RER_GPU_DYNAMIC_CONST_TO_ATTR(kListToTensorOpName, 1);
|
||||
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kCastOpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kCOO2CSROpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kCSR2COOOpName, 1);
|
||||
|
@ -35,6 +39,9 @@ RER_GPU_STATIC_CONST_TO_ATTR(kCSRMulOpName, 3);
|
|||
RER_GPU_STATIC_CONST_TO_ATTR(kCSRMVOpName, 3);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kCSRReduceSumOpName, 3, 4);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kFillOpName, 0);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kScalarToTensorOpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kTupleToTensorOpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kListToTensorOpName, 1);
|
||||
} // namespace mindspore::opt
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_GPU_OPTIMIZER_REG_GPU_CONST_INPUT_TO_ATTR_H_
|
||||
|
|
|
@ -69,21 +69,21 @@ constexpr auto kScalarAdd = "ScalarAdd";
|
|||
constexpr auto kScalarSub = "ScalarSub";
|
||||
constexpr auto kScalarMul = "ScalarMul";
|
||||
constexpr auto kScalarDiv = "ScalarDiv";
|
||||
constexpr auto kScalarFloorDiv = "ScalarFloorDiv";
|
||||
constexpr auto kScalarFloordiv = "ScalarFloordiv";
|
||||
constexpr auto kScalarMod = "ScalarMod";
|
||||
constexpr auto kScalarPow = "ScalarPow";
|
||||
constexpr auto kScalarTrunc = "ScalarTrunc";
|
||||
constexpr auto kScalarFloor = "ScalarFloor";
|
||||
constexpr auto kScalarUadd = "ScalarUadd";
|
||||
constexpr auto kScalarUsub = "ScalarUsub";
|
||||
constexpr auto kScalarEq = "ScalarEqual";
|
||||
constexpr auto kScalarLt = "ScalarLess";
|
||||
constexpr auto kScalarGt = "ScalarGreater";
|
||||
constexpr auto kScalarLe = "ScalarLessEqual";
|
||||
constexpr auto kScalarGe = "ScalarGreaterEqual";
|
||||
constexpr auto kScalarEq = "scalar_eq";
|
||||
constexpr auto kScalarLt = "scalar_lt";
|
||||
constexpr auto kScalarGt = "scalar_gt";
|
||||
constexpr auto kScalarLe = "scalar_le";
|
||||
constexpr auto kScalarGe = "scalar_ge";
|
||||
constexpr auto kScalarBool = "ScalarBool";
|
||||
constexpr auto kScalarBitwiseAnd = "ScalarBitwiseAnd";
|
||||
constexpr auto kScalarBitwiseOr = "ScalarBitwiseOr";
|
||||
constexpr auto kScalarBitwiseAnd = "bit_and";
|
||||
constexpr auto kScalarBitwiseOr = "bit_or";
|
||||
constexpr auto kExp = "Exp";
|
||||
constexpr auto kEqual = "Equal";
|
||||
constexpr auto kNotEqual = "NotEqual";
|
||||
|
@ -181,7 +181,7 @@ constexpr auto kLogNormalReverse = "LogNormalReverse";
|
|||
constexpr auto kUnstack = "Unstack";
|
||||
constexpr auto kUnpack = "Unpack";
|
||||
constexpr auto kTupleGetItem = "TupleGetItem";
|
||||
constexpr auto kListGetItem = "ListGetItem";
|
||||
constexpr auto kListGetItem = "list_getitem";
|
||||
constexpr auto kSliceGetItem = "SliceGetItem";
|
||||
constexpr auto kGeLU = "GeLU";
|
||||
constexpr auto kUnravelIndex = "UnravelIndex";
|
||||
|
@ -491,7 +491,7 @@ GVAR_DEF(PrimitivePtr, kPrimScalarAdd, std::make_shared<Primitive>(kScalarAdd));
|
|||
GVAR_DEF(PrimitivePtr, kPrimScalarSub, std::make_shared<Primitive>(kScalarSub));
|
||||
GVAR_DEF(PrimitivePtr, kPrimScalarMul, std::make_shared<Primitive>(kScalarMul));
|
||||
GVAR_DEF(PrimitivePtr, kPrimScalarDiv, std::make_shared<Primitive>(kScalarDiv));
|
||||
GVAR_DEF(PrimitivePtr, kPrimScalarFloorDiv, std::make_shared<Primitive>(kScalarFloorDiv));
|
||||
GVAR_DEF(PrimitivePtr, kPrimScalarFloorDiv, std::make_shared<Primitive>(kScalarFloordiv));
|
||||
GVAR_DEF(PrimitivePtr, kPrimScalarMod, std::make_shared<Primitive>(kScalarMod));
|
||||
GVAR_DEF(PrimitivePtr, kPrimScalarPow, std::make_shared<Primitive>(kScalarPow));
|
||||
GVAR_DEF(PrimitivePtr, kPrimScalarTrunc, std::make_shared<Primitive>(kScalarTrunc));
|
||||
|
|
|
@ -21,13 +21,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
/// \brief ListGetItem op is added to the multi-output node to describe which output of the node, which is only used
|
||||
/// \brief list_getitem op is added to the multi-output node to describe which output of the node, which is only used
|
||||
/// in FuncGraph.
|
||||
class MIND_API ListGetItem : public BaseOperator {
|
||||
class MIND_API list_getitem : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ListGetItem);
|
||||
MIND_API_BASE_MEMBER(list_getitem);
|
||||
/// \brief Constructor.
|
||||
ListGetItem() : BaseOperator(prim::kListGetItem) { InitIOName({"input", "index"}, {"output"}); }
|
||||
list_getitem() : BaseOperator(prim::kListGetItem) { InitIOName({"input", "index"}, {"output"}); }
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -224,7 +224,7 @@ MathImplFunc ChooseFunc(const std::string &prim_name) {
|
|||
{prim::kScalarLt, LtImpl<T>},
|
||||
{prim::kScalarGe, GeImpl<T>},
|
||||
{prim::kScalarLe, LeImpl<T>},
|
||||
{prim::kScalarFloorDiv, FloorDivImpl<T>}};
|
||||
{prim::kScalarFloordiv, FloorDivImpl<T>}};
|
||||
auto iter = infer_value_func_map.find(prim_name);
|
||||
if (iter == infer_value_func_map.end()) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name
|
||||
|
@ -333,23 +333,23 @@ MIND_API_OPERATOR_IMPL(ScalarAdd, BaseOperator);
|
|||
MIND_API_OPERATOR_IMPL(ScalarSub, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(ScalarMul, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(ScalarDiv, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(ScalarFloorDiv, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(ScalarFloordiv, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(ScalarMod, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(ScalarEqual, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(ScalarGreater, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(ScalarGreaterEqual, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(ScalarLess, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(ScalarLessEqual, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(scalar_eq, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(scalar_gt, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(scalar_ge, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(scalar_lt, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(scalar_le, BaseOperator);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarAdd, prim::kPrimScalarAdd, ScalarArithmeticInfer, true);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarSub, prim::kPrimScalarSub, ScalarArithmeticInfer, true);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarMul, prim::kPrimScalarMul, ScalarArithmeticInfer, true);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarDiv, prim::kPrimScalarDiv, ScalarArithmeticInfer, true);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarFloorDiv, prim::kPrimScalarFloorDiv, ScalarArithmeticInfer, true);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarFloordiv, prim::kPrimScalarFloorDiv, ScalarArithmeticInfer, true);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarMod, prim::kPrimScalarMod, ScalarArithmeticInfer, true);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarEqual, prim::kPrimScalarEq, ScalarArithmeticInfer, true);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarGreater, prim::kPrimScalarGt, ScalarArithmeticInfer, true);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarGreaterEqual, prim::kPrimScalarGe, ScalarArithmeticInfer, true);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarLess, prim::kPrimScalarLt, ScalarArithmeticInfer, true);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarLessEqual, prim::kPrimScalarLe, ScalarArithmeticInfer, true);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(scalar_eq, prim::kPrimScalarEq, ScalarArithmeticInfer, true);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(scalar_gt, prim::kPrimScalarGt, ScalarArithmeticInfer, true);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(scalar_ge, prim::kPrimScalarGe, ScalarArithmeticInfer, true);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(scalar_lt, prim::kPrimScalarLt, ScalarArithmeticInfer, true);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(scalar_le, prim::kPrimScalarLe, ScalarArithmeticInfer, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -115,9 +115,9 @@ class ScalarBitwiseInfer : public abstract::OpInferBase {
|
|||
return res;
|
||||
}
|
||||
};
|
||||
MIND_API_OPERATOR_IMPL(ScalarBitwiseOr, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(ScalarBitwiseAnd, BaseOperator);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarBitwiseOr, prim::kPrimScalarBitwiseOr, ScalarBitwiseInfer, true);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(ScalarBitwiseAnd, prim::kPrimScalarBitwiseAnd, ScalarBitwiseInfer, true);
|
||||
MIND_API_OPERATOR_IMPL(bit_or, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(bit_and, BaseOperator);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(bit_or, prim::kPrimScalarBitwiseOr, ScalarBitwiseInfer, true);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(bit_and, prim::kPrimScalarBitwiseAnd, ScalarBitwiseInfer, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,11 +22,11 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
/// \brief
|
||||
class MIND_API ScalarBitwiseAnd : public BaseOperator {
|
||||
class MIND_API bit_and : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ScalarBitwiseAnd);
|
||||
MIND_API_BASE_MEMBER(bit_and);
|
||||
/// \brief Constructor.
|
||||
ScalarBitwiseAnd() : BaseOperator(prim::kScalarBitwiseAnd) {}
|
||||
bit_and() : BaseOperator(prim::kScalarBitwiseAnd) {}
|
||||
/// \brief Init.
|
||||
void Init() const {}
|
||||
};
|
||||
|
|
|
@ -22,11 +22,11 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
/// \brief
|
||||
class MIND_API ScalarBitwiseOr : public BaseOperator {
|
||||
class MIND_API bit_or : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ScalarBitwiseOr);
|
||||
MIND_API_BASE_MEMBER(bit_or);
|
||||
/// \brief Constructor.
|
||||
ScalarBitwiseOr() : BaseOperator(prim::kScalarBitwiseOr) {}
|
||||
bit_or() : BaseOperator(prim::kScalarBitwiseOr) {}
|
||||
/// \brief Init.
|
||||
void Init() const {}
|
||||
};
|
||||
|
|
|
@ -21,12 +21,12 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
/// \brief ScalarEqual op is used to judge equal between variable scalar.
|
||||
class MIND_API ScalarEqual : public BaseOperator {
|
||||
/// \brief scalar_eq op is used to judge equal between variable scalar.
|
||||
class MIND_API scalar_eq : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ScalarEqual);
|
||||
MIND_API_BASE_MEMBER(scalar_eq);
|
||||
/// \brief Constructor.
|
||||
ScalarEqual() : BaseOperator(prim::kScalarEq) {}
|
||||
scalar_eq() : BaseOperator(prim::kScalarEq) {}
|
||||
/// \brief Init.
|
||||
void Init() const {}
|
||||
};
|
||||
|
|
|
@ -22,11 +22,11 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
/// \brief ScalarFloorDiv op is used to div between variable scalar.
|
||||
class MIND_API ScalarFloorDiv : public BaseOperator {
|
||||
class MIND_API ScalarFloordiv : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ScalarFloorDiv);
|
||||
MIND_API_BASE_MEMBER(ScalarFloordiv);
|
||||
/// \brief Constructor.
|
||||
ScalarFloorDiv() : BaseOperator(prim::kScalarFloorDiv) {}
|
||||
ScalarFloordiv() : BaseOperator(prim::kScalarFloordiv) {}
|
||||
/// \brief Init.
|
||||
void Init() const {}
|
||||
};
|
||||
|
|
|
@ -21,12 +21,12 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
/// \brief ScalarGreaterEqual op is used to judge greaterEqual between variable scalar.
|
||||
class MIND_API ScalarGreaterEqual : public BaseOperator {
|
||||
/// \brief scalar_ge op is used to judge greaterEqual between variable scalar.
|
||||
class MIND_API scalar_ge : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ScalarGreaterEqual);
|
||||
MIND_API_BASE_MEMBER(scalar_ge);
|
||||
/// \brief Constructor.
|
||||
ScalarGreaterEqual() : BaseOperator(prim::kScalarGe) {}
|
||||
scalar_ge() : BaseOperator(prim::kScalarGe) {}
|
||||
/// \brief Init.
|
||||
void Init() const {}
|
||||
};
|
||||
|
|
|
@ -21,12 +21,12 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
/// \brief ScalarGreater op is used to judge greater between variable scalar.
|
||||
class MIND_API ScalarGreater : public BaseOperator {
|
||||
/// \brief scalar_gt op is used to judge greater between variable scalar.
|
||||
class MIND_API scalar_gt : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ScalarGreater);
|
||||
MIND_API_BASE_MEMBER(scalar_gt);
|
||||
/// \brief Constructor.
|
||||
ScalarGreater() : BaseOperator(prim::kScalarGt) {}
|
||||
scalar_gt() : BaseOperator(prim::kScalarGt) {}
|
||||
/// \brief Init.
|
||||
void Init() const {}
|
||||
};
|
||||
|
|
|
@ -21,12 +21,12 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
/// \brief ScalarLessEqual op is used to judge lessEqual between variable scalar.
|
||||
class MIND_API ScalarLessEqual : public BaseOperator {
|
||||
/// \brief scalar_le op is used to judge lessEqual between variable scalar.
|
||||
class MIND_API scalar_le : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ScalarLessEqual);
|
||||
MIND_API_BASE_MEMBER(scalar_le);
|
||||
/// \brief Constructor.
|
||||
ScalarLessEqual() : BaseOperator(prim::kScalarLe) {}
|
||||
scalar_le() : BaseOperator(prim::kScalarLe) {}
|
||||
/// \brief Init.
|
||||
void Init() const {}
|
||||
};
|
||||
|
|
|
@ -21,12 +21,12 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
/// \brief ScalarLess op is used to judge less between variable scalar.
|
||||
class MIND_API ScalarLess : public BaseOperator {
|
||||
/// \brief scalar_lt op is used to judge less between variable scalar.
|
||||
class MIND_API scalar_lt : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ScalarLess);
|
||||
MIND_API_BASE_MEMBER(scalar_lt);
|
||||
/// \brief Constructor.
|
||||
ScalarLess() : BaseOperator(prim::kScalarLt) {}
|
||||
scalar_lt() : BaseOperator(prim::kScalarLt) {}
|
||||
/// \brief Init.
|
||||
void Init() const {}
|
||||
};
|
||||
|
|
|
@ -102,9 +102,9 @@ class SequenceGetItemInfer : public abstract::OpInferBase {
|
|||
};
|
||||
MIND_API_OPERATOR_IMPL(TupleGetItem, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(RealTupleGetItem, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(ListGetItem, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(list_getitem, BaseOperator);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(TupleGetItem, prim::kPrimTupleGetItem, SequenceGetItemInfer, false);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(RealTupleGetItem, prim::kPrimRealTupleGetItem, SequenceGetItemInfer, false);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(ListGetItem, prim::kPrimListGetItem, SequenceGetItemInfer, false);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(list_getitem, prim::kPrimListGetItem, SequenceGetItemInfer, false);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
namespace mindspore {
|
||||
// clang-format off
|
||||
#ifndef ENABLE_SECURITY
|
||||
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "ListGetItem",
|
||||
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "list_getitem",
|
||||
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
|
||||
"list_append", "list_map", "list_reduce", "tuple_reversed", "tile_shape", "tuple_div", "tuple_to_array",
|
||||
"make_dict", "make_slice", "string_eq", "VirtualLoss", "Return", "env_getitem",
|
||||
|
@ -33,7 +33,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
|
|||
"InvertPermutation", "DropoutGenMask", "StatelessDropOutGenMask", "embed", "create_instance", "RefToEmbed",
|
||||
"StopGradient", "UpdateState", "Load", "Switch", "Print", "call_instance"};
|
||||
#else
|
||||
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "ListGetItem",
|
||||
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "list_getitem",
|
||||
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
|
||||
"list_append", "list_map", "list_reduce", "tuple_reversed", "tile_shape", "tuple_div", "tuple_to_array",
|
||||
"make_dict", "make_slice", "string_eq", "VirtualLoss", "Return", "env_getitem",
|
||||
|
|
|
@ -65,7 +65,7 @@ def TupleGetItem(x, index):
|
|||
return x[index]
|
||||
|
||||
|
||||
def ScalarGreater(x, y):
|
||||
def scalar_gt(x, y):
|
||||
"""Implement `scalar_gt`."""
|
||||
return x > y
|
||||
|
||||
|
@ -75,17 +75,17 @@ def scalar_ne(x, y):
|
|||
return x != y
|
||||
|
||||
|
||||
def ScalarEqual(x, y):
|
||||
def scalar_eq(x, y):
|
||||
"""Implement `scalar_eq`."""
|
||||
return x == y
|
||||
|
||||
|
||||
def ScalarLessEqual(x, y):
|
||||
def scalar_le(x, y):
|
||||
"""Implement `scalar_le`."""
|
||||
return x <= y
|
||||
|
||||
|
||||
def ScalarLess(x, y):
|
||||
def scalar_lt(x, y):
|
||||
"""Implement `scalar_lt`."""
|
||||
return x < y
|
||||
|
||||
|
@ -117,7 +117,7 @@ def Switch(c, x, y):
|
|||
return x if c else y
|
||||
|
||||
|
||||
def ListGetItem(data, item):
|
||||
def list_getitem(data, item):
|
||||
"""Implement `list_getitem`."""
|
||||
return data[item]
|
||||
|
||||
|
|
|
@ -72,7 +72,7 @@ def bprop_tuple_getitem(data, idx, out, dout):
|
|||
return F.tuple_setitem(zeros_like(data), idx, dout), zeros_like(idx)
|
||||
|
||||
|
||||
@bprops.register("ListGetItem")
|
||||
@bprops.register("list_getitem")
|
||||
def bprop_list_getitem(data, idx, out, dout):
|
||||
"""Backpropagator for primitive `list_getitem`."""
|
||||
return F.list_setitem(zeros_like(data), idx, dout), zeros_like(idx)
|
||||
|
|
|
@ -64,7 +64,7 @@ def get_bprop_scalar_div(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(_scalar_ops.ScalarFloorDiv)
|
||||
@bprop_getters.register(_scalar_ops.ScalarFloordiv)
|
||||
def get_bprop_scalar_floordiv(self):
|
||||
"""Grad definition for `ScalarFloorDiv` operation."""
|
||||
|
||||
|
@ -86,13 +86,13 @@ def get_bprop_scalar_mod(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(_scalar_ops.ScalarEqual)
|
||||
@bprop_getters.register(_scalar_ops.ScalarLessEqual)
|
||||
@bprop_getters.register(_scalar_ops.ScalarLess)
|
||||
@bprop_getters.register(_scalar_ops.ScalarGreaterEqual)
|
||||
@bprop_getters.register(_scalar_ops.ScalarGreater)
|
||||
@bprop_getters.register(_scalar_ops.ScalarBitwiseAnd)
|
||||
@bprop_getters.register(_scalar_ops.ScalarBitwiseOr)
|
||||
@bprop_getters.register(_scalar_ops.scalar_eq)
|
||||
@bprop_getters.register(_scalar_ops.scalar_le)
|
||||
@bprop_getters.register(_scalar_ops.scalar_lt)
|
||||
@bprop_getters.register(_scalar_ops.scalar_ge)
|
||||
@bprop_getters.register(_scalar_ops.scalar_gt)
|
||||
@bprop_getters.register(_scalar_ops.bit_and)
|
||||
@bprop_getters.register(_scalar_ops.bit_or)
|
||||
def get_bprop_scalar_logic(self):
|
||||
"""Grad definition for `ScalarLogicOps` operation."""
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ def get_identity_vmap_rule(prim, axis_size):
|
|||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register("ListGetItem")
|
||||
@vmap_rules_getters.register("list_getitem")
|
||||
@vmap_rules_getters.register("TupleGetItem")
|
||||
def get_seq_get_item_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `list_getitem` or `TupleGetItem` operation."""
|
||||
|
|
|
@ -73,7 +73,7 @@ class _ListSlice(base.SequenceSliceGetItem_):
|
|||
|
||||
def __init__(self, name):
|
||||
"""Initialize _TupleSlice."""
|
||||
base.SequenceSliceGetItem_.__init__(self, name, "make_list", "ListGetItem")
|
||||
base.SequenceSliceGetItem_.__init__(self, name, "make_list", "list_getitem")
|
||||
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
|
|
@ -61,16 +61,16 @@ scalar_mod = _scalar_ops.ScalarMod()
|
|||
scalar_add = _scalar_ops.ScalarAdd()
|
||||
scalar_mul = _scalar_ops.ScalarMul()
|
||||
scalar_sub = _scalar_ops.ScalarSub()
|
||||
scalar_gt = _scalar_ops.ScalarGreater()
|
||||
scalar_ge = _scalar_ops.ScalarGreaterEqual()
|
||||
scalar_le = _scalar_ops.ScalarLessEqual()
|
||||
scalar_lt = _scalar_ops.ScalarLess()
|
||||
scalar_eq = _scalar_ops.ScalarEqual()
|
||||
scalar_floordiv = _scalar_ops.ScalarFloorDiv()
|
||||
scalar_gt = _scalar_ops.scalar_gt()
|
||||
scalar_ge = _scalar_ops.scalar_ge()
|
||||
scalar_le = _scalar_ops.scalar_le()
|
||||
scalar_lt = _scalar_ops.scalar_lt()
|
||||
scalar_eq = _scalar_ops.scalar_eq()
|
||||
scalar_floordiv = _scalar_ops.ScalarFloordiv()
|
||||
|
||||
tuple_setitem = Primitive('tuple_setitem')
|
||||
tuple_getitem = Primitive(_constants.kTupleGetItem)
|
||||
list_getitem = Primitive('ListGetItem')
|
||||
list_getitem = Primitive('list_getitem')
|
||||
list_setitem = Primitive('list_setitem')
|
||||
dict_getitem = Primitive('dict_getitem')
|
||||
dict_setitem = Primitive('dict_setitem')
|
||||
|
|
|
@ -21,7 +21,7 @@ import numpy as np
|
|||
from mindspore.common import Tensor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops.operations.array_ops import Cast
|
||||
from mindspore.ops.operations._scalar_ops import ScalarBitwiseOr, ScalarBitwiseAnd
|
||||
from mindspore.ops.operations._scalar_ops import bit_or, bit_and
|
||||
from mindspore.ops import signature as sig
|
||||
from mindspore.ops.operations.math_ops import _infer_shape_reduce
|
||||
from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, _run_op
|
||||
|
@ -36,8 +36,8 @@ from mindspore.common._register_for_adapter import ms_adapter_registry
|
|||
|
||||
|
||||
# Bit operation
|
||||
bit_and = ScalarBitwiseAnd()
|
||||
bit_or = ScalarBitwiseOr()
|
||||
bit_and = bit_and()
|
||||
bit_or = bit_or()
|
||||
bit_xor = Primitive("bit_xor")
|
||||
bit_left_shift = Primitive("bit_left_shift")
|
||||
bit_right_shift = Primitive("bit_right_shift")
|
||||
|
|
|
@ -48,7 +48,7 @@ class ScalarDiv(Primitive):
|
|||
"""Initialize ScalarDiv"""
|
||||
|
||||
|
||||
class ScalarFloorDiv(Primitive):
|
||||
class ScalarFloordiv(Primitive):
|
||||
r"""
|
||||
Computes the quotient of dividing the first input scalar by the second input scalar element-wise.
|
||||
|
||||
|
@ -76,7 +76,7 @@ class ScalarFloorDiv(Primitive):
|
|||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize ScalarFloorDiv"""
|
||||
"""Initialize ScalarFloordiv"""
|
||||
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
||||
|
||||
|
||||
|
@ -158,7 +158,7 @@ class ScalarMul(Primitive):
|
|||
"""Initialize ScalarMul"""
|
||||
|
||||
|
||||
class ScalarEqual(Primitive):
|
||||
class scalar_eq(Primitive):
|
||||
r"""
|
||||
Computes the equivalence between two Scalars.
|
||||
|
||||
|
@ -184,7 +184,7 @@ class ScalarEqual(Primitive):
|
|||
"""Initialize ScalarMul"""
|
||||
|
||||
|
||||
class ScalarGreater(Primitive):
|
||||
class scalar_gt(Primitive):
|
||||
r"""
|
||||
Compare the value of the input scalars :math:`x,y`, and the output result is a bool value.
|
||||
|
||||
|
@ -207,10 +207,10 @@ class ScalarGreater(Primitive):
|
|||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize ScalarGreater"""
|
||||
"""Initialize scalar_gt"""
|
||||
|
||||
|
||||
class ScalarLess(Primitive):
|
||||
class scalar_lt(Primitive):
|
||||
r"""
|
||||
Computes the boolean value of :math:`x < y`.
|
||||
|
||||
|
@ -233,10 +233,10 @@ class ScalarLess(Primitive):
|
|||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize ScalarLess"""
|
||||
"""Initialize scalar_lt"""
|
||||
|
||||
|
||||
class ScalarGreaterEqual(Primitive):
|
||||
class scalar_ge(Primitive):
|
||||
r"""
|
||||
Compare the value of the input scalars :math:`x,y`, and the output result is a bool value.
|
||||
|
||||
|
@ -259,10 +259,10 @@ class ScalarGreaterEqual(Primitive):
|
|||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize ScalarGreaterEqual"""
|
||||
"""Initialize scalar_ge"""
|
||||
|
||||
|
||||
class ScalarLessEqual(Primitive):
|
||||
class scalar_le(Primitive):
|
||||
r"""
|
||||
Compare the value of the input scalars :math:`x,y`, and the output result is a bool value.
|
||||
|
||||
|
@ -285,7 +285,7 @@ class ScalarLessEqual(Primitive):
|
|||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize ScalarLessEqual"""
|
||||
"""Initialize scalar_le"""
|
||||
|
||||
|
||||
class ScalarMod(Primitive):
|
||||
|
@ -343,7 +343,7 @@ class ScalarBool(Primitive):
|
|||
"""Initialize ScalarBool"""
|
||||
|
||||
|
||||
class ScalarBitwiseAnd(Primitive):
|
||||
class bit_and(Primitive):
|
||||
r"""
|
||||
Returns bitwise `and` of two scalars.
|
||||
|
||||
|
@ -373,7 +373,7 @@ class ScalarBitwiseAnd(Primitive):
|
|||
"""Initialize ScalarMod"""
|
||||
|
||||
|
||||
class ScalarBitwiseOr(Primitive):
|
||||
class bit_or(Primitive):
|
||||
r"""
|
||||
Returns bitwise `or` of two scalars.
|
||||
|
||||
|
|
|
@ -20,8 +20,9 @@ from tuple_help import TupleFactory
|
|||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scalar_add():
|
||||
"""
|
||||
|
@ -45,8 +46,9 @@ def test_scalar_add():
|
|||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scalar_sub():
|
||||
"""
|
||||
|
@ -71,8 +73,9 @@ def test_scalar_sub():
|
|||
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scalar_mul():
|
||||
"""
|
||||
|
@ -98,6 +101,7 @@ def test_scalar_mul():
|
|||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scalar_div():
|
||||
"""
|
||||
|
@ -121,8 +125,9 @@ def test_scalar_div():
|
|||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scalar_mod():
|
||||
"""
|
||||
|
@ -146,8 +151,9 @@ def test_scalar_mod():
|
|||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scalar_floordiv():
|
||||
"""
|
||||
|
@ -171,12 +177,13 @@ def test_scalar_floordiv():
|
|||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scalar_eq():
|
||||
"""
|
||||
Feature: test ScalarEqual.
|
||||
Feature: test scalar_eq.
|
||||
Description: inputs is dynamic scalar.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
|
@ -196,12 +203,13 @@ def test_scalar_eq():
|
|||
fact.grad_impl()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scalar_ge():
|
||||
"""
|
||||
Feature: test ScalarGreaterEqual.
|
||||
Feature: test scalar_ge.
|
||||
Description: inputs is dynamic scalar.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
|
@ -223,10 +231,11 @@ def test_scalar_ge():
|
|||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scalar_gt():
|
||||
"""
|
||||
Feature: test ScalarGreater.
|
||||
Feature: test scalar_gt.
|
||||
Description: inputs is dynamic scalar.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
|
@ -248,10 +257,11 @@ def test_scalar_gt():
|
|||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scalar_le():
|
||||
"""
|
||||
Feature: test ScalarLessEqual.
|
||||
Feature: test scalar_le.
|
||||
Description: inputs is dynamic scalar.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
|
@ -273,10 +283,11 @@ def test_scalar_le():
|
|||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scalar_lt():
|
||||
"""
|
||||
Feature: test ScalarLess.
|
||||
Feature: test scalar_lt.
|
||||
Description: inputs is dynamic scalar.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops.operations import _sequence_ops as seq
|
||||
import mindspore.ops as ops
|
||||
from mindspore import context
|
||||
from mindspore.common import mutable
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.tuple_to_tensor = seq.TupleToTensor()
|
||||
self.scalar_to_tensor = ops.ScalarToTensor()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.tuple_to_tensor(x, mstype.float32), self.scalar_to_tensor(y, mstype.int64)
|
||||
|
||||
|
||||
def dyn_case():
|
||||
x = mutable((1, 2, 3), True)
|
||||
y = mutable(3)
|
||||
expect_x = np.array([1, 2, 3], dtype=np.float32)
|
||||
expect_y = np.array(3, dtype=np.int64)
|
||||
net = Net()
|
||||
res_x, res_y = net(x, y)
|
||||
rtol = 1.e-4
|
||||
atol = 1.e-4
|
||||
assert np.allclose(res_x.asnumpy(), expect_x, rtol, atol, equal_nan=True)
|
||||
assert np.allclose(res_y.asnumpy(), expect_y, rtol, atol, equal_nan=True)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_seq_to_tensor():
|
||||
"""
|
||||
Feature: test xxToTensor.
|
||||
Description: inputs is dynamic sequence or scalar.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
dyn_case()
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops.operations import _sequence_ops as seq
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, axis=0):
|
||||
super().__init__()
|
||||
self.tensor_to_tuple = seq.TensorToTuple()
|
||||
self.tensor_to_scalar = seq.TensorToScalar()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.tensor_to_tuple(x), self.tensor_to_scalar(y)
|
||||
|
||||
|
||||
def dyn_case():
|
||||
x = Tensor([1, 2, 3], mstype.int64)
|
||||
y = Tensor(1, mstype.float32)
|
||||
expect_x = (1, 2, 3)
|
||||
expect_y = 1.0
|
||||
net = Net()
|
||||
res_x, res_y = net(x, y)
|
||||
assert expect_x == res_x
|
||||
assert expect_y == res_y
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_seq_to_tensor():
|
||||
"""
|
||||
Feature: test TensorToxx.
|
||||
Description: inputs is dynamic sequence or scalar.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
dyn_case()
|
|
@ -263,7 +263,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
|
|||
/// Description: The second input is a scalar
|
||||
/// Expectation: Throw type error
|
||||
TEST_F(TestComposite, test_ListSlice_arg_one_number) {
|
||||
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "ListGetItem");
|
||||
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
|
||||
FuncGraphPtr list_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 3);
|
||||
|
||||
AbstractBasePtrList eles;
|
||||
|
@ -298,7 +298,7 @@ TEST_F(TestComposite, test_ListSlice_arg_one_number) {
|
|||
/// Expectation: No Expectation
|
||||
TEST_F(TestComposite, test_ListSlice_arg_slice) {
|
||||
std::shared_ptr<py::scoped_interpreter> env = python_adapter::set_python_scoped();
|
||||
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "ListGetItem");
|
||||
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
|
||||
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
|
||||
|
||||
AbstractBasePtrList eles;
|
||||
|
@ -327,7 +327,7 @@ TEST_F(TestComposite, test_ListSlice_arg_slice) {
|
|||
/// Description: Test List slice the step is none
|
||||
/// Expectation: No Expectation
|
||||
TEST_F(TestComposite, test_ListSlice_arg_slice_step_none) {
|
||||
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "ListGetItem");
|
||||
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
|
||||
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
|
||||
|
||||
AbstractBasePtrList eles;
|
||||
|
@ -356,7 +356,7 @@ TEST_F(TestComposite, test_ListSlice_arg_slice_step_none) {
|
|||
/// Description: Test List slice the step is negative
|
||||
/// Expectation: No Expectation
|
||||
TEST_F(TestComposite, test_ListSlice_arg_slice_step_negative) {
|
||||
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "ListGetItem");
|
||||
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
|
||||
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
|
||||
|
||||
AbstractBasePtrList eles;
|
||||
|
@ -385,7 +385,7 @@ TEST_F(TestComposite, test_ListSlice_arg_slice_step_negative) {
|
|||
/// Description: Test List slice the step is positive
|
||||
/// Expectation: No Expectation
|
||||
TEST_F(TestComposite, test_ListSlice_arg_slice_step_positive) {
|
||||
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "ListGetItem");
|
||||
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
|
||||
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
|
||||
|
||||
AbstractBasePtrList eles;
|
||||
|
|
|
@ -189,7 +189,7 @@ TEST_F(TestOps, TupleGetItemTest) {
|
|||
}
|
||||
|
||||
TEST_F(TestOps, ListGetItemTest) {
|
||||
auto prim = std::make_shared<Primitive>("ListGetItem");
|
||||
auto prim = std::make_shared<Primitive>("list_getitem");
|
||||
ASSERT_EQ(prim->name(), kPrimListGetItem->name());
|
||||
}
|
||||
|
||||
|
|
|
@ -682,7 +682,7 @@ TEST_F(TestPrim, test_list_getitem) {
|
|||
args_spec_list.push_back(abstract_v1);
|
||||
args_spec_list.push_back(abstract_v2);
|
||||
|
||||
auto prim = std::make_shared<Primitive>("ListGetItem");
|
||||
auto prim = std::make_shared<Primitive>("list_getitem");
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
|
||||
|
|
Loading…
Reference in New Issue