!23569 Produce parallel operators for ResizeBilinear and ResizeNearestNeighbor

Merge pull request !23569 from Bert0108/resizebilinear_parallel_ops
This commit is contained in:
i-robot 2021-09-24 03:05:23 +00:00 committed by Gitee
commit e7cb505e68
8 changed files with 365 additions and 1 deletions

View File

@ -187,6 +187,7 @@ using RangeCost = CastCost;
using SplitCost = CastCost;
using ScatterUpdateCost = CastCost;
using UniformRealCost = CastCost;
using ResizeBilinearCost = CastCost;
class SqrtCost : public CastCost {
public:

View File

@ -207,6 +207,8 @@ REGISTER(ReduceAnyInfo);
REGISTER(MatmulDDSInfo);
REGISTER(DSDMatmulInfo);
REGISTER(UniformRealInfo);
REGISTER(ResizeBilinearInfo);
REGISTER(ResizeNearestNeighborInfo);
} // namespace parallel
} // namespace mindspore

View File

@ -61,5 +61,6 @@
#include "frontend/parallel/ops_info/matmul_dds_info.h"
#include "frontend/parallel/ops_info/dsd_matmul_info.h"
#include "frontend/parallel/ops_info/uniform_real_info.h"
#include "frontend/parallel/ops_info/resizebilinear_info.h"
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_

View File

@ -134,6 +134,7 @@ constexpr char NEW_AXIS_MASK[] = "new_axis_mask";
constexpr char SHRINK_AXIS_MASK[] = "shrink_axis_mask";
constexpr char BEGIN[] = "begin";
constexpr char SIZE[] = "size";
constexpr char ALIGN_CORNERS[] = "align_corners";
constexpr char END[] = "end";
constexpr char STRIDES[] = "strides";
constexpr char GROUP[] = "group";
@ -398,6 +399,8 @@ constexpr char GATHERND[] = "GatherNd";
constexpr char SCATTER_UPDATE[] = "ScatterUpdate";
constexpr char GATHERD[] = "GatherD";
constexpr char DSD_MATMUL[] = "DSDMatmul";
constexpr char RESIZE_BILINEAR[] = "ResizeBilinear";
constexpr char RESIZE_NEAREST_NEIGHBOR[] = "ResizeNearestNeighbor";
// pipeline
constexpr char MICRO[] = "micro";

View File

@ -0,0 +1,139 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/parallel/ops_info/resizebilinear_info.h"
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#include "pipeline/jit/resource.h"
namespace mindspore {
namespace parallel {
Status ResizeBilinearInfo::GetAttrs() {
size_ = GetTupleIntAttr(SIZE);
if (size_.size() != 2) {
MS_LOG(ERROR) << name_ << ": The size of input size must be 2, but got " << size_.size();
return FAILED;
}
if (size_[0] != size_[1]) {
MS_LOG(ERROR) << name_ << ": The second two elements of size must be the same, but got (" << size_[0] << ", "
<< size_[1] << ")";
return FAILED;
}
align_corners_ = GetBoolAttr(ALIGN_CORNERS);
MS_LOG(INFO) << name_ << ": The input size is " << size_ << ", align_corners is " << align_corners_;
return SUCCESS;
}
Status ResizeBilinearInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy);
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy";
return FAILED;
}
std::vector<Dimensions> stra = strategy->GetInputDim();
if (stra.size() != 1) {
MS_LOG(ERROR) << name_ << ": The size of strategy must be 1, but got " << stra.size();
return FAILED;
}
Dimensions input_strategy = stra[0];
if (input_strategy.size() != 4) {
MS_LOG(ERROR) << name_ << ": The size of input strategy must be 4, but got" << input_strategy.size();
return FAILED;
}
if (input_strategy[2] != 1 || input_strategy[3] != 1) {
MS_LOG(ERROR) << name_ << ": Do not support split from H or W";
return FAILED;
}
return SUCCESS;
}
Status ResizeBilinearInfo::InferDevMatrixShape() {
// the strategy is (n, c, h, w)
// the dev matrix is (n, c, h, w)
MS_EXCEPTION_IF_NULL(strategy_);
std::vector<Dimensions> stra = strategy_->GetInputDim();
if (stra.empty()) {
MS_LOG(ERROR) << name_ << ": The strategy is empty";
return FAILED;
}
dev_matrix_shape_ = stra[0];
return SUCCESS;
}
Status ResizeBilinearInfo::InferTensorMap() {
// input_strategy: (n, c, h, w)
// output_strategy: (n, c, h, w)
// dev_matrix: (n, c, h, w)
TensorMap input_tensor_map = {3, 2, 1, 0};
TensorMap output_tensor_map = {3, 2, 1, 0};
(void)inputs_tensor_map_.emplace_back(std::move(input_tensor_map));
(void)outputs_tensor_map_.emplace_back(std::move(output_tensor_map));
return SUCCESS;
}
Status ResizeBilinearInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
return SetCostUnderStrategyBase(strategy);
}
std::vector<StrategyPtr> ResizeBilinearInfo::GenerateOpStrategies(int64_t stage_id) {
Shape input0_split(inputs_shape_[0].size(), 0);
input0_split[0] = 1;
Shapes splittable_inputs = {input0_split};
std::vector<StrategyPtr> sp_vector;
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << " : Generate strategies for independent inputs() failed.";
}
return sp_vector;
}
Status ResizeBilinearInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}
Status ResizeBilinearInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,68 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESIZEBILINEAR_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESIZEBILINEAR_INFO_H_
#include <string>
#include <memory>
#include <unordered_map>
#include <vector>
#include "ir/value.h"
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
namespace mindspore {
namespace parallel {
class ResizeBilinearInfo : public OperatorInfo {
public:
ResizeBilinearInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ResizeBilinearCost>()) {}
~ResizeBilinearInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
protected:
Status GetAttrs() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
private:
std::vector<int64_t> size_; // four integers, NCHW
bool align_corners_;
};
class ResizeNearestNeighborInfo : public ResizeBilinearInfo {
public:
ResizeNearestNeighborInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ResizeBilinearInfo(name, inputs_shape, outputs_shape, attrs) {}
~ResizeNearestNeighborInfo() override = default;
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESIZEBILINEAR_INFO_H_

View File

@ -172,7 +172,7 @@ bool IsSplittableOperator(const std::string &op_name) {
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT, GATHERD,
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE,
MATMUL_DDS, DSD_MATMUL, UNIFORMREAL};
MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, RESIZE_NEAREST_NEIGHBOR};
// clang-format on
auto iter = splittable_op.find(op_name);

View File

@ -0,0 +1,150 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''ResizeBilinear and ResizeNearestNeigbor ut'''
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.common.api import _cell_graph_executor
from mindspore.nn import Cell, TrainOneStepCell, Momentum
from mindspore.ops import operations as P
class Net(Cell):
'''
create the test Net
'''
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
strategy1=None, strategy2=None):
super(Net, self).__init__()
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
pad_mode=pad_mode, stride=stride).shard(strategy1)
self.conv2d_weight = Parameter(conv2d_weight, "w1")
self.resize_bilinear = P.ResizeBilinear((16, 16)).shard(strategy2)
def construct(self, x):
out = self.conv2d(x, self.conv2d_weight)
out = self.resize_bilinear(out)
return out
class Net2(Cell):
'''
create the test Net
'''
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
strategy1=None, strategy2=None):
super(Net2, self).__init__()
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
pad_mode=pad_mode, stride=stride).shard(strategy1)
self.conv2d_weight = Parameter(conv2d_weight, "w1")
self.resize_neighbor = P.ResizeNearestNeighbor((16, 16)).shard(strategy2)
def construct(self, x):
out = self.conv2d(x, self.conv2d_weight)
out = self.resize_neighbor(out)
return out
class Net3(Cell):
'''
create the test Net
'''
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
strategy1=None):
super(Net3, self).__init__()
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
pad_mode=pad_mode, stride=stride).shard(strategy1)
self.conv2d_weight = Parameter(conv2d_weight, "w1")
self.resize_bilinear = P.ResizeBilinear((16, 16))
def construct(self, x):
out = self.conv2d(x, self.conv2d_weight)
out = self.resize_bilinear(out)
return out
_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
def compile_net(net, inputs=_x):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_cell_graph_executor.compile(train_net, inputs)
context.reset_auto_parallel_context()
def test_bililear_data_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((8, 1, 1, 1),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_bilinear_model_parallel1():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((4, 2, 1, 1),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_bilinear_model_parallel2():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((2, 1, 1, 1),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_bilinear_auto_parallel():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
compile_net(net)
def test_bilinear_no_strategy():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net3(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
compile_net(net)
def test_neighbor_data_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((8, 1, 1, 1),)
net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_neighbor_model_parallel1():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((4, 2, 1, 1),)
net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_neighbor_auto_parallel():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
compile_net(net)