forked from mindspore-Ecosystem/mindspore
Setitem for variable
This commit is contained in:
parent
0515f11dc4
commit
dea0d86f11
|
@ -20,6 +20,7 @@
|
|||
#include "abstract/utils.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
|
@ -169,6 +170,20 @@ AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const Abstra
|
|||
return queue->elements()[index_unsigned_value];
|
||||
}
|
||||
|
||||
void CheckDynamicLengthSequenceSetItem(const std::string &op_name, const AbstractSequencePtr &queue,
|
||||
const AbstractBasePtr &target) {
|
||||
auto element_abs = queue->dynamic_len_element_abs();
|
||||
if (element_abs == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Empty variable len sequence can not setitem.";
|
||||
}
|
||||
const auto precondition_log = "For " + op_name + ", when the queue is dynamic length";
|
||||
const auto standard_abs_description = "element within dynamic length sequence";
|
||||
const auto differ_abs_description = "target element";
|
||||
CheckAndConvertUtils::CheckAbstractTypeAndShapeSame(std::vector<AbstractBasePtr>{element_abs, target},
|
||||
precondition_log, standard_abs_description,
|
||||
differ_abs_description);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tuple or list, a scalar whose value is an int64 number and an object of a subclass of AbstractBase.
|
||||
|
@ -177,39 +192,37 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra
|
|||
auto queue = CheckArg<T>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
||||
auto index_type = index->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(index_type);
|
||||
if (index_type->type_id() != kInt64->type_id()) {
|
||||
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got a "
|
||||
<< index_type->ToString() << " number.";
|
||||
}
|
||||
ValuePtr index_value = index->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(index_value);
|
||||
if (!index_value->isa<Int64Imm>()) {
|
||||
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got "
|
||||
<< index_value->ToString();
|
||||
}
|
||||
constexpr int target_value_index = 2;
|
||||
auto target = args_spec_list[kIndex2];
|
||||
MS_EXCEPTION_IF_NULL(target);
|
||||
if (queue->dynamic_len()) {
|
||||
auto element_abs = queue->dynamic_len_element_abs();
|
||||
if (element_abs == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Empty variable len sequence can not setitem.";
|
||||
} else {
|
||||
auto target = args_spec_list[target_value_index];
|
||||
auto element_abs_shape = element_abs->BuildShape();
|
||||
auto target_shape = target->BuildShape();
|
||||
if (*target_shape != *element_abs_shape) {
|
||||
MS_EXCEPTION(ValueError) << "In graph mode, when setitem for a dynamic length sequence, the new value should"
|
||||
<< " have the same type and shape as the element within the dynamic length sequence."
|
||||
<< "Now, the shape is not match, the element within the dynamic length sequence has"
|
||||
<< " shape: " << element_abs_shape->ToString()
|
||||
<< " and the new value has shape: " << target_shape->ToString();
|
||||
}
|
||||
auto element_abs_type = element_abs->BuildType();
|
||||
auto target_type = target->BuildType();
|
||||
if (*target_type != *element_abs_type) {
|
||||
MS_EXCEPTION(ValueError) << "In graph mode, when setitem for a dynamic length sequence, the new value should"
|
||||
<< " have the same type and shape as the element within the dynamic length sequence."
|
||||
<< "Now, the type is not match, the element within the dynamic length sequence has"
|
||||
<< " type: " << element_abs_type->ToString()
|
||||
<< " and the new value has type: " << target_type->ToString();
|
||||
}
|
||||
CheckDynamicLengthSequenceSetItem(op_name, queue, target);
|
||||
return queue->Clone();
|
||||
}
|
||||
if (index_value == kAnyValue) {
|
||||
// If the index is variable and the sequence is constant length, then all of the element within the sequence
|
||||
// should have the same type and shape with the target input. The element within the return sequence should
|
||||
// be all broadened.
|
||||
const auto &elements = queue->elements();
|
||||
if (elements.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "Empty sequence can not setitem.";
|
||||
}
|
||||
return queue;
|
||||
const auto precondition_log = "For " + op_name + ", when the index is variable and the queue is constant length";
|
||||
CheckAndConvertUtils::CheckAbstractTypeAndShapeSame(elements, precondition_log);
|
||||
auto first_element = elements[kIndex0];
|
||||
const auto standard_abs_description = "element within constant length sequence";
|
||||
const auto differ_abs_description = "target element";
|
||||
CheckAndConvertUtils::CheckAbstractTypeAndShapeSame(std::vector<AbstractBasePtr>{first_element, target},
|
||||
precondition_log, standard_abs_description,
|
||||
differ_abs_description);
|
||||
return CheckAndConvertUtils::BroadenAllSequenceElements(queue);
|
||||
}
|
||||
auto index_int64_value = GetValue<int64_t>(index_value);
|
||||
AbstractBasePtrList elements = queue->elements();
|
||||
|
@ -223,7 +236,7 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra
|
|||
<< nelems << "," << (nelems - 1) << "].";
|
||||
}
|
||||
size_t index_unsigned_value = LongToSize(index_positive_value);
|
||||
elements[index_unsigned_value] = args_spec_list[target_value_index];
|
||||
elements[index_unsigned_value] = args_spec_list[kIndex2];
|
||||
MS_LOG(DEBUG) << "SetItem use flags, index: " << index_unsigned_value << ", for " << queue->ToString();
|
||||
return std::make_shared<T>(elements, queue->sequence_nodes());
|
||||
}
|
||||
|
|
|
@ -257,6 +257,7 @@ constexpr auto kSequenceCount = "SequenceCount";
|
|||
constexpr auto kSequenceIndex = "SequenceIndex";
|
||||
constexpr auto kSequenceMul = "SequenceMul";
|
||||
constexpr auto kSequenceSlice = "SequenceSlice";
|
||||
constexpr auto kSequenceSliceSetItem = "SequenceSliceSetItem";
|
||||
|
||||
// NN
|
||||
constexpr auto kFractionalMaxPoolWithFixedKsize = "FractionalMaxPoolWithFixedKsize";
|
||||
|
@ -1573,6 +1574,7 @@ GVAR_DEF(PrimitivePtr, kPrimSequenceCount, std::make_shared<Primitive>(kSequence
|
|||
GVAR_DEF(PrimitivePtr, kPrimSequenceIndex, std::make_shared<Primitive>(kSequenceIndex));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSequenceMul, std::make_shared<Primitive>(kSequenceMul));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSequenceSlice, std::make_shared<Primitive>(kSequenceSlice));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSequenceSliceSetItem, std::make_shared<Primitive>(kSequenceSliceSetItem));
|
||||
|
||||
// Other miscellaneous
|
||||
GVAR_DEF(PrimitivePtr, kPrimSampleDistortedBoundingBoxV2, std::make_shared<Primitive>(kSampleDistortedBoundingBoxV2));
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
// Take out the abstract of element.
|
||||
// The elements of input should have same shape and type. Dynamic length sequence already satisfies this requirement.
|
||||
// For constant length sequence, this requirement need to be checked in this function.
|
||||
|
@ -93,6 +94,7 @@ AbstractBasePtr SequenceAddInferInner(const PrimitivePtr &primitive, const std::
|
|||
}
|
||||
return input_2->Clone();
|
||||
}
|
||||
} // namespace
|
||||
MIND_API_OPERATOR_IMPL(SequenceAdd, BaseOperator);
|
||||
class SequenceAddInfer : public abstract::OpInferBase {
|
||||
public:
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
AbstractBasePtr SequenceMulInferInner(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
|
@ -56,6 +57,7 @@ AbstractBasePtr SequenceMulInferInner(const PrimitivePtr &primitive, const std::
|
|||
ret->CheckAndConvertToDynamicLenSequence();
|
||||
return ret;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(SequenceMul, BaseOperator);
|
||||
class SequenceMulInfer : public abstract::OpInferBase {
|
||||
|
|
|
@ -25,7 +25,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
AbstractBasePtr AbstractInferInner(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
namespace {
|
||||
AbstractBasePtr SliceInferInner(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
constexpr size_t input_num = 4;
|
||||
|
@ -61,22 +62,23 @@ AbstractBasePtr AbstractInferInner(const PrimitivePtr &primitive, const std::vec
|
|||
ret->CheckAndConvertToDynamicLenSequence();
|
||||
return ret;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(SequenceSlice, BaseOperator);
|
||||
class SequenceSliceInfer : public abstract::OpInferBase {
|
||||
public:
|
||||
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return AbstractInferInner(primitive, input_args)->BuildShape();
|
||||
return SliceInferInner(primitive, input_args)->BuildShape();
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return AbstractInferInner(prim, input_args)->BuildType();
|
||||
return SliceInferInner(prim, input_args)->BuildType();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return AbstractInferInner(primitive, input_args);
|
||||
return SliceInferInner(primitive, input_args);
|
||||
}
|
||||
};
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(SequenceSlice, prim::kPrimSequenceSlice, SequenceSliceInfer, false);
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/sequence_slice_setitem.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
AbstractBasePtr SequenceSliceInferInner(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
constexpr size_t input_num = 5;
|
||||
constexpr size_t sequence_index = 0;
|
||||
constexpr size_t target_index = 1;
|
||||
constexpr size_t start_index = 2;
|
||||
constexpr size_t stop_index = 3;
|
||||
constexpr size_t step_index = 4;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||
for (auto arg : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
}
|
||||
auto sequence_abs = dyn_cast<abstract::AbstractSequence>(input_args[sequence_index]);
|
||||
MS_EXCEPTION_IF_NULL(sequence_abs);
|
||||
auto target_abs = dyn_cast<abstract::AbstractSequence>(input_args[target_index]);
|
||||
if (target_abs == nullptr) {
|
||||
MS_EXCEPTION(TypeError) << "Can only assign an iterable.";
|
||||
}
|
||||
auto start_abs = input_args[start_index];
|
||||
auto stop_abs = input_args[stop_index];
|
||||
auto step_abs = input_args[step_index];
|
||||
if (!sequence_abs->dynamic_len() && !target_abs->dynamic_len() && start_abs->BuildValue() != kAnyValue &&
|
||||
stop_abs->BuildValue() != kAnyValue && step_abs->BuildValue() != kAnyValue) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the origin/target sequence should be dynamic length "
|
||||
<< "or one of start/stop/step should be variable.";
|
||||
}
|
||||
|
||||
if (!sequence_abs->dynamic_len()) {
|
||||
sequence_abs = sequence_abs->Clone()->cast<abstract::AbstractSequencePtr>();
|
||||
sequence_abs->CheckAndConvertToDynamicLenSequence();
|
||||
}
|
||||
if (!target_abs->dynamic_len()) {
|
||||
target_abs = target_abs->Clone()->cast<abstract::AbstractSequencePtr>();
|
||||
target_abs->CheckAndConvertToDynamicLenSequence();
|
||||
}
|
||||
auto seq_element = sequence_abs->dynamic_len_element_abs();
|
||||
auto target_element = target_abs->dynamic_len_element_abs();
|
||||
auto ret = (sequence_abs == input_args[sequence_index]) ? sequence_abs->Clone()->cast<abstract::AbstractSequencePtr>()
|
||||
: sequence_abs;
|
||||
if (target_element == nullptr) {
|
||||
return ret;
|
||||
}
|
||||
if (seq_element == nullptr) {
|
||||
ret->set_dynamic_len_element_abs(target_element);
|
||||
return ret;
|
||||
}
|
||||
const auto precondition_log = "For " + prim_name;
|
||||
const auto standard_abs_description = "element within origin sequence";
|
||||
const auto differ_abs_description = "element within target sequence";
|
||||
CheckAndConvertUtils::CheckAbstractTypeAndShapeSame(std::vector<AbstractBasePtr>{seq_element, target_element},
|
||||
precondition_log, standard_abs_description,
|
||||
differ_abs_description);
|
||||
return ret;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(SequenceSliceSetItem, BaseOperator);
|
||||
class SequenceSliceSetItemInfer : public abstract::OpInferBase {
|
||||
public:
|
||||
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return SequenceSliceInferInner(primitive, input_args)->BuildShape();
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return SequenceSliceInferInner(prim, input_args)->BuildType();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return SequenceSliceInferInner(primitive, input_args);
|
||||
}
|
||||
};
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(SequenceSliceSetItem, prim::kPrimSequenceSliceSetItem, SequenceSliceSetItemInfer,
|
||||
false);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SEQUENCE_SLICE_SETITEM_H_
|
||||
#define MINDSPORE_CORE_OPS_SEQUENCE_SLICE_SETITEM_H_
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
/// \brief Sequence slice setitem operation.
|
||||
class MIND_API SequenceSliceSetItem : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(SequenceSliceSetItem);
|
||||
/// \brief Constructor.
|
||||
SequenceSliceSetItem() : BaseOperator(prim::kSequenceSliceSetItem) {}
|
||||
/// \brief Init function.
|
||||
void Init() const {}
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SEQUENCE_SLICE_SETITEM_H_
|
|
@ -369,6 +369,30 @@ void CheckAndConvertUtils::CheckAbstractTypeAndShapeSame(const std::vector<Abstr
|
|||
}
|
||||
}
|
||||
|
||||
abstract::AbstractSequencePtr CheckAndConvertUtils::BroadenAllSequenceElements(
|
||||
const abstract::AbstractSequencePtr &sequence) {
|
||||
MS_EXCEPTION_IF_NULL(sequence);
|
||||
const auto &elements = sequence->elements();
|
||||
AbstractBasePtrList new_elements;
|
||||
for (auto element : elements) {
|
||||
AbstractBasePtr new_element = nullptr;
|
||||
if (element->isa<abstract::AbstractSequence>()) {
|
||||
new_element = BroadenAllSequenceElements(element->cast<abstract::AbstractSequencePtr>());
|
||||
} else {
|
||||
auto tmp_element = element->Clone();
|
||||
if (element->isa<abstract::AbstractScalar>()) {
|
||||
tmp_element->cast<abstract::AbstractScalarPtr>()->set_is_variable(true);
|
||||
}
|
||||
new_element = tmp_element->Broaden();
|
||||
}
|
||||
new_elements.push_back(new_element);
|
||||
}
|
||||
if (sequence->isa<abstract::AbstractList>()) {
|
||||
return std::make_shared<abstract::AbstractList>(new_elements, sequence->sequence_nodes());
|
||||
}
|
||||
return std::make_shared<abstract::AbstractTuple>(new_elements, sequence->sequence_nodes());
|
||||
}
|
||||
|
||||
bool CheckAndConvertUtils::CheckValueSame(const ValuePtr &value_1, const ValuePtr &value_2) {
|
||||
MS_EXCEPTION_IF_NULL(value_1);
|
||||
MS_EXCEPTION_IF_NULL(value_2);
|
||||
|
|
|
@ -335,6 +335,7 @@ class MS_CORE_API CheckAndConvertUtils {
|
|||
const std::string &standard_abs_description = "",
|
||||
const std::string &differ_abs_description = "");
|
||||
static bool CheckValueSame(const ValuePtr &value_1, const ValuePtr &value_2);
|
||||
static abstract::AbstractSequencePtr BroadenAllSequenceElements(const abstract::AbstractSequencePtr &sequence);
|
||||
|
||||
private:
|
||||
static TypePtr _CheckTypeSame(const std::map<std::string, TypePtr> &args, const std::string &prim_name,
|
||||
|
|
|
@ -22,12 +22,12 @@ from mindspore.ops.operations._inner_ops import SliceGetItem
|
|||
from mindspore.ops.operations import _map_tensor_ops
|
||||
from mindspore.ops.composite import base
|
||||
from mindspore.common import Tensor
|
||||
from ...operations._sequence_ops import SequenceSliceSetItem
|
||||
|
||||
DOC_URL = "https://mindspore.cn/docs/zh-CN/master/note/index_support.html"
|
||||
|
||||
setitem = base.MultitypeFuncGraph('setitem', doc_url=DOC_URL)
|
||||
|
||||
slice_get_item = SliceGetItem()
|
||||
sequence_slice_setitem = SequenceSliceSetItem()
|
||||
|
||||
|
||||
class _ListSliceSetItem(base.ListSliceSetItem_):
|
||||
|
@ -147,6 +147,11 @@ def _list_slice_setitem_with_tuple(data, slice_index, value):
|
|||
Outputs:
|
||||
list, type is the same as the element type of data.
|
||||
"""
|
||||
if F.is_sequence_shape_unknown(data) or F.is_sequence_shape_unknown(value) or not F.isconstant(slice_index):
|
||||
start = slice_get_item(slice_index, "start")
|
||||
stop = slice_get_item(slice_index, "stop")
|
||||
step = slice_get_item(slice_index, "step")
|
||||
return sequence_slice_setitem(data, value, start, stop, step)
|
||||
list_value = list(value)
|
||||
return _list_slice_set_item(data, slice_index, list_value)
|
||||
|
||||
|
@ -164,6 +169,11 @@ def _list_slice_setitem_with_list(data, slice_index, value):
|
|||
Outputs:
|
||||
list, type is the same as the element type of data.
|
||||
"""
|
||||
if F.is_sequence_shape_unknown(data) or F.is_sequence_shape_unknown(value) or not F.isconstant(slice_index):
|
||||
start = slice_get_item(slice_index, "start")
|
||||
stop = slice_get_item(slice_index, "stop")
|
||||
step = slice_get_item(slice_index, "step")
|
||||
return sequence_slice_setitem(data, value, start, stop, step)
|
||||
return _list_slice_set_item(data, slice_index, value)
|
||||
|
||||
|
||||
|
@ -181,6 +191,11 @@ def _list_slice_setitem_with_tensor(data, slice_index, value):
|
|||
list, type is the same as the element type of data.
|
||||
"""
|
||||
value_list = list(value)
|
||||
if F.is_sequence_shape_unknown(data) or F.is_sequence_shape_unknown(value_list) or not F.isconstant(slice_index):
|
||||
start = slice_get_item(slice_index, "start")
|
||||
stop = slice_get_item(slice_index, "stop")
|
||||
step = slice_get_item(slice_index, "step")
|
||||
return sequence_slice_setitem(data, value_list, start, stop, step)
|
||||
return _list_slice_set_item(data, slice_index, value_list)
|
||||
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ class SequenceSlice(Primitive):
|
|||
- **step** (int) - step of slice.
|
||||
|
||||
Outputs:
|
||||
Dynamic length sequence after addition.
|
||||
Dynamic length sequence after slice.
|
||||
|
||||
Raises:
|
||||
TypeError: The 'seq' input is neither list or tuple.
|
||||
|
@ -73,10 +73,41 @@ class SequenceSlice(Primitive):
|
|||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize SequenceCount"""
|
||||
"""Initialize SequenceSlice"""
|
||||
self.init_prim_io_names(inputs=['seq', 'start', 'stop', 'step'], outputs=['output_data'])
|
||||
|
||||
|
||||
class SequenceSliceSetItem(Primitive):
|
||||
r"""
|
||||
Sequence slice setitem operation.
|
||||
|
||||
.. note::
|
||||
This it is only for internal used. The sequence input should be dynamic length sequence or at least one of
|
||||
start/end/step should be variable.
|
||||
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
||||
|
||||
Inputs:
|
||||
- **seq** (Union[List, Tuple]) - The sequence to perform slice setitem.
|
||||
- **target** (Union[List, Tuple]) - The target item to set.
|
||||
- **start** (int) - start index of slice.
|
||||
- **stop** (int) - stop index of slice.
|
||||
- **step** (int) - step of slice.
|
||||
|
||||
Outputs:
|
||||
Dynamic length sequence after slice setitem.
|
||||
|
||||
Raises:
|
||||
ValueError: The 'seq' and 'target' input is not dynamic length and none of start/end/step is variable.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize SequenceSliceSetItem"""
|
||||
self.init_prim_io_names(inputs=['seq', 'target', 'start', 'stop', 'step'], outputs=['output_data'])
|
||||
|
||||
|
||||
class SequenceAdd(Primitive):
|
||||
r"""
|
||||
Add elements of two sequence together.
|
||||
|
|
|
@ -0,0 +1,226 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test setitem operation for tuple/list with variable index or dynamic length sequence"""
|
||||
import pytest
|
||||
from mindspore.common import mutable
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import jit
|
||||
from mindspore import context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_setitem_dynamic_length_list_constant_index():
|
||||
"""
|
||||
Feature: Setitem operation including variable.
|
||||
Description: setitem for dynamic length list and constant index return dynamic length list.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
a = mutable([1, 2, 3, 4], True)
|
||||
a[0] = 20
|
||||
return isinstance(a, list), F.is_sequence_shape_unknown(a)
|
||||
|
||||
ret1, ret2 = foo()
|
||||
assert ret1
|
||||
assert ret2
|
||||
|
||||
|
||||
def test_setitem_dynamic_length_list_constant_index_2():
|
||||
"""
|
||||
Feature: Setitem operation including variable.
|
||||
Description: setitem for dynamic length list and constant index return dynamic length list.
|
||||
Expectation: Raise TypeError.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
a = mutable([1, 2, 3, 4], True)
|
||||
a[0] = 1.0
|
||||
return a
|
||||
|
||||
with pytest.raises(TypeError) as ex:
|
||||
foo()
|
||||
assert "element within dynamic length sequence" in str(ex.value)
|
||||
|
||||
|
||||
def test_setitem_constant_length_list_variable_index():
|
||||
"""
|
||||
Feature: Setitem operation including variable.
|
||||
Description: setitem for constant length list and dynamic index return constant length list.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
a = [1, 2]
|
||||
index = mutable(0)
|
||||
a[index] = 10
|
||||
return isinstance(a, list), F.isconstant(a[0]), F.isconstant(a[1])
|
||||
|
||||
ret1, ret2, ret3 = foo()
|
||||
assert ret1
|
||||
assert not ret2
|
||||
assert not ret3
|
||||
|
||||
|
||||
def test_setitem_constant_length_list_variable_index_2():
|
||||
"""
|
||||
Feature: Setitem operation including variable.
|
||||
Description: setitem for constant length list and dynamic index return constant length list.
|
||||
Expectation: Raise TypeError.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
a = [1, 2.0]
|
||||
index = mutable(0)
|
||||
a[index] = 10
|
||||
return a
|
||||
|
||||
with pytest.raises(TypeError) as ex:
|
||||
foo()
|
||||
assert "sequence[0] item" in str(ex.value)
|
||||
|
||||
|
||||
def test_setitem_constant_length_list_variable_index_3():
|
||||
"""
|
||||
Feature: Setitem operation including variable.
|
||||
Description: setitem for constant length list and dynamic index return constant length list.
|
||||
Expectation: Raise TypeError.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
a = [1, 2]
|
||||
index = mutable(0)
|
||||
a[index] = 1.0
|
||||
return a
|
||||
|
||||
with pytest.raises(TypeError) as ex:
|
||||
foo()
|
||||
assert "element within constant length sequence" in str(ex.value)
|
||||
|
||||
|
||||
def test_slice_setitem_dynamic_length_list():
|
||||
"""
|
||||
Feature: Slice setitem operation including variable.
|
||||
Description: setitem for dynamic length list and constant index return dynamic length list.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
a = mutable([1, 2, 3, 4], True)
|
||||
a[0:2] = [2, 3, 4, 5, 6]
|
||||
return isinstance(a, list), F.is_sequence_shape_unknown(a)
|
||||
|
||||
ret1, ret2 = foo()
|
||||
assert ret1
|
||||
assert ret2
|
||||
|
||||
|
||||
def test_slice_setitem_dynamic_length_list_2():
|
||||
"""
|
||||
Feature: Slice setitem operation including variable.
|
||||
Description: setitem for dynamic length list and constant index return dynamic length list.
|
||||
Expectation: Raise ValueError.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
a = mutable([1, 2, 3, 4], True)
|
||||
a[0:2] = [2, 3, 4.0, 5]
|
||||
return a
|
||||
|
||||
with pytest.raises(ValueError) as ex:
|
||||
foo()
|
||||
assert "The element type do not match, can not convert to dynamic length sequence." in str(ex.value)
|
||||
|
||||
|
||||
def test_slice_setitem_dynamic_length_target():
|
||||
"""
|
||||
Feature: Slice setitem operation including variable.
|
||||
Description: setitem for dynamic length list and constant index return dynamic length list.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
a = [1, 2, 3, 4]
|
||||
a[0:2] = mutable([1, 2, 3, 4], True)
|
||||
return isinstance(a, list), F.is_sequence_shape_unknown(a)
|
||||
|
||||
ret1, ret2 = foo()
|
||||
assert ret1
|
||||
assert ret2
|
||||
|
||||
|
||||
def test_slice_setitem_dynamic_length_target_2():
|
||||
"""
|
||||
Feature: Slice setitem operation including variable.
|
||||
Description: setitem for dynamic length list and constant index return dynamic length list.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
a = [1, 2, 3, 4.0]
|
||||
a[0:2] = mutable([1, 2, 3, 4], True)
|
||||
return a
|
||||
|
||||
with pytest.raises(ValueError) as ex:
|
||||
foo()
|
||||
assert "The element type do not match, can not convert to dynamic length sequence." in str(ex.value)
|
||||
|
||||
|
||||
def test_slice_setitem_dynamic_slice():
|
||||
"""
|
||||
Feature: Slice setitem operation including variable.
|
||||
Description: setitem for dynamic length list and constant index return dynamic length list.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
a = [1, 2, 3, 4]
|
||||
start = mutable(0)
|
||||
a[start:2] = [1, 2, 3, 4]
|
||||
return isinstance(a, list), F.is_sequence_shape_unknown(a)
|
||||
|
||||
ret1, ret2 = foo()
|
||||
assert ret1
|
||||
assert ret2
|
||||
|
||||
|
||||
def test_slice_setitem_dynamic_slice_2():
|
||||
"""
|
||||
Feature: Slice setitem operation including variable.
|
||||
Description: setitem for dynamic length list and constant index return dynamic length list.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
a = [1.0, 2.0, 3.0, 4.0]
|
||||
start = mutable(0)
|
||||
a[start:2] = [1, 2, 3, 4]
|
||||
return a
|
||||
|
||||
with pytest.raises(TypeError) as ex:
|
||||
foo()
|
||||
assert "element within origin sequence" in str(ex.value)
|
|
@ -224,9 +224,9 @@ def test_dynamic_length_sequence_setitem_3():
|
|||
x = mutable([1, 2, 3, 4], True)
|
||||
x[3] = 10.0
|
||||
return x
|
||||
with pytest.raises(ValueError) as ex:
|
||||
with pytest.raises(TypeError) as ex:
|
||||
foo()
|
||||
assert "the type is not match" in str(ex.value)
|
||||
assert "when the queue is dynamic length" in str(ex.value)
|
||||
|
||||
|
||||
def test_dynamic_length_sequence_setitem_4():
|
||||
|
@ -241,9 +241,9 @@ def test_dynamic_length_sequence_setitem_4():
|
|||
x = mutable([(1, 2, 3), (2, 3, 4)], True)
|
||||
x[3] = (2, 3)
|
||||
return x
|
||||
with pytest.raises(ValueError) as ex:
|
||||
with pytest.raises(TypeError) as ex:
|
||||
foo()
|
||||
assert "the shape is not match" in str(ex.value)
|
||||
assert "when the queue is dynamic length" in str(ex.value)
|
||||
|
||||
|
||||
def test_dynamic_sequence_len():
|
||||
|
|
Loading…
Reference in New Issue