!46418 Add sequence index and sequence mul operation and fix some infer function.

Merge pull request !46418 from LiangZhibo/list_ops
This commit is contained in:
i-robot 2022-12-08 07:53:51 +00:00 committed by Gitee
commit fe7acc0aa3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
21 changed files with 799 additions and 93 deletions

View File

@ -126,6 +126,7 @@ BuiltInTypeMap &GetMethodMap() {
{"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext
{"__bool__", std::string("tuple_bool")}, // C.tuple_bool
{"count", prim::kPrimSequenceCount}, // P.sequence_count
{"index", prim::kPrimSequenceIndex}, // P.sequenc_index
}},
{kObjectTypeList,
{
@ -143,6 +144,7 @@ BuiltInTypeMap &GetMethodMap() {
{"reverse", std::string("list_reverse")}, // C.list_reverse
{"extend", std::string("list_extend")}, // C.list_extend
{"count", prim::kPrimSequenceCount}, // P.sequence_count
{"index", prim::kPrimSequenceIndex}, // P.sequence_index
}},
{kObjectTypeDictionary,
{

View File

@ -381,13 +381,9 @@ AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type
if ((base_shape->isa<Shape>())) {
auto shape = base_shape->cast<ShapePtr>();
MS_EXCEPTION_IF_NULL(shape);
auto shape_vec = shape->shape();
// if the size of shape list is empty, return an scalar abstract
if (shape_vec.empty() && (!type->isa<TensorType>())) {
abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kAnyValue, type);
return abs_scalar;
}
return MakeAbstractTensor(shape, type);
} else if (base_shape->isa<NoShape>() && type->isa<Number>()) {
return std::make_shared<abstract::AbstractScalar>(kAnyValue, type);
} else if (base_shape->isa<TupleShape>() && type->isa<Tuple>()) {
auto shape_tuple = base_shape->cast_ptr<TupleShape>();
auto type_tuple = type->cast_ptr<Tuple>();
@ -416,7 +412,8 @@ AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type
// Return monad abstract if it is monad type.
return MakeMonadAbstract(type->cast<MonadTypePtr>());
} else {
MS_LOG(EXCEPTION) << "Evaluator return invalid shape " << base_shape->ToString() << "or type. " << type->ToString();
MS_LOG(EXCEPTION) << "Evaluator return invalid shape " << base_shape->ToString() << " or type. "
<< type->ToString();
}
}
} // namespace abstract

View File

@ -252,6 +252,8 @@ constexpr auto kSearchSorted = "SearchSorted";
constexpr auto kListAppend = "ListAppend";
constexpr auto kSequenceAdd = "SequenceAdd";
constexpr auto kSequenceCount = "SequenceCount";
constexpr auto kSequenceIndex = "SequenceIndex";
constexpr auto kSequenceMul = "SequenceMul";
// NN
constexpr auto kFractionalMaxPoolWithFixedKsize = "FractionalMaxPoolWithFixedKsize";
@ -1555,6 +1557,8 @@ GVAR_DEF(PrimitivePtr, kPrimSequenceLen, std::make_shared<Primitive>("sequence_l
GVAR_DEF(PrimitivePtr, kPrimListAppend, std::make_shared<Primitive>(kListAppend));
GVAR_DEF(PrimitivePtr, kPrimSequenceAdd, std::make_shared<Primitive>(kSequenceAdd));
GVAR_DEF(PrimitivePtr, kPrimSequenceCount, std::make_shared<Primitive>(kSequenceCount));
GVAR_DEF(PrimitivePtr, kPrimSequenceIndex, std::make_shared<Primitive>(kSequenceIndex));
GVAR_DEF(PrimitivePtr, kPrimSequenceMul, std::make_shared<Primitive>(kSequenceMul));
// Other miscellaneous
GVAR_DEF(PrimitivePtr, kPrimSampleDistortedBoundingBoxV2, std::make_shared<Primitive>(kSampleDistortedBoundingBoxV2));

View File

@ -49,8 +49,7 @@ AbstractBasePtr CheckAndGetElementType(const abstract::AbstractSequencePtr input
return elements[0];
}
AbstractBasePtr SequenceAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
AbstractBasePtr SequenceAddInferInner(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
constexpr size_t input_len = 2;
@ -109,6 +108,22 @@ AbstractBasePtr SequenceAddInfer(const abstract::AnalysisEnginePtr &, const Prim
return input_2->Clone();
}
MIND_API_OPERATOR_IMPL(SequenceAdd, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(SequenceAdd, prim::kPrimSequenceAdd, SequenceAddInfer, nullptr, true);
class SequenceAddInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
return SequenceAddInferInner(primitive, input_args)->BuildShape();
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
return SequenceAddInferInner(prim, input_args)->BuildType();
}
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
return SequenceAddInferInner(primitive, input_args);
}
};
REGISTER_PRIMITIVE_OP_INFER_IMPL(SequenceAdd, prim::kPrimSequenceAdd, SequenceAddInfer, true);
} // namespace ops
} // namespace mindspore

View File

@ -26,61 +26,52 @@
namespace mindspore {
namespace ops {
AbstractBasePtr SequenceCountInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractScalar>(kAnyValue, kInt64);
}
bool ComparesTwoValues(const ValuePtr &value_1, const ValuePtr &value_2) {
MS_EXCEPTION_IF_NULL(value_1);
MS_EXCEPTION_IF_NULL(value_2);
if (!value_1->IsSameTypeId(value_2->tid())) {
return false;
}
if (value_1->isa<tensor::Tensor>()) {
auto list_tensor_value = value_2->cast_ptr<tensor::Tensor>();
MS_EXCEPTION_IF_NULL(list_tensor_value);
return value_1->cast_ptr<tensor::Tensor>()->ValueEqual(*list_tensor_value);
}
return *value_1 == *value_2;
}
ValuePtr SequenceCountInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
const size_t input_num = 2;
auto prim_name = prim->name();
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
constexpr size_t seq_index = 0;
constexpr size_t target_index = 1;
auto input_abs = input_args[seq_index];
auto target_abs = input_args[target_index];
if (!input_abs->isa<abstract::AbstractSequence>()) {
MS_EXCEPTION(TypeError) << "For primitive '" << prim_name << "', the first input must be a list or tuple, "
<< "but got: " << input_abs->ToString();
}
auto seq_abs = input_abs->cast<abstract::AbstractSequencePtr>();
if (seq_abs->dynamic_len()) {
return nullptr;
}
auto target_value = target_abs->BuildValue();
if (seq_abs->BuildValue() == kAnyValue || target_value == kAnyValue) {
return nullptr;
}
const auto &seq_elements = seq_abs->elements();
int64_t count = 0;
for (auto element : seq_elements) {
if (ComparesTwoValues(target_value, element->BuildValue())) {
++count;
}
}
return MakeValue(count);
}
MIND_API_OPERATOR_IMPL(SequenceCount, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(SequenceCount, prim::kPrimSequenceCount, SequenceCountInfer, SequenceCountInferValue,
true);
class SequenceCountInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
return abstract::kNoShape;
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
return kInt64;
}
ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
MS_EXCEPTION_IF_NULL(primitive);
constexpr size_t input_num = 2;
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
constexpr size_t seq_index = 0;
constexpr size_t target_index = 1;
auto input_abs = input_args[seq_index];
auto target_abs = input_args[target_index];
if (!input_abs->isa<abstract::AbstractSequence>()) {
MS_EXCEPTION(TypeError) << "For primitive '" << prim_name << "', the first input must be a list or tuple, "
<< "but got: " << input_abs->ToString();
}
auto seq_abs = input_abs->cast<abstract::AbstractSequencePtr>();
if (seq_abs->dynamic_len()) {
return nullptr;
}
auto target_value = target_abs->BuildValue();
if (seq_abs->BuildValue() == kAnyValue || target_value == kAnyValue) {
return nullptr;
}
const auto &seq_elements = seq_abs->elements();
int64_t count = 0;
for (auto element : seq_elements) {
if (CheckAndConvertUtils::CheckValueSame(target_value, element->BuildValue())) {
++count;
}
}
return MakeValue(count);
}
};
REGISTER_PRIMITIVE_OP_INFER_IMPL(SequenceCount, prim::kPrimSequenceCount, SequenceCountInfer, true);
} // namespace ops
} // namespace mindspore

View File

@ -22,7 +22,7 @@
namespace mindspore {
namespace ops {
/// \brief Sequence addition operation
/// \brief Sequence count operation.
class MIND_API SequenceCount : public BaseOperator {
public:
MIND_API_BASE_MEMBER(SequenceCount);

View File

@ -0,0 +1,78 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/sequence_index.h"
#include <vector>
#include <memory>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "include/common/utils/utils.h"
#include "mindapi/src/helper.h"
#include "abstract/ops/primitive_infer_map.h"
namespace mindspore {
namespace ops {
MIND_API_OPERATOR_IMPL(SequenceIndex, BaseOperator);
class SequenceIndexInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
return abstract::kNoShape;
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
return kInt64;
}
ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
MS_EXCEPTION_IF_NULL(primitive);
const size_t input_num = 2;
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
constexpr size_t seq_index = 0;
constexpr size_t target_index = 1;
auto input_abs = input_args[seq_index];
auto target_abs = input_args[target_index];
if (!input_abs->isa<abstract::AbstractSequence>()) {
MS_EXCEPTION(TypeError) << "For primitive '" << prim_name << "', the first input must be a list or tuple, "
<< "but got: " << input_abs->ToString();
}
auto seq_abs = input_abs->cast<abstract::AbstractSequencePtr>();
if (seq_abs->dynamic_len()) {
return nullptr;
}
auto target_value = target_abs->BuildValue();
if (seq_abs->BuildValue() == kAnyValue || target_value == kAnyValue) {
return nullptr;
}
const auto &seq_elements = seq_abs->elements();
for (size_t i = 0; i < seq_elements.size(); ++i) {
auto element = seq_elements[i];
if (CheckAndConvertUtils::CheckValueSame(target_value, element->BuildValue())) {
return MakeValue(static_cast<int64_t>(i));
}
}
MS_EXCEPTION(ValueError) << target_value->ToString() << " is not in " << seq_abs->ToString();
}
};
REGISTER_PRIMITIVE_OP_INFER_IMPL(SequenceIndex, prim::kPrimSequenceIndex, SequenceIndexInfer, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SEQUENCE_INDEX_H_
#define MINDSPORE_CORE_OPS_SEQUENCE_INDEX_H_
#include "ops/base_operator.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
namespace ops {
/// \brief Sequence index operation.
class MIND_API SequenceIndex : public BaseOperator {
public:
MIND_API_BASE_MEMBER(SequenceIndex);
/// \brief Constructor.
SequenceIndex() : BaseOperator(prim::kSequenceIndex) {}
/// \brief Init function.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SEQUENCE_INDEX_H_

View File

@ -0,0 +1,79 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/sequence_mul.h"
#include <vector>
#include <memory>
#include <set>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "include/common/utils/utils.h"
#include "mindapi/src/helper.h"
#include "abstract/ops/primitive_infer_map.h"
namespace mindspore {
namespace ops {
AbstractBasePtr SequenceMulInferInner(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
constexpr size_t input_len = 2;
constexpr size_t seq_index = 0;
constexpr size_t scalar_index = 1;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_len, prim_name);
auto first_abs = input_args[seq_index];
if (!first_abs->isa<abstract::AbstractSequence>()) {
MS_EXCEPTION(TypeError) << "For '" << prim_name
<< "', the first input should be tuple or list but got: " << first_abs->ToString();
}
auto seq_abs = first_abs->cast<abstract::AbstractSequencePtr>();
auto scalar_abs = input_args[scalar_index];
const std::set<TypePtr> scalar_valid_types = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTypeValid("scalar", scalar_abs->BuildType(), scalar_valid_types, prim_name);
if (seq_abs->BuildValue() != kAnyValue && scalar_abs->BuildValue() != kAnyValue) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', at least one of the inputs should be kAnyValue, but got "
<< "sequence input: " << seq_abs->BuildValue()
<< " and scalar input: " << scalar_abs->BuildValue();
}
if (seq_abs->dynamic_len()) {
return seq_abs;
}
auto ret = seq_abs->Clone()->cast<abstract::AbstractSequencePtr>();
ret->CheckAndConvertToDynamicLenSequence();
return ret;
}
MIND_API_OPERATOR_IMPL(SequenceMul, BaseOperator);
class SequenceMulInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
return SequenceMulInferInner(primitive, input_args)->BuildShape();
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
return SequenceMulInferInner(prim, input_args)->BuildType();
}
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
return SequenceMulInferInner(primitive, input_args);
}
};
REGISTER_PRIMITIVE_OP_INFER_IMPL(SequenceMul, prim::kPrimSequenceMul, SequenceMulInfer, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SEQUENCE_MUL_H_
#define MINDSPORE_CORE_OPS_SEQUENCE_MUL_H_
#include "ops/base_operator.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
namespace ops {
/// \brief Sequence mul integer operation.
class MIND_API SequenceMul : public BaseOperator {
public:
MIND_API_BASE_MEMBER(SequenceMul);
/// \brief Constructor.
SequenceMul() : BaseOperator(prim::kSequenceMul) {}
/// \brief Init function.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SEQUENCE_MUL_H_

View File

@ -27,14 +27,12 @@
namespace mindspore {
namespace ops {
MIND_API_OPERATOR_IMPL(Shape, BaseOperator);
AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
AbstractBasePtr InferInner(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
// Only called when the input of shape is dynamic shape/rank tensor.
// infer shape
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("shape infer", int64_t(input_args.size()), kEqual, 1, op_name);
(void)CheckAndConvertUtils::CheckInteger("shape infer", static_cast<int64_t>(input_args.size()), kEqual, 1, op_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto in_shape = shape_map[kShape];
@ -57,25 +55,43 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
}
return abs;
}
ValuePtr ShapeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("shape infer", int64_t(input_args.size()), kEqual, 1, op_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
std::set<TypePtr> valid_params_types = {kTensorType};
(void)CheckAndConvertUtils::CheckSubClass("shape type", input_args[0]->BuildType(), valid_params_types, op_name);
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto inshape = shape_map[kShape];
if (std::any_of(inshape.begin(), inshape.end(), [](ShapeValueDType shape) {
return shape == abstract::Shape::kShapeRankAny || shape == abstract::Shape::kShapeDimAny;
})) {
// If the input of shape is dynamic shape/rank tensor, value can not be directly built.
// Run infer of shape.
return nullptr;
MIND_API_OPERATOR_IMPL(Shape, BaseOperator);
class ShapeInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
return InferInner(primitive, input_args)->BuildShape();
}
return MakeValue(inshape);
}
REGISTER_PRIMITIVE_EVAL_IMPL(Shape, prim::kPrimShape, ShapeInfer, ShapeInferValue, true);
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
return InferInner(prim, input_args)->BuildType();
}
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
return InferInner(primitive, input_args);
}
ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("shape infer", int64_t(input_args.size()), kEqual, 1, op_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
std::set<TypePtr> valid_params_types = {kTensorType};
(void)CheckAndConvertUtils::CheckSubClass("shape type", input_args[0]->BuildType(), valid_params_types, op_name);
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
if (shape_map.count(kShape) == 0) {
MS_LOG(EXCEPTION) << "For primitive " << op_name << " the input convert shape failed.";
}
const auto &inshape = shape_map[kShape];
if (IsDynamic(inshape)) {
// If the input of shape is dynamic shape/rank tensor, value can not be directly built.
// Run infer of shape.
return nullptr;
}
return MakeValue(inshape);
}
};
REGISTER_PRIMITIVE_OP_INFER_IMPL(Shape, prim::kPrimShape, ShapeInfer, true);
} // namespace ops
} // namespace mindspore

View File

@ -346,6 +346,19 @@ size_t CheckAndConvertUtils::CheckAbstractTypeSame(const std::vector<AbstractBas
return 0;
}
bool CheckAndConvertUtils::CheckValueSame(const ValuePtr &value_1, const ValuePtr &value_2) {
MS_EXCEPTION_IF_NULL(value_1);
MS_EXCEPTION_IF_NULL(value_2);
if (!value_1->IsSameTypeId(value_2->tid())) {
return false;
}
if (value_1->isa<tensor::Tensor>()) {
auto list_tensor_value = value_2->cast_ptr<tensor::Tensor>();
return value_1->cast_ptr<tensor::Tensor>()->ValueEqual(*list_tensor_value);
}
return *value_1 == *value_2;
}
void CheckAndConvertUtils::ConvertAttrValueInExport(const std::string &op_type, const std::string &attr_name,
ValuePtr *const value) {
if (value == nullptr || *value == nullptr) {

View File

@ -329,6 +329,7 @@ class MS_CORE_API CheckAndConvertUtils {
static void GetFormatStringVal(const PrimitivePtr &prim, std::string *format);
static size_t CheckAbstractShapeSame(const std::vector<AbstractBasePtr> &abs_list);
static size_t CheckAbstractTypeSame(const std::vector<AbstractBasePtr> &abs_list);
static bool CheckValueSame(const ValuePtr &value_1, const ValuePtr &value_2);
private:
static TypePtr _CheckTypeSame(const std::map<std::string, TypePtr> &args, const std::string &prim_name,

View File

@ -19,6 +19,8 @@ from mindspore.ops.composite.multitype_ops._constexpr_utils import check_equal
from mindspore.ops.composite import base
from mindspore.ops import functional as F
from mindspore.common import COOTensor
from ...operations._sequence_ops import SequenceMul
mul = base.MultitypeFuncGraph("mul", True)
"""
@ -101,6 +103,10 @@ def _list_mul_scalar(x, y):
Outputs:
List.
"""
if not isinstance(y, int):
raise TypeError(f"can't multiply sequence by non-int of type '{type(y)}'.")
if F.is_sequence_shape_unknown(x) or not F.isconstant(y):
return SequenceMul()(x, y)
res = []
i = 0
while i < y:
@ -117,6 +123,10 @@ def _scalar_mul_list(x, y):
Outputs:
List.
"""
if not isinstance(x, int):
raise TypeError(f"can't multiply sequence by non-int of type '{type(x)}'.")
if not F.isconstant(x) or F.is_sequence_shape_unknown(y):
return SequenceMul()(y, x)
res = []
i = 0
while i < x:
@ -133,6 +143,10 @@ def _tuple_mul_scalar(x, y):
Outputs:
Tuple.
"""
if not isinstance(y, int):
raise TypeError(f"can't multiply sequence by non-int of type '{type(y)}'.")
if F.is_sequence_shape_unknown(x) or not F.isconstant(y):
return SequenceMul()(x, y)
res = ()
i = 0
while i < y:
@ -149,6 +163,10 @@ def _scalar_mul_tuple(x, y):
Outputs:
Tuple.
"""
if not isinstance(x, int):
raise TypeError(f"can't multiply sequence by non-int of type '{type(x)}'.")
if not F.isconstant(x) or F.is_sequence_shape_unknown(y):
return SequenceMul()(y, x)
res = ()
i = 0
while i < x:

View File

@ -99,3 +99,31 @@ class SequenceCount(Primitive):
def __init__(self):
"""Initialize ListAppend"""
self.init_prim_io_names(inputs=['sequence', 'target'], outputs=['output_data'])
class SequenceMul(Primitive):
r"""
Support sequence multiplication operation 'seq.mul(scalar)'.
.. note::
This it is only for internal used.
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
Inputs:
- **sequence** (Union[List, Tuple]) - The sequence to count elements.
- **scalar** (Any Object) - The times to replicate the sequence.
Outputs:
List or tuple with 'scalar' times multiplication.
Raises:
TypeError: The 'sequence' is not list or tuple.
ValueError: Both 'sequence' and 'scalar' is constant.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self):
"""Initialize ListAppend"""
self.init_prim_io_names(inputs=['sequence', 'scalar'], outputs=['output_data'])

View File

@ -0,0 +1,76 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test index operation for dynamic sequence in graph mode"""
from mindspore.common import mutable
from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore import jit
from mindspore import context
context.set_context(mode=context.GRAPH_MODE)
def test_dynamic_sequence_index_dynamic_length_sequence_const_index():
"""
Feature: Sequence index operation.
Description: If sequence is dynamic length, index() will return variable integer.
Expectation: No exception.
"""
@jit
def foo():
a = mutable([1, 2, 3, 4], True)
ret = a.index(0)
return isinstance(ret, int), F.isconstant(ret)
ret1, ret2 = foo()
assert ret1
assert not ret2
def test_dynamic_sequence_index_variable_element_sequence_const_index():
"""
Feature: Sequence index operation.
Description: If sequence has variable element, index() will return variable integer.
Expectation: No exception.
"""
@jit
def foo(x):
a = [x, x+1, x+2]
ret = a.index(0)
return isinstance(ret, int), F.isconstant(ret)
ret1, ret2 = foo(Tensor([0]))
assert ret1
assert not ret2
def test_dynamic_sequence_index_constant_sequence_dynamic_index():
"""
Feature: Sequence index operation.
Description: If target is dynamic, index() will return variable integer.
Expectation: No exception.
"""
@jit
def foo(x):
a = [Tensor([1]), Tensor([2]), Tensor([3])]
ret = a.index(x)
return isinstance(ret, int), F.isconstant(ret)
ret1, ret2 = foo(Tensor([0]))
assert ret1
assert not ret2

View File

@ -0,0 +1,98 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test mul operation for dynamic sequence and variable integer in graph mode"""
from mindspore.common import mutable
from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore import jit
from mindspore import context
context.set_context(mode=context.GRAPH_MODE)
def test_dynamic_length_sequence_mul_constant_scalar():
"""
Feature: Dynamic length sequence mul operation.
Description: Dynamic length sequence mul constant scalar should return dynamic length sequence.
Expectation: No exception.
"""
@jit
def foo():
a = mutable([1, 2, 3, 4], True)
ret = a * 5
return F.is_sequence_value_unknown(ret), F.is_sequence_shape_unknown(ret)
ret1, ret2 = foo()
assert ret1
assert ret2
def test_constant_length_sequence_mul_constant_scalar():
"""
Feature: Dynamic length sequence mul operation.
Description: Constant length sequence mul constant scalar should return constant length sequence.
Expectation: No exception.
"""
@jit
def foo(x):
a = [x, x + 1, x + 2]
ret = a * 5
return F.is_sequence_value_unknown(ret), F.is_sequence_shape_unknown(ret)
ret1, ret2 = foo(Tensor([1]))
assert ret1
assert not ret2
def test_constant_length_sequence_mul_variable_scalar():
"""
Feature: Dynamic length sequence mul operation.
Description: Constant length sequence mul variable scalar should return variable length sequence.
Expectation: No exception.
"""
context.set_context(grad_for_scalar=True)
@jit
def foo(x):
a = [1, 2, 3, 4]
ret = a * x
return F.is_sequence_value_unknown(ret), F.is_sequence_shape_unknown(ret)
ret1, ret2 = foo(5)
assert ret1
assert ret2
context.set_context(grad_for_scalar=False)
def test_variable_length_sequence_mul_variable_scalar():
"""
Feature: Dynamic length sequence mul operation.
Description: Constant length sequence mul variable scalar should return variable length sequence.
Expectation: No exception.
"""
context.set_context(grad_for_scalar=True)
@jit
def foo(x):
a = mutable([1, 2, 3, 4], True)
ret = a * x
return F.is_sequence_value_unknown(ret), F.is_sequence_shape_unknown(ret)
ret1, ret2 = foo(5)
assert ret1
assert ret2
context.set_context(grad_for_scalar=False)

View File

@ -0,0 +1,92 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test list index operation"""
import pytest
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
def test_list_index():
"""
Feature: list index.
Description: support list index operation.
Expectation: No exception.
"""
@jit
def foo():
x = [1, 2, 3, 4]
return x.index(1)
assert foo() == 0
def test_list_index_2():
"""
Feature: list index.
Description: support list index operation.
Expectation: No exception.
"""
@jit
def foo():
x = ['1', '2', 3, 4]
return x.index('2')
assert foo() == 1
def test_list_index_3():
"""
Feature: list index.
Description: support list index operation.
Expectation: No exception.
"""
@jit
def foo():
x = [Tensor([1]), Tensor([2]), Tensor([3]), Tensor([4])]
return x.index(Tensor([3]))
assert foo() == 2
def test_list_index_not_found():
"""
Feature: list index.
Description: support list index operation.
Expectation: Raise ValueError.
"""
@jit
def foo():
x = [1, 2, 3, 4]
return x.index(5)
with pytest.raises(ValueError) as info:
foo()
assert "is not in" in str(info.value)
def test_list_index_not_found_2():
"""
Feature: list index.
Description: support list index operation.
Expectation: Raise ValueError.
"""
@jit
def foo():
x = [1, 2, 3, 4]
return x.index(Tensor(1))
with pytest.raises(ValueError) as info:
foo()
assert "is not in" in str(info.value)

View File

@ -14,10 +14,13 @@
# ============================================================================
""" test list mul number """
import pytest
import numpy as np
from mindspore import Tensor, context
from mindspore import Tensor, context, jit
from mindspore import nn
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def __init__(self):
@ -35,8 +38,6 @@ def test_list_mul_number():
Description: test_list_mul_number
Expectation: the results are as expected
"""
context.set_context(mode=context.GRAPH_MODE)
net = Net()
expect_ret0 = [Tensor([1, 2, 3])] * 5
expect_ret1 = (Tensor([1, 2, 3]),) * 0
@ -45,3 +46,18 @@ def test_list_mul_number():
for i in range(len(net()[0])):
assert np.array_equal(net()[0][i].asnumpy(), expect_ret0[i].asnumpy())
assert net()[1] == expect_ret1
def test_list_mul_non_integer_number():
"""
Feature: list multiple non-integet number.
Description: list can only multiply integet number.
Expectation: Raise TypeError.
"""
@jit
def foo():
x = [1, 2, 3, 4]
return x * 2.0
with pytest.raises(TypeError) as error_info:
foo()
assert "can't multiply sequence by non-int of type" in str(error_info)

View File

@ -0,0 +1,92 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test tuple index operation"""
import pytest
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
def test_tuple_index():
"""
Feature: tuple index.
Description: support tuple index operation.
Expectation: No exception.
"""
@jit
def foo():
x = (1, 2, 3, 4)
return x.index(4)
assert foo() == 3
def test_tuple_index_2():
"""
Feature: tuple index.
Description: support tuple index operation.
Expectation: No exception.
"""
@jit
def foo():
x = ('1', '2', 3, 4)
return x.index(3)
assert foo() == 2
def test_tuple_index_3():
"""
Feature: tuple index.
Description: support tuple index operation.
Expectation: No exception.
"""
@jit
def foo():
x = (Tensor([1]), Tensor([2]), Tensor([3]), Tensor([4]))
return x.index(Tensor([1]))
assert foo() == 0
def test_tuple_index_not_found():
"""
Feature: tuple index.
Description: support tuple index operation.
Expectation: Raise ValueError.
"""
@jit
def foo():
x = (1, 2, 3, 4)
return x.index(5)
with pytest.raises(ValueError) as info:
foo()
assert "is not in" in str(info.value)
def test_tuple_index_not_found_2():
"""
Feature: tuple index.
Description: support tuple index operation.
Expectation: Raise ValueError.
"""
@jit
def foo():
x = (1, 2, 3, 4)
return x.index(Tensor(2))
with pytest.raises(ValueError) as info:
foo()
assert "is not in" in str(info.value)

View File

@ -14,8 +14,9 @@
# ============================================================================
""" test tuple mul number """
import pytest
import numpy as np
from mindspore import Tensor, context
from mindspore import Tensor, context, jit
from mindspore import nn
@ -45,3 +46,18 @@ def test_tuple_mul_number():
for i in range(len(net()[0])):
assert np.array_equal(net()[0][i].asnumpy(), expect_ret0[i].asnumpy())
assert net()[1] == expect_ret1
def test_tuple_mul_non_integer_number():
"""
Feature: tuple multiple non-integer number.
Description: tuple can only multiply integer number.
Expectation: Raise TypeError.
"""
@jit
def foo():
x = (1, 2, 3, 4)
return x * 2.0
with pytest.raises(TypeError) as error_info:
foo()
assert "can't multiply sequence by non-int of type" in str(error_info)