forked from mindspore-Ecosystem/mindspore
Add test case for CellList getattr.
This commit is contained in:
parent
eaa07864f5
commit
9946062c53
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -13,8 +13,9 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test a list of cell, and getattr by its item """
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import context, nn, dtype, Tensor
|
||||
from mindspore import context, nn, dtype, Tensor, ms_function
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
|
@ -49,6 +50,26 @@ def test_list_item_getattr():
|
|||
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph mode yet')
|
||||
def test_cell_list_getattr():
|
||||
"""
|
||||
Feature: getattr by the item from nn.CellList.
|
||||
Description: Support RL use method in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
actor_list = nn.CellList()
|
||||
for _ in range(3):
|
||||
actor_list.append(Actor())
|
||||
trainer = Trainer(actor_list)
|
||||
x = Tensor([3], dtype=dtype.float32)
|
||||
y = Tensor([6], dtype=dtype.float32)
|
||||
res = trainer(x, y)
|
||||
print(f'res: {res}')
|
||||
expect_res = Tensor([9], dtype=dtype.float32)
|
||||
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
|
||||
|
||||
|
||||
class Trainer2(nn.Cell):
|
||||
def __init__(self, net_list):
|
||||
super(Trainer2, self).__init__()
|
||||
|
@ -80,3 +101,98 @@ def test_list_item_getattr2():
|
|||
print(f'res: {res}')
|
||||
expect_res = Tensor([27], dtype=dtype.float32)
|
||||
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph mode yet')
|
||||
def test_cell_list_getattr2():
|
||||
"""
|
||||
Feature: getattr by the item from nn.CellList.
|
||||
Description: Support RL use method in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
actor_list = nn.CellList()
|
||||
for _ in range(3):
|
||||
actor_list.append(Actor())
|
||||
trainer = Trainer2(actor_list)
|
||||
x = Tensor([3], dtype=dtype.float32)
|
||||
y = Tensor([6], dtype=dtype.float32)
|
||||
res = trainer(x, y)
|
||||
print(f'res: {res}')
|
||||
expect_res = Tensor([27], dtype=dtype.float32)
|
||||
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
|
||||
|
||||
|
||||
class MSRL(nn.Cell):
|
||||
def __init__(self, agent):
|
||||
super(MSRL, self).__init__()
|
||||
self.agent = agent
|
||||
|
||||
|
||||
class Agent(nn.Cell):
|
||||
def __init__(self, actor):
|
||||
super(Agent, self).__init__()
|
||||
self.actor = actor
|
||||
|
||||
def act(self, x, y):
|
||||
out = self.actor.act(x, y)
|
||||
return out
|
||||
|
||||
|
||||
class Trainer3(nn.Cell):
|
||||
def __init__(self, msrl):
|
||||
super(Trainer3, self).__init__()
|
||||
self.msrl = msrl
|
||||
|
||||
@ms_function
|
||||
def test(self, x, y):
|
||||
num_actor = 0
|
||||
output = 0
|
||||
while num_actor < 3:
|
||||
output += self.msrl.agent[num_actor].act(x, y)
|
||||
num_actor += 1
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph mode yet')
|
||||
def test_list_item_getattr3():
|
||||
"""
|
||||
Feature: getattr by the item from list of cell.
|
||||
Description: Support RL use method in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
agent_list = []
|
||||
for _ in range(3):
|
||||
actor = Actor()
|
||||
agent_list.append(Agent(actor))
|
||||
msrl = MSRL(agent_list)
|
||||
trainer = Trainer3(msrl)
|
||||
x = Tensor([2], dtype=dtype.int32)
|
||||
y = Tensor([3], dtype=dtype.int32)
|
||||
res = trainer.test(x, y)
|
||||
print(f'res: {res}')
|
||||
expect_res = Tensor([15], dtype=dtype.int32)
|
||||
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph mode yet')
|
||||
def test_cell_list_getattr3():
|
||||
"""
|
||||
Feature: getattr by the item from list of cell.
|
||||
Description: Support RL use method in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
agent_list = nn.CellList()
|
||||
for _ in range(3):
|
||||
actor = Actor()
|
||||
agent_list.append(Agent(actor))
|
||||
msrl = MSRL(agent_list)
|
||||
trainer = Trainer3(msrl)
|
||||
x = Tensor([2], dtype=dtype.float32)
|
||||
y = Tensor([3], dtype=dtype.float32)
|
||||
res = trainer.test(x, y)
|
||||
print(f'res: {res}')
|
||||
expect_res = Tensor([15], dtype=dtype.float32)
|
||||
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
|
||||
|
|
Loading…
Reference in New Issue