!48358 add sequence test

Merge pull request !48358 from NaCN/test_seq
This commit is contained in:
i-robot 2023-02-07 04:33:09 +00:00 committed by Gitee
commit 1309df9bb7
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
33 changed files with 546 additions and 37 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -76,25 +76,20 @@ bool SequenceAddOffsetCpuKernelMod::Launch(const std::vector<KernelTensorPtr> &i
std::vector<std::pair<KernelAttr, SequenceAddOffsetCpuKernelMod::SequenceAddOffsetFunc>>
SequenceAddOffsetCpuKernelMod::func_list_ = {{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32)
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
&SequenceAddOffsetCpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat64)
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat64)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
&SequenceAddOffsetCpuKernelMod::LaunchKernel<double>},
{KernelAttr()
.AddInputAttr(kObjectTypeList, kNumberTypeInt32)
.AddInputAttr(kObjectTypeList, kNumberTypeInt32)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
&SequenceAddOffsetCpuKernelMod::LaunchKernel<int>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
&SequenceAddOffsetCpuKernelMod::LaunchKernel<int>},
{KernelAttr()
.AddInputAttr(kObjectTypeList, kNumberTypeInt64)
.AddInputAttr(kObjectTypeList, kNumberTypeInt64)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
&SequenceAddOffsetCpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -36,7 +36,7 @@ T Sign(T num) {
} else {
return static_cast<T>(-1.0);
}
} // namespace
}
} // namespace
bool MakeRangeCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
@ -74,7 +74,7 @@ bool MakeRangeCpuKernelMod::LaunchKernel(const std::vector<KernelTensorPtr> &inp
}
}
return true;
} // namespace kernel
}
const std::vector<std::pair<KernelAttr, MakeRangeCpuKernelMod::KernelRunFunc>> &MakeRangeCpuKernelMod::GetFuncList()
const {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -18,6 +18,7 @@
from __future__ import absolute_import
from mindspore.ops.composite import base
from mindspore.ops import functional as F
from mindspore.ops.operations import _sequence_ops as seq
zeros_like_leaf = base.MultitypeFuncGraph('zeros_like_leaf', True)
"""
@ -59,6 +60,12 @@ def _zeros_like_tensor(x):
return F.zeros_like(x)
@zeros_like_leaf.register("Tuple")
def _zeros_like_tuple(x):
"""Returns a Tuple with the same shape and dtype as x and all elements are 0."""
return seq.SequenceZerosLike()(x)
@zeros_like_leaf.register("COOTensor")
def _zeros_like_coo_tensor(x):
"""Returns a tensor with the same shape and dtype as x and all elements are 0."""

View File

@ -0,0 +1,118 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import mindspore.nn as nn
from mindspore.ops.operations import _sequence_ops as seq
from mindspore import context
from mindspore.common import mutable
from mindspore.nn import Cell
from mindspore.ops.composite import GradOperation
context.set_context(mode=context.GRAPH_MODE)
class NetAdd(nn.Cell):
def __init__(self):
super().__init__()
self.seq_add = seq.SequenceAdd()
def construct(self, x, y):
return self.seq_add(x, y)
class NetAddOffset(nn.Cell):
def __init__(self):
super().__init__()
self.seq_add_offset = seq.SequenceAddOffset()
def construct(self, x, y):
return self.seq_add_offset(x, y)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_seq_add():
"""
Feature: test sequence_add op
Description: two inputs are dynamic sequence
Expectation: the result match with tuple result
"""
x = mutable((1, 2, 3), True)
y = mutable((4, 5, 6), True)
expect = (1, 2, 3, 4, 5, 6)
net = NetAdd()
res = net(x, y)
assert res == expect
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_seq_add_offset():
"""
Feature: test sequence_add_offset op
Description: inputs are dynamic sequence.
Expectation: the result match with tuple result
"""
x = mutable((1, 2, 3), True)
y = mutable((4, 5, 6), True)
expect = (0, 3)
net = NetAddOffset()
res = net(x, y)
assert res == expect
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_seq_add_grad():
"""
Feature: test sequence add grad op
Description: inputs are dynamic sequence.
Expectation: the result match with tuple result
"""
class Net(Cell):
def construct(self, x, y):
return x + y
net_ms = Net()
input_x = mutable((1, 2, 3), True)
input_y = mutable((3, 4, 5, 6), True)
dout = mutable((1, 1, 1, 1, 1, 1, 1), True)
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
print("grad out = ", grad_func(input_x, input_y, dout))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_seq_add_grad_other():
"""
Feature: test sequence add grad op
Description: inputs are dynamic sequence.
Expectation: the result match with tuple result
"""
class Net(Cell):
def construct(self, x, y):
return x + y
net_ms = Net()
input_x = mutable((1, 2, 3), True)
input_y = (3, 4, 5, 6)
dout = mutable((1, 1, 1, 1, 1, 1, 1), True)
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
print("grad out1 = ", grad_func(input_x, input_y, dout))

View File

@ -0,0 +1,61 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import mindspore.nn as nn
from mindspore import context
from mindspore.common import mutable
from mindspore.ops.composite import GradOperation
context.set_context(mode=context.GRAPH_MODE)
class NetGetItem(nn.Cell):
def construct(self, seq, idx):
return seq[idx]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_seq_getitem():
"""
Feature: test sequence getitem op
Description: setitem operation on tuple type
Expectation: the behavior is matched to python style
"""
seq = mutable((1, 2, 3, 4, 5, 6), True)
idx = 3
expect = 4
net = NetGetItem()
res = net(seq, idx)
assert res == expect
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_seq_getitem_grad():
"""
Feature: test sequence getitem grad op
Description: setitem operation on tuple type
Expectation: the behavior is matched to python style
"""
net_ms = NetGetItem()
seq = mutable((1, 2, 3, 4, 5, 6), True)
index = 1
dout = 1
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
print("grad out1 = ", grad_func(seq, index, dout))

View File

@ -0,0 +1,71 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
from mindspore.ops.operations import _sequence_ops as seq
from mindspore import context
from mindspore.nn import Cell
from mindspore.common import mutable
from mindspore.ops.composite import GradOperation
from tuple_help import TupleFactory
context.set_context(mode=context.GRAPH_MODE)
class Net(Cell):
def __init__(self):
super().__init__()
self.func = seq.make_range()
def construct(self, x, y, z):
return self.func(x, y, z)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_seqence_make_range():
"""
Feature: test sequence makerange op
Description: setitem operation on tuple type
Expectation: the behavior is matched to python style
"""
def func(x, y, z):
return tuple(range(x, y, z))
net_ms = Net()
input_x = 1
input_y = 1000
input_z = 31
fact = TupleFactory(net_ms, func, (input_x, input_y, input_z))
fact.forward_cmp()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_seqence_make_range_grad():
"""
Feature: test sequence makerange grad
Description: setitem operation on tuple type
Expectation: the behavior is matched to python style
"""
net_ms = Net()
input_x = mutable(10)
input_y = mutable(100)
input_z = mutable(3)
dout = mutable((1, 1), True)
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
print("grad out1 = ", grad_func(input_x, input_y, input_z, dout))

View File

@ -0,0 +1,82 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import mindspore.nn as nn
from mindspore.ops import functional as F
from mindspore import context
from mindspore.common import mutable
from mindspore.ops.composite import GradOperation
context.set_context(mode=context.GRAPH_MODE)
class NetSetItem(nn.Cell):
def construct(self, seq, idx, value):
return F.tuple_setitem(seq, idx, value)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_seq_setitem():
"""
Feature: test sequence_setitem op
Description: setitem operation on tuple type
Expectation: the behavior is matched to python style
"""
seq = mutable((1, 2, 3, 4, 5, 6), True)
value = 9
idx = 3
expect = (1, 2, 3, 9, 5, 6)
net = NetSetItem()
res = net(seq, idx, value)
assert res == expect
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_seq_setitem_grad_0():
"""
Feature: test sequence setitem grad op
Description: setitem operation on tuple type
Expectation: the behavior is matched to python style
"""
net_ms = NetSetItem()
input_x = mutable((1, 2, 3), True)
idx = mutable(1)
value = mutable(8)
dout = mutable((1, 1, 1), True)
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
print("grad out0 = ", grad_func(input_x, idx, value, dout))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_seq_setitem_grad_1():
"""
Feature: test sequence setitem grad op
Description: setitem operation on tuple type
Expectation: the behavior is matched to python style
"""
net_ms = NetSetItem()
input_x = mutable((1, 2, 3), True)
idx = 1
value = 8
dout = mutable((1, 1, 1), True)
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
print("grad out1 = ", grad_func(input_x, idx, value, dout))

View File

@ -0,0 +1,79 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import mindspore.nn as nn
from mindspore import context
from mindspore.ops.operations import _sequence_ops as S
from mindspore.common import mutable
from mindspore.ops.composite import GradOperation
from tuple_help import TupleFactory
context.set_context(mode=context.GRAPH_MODE)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_seq_slice():
"""
Feature: test sequence_slice op
Description: slice operation on tuple type
Expectation: the behavior is matched to python style
"""
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.seq_slice = S.SequenceSlice()
def construct(self, seq, start, stop, step):
return self.seq_slice(seq, start, stop, step)
def func(seq, start, stop, step):
return seq[start:stop:step]
seq = (1, 2, 3, 4, 5, 6)
start = 1
stop = 3
step = 1
net_ms = Net()
fact = TupleFactory(net_ms, func, (seq, start, stop, step))
fact.forward_cmp()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_seq_slice_grad():
"""
Feature: test sequence_slice grad
Description: slice operation on tuple type
Expectation: the behavior is matched to python style
"""
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.seq_slice = S.SequenceSlice()
def construct(self, seq, start, stop, step):
return self.seq_slice(seq, start, stop, step)
seq = mutable((1, 2, 3, 4, 5, 6), True)
start = 1
stop = 3
step = 1
dout = mutable((1, 1), True)
net_ms = Net()
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
print("grad out1 = ", grad_func(seq, start, stop, step, dout))

View File

@ -0,0 +1,47 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
from mindspore.ops.operations import _sequence_ops as seq
from mindspore import context
from mindspore.nn import Cell
from tuple_help import TupleFactory
context.set_context(mode=context.GRAPH_MODE)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_seqence_zeros_like():
"""
Feature: test sequence zeroslike op
Description: setitem operation on tuple type
Expectation: the behavior is matched to python style
"""
class Net(Cell):
def __init__(self):
super().__init__()
self.func = seq.SequenceZerosLike()
def construct(self, x):
return self.func(x)
def func(x):
return tuple(np.zeros_like(x))
net_ms = Net()
input_x = (1, 2, 3, 4, 5)
fact = TupleFactory(net_ms, func, (input_x,))
fact.forward_cmp()

View File

@ -0,0 +1,49 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
from mindspore import context
from mindspore import Tensor
from mindspore.nn import Cell
import mindspore
context.set_context(mode=context.GRAPH_MODE)
class ShapeNet(Cell):
def __init__(self):
super().__init__()
self.shape = mindspore.ops.shape
def construct(self, x):
return self.shape(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_shape():
"""
Feature: test sequence shape op
Description: setitem operation on tuple type
Expectation: the behavior is matched to python style
"""
x = Tensor(np.random.randn(1, 2, 4).astype(np.float32))
dynx = Tensor(shape=(1, 2, None), dtype=mindspore.float32)
expect_x = (1, 2, 4)
net = ShapeNet()
net.set_inputs(dynx)
res_x = net(x)
assert expect_x == res_x