forked from mindspore-Ecosystem/mindspore
!26350 add check for resizenearestneighbor parallel op
Merge pull request !26350 from yangzhenzhang/add-check-for-resize-op
This commit is contained in:
commit
9f52343a6a
|
@ -56,6 +56,14 @@ std::string StrategyToString(const Strategys &strategy) {
|
|||
return strategy_str;
|
||||
}
|
||||
|
||||
Status OperatorInfo::CheckOutputStrategy(const StrategyPtr &out_strategy) {
|
||||
if (out_strategy) {
|
||||
MS_LOG(ERROR) << name_ << ": It does not support to set output strategy now, please modify the shard set";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape) {
|
||||
if (strategy == nullptr) {
|
||||
MS_LOG(ERROR) << name_ << ": The strategy is null.";
|
||||
|
|
|
@ -209,7 +209,7 @@ class OperatorInfo {
|
|||
virtual Status InferMirrorOps();
|
||||
virtual Status InferTensorInfo();
|
||||
virtual void InferReplaceOps() {}
|
||||
virtual Status CheckOutputStrategy(const StrategyPtr &out_strategy) { return SUCCESS; }
|
||||
virtual Status CheckOutputStrategy(const StrategyPtr &out_strategy);
|
||||
Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape);
|
||||
void SetRepeatedCalcDevMatrix();
|
||||
void ResetTensorMapIfRepeatedCalc();
|
||||
|
|
|
@ -126,6 +126,25 @@ void ResizeBilinearInfo::ReplaceNodeInputOrAttrs() {
|
|||
Status ResizeNearestNeighborInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
MS_EXCEPTION_IF_NULL(strategy);
|
||||
|
||||
if (align_corners_) {
|
||||
std::vector<Dimensions> stra = strategy->GetInputDim();
|
||||
if (stra.size() != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of strategy must be 1, but got " << stra.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Dimensions input_strategy = stra[0];
|
||||
if (input_strategy.size() != 4) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of input strategy must be 4, but got" << input_strategy.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (input_strategy[2] != 1 || input_strategy[3] != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The align_corners is True, do not support split from H or W";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
// check input strategy
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Check input strategy failed";
|
||||
|
@ -143,6 +162,10 @@ Status ResizeNearestNeighborInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
|
||||
std::vector<StrategyPtr> ResizeNearestNeighborInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
Shape multiples_split(inputs_shape_[0].size(), 1);
|
||||
if (align_corners_) {
|
||||
multiples_split[2] = 0;
|
||||
multiples_split[3] = 0;
|
||||
}
|
||||
Shapes splittable_inputs = {multiples_split};
|
||||
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
'''ResizeBilinear and ResizeNearestNeigbor ut'''
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
|
@ -43,13 +44,13 @@ class Net2(Cell):
|
|||
'''
|
||||
create the test Net
|
||||
'''
|
||||
def __init__(self, conv2d_weight, mul_weight, out_channel, kernel_size, pad_mode, stride,
|
||||
strategy1=None, strategy2=None):
|
||||
def __init__(self, conv2d_weight, mul_weight, out_channel, kernel_size, pad_mode, stride, align_corners=False,
|
||||
strategy1=None, strategy2=None, out_strategy=None):
|
||||
super(Net2, self).__init__()
|
||||
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
|
||||
pad_mode=pad_mode, stride=stride).shard(strategy1)
|
||||
self.conv2d_weight = Parameter(conv2d_weight, "w1")
|
||||
self.resize_neighbor = P.ResizeNearestNeighbor((16, 16)).shard(strategy2)
|
||||
self.resize_neighbor = P.ResizeNearestNeighbor((16, 16), align_corners).shard(strategy2, out_strategy)
|
||||
self.mul = P.Mul()
|
||||
self.mul_weight = Parameter(mul_weight, "w2")
|
||||
|
||||
|
@ -59,6 +60,7 @@ class Net2(Cell):
|
|||
out = self.mul(out, self.mul_weight)
|
||||
return out
|
||||
|
||||
|
||||
class Net3(Cell):
|
||||
'''
|
||||
create the test Net
|
||||
|
@ -92,6 +94,11 @@ def compile_net(net, inputs=_x):
|
|||
|
||||
|
||||
def test_bililear_data_parallel():
|
||||
"""
|
||||
Feature: test ResizeBilinear data parallel strategy
|
||||
Description: only shard batch dimension
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
|
||||
strategy2 = ((8, 1, 1, 1),)
|
||||
|
@ -101,6 +108,11 @@ def test_bililear_data_parallel():
|
|||
|
||||
|
||||
def test_bilinear_model_parallel1():
|
||||
"""
|
||||
Feature: test ResizeBilinear model parallel strategy
|
||||
Description: shard N/C
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||
strategy2 = ((4, 2, 1, 1),)
|
||||
|
@ -109,7 +121,12 @@ def test_bilinear_model_parallel1():
|
|||
compile_net(net)
|
||||
|
||||
|
||||
def test_bilinear_model_parallel2():
|
||||
def test_bilinear_repeated_calc():
|
||||
"""
|
||||
Feature: test ResizeBilinear repeated calculation parallel strategy
|
||||
Description: only shard batch dimension, but shard num smaller than device num
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||
strategy2 = ((2, 1, 1, 1),)
|
||||
|
@ -119,18 +136,33 @@ def test_bilinear_model_parallel2():
|
|||
|
||||
|
||||
def test_bilinear_auto_parallel():
|
||||
"""
|
||||
Feature: test ResizeBilinear auto parallel
|
||||
Description:
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_bilinear_no_strategy():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
"""
|
||||
Feature: test ResizeBilinear semi auto parallel, and has not set strategy for it
|
||||
Description:
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net3(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_neighbor_data_parallel():
|
||||
"""
|
||||
Feature: test ResizeNearestNeighbor data parallel strategy
|
||||
Description: only shard batch dimension
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
|
||||
strategy2 = ((8, 1, 1, 1),)
|
||||
|
@ -139,7 +171,57 @@ def test_neighbor_data_parallel():
|
|||
compile_net(net)
|
||||
|
||||
|
||||
def test_neighbor_model_parallel1():
|
||||
def test_neighbor_model_parallel_align_corners_shard_HW():
|
||||
"""
|
||||
Feature: test ResizeNearestNeighbor model parallel strategy
|
||||
Description: the align_corners is True, and shard N/C/H/W
|
||||
Expectation: compile failed
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||
strategy2 = ((2, 2, 2, 2),)
|
||||
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1, align_corners=True,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_neighbor_out_strategy():
|
||||
"""
|
||||
Feature: test ResizeNearestNeighbor to set output parallel strategy
|
||||
Description:
|
||||
Expectation: compile failed
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||
strategy2 = ((2, 2, 2, 2),)
|
||||
out_strategy = ((2, 2, 2, 2),)
|
||||
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
|
||||
strategy1=strategy1, strategy2=strategy2, out_strategy=out_strategy)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_neighbor_model_parallel_align_corners_shard_NC():
|
||||
"""
|
||||
Feature: test ResizeNearestNeighbor model parallel strategy
|
||||
Description: the align_corners is True, and shard N/C
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||
strategy2 = ((4, 4, 1, 1),)
|
||||
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1, align_corners=True,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_neighbor_model_parallel_align_corners_is_false():
|
||||
"""
|
||||
Feature: test ResizeNearestNeighbor model parallel strategy
|
||||
Description: the align_corners is False, and shard N/C/H/W
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||
strategy2 = ((2, 2, 2, 2),)
|
||||
|
@ -149,6 +231,11 @@ def test_neighbor_model_parallel1():
|
|||
|
||||
|
||||
def test_neighbor_auto_parallel():
|
||||
"""
|
||||
Feature: test ResizeNearestNeighbor auto parallel
|
||||
Description:
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
|
||||
compile_net(net)
|
||||
|
|
Loading…
Reference in New Issue