forked from mindspore-Ecosystem/mindspore
commit
1309df9bb7
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -76,25 +76,20 @@ bool SequenceAddOffsetCpuKernelMod::Launch(const std::vector<KernelTensorPtr> &i
|
||||||
|
|
||||||
std::vector<std::pair<KernelAttr, SequenceAddOffsetCpuKernelMod::SequenceAddOffsetFunc>>
|
std::vector<std::pair<KernelAttr, SequenceAddOffsetCpuKernelMod::SequenceAddOffsetFunc>>
|
||||||
SequenceAddOffsetCpuKernelMod::func_list_ = {{KernelAttr()
|
SequenceAddOffsetCpuKernelMod::func_list_ = {{KernelAttr()
|
||||||
|
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32)
|
||||||
|
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
|
||||||
|
&SequenceAddOffsetCpuKernelMod::LaunchKernel<float>},
|
||||||
|
{KernelAttr()
|
||||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat64)
|
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat64)
|
||||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat64)
|
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat64)
|
||||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
|
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
|
||||||
&SequenceAddOffsetCpuKernelMod::LaunchKernel<double>},
|
&SequenceAddOffsetCpuKernelMod::LaunchKernel<double>},
|
||||||
{KernelAttr()
|
|
||||||
.AddInputAttr(kObjectTypeList, kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kObjectTypeList, kNumberTypeInt32)
|
|
||||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
|
|
||||||
&SequenceAddOffsetCpuKernelMod::LaunchKernel<int>},
|
|
||||||
{KernelAttr()
|
{KernelAttr()
|
||||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
||||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
||||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
|
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
|
||||||
&SequenceAddOffsetCpuKernelMod::LaunchKernel<int>},
|
&SequenceAddOffsetCpuKernelMod::LaunchKernel<int>},
|
||||||
{KernelAttr()
|
|
||||||
.AddInputAttr(kObjectTypeList, kNumberTypeInt64)
|
|
||||||
.AddInputAttr(kObjectTypeList, kNumberTypeInt64)
|
|
||||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
|
|
||||||
&SequenceAddOffsetCpuKernelMod::LaunchKernel<int64_t>},
|
|
||||||
{KernelAttr()
|
{KernelAttr()
|
||||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -36,7 +36,7 @@ T Sign(T num) {
|
||||||
} else {
|
} else {
|
||||||
return static_cast<T>(-1.0);
|
return static_cast<T>(-1.0);
|
||||||
}
|
}
|
||||||
} // namespace
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
bool MakeRangeCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
bool MakeRangeCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
@ -74,7 +74,7 @@ bool MakeRangeCpuKernelMod::LaunchKernel(const std::vector<KernelTensorPtr> &inp
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
} // namespace kernel
|
}
|
||||||
|
|
||||||
const std::vector<std::pair<KernelAttr, MakeRangeCpuKernelMod::KernelRunFunc>> &MakeRangeCpuKernelMod::GetFuncList()
|
const std::vector<std::pair<KernelAttr, MakeRangeCpuKernelMod::KernelRunFunc>> &MakeRangeCpuKernelMod::GetFuncList()
|
||||||
const {
|
const {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from mindspore.ops.composite import base
|
from mindspore.ops.composite import base
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.ops.operations import _sequence_ops as seq
|
||||||
|
|
||||||
zeros_like_leaf = base.MultitypeFuncGraph('zeros_like_leaf', True)
|
zeros_like_leaf = base.MultitypeFuncGraph('zeros_like_leaf', True)
|
||||||
"""
|
"""
|
||||||
|
@ -59,6 +60,12 @@ def _zeros_like_tensor(x):
|
||||||
return F.zeros_like(x)
|
return F.zeros_like(x)
|
||||||
|
|
||||||
|
|
||||||
|
@zeros_like_leaf.register("Tuple")
|
||||||
|
def _zeros_like_tuple(x):
|
||||||
|
"""Returns a Tuple with the same shape and dtype as x and all elements are 0."""
|
||||||
|
return seq.SequenceZerosLike()(x)
|
||||||
|
|
||||||
|
|
||||||
@zeros_like_leaf.register("COOTensor")
|
@zeros_like_leaf.register("COOTensor")
|
||||||
def _zeros_like_coo_tensor(x):
|
def _zeros_like_coo_tensor(x):
|
||||||
"""Returns a tensor with the same shape and dtype as x and all elements are 0."""
|
"""Returns a tensor with the same shape and dtype as x and all elements are 0."""
|
||||||
|
|
|
@ -0,0 +1,118 @@
|
||||||
|
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops.operations import _sequence_ops as seq
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.common import mutable
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
from mindspore.ops.composite import GradOperation
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
class NetAdd(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.seq_add = seq.SequenceAdd()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.seq_add(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class NetAddOffset(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.seq_add_offset = seq.SequenceAddOffset()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.seq_add_offset(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_seq_add():
|
||||||
|
"""
|
||||||
|
Feature: test sequence_add op
|
||||||
|
Description: two inputs are dynamic sequence
|
||||||
|
Expectation: the result match with tuple result
|
||||||
|
"""
|
||||||
|
x = mutable((1, 2, 3), True)
|
||||||
|
y = mutable((4, 5, 6), True)
|
||||||
|
expect = (1, 2, 3, 4, 5, 6)
|
||||||
|
net = NetAdd()
|
||||||
|
res = net(x, y)
|
||||||
|
assert res == expect
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_seq_add_offset():
|
||||||
|
"""
|
||||||
|
Feature: test sequence_add_offset op
|
||||||
|
Description: inputs are dynamic sequence.
|
||||||
|
Expectation: the result match with tuple result
|
||||||
|
"""
|
||||||
|
x = mutable((1, 2, 3), True)
|
||||||
|
y = mutable((4, 5, 6), True)
|
||||||
|
expect = (0, 3)
|
||||||
|
net = NetAddOffset()
|
||||||
|
res = net(x, y)
|
||||||
|
assert res == expect
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_seq_add_grad():
|
||||||
|
"""
|
||||||
|
Feature: test sequence add grad op
|
||||||
|
Description: inputs are dynamic sequence.
|
||||||
|
Expectation: the result match with tuple result
|
||||||
|
"""
|
||||||
|
class Net(Cell):
|
||||||
|
def construct(self, x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
net_ms = Net()
|
||||||
|
input_x = mutable((1, 2, 3), True)
|
||||||
|
input_y = mutable((3, 4, 5, 6), True)
|
||||||
|
dout = mutable((1, 1, 1, 1, 1, 1, 1), True)
|
||||||
|
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
||||||
|
print("grad out = ", grad_func(input_x, input_y, dout))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_seq_add_grad_other():
|
||||||
|
"""
|
||||||
|
Feature: test sequence add grad op
|
||||||
|
Description: inputs are dynamic sequence.
|
||||||
|
Expectation: the result match with tuple result
|
||||||
|
"""
|
||||||
|
class Net(Cell):
|
||||||
|
def construct(self, x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
net_ms = Net()
|
||||||
|
input_x = mutable((1, 2, 3), True)
|
||||||
|
input_y = (3, 4, 5, 6)
|
||||||
|
dout = mutable((1, 1, 1, 1, 1, 1, 1), True)
|
||||||
|
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
||||||
|
print("grad out1 = ", grad_func(input_x, input_y, dout))
|
|
@ -0,0 +1,61 @@
|
||||||
|
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.common import mutable
|
||||||
|
from mindspore.ops.composite import GradOperation
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
class NetGetItem(nn.Cell):
|
||||||
|
def construct(self, seq, idx):
|
||||||
|
return seq[idx]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_seq_getitem():
|
||||||
|
"""
|
||||||
|
Feature: test sequence getitem op
|
||||||
|
Description: setitem operation on tuple type
|
||||||
|
Expectation: the behavior is matched to python style
|
||||||
|
"""
|
||||||
|
seq = mutable((1, 2, 3, 4, 5, 6), True)
|
||||||
|
idx = 3
|
||||||
|
expect = 4
|
||||||
|
net = NetGetItem()
|
||||||
|
res = net(seq, idx)
|
||||||
|
assert res == expect
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_seq_getitem_grad():
|
||||||
|
"""
|
||||||
|
Feature: test sequence getitem grad op
|
||||||
|
Description: setitem operation on tuple type
|
||||||
|
Expectation: the behavior is matched to python style
|
||||||
|
"""
|
||||||
|
net_ms = NetGetItem()
|
||||||
|
seq = mutable((1, 2, 3, 4, 5, 6), True)
|
||||||
|
index = 1
|
||||||
|
dout = 1
|
||||||
|
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
||||||
|
print("grad out1 = ", grad_func(seq, index, dout))
|
|
@ -0,0 +1,71 @@
|
||||||
|
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import pytest
|
||||||
|
from mindspore.ops.operations import _sequence_ops as seq
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
from mindspore.common import mutable
|
||||||
|
from mindspore.ops.composite import GradOperation
|
||||||
|
from tuple_help import TupleFactory
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
class Net(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.func = seq.make_range()
|
||||||
|
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
return self.func(x, y, z)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_seqence_make_range():
|
||||||
|
"""
|
||||||
|
Feature: test sequence makerange op
|
||||||
|
Description: setitem operation on tuple type
|
||||||
|
Expectation: the behavior is matched to python style
|
||||||
|
"""
|
||||||
|
|
||||||
|
def func(x, y, z):
|
||||||
|
return tuple(range(x, y, z))
|
||||||
|
|
||||||
|
net_ms = Net()
|
||||||
|
input_x = 1
|
||||||
|
input_y = 1000
|
||||||
|
input_z = 31
|
||||||
|
fact = TupleFactory(net_ms, func, (input_x, input_y, input_z))
|
||||||
|
fact.forward_cmp()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_seqence_make_range_grad():
|
||||||
|
"""
|
||||||
|
Feature: test sequence makerange grad
|
||||||
|
Description: setitem operation on tuple type
|
||||||
|
Expectation: the behavior is matched to python style
|
||||||
|
"""
|
||||||
|
net_ms = Net()
|
||||||
|
input_x = mutable(10)
|
||||||
|
input_y = mutable(100)
|
||||||
|
input_z = mutable(3)
|
||||||
|
dout = mutable((1, 1), True)
|
||||||
|
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
||||||
|
print("grad out1 = ", grad_func(input_x, input_y, input_z, dout))
|
|
@ -0,0 +1,82 @@
|
||||||
|
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.common import mutable
|
||||||
|
from mindspore.ops.composite import GradOperation
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
class NetSetItem(nn.Cell):
|
||||||
|
def construct(self, seq, idx, value):
|
||||||
|
return F.tuple_setitem(seq, idx, value)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_seq_setitem():
|
||||||
|
"""
|
||||||
|
Feature: test sequence_setitem op
|
||||||
|
Description: setitem operation on tuple type
|
||||||
|
Expectation: the behavior is matched to python style
|
||||||
|
"""
|
||||||
|
seq = mutable((1, 2, 3, 4, 5, 6), True)
|
||||||
|
value = 9
|
||||||
|
idx = 3
|
||||||
|
expect = (1, 2, 3, 9, 5, 6)
|
||||||
|
net = NetSetItem()
|
||||||
|
res = net(seq, idx, value)
|
||||||
|
assert res == expect
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_seq_setitem_grad_0():
|
||||||
|
"""
|
||||||
|
Feature: test sequence setitem grad op
|
||||||
|
Description: setitem operation on tuple type
|
||||||
|
Expectation: the behavior is matched to python style
|
||||||
|
"""
|
||||||
|
net_ms = NetSetItem()
|
||||||
|
input_x = mutable((1, 2, 3), True)
|
||||||
|
idx = mutable(1)
|
||||||
|
value = mutable(8)
|
||||||
|
dout = mutable((1, 1, 1), True)
|
||||||
|
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
||||||
|
print("grad out0 = ", grad_func(input_x, idx, value, dout))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_seq_setitem_grad_1():
|
||||||
|
"""
|
||||||
|
Feature: test sequence setitem grad op
|
||||||
|
Description: setitem operation on tuple type
|
||||||
|
Expectation: the behavior is matched to python style
|
||||||
|
"""
|
||||||
|
net_ms = NetSetItem()
|
||||||
|
input_x = mutable((1, 2, 3), True)
|
||||||
|
idx = 1
|
||||||
|
value = 8
|
||||||
|
dout = mutable((1, 1, 1), True)
|
||||||
|
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
||||||
|
print("grad out1 = ", grad_func(input_x, idx, value, dout))
|
|
@ -0,0 +1,79 @@
|
||||||
|
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import pytest
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.ops.operations import _sequence_ops as S
|
||||||
|
from mindspore.common import mutable
|
||||||
|
from mindspore.ops.composite import GradOperation
|
||||||
|
from tuple_help import TupleFactory
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_seq_slice():
|
||||||
|
"""
|
||||||
|
Feature: test sequence_slice op
|
||||||
|
Description: slice operation on tuple type
|
||||||
|
Expectation: the behavior is matched to python style
|
||||||
|
"""
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.seq_slice = S.SequenceSlice()
|
||||||
|
|
||||||
|
def construct(self, seq, start, stop, step):
|
||||||
|
return self.seq_slice(seq, start, stop, step)
|
||||||
|
|
||||||
|
def func(seq, start, stop, step):
|
||||||
|
return seq[start:stop:step]
|
||||||
|
|
||||||
|
seq = (1, 2, 3, 4, 5, 6)
|
||||||
|
start = 1
|
||||||
|
stop = 3
|
||||||
|
step = 1
|
||||||
|
net_ms = Net()
|
||||||
|
fact = TupleFactory(net_ms, func, (seq, start, stop, step))
|
||||||
|
fact.forward_cmp()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_seq_slice_grad():
|
||||||
|
"""
|
||||||
|
Feature: test sequence_slice grad
|
||||||
|
Description: slice operation on tuple type
|
||||||
|
Expectation: the behavior is matched to python style
|
||||||
|
"""
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.seq_slice = S.SequenceSlice()
|
||||||
|
|
||||||
|
def construct(self, seq, start, stop, step):
|
||||||
|
return self.seq_slice(seq, start, stop, step)
|
||||||
|
|
||||||
|
seq = mutable((1, 2, 3, 4, 5, 6), True)
|
||||||
|
start = 1
|
||||||
|
stop = 3
|
||||||
|
step = 1
|
||||||
|
dout = mutable((1, 1), True)
|
||||||
|
net_ms = Net()
|
||||||
|
grad_func = GradOperation(get_all=True, sens_param=True)(net_ms)
|
||||||
|
print("grad out1 = ", grad_func(seq, start, stop, step, dout))
|
|
@ -0,0 +1,47 @@
|
||||||
|
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
from mindspore.ops.operations import _sequence_ops as seq
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
from tuple_help import TupleFactory
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_seqence_zeros_like():
|
||||||
|
"""
|
||||||
|
Feature: test sequence zeroslike op
|
||||||
|
Description: setitem operation on tuple type
|
||||||
|
Expectation: the behavior is matched to python style
|
||||||
|
"""
|
||||||
|
class Net(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.func = seq.SequenceZerosLike()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.func(x)
|
||||||
|
|
||||||
|
def func(x):
|
||||||
|
return tuple(np.zeros_like(x))
|
||||||
|
net_ms = Net()
|
||||||
|
input_x = (1, 2, 3, 4, 5)
|
||||||
|
fact = TupleFactory(net_ms, func, (input_x,))
|
||||||
|
fact.forward_cmp()
|
|
@ -0,0 +1,49 @@
|
||||||
|
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
import mindspore
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
class ShapeNet(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.shape = mindspore.ops.shape
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.shape(x)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_shape():
|
||||||
|
"""
|
||||||
|
Feature: test sequence shape op
|
||||||
|
Description: setitem operation on tuple type
|
||||||
|
Expectation: the behavior is matched to python style
|
||||||
|
"""
|
||||||
|
x = Tensor(np.random.randn(1, 2, 4).astype(np.float32))
|
||||||
|
dynx = Tensor(shape=(1, 2, None), dtype=mindspore.float32)
|
||||||
|
expect_x = (1, 2, 4)
|
||||||
|
net = ShapeNet()
|
||||||
|
net.set_inputs(dynx)
|
||||||
|
res_x = net(x)
|
||||||
|
assert expect_x == res_x
|
Loading…
Reference in New Issue