1、Optimize the error message in operator 2、add operator test cases
This commit is contained in:
parent
a9fbb25530
commit
7b93e4e87c
|
@ -125,7 +125,7 @@ T InnerScalarMul(T x, T y) {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
float InnerScalarDiv(T x, T y) {
|
float InnerScalarDiv(T x, T y) {
|
||||||
if (y == 0) {
|
if (y == 0) {
|
||||||
MS_LOG(EXCEPTION) << "Divisor could not be zero";
|
MS_LOG(EXCEPTION) << "The divisor could not be zero.";
|
||||||
}
|
}
|
||||||
if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::DIV)) {
|
if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::DIV)) {
|
||||||
MS_LOG(EXCEPTION) << "Overflow of the div of two signed number x: " << std::to_string(x)
|
MS_LOG(EXCEPTION) << "Overflow of the div of two signed number x: " << std::to_string(x)
|
||||||
|
@ -199,7 +199,7 @@ bool InnerScalarGe(T x, U y) {
|
||||||
ValuePtr Scalar##op_t(const ValuePtrList &list) { \
|
ValuePtr Scalar##op_t(const ValuePtrList &list) { \
|
||||||
do { \
|
do { \
|
||||||
if (list.size() < 2) { \
|
if (list.size() < 2) { \
|
||||||
MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \
|
MS_LOG(EXCEPTION) << "The length of input list for Scalar" << #op_t << " is less than 2."; \
|
||||||
} \
|
} \
|
||||||
ValuePtr x = list[0]; \
|
ValuePtr x = list[0]; \
|
||||||
ValuePtr y = list[1]; \
|
ValuePtr y = list[1]; \
|
||||||
|
@ -267,73 +267,75 @@ SCALAR_OP(Mod)
|
||||||
SCALAR_OP(Pow)
|
SCALAR_OP(Pow)
|
||||||
SCALAR_OP(Floordiv)
|
SCALAR_OP(Floordiv)
|
||||||
|
|
||||||
#define LOGIC_OP(op_t) \
|
#define LOGIC_OP(op_t) \
|
||||||
ValuePtr Scalar##op_t(const ValuePtrList &list) { \
|
ValuePtr Scalar##op_t(const ValuePtrList &list) { \
|
||||||
if (list.size() < 2) { \
|
constexpr size_t kListInputSize = 2; \
|
||||||
MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \
|
if (list.size() < kListInputSize) { \
|
||||||
} \
|
MS_LOG(EXCEPTION) << "The length of input list for Scalar" << #op_t << " is less than 2."; \
|
||||||
ValuePtr x = list[0]; \
|
} \
|
||||||
ValuePtr y = list[1]; \
|
ValuePtr x = list[0]; \
|
||||||
MS_EXCEPTION_IF_NULL(x); \
|
ValuePtr y = list[1]; \
|
||||||
MS_EXCEPTION_IF_NULL(y); \
|
MS_EXCEPTION_IF_NULL(x); \
|
||||||
if (x->isa<FP64Imm>() && y->isa<FP64Imm>()) { \
|
MS_EXCEPTION_IF_NULL(y); \
|
||||||
bool sum = InnerScalar##op_t(GetValue<double>(x), GetValue<double>(y)); \
|
if (x->isa<FP64Imm>() && y->isa<FP64Imm>()) { \
|
||||||
return MakeValue(sum); \
|
bool sum = InnerScalar##op_t(GetValue<double>(x), GetValue<double>(y)); \
|
||||||
} \
|
return MakeValue(sum); \
|
||||||
if (x->isa<FP32Imm>() && y->isa<FP32Imm>()) { \
|
} \
|
||||||
bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<float>(y)); \
|
if (x->isa<FP32Imm>() && y->isa<FP32Imm>()) { \
|
||||||
return MakeValue(sum); \
|
bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<float>(y)); \
|
||||||
} \
|
return MakeValue(sum); \
|
||||||
if (x->isa<FP64Imm>() && y->isa<FP32Imm>()) { \
|
} \
|
||||||
bool sum = InnerScalar##op_t(GetValue<double>(x), GetValue<float>(y)); \
|
if (x->isa<FP64Imm>() && y->isa<FP32Imm>()) { \
|
||||||
return MakeValue(sum); \
|
bool sum = InnerScalar##op_t(GetValue<double>(x), GetValue<float>(y)); \
|
||||||
} \
|
return MakeValue(sum); \
|
||||||
if (x->isa<FP32Imm>() && y->isa<FP64Imm>()) { \
|
} \
|
||||||
bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<double>(y)); \
|
if (x->isa<FP32Imm>() && y->isa<FP64Imm>()) { \
|
||||||
return MakeValue(sum); \
|
bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<double>(y)); \
|
||||||
} \
|
return MakeValue(sum); \
|
||||||
if (x->isa<Int32Imm>() && y->isa<Int32Imm>()) { \
|
} \
|
||||||
bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int>(y)); \
|
if (x->isa<Int32Imm>() && y->isa<Int32Imm>()) { \
|
||||||
return MakeValue(sum); \
|
bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int>(y)); \
|
||||||
} \
|
return MakeValue(sum); \
|
||||||
if (x->isa<FP32Imm>() && y->isa<Int32Imm>()) { \
|
} \
|
||||||
bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<int>(y)); \
|
if (x->isa<FP32Imm>() && y->isa<Int32Imm>()) { \
|
||||||
return MakeValue(sum); \
|
bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<int>(y)); \
|
||||||
} \
|
return MakeValue(sum); \
|
||||||
if (x->isa<FP32Imm>() && y->isa<Int64Imm>()) { \
|
} \
|
||||||
bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<int64_t>(y)); \
|
if (x->isa<FP32Imm>() && y->isa<Int64Imm>()) { \
|
||||||
return MakeValue(sum); \
|
bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<int64_t>(y)); \
|
||||||
} \
|
return MakeValue(sum); \
|
||||||
if (x->isa<Int32Imm>() && y->isa<FP32Imm>()) { \
|
} \
|
||||||
bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<float>(y)); \
|
if (x->isa<Int32Imm>() && y->isa<FP32Imm>()) { \
|
||||||
return MakeValue(sum); \
|
bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<float>(y)); \
|
||||||
} \
|
return MakeValue(sum); \
|
||||||
if (x->isa<Int64Imm>() && y->isa<FP32Imm>()) { \
|
} \
|
||||||
bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<float>(y)); \
|
if (x->isa<Int64Imm>() && y->isa<FP32Imm>()) { \
|
||||||
return MakeValue(sum); \
|
bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<float>(y)); \
|
||||||
} \
|
return MakeValue(sum); \
|
||||||
if (x->isa<Int64Imm>() && y->isa<Int64Imm>()) { \
|
} \
|
||||||
bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int64_t>(y)); \
|
if (x->isa<Int64Imm>() && y->isa<Int64Imm>()) { \
|
||||||
return MakeValue(sum); \
|
bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int64_t>(y)); \
|
||||||
} \
|
return MakeValue(sum); \
|
||||||
if (x->isa<FP64Imm>() && y->isa<Int64Imm>()) { \
|
} \
|
||||||
bool sum = InnerScalar##op_t(GetValue<double>(x), GetValue<int64_t>(y)); \
|
if (x->isa<FP64Imm>() && y->isa<Int64Imm>()) { \
|
||||||
return MakeValue(sum); \
|
bool sum = InnerScalar##op_t(GetValue<double>(x), GetValue<int64_t>(y)); \
|
||||||
} \
|
return MakeValue(sum); \
|
||||||
if (x->isa<Int64Imm>() && y->isa<FP64Imm>()) { \
|
} \
|
||||||
bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<double>(y)); \
|
if (x->isa<Int64Imm>() && y->isa<FP64Imm>()) { \
|
||||||
return MakeValue(sum); \
|
bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<double>(y)); \
|
||||||
} \
|
return MakeValue(sum); \
|
||||||
if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) { \
|
} \
|
||||||
bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int>(y)); \
|
if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) { \
|
||||||
return MakeValue(sum); \
|
bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int>(y)); \
|
||||||
} \
|
return MakeValue(sum); \
|
||||||
if (x->isa<Int32Imm>() && y->isa<Int64Imm>()) { \
|
} \
|
||||||
bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int64_t>(y)); \
|
if (x->isa<Int32Imm>() && y->isa<Int64Imm>()) { \
|
||||||
return MakeValue(sum); \
|
bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int64_t>(y)); \
|
||||||
} \
|
return MakeValue(sum); \
|
||||||
MS_LOG(EXCEPTION) << "Unsupported Value for Scalar" << #op_t << ", x: " << x->ToString() \
|
} \
|
||||||
<< ", y: " << y->ToString() << "."; \
|
MS_LOG(EXCEPTION) << "Unsupported input type for Scalar" << #op_t << ", type of x:" << x->type_name() \
|
||||||
|
<< ", value of x:" << x->ToString() << ", type of y:" << y->type_name() \
|
||||||
|
<< ", value of y:" << y->ToString(); \
|
||||||
}
|
}
|
||||||
|
|
||||||
LOGIC_OP(Eq)
|
LOGIC_OP(Eq)
|
||||||
|
@ -372,12 +374,12 @@ ValuePtr ScalarUSub(const ValuePtrList &list) {
|
||||||
return MakeValue(sum);
|
return MakeValue(sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(EXCEPTION) << "Unsported Value for ScalarUSub, x: " << x->ToString() << ".";
|
MS_LOG(EXCEPTION) << "Unsupported Value for ScalarUSub, x: " << x->ToString() << ".";
|
||||||
}
|
}
|
||||||
|
|
||||||
ValuePtr ScalarLog(const ValuePtrList &list) {
|
ValuePtr ScalarLog(const ValuePtrList &list) {
|
||||||
if (list.empty()) {
|
if (list.size() != 1) {
|
||||||
MS_LOG(EXCEPTION) << "Input list of ScalarLog is empty.";
|
MS_LOG(EXCEPTION) << "Input number of ScalarLog should be 1, but got " << list.size();
|
||||||
}
|
}
|
||||||
ValuePtr x = list[0];
|
ValuePtr x = list[0];
|
||||||
MS_EXCEPTION_IF_NULL(x);
|
MS_EXCEPTION_IF_NULL(x);
|
||||||
|
@ -391,12 +393,12 @@ ValuePtr ScalarLog(const ValuePtrList &list) {
|
||||||
return MakeValue(v);
|
return MakeValue(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(EXCEPTION) << "Unsported Value for ScalarLog, x: " << x->ToString();
|
MS_LOG(EXCEPTION) << "Unsupported Value for ScalarLog, x: " << x->ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
ValuePtr BoolNot(const ValuePtrList &list) {
|
ValuePtr BoolNot(const ValuePtrList &list) {
|
||||||
if (list.empty()) {
|
if (list.size() != 1) {
|
||||||
MS_LOG(EXCEPTION) << "value list of BoolNot is empty";
|
MS_LOG(EXCEPTION) << "Input number of BoolNot should be 1, but got " << list.size();
|
||||||
}
|
}
|
||||||
ValuePtr x = list[0];
|
ValuePtr x = list[0];
|
||||||
MS_EXCEPTION_IF_NULL(x);
|
MS_EXCEPTION_IF_NULL(x);
|
||||||
|
@ -407,12 +409,13 @@ ValuePtr BoolNot(const ValuePtrList &list) {
|
||||||
return MakeValue(res);
|
return MakeValue(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(EXCEPTION) << "Unsported Value for BoolNot, x: " << x->ToString();
|
MS_LOG(EXCEPTION) << "Unsupported Value for BoolNot, x: " << x->ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
ValuePtr BoolAnd(const ValuePtrList &list) {
|
ValuePtr BoolAnd(const ValuePtrList &list) {
|
||||||
if (list.size() < 2) {
|
constexpr size_t kListInputSize = 2;
|
||||||
MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolAnd is less then 2.";
|
if (list.size() != kListInputSize) {
|
||||||
|
MS_LOG(EXCEPTION) << "Input number of BoolAnd must be 2, but got " << list.size();
|
||||||
}
|
}
|
||||||
ValuePtr x = list[0];
|
ValuePtr x = list[0];
|
||||||
ValuePtr y = list[1];
|
ValuePtr y = list[1];
|
||||||
|
@ -426,12 +429,13 @@ ValuePtr BoolAnd(const ValuePtrList &list) {
|
||||||
return MakeValue(res);
|
return MakeValue(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(EXCEPTION) << "Unsported Value for BoolAnd, x: " << x->ToString() << ".";
|
MS_LOG(EXCEPTION) << "Unsupported Value for BoolAnd, x: " << x->ToString() << " y: " << y->ToString() << ".";
|
||||||
}
|
}
|
||||||
|
|
||||||
ValuePtr BoolOr(const ValuePtrList &list) {
|
ValuePtr BoolOr(const ValuePtrList &list) {
|
||||||
if (list.size() < 2) {
|
constexpr size_t kListInputSize = 2;
|
||||||
MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolOr is less then 2.";
|
if (list.size() != kListInputSize) {
|
||||||
|
MS_LOG(EXCEPTION) << "Input number of BoolOr must be 2, but got " << list.size();
|
||||||
}
|
}
|
||||||
ValuePtr x = list[0];
|
ValuePtr x = list[0];
|
||||||
ValuePtr y = list[1];
|
ValuePtr y = list[1];
|
||||||
|
@ -445,12 +449,13 @@ ValuePtr BoolOr(const ValuePtrList &list) {
|
||||||
return MakeValue(res);
|
return MakeValue(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(EXCEPTION) << "Unsported Value for BoolOr, x: " << x->ToString() << ".";
|
MS_LOG(EXCEPTION) << "Unsupported Value for BoolOr, x: " << x->ToString() << " y: " << y->ToString() << ".";
|
||||||
}
|
}
|
||||||
|
|
||||||
ValuePtr BoolEq(const ValuePtrList &list) {
|
ValuePtr BoolEq(const ValuePtrList &list) {
|
||||||
if (list.size() < 2) {
|
constexpr size_t kListInputSize = 2;
|
||||||
MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolEq is less than 2.";
|
if (list.size() != kListInputSize) {
|
||||||
|
MS_LOG(EXCEPTION) << "Input number of BoolEq must be 2, but got " << list.size();
|
||||||
}
|
}
|
||||||
ValuePtr x = list[0];
|
ValuePtr x = list[0];
|
||||||
ValuePtr y = list[1];
|
ValuePtr y = list[1];
|
||||||
|
@ -464,7 +469,7 @@ ValuePtr BoolEq(const ValuePtrList &list) {
|
||||||
return MakeValue(res);
|
return MakeValue(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(EXCEPTION) << "Unsported Value for BoolEq, x: " << x->ToString() << ".";
|
MS_LOG(EXCEPTION) << "Unsupported Value for BoolEq, x: " << x->ToString() << " y: " << y->ToString() << ".";
|
||||||
}
|
}
|
||||||
} // namespace prim
|
} // namespace prim
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -131,7 +131,7 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraph
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
if (is_not_same) {
|
if (is_not_same) {
|
||||||
MS_LOG(EXCEPTION) << "List in HyperMap should have same length";
|
MS_LOG(EXCEPTION) << "List in HyperMap should have same length.";
|
||||||
}
|
}
|
||||||
|
|
||||||
// cannot use shared_from_base() also known as this, as it will make a reference cycle on
|
// cannot use shared_from_base() also known as this, as it will make a reference cycle on
|
||||||
|
@ -189,7 +189,7 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGrap
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
if (is_not_same) {
|
if (is_not_same) {
|
||||||
MS_LOG(EXCEPTION) << "Tuple in HyperMap should have same length";
|
MS_LOG(EXCEPTION) << "Tuple in HyperMap should have same length.";
|
||||||
}
|
}
|
||||||
|
|
||||||
// cannot use shared_from_base() also known as this, as it will make a reference cycle on
|
// cannot use shared_from_base() also known as this, as it will make a reference cycle on
|
||||||
|
@ -469,7 +469,7 @@ FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list)
|
||||||
return GenerateSequeueFuncGraph(a->cast<abstract::AbstractSequeuePtr>());
|
return GenerateSequeueFuncGraph(a->cast<abstract::AbstractSequeuePtr>());
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(EXCEPTION) << "arg0 must be AbstractTuple or AbstractList, but: " << a->ToString();
|
MS_LOG(EXCEPTION) << "'Tail' arg0 must be AbstractTuple or AbstractList, but: " << a->ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_PYBIND_DEFINE(
|
REGISTER_PYBIND_DEFINE(
|
||||||
|
|
|
@ -237,7 +237,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
|
||||||
if ((arg_value->isa<TensorType>()) && arg_type_id == it->second) {
|
if ((arg_value->isa<TensorType>()) && arg_type_id == it->second) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "do cast for inputs " << i << " " << (*op_inputs)[i + 1]->ToString() << " " << arg_type_id
|
MS_LOG(DEBUG) << "Do cast for inputs " << i << " " << (*op_inputs)[i + 1]->ToString() << " " << arg_type_id
|
||||||
<< " to " << it->second;
|
<< " to " << it->second;
|
||||||
(*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph);
|
(*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph);
|
||||||
}
|
}
|
||||||
|
@ -339,7 +339,7 @@ void RaiseExceptionForConvertRefDtype(const std::string &func_name, const std::s
|
||||||
const std::string &target_type) {
|
const std::string &target_type) {
|
||||||
MS_LOG(EXCEPTION) << "In op '" << func_name << "', \n"
|
MS_LOG(EXCEPTION) << "In op '" << func_name << "', \n"
|
||||||
<< "the type of writable argument is '" << ref_type << "', "
|
<< "the type of writable argument is '" << ref_type << "', "
|
||||||
<< "but the largest type in the same SignatureEumDtype is '" << target_type
|
<< "but the largest type in the same SignatureEnumDType is '" << target_type
|
||||||
<< "'. The writable arg type is not equal to the largest type, "
|
<< "'. The writable arg type is not equal to the largest type, "
|
||||||
<< "so can not cast automatically.";
|
<< "so can not cast automatically.";
|
||||||
}
|
}
|
||||||
|
|
|
@ -88,7 +88,7 @@ AnfNodePtr Map::FullMakeList(const std::shared_ptr<List> &type, const FuncGraphP
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
if (is_not_same) {
|
if (is_not_same) {
|
||||||
MS_LOG(EXCEPTION) << "List in Map should have same length";
|
MS_LOG(EXCEPTION) << "List in Map should have same length.";
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr size_t kPrimHoldLen = 1;
|
constexpr size_t kPrimHoldLen = 1;
|
||||||
|
@ -147,7 +147,7 @@ AnfNodePtr Map::FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGrap
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
if (is_not_same) {
|
if (is_not_same) {
|
||||||
MS_LOG(EXCEPTION) << "tuple in Map should have same length";
|
MS_LOG(EXCEPTION) << "Tuple in Map should have same length.";
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr size_t kPrimHoldLen = 1;
|
constexpr size_t kPrimHoldLen = 1;
|
||||||
|
@ -227,7 +227,7 @@ AnfNodePtr Map::FullMakeClass(const std::shared_ptr<Class> &type, const FuncGrap
|
||||||
|
|
||||||
AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
|
AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
|
||||||
if (arg_pairs.empty()) {
|
if (arg_pairs.empty()) {
|
||||||
MS_EXCEPTION(TypeError) << "map() must have at least two arguments";
|
MS_EXCEPTION(TypeError) << "The map operator must have at least two arguments.";
|
||||||
}
|
}
|
||||||
bool found = false;
|
bool found = false;
|
||||||
TypeId id = kObjectTypeEnd;
|
TypeId id = kObjectTypeEnd;
|
||||||
|
|
|
@ -85,7 +85,7 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_
|
||||||
{NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(item.first), dict_get_item});
|
{NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(item.first), dict_get_item});
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << op_name << " require args should be tuple or dict, but got "
|
MS_LOG(EXCEPTION) << op_name << " require args should be tuple, list or dict, but got "
|
||||||
<< args_spec_list[index]->ToString();
|
<< args_spec_list[index]->ToString();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,12 +43,12 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spe
|
||||||
MS_LOG(EXCEPTION) << "For 'zip', there is at least one input.";
|
MS_LOG(EXCEPTION) << "For 'zip', there is at least one input.";
|
||||||
}
|
}
|
||||||
|
|
||||||
auto is_all_sequeue =
|
auto all_is_sequence =
|
||||||
std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &abs) -> bool {
|
std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &abs) -> bool {
|
||||||
MS_EXCEPTION_IF_NULL(abs);
|
MS_EXCEPTION_IF_NULL(abs);
|
||||||
return abs->isa<AbstractSequeue>();
|
return abs->isa<AbstractSequeue>();
|
||||||
});
|
});
|
||||||
if (!is_all_sequeue) {
|
if (!all_is_sequence) {
|
||||||
MS_LOG(EXCEPTION) << "For 'zip', all inputs must be sequence.";
|
MS_LOG(EXCEPTION) << "For 'zip', all inputs must be sequence.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spe
|
||||||
(void)ret_graph->add_parameter();
|
(void)ret_graph->add_parameter();
|
||||||
}
|
}
|
||||||
|
|
||||||
// generate tuple output of ziped arguments input
|
// generate tuple output of zipped arguments input
|
||||||
std::vector<AnfNodePtr> make_tuple_nodes;
|
std::vector<AnfNodePtr> make_tuple_nodes;
|
||||||
make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple));
|
make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||||
for (size_t idx = 0; idx < (*min_abs)->cast<AbstractSequeuePtr>()->size(); idx++) {
|
for (size_t idx = 0; idx < (*min_abs)->cast<AbstractSequeuePtr>()->size(); idx++) {
|
||||||
|
|
|
@ -26,7 +26,7 @@ ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name
|
||||||
ValuePtr node = nullptr;
|
ValuePtr node = nullptr;
|
||||||
bool succ = parse::ConvertData(obj, &node, use_signature);
|
bool succ = parse::ConvertData(obj, &node, use_signature);
|
||||||
if (!succ) {
|
if (!succ) {
|
||||||
MS_LOG(EXCEPTION) << "get Python op " << op_name << " from " << module_name << " fail";
|
MS_LOG(EXCEPTION) << "Get Python op " << op_name << " from " << module_name << " fail.";
|
||||||
}
|
}
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
|
@ -64,7 +64,7 @@ void CalcSlidePara(const AbstractBasePtrList &args_spec_list, SlideInfo *slide)
|
||||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||||
auto arg_value = args_spec_list[0]->BuildValue();
|
auto arg_value = args_spec_list[0]->BuildValue();
|
||||||
if (!arg_value->isa<Int64Imm>()) {
|
if (!arg_value->isa<Int64Imm>()) {
|
||||||
MS_LOG(EXCEPTION) << "Only supported input an int64 number.";
|
MS_LOG(EXCEPTION) << "The type of inputs of make_range operator only support int64 number.";
|
||||||
}
|
}
|
||||||
arg1 = GetValue<int64_t>(arg_value);
|
arg1 = GetValue<int64_t>(arg_value);
|
||||||
}
|
}
|
||||||
|
@ -73,7 +73,7 @@ void CalcSlidePara(const AbstractBasePtrList &args_spec_list, SlideInfo *slide)
|
||||||
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
|
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
|
||||||
auto arg_value = args_spec_list[1]->BuildValue();
|
auto arg_value = args_spec_list[1]->BuildValue();
|
||||||
if (!arg_value->isa<Int64Imm>()) {
|
if (!arg_value->isa<Int64Imm>()) {
|
||||||
MS_LOG(EXCEPTION) << "Only supported input an int64 number.";
|
MS_LOG(EXCEPTION) << "The type of inputs of make_range operator only support int64 number.";
|
||||||
}
|
}
|
||||||
arg2 = GetValue<int64_t>(arg_value);
|
arg2 = GetValue<int64_t>(arg_value);
|
||||||
}
|
}
|
||||||
|
@ -82,7 +82,7 @@ void CalcSlidePara(const AbstractBasePtrList &args_spec_list, SlideInfo *slide)
|
||||||
MS_EXCEPTION_IF_NULL(args_spec_list[2]);
|
MS_EXCEPTION_IF_NULL(args_spec_list[2]);
|
||||||
auto arg_value = args_spec_list[2]->BuildValue();
|
auto arg_value = args_spec_list[2]->BuildValue();
|
||||||
if (!arg_value->isa<Int64Imm>()) {
|
if (!arg_value->isa<Int64Imm>()) {
|
||||||
MS_LOG(EXCEPTION) << "Only supported input an int64 number.";
|
MS_LOG(EXCEPTION) << "The type of inputs of make_range operator only support int64 number.";
|
||||||
}
|
}
|
||||||
slide->step = GetValue<int64_t>(arg_value);
|
slide->step = GetValue<int64_t>(arg_value);
|
||||||
slide->start = arg1;
|
slide->start = arg1;
|
||||||
|
@ -183,11 +183,12 @@ AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||||
|
|
||||||
AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
// Inputs: a pointer to an AbstractBase object and a pointer to a Type
|
// Inputs: a pointer to an AbstractBase object and a pointer to a Type
|
||||||
const std::string op_name = primitive->name();
|
const std::string op_name = primitive->name();
|
||||||
CheckArgsSize(op_name, args_spec_list, 2);
|
CheckArgsSize(op_name, args_spec_list, 2);
|
||||||
AbstractTypePtr abs_type = CheckArg<AbstractType>(op_name, args_spec_list, 1);
|
AbstractTypePtr abs_type = CheckArg<AbstractType>(op_name, args_spec_list, 1);
|
||||||
|
MS_EXCEPTION_IF_NULL(abs_type);
|
||||||
auto mode_v = abs_type->GetValueTrack();
|
auto mode_v = abs_type->GetValueTrack();
|
||||||
MS_EXCEPTION_IF_NULL(mode_v);
|
MS_EXCEPTION_IF_NULL(mode_v);
|
||||||
if (!mode_v->isa<Type>()) {
|
if (!mode_v->isa<Type>()) {
|
||||||
|
@ -229,7 +230,7 @@ AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValueP
|
||||||
int64_t e_value = CheckAxis(primitive->name(), elem, -SizeToLong(x_rank), SizeToLong(x_rank) - 1);
|
int64_t e_value = CheckAxis(primitive->name(), elem, -SizeToLong(x_rank), SizeToLong(x_rank) - 1);
|
||||||
(void)axis_set.insert(e_value);
|
(void)axis_set.insert(e_value);
|
||||||
}
|
}
|
||||||
|
MS_EXCEPTION_IF_NULL(x_shp_value->cast<ValueTuplePtr>());
|
||||||
auto x_shp_data = x_shp_value->cast<ValueTuplePtr>()->value();
|
auto x_shp_data = x_shp_value->cast<ValueTuplePtr>()->value();
|
||||||
if (x_shp_data.size() < x_rank) {
|
if (x_shp_data.size() < x_rank) {
|
||||||
MS_LOG(EXCEPTION) << "x_shape_data.size() " << x_shp_data.size() << " less than x_shape.size() " << x_rank;
|
MS_LOG(EXCEPTION) << "x_shape_data.size() " << x_shp_data.size() << " less than x_shape.size() " << x_rank;
|
||||||
|
@ -254,6 +255,7 @@ AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const
|
||||||
// this primitive get the index that need to reduce
|
// this primitive get the index that need to reduce
|
||||||
// input: x's shape and y's shape, inputs should be tuple
|
// input: x's shape and y's shape, inputs should be tuple
|
||||||
// output: tuple of x and y 's reduce index, reduce index should be a tuple
|
// output: tuple of x and y 's reduce index, reduce index should be a tuple
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
const std::string op_name = primitive->name();
|
const std::string op_name = primitive->name();
|
||||||
const size_t inputs_size = 2;
|
const size_t inputs_size = 2;
|
||||||
CheckArgsSize(op_name, args_spec_list, inputs_size);
|
CheckArgsSize(op_name, args_spec_list, inputs_size);
|
||||||
|
@ -289,6 +291,7 @@ AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &engine, const Primitiv
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
// Inputs: fn, list1, list2, ...
|
// Inputs: fn, list1, list2, ...
|
||||||
MS_EXCEPTION_IF_NULL(engine);
|
MS_EXCEPTION_IF_NULL(engine);
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
if (args_spec_list.size() <= 1) {
|
if (args_spec_list.size() <= 1) {
|
||||||
MS_LOG(EXCEPTION) << "List_map requires at least 1 list. while the input size is " << args_spec_list.size() << ".";
|
MS_LOG(EXCEPTION) << "List_map requires at least 1 list. while the input size is " << args_spec_list.size() << ".";
|
||||||
}
|
}
|
||||||
|
@ -317,11 +320,13 @@ AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const Primi
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
// Inputs: a fn, a list and an object of a subclass of a AbstractBase.
|
// Inputs: a fn, a list and an object of a subclass of a AbstractBase.
|
||||||
MS_EXCEPTION_IF_NULL(engine);
|
MS_EXCEPTION_IF_NULL(engine);
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
const std::string op_name = primitive->name();
|
const std::string op_name = primitive->name();
|
||||||
const size_t inputs_size = 3;
|
const size_t inputs_size = 3;
|
||||||
CheckArgsSize(op_name, args_spec_list, inputs_size);
|
CheckArgsSize(op_name, args_spec_list, inputs_size);
|
||||||
AbstractFunctionPtr fn = CheckArg<AbstractFunction>(op_name, args_spec_list, 0);
|
AbstractFunctionPtr fn = CheckArg<AbstractFunction>(op_name, args_spec_list, 0);
|
||||||
AbstractListPtr lst = CheckArg<AbstractList>(op_name, args_spec_list, 1);
|
AbstractListPtr lst = CheckArg<AbstractList>(op_name, args_spec_list, 1);
|
||||||
|
MS_EXCEPTION_IF_NULL(lst);
|
||||||
AbstractBasePtr dflt = args_spec_list[2];
|
AbstractBasePtr dflt = args_spec_list[2];
|
||||||
|
|
||||||
AbstractBasePtr list_type = AbstractJoin(lst->elements());
|
AbstractBasePtr list_type = AbstractJoin(lst->elements());
|
||||||
|
@ -337,10 +342,11 @@ AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const Primi
|
||||||
AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
// Inputs: a tuple
|
// Inputs: a tuple
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
const std::string op_name = primitive->name();
|
const std::string op_name = primitive->name();
|
||||||
CheckArgsSize(op_name, args_spec_list, 1);
|
CheckArgsSize(op_name, args_spec_list, 1);
|
||||||
AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||||
|
MS_EXCEPTION_IF_NULL(input);
|
||||||
auto tuple_elements = input->elements();
|
auto tuple_elements = input->elements();
|
||||||
AbstractBasePtrList elem_list;
|
AbstractBasePtrList elem_list;
|
||||||
(void)std::transform(tuple_elements.rbegin(), tuple_elements.rend(), std::back_inserter(elem_list),
|
(void)std::transform(tuple_elements.rbegin(), tuple_elements.rend(), std::back_inserter(elem_list),
|
||||||
|
@ -351,10 +357,12 @@ AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const Primitiv
|
||||||
AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
// Inputs: x_shape, axis
|
// Inputs: x_shape, axis
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
const std::string op_name = primitive->name();
|
const std::string op_name = primitive->name();
|
||||||
constexpr size_t arg_size = 2;
|
constexpr size_t arg_size = 2;
|
||||||
CheckArgsSize(op_name, args_spec_list, arg_size);
|
CheckArgsSize(op_name, args_spec_list, arg_size);
|
||||||
AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||||
|
MS_EXCEPTION_IF_NULL(shape_x);
|
||||||
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
|
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
|
||||||
|
|
||||||
auto x_shp_value = shape_x->BuildValue();
|
auto x_shp_value = shape_x->BuildValue();
|
||||||
|
@ -391,44 +399,48 @@ AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitiveP
|
||||||
AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
// Inputs: two tuples.
|
// Inputs: two tuples.
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
const std::string op_name = primitive->name();
|
const std::string op_name = primitive->name();
|
||||||
constexpr size_t arg_size = 2;
|
constexpr size_t arg_size = 2;
|
||||||
CheckArgsSize(op_name, args_spec_list, arg_size);
|
CheckArgsSize(op_name, args_spec_list, arg_size);
|
||||||
AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||||
AbstractTuplePtr div_shp = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
|
AbstractTuplePtr div_shp = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
|
||||||
MS_LOG(INFO) << "DivShape input:" << shape_x->ToString() << ", div:" << div_shp->ToString();
|
MS_EXCEPTION_IF_NULL(shape_x);
|
||||||
|
MS_EXCEPTION_IF_NULL(div_shp);
|
||||||
|
MS_LOG(INFO) << "The shape of dividend:" << shape_x->ToString() << ", the shape of divisor:" << div_shp->ToString();
|
||||||
|
|
||||||
auto div_shp_value = div_shp->BuildValue();
|
auto div_shp_value = div_shp->BuildValue();
|
||||||
if (div_shp_value->isa<AnyValue>()) {
|
if (div_shp_value->isa<AnyValue>()) {
|
||||||
MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << args_spec_list[0]->ToString();
|
MS_LOG(EXCEPTION) << "The shape's data field can't be anything: " << args_spec_list[0]->ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto shpx_value = shape_x->BuildValue();
|
auto shape_x_value = shape_x->BuildValue();
|
||||||
if (shpx_value->isa<AnyValue>()) {
|
if (shape_x_value->isa<AnyValue>()) {
|
||||||
MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << args_spec_list[1]->ToString();
|
MS_LOG(EXCEPTION) << "The shape's data field can't be anything: " << args_spec_list[1]->ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (div_shp->size() != shape_x->size()) {
|
if (div_shp->size() != shape_x->size()) {
|
||||||
MS_LOG(EXCEPTION) << "tileshape elems shape must the same div_shp: " << div_shp->size()
|
MS_LOG(EXCEPTION) << "The size of inputs of tuple_div operator must be same, but the size of divisor tuple is"
|
||||||
<< ", shapex: " << shape_x->size() << ".";
|
<< div_shp->size() << ", the size of dividend tuple is " << shape_x->size() << ".";
|
||||||
}
|
}
|
||||||
|
|
||||||
auto shpx_data = shpx_value->cast<ValueTuplePtr>()->value();
|
auto shape_x_data = shape_x_value->cast<ValueTuplePtr>()->value();
|
||||||
auto div_shp_data = div_shp_value->cast<ValueTuplePtr>()->value();
|
auto div_shape_data = div_shp_value->cast<ValueTuplePtr>()->value();
|
||||||
AbstractBasePtrList values;
|
AbstractBasePtrList values;
|
||||||
|
|
||||||
for (size_t i = 0; i < div_shp_data.size(); i++) {
|
for (size_t i = 0; i < div_shape_data.size(); i++) {
|
||||||
if (div_shp_data[i]->cast<Int64ImmPtr>() == nullptr) {
|
if (div_shape_data[i]->cast<Int64ImmPtr>() == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "div_shp_shape data should be an int64 number, but it's " << args_spec_list[1]->ToString();
|
MS_LOG(EXCEPTION) << "div_shp_shape data should be an int64 number, but it's " << args_spec_list[1]->ToString();
|
||||||
}
|
}
|
||||||
int64_t shapex_value = GetValue<int64_t>(shpx_data[i]);
|
int64_t shapex_value = GetValue<int64_t>(shape_x_data[i]);
|
||||||
int64_t div_value = GetValue<int64_t>(div_shp_data[i]);
|
int64_t div_value = GetValue<int64_t>(div_shape_data[i]);
|
||||||
MS_LOG(DEBUG) << "div_shp_shape data shapex_value :" << shapex_value << " div_value: " << div_value;
|
MS_LOG(DEBUG) << "div_shp_shape data shapex_value :" << shapex_value << " div_value: " << div_value;
|
||||||
if (div_value == 0) {
|
if (div_value == 0) {
|
||||||
MS_LOG(EXCEPTION) << "error: division value should not be 0!";
|
MS_LOG(EXCEPTION) << "The divisor value should not be 0!";
|
||||||
}
|
}
|
||||||
if ((shapex_value % div_value) != 0) {
|
if ((shapex_value % div_value) != 0) {
|
||||||
MS_LOG(EXCEPTION) << "div_shp_shape data shapex must div int64_t:" << shapex_value << " div_value: " << div_value;
|
MS_LOG(EXCEPTION) << "The inputs of tuple_div is not divisible, the dividend is :" << shapex_value
|
||||||
|
<< ", the divisor is: " << div_value << ".";
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t result = shapex_value / div_value;
|
int64_t result = shapex_value / div_value;
|
||||||
|
@ -445,13 +457,13 @@ AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitiveP
|
||||||
const std::string op_name = primitive->name();
|
const std::string op_name = primitive->name();
|
||||||
CheckArgsSize(op_name, args_spec_list, 1);
|
CheckArgsSize(op_name, args_spec_list, 1);
|
||||||
AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||||
|
MS_EXCEPTION_IF_NULL(input);
|
||||||
py::tuple data_tuple = ValueToPyData(input->BuildValue());
|
py::tuple data_tuple = ValueToPyData(input->BuildValue());
|
||||||
py::array data = py::array(data_tuple);
|
py::array data = py::array(data_tuple);
|
||||||
auto tensor = tensor::TensorPy::MakeTensor(data);
|
auto tensor = tensor::TensorPy::MakeTensor(data);
|
||||||
auto ret = tensor->ToAbstract();
|
auto ret = tensor->ToAbstract();
|
||||||
ret->set_value(tensor);
|
ret->set_value(tensor);
|
||||||
MS_LOG(DEBUG) << "Tuple2arry result AbstractTensor: " << ret->ToString();
|
MS_LOG(DEBUG) << "Tuple2array result AbstractTensor: " << ret->ToString();
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -465,7 +477,7 @@ AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr
|
||||||
|
|
||||||
auto shpx_value = shape_x->BuildValue();
|
auto shpx_value = shape_x->BuildValue();
|
||||||
if (shpx_value->isa<AnyValue>()) {
|
if (shpx_value->isa<AnyValue>()) {
|
||||||
MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << shape_x->ToString();
|
MS_LOG(EXCEPTION) << "The shape's data field can't be anything: " << shape_x->ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto shpx_data = shpx_value->cast<ValueTuplePtr>()->value();
|
auto shpx_data = shpx_value->cast<ValueTuplePtr>()->value();
|
||||||
|
@ -477,7 +489,7 @@ AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr
|
||||||
}
|
}
|
||||||
|
|
||||||
auto result_v = MakeValue(result);
|
auto result_v = MakeValue(result);
|
||||||
MS_LOG(DEBUG) << "shape mul result:" << result_v->ToString();
|
MS_LOG(DEBUG) << "The result of shape_mul:" << result_v->ToString();
|
||||||
return std::make_shared<AbstractScalar>(result_v, result_v->type());
|
return std::make_shared<AbstractScalar>(result_v, result_v->type());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -523,8 +535,8 @@ AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr
|
||||||
ValuePtr scalar_index = MakeValue(static_cast<int64_t>(scalar_value->cast<BoolImmPtr>()->value()));
|
ValuePtr scalar_index = MakeValue(static_cast<int64_t>(scalar_value->cast<BoolImmPtr>()->value()));
|
||||||
slice_args.push_back(scalar_index->ToAbstract());
|
slice_args.push_back(scalar_index->ToAbstract());
|
||||||
} else {
|
} else {
|
||||||
MS_EXCEPTION(TypeError) << "MakeSlice eval " << index
|
MS_EXCEPTION(TypeError) << "The " << index << "th input of scalar should be int or bool, but got "
|
||||||
<< " the input scalar type should be int or bool, but got " << scalar_value->ToString();
|
<< scalar_value->ToString();
|
||||||
}
|
}
|
||||||
} else if (args_spec_list[index]->isa<AbstractTensor>()) {
|
} else if (args_spec_list[index]->isa<AbstractTensor>()) {
|
||||||
auto arg = args_spec_list[index]->cast<AbstractTensorPtr>();
|
auto arg = args_spec_list[index]->cast<AbstractTensorPtr>();
|
||||||
|
@ -552,7 +564,7 @@ AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr
|
||||||
slice_args.push_back(args_spec_list[index]);
|
slice_args.push_back(args_spec_list[index]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
MS_EXCEPTION(TypeError) << "MakeSlice eval " << index << " inputs should scalar, None or Tensor, but got"
|
MS_EXCEPTION(TypeError) << "The " << index << "th input of MakeSlice should be scalar, none or tensor, but got"
|
||||||
<< args_spec_list[index]->ToString();
|
<< args_spec_list[index]->ToString();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -563,19 +575,19 @@ AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr
|
||||||
AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &,
|
AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
if (args_spec_list.empty()) {
|
if (args_spec_list.empty()) {
|
||||||
MS_LOG(EXCEPTION) << "Cannot make range from empty input.";
|
MS_LOG(EXCEPTION) << "The inputs of make_range operator could not be empty.";
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr size_t max_args_size = 3;
|
constexpr size_t max_args_size = 3;
|
||||||
if (args_spec_list.size() > max_args_size) {
|
if (args_spec_list.size() > max_args_size) {
|
||||||
MS_LOG(EXCEPTION) << "Error args size of make range operational.";
|
MS_LOG(EXCEPTION) << "The size of inputs of make_range operator could not exceed 3.";
|
||||||
}
|
}
|
||||||
|
|
||||||
SlideInfo slide = {0, 1, 0};
|
SlideInfo slide = {0, 1, 0};
|
||||||
CalcSlidePara(args_spec_list, &slide);
|
CalcSlidePara(args_spec_list, &slide);
|
||||||
|
|
||||||
if (slide.step == 0) {
|
if (slide.step == 0) {
|
||||||
MS_LOG(EXCEPTION) << "Error, step value is 0.";
|
MS_LOG(EXCEPTION) << "The step value of make_range operator could not be 0.";
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtrList args;
|
AbstractBasePtrList args;
|
||||||
|
|
|
@ -25,7 +25,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
self.args = _inner_ops.BroadcastGradientArgs()
|
self.args = _inner_ops.DynamicBroadcastGradientArgs()
|
||||||
|
|
||||||
def construct(self, s0, s1):
|
def construct(self, s0, s1):
|
||||||
return self.args(s0, s1)
|
return self.args(s0, s1)
|
||||||
|
|
|
@ -0,0 +1,189 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.common.api import ms_function
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
|
||||||
|
add = P.Add()
|
||||||
|
hyper_map = C.HyperMap()
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def main_noleaf(x, y):
|
||||||
|
return hyper_map(add, x, y)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hypermap_noleaf_tuple_list_mix():
|
||||||
|
"""
|
||||||
|
Feature: Check the types of inputs of HyperMap.
|
||||||
|
Description: The types of inputs of HyperMap must be the same.
|
||||||
|
Expectation: The types of inputs of HyperMap must be the same.
|
||||||
|
"""
|
||||||
|
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
|
||||||
|
tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
|
||||||
|
with pytest.raises(Exception, match="HyperMap cannot match up all input types of arguments."):
|
||||||
|
main_noleaf((tensor1, 1), [tensor2, 2])
|
||||||
|
|
||||||
|
|
||||||
|
def test_hypermap_noleaf_tuple_length():
|
||||||
|
"""
|
||||||
|
Feature: Check the length of arg of Tuple in HyperMap.
|
||||||
|
Description: The length of inputs of HyperMap must be the same.
|
||||||
|
Expectation: The length of inputs of HyperMap must be the same.
|
||||||
|
"""
|
||||||
|
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
|
||||||
|
tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
|
||||||
|
with pytest.raises(Exception, match="Tuple in HyperMap should have same length."):
|
||||||
|
main_noleaf((tensor1, 1), (tensor2, 2, 2))
|
||||||
|
|
||||||
|
|
||||||
|
def test_hypermap_noleaf_list_length():
|
||||||
|
"""
|
||||||
|
Feature: Check the length of arg of List in HyperMap.
|
||||||
|
Description: Check the length of arg of List in HyperMap.
|
||||||
|
Expectation: Check the length of arg of List in HyperMap.
|
||||||
|
"""
|
||||||
|
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
|
||||||
|
tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
|
||||||
|
with pytest.raises(Exception, match="List in HyperMap should have same length."):
|
||||||
|
main_noleaf([tensor1], [tensor2, tensor2])
|
||||||
|
|
||||||
|
|
||||||
|
def test_hypermap_noleaf_list_tuple():
|
||||||
|
"""
|
||||||
|
Feature: Check the types of inputs of HyperMap.
|
||||||
|
Description: The types of inputs of HyperMap must be the same.
|
||||||
|
Expectation: The types of inputs of HyperMap must be the same.
|
||||||
|
"""
|
||||||
|
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
|
||||||
|
tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
|
||||||
|
with pytest.raises(Exception, match="HyperMap cannot match up all input types of arguments."):
|
||||||
|
main_noleaf([tensor1], (tensor2, tensor2))
|
||||||
|
|
||||||
|
|
||||||
|
def test_tuple_slice_stop_index():
|
||||||
|
"""
|
||||||
|
Feature: Check the type of stop index of slice.
|
||||||
|
Description: The type of stop index of slice must be scalar, None or Tensor.
|
||||||
|
Expectation: The type of stop index of slice must be scalar, None or Tensor.
|
||||||
|
"""
|
||||||
|
class TupleSliceNet(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(TupleSliceNet, self).__init__()
|
||||||
|
self.addN = P.AddN()
|
||||||
|
self.index_0 = Tensor(3)
|
||||||
|
|
||||||
|
def construct(self, tensor_tuple):
|
||||||
|
tensor_tuple_slice0 = tensor_tuple[:]
|
||||||
|
tensor_tuple_slice1 = tensor_tuple[self.index_0:"str"] # slice should be Scalar or None, rather than string
|
||||||
|
sum0 = self.addN(tensor_tuple_slice0)
|
||||||
|
sum1 = self.addN(tensor_tuple_slice1)
|
||||||
|
ret = sum0 + sum1
|
||||||
|
return ret
|
||||||
|
|
||||||
|
data = (Tensor(np.ones([2, 3, 4], np.int32)),
|
||||||
|
Tensor(np.zeros([2, 3, 4], np.int32)),
|
||||||
|
Tensor(np.ones([2, 3, 4], np.int32)),
|
||||||
|
Tensor(np.ones([2, 3, 4], np.int32)),
|
||||||
|
Tensor(np.zeros([2, 3, 4], np.int32)),
|
||||||
|
Tensor(np.ones([2, 3, 4], np.int32)))
|
||||||
|
|
||||||
|
net = TupleSliceNet()
|
||||||
|
with pytest.raises(Exception, match="The 1th input of MakeSlice should be scalar, none or tensor, but got str"):
|
||||||
|
output = net(data)
|
||||||
|
print("output:", output)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tuple_slice_start_index():
|
||||||
|
"""
|
||||||
|
Feature: Check the type of start index of slice.
|
||||||
|
Description: The type of start index of slice must be scalar, None or Tensor.
|
||||||
|
Expectation: The type of start index of slice must be scalar, None or Tensor.
|
||||||
|
"""
|
||||||
|
class TupleSliceNet(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(TupleSliceNet, self).__init__()
|
||||||
|
self.addN = P.AddN()
|
||||||
|
self.index_0 = Tensor(3)
|
||||||
|
self.index_1 = Tensor([5])
|
||||||
|
self.index_3 = Tensor([True])
|
||||||
|
|
||||||
|
def construct(self, tensor_tuple):
|
||||||
|
tensor_tuple_slice0 = tensor_tuple[:]
|
||||||
|
tensor_tuple_slice1 = tensor_tuple["str":self.index_0]
|
||||||
|
tensor_tuple_slice2 = tensor_tuple[self.index_3:]
|
||||||
|
tensor_tuple_slice3 = tensor_tuple[2:self.index_1:]
|
||||||
|
sum0 = self.addN(tensor_tuple_slice0)
|
||||||
|
sum1 = self.addN(tensor_tuple_slice1)
|
||||||
|
sum2 = self.addN(tensor_tuple_slice2)
|
||||||
|
sum3 = self.addN(tensor_tuple_slice3)
|
||||||
|
ret = sum0 + sum1 + sum2 + sum3
|
||||||
|
return ret
|
||||||
|
|
||||||
|
data = (Tensor(np.ones([2, 3, 4], np.int32)),
|
||||||
|
Tensor(np.zeros([2, 3, 4], np.int32)),
|
||||||
|
Tensor(np.ones([2, 3, 4], np.int32)),
|
||||||
|
Tensor(np.ones([2, 3, 4], np.int32)),
|
||||||
|
Tensor(np.zeros([2, 3, 4], np.int32)),
|
||||||
|
Tensor(np.ones([2, 3, 4], np.int32)))
|
||||||
|
|
||||||
|
net = TupleSliceNet()
|
||||||
|
with pytest.raises(Exception, match="The 0th input of MakeSlice should be scalar, none or tensor, but got str"):
|
||||||
|
output = net(data)
|
||||||
|
print("output:", output)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tuple_slice_step():
|
||||||
|
"""
|
||||||
|
Feature: Check the type of step of slice.
|
||||||
|
Description: The type of step of slice must not be 0.
|
||||||
|
Expectation: The type of step of slice must be scalar, None or Tensor.
|
||||||
|
"""
|
||||||
|
class TupleSliceNet(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(TupleSliceNet, self).__init__()
|
||||||
|
self.addN = P.AddN()
|
||||||
|
self.index_0 = Tensor(3)
|
||||||
|
self.index_1 = Tensor([5])
|
||||||
|
self.index_3 = Tensor([True])
|
||||||
|
|
||||||
|
def construct(self, tensor_tuple):
|
||||||
|
tensor_tuple_slice0 = tensor_tuple[:]
|
||||||
|
tensor_tuple_slice1 = tensor_tuple[:self.index_0]
|
||||||
|
tensor_tuple_slice2 = tensor_tuple[self.index_3:]
|
||||||
|
tensor_tuple_slice3 = tensor_tuple[2:self.index_1:0]
|
||||||
|
sum0 = self.addN(tensor_tuple_slice0)
|
||||||
|
sum1 = self.addN(tensor_tuple_slice1)
|
||||||
|
sum2 = self.addN(tensor_tuple_slice2)
|
||||||
|
sum3 = self.addN(tensor_tuple_slice3)
|
||||||
|
ret = sum0 + sum1 + sum2 + sum3
|
||||||
|
return ret
|
||||||
|
|
||||||
|
data = (Tensor(np.ones([2, 3, 4], np.int32)),
|
||||||
|
Tensor(np.zeros([2, 3, 4], np.int32)),
|
||||||
|
Tensor(np.ones([2, 3, 4], np.int32)),
|
||||||
|
Tensor(np.ones([2, 3, 4], np.int32)),
|
||||||
|
Tensor(np.zeros([2, 3, 4], np.int32)),
|
||||||
|
Tensor(np.ones([2, 3, 4], np.int32)))
|
||||||
|
|
||||||
|
net = TupleSliceNet()
|
||||||
|
with pytest.raises(Exception, match="TupleSlice require the step value could not be 0, but got 0."):
|
||||||
|
output = net(data)
|
||||||
|
print("output:", output)
|
|
@ -0,0 +1,170 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore import Tensor, nn, Parameter
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
import mindspore as ms
|
||||||
|
|
||||||
|
|
||||||
|
def test_map_args_size():
|
||||||
|
"""
|
||||||
|
Feature: Check the size of inputs of map.
|
||||||
|
Description: The size of inputs of map must be greater than 1.
|
||||||
|
Expectation: The size of inputs of map must be greater than 1.
|
||||||
|
"""
|
||||||
|
class MapNet(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def mul(self, x=2, y=4):
|
||||||
|
return x * y
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
if map(self.mul) == 8:
|
||||||
|
x = self.relu(x)
|
||||||
|
return x
|
||||||
|
input_np_x = np.random.randn(2, 3, 4, 5).astype(np.float32)
|
||||||
|
input_me_x = Tensor(input_np_x)
|
||||||
|
|
||||||
|
net = MapNet()
|
||||||
|
with pytest.raises(Exception, match="The map operator must have at least two arguments."):
|
||||||
|
ret = net(input_me_x)
|
||||||
|
print("ret:", ret)
|
||||||
|
|
||||||
|
|
||||||
|
def test_map_args_type():
|
||||||
|
"""
|
||||||
|
Feature: Check the type of inputs of Map().
|
||||||
|
Description: The type of inputs of Map() must be list, tuple or class.
|
||||||
|
Expectation: The type of inputs of Map() must be list, tuple or class.
|
||||||
|
"""
|
||||||
|
class MapNet(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def mul(self, x=2, y=4):
|
||||||
|
return x * y
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
if map(self.mul, 3, 4) == 8:
|
||||||
|
x = self.relu(x)
|
||||||
|
return x
|
||||||
|
input_np_x = np.random.randn(2, 3, 4, 5).astype(np.float32)
|
||||||
|
input_me_x = Tensor(input_np_x)
|
||||||
|
|
||||||
|
net = MapNet()
|
||||||
|
with pytest.raises(Exception, match="Map can only be applied to list, tuple and class"):
|
||||||
|
ret = net(input_me_x)
|
||||||
|
print("ret:", ret)
|
||||||
|
|
||||||
|
|
||||||
|
def test_map_args_full_make_list():
|
||||||
|
"""
|
||||||
|
Feature: Check the types of all inputs in Map.
|
||||||
|
Description: The types of all inputs in Map must be same.
|
||||||
|
Expectation: The types of all inputs in Map must be same.
|
||||||
|
"""
|
||||||
|
class MapNet(Cell):
|
||||||
|
def mul(self, x=2, y=4):
|
||||||
|
return x * y
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
if map(self.mul, x, y) == [8]:
|
||||||
|
x = y
|
||||||
|
return x
|
||||||
|
|
||||||
|
input_me_x = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
|
||||||
|
input_me_y = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
|
||||||
|
|
||||||
|
net = MapNet()
|
||||||
|
with pytest.raises(Exception, match="Map cannot match up all input types of arguments."):
|
||||||
|
ret = net([input_me_x], (input_me_y))
|
||||||
|
print("ret:", ret)
|
||||||
|
|
||||||
|
|
||||||
|
def test_map_args_full_make_list_same_length():
|
||||||
|
"""
|
||||||
|
Feature: Check the length of list input Map.
|
||||||
|
Description: The list in Map should have same length.
|
||||||
|
Expectation: The list in Map should have same length.
|
||||||
|
"""
|
||||||
|
class MapNet(Cell):
|
||||||
|
def mul(self, x=2, y=4):
|
||||||
|
return x * y
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
if map(self.mul, x, y) == [8]:
|
||||||
|
x = y
|
||||||
|
return x
|
||||||
|
|
||||||
|
input_me_x = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
|
||||||
|
input_me_y = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
|
||||||
|
|
||||||
|
net = MapNet()
|
||||||
|
with pytest.raises(Exception, match="List in Map should have same length."):
|
||||||
|
ret = net([input_me_x], [input_me_y, input_me_y])
|
||||||
|
print("ret:", ret)
|
||||||
|
|
||||||
|
|
||||||
|
def test_map_args_full_make_tuple_same_length():
|
||||||
|
"""
|
||||||
|
Feature: Check the length of tuple input Map.
|
||||||
|
Description: The tuple in Map should have same length.
|
||||||
|
Expectation: The tuple in Map should have same length.
|
||||||
|
"""
|
||||||
|
class MapNet(Cell):
|
||||||
|
def mul(self, x=2, y=4):
|
||||||
|
return x * y
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
if map(self.mul, x, y) == [8]:
|
||||||
|
x = y
|
||||||
|
return x
|
||||||
|
|
||||||
|
input_me_x = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
|
||||||
|
input_me_y = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
|
||||||
|
|
||||||
|
net = MapNet()
|
||||||
|
with pytest.raises(Exception, match="Tuple in Map should have same length."):
|
||||||
|
ret = net((input_me_x, input_me_x), (input_me_y, input_me_y, input_me_y))
|
||||||
|
print("ret:", ret)
|
||||||
|
|
||||||
|
|
||||||
|
def test_map_param_cast():
|
||||||
|
"""
|
||||||
|
Feature: Check the ref type when insert auto cast.
|
||||||
|
Description: Check the ref type when insert auto cast.
|
||||||
|
Expectation: Check the ref type when insert auto cast.
|
||||||
|
"""
|
||||||
|
class MapNet(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.param = Parameter(Tensor(5, ms.float32), name="param_b")
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
self.param = x
|
||||||
|
return self.param
|
||||||
|
|
||||||
|
input_me_x = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float64))
|
||||||
|
|
||||||
|
net = MapNet()
|
||||||
|
with pytest.raises(Exception, match="In op 'S-Prim-Assign', the type of writable argument is 'float32'"):
|
||||||
|
ret = net(input_me_x)
|
||||||
|
print("ret:", ret)
|
|
@ -0,0 +1,240 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mindspore import Tensor, context, Parameter
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
import mindspore as ms
|
||||||
|
|
||||||
|
|
||||||
|
def test_inner_scalar_divisor():
|
||||||
|
"""
|
||||||
|
Feature: Check whether the divisor of inner scalar is zero.
|
||||||
|
Description: The divisor of inner scalar must not be zero.
|
||||||
|
Expectation: The divisor of inner scalar must not be zero.
|
||||||
|
"""
|
||||||
|
class Net(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.param_a = Parameter(Tensor(5, ms.int32), name="param_a")
|
||||||
|
self.param_b = Parameter(Tensor(5, ms.int32), name="param_b")
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return x + self.param_a + 5 / 0
|
||||||
|
|
||||||
|
context.set_context(device_target="GPU")
|
||||||
|
x = Tensor(2, dtype=ms.int32)
|
||||||
|
net = Net()
|
||||||
|
with pytest.raises(Exception, match="The divisor could not be zero."):
|
||||||
|
ret = net(x)
|
||||||
|
print("ret:", ret)
|
||||||
|
|
||||||
|
|
||||||
|
def test_inner_scalar_mod():
|
||||||
|
"""
|
||||||
|
Feature: Check the input of inner scalar mod.
|
||||||
|
Description: The input of inner scalar mod must not be zero.
|
||||||
|
Expectation: The input of inner scalar mod must not be zero.
|
||||||
|
"""
|
||||||
|
class Net(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.param_a = Parameter(Tensor(5, ms.int32), name="param_a")
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return x + self.param_a + 5 % 0
|
||||||
|
|
||||||
|
x = Tensor(2, dtype=ms.int32)
|
||||||
|
net = Net()
|
||||||
|
with pytest.raises(Exception, match="Could not mod to zero."):
|
||||||
|
ret = net(x)
|
||||||
|
print("ret:", ret)
|
||||||
|
|
||||||
|
|
||||||
|
def test_inner_scalar_mod_args_length():
|
||||||
|
"""
|
||||||
|
Feature: Check the length of input of inner scalar mod.
|
||||||
|
Description: The length of input of inner scalar mod should not less than 2.
|
||||||
|
Expectation: The length of input of inner scalar mod should not less than 2.
|
||||||
|
"""
|
||||||
|
class Net(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.param_a = Parameter(Tensor(5, ms.int32), name="param_a")
|
||||||
|
self.mod = P.Mod()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return x + self.param_a + self.mod(5)
|
||||||
|
|
||||||
|
x = Tensor(2, dtype=ms.int32)
|
||||||
|
net = Net()
|
||||||
|
with pytest.raises(Exception, match="Function S-Prim-Mod's input length is not equal to Signature length."):
|
||||||
|
ret = net(x)
|
||||||
|
print("ret:", ret)
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_range_input_is_empty():
|
||||||
|
"""
|
||||||
|
Feature: Check the length of inputs of make_range operator.
|
||||||
|
Description: The inputs of make_range operator could not be empty.
|
||||||
|
Expectation: The inputs of make_range operator could not be empty.
|
||||||
|
"""
|
||||||
|
class Net(Cell):
|
||||||
|
def construct(self, x, y):
|
||||||
|
for _ in F.make_range():
|
||||||
|
x += y
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = Tensor(2, dtype=ms.int32)
|
||||||
|
y = Tensor(4, dtype=ms.int32)
|
||||||
|
net = Net()
|
||||||
|
with pytest.raises(Exception, match="The inputs of make_range operator could not be empty."):
|
||||||
|
ret = net(x, y)
|
||||||
|
print("ret:", ret)
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_range_input_type():
|
||||||
|
"""
|
||||||
|
Feature: Check the type of inputs of make_range operator.
|
||||||
|
Description: The type of inputs of make_range operator must be int64.
|
||||||
|
Expectation: The type of inputs of make_range operator must be int64.
|
||||||
|
"""
|
||||||
|
class Net(Cell):
|
||||||
|
def construct(self, x, y):
|
||||||
|
for _ in F.make_range(0, 0.02):
|
||||||
|
x += y
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = Tensor(2, dtype=ms.int32)
|
||||||
|
y = Tensor(4, dtype=ms.int32)
|
||||||
|
net = Net()
|
||||||
|
with pytest.raises(Exception, match="The type of inputs of make_range operator only support int64 number."):
|
||||||
|
ret = net(x, y)
|
||||||
|
print("ret:", ret)
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_range_input_size():
|
||||||
|
"""
|
||||||
|
Feature: Check the size of inputs of make_range operator.
|
||||||
|
Description: The size of inputs of make_range operator could not exceed 3.
|
||||||
|
Expectation: The size of inputs of make_range operator could not exceed 3.
|
||||||
|
"""
|
||||||
|
class Net(Cell):
|
||||||
|
def construct(self, x, y):
|
||||||
|
for _ in F.make_range(1, 2, 3, 4):
|
||||||
|
x += y
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = Tensor(2, dtype=ms.int32)
|
||||||
|
y = Tensor(4, dtype=ms.int32)
|
||||||
|
net = Net()
|
||||||
|
with pytest.raises(Exception, match="The size of inputs of make_range operator could not exceed 3."):
|
||||||
|
ret = net(x, y)
|
||||||
|
print("ret:", ret)
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_range_overflow():
|
||||||
|
"""
|
||||||
|
Feature: Check the size of inputs of make_range operator.
|
||||||
|
Description: The size of inputs of make_range operator could not exceed 3.
|
||||||
|
Expectation: The size of inputs of make_range operator could not exceed 3.
|
||||||
|
"""
|
||||||
|
class Net(Cell):
|
||||||
|
def construct(self, x, y):
|
||||||
|
max_index = sys.maxsize
|
||||||
|
for _ in F.make_range(max_index - 1, max_index, 3):
|
||||||
|
x += y
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = Tensor(2, dtype=ms.int32)
|
||||||
|
y = Tensor(4, dtype=ms.int32)
|
||||||
|
net = Net()
|
||||||
|
with pytest.raises(Exception, match="For make range, the required cycles number is greater than max cycles number"):
|
||||||
|
ret = net(x, y)
|
||||||
|
print("ret:", ret)
|
||||||
|
|
||||||
|
|
||||||
|
def test_typeof():
|
||||||
|
"""
|
||||||
|
Feature: Check the size of inputs of typeof operator.
|
||||||
|
Description: The size of inputs of typeof operator must be 1.
|
||||||
|
Expectation: The size of inputs of typeof operator must be 1.
|
||||||
|
"""
|
||||||
|
class Net(Cell):
|
||||||
|
def construct(self, x):
|
||||||
|
return F.typeof(x, x)
|
||||||
|
|
||||||
|
x = Tensor([2, 3, 4, 5], dtype=ms.int32)
|
||||||
|
net = Net()
|
||||||
|
with pytest.raises(Exception, match="Typeof evaluator requires 1 parameter, while the input size is 2."):
|
||||||
|
ret = net(x)
|
||||||
|
print("ret:", ret)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tuple_div():
|
||||||
|
"""
|
||||||
|
Feature: Check the size of inputs of tuple_div operator.
|
||||||
|
Description: The size of inputs of tuple_div operator must be same.
|
||||||
|
Expectation: The size of inputs of tuple_div operator must be same.
|
||||||
|
"""
|
||||||
|
class Net(Cell):
|
||||||
|
def construct(self, x, y):
|
||||||
|
return F.tuple_div(x, y)
|
||||||
|
|
||||||
|
x = (8, 14, 20)
|
||||||
|
y = (2, 2)
|
||||||
|
net = Net()
|
||||||
|
with pytest.raises(Exception, match="The size of inputs of tuple_div operator must be same"):
|
||||||
|
ret = net(x, y)
|
||||||
|
print("ret:", ret)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tuple_div_input_is_not_divisible():
|
||||||
|
"""
|
||||||
|
Feature: Check whether the inputs of tuple_div is divisible.
|
||||||
|
Description: The inputs of tuple_div could be divisible.
|
||||||
|
Expectation: The inputs of tuple_div could be divisible.
|
||||||
|
"""
|
||||||
|
class Net(Cell):
|
||||||
|
def construct(self, x, y):
|
||||||
|
return F.tuple_div(x, y)
|
||||||
|
|
||||||
|
x = (8, 14)
|
||||||
|
y = (2, 3)
|
||||||
|
net = Net()
|
||||||
|
with pytest.raises(Exception, match="The inputs of tuple_div is not divisible"):
|
||||||
|
ret = net(x, y)
|
||||||
|
print("ret:", ret)
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_slice_scalar():
|
||||||
|
"""
|
||||||
|
Feature: Check whether the scalar input of make_slice is int or bool.
|
||||||
|
Description: The scalar input of make_slice is int or bool.
|
||||||
|
Expectation: The scalar input of make_slice is int or bool.
|
||||||
|
"""
|
||||||
|
class Net(Cell):
|
||||||
|
def construct(self, data):
|
||||||
|
return data[F.make_slice(1.01, None, None)]
|
||||||
|
|
||||||
|
x = Tensor((8, 10, 12), dtype=ms.int32)
|
||||||
|
net = Net()
|
||||||
|
with pytest.raises(Exception, match="The 0th input of scalar should be int or bool"):
|
||||||
|
ret = net(x)
|
||||||
|
print("ret:", ret)
|
|
@ -0,0 +1,85 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore import Tensor, Parameter
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
import mindspore as ms
|
||||||
|
|
||||||
|
|
||||||
|
def test_zip_operation_args_size():
|
||||||
|
"""
|
||||||
|
Feature: Check the size of inputs of ZipOperation.
|
||||||
|
Description: The inputs of ZipOperation must not be empty.
|
||||||
|
Expectation: The size of inputs of ZipOperation must be greater than 0.
|
||||||
|
"""
|
||||||
|
class AssignInZipLoop(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = ms.nn.Conv2d(3, 2, 1, weight_init="zero")
|
||||||
|
self.conv2 = ms.nn.Conv2d(3, 2, 1, weight_init="zero")
|
||||||
|
self.params1 = self.conv1.trainable_params()
|
||||||
|
self.params2 = self.conv2.trainable_params()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
for p1, p2 in zip():
|
||||||
|
P.Assign()(p2, p1 + x)
|
||||||
|
|
||||||
|
out = 0
|
||||||
|
for p1, p2 in zip(self.params1, self.params2):
|
||||||
|
out = p1 + p2
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
x = Tensor.from_numpy(np.ones([1], np.float32))
|
||||||
|
net = AssignInZipLoop()
|
||||||
|
with pytest.raises(Exception, match="For 'zip', there is at least one input."):
|
||||||
|
out = net(x)
|
||||||
|
assert np.all(out.asnumpy() == 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_zip_operation_args_type():
|
||||||
|
"""
|
||||||
|
Feature: Check the type of inputs of ZipOperation.
|
||||||
|
Description: Check whether all inputs in zip is sequeue.
|
||||||
|
Expectation: All inputs in zip must be sequeue.
|
||||||
|
"""
|
||||||
|
class AssignInZipLoop(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = ms.nn.Conv2d(3, 2, 1, weight_init="zero")
|
||||||
|
self.conv2 = ms.nn.Conv2d(3, 2, 1, weight_init="zero")
|
||||||
|
self.params1 = self.conv1.trainable_params()
|
||||||
|
self.params2 = self.conv2.trainable_params()
|
||||||
|
self.param = Parameter(Tensor(5, ms.float32), name="param")
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
for p1, p2 in zip(self.params1, self.params2, self.param):
|
||||||
|
P.Assign()(p2, p1 + x)
|
||||||
|
|
||||||
|
out = 0
|
||||||
|
for p1, p2 in zip(self.params1, self.params2):
|
||||||
|
out = p1 + p2
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
x = Tensor.from_numpy(np.ones([1], np.float32))
|
||||||
|
net = AssignInZipLoop()
|
||||||
|
with pytest.raises(Exception, match="For 'zip', all inputs must be sequence."):
|
||||||
|
out = net(x)
|
||||||
|
assert np.all(out.asnumpy() == 1)
|
|
@ -191,7 +191,7 @@ TEST_F(TestImplementations, ScalarDivTest) {
|
||||||
ScalarDiv(list);
|
ScalarDiv(list);
|
||||||
FAIL();
|
FAIL();
|
||||||
} catch (std::runtime_error const &err) {
|
} catch (std::runtime_error const &err) {
|
||||||
ASSERT_TRUE(std::string(err.what()).find("Divisor could not be zero") != std::string::npos);
|
ASSERT_TRUE(std::string(err.what()).find("The divisor could not be zero.") != std::string::npos);
|
||||||
}
|
}
|
||||||
list.clear();
|
list.clear();
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue