forked from mindspore-Ecosystem/mindspore
add check for resize op
This commit is contained in:
parent
36bafb3654
commit
ba99e4c505
|
@ -56,6 +56,14 @@ std::string StrategyToString(const Strategys &strategy) {
|
||||||
return strategy_str;
|
return strategy_str;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status OperatorInfo::CheckOutputStrategy(const StrategyPtr &out_strategy) {
|
||||||
|
if (out_strategy) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": It does not support to set output strategy now, please modify the shard set";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape) {
|
Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape) {
|
||||||
if (strategy == nullptr) {
|
if (strategy == nullptr) {
|
||||||
MS_LOG(ERROR) << name_ << ": The strategy is null.";
|
MS_LOG(ERROR) << name_ << ": The strategy is null.";
|
||||||
|
|
|
@ -209,7 +209,7 @@ class OperatorInfo {
|
||||||
virtual Status InferMirrorOps();
|
virtual Status InferMirrorOps();
|
||||||
virtual Status InferTensorInfo();
|
virtual Status InferTensorInfo();
|
||||||
virtual void InferReplaceOps() {}
|
virtual void InferReplaceOps() {}
|
||||||
virtual Status CheckOutputStrategy(const StrategyPtr &out_strategy) { return SUCCESS; }
|
virtual Status CheckOutputStrategy(const StrategyPtr &out_strategy);
|
||||||
Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape);
|
Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape);
|
||||||
void SetRepeatedCalcDevMatrix();
|
void SetRepeatedCalcDevMatrix();
|
||||||
void ResetTensorMapIfRepeatedCalc();
|
void ResetTensorMapIfRepeatedCalc();
|
||||||
|
|
|
@ -126,6 +126,25 @@ void ResizeBilinearInfo::ReplaceNodeInputOrAttrs() {
|
||||||
Status ResizeNearestNeighborInfo::CheckStrategy(const StrategyPtr &strategy) {
|
Status ResizeNearestNeighborInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
MS_EXCEPTION_IF_NULL(strategy);
|
MS_EXCEPTION_IF_NULL(strategy);
|
||||||
|
|
||||||
|
if (align_corners_) {
|
||||||
|
std::vector<Dimensions> stra = strategy->GetInputDim();
|
||||||
|
if (stra.size() != 1) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The size of strategy must be 1, but got " << stra.size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
Dimensions input_strategy = stra[0];
|
||||||
|
if (input_strategy.size() != 4) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The size of input strategy must be 4, but got" << input_strategy.size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (input_strategy[2] != 1 || input_strategy[3] != 1) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The align_corners is True, do not support split from H or W";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// check input strategy
|
// check input strategy
|
||||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||||
MS_LOG(ERROR) << name_ << ": Check input strategy failed";
|
MS_LOG(ERROR) << name_ << ": Check input strategy failed";
|
||||||
|
@ -143,6 +162,10 @@ Status ResizeNearestNeighborInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
|
|
||||||
std::vector<StrategyPtr> ResizeNearestNeighborInfo::GenerateOpStrategies(int64_t stage_id) {
|
std::vector<StrategyPtr> ResizeNearestNeighborInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
Shape multiples_split(inputs_shape_[0].size(), 1);
|
Shape multiples_split(inputs_shape_[0].size(), 1);
|
||||||
|
if (align_corners_) {
|
||||||
|
multiples_split[2] = 0;
|
||||||
|
multiples_split[3] = 0;
|
||||||
|
}
|
||||||
Shapes splittable_inputs = {multiples_split};
|
Shapes splittable_inputs = {multiples_split};
|
||||||
|
|
||||||
std::vector<StrategyPtr> sp_vector;
|
std::vector<StrategyPtr> sp_vector;
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
'''ResizeBilinear and ResizeNearestNeigbor ut'''
|
'''ResizeBilinear and ResizeNearestNeigbor ut'''
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
from mindspore import context, Tensor, Parameter
|
from mindspore import context, Tensor, Parameter
|
||||||
|
@ -43,13 +44,13 @@ class Net2(Cell):
|
||||||
'''
|
'''
|
||||||
create the test Net
|
create the test Net
|
||||||
'''
|
'''
|
||||||
def __init__(self, conv2d_weight, mul_weight, out_channel, kernel_size, pad_mode, stride,
|
def __init__(self, conv2d_weight, mul_weight, out_channel, kernel_size, pad_mode, stride, align_corners=False,
|
||||||
strategy1=None, strategy2=None):
|
strategy1=None, strategy2=None, out_strategy=None):
|
||||||
super(Net2, self).__init__()
|
super(Net2, self).__init__()
|
||||||
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
|
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
|
||||||
pad_mode=pad_mode, stride=stride).shard(strategy1)
|
pad_mode=pad_mode, stride=stride).shard(strategy1)
|
||||||
self.conv2d_weight = Parameter(conv2d_weight, "w1")
|
self.conv2d_weight = Parameter(conv2d_weight, "w1")
|
||||||
self.resize_neighbor = P.ResizeNearestNeighbor((16, 16)).shard(strategy2)
|
self.resize_neighbor = P.ResizeNearestNeighbor((16, 16), align_corners).shard(strategy2, out_strategy)
|
||||||
self.mul = P.Mul()
|
self.mul = P.Mul()
|
||||||
self.mul_weight = Parameter(mul_weight, "w2")
|
self.mul_weight = Parameter(mul_weight, "w2")
|
||||||
|
|
||||||
|
@ -59,6 +60,7 @@ class Net2(Cell):
|
||||||
out = self.mul(out, self.mul_weight)
|
out = self.mul(out, self.mul_weight)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Net3(Cell):
|
class Net3(Cell):
|
||||||
'''
|
'''
|
||||||
create the test Net
|
create the test Net
|
||||||
|
@ -92,6 +94,11 @@ def compile_net(net, inputs=_x):
|
||||||
|
|
||||||
|
|
||||||
def test_bililear_data_parallel():
|
def test_bililear_data_parallel():
|
||||||
|
"""
|
||||||
|
Feature: test ResizeBilinear data parallel strategy
|
||||||
|
Description: only shard batch dimension
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
|
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
|
||||||
strategy2 = ((8, 1, 1, 1),)
|
strategy2 = ((8, 1, 1, 1),)
|
||||||
|
@ -101,6 +108,11 @@ def test_bililear_data_parallel():
|
||||||
|
|
||||||
|
|
||||||
def test_bilinear_model_parallel1():
|
def test_bilinear_model_parallel1():
|
||||||
|
"""
|
||||||
|
Feature: test ResizeBilinear model parallel strategy
|
||||||
|
Description: shard N/C
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||||
strategy2 = ((4, 2, 1, 1),)
|
strategy2 = ((4, 2, 1, 1),)
|
||||||
|
@ -109,7 +121,12 @@ def test_bilinear_model_parallel1():
|
||||||
compile_net(net)
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
def test_bilinear_model_parallel2():
|
def test_bilinear_repeated_calc():
|
||||||
|
"""
|
||||||
|
Feature: test ResizeBilinear repeated calculation parallel strategy
|
||||||
|
Description: only shard batch dimension, but shard num smaller than device num
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||||
strategy2 = ((2, 1, 1, 1),)
|
strategy2 = ((2, 1, 1, 1),)
|
||||||
|
@ -119,18 +136,33 @@ def test_bilinear_model_parallel2():
|
||||||
|
|
||||||
|
|
||||||
def test_bilinear_auto_parallel():
|
def test_bilinear_auto_parallel():
|
||||||
|
"""
|
||||||
|
Feature: test ResizeBilinear auto parallel
|
||||||
|
Description:
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
|
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
|
||||||
compile_net(net)
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
def test_bilinear_no_strategy():
|
def test_bilinear_no_strategy():
|
||||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
"""
|
||||||
|
Feature: test ResizeBilinear semi auto parallel, and has not set strategy for it
|
||||||
|
Description:
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
net = Net3(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
|
net = Net3(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
|
||||||
compile_net(net)
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
def test_neighbor_data_parallel():
|
def test_neighbor_data_parallel():
|
||||||
|
"""
|
||||||
|
Feature: test ResizeNearestNeighbor data parallel strategy
|
||||||
|
Description: only shard batch dimension
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
|
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
|
||||||
strategy2 = ((8, 1, 1, 1),)
|
strategy2 = ((8, 1, 1, 1),)
|
||||||
|
@ -139,7 +171,57 @@ def test_neighbor_data_parallel():
|
||||||
compile_net(net)
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
def test_neighbor_model_parallel1():
|
def test_neighbor_model_parallel_align_corners_shard_HW():
|
||||||
|
"""
|
||||||
|
Feature: test ResizeNearestNeighbor model parallel strategy
|
||||||
|
Description: the align_corners is True, and shard N/C/H/W
|
||||||
|
Expectation: compile failed
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||||
|
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||||
|
strategy2 = ((2, 2, 2, 2),)
|
||||||
|
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1, align_corners=True,
|
||||||
|
strategy1=strategy1, strategy2=strategy2)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_neighbor_out_strategy():
|
||||||
|
"""
|
||||||
|
Feature: test ResizeNearestNeighbor to set output parallel strategy
|
||||||
|
Description:
|
||||||
|
Expectation: compile failed
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||||
|
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||||
|
strategy2 = ((2, 2, 2, 2),)
|
||||||
|
out_strategy = ((2, 2, 2, 2),)
|
||||||
|
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
|
||||||
|
strategy1=strategy1, strategy2=strategy2, out_strategy=out_strategy)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_neighbor_model_parallel_align_corners_shard_NC():
|
||||||
|
"""
|
||||||
|
Feature: test ResizeNearestNeighbor model parallel strategy
|
||||||
|
Description: the align_corners is True, and shard N/C
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||||
|
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||||
|
strategy2 = ((4, 4, 1, 1),)
|
||||||
|
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1, align_corners=True,
|
||||||
|
strategy1=strategy1, strategy2=strategy2)
|
||||||
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_neighbor_model_parallel_align_corners_is_false():
|
||||||
|
"""
|
||||||
|
Feature: test ResizeNearestNeighbor model parallel strategy
|
||||||
|
Description: the align_corners is False, and shard N/C/H/W
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||||
strategy2 = ((2, 2, 2, 2),)
|
strategy2 = ((2, 2, 2, 2),)
|
||||||
|
@ -149,6 +231,11 @@ def test_neighbor_model_parallel1():
|
||||||
|
|
||||||
|
|
||||||
def test_neighbor_auto_parallel():
|
def test_neighbor_auto_parallel():
|
||||||
|
"""
|
||||||
|
Feature: test ResizeNearestNeighbor auto parallel
|
||||||
|
Description:
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||||
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
|
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
|
||||||
compile_net(net)
|
compile_net(net)
|
||||||
|
|
Loading…
Reference in New Issue