forked from mindspore-Ecosystem/mindspore
!23569 Produce parallel operators for ResizeBilinear and ResizeNearestNeighbor
Merge pull request !23569 from Bert0108/resizebilinear_parallel_ops
This commit is contained in:
commit
e7cb505e68
|
@ -187,6 +187,7 @@ using RangeCost = CastCost;
|
|||
using SplitCost = CastCost;
|
||||
using ScatterUpdateCost = CastCost;
|
||||
using UniformRealCost = CastCost;
|
||||
using ResizeBilinearCost = CastCost;
|
||||
|
||||
class SqrtCost : public CastCost {
|
||||
public:
|
||||
|
|
|
@ -207,6 +207,8 @@ REGISTER(ReduceAnyInfo);
|
|||
REGISTER(MatmulDDSInfo);
|
||||
REGISTER(DSDMatmulInfo);
|
||||
REGISTER(UniformRealInfo);
|
||||
REGISTER(ResizeBilinearInfo);
|
||||
REGISTER(ResizeNearestNeighborInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -61,5 +61,6 @@
|
|||
#include "frontend/parallel/ops_info/matmul_dds_info.h"
|
||||
#include "frontend/parallel/ops_info/dsd_matmul_info.h"
|
||||
#include "frontend/parallel/ops_info/uniform_real_info.h"
|
||||
#include "frontend/parallel/ops_info/resizebilinear_info.h"
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_
|
||||
|
|
|
@ -134,6 +134,7 @@ constexpr char NEW_AXIS_MASK[] = "new_axis_mask";
|
|||
constexpr char SHRINK_AXIS_MASK[] = "shrink_axis_mask";
|
||||
constexpr char BEGIN[] = "begin";
|
||||
constexpr char SIZE[] = "size";
|
||||
constexpr char ALIGN_CORNERS[] = "align_corners";
|
||||
constexpr char END[] = "end";
|
||||
constexpr char STRIDES[] = "strides";
|
||||
constexpr char GROUP[] = "group";
|
||||
|
@ -398,6 +399,8 @@ constexpr char GATHERND[] = "GatherNd";
|
|||
constexpr char SCATTER_UPDATE[] = "ScatterUpdate";
|
||||
constexpr char GATHERD[] = "GatherD";
|
||||
constexpr char DSD_MATMUL[] = "DSDMatmul";
|
||||
constexpr char RESIZE_BILINEAR[] = "ResizeBilinear";
|
||||
constexpr char RESIZE_NEAREST_NEIGHBOR[] = "ResizeNearestNeighbor";
|
||||
|
||||
// pipeline
|
||||
constexpr char MICRO[] = "micro";
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/parallel/ops_info/resizebilinear_info.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status ResizeBilinearInfo::GetAttrs() {
|
||||
size_ = GetTupleIntAttr(SIZE);
|
||||
if (size_.size() != 2) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of input size must be 2, but got " << size_.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (size_[0] != size_[1]) {
|
||||
MS_LOG(ERROR) << name_ << ": The second two elements of size must be the same, but got (" << size_[0] << ", "
|
||||
<< size_[1] << ")";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
align_corners_ = GetBoolAttr(ALIGN_CORNERS);
|
||||
|
||||
MS_LOG(INFO) << name_ << ": The input size is " << size_ << ", align_corners is " << align_corners_;
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ResizeBilinearInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
MS_EXCEPTION_IF_NULL(strategy);
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
std::vector<Dimensions> stra = strategy->GetInputDim();
|
||||
if (stra.size() != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of strategy must be 1, but got " << stra.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Dimensions input_strategy = stra[0];
|
||||
if (input_strategy.size() != 4) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of input strategy must be 4, but got" << input_strategy.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (input_strategy[2] != 1 || input_strategy[3] != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": Do not support split from H or W";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ResizeBilinearInfo::InferDevMatrixShape() {
|
||||
// the strategy is (n, c, h, w)
|
||||
// the dev matrix is (n, c, h, w)
|
||||
MS_EXCEPTION_IF_NULL(strategy_);
|
||||
std::vector<Dimensions> stra = strategy_->GetInputDim();
|
||||
if (stra.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The strategy is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
dev_matrix_shape_ = stra[0];
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ResizeBilinearInfo::InferTensorMap() {
|
||||
// input_strategy: (n, c, h, w)
|
||||
// output_strategy: (n, c, h, w)
|
||||
// dev_matrix: (n, c, h, w)
|
||||
TensorMap input_tensor_map = {3, 2, 1, 0};
|
||||
TensorMap output_tensor_map = {3, 2, 1, 0};
|
||||
|
||||
(void)inputs_tensor_map_.emplace_back(std::move(input_tensor_map));
|
||||
(void)outputs_tensor_map_.emplace_back(std::move(output_tensor_map));
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ResizeBilinearInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
return SetCostUnderStrategyBase(strategy);
|
||||
}
|
||||
|
||||
std::vector<StrategyPtr> ResizeBilinearInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
Shape input0_split(inputs_shape_[0].size(), 0);
|
||||
input0_split[0] = 1;
|
||||
Shapes splittable_inputs = {input0_split};
|
||||
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << " : Generate strategies for independent inputs() failed.";
|
||||
}
|
||||
return sp_vector;
|
||||
}
|
||||
|
||||
Status ResizeBilinearInfo::Init(const StrategyPtr &strategy) {
|
||||
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init failed.";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << name_ << ": Init success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ResizeBilinearInfo::InitForCostModel(const StrategyPtr &strategy) {
|
||||
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << name_ << ": Init for cost model success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,68 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESIZEBILINEAR_INFO_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESIZEBILINEAR_INFO_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "ir/value.h"
|
||||
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
|
||||
#include "frontend/parallel/ops_info/operator_info.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
class ResizeBilinearInfo : public OperatorInfo {
|
||||
public:
|
||||
ResizeBilinearInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ResizeBilinearCost>()) {}
|
||||
~ResizeBilinearInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &) override;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override;
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
|
||||
|
||||
private:
|
||||
std::vector<int64_t> size_; // four integers, NCHW
|
||||
bool align_corners_;
|
||||
};
|
||||
|
||||
class ResizeNearestNeighborInfo : public ResizeBilinearInfo {
|
||||
public:
|
||||
ResizeNearestNeighborInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ResizeBilinearInfo(name, inputs_shape, outputs_shape, attrs) {}
|
||||
~ResizeNearestNeighborInfo() override = default;
|
||||
};
|
||||
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESIZEBILINEAR_INFO_H_
|
|
@ -172,7 +172,7 @@ bool IsSplittableOperator(const std::string &op_name) {
|
|||
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
|
||||
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT, GATHERD,
|
||||
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE,
|
||||
MATMUL_DDS, DSD_MATMUL, UNIFORMREAL};
|
||||
MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, RESIZE_NEAREST_NEIGHBOR};
|
||||
// clang-format on
|
||||
|
||||
auto iter = splittable_op.find(op_name);
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
'''ResizeBilinear and ResizeNearestNeigbor ut'''
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from mindspore.nn import Cell, TrainOneStepCell, Momentum
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
'''
|
||||
create the test Net
|
||||
'''
|
||||
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
|
||||
strategy1=None, strategy2=None):
|
||||
super(Net, self).__init__()
|
||||
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
|
||||
pad_mode=pad_mode, stride=stride).shard(strategy1)
|
||||
self.conv2d_weight = Parameter(conv2d_weight, "w1")
|
||||
self.resize_bilinear = P.ResizeBilinear((16, 16)).shard(strategy2)
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv2d(x, self.conv2d_weight)
|
||||
out = self.resize_bilinear(out)
|
||||
return out
|
||||
|
||||
|
||||
class Net2(Cell):
|
||||
'''
|
||||
create the test Net
|
||||
'''
|
||||
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
|
||||
strategy1=None, strategy2=None):
|
||||
super(Net2, self).__init__()
|
||||
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
|
||||
pad_mode=pad_mode, stride=stride).shard(strategy1)
|
||||
self.conv2d_weight = Parameter(conv2d_weight, "w1")
|
||||
self.resize_neighbor = P.ResizeNearestNeighbor((16, 16)).shard(strategy2)
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv2d(x, self.conv2d_weight)
|
||||
out = self.resize_neighbor(out)
|
||||
return out
|
||||
|
||||
class Net3(Cell):
|
||||
'''
|
||||
create the test Net
|
||||
'''
|
||||
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
|
||||
strategy1=None):
|
||||
super(Net3, self).__init__()
|
||||
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
|
||||
pad_mode=pad_mode, stride=stride).shard(strategy1)
|
||||
self.conv2d_weight = Parameter(conv2d_weight, "w1")
|
||||
self.resize_bilinear = P.ResizeBilinear((16, 16))
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv2d(x, self.conv2d_weight)
|
||||
out = self.resize_bilinear(out)
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile_net(net, inputs=_x):
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
train_net.set_train()
|
||||
_cell_graph_executor.compile(train_net, inputs)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_bililear_data_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
|
||||
strategy2 = ((8, 1, 1, 1),)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_bilinear_model_parallel1():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||
strategy2 = ((4, 2, 1, 1),)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_bilinear_model_parallel2():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||
strategy2 = ((2, 1, 1, 1),)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_bilinear_auto_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_bilinear_no_strategy():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net3(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_neighbor_data_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
|
||||
strategy2 = ((8, 1, 1, 1),)
|
||||
net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_neighbor_model_parallel1():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||
strategy2 = ((4, 2, 1, 1),)
|
||||
net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_neighbor_auto_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
|
||||
compile_net(net)
|
Loading…
Reference in New Issue