!25544 add parallel op for resizenearestneighbor

Merge pull request !25544 from yangzhenzhang/add-parallel-op-for-resizenearestneighbor
This commit is contained in:
i-robot 2021-10-29 02:47:49 +00:00 committed by Gitee
commit 48699c939c
4 changed files with 62 additions and 18 deletions

View File

@ -35,14 +35,7 @@ Status ResizeBilinearInfo::GetAttrs() {
return FAILED;
}
if (size_[0] != size_[1]) {
MS_LOG(ERROR) << name_ << ": The second two elements of size must be the same, but got (" << size_[0] << ", "
<< size_[1] << ")";
return FAILED;
}
align_corners_ = GetBoolAttr(ALIGN_CORNERS);
MS_LOG(INFO) << name_ << ": The input size is " << size_ << ", align_corners is " << align_corners_;
return SUCCESS;
@ -85,7 +78,15 @@ Status ResizeBilinearInfo::InferDevMatrixShape() {
return FAILED;
}
if (stra[0].size() != 4) {
MS_LOG(ERROR) << name_ << ": The size of strategy must be 4, but got " << stra[0].size();
return FAILED;
}
dev_matrix_shape_ = stra[0];
slice_size_ = size_;
slice_size_[0] = slice_size_[0] / dev_matrix_shape_[2];
slice_size_[1] = slice_size_[1] / dev_matrix_shape_[3];
return SUCCESS;
}
@ -135,5 +136,40 @@ Status ResizeBilinearInfo::InitForCostModel(const StrategyPtr &strategy) {
MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
void ResizeBilinearInfo::ReplaceNodeInputOrAttrs() {
auto prim = GetValueNode<PrimitivePtr>(cnode_->input(0));
prim->set_attr(SIZE, MakeValue(slice_size_));
}
Status ResizeNearestNeighborInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy);
// check input strategy
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Check input strategy failed";
return FAILED;
}
// check output strategy
if (CheckStrategyValue(strategy, outputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Check output strategy failed";
return FAILED;
}
return SUCCESS;
}
std::vector<StrategyPtr> ResizeNearestNeighborInfo::GenerateOpStrategies(int64_t stage_id) {
Shape multiples_split(inputs_shape_[0].size(), 1);
Shapes splittable_inputs = {multiples_split};
std::vector<StrategyPtr> sp_vector;
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": generate strategies failed";
}
return sp_vector;
}
} // namespace parallel
} // namespace mindspore

View File

@ -47,11 +47,11 @@ class ResizeBilinearInfo : public OperatorInfo {
Status InferForwardCommunication() override { return SUCCESS; }
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
void ReplaceNodeInputOrAttrs() override;
private:
std::vector<int64_t> size_; // four integers, NCHW
bool align_corners_;
std::vector<int64_t> size_;
std::vector<int64_t> slice_size_;
bool align_corners_ = false;
};
class ResizeNearestNeighborInfo : public ResizeBilinearInfo {
@ -60,6 +60,10 @@ class ResizeNearestNeighborInfo : public ResizeBilinearInfo {
const PrimitiveAttrs &attrs)
: ResizeBilinearInfo(name, inputs_shape, outputs_shape, attrs) {}
~ResizeNearestNeighborInfo() override = default;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
};
} // namespace parallel

View File

@ -3841,7 +3841,7 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
def infer_shape(self, x_shape):
validator.check('the dimension of input_x', len(x_shape), '', 4, Rel.EQ, self.name)
return tuple(x_shape)[:-2] + tuple(self.size)
return tuple(x_shape)[:-2] + tuple(super().get_attr_dict()['size'])
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name)

View File

@ -43,17 +43,20 @@ class Net2(Cell):
'''
create the test Net
'''
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
def __init__(self, conv2d_weight, mul_weight, out_channel, kernel_size, pad_mode, stride,
strategy1=None, strategy2=None):
super(Net2, self).__init__()
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
pad_mode=pad_mode, stride=stride).shard(strategy1)
self.conv2d_weight = Parameter(conv2d_weight, "w1")
self.resize_neighbor = P.ResizeNearestNeighbor((16, 16)).shard(strategy2)
self.mul = P.Mul()
self.mul_weight = Parameter(mul_weight, "w2")
def construct(self, x):
out = self.conv2d(x, self.conv2d_weight)
out = self.resize_neighbor(out)
out = self.mul(out, self.mul_weight)
return out
class Net3(Cell):
@ -76,6 +79,7 @@ class Net3(Cell):
_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
_w2 = Tensor(np.ones([32, 8, 16, 16]), dtype=ms.float32)
def compile_net(net, inputs=_x):
@ -130,21 +134,21 @@ def test_neighbor_data_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((8, 1, 1, 1),)
net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_neighbor_model_parallel1():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((4, 2, 1, 1),)
net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
strategy2 = ((2, 2, 2, 2),)
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_neighbor_auto_parallel():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
compile_net(net)