add testcase for batchreadwrite and tensorsqueue

This commit is contained in:
VectorSL 2022-04-01 11:05:05 +08:00
parent 74bdc00fed
commit 747f064f6b
3 changed files with 240 additions and 26 deletions

View File

@ -58,19 +58,7 @@ class TensorsQueuePutCpuKernelMod : public TensorsQueueCPUBaseMod {
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64)};
static std::vector<KernelAttr> support_list = {KernelAttr().AddSkipCheckAttr(true)};
return support_list;
}
@ -90,19 +78,7 @@ class TensorsQueueGetCpuKernelMod : public TensorsQueueCPUBaseMod {
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool)};
static std::vector<KernelAttr> support_list = {KernelAttr().AddSkipCheckAttr(true)};
return support_list;
}

View File

@ -0,0 +1,129 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.common.dtype as mstype
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.nn.reinforcement._batch_read_write import BatchRead, BatchWrite
class DstNet(nn.Cell):
'''Dst net'''
def __init__(self):
super(DstNet, self).__init__()
self.a = Parameter(Tensor(0.1, mstype.float32), name="a")
self.dense = nn.Dense(in_channels=16, out_channels=1)
def construct(self, data):
d = self.dense(data)
out = d + self.a
return out
class SourceNet(nn.Cell):
'''Source net'''
def __init__(self):
super(SourceNet, self).__init__()
self.a = Parameter(Tensor(0.5, mstype.float32), name="a")
self.dense = nn.Dense(in_channels=16, out_channels=1, weight_init=0)
def construct(self, data):
d = self.dense(data)
out = d + self.a
return out
class Write(nn.Cell):
'''Write cell'''
def __init__(self, dst, src):
super(Write, self).__init__()
self.write = BatchWrite()
self.dst = ParameterTuple(dst.trainable_params())
self.src = ParameterTuple(src.trainable_params())
def construct(self):
success = self.write(self.dst, self.src)
return success
class Read(nn.Cell):
'''Read cell'''
def __init__(self, dst, src):
super(Read, self).__init__()
self.read = BatchRead()
self.dst = ParameterTuple(dst.trainable_params())
self.src = ParameterTuple(src.trainable_params())
def construct(self):
success = self.read(self.dst, self.src)
return success
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_read_write_model_gpu():
"""
Feature: BatchPushPull gpu TEST.
Description: Test the batch assign.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
dst_net = DstNet()
source_net = SourceNet()
dst_param = dst_net.trainable_params()
source_param = source_net.trainable_params()
nets = nn.CellList()
nets.append(dst_net)
nets.append(source_net)
# Test read source net's params to replace dst_net's params.
_ = Read(nets[0], nets[1])()
assert np.allclose(dst_param[0].asnumpy(), 0.5)
# Test write dst net's params to overwrite the source.
dst_net2 = DstNet()
nets[0] = dst_net2
_ = Write(nets[1], nets[0])()
assert np.allclose(source_param[0].asnumpy(), 0.1)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_read_write_model_cpu():
"""
Feature: BatchPushPull cpu TEST.
Description: Test the batch assign.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
dst_net = DstNet()
source_net = SourceNet()
dst_param = dst_net.trainable_params()
source_param = source_net.trainable_params()
cpu_nets = nn.CellList()
cpu_nets.append(dst_net)
cpu_nets.append(source_net)
_ = Read(cpu_nets[0], cpu_nets[1])()
assert np.allclose(dst_param[0].asnumpy(), 0.5)
dst_net2 = DstNet()
cpu_nets[0] = dst_net2
_ = Write(cpu_nets[1], cpu_nets[0])()
assert np.allclose(source_param[0].asnumpy(), 0.1)

View File

@ -0,0 +1,109 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.parameter import Parameter, ParameterTuple
import mindspore.common.dtype as mstype
from mindspore.ops import composite as C
from mindspore.nn.reinforcement._tensors_queue import TensorsQueue
class TensorsQueueNet(nn.Cell):
def __init__(self, dtype, shapes, size=0, name="q"):
super(TensorsQueueNet, self).__init__()
self.tq = TensorsQueue(dtype, shapes, size, name)
def construct(self, grads):
self.tq.put(grads)
self.tq.put(grads)
size_before = self.tq.size()
ans = self.tq.pop()
size_after = self.tq.size()
self.tq.clear()
self.tq.close()
return ans, size_before, size_after
class SourceNet(nn.Cell):
'''Source net'''
def __init__(self):
super(SourceNet, self).__init__()
self.a = Parameter(Tensor(0.5, mstype.float32), name="a")
self.dense = nn.Dense(in_channels=4, out_channels=1, weight_init=0)
def construct(self, data):
d = self.dense(data)
out = d + self.a
return out
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_tensorsqueue_gpu():
"""
Feature: TensorsQueue gpu TEST.
Description: Test the function write, read, stack, clear, close in both graph and pynative mode.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
input_data = Tensor(np.arange(8).reshape(2, 4), mstype.float32)
net = SourceNet()
weight = ParameterTuple(net.trainable_params())
grad = C.GradOperation(get_by_list=True, sens_param=False)
_ = net(input_data)
grads = grad(net, weight)(input_data)
shapes = []
for i in grads:
shapes.append(i.shape)
tq = TensorsQueueNet(dtype=mstype.float32, shapes=shapes, size=5, name="tq")
ans, size_before, size_after = tq(grads)
assert np.allclose(size_before.asnumpy(), 2)
assert np.allclose(size_after.asnumpy(), 1)
assert np.allclose(ans[0].asnumpy(), 2.0)
assert np.allclose(ans[1].asnumpy(), [[4.0, 6.0, 8.0, 10.0]])
assert np.allclose(ans[2].asnumpy(), [2.0])
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_tensorsqueue_cpu():
"""
Feature: TensorsQueue cpu TEST.
Description: Test the function write, read, stack, clear, close in both graph and pynative mode.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
data = Tensor(np.arange(8).reshape(2, 4), mstype.float32)
net = SourceNet()
weight = ParameterTuple(net.trainable_params())
grad = C.GradOperation(get_by_list=True, sens_param=False)
_ = net(data)
grads = grad(net, weight)(data)
shapes = []
for i in grads:
shapes.append(i.shape)
tq_cpu = TensorsQueueNet(dtype=mstype.float32, shapes=shapes, size=5, name="tq")
ans, size_before, size_after = tq_cpu(grads)
assert np.allclose(ans[0].asnumpy(), 2.0)
assert np.allclose(ans[1].asnumpy(), [[4.0, 6.0, 8.0, 10.0]])
assert np.allclose(ans[2].asnumpy(), [2.0])
assert np.allclose(size_before.asnumpy(), 2)
assert np.allclose(size_after.asnumpy(), 1)