forked from mindspore-Ecosystem/mindspore
commit
1309df9bb7
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -76,25 +76,20 @@ bool SequenceAddOffsetCpuKernelMod::Launch(const std::vector<KernelTensorPtr> &i
|
|||
|
||||
std::vector<std::pair<KernelAttr, SequenceAddOffsetCpuKernelMod::SequenceAddOffsetFunc>>
|
||||
SequenceAddOffsetCpuKernelMod::func_list_ = {{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
|
||||
&SequenceAddOffsetCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat64)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat64)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
|
||||
&SequenceAddOffsetCpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeList, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeList, kNumberTypeInt32)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
|
||||
&SequenceAddOffsetCpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
|
||||
&SequenceAddOffsetCpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeList, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeList, kNumberTypeInt64)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
|
||||
&SequenceAddOffsetCpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -36,7 +36,7 @@ T Sign(T num) {
|
|||
} else {
|
||||
return static_cast<T>(-1.0);
|
||||
}
|
||||
} // namespace
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool MakeRangeCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
|
@ -74,7 +74,7 @@ bool MakeRangeCpuKernelMod::LaunchKernel(const std::vector<KernelTensorPtr> &inp
|
|||
}
|
||||
}
|
||||
return true;
|
||||
} // namespace kernel
|
||||
}
|
||||
|
||||
const std::vector<std::pair<KernelAttr, MakeRangeCpuKernelMod::KernelRunFunc>> &MakeRangeCpuKernelMod::GetFuncList()
|
||||
const {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
from __future__ import absolute_import
|
||||
from mindspore.ops.composite import base
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.operations import _sequence_ops as seq
|
||||
|
||||
zeros_like_leaf = base.MultitypeFuncGraph('zeros_like_leaf', True)
|
||||
"""
|
||||
|
@ -59,6 +60,12 @@ def _zeros_like_tensor(x):
|
|||
return F.zeros_like(x)
|
||||
|
||||
|
||||
@zeros_like_leaf.register("Tuple")
|
||||
def _zeros_like_tuple(x):
|
||||
"""Returns a Tuple with the same shape and dtype as x and all elements are 0."""
|
||||
return seq.SequenceZerosLike()(x)
|
||||
|
||||
|
||||
@zeros_like_leaf.register("COOTensor")
|
||||
def _zeros_like_coo_tensor(x):
|
||||
"""Returns a tensor with the same shape and dtype as x and all elements are 0."""
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops.operations import _sequence_ops as seq
|
||||
from mindspore import context
|
||||
from mindspore.common import mutable
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops.composite import GradOperation
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class NetAdd(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.seq_add = seq.SequenceAdd()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.seq_add(x, y)
|
||||
|
||||
|
||||
class NetAddOffset(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.seq_add_offset = seq.SequenceAddOffset()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.seq_add_offset(x, y)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_seq_add():
|
||||
"""
|
||||
Feature: test sequence_add op
|
||||
Description: two inputs are dynamic sequence
|
||||
Expectation: the result match with tuple result
|
||||
"""
|
||||
x = mutable((1, 2, 3), True)
|
||||
y = mutable((4, 5, 6), True)
|
||||
expect = (1, 2, 3, 4, 5, 6)
|
||||
net = NetAdd()
|
||||
res = net(x, y)
|
||||
assert res == expect
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_seq_add_offset():
|
||||
"""
|
||||
Feature: test sequence_add_offset op
|
||||
Description: inputs are dynamic sequence.
|
||||
Expectation: the result match with tuple result
|
||||
"""
|
||||
x = mutable((1, 2, 3), True)
|
||||
y = mutable((4, 5, 6), True)
|
||||
expect = (0, 3)
|
||||
net = NetAddOffset()
|
||||
res = net(x, y)
|
||||
assert res == expect
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_seq_add_grad():
|
||||
"""
|
||||
Feature: test sequence add grad op
|
||||
Description: inputs are dynamic sequence.
|
||||
Expectation: the result match with tuple result
|
||||
"""
|
||||
class Net(Cell):
|
||||
def construct(self, x, y):
|
||||
return x + y
|
||||
|
||||
net_ms = Net()
|
||||
input_x = mutable((1, 2, 3), True)
|
||||
input_y = mutable((3, 4, 5, 6), True)
|
||||
dout = mutable((1, 1, 1, 1, 1, 1, 1), True)
|
||||
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
||||
print("grad out = ", grad_func(input_x, input_y, dout))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_seq_add_grad_other():
|
||||
"""
|
||||
Feature: test sequence add grad op
|
||||
Description: inputs are dynamic sequence.
|
||||
Expectation: the result match with tuple result
|
||||
"""
|
||||
class Net(Cell):
|
||||
def construct(self, x, y):
|
||||
return x + y
|
||||
|
||||
net_ms = Net()
|
||||
input_x = mutable((1, 2, 3), True)
|
||||
input_y = (3, 4, 5, 6)
|
||||
dout = mutable((1, 1, 1, 1, 1, 1, 1), True)
|
||||
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
||||
print("grad out1 = ", grad_func(input_x, input_y, dout))
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.common import mutable
|
||||
from mindspore.ops.composite import GradOperation
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class NetGetItem(nn.Cell):
|
||||
def construct(self, seq, idx):
|
||||
return seq[idx]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_seq_getitem():
|
||||
"""
|
||||
Feature: test sequence getitem op
|
||||
Description: setitem operation on tuple type
|
||||
Expectation: the behavior is matched to python style
|
||||
"""
|
||||
seq = mutable((1, 2, 3, 4, 5, 6), True)
|
||||
idx = 3
|
||||
expect = 4
|
||||
net = NetGetItem()
|
||||
res = net(seq, idx)
|
||||
assert res == expect
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_seq_getitem_grad():
|
||||
"""
|
||||
Feature: test sequence getitem grad op
|
||||
Description: setitem operation on tuple type
|
||||
Expectation: the behavior is matched to python style
|
||||
"""
|
||||
net_ms = NetGetItem()
|
||||
seq = mutable((1, 2, 3, 4, 5, 6), True)
|
||||
index = 1
|
||||
dout = 1
|
||||
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
||||
print("grad out1 = ", grad_func(seq, index, dout))
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
from mindspore.ops.operations import _sequence_ops as seq
|
||||
from mindspore import context
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.common import mutable
|
||||
from mindspore.ops.composite import GradOperation
|
||||
from tuple_help import TupleFactory
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.func = seq.make_range()
|
||||
|
||||
def construct(self, x, y, z):
|
||||
return self.func(x, y, z)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_seqence_make_range():
|
||||
"""
|
||||
Feature: test sequence makerange op
|
||||
Description: setitem operation on tuple type
|
||||
Expectation: the behavior is matched to python style
|
||||
"""
|
||||
|
||||
def func(x, y, z):
|
||||
return tuple(range(x, y, z))
|
||||
|
||||
net_ms = Net()
|
||||
input_x = 1
|
||||
input_y = 1000
|
||||
input_z = 31
|
||||
fact = TupleFactory(net_ms, func, (input_x, input_y, input_z))
|
||||
fact.forward_cmp()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_seqence_make_range_grad():
|
||||
"""
|
||||
Feature: test sequence makerange grad
|
||||
Description: setitem operation on tuple type
|
||||
Expectation: the behavior is matched to python style
|
||||
"""
|
||||
net_ms = Net()
|
||||
input_x = mutable(10)
|
||||
input_y = mutable(100)
|
||||
input_z = mutable(3)
|
||||
dout = mutable((1, 1), True)
|
||||
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
||||
print("grad out1 = ", grad_func(input_x, input_y, input_z, dout))
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import context
|
||||
from mindspore.common import mutable
|
||||
from mindspore.ops.composite import GradOperation
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class NetSetItem(nn.Cell):
|
||||
def construct(self, seq, idx, value):
|
||||
return F.tuple_setitem(seq, idx, value)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_seq_setitem():
|
||||
"""
|
||||
Feature: test sequence_setitem op
|
||||
Description: setitem operation on tuple type
|
||||
Expectation: the behavior is matched to python style
|
||||
"""
|
||||
seq = mutable((1, 2, 3, 4, 5, 6), True)
|
||||
value = 9
|
||||
idx = 3
|
||||
expect = (1, 2, 3, 9, 5, 6)
|
||||
net = NetSetItem()
|
||||
res = net(seq, idx, value)
|
||||
assert res == expect
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_seq_setitem_grad_0():
|
||||
"""
|
||||
Feature: test sequence setitem grad op
|
||||
Description: setitem operation on tuple type
|
||||
Expectation: the behavior is matched to python style
|
||||
"""
|
||||
net_ms = NetSetItem()
|
||||
input_x = mutable((1, 2, 3), True)
|
||||
idx = mutable(1)
|
||||
value = mutable(8)
|
||||
dout = mutable((1, 1, 1), True)
|
||||
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
||||
print("grad out0 = ", grad_func(input_x, idx, value, dout))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_seq_setitem_grad_1():
|
||||
"""
|
||||
Feature: test sequence setitem grad op
|
||||
Description: setitem operation on tuple type
|
||||
Expectation: the behavior is matched to python style
|
||||
"""
|
||||
net_ms = NetSetItem()
|
||||
input_x = mutable((1, 2, 3), True)
|
||||
idx = 1
|
||||
value = 8
|
||||
dout = mutable((1, 1, 1), True)
|
||||
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
||||
print("grad out1 = ", grad_func(input_x, idx, value, dout))
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.ops.operations import _sequence_ops as S
|
||||
from mindspore.common import mutable
|
||||
from mindspore.ops.composite import GradOperation
|
||||
from tuple_help import TupleFactory
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_seq_slice():
|
||||
"""
|
||||
Feature: test sequence_slice op
|
||||
Description: slice operation on tuple type
|
||||
Expectation: the behavior is matched to python style
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.seq_slice = S.SequenceSlice()
|
||||
|
||||
def construct(self, seq, start, stop, step):
|
||||
return self.seq_slice(seq, start, stop, step)
|
||||
|
||||
def func(seq, start, stop, step):
|
||||
return seq[start:stop:step]
|
||||
|
||||
seq = (1, 2, 3, 4, 5, 6)
|
||||
start = 1
|
||||
stop = 3
|
||||
step = 1
|
||||
net_ms = Net()
|
||||
fact = TupleFactory(net_ms, func, (seq, start, stop, step))
|
||||
fact.forward_cmp()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_seq_slice_grad():
|
||||
"""
|
||||
Feature: test sequence_slice grad
|
||||
Description: slice operation on tuple type
|
||||
Expectation: the behavior is matched to python style
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.seq_slice = S.SequenceSlice()
|
||||
|
||||
def construct(self, seq, start, stop, step):
|
||||
return self.seq_slice(seq, start, stop, step)
|
||||
|
||||
seq = mutable((1, 2, 3, 4, 5, 6), True)
|
||||
start = 1
|
||||
stop = 3
|
||||
step = 1
|
||||
dout = mutable((1, 1), True)
|
||||
net_ms = Net()
|
||||
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
||||
print("grad out1 = ", grad_func(seq, start, stop, step, dout))
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore.ops.operations import _sequence_ops as seq
|
||||
from mindspore import context
|
||||
from mindspore.nn import Cell
|
||||
from tuple_help import TupleFactory
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_seqence_zeros_like():
|
||||
"""
|
||||
Feature: test sequence zeroslike op
|
||||
Description: setitem operation on tuple type
|
||||
Expectation: the behavior is matched to python style
|
||||
"""
|
||||
class Net(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.func = seq.SequenceZerosLike()
|
||||
|
||||
def construct(self, x):
|
||||
return self.func(x)
|
||||
|
||||
def func(x):
|
||||
return tuple(np.zeros_like(x))
|
||||
net_ms = Net()
|
||||
input_x = (1, 2, 3, 4, 5)
|
||||
fact = TupleFactory(net_ms, func, (input_x,))
|
||||
fact.forward_cmp()
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import Cell
|
||||
import mindspore
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class ShapeNet(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.shape = mindspore.ops.shape
|
||||
|
||||
def construct(self, x):
|
||||
return self.shape(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_shape():
|
||||
"""
|
||||
Feature: test sequence shape op
|
||||
Description: setitem operation on tuple type
|
||||
Expectation: the behavior is matched to python style
|
||||
"""
|
||||
x = Tensor(np.random.randn(1, 2, 4).astype(np.float32))
|
||||
dynx = Tensor(shape=(1, 2, None), dtype=mindspore.float32)
|
||||
expect_x = (1, 2, 4)
|
||||
net = ShapeNet()
|
||||
net.set_inputs(dynx)
|
||||
res_x = net(x)
|
||||
assert expect_x == res_x
|
Loading…
Reference in New Issue