add check for resize op

This commit is contained in:
yangzhenzhang 2021-11-16 11:57:08 +08:00
parent 36bafb3654
commit ba99e4c505
4 changed files with 125 additions and 7 deletions

View File

@ -56,6 +56,14 @@ std::string StrategyToString(const Strategys &strategy) {
return strategy_str; return strategy_str;
} }
Status OperatorInfo::CheckOutputStrategy(const StrategyPtr &out_strategy) {
if (out_strategy) {
MS_LOG(ERROR) << name_ << ": It does not support to set output strategy now, please modify the shard set";
return FAILED;
}
return SUCCESS;
}
Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape) { Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape) {
if (strategy == nullptr) { if (strategy == nullptr) {
MS_LOG(ERROR) << name_ << ": The strategy is null."; MS_LOG(ERROR) << name_ << ": The strategy is null.";

View File

@ -209,7 +209,7 @@ class OperatorInfo {
virtual Status InferMirrorOps(); virtual Status InferMirrorOps();
virtual Status InferTensorInfo(); virtual Status InferTensorInfo();
virtual void InferReplaceOps() {} virtual void InferReplaceOps() {}
virtual Status CheckOutputStrategy(const StrategyPtr &out_strategy) { return SUCCESS; } virtual Status CheckOutputStrategy(const StrategyPtr &out_strategy);
Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape); Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape);
void SetRepeatedCalcDevMatrix(); void SetRepeatedCalcDevMatrix();
void ResetTensorMapIfRepeatedCalc(); void ResetTensorMapIfRepeatedCalc();

View File

@ -126,6 +126,25 @@ void ResizeBilinearInfo::ReplaceNodeInputOrAttrs() {
Status ResizeNearestNeighborInfo::CheckStrategy(const StrategyPtr &strategy) { Status ResizeNearestNeighborInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy); MS_EXCEPTION_IF_NULL(strategy);
if (align_corners_) {
std::vector<Dimensions> stra = strategy->GetInputDim();
if (stra.size() != 1) {
MS_LOG(ERROR) << name_ << ": The size of strategy must be 1, but got " << stra.size();
return FAILED;
}
Dimensions input_strategy = stra[0];
if (input_strategy.size() != 4) {
MS_LOG(ERROR) << name_ << ": The size of input strategy must be 4, but got" << input_strategy.size();
return FAILED;
}
if (input_strategy[2] != 1 || input_strategy[3] != 1) {
MS_LOG(ERROR) << name_ << ": The align_corners is True, do not support split from H or W";
return FAILED;
}
}
// check input strategy // check input strategy
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Check input strategy failed"; MS_LOG(ERROR) << name_ << ": Check input strategy failed";
@ -143,6 +162,10 @@ Status ResizeNearestNeighborInfo::CheckStrategy(const StrategyPtr &strategy) {
std::vector<StrategyPtr> ResizeNearestNeighborInfo::GenerateOpStrategies(int64_t stage_id) { std::vector<StrategyPtr> ResizeNearestNeighborInfo::GenerateOpStrategies(int64_t stage_id) {
Shape multiples_split(inputs_shape_[0].size(), 1); Shape multiples_split(inputs_shape_[0].size(), 1);
if (align_corners_) {
multiples_split[2] = 0;
multiples_split[3] = 0;
}
Shapes splittable_inputs = {multiples_split}; Shapes splittable_inputs = {multiples_split};
std::vector<StrategyPtr> sp_vector; std::vector<StrategyPtr> sp_vector;

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
'''ResizeBilinear and ResizeNearestNeigbor ut''' '''ResizeBilinear and ResizeNearestNeigbor ut'''
import numpy as np import numpy as np
import pytest
import mindspore as ms import mindspore as ms
from mindspore import context, Tensor, Parameter from mindspore import context, Tensor, Parameter
@ -43,13 +44,13 @@ class Net2(Cell):
''' '''
create the test Net create the test Net
''' '''
def __init__(self, conv2d_weight, mul_weight, out_channel, kernel_size, pad_mode, stride, def __init__(self, conv2d_weight, mul_weight, out_channel, kernel_size, pad_mode, stride, align_corners=False,
strategy1=None, strategy2=None): strategy1=None, strategy2=None, out_strategy=None):
super(Net2, self).__init__() super(Net2, self).__init__()
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size, self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
pad_mode=pad_mode, stride=stride).shard(strategy1) pad_mode=pad_mode, stride=stride).shard(strategy1)
self.conv2d_weight = Parameter(conv2d_weight, "w1") self.conv2d_weight = Parameter(conv2d_weight, "w1")
self.resize_neighbor = P.ResizeNearestNeighbor((16, 16)).shard(strategy2) self.resize_neighbor = P.ResizeNearestNeighbor((16, 16), align_corners).shard(strategy2, out_strategy)
self.mul = P.Mul() self.mul = P.Mul()
self.mul_weight = Parameter(mul_weight, "w2") self.mul_weight = Parameter(mul_weight, "w2")
@ -59,6 +60,7 @@ class Net2(Cell):
out = self.mul(out, self.mul_weight) out = self.mul(out, self.mul_weight)
return out return out
class Net3(Cell): class Net3(Cell):
''' '''
create the test Net create the test Net
@ -92,6 +94,11 @@ def compile_net(net, inputs=_x):
def test_bililear_data_parallel(): def test_bililear_data_parallel():
"""
Feature: test ResizeBilinear data parallel strategy
Description: only shard batch dimension
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((8, 1, 1, 1),) strategy2 = ((8, 1, 1, 1),)
@ -101,6 +108,11 @@ def test_bililear_data_parallel():
def test_bilinear_model_parallel1(): def test_bilinear_model_parallel1():
"""
Feature: test ResizeBilinear model parallel strategy
Description: shard N/C
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1)) strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((4, 2, 1, 1),) strategy2 = ((4, 2, 1, 1),)
@ -109,7 +121,12 @@ def test_bilinear_model_parallel1():
compile_net(net) compile_net(net)
def test_bilinear_model_parallel2(): def test_bilinear_repeated_calc():
"""
Feature: test ResizeBilinear repeated calculation parallel strategy
Description: only shard batch dimension, but shard num smaller than device num
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1)) strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((2, 1, 1, 1),) strategy2 = ((2, 1, 1, 1),)
@ -119,18 +136,33 @@ def test_bilinear_model_parallel2():
def test_bilinear_auto_parallel(): def test_bilinear_auto_parallel():
"""
Feature: test ResizeBilinear auto parallel
Description:
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1) net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
compile_net(net) compile_net(net)
def test_bilinear_no_strategy(): def test_bilinear_no_strategy():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) """
Feature: test ResizeBilinear semi auto parallel, and has not set strategy for it
Description:
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
net = Net3(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1) net = Net3(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
compile_net(net) compile_net(net)
def test_neighbor_data_parallel(): def test_neighbor_data_parallel():
"""
Feature: test ResizeNearestNeighbor data parallel strategy
Description: only shard batch dimension
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((8, 1, 1, 1),) strategy2 = ((8, 1, 1, 1),)
@ -139,7 +171,57 @@ def test_neighbor_data_parallel():
compile_net(net) compile_net(net)
def test_neighbor_model_parallel1(): def test_neighbor_model_parallel_align_corners_shard_HW():
"""
Feature: test ResizeNearestNeighbor model parallel strategy
Description: the align_corners is True, and shard N/C/H/W
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((2, 2, 2, 2),)
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1, align_corners=True,
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
def test_neighbor_out_strategy():
"""
Feature: test ResizeNearestNeighbor to set output parallel strategy
Description:
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((2, 2, 2, 2),)
out_strategy = ((2, 2, 2, 2),)
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
strategy1=strategy1, strategy2=strategy2, out_strategy=out_strategy)
with pytest.raises(RuntimeError):
compile_net(net)
def test_neighbor_model_parallel_align_corners_shard_NC():
"""
Feature: test ResizeNearestNeighbor model parallel strategy
Description: the align_corners is True, and shard N/C
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((4, 4, 1, 1),)
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1, align_corners=True,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_neighbor_model_parallel_align_corners_is_false():
"""
Feature: test ResizeNearestNeighbor model parallel strategy
Description: the align_corners is False, and shard N/C/H/W
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1)) strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((2, 2, 2, 2),) strategy2 = ((2, 2, 2, 2),)
@ -149,6 +231,11 @@ def test_neighbor_model_parallel1():
def test_neighbor_auto_parallel(): def test_neighbor_auto_parallel():
"""
Feature: test ResizeNearestNeighbor auto parallel
Description:
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1) net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
compile_net(net) compile_net(net)