diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h index 10114bc5f05..1cfad0ae64f 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h @@ -187,6 +187,7 @@ using RangeCost = CastCost; using SplitCost = CastCost; using ScatterUpdateCost = CastCost; using UniformRealCost = CastCost; +using ResizeBilinearCost = CastCost; class SqrtCost : public CastCost { public: diff --git a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h index 89f292399bf..4480c0d7156 100644 --- a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h @@ -207,6 +207,8 @@ REGISTER(ReduceAnyInfo); REGISTER(MatmulDDSInfo); REGISTER(DSDMatmulInfo); REGISTER(UniformRealInfo); +REGISTER(ResizeBilinearInfo); +REGISTER(ResizeNearestNeighborInfo); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h index c705c322f0c..04479aef86b 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h @@ -61,5 +61,6 @@ #include "frontend/parallel/ops_info/matmul_dds_info.h" #include "frontend/parallel/ops_info/dsd_matmul_info.h" #include "frontend/parallel/ops_info/uniform_real_info.h" +#include "frontend/parallel/ops_info/resizebilinear_info.h" #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index bd038ae0c2e..aedccfcfda6 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -134,6 +134,7 @@ constexpr char NEW_AXIS_MASK[] = "new_axis_mask"; constexpr char SHRINK_AXIS_MASK[] = "shrink_axis_mask"; constexpr char BEGIN[] = "begin"; constexpr char SIZE[] = "size"; +constexpr char ALIGN_CORNERS[] = "align_corners"; constexpr char END[] = "end"; constexpr char STRIDES[] = "strides"; constexpr char GROUP[] = "group"; @@ -398,6 +399,8 @@ constexpr char GATHERND[] = "GatherNd"; constexpr char SCATTER_UPDATE[] = "ScatterUpdate"; constexpr char GATHERD[] = "GatherD"; constexpr char DSD_MATMUL[] = "DSDMatmul"; +constexpr char RESIZE_BILINEAR[] = "ResizeBilinear"; +constexpr char RESIZE_NEAREST_NEIGHBOR[] = "ResizeNearestNeighbor"; // pipeline constexpr char MICRO[] = "micro"; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/resizebilinear_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/resizebilinear_info.cc new file mode 100644 index 00000000000..ccbef982086 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/resizebilinear_info.cc @@ -0,0 +1,139 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/resizebilinear_info.h" + +#include +#include +#include +#include + +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" +#include "pipeline/jit/resource.h" + +namespace mindspore { +namespace parallel { +Status ResizeBilinearInfo::GetAttrs() { + size_ = GetTupleIntAttr(SIZE); + if (size_.size() != 2) { + MS_LOG(ERROR) << name_ << ": The size of input size must be 2, but got " << size_.size(); + return FAILED; + } + + if (size_[0] != size_[1]) { + MS_LOG(ERROR) << name_ << ": The second two elements of size must be the same, but got (" << size_[0] << ", " + << size_[1] << ")"; + return FAILED; + } + + align_corners_ = GetBoolAttr(ALIGN_CORNERS); + + MS_LOG(INFO) << name_ << ": The input size is " << size_ << ", align_corners is " << align_corners_; + + return SUCCESS; +} + +Status ResizeBilinearInfo::CheckStrategy(const StrategyPtr &strategy) { + MS_EXCEPTION_IF_NULL(strategy); + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Invalid strategy"; + return FAILED; + } + + std::vector stra = strategy->GetInputDim(); + if (stra.size() != 1) { + MS_LOG(ERROR) << name_ << ": The size of strategy must be 1, but got " << stra.size(); + return FAILED; + } + + Dimensions input_strategy = stra[0]; + if (input_strategy.size() != 4) { + MS_LOG(ERROR) << name_ << ": The size of input strategy must be 4, but got" << input_strategy.size(); + return FAILED; + } + + if (input_strategy[2] != 1 || input_strategy[3] != 1) { + MS_LOG(ERROR) << name_ << ": Do not support split from H or W"; + return FAILED; + } + + return SUCCESS; +} + +Status ResizeBilinearInfo::InferDevMatrixShape() { + // the strategy is (n, c, h, w) + // the dev matrix is (n, c, h, w) + MS_EXCEPTION_IF_NULL(strategy_); + std::vector stra = strategy_->GetInputDim(); + if (stra.empty()) { + MS_LOG(ERROR) << name_ << ": The strategy is empty"; + return FAILED; + } + + dev_matrix_shape_ = stra[0]; + return SUCCESS; +} + +Status ResizeBilinearInfo::InferTensorMap() { + // input_strategy: (n, c, h, w) + // output_strategy: (n, c, h, w) + // dev_matrix: (n, c, h, w) + TensorMap input_tensor_map = {3, 2, 1, 0}; + TensorMap output_tensor_map = {3, 2, 1, 0}; + + (void)inputs_tensor_map_.emplace_back(std::move(input_tensor_map)); + (void)outputs_tensor_map_.emplace_back(std::move(output_tensor_map)); + return SUCCESS; +} + +Status ResizeBilinearInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + return SetCostUnderStrategyBase(strategy); +} + +std::vector ResizeBilinearInfo::GenerateOpStrategies(int64_t stage_id) { + Shape input0_split(inputs_shape_[0].size(), 0); + input0_split[0] = 1; + Shapes splittable_inputs = {input0_split}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(EXCEPTION) << name_ << " : Generate strategies for independent inputs() failed."; + } + return sp_vector; +} + +Status ResizeBilinearInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status ResizeBilinearInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/resizebilinear_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/resizebilinear_info.h new file mode 100644 index 00000000000..d4cef30a4c0 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/resizebilinear_info.h @@ -0,0 +1,68 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESIZEBILINEAR_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESIZEBILINEAR_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class ResizeBilinearInfo : public OperatorInfo { + public: + ResizeBilinearInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + ~ResizeBilinearInfo() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + std::vector GenerateOpStrategies(int64_t) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; + + protected: + Status GetAttrs() override; + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferForwardCommunication() override { return SUCCESS; } + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy); + + private: + std::vector size_; // four integers, NCHW + bool align_corners_; +}; + +class ResizeNearestNeighborInfo : public ResizeBilinearInfo { + public: + ResizeNearestNeighborInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ResizeBilinearInfo(name, inputs_shape, outputs_shape, attrs) {} + ~ResizeNearestNeighborInfo() override = default; +}; + +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESIZEBILINEAR_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index a78f606f1f2..586fcf420f4 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -172,7 +172,7 @@ bool IsSplittableOperator(const std::string &op_name) { SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM, UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT, GATHERD, UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE, - MATMUL_DDS, DSD_MATMUL, UNIFORMREAL}; + MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, RESIZE_NEAREST_NEIGHBOR}; // clang-format on auto iter = splittable_op.find(op_name); diff --git a/tests/ut/python/parallel/test_resizebilinear.py b/tests/ut/python/parallel/test_resizebilinear.py new file mode 100644 index 00000000000..9315c7d80c9 --- /dev/null +++ b/tests/ut/python/parallel/test_resizebilinear.py @@ -0,0 +1,150 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +'''ResizeBilinear and ResizeNearestNeigbor ut''' +import numpy as np + +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.common.api import _cell_graph_executor +from mindspore.nn import Cell, TrainOneStepCell, Momentum +from mindspore.ops import operations as P + + +class Net(Cell): + ''' + create the test Net + ''' + def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, + strategy1=None, strategy2=None): + super(Net, self).__init__() + self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size, + pad_mode=pad_mode, stride=stride).shard(strategy1) + self.conv2d_weight = Parameter(conv2d_weight, "w1") + self.resize_bilinear = P.ResizeBilinear((16, 16)).shard(strategy2) + + def construct(self, x): + out = self.conv2d(x, self.conv2d_weight) + out = self.resize_bilinear(out) + return out + + +class Net2(Cell): + ''' + create the test Net + ''' + def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, + strategy1=None, strategy2=None): + super(Net2, self).__init__() + self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size, + pad_mode=pad_mode, stride=stride).shard(strategy1) + self.conv2d_weight = Parameter(conv2d_weight, "w1") + self.resize_neighbor = P.ResizeNearestNeighbor((16, 16)).shard(strategy2) + + def construct(self, x): + out = self.conv2d(x, self.conv2d_weight) + out = self.resize_neighbor(out) + return out + +class Net3(Cell): + ''' + create the test Net + ''' + def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, + strategy1=None): + super(Net3, self).__init__() + self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size, + pad_mode=pad_mode, stride=stride).shard(strategy1) + self.conv2d_weight = Parameter(conv2d_weight, "w1") + self.resize_bilinear = P.ResizeBilinear((16, 16)) + + def construct(self, x): + out = self.conv2d(x, self.conv2d_weight) + out = self.resize_bilinear(out) + return out + + +_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32) +_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32) + + +def compile_net(net, inputs=_x): + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + train_net.set_train() + _cell_graph_executor.compile(train_net, inputs) + context.reset_auto_parallel_context() + + +def test_bililear_data_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) + strategy2 = ((8, 1, 1, 1),) + net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, + strategy1=strategy1, strategy2=strategy2) + compile_net(net) + + +def test_bilinear_model_parallel1(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1)) + strategy2 = ((4, 2, 1, 1),) + net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, + strategy1=strategy1, strategy2=strategy2) + compile_net(net) + + +def test_bilinear_model_parallel2(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1)) + strategy2 = ((2, 1, 1, 1),) + net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, + strategy1=strategy1, strategy2=strategy2) + compile_net(net) + + +def test_bilinear_auto_parallel(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) + net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1) + compile_net(net) + + +def test_bilinear_no_strategy(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) + net = Net3(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1) + compile_net(net) + + +def test_neighbor_data_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) + strategy2 = ((8, 1, 1, 1),) + net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, + strategy1=strategy1, strategy2=strategy2) + compile_net(net) + + +def test_neighbor_model_parallel1(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1)) + strategy2 = ((4, 2, 1, 1),) + net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, + strategy1=strategy1, strategy2=strategy2) + compile_net(net) + + +def test_neighbor_auto_parallel(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) + net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1) + compile_net(net)