forked from mindspore-Ecosystem/mindspore
!9591 Assignment enables negative index
From: @liangzhibo Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
f6a22cb455
|
@ -330,7 +330,7 @@ AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) {
|
|||
ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int64_t depth) {
|
||||
const int64_t DEPTH_MAX = 5;
|
||||
if (depth > DEPTH_MAX) {
|
||||
MS_LOG(EXCEPTION) << "List nesting is not allowed more than 5 levels.";
|
||||
MS_LOG(EXCEPTION) << "List nesting is not allowed more than 6 levels.";
|
||||
}
|
||||
std::vector<ValuePtr> elements;
|
||||
for (const auto &it : value_list->value()) {
|
||||
|
|
|
@ -163,18 +163,14 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra
|
|||
<< index_value->ToString();
|
||||
}
|
||||
int64_t idx_v = GetValue<int64_t>(index_value);
|
||||
if (idx_v < 0) {
|
||||
MS_EXCEPTION(IndexError) << "The index of " << typeid(T).name() << " should be positive number, but got " << idx_v
|
||||
<< ".";
|
||||
}
|
||||
|
||||
size_t uidx_v = LongToSize(idx_v);
|
||||
AbstractBasePtrList elements = queue->elements();
|
||||
std::size_t nelems = elements.size();
|
||||
if (uidx_v >= nelems) {
|
||||
MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1
|
||||
<< ".";
|
||||
int64_t idx_t = idx_v >= 0 ? idx_v : idx_v + SizeToLong(nelems);
|
||||
if (idx_t < 0 || idx_t >= SizeToLong(nelems)) {
|
||||
MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << idx_v << " to set out of range: [-" << nelems
|
||||
<< "," << nelems - 1 << "].";
|
||||
}
|
||||
size_t uidx_v = LongToSize(idx_t);
|
||||
elements[uidx_v] = args_spec_list[2];
|
||||
return std::make_shared<T>(elements);
|
||||
}
|
||||
|
|
|
@ -755,6 +755,9 @@ class PConstant : public PBase<PConstant<T> > {
|
|||
ShapeVector tensor_shape = tensor_abstract->shape()->shape();
|
||||
auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
|
||||
size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
|
||||
if (new_tensor_ptr->DataSize() < tensor_ptr->DataSize()) {
|
||||
MS_EXCEPTION(ValueError) << "DataSize of new_tensor_ptr is smaller than DataSize of tensor_ptr";
|
||||
}
|
||||
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat) ||
|
||||
(tensor_type == TypeId::kNumberTypeFloat64)) {
|
||||
float *data = reinterpret_cast<float *>(tensor_ptr->data_c());
|
||||
|
|
|
@ -37,6 +37,23 @@ def test_list_index_1D():
|
|||
assert out[2] == [3, 3, 3]
|
||||
|
||||
|
||||
def test_list_neg_index_1D():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self):
|
||||
list_ = [[1], [2, 2], [3, 3, 3]]
|
||||
list_[-3] = [100]
|
||||
return list_
|
||||
|
||||
net = Net()
|
||||
out = net()
|
||||
assert out[0] == [100]
|
||||
assert out[1] == [2, 2]
|
||||
assert out[2] == [3, 3, 3]
|
||||
|
||||
|
||||
def test_list_index_2D():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -55,6 +72,24 @@ def test_list_index_2D():
|
|||
assert out[2] == [3, 3, 3]
|
||||
|
||||
|
||||
def test_list_neg_index_2D():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self):
|
||||
list_ = [[1], [2, 2], [3, 3, 3]]
|
||||
list_[1][-2] = 200
|
||||
list_[1][-1] = 201
|
||||
return list_
|
||||
|
||||
net = Net()
|
||||
out = net()
|
||||
assert out[0] == [1]
|
||||
assert out[1] == [200, 201]
|
||||
assert out[2] == [3, 3, 3]
|
||||
|
||||
|
||||
def test_list_index_3D():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -74,6 +109,25 @@ def test_list_index_3D():
|
|||
assert out[2] == [[300, 301, 302]]
|
||||
|
||||
|
||||
def test_list_neg_index_3D():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self):
|
||||
list_ = [[1], [2, 2], [[3, 3, 3]]]
|
||||
list_[2][0][-3] = 300
|
||||
list_[2][0][-2] = 301
|
||||
list_[2][0][-1] = 302
|
||||
return list_
|
||||
|
||||
net = Net()
|
||||
out = net()
|
||||
assert out[0] == [1]
|
||||
assert out[1] == [2, 2]
|
||||
assert out[2] == [[300, 301, 302]]
|
||||
|
||||
|
||||
def test_list_index_1D_parameter():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue