forked from mindspore-Ecosystem/mindspore
!177 prelu operator support parallel on the channel
Merge pull request !177 from yao_yf/fix_auto_parallel_prelu
This commit is contained in:
commit
2e6e94b2b6
|
@ -52,7 +52,7 @@ Status PReLUInfo::CheckStrategy(const StrategyPtr& strategy) {
|
|||
}
|
||||
return FAILED;
|
||||
}
|
||||
if ((stra[0][PRELU_CHANNEL_INDEX] != PRELU_CHANNEL_STRATEGY) || (stra[1][0] != PRELU_CHANNEL_STRATEGY)) {
|
||||
if (stra[0][PRELU_CHANNEL_INDEX] != stra[1][0]) {
|
||||
if (is_auto_parallel_) {
|
||||
MS_LOG(DEBUG) << name_ << ": Invalid channel strategy.";
|
||||
} else {
|
||||
|
|
|
@ -146,11 +146,10 @@ TEST_F(TestPReLUInfo, CheckStrategy1) {
|
|||
}
|
||||
|
||||
TEST_F(TestPReLUInfo, CheckStrategy2) {
|
||||
// Success: {{2,1,8,16},{1}}
|
||||
std::vector<Dimensions> inputs = {{2, 4, 8, 16}, {4}};
|
||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||
Status ret = prelu->Init(strategy);
|
||||
ASSERT_EQ(ret, FAILED);
|
||||
ASSERT_EQ(ret, SUCCESS);
|
||||
}
|
||||
|
||||
TEST_F(TestPReLUInfo, AutoStrategy1) {
|
||||
|
@ -252,11 +251,10 @@ TEST_F(TestPReLUInfo, CheckStrategy_2d1) {
|
|||
}
|
||||
|
||||
TEST_F(TestPReLUInfo, CheckStrategy_2d2) {
|
||||
// Success: {{2,1,8,16},{1}}
|
||||
std::vector<Dimensions> inputs = {{128, 4}, {4}};
|
||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||
Status ret = prelu_2d->Init(strategy);
|
||||
ASSERT_EQ(ret, FAILED);
|
||||
ASSERT_EQ(ret, SUCCESS);
|
||||
}
|
||||
|
||||
TEST_F(TestPReLUInfo, AutoStrategy_2d1) {
|
||||
|
|
|
@ -149,3 +149,20 @@ def test_prelu_parallel_success3():
|
|||
w = Tensor(np.random.rand(16),dtype=ms.float32)
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
||||
_executor.compile(net, x, y, w)
|
||||
|
||||
def test_prelu_parallel_success4():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy):
|
||||
super().__init__()
|
||||
self.prelu = P.PReLU().set_strategy(strategy)
|
||||
def construct(self, x, y):
|
||||
out = self.prelu(x, y)
|
||||
return out
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=64, global_rank=0)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
strategy = ((2, 4, 4, 2), (4, ))
|
||||
x = Tensor(np.random.rand(4, 16, 32, 64),dtype=ms.float32)
|
||||
w = Tensor(np.random.rand(16),dtype=ms.float32)
|
||||
net = GradWrap(NetWithLoss(Net(strategy)))
|
||||
_executor.compile(net, x, w)
|
||||
|
|
Loading…
Reference in New Issue