add check for resize op

This commit is contained in:
yangzhenzhang 2021-11-16 11:57:08 +08:00
parent 36bafb3654
commit ba99e4c505
4 changed files with 125 additions and 7 deletions

View File

@ -56,6 +56,14 @@ std::string StrategyToString(const Strategys &strategy) {
return strategy_str;
}
Status OperatorInfo::CheckOutputStrategy(const StrategyPtr &out_strategy) {
if (out_strategy) {
MS_LOG(ERROR) << name_ << ": It does not support to set output strategy now, please modify the shard set";
return FAILED;
}
return SUCCESS;
}
Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape) {
if (strategy == nullptr) {
MS_LOG(ERROR) << name_ << ": The strategy is null.";

View File

@ -209,7 +209,7 @@ class OperatorInfo {
virtual Status InferMirrorOps();
virtual Status InferTensorInfo();
virtual void InferReplaceOps() {}
virtual Status CheckOutputStrategy(const StrategyPtr &out_strategy) { return SUCCESS; }
virtual Status CheckOutputStrategy(const StrategyPtr &out_strategy);
Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape);
void SetRepeatedCalcDevMatrix();
void ResetTensorMapIfRepeatedCalc();

View File

@ -126,6 +126,25 @@ void ResizeBilinearInfo::ReplaceNodeInputOrAttrs() {
Status ResizeNearestNeighborInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy);
if (align_corners_) {
std::vector<Dimensions> stra = strategy->GetInputDim();
if (stra.size() != 1) {
MS_LOG(ERROR) << name_ << ": The size of strategy must be 1, but got " << stra.size();
return FAILED;
}
Dimensions input_strategy = stra[0];
if (input_strategy.size() != 4) {
MS_LOG(ERROR) << name_ << ": The size of input strategy must be 4, but got" << input_strategy.size();
return FAILED;
}
if (input_strategy[2] != 1 || input_strategy[3] != 1) {
MS_LOG(ERROR) << name_ << ": The align_corners is True, do not support split from H or W";
return FAILED;
}
}
// check input strategy
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Check input strategy failed";
@ -143,6 +162,10 @@ Status ResizeNearestNeighborInfo::CheckStrategy(const StrategyPtr &strategy) {
std::vector<StrategyPtr> ResizeNearestNeighborInfo::GenerateOpStrategies(int64_t stage_id) {
Shape multiples_split(inputs_shape_[0].size(), 1);
if (align_corners_) {
multiples_split[2] = 0;
multiples_split[3] = 0;
}
Shapes splittable_inputs = {multiples_split};
std::vector<StrategyPtr> sp_vector;

View File

@ -13,6 +13,7 @@
# limitations under the License.
'''ResizeBilinear and ResizeNearestNeigbor ut'''
import numpy as np
import pytest
import mindspore as ms
from mindspore import context, Tensor, Parameter
@ -43,13 +44,13 @@ class Net2(Cell):
'''
create the test Net
'''
def __init__(self, conv2d_weight, mul_weight, out_channel, kernel_size, pad_mode, stride,
strategy1=None, strategy2=None):
def __init__(self, conv2d_weight, mul_weight, out_channel, kernel_size, pad_mode, stride, align_corners=False,
strategy1=None, strategy2=None, out_strategy=None):
super(Net2, self).__init__()
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
pad_mode=pad_mode, stride=stride).shard(strategy1)
self.conv2d_weight = Parameter(conv2d_weight, "w1")
self.resize_neighbor = P.ResizeNearestNeighbor((16, 16)).shard(strategy2)
self.resize_neighbor = P.ResizeNearestNeighbor((16, 16), align_corners).shard(strategy2, out_strategy)
self.mul = P.Mul()
self.mul_weight = Parameter(mul_weight, "w2")
@ -59,6 +60,7 @@ class Net2(Cell):
out = self.mul(out, self.mul_weight)
return out
class Net3(Cell):
'''
create the test Net
@ -92,6 +94,11 @@ def compile_net(net, inputs=_x):
def test_bililear_data_parallel():
"""
Feature: test ResizeBilinear data parallel strategy
Description: only shard batch dimension
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((8, 1, 1, 1),)
@ -101,6 +108,11 @@ def test_bililear_data_parallel():
def test_bilinear_model_parallel1():
"""
Feature: test ResizeBilinear model parallel strategy
Description: shard N/C
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((4, 2, 1, 1),)
@ -109,7 +121,12 @@ def test_bilinear_model_parallel1():
compile_net(net)
def test_bilinear_model_parallel2():
def test_bilinear_repeated_calc():
"""
Feature: test ResizeBilinear repeated calculation parallel strategy
Description: only shard batch dimension, but shard num smaller than device num
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((2, 1, 1, 1),)
@ -119,18 +136,33 @@ def test_bilinear_model_parallel2():
def test_bilinear_auto_parallel():
"""
Feature: test ResizeBilinear auto parallel
Description:
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
compile_net(net)
def test_bilinear_no_strategy():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
"""
Feature: test ResizeBilinear semi auto parallel, and has not set strategy for it
Description:
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
net = Net3(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
compile_net(net)
def test_neighbor_data_parallel():
"""
Feature: test ResizeNearestNeighbor data parallel strategy
Description: only shard batch dimension
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((8, 1, 1, 1),)
@ -139,7 +171,57 @@ def test_neighbor_data_parallel():
compile_net(net)
def test_neighbor_model_parallel1():
def test_neighbor_model_parallel_align_corners_shard_HW():
"""
Feature: test ResizeNearestNeighbor model parallel strategy
Description: the align_corners is True, and shard N/C/H/W
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((2, 2, 2, 2),)
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1, align_corners=True,
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
def test_neighbor_out_strategy():
"""
Feature: test ResizeNearestNeighbor to set output parallel strategy
Description:
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((2, 2, 2, 2),)
out_strategy = ((2, 2, 2, 2),)
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
strategy1=strategy1, strategy2=strategy2, out_strategy=out_strategy)
with pytest.raises(RuntimeError):
compile_net(net)
def test_neighbor_model_parallel_align_corners_shard_NC():
"""
Feature: test ResizeNearestNeighbor model parallel strategy
Description: the align_corners is True, and shard N/C
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((4, 4, 1, 1),)
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1, align_corners=True,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_neighbor_model_parallel_align_corners_is_false():
"""
Feature: test ResizeNearestNeighbor model parallel strategy
Description: the align_corners is False, and shard N/C/H/W
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((2, 2, 2, 2),)
@ -149,6 +231,11 @@ def test_neighbor_model_parallel1():
def test_neighbor_auto_parallel():
"""
Feature: test ResizeNearestNeighbor auto parallel
Description:
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
compile_net(net)