Add test case for CellList getattr.

This commit is contained in:
huangbingjian 2022-03-08 16:22:33 +08:00
parent eaa07864f5
commit 9946062c53
1 changed files with 118 additions and 2 deletions

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,8 +13,9 @@
# limitations under the License.
# ============================================================================
""" test a list of cell, and getattr by its item """
import pytest
import numpy as np
from mindspore import context, nn, dtype, Tensor
from mindspore import context, nn, dtype, Tensor, ms_function
from mindspore.ops import operations as P
@ -49,6 +50,26 @@ def test_list_item_getattr():
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
@pytest.mark.skip(reason='Not support in graph mode yet')
def test_cell_list_getattr():
"""
Feature: getattr by the item from nn.CellList.
Description: Support RL use method in graph mode.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
actor_list = nn.CellList()
for _ in range(3):
actor_list.append(Actor())
trainer = Trainer(actor_list)
x = Tensor([3], dtype=dtype.float32)
y = Tensor([6], dtype=dtype.float32)
res = trainer(x, y)
print(f'res: {res}')
expect_res = Tensor([9], dtype=dtype.float32)
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
class Trainer2(nn.Cell):
def __init__(self, net_list):
super(Trainer2, self).__init__()
@ -80,3 +101,98 @@ def test_list_item_getattr2():
print(f'res: {res}')
expect_res = Tensor([27], dtype=dtype.float32)
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
@pytest.mark.skip(reason='Not support in graph mode yet')
def test_cell_list_getattr2():
"""
Feature: getattr by the item from nn.CellList.
Description: Support RL use method in graph mode.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
actor_list = nn.CellList()
for _ in range(3):
actor_list.append(Actor())
trainer = Trainer2(actor_list)
x = Tensor([3], dtype=dtype.float32)
y = Tensor([6], dtype=dtype.float32)
res = trainer(x, y)
print(f'res: {res}')
expect_res = Tensor([27], dtype=dtype.float32)
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
class MSRL(nn.Cell):
def __init__(self, agent):
super(MSRL, self).__init__()
self.agent = agent
class Agent(nn.Cell):
def __init__(self, actor):
super(Agent, self).__init__()
self.actor = actor
def act(self, x, y):
out = self.actor.act(x, y)
return out
class Trainer3(nn.Cell):
def __init__(self, msrl):
super(Trainer3, self).__init__()
self.msrl = msrl
@ms_function
def test(self, x, y):
num_actor = 0
output = 0
while num_actor < 3:
output += self.msrl.agent[num_actor].act(x, y)
num_actor += 1
return output
@pytest.mark.skip(reason='Not support in graph mode yet')
def test_list_item_getattr3():
"""
Feature: getattr by the item from list of cell.
Description: Support RL use method in graph mode.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
agent_list = []
for _ in range(3):
actor = Actor()
agent_list.append(Agent(actor))
msrl = MSRL(agent_list)
trainer = Trainer3(msrl)
x = Tensor([2], dtype=dtype.int32)
y = Tensor([3], dtype=dtype.int32)
res = trainer.test(x, y)
print(f'res: {res}')
expect_res = Tensor([15], dtype=dtype.int32)
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
@pytest.mark.skip(reason='Not support in graph mode yet')
def test_cell_list_getattr3():
"""
Feature: getattr by the item from list of cell.
Description: Support RL use method in graph mode.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
agent_list = nn.CellList()
for _ in range(3):
actor = Actor()
agent_list.append(Agent(actor))
msrl = MSRL(agent_list)
trainer = Trainer3(msrl)
x = Tensor([2], dtype=dtype.float32)
y = Tensor([3], dtype=dtype.float32)
res = trainer.test(x, y)
print(f'res: {res}')
expect_res = Tensor([15], dtype=dtype.float32)
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())