forked from mindspore-Ecosystem/mindspore
!46418 Add sequence index and sequence mul operation and fix some infer function.
Merge pull request !46418 from LiangZhibo/list_ops
This commit is contained in:
commit
fe7acc0aa3
|
@ -126,6 +126,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext
|
||||
{"__bool__", std::string("tuple_bool")}, // C.tuple_bool
|
||||
{"count", prim::kPrimSequenceCount}, // P.sequence_count
|
||||
{"index", prim::kPrimSequenceIndex}, // P.sequenc_index
|
||||
}},
|
||||
{kObjectTypeList,
|
||||
{
|
||||
|
@ -143,6 +144,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"reverse", std::string("list_reverse")}, // C.list_reverse
|
||||
{"extend", std::string("list_extend")}, // C.list_extend
|
||||
{"count", prim::kPrimSequenceCount}, // P.sequence_count
|
||||
{"index", prim::kPrimSequenceIndex}, // P.sequence_index
|
||||
}},
|
||||
{kObjectTypeDictionary,
|
||||
{
|
||||
|
|
|
@ -381,13 +381,9 @@ AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type
|
|||
if ((base_shape->isa<Shape>())) {
|
||||
auto shape = base_shape->cast<ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
auto shape_vec = shape->shape();
|
||||
// if the size of shape list is empty, return an scalar abstract
|
||||
if (shape_vec.empty() && (!type->isa<TensorType>())) {
|
||||
abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kAnyValue, type);
|
||||
return abs_scalar;
|
||||
}
|
||||
return MakeAbstractTensor(shape, type);
|
||||
} else if (base_shape->isa<NoShape>() && type->isa<Number>()) {
|
||||
return std::make_shared<abstract::AbstractScalar>(kAnyValue, type);
|
||||
} else if (base_shape->isa<TupleShape>() && type->isa<Tuple>()) {
|
||||
auto shape_tuple = base_shape->cast_ptr<TupleShape>();
|
||||
auto type_tuple = type->cast_ptr<Tuple>();
|
||||
|
@ -416,7 +412,8 @@ AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type
|
|||
// Return monad abstract if it is monad type.
|
||||
return MakeMonadAbstract(type->cast<MonadTypePtr>());
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Evaluator return invalid shape " << base_shape->ToString() << "or type. " << type->ToString();
|
||||
MS_LOG(EXCEPTION) << "Evaluator return invalid shape " << base_shape->ToString() << " or type. "
|
||||
<< type->ToString();
|
||||
}
|
||||
}
|
||||
} // namespace abstract
|
||||
|
|
|
@ -252,6 +252,8 @@ constexpr auto kSearchSorted = "SearchSorted";
|
|||
constexpr auto kListAppend = "ListAppend";
|
||||
constexpr auto kSequenceAdd = "SequenceAdd";
|
||||
constexpr auto kSequenceCount = "SequenceCount";
|
||||
constexpr auto kSequenceIndex = "SequenceIndex";
|
||||
constexpr auto kSequenceMul = "SequenceMul";
|
||||
|
||||
// NN
|
||||
constexpr auto kFractionalMaxPoolWithFixedKsize = "FractionalMaxPoolWithFixedKsize";
|
||||
|
@ -1555,6 +1557,8 @@ GVAR_DEF(PrimitivePtr, kPrimSequenceLen, std::make_shared<Primitive>("sequence_l
|
|||
GVAR_DEF(PrimitivePtr, kPrimListAppend, std::make_shared<Primitive>(kListAppend));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSequenceAdd, std::make_shared<Primitive>(kSequenceAdd));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSequenceCount, std::make_shared<Primitive>(kSequenceCount));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSequenceIndex, std::make_shared<Primitive>(kSequenceIndex));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSequenceMul, std::make_shared<Primitive>(kSequenceMul));
|
||||
|
||||
// Other miscellaneous
|
||||
GVAR_DEF(PrimitivePtr, kPrimSampleDistortedBoundingBoxV2, std::make_shared<Primitive>(kSampleDistortedBoundingBoxV2));
|
||||
|
|
|
@ -49,8 +49,7 @@ AbstractBasePtr CheckAndGetElementType(const abstract::AbstractSequencePtr input
|
|||
return elements[0];
|
||||
}
|
||||
|
||||
AbstractBasePtr SequenceAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
AbstractBasePtr SequenceAddInferInner(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
constexpr size_t input_len = 2;
|
||||
|
@ -109,6 +108,22 @@ AbstractBasePtr SequenceAddInfer(const abstract::AnalysisEnginePtr &, const Prim
|
|||
return input_2->Clone();
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(SequenceAdd, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(SequenceAdd, prim::kPrimSequenceAdd, SequenceAddInfer, nullptr, true);
|
||||
class SequenceAddInfer : public abstract::OpInferBase {
|
||||
public:
|
||||
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return SequenceAddInferInner(primitive, input_args)->BuildShape();
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return SequenceAddInferInner(prim, input_args)->BuildType();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return SequenceAddInferInner(primitive, input_args);
|
||||
}
|
||||
};
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(SequenceAdd, prim::kPrimSequenceAdd, SequenceAddInfer, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,61 +26,52 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
AbstractBasePtr SequenceCountInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractScalar>(kAnyValue, kInt64);
|
||||
}
|
||||
|
||||
bool ComparesTwoValues(const ValuePtr &value_1, const ValuePtr &value_2) {
|
||||
MS_EXCEPTION_IF_NULL(value_1);
|
||||
MS_EXCEPTION_IF_NULL(value_2);
|
||||
if (!value_1->IsSameTypeId(value_2->tid())) {
|
||||
return false;
|
||||
}
|
||||
if (value_1->isa<tensor::Tensor>()) {
|
||||
auto list_tensor_value = value_2->cast_ptr<tensor::Tensor>();
|
||||
MS_EXCEPTION_IF_NULL(list_tensor_value);
|
||||
return value_1->cast_ptr<tensor::Tensor>()->ValueEqual(*list_tensor_value);
|
||||
}
|
||||
return *value_1 == *value_2;
|
||||
}
|
||||
|
||||
ValuePtr SequenceCountInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const size_t input_num = 2;
|
||||
auto prim_name = prim->name();
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
constexpr size_t seq_index = 0;
|
||||
constexpr size_t target_index = 1;
|
||||
auto input_abs = input_args[seq_index];
|
||||
auto target_abs = input_args[target_index];
|
||||
if (!input_abs->isa<abstract::AbstractSequence>()) {
|
||||
MS_EXCEPTION(TypeError) << "For primitive '" << prim_name << "', the first input must be a list or tuple, "
|
||||
<< "but got: " << input_abs->ToString();
|
||||
}
|
||||
auto seq_abs = input_abs->cast<abstract::AbstractSequencePtr>();
|
||||
if (seq_abs->dynamic_len()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto target_value = target_abs->BuildValue();
|
||||
if (seq_abs->BuildValue() == kAnyValue || target_value == kAnyValue) {
|
||||
return nullptr;
|
||||
}
|
||||
const auto &seq_elements = seq_abs->elements();
|
||||
int64_t count = 0;
|
||||
for (auto element : seq_elements) {
|
||||
if (ComparesTwoValues(target_value, element->BuildValue())) {
|
||||
++count;
|
||||
}
|
||||
}
|
||||
return MakeValue(count);
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(SequenceCount, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(SequenceCount, prim::kPrimSequenceCount, SequenceCountInfer, SequenceCountInferValue,
|
||||
true);
|
||||
class SequenceCountInfer : public abstract::OpInferBase {
|
||||
public:
|
||||
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return abstract::kNoShape;
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return kInt64;
|
||||
}
|
||||
|
||||
ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
constexpr size_t input_num = 2;
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
constexpr size_t seq_index = 0;
|
||||
constexpr size_t target_index = 1;
|
||||
auto input_abs = input_args[seq_index];
|
||||
auto target_abs = input_args[target_index];
|
||||
if (!input_abs->isa<abstract::AbstractSequence>()) {
|
||||
MS_EXCEPTION(TypeError) << "For primitive '" << prim_name << "', the first input must be a list or tuple, "
|
||||
<< "but got: " << input_abs->ToString();
|
||||
}
|
||||
auto seq_abs = input_abs->cast<abstract::AbstractSequencePtr>();
|
||||
if (seq_abs->dynamic_len()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto target_value = target_abs->BuildValue();
|
||||
if (seq_abs->BuildValue() == kAnyValue || target_value == kAnyValue) {
|
||||
return nullptr;
|
||||
}
|
||||
const auto &seq_elements = seq_abs->elements();
|
||||
int64_t count = 0;
|
||||
for (auto element : seq_elements) {
|
||||
if (CheckAndConvertUtils::CheckValueSame(target_value, element->BuildValue())) {
|
||||
++count;
|
||||
}
|
||||
}
|
||||
return MakeValue(count);
|
||||
}
|
||||
};
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(SequenceCount, prim::kPrimSequenceCount, SequenceCountInfer, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
/// \brief Sequence addition operation
|
||||
/// \brief Sequence count operation.
|
||||
class MIND_API SequenceCount : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(SequenceCount);
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/sequence_index.h"
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
MIND_API_OPERATOR_IMPL(SequenceIndex, BaseOperator);
|
||||
class SequenceIndexInfer : public abstract::OpInferBase {
|
||||
public:
|
||||
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return abstract::kNoShape;
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return kInt64;
|
||||
}
|
||||
|
||||
ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const size_t input_num = 2;
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
constexpr size_t seq_index = 0;
|
||||
constexpr size_t target_index = 1;
|
||||
auto input_abs = input_args[seq_index];
|
||||
auto target_abs = input_args[target_index];
|
||||
if (!input_abs->isa<abstract::AbstractSequence>()) {
|
||||
MS_EXCEPTION(TypeError) << "For primitive '" << prim_name << "', the first input must be a list or tuple, "
|
||||
<< "but got: " << input_abs->ToString();
|
||||
}
|
||||
auto seq_abs = input_abs->cast<abstract::AbstractSequencePtr>();
|
||||
if (seq_abs->dynamic_len()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto target_value = target_abs->BuildValue();
|
||||
if (seq_abs->BuildValue() == kAnyValue || target_value == kAnyValue) {
|
||||
return nullptr;
|
||||
}
|
||||
const auto &seq_elements = seq_abs->elements();
|
||||
for (size_t i = 0; i < seq_elements.size(); ++i) {
|
||||
auto element = seq_elements[i];
|
||||
if (CheckAndConvertUtils::CheckValueSame(target_value, element->BuildValue())) {
|
||||
return MakeValue(static_cast<int64_t>(i));
|
||||
}
|
||||
}
|
||||
MS_EXCEPTION(ValueError) << target_value->ToString() << " is not in " << seq_abs->ToString();
|
||||
}
|
||||
};
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(SequenceIndex, prim::kPrimSequenceIndex, SequenceIndexInfer, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SEQUENCE_INDEX_H_
|
||||
#define MINDSPORE_CORE_OPS_SEQUENCE_INDEX_H_
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
/// \brief Sequence index operation.
|
||||
class MIND_API SequenceIndex : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(SequenceIndex);
|
||||
/// \brief Constructor.
|
||||
SequenceIndex() : BaseOperator(prim::kSequenceIndex) {}
|
||||
/// \brief Init function.
|
||||
void Init() const {}
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SEQUENCE_INDEX_H_
|
|
@ -0,0 +1,79 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/sequence_mul.h"
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
AbstractBasePtr SequenceMulInferInner(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
constexpr size_t input_len = 2;
|
||||
constexpr size_t seq_index = 0;
|
||||
constexpr size_t scalar_index = 1;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_len, prim_name);
|
||||
auto first_abs = input_args[seq_index];
|
||||
if (!first_abs->isa<abstract::AbstractSequence>()) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name
|
||||
<< "', the first input should be tuple or list but got: " << first_abs->ToString();
|
||||
}
|
||||
auto seq_abs = first_abs->cast<abstract::AbstractSequencePtr>();
|
||||
auto scalar_abs = input_args[scalar_index];
|
||||
const std::set<TypePtr> scalar_valid_types = {kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("scalar", scalar_abs->BuildType(), scalar_valid_types, prim_name);
|
||||
if (seq_abs->BuildValue() != kAnyValue && scalar_abs->BuildValue() != kAnyValue) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', at least one of the inputs should be kAnyValue, but got "
|
||||
<< "sequence input: " << seq_abs->BuildValue()
|
||||
<< " and scalar input: " << scalar_abs->BuildValue();
|
||||
}
|
||||
if (seq_abs->dynamic_len()) {
|
||||
return seq_abs;
|
||||
}
|
||||
auto ret = seq_abs->Clone()->cast<abstract::AbstractSequencePtr>();
|
||||
ret->CheckAndConvertToDynamicLenSequence();
|
||||
return ret;
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(SequenceMul, BaseOperator);
|
||||
class SequenceMulInfer : public abstract::OpInferBase {
|
||||
public:
|
||||
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return SequenceMulInferInner(primitive, input_args)->BuildShape();
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return SequenceMulInferInner(prim, input_args)->BuildType();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return SequenceMulInferInner(primitive, input_args);
|
||||
}
|
||||
};
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(SequenceMul, prim::kPrimSequenceMul, SequenceMulInfer, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SEQUENCE_MUL_H_
|
||||
#define MINDSPORE_CORE_OPS_SEQUENCE_MUL_H_
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
/// \brief Sequence mul integer operation.
|
||||
class MIND_API SequenceMul : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(SequenceMul);
|
||||
/// \brief Constructor.
|
||||
SequenceMul() : BaseOperator(prim::kSequenceMul) {}
|
||||
/// \brief Init function.
|
||||
void Init() const {}
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SEQUENCE_MUL_H_
|
|
@ -27,14 +27,12 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
MIND_API_OPERATOR_IMPL(Shape, BaseOperator);
|
||||
AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
AbstractBasePtr InferInner(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
// Only called when the input of shape is dynamic shape/rank tensor.
|
||||
// infer shape
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("shape infer", int64_t(input_args.size()), kEqual, 1, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("shape infer", static_cast<int64_t>(input_args.size()), kEqual, 1, op_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto in_shape = shape_map[kShape];
|
||||
|
@ -57,25 +55,43 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
|||
}
|
||||
return abs;
|
||||
}
|
||||
|
||||
ValuePtr ShapeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("shape infer", int64_t(input_args.size()), kEqual, 1, op_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
std::set<TypePtr> valid_params_types = {kTensorType};
|
||||
(void)CheckAndConvertUtils::CheckSubClass("shape type", input_args[0]->BuildType(), valid_params_types, op_name);
|
||||
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto inshape = shape_map[kShape];
|
||||
if (std::any_of(inshape.begin(), inshape.end(), [](ShapeValueDType shape) {
|
||||
return shape == abstract::Shape::kShapeRankAny || shape == abstract::Shape::kShapeDimAny;
|
||||
})) {
|
||||
// If the input of shape is dynamic shape/rank tensor, value can not be directly built.
|
||||
// Run infer of shape.
|
||||
return nullptr;
|
||||
MIND_API_OPERATOR_IMPL(Shape, BaseOperator);
|
||||
class ShapeInfer : public abstract::OpInferBase {
|
||||
public:
|
||||
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return InferInner(primitive, input_args)->BuildShape();
|
||||
}
|
||||
return MakeValue(inshape);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Shape, prim::kPrimShape, ShapeInfer, ShapeInferValue, true);
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return InferInner(prim, input_args)->BuildType();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return InferInner(primitive, input_args);
|
||||
}
|
||||
|
||||
ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("shape infer", int64_t(input_args.size()), kEqual, 1, op_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
std::set<TypePtr> valid_params_types = {kTensorType};
|
||||
(void)CheckAndConvertUtils::CheckSubClass("shape type", input_args[0]->BuildType(), valid_params_types, op_name);
|
||||
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
if (shape_map.count(kShape) == 0) {
|
||||
MS_LOG(EXCEPTION) << "For primitive " << op_name << " the input convert shape failed.";
|
||||
}
|
||||
const auto &inshape = shape_map[kShape];
|
||||
if (IsDynamic(inshape)) {
|
||||
// If the input of shape is dynamic shape/rank tensor, value can not be directly built.
|
||||
// Run infer of shape.
|
||||
return nullptr;
|
||||
}
|
||||
return MakeValue(inshape);
|
||||
}
|
||||
};
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(Shape, prim::kPrimShape, ShapeInfer, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -346,6 +346,19 @@ size_t CheckAndConvertUtils::CheckAbstractTypeSame(const std::vector<AbstractBas
|
|||
return 0;
|
||||
}
|
||||
|
||||
bool CheckAndConvertUtils::CheckValueSame(const ValuePtr &value_1, const ValuePtr &value_2) {
|
||||
MS_EXCEPTION_IF_NULL(value_1);
|
||||
MS_EXCEPTION_IF_NULL(value_2);
|
||||
if (!value_1->IsSameTypeId(value_2->tid())) {
|
||||
return false;
|
||||
}
|
||||
if (value_1->isa<tensor::Tensor>()) {
|
||||
auto list_tensor_value = value_2->cast_ptr<tensor::Tensor>();
|
||||
return value_1->cast_ptr<tensor::Tensor>()->ValueEqual(*list_tensor_value);
|
||||
}
|
||||
return *value_1 == *value_2;
|
||||
}
|
||||
|
||||
void CheckAndConvertUtils::ConvertAttrValueInExport(const std::string &op_type, const std::string &attr_name,
|
||||
ValuePtr *const value) {
|
||||
if (value == nullptr || *value == nullptr) {
|
||||
|
|
|
@ -329,6 +329,7 @@ class MS_CORE_API CheckAndConvertUtils {
|
|||
static void GetFormatStringVal(const PrimitivePtr &prim, std::string *format);
|
||||
static size_t CheckAbstractShapeSame(const std::vector<AbstractBasePtr> &abs_list);
|
||||
static size_t CheckAbstractTypeSame(const std::vector<AbstractBasePtr> &abs_list);
|
||||
static bool CheckValueSame(const ValuePtr &value_1, const ValuePtr &value_2);
|
||||
|
||||
private:
|
||||
static TypePtr _CheckTypeSame(const std::map<std::string, TypePtr> &args, const std::string &prim_name,
|
||||
|
|
|
@ -19,6 +19,8 @@ from mindspore.ops.composite.multitype_ops._constexpr_utils import check_equal
|
|||
from mindspore.ops.composite import base
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common import COOTensor
|
||||
from ...operations._sequence_ops import SequenceMul
|
||||
|
||||
|
||||
mul = base.MultitypeFuncGraph("mul", True)
|
||||
"""
|
||||
|
@ -101,6 +103,10 @@ def _list_mul_scalar(x, y):
|
|||
Outputs:
|
||||
List.
|
||||
"""
|
||||
if not isinstance(y, int):
|
||||
raise TypeError(f"can't multiply sequence by non-int of type '{type(y)}'.")
|
||||
if F.is_sequence_shape_unknown(x) or not F.isconstant(y):
|
||||
return SequenceMul()(x, y)
|
||||
res = []
|
||||
i = 0
|
||||
while i < y:
|
||||
|
@ -117,6 +123,10 @@ def _scalar_mul_list(x, y):
|
|||
Outputs:
|
||||
List.
|
||||
"""
|
||||
if not isinstance(x, int):
|
||||
raise TypeError(f"can't multiply sequence by non-int of type '{type(x)}'.")
|
||||
if not F.isconstant(x) or F.is_sequence_shape_unknown(y):
|
||||
return SequenceMul()(y, x)
|
||||
res = []
|
||||
i = 0
|
||||
while i < x:
|
||||
|
@ -133,6 +143,10 @@ def _tuple_mul_scalar(x, y):
|
|||
Outputs:
|
||||
Tuple.
|
||||
"""
|
||||
if not isinstance(y, int):
|
||||
raise TypeError(f"can't multiply sequence by non-int of type '{type(y)}'.")
|
||||
if F.is_sequence_shape_unknown(x) or not F.isconstant(y):
|
||||
return SequenceMul()(x, y)
|
||||
res = ()
|
||||
i = 0
|
||||
while i < y:
|
||||
|
@ -149,6 +163,10 @@ def _scalar_mul_tuple(x, y):
|
|||
Outputs:
|
||||
Tuple.
|
||||
"""
|
||||
if not isinstance(x, int):
|
||||
raise TypeError(f"can't multiply sequence by non-int of type '{type(x)}'.")
|
||||
if not F.isconstant(x) or F.is_sequence_shape_unknown(y):
|
||||
return SequenceMul()(y, x)
|
||||
res = ()
|
||||
i = 0
|
||||
while i < x:
|
||||
|
|
|
@ -99,3 +99,31 @@ class SequenceCount(Primitive):
|
|||
def __init__(self):
|
||||
"""Initialize ListAppend"""
|
||||
self.init_prim_io_names(inputs=['sequence', 'target'], outputs=['output_data'])
|
||||
|
||||
|
||||
class SequenceMul(Primitive):
|
||||
r"""
|
||||
Support sequence multiplication operation 'seq.mul(scalar)'.
|
||||
|
||||
.. note::
|
||||
This it is only for internal used.
|
||||
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
||||
|
||||
Inputs:
|
||||
- **sequence** (Union[List, Tuple]) - The sequence to count elements.
|
||||
- **scalar** (Any Object) - The times to replicate the sequence.
|
||||
|
||||
Outputs:
|
||||
List or tuple with 'scalar' times multiplication.
|
||||
|
||||
Raises:
|
||||
TypeError: The 'sequence' is not list or tuple.
|
||||
ValueError: Both 'sequence' and 'scalar' is constant.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize ListAppend"""
|
||||
self.init_prim_io_names(inputs=['sequence', 'scalar'], outputs=['output_data'])
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test index operation for dynamic sequence in graph mode"""
|
||||
from mindspore.common import mutable
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import Tensor
|
||||
from mindspore import jit
|
||||
from mindspore import context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_dynamic_sequence_index_dynamic_length_sequence_const_index():
|
||||
"""
|
||||
Feature: Sequence index operation.
|
||||
Description: If sequence is dynamic length, index() will return variable integer.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
a = mutable([1, 2, 3, 4], True)
|
||||
ret = a.index(0)
|
||||
return isinstance(ret, int), F.isconstant(ret)
|
||||
|
||||
ret1, ret2 = foo()
|
||||
assert ret1
|
||||
assert not ret2
|
||||
|
||||
|
||||
def test_dynamic_sequence_index_variable_element_sequence_const_index():
|
||||
"""
|
||||
Feature: Sequence index operation.
|
||||
Description: If sequence has variable element, index() will return variable integer.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo(x):
|
||||
a = [x, x+1, x+2]
|
||||
ret = a.index(0)
|
||||
return isinstance(ret, int), F.isconstant(ret)
|
||||
|
||||
ret1, ret2 = foo(Tensor([0]))
|
||||
assert ret1
|
||||
assert not ret2
|
||||
|
||||
|
||||
def test_dynamic_sequence_index_constant_sequence_dynamic_index():
|
||||
"""
|
||||
Feature: Sequence index operation.
|
||||
Description: If target is dynamic, index() will return variable integer.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo(x):
|
||||
a = [Tensor([1]), Tensor([2]), Tensor([3])]
|
||||
ret = a.index(x)
|
||||
return isinstance(ret, int), F.isconstant(ret)
|
||||
|
||||
ret1, ret2 = foo(Tensor([0]))
|
||||
assert ret1
|
||||
assert not ret2
|
|
@ -0,0 +1,98 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test mul operation for dynamic sequence and variable integer in graph mode"""
|
||||
from mindspore.common import mutable
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import Tensor
|
||||
from mindspore import jit
|
||||
from mindspore import context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_dynamic_length_sequence_mul_constant_scalar():
|
||||
"""
|
||||
Feature: Dynamic length sequence mul operation.
|
||||
Description: Dynamic length sequence mul constant scalar should return dynamic length sequence.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
a = mutable([1, 2, 3, 4], True)
|
||||
ret = a * 5
|
||||
return F.is_sequence_value_unknown(ret), F.is_sequence_shape_unknown(ret)
|
||||
|
||||
ret1, ret2 = foo()
|
||||
assert ret1
|
||||
assert ret2
|
||||
|
||||
|
||||
def test_constant_length_sequence_mul_constant_scalar():
|
||||
"""
|
||||
Feature: Dynamic length sequence mul operation.
|
||||
Description: Constant length sequence mul constant scalar should return constant length sequence.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo(x):
|
||||
a = [x, x + 1, x + 2]
|
||||
ret = a * 5
|
||||
return F.is_sequence_value_unknown(ret), F.is_sequence_shape_unknown(ret)
|
||||
|
||||
ret1, ret2 = foo(Tensor([1]))
|
||||
assert ret1
|
||||
assert not ret2
|
||||
|
||||
|
||||
def test_constant_length_sequence_mul_variable_scalar():
|
||||
"""
|
||||
Feature: Dynamic length sequence mul operation.
|
||||
Description: Constant length sequence mul variable scalar should return variable length sequence.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(grad_for_scalar=True)
|
||||
|
||||
@jit
|
||||
def foo(x):
|
||||
a = [1, 2, 3, 4]
|
||||
ret = a * x
|
||||
return F.is_sequence_value_unknown(ret), F.is_sequence_shape_unknown(ret)
|
||||
|
||||
ret1, ret2 = foo(5)
|
||||
assert ret1
|
||||
assert ret2
|
||||
context.set_context(grad_for_scalar=False)
|
||||
|
||||
|
||||
def test_variable_length_sequence_mul_variable_scalar():
|
||||
"""
|
||||
Feature: Dynamic length sequence mul operation.
|
||||
Description: Constant length sequence mul variable scalar should return variable length sequence.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(grad_for_scalar=True)
|
||||
|
||||
@jit
|
||||
def foo(x):
|
||||
a = mutable([1, 2, 3, 4], True)
|
||||
ret = a * x
|
||||
return F.is_sequence_value_unknown(ret), F.is_sequence_shape_unknown(ret)
|
||||
|
||||
ret1, ret2 = foo(5)
|
||||
assert ret1
|
||||
assert ret2
|
||||
context.set_context(grad_for_scalar=False)
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test list index operation"""
|
||||
|
||||
import pytest
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_list_index():
|
||||
"""
|
||||
Feature: list index.
|
||||
Description: support list index operation.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
x = [1, 2, 3, 4]
|
||||
return x.index(1)
|
||||
assert foo() == 0
|
||||
|
||||
|
||||
def test_list_index_2():
|
||||
"""
|
||||
Feature: list index.
|
||||
Description: support list index operation.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
x = ['1', '2', 3, 4]
|
||||
return x.index('2')
|
||||
assert foo() == 1
|
||||
|
||||
|
||||
def test_list_index_3():
|
||||
"""
|
||||
Feature: list index.
|
||||
Description: support list index operation.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
x = [Tensor([1]), Tensor([2]), Tensor([3]), Tensor([4])]
|
||||
return x.index(Tensor([3]))
|
||||
assert foo() == 2
|
||||
|
||||
|
||||
def test_list_index_not_found():
|
||||
"""
|
||||
Feature: list index.
|
||||
Description: support list index operation.
|
||||
Expectation: Raise ValueError.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
x = [1, 2, 3, 4]
|
||||
return x.index(5)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
foo()
|
||||
assert "is not in" in str(info.value)
|
||||
|
||||
|
||||
def test_list_index_not_found_2():
|
||||
"""
|
||||
Feature: list index.
|
||||
Description: support list index operation.
|
||||
Expectation: Raise ValueError.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
x = [1, 2, 3, 4]
|
||||
return x.index(Tensor(1))
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
foo()
|
||||
assert "is not in" in str(info.value)
|
|
@ -14,10 +14,13 @@
|
|||
# ============================================================================
|
||||
""" test list mul number """
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, context
|
||||
from mindspore import Tensor, context, jit
|
||||
from mindspore import nn
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -35,8 +38,6 @@ def test_list_mul_number():
|
|||
Description: test_list_mul_number
|
||||
Expectation: the results are as expected
|
||||
"""
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net()
|
||||
expect_ret0 = [Tensor([1, 2, 3])] * 5
|
||||
expect_ret1 = (Tensor([1, 2, 3]),) * 0
|
||||
|
@ -45,3 +46,18 @@ def test_list_mul_number():
|
|||
for i in range(len(net()[0])):
|
||||
assert np.array_equal(net()[0][i].asnumpy(), expect_ret0[i].asnumpy())
|
||||
assert net()[1] == expect_ret1
|
||||
|
||||
|
||||
def test_list_mul_non_integer_number():
|
||||
"""
|
||||
Feature: list multiple non-integet number.
|
||||
Description: list can only multiply integet number.
|
||||
Expectation: Raise TypeError.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
x = [1, 2, 3, 4]
|
||||
return x * 2.0
|
||||
with pytest.raises(TypeError) as error_info:
|
||||
foo()
|
||||
assert "can't multiply sequence by non-int of type" in str(error_info)
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test tuple index operation"""
|
||||
|
||||
import pytest
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_tuple_index():
|
||||
"""
|
||||
Feature: tuple index.
|
||||
Description: support tuple index operation.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
x = (1, 2, 3, 4)
|
||||
return x.index(4)
|
||||
assert foo() == 3
|
||||
|
||||
|
||||
def test_tuple_index_2():
|
||||
"""
|
||||
Feature: tuple index.
|
||||
Description: support tuple index operation.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
x = ('1', '2', 3, 4)
|
||||
return x.index(3)
|
||||
assert foo() == 2
|
||||
|
||||
|
||||
def test_tuple_index_3():
|
||||
"""
|
||||
Feature: tuple index.
|
||||
Description: support tuple index operation.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
x = (Tensor([1]), Tensor([2]), Tensor([3]), Tensor([4]))
|
||||
return x.index(Tensor([1]))
|
||||
assert foo() == 0
|
||||
|
||||
|
||||
def test_tuple_index_not_found():
|
||||
"""
|
||||
Feature: tuple index.
|
||||
Description: support tuple index operation.
|
||||
Expectation: Raise ValueError.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
x = (1, 2, 3, 4)
|
||||
return x.index(5)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
foo()
|
||||
assert "is not in" in str(info.value)
|
||||
|
||||
|
||||
def test_tuple_index_not_found_2():
|
||||
"""
|
||||
Feature: tuple index.
|
||||
Description: support tuple index operation.
|
||||
Expectation: Raise ValueError.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
x = (1, 2, 3, 4)
|
||||
return x.index(Tensor(2))
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
foo()
|
||||
assert "is not in" in str(info.value)
|
|
@ -14,8 +14,9 @@
|
|||
# ============================================================================
|
||||
""" test tuple mul number """
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, context
|
||||
from mindspore import Tensor, context, jit
|
||||
from mindspore import nn
|
||||
|
||||
|
||||
|
@ -45,3 +46,18 @@ def test_tuple_mul_number():
|
|||
for i in range(len(net()[0])):
|
||||
assert np.array_equal(net()[0][i].asnumpy(), expect_ret0[i].asnumpy())
|
||||
assert net()[1] == expect_ret1
|
||||
|
||||
|
||||
def test_tuple_mul_non_integer_number():
|
||||
"""
|
||||
Feature: tuple multiple non-integer number.
|
||||
Description: tuple can only multiply integer number.
|
||||
Expectation: Raise TypeError.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
x = (1, 2, 3, 4)
|
||||
return x * 2.0
|
||||
with pytest.raises(TypeError) as error_info:
|
||||
foo()
|
||||
assert "can't multiply sequence by non-int of type" in str(error_info)
|
||||
|
|
Loading…
Reference in New Issue