!18546 add parallel operator for maxpool and avgpool

Merge pull request !18546 from yangzhenzhang/add-parallel-op-for-maxpool
This commit is contained in:
i-robot 2021-06-21 14:23:21 +08:00 committed by Gitee
commit 504ffc6292
8 changed files with 407 additions and 1 deletions

View File

@ -254,6 +254,7 @@ class GeLUCost : public SqrtCost {
using BesselI0eCost = GeLUCost;
using BesselI1eCost = GeLUCost;
using L2NormalizeCost = GeLUCost;
using MaxPoolCost = GeLUCost;
class SoftmaxCost : public OperatorCost {
public:

View File

@ -198,6 +198,8 @@ REGISTER(ScatterUpdateInfo);
REGISTER(VirtualOutputInfo);
REGISTER(Conv2DInfo);
REGISTER(BatchNormInfo);
REGISTER(MaxPoolInfo);
REGISTER(AvgPoolInfo);
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,198 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/parallel/ops_info/maxpool_info.h"
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#include "pipeline/jit/resource.h"
namespace mindspore {
namespace parallel {
Status MaxPoolInfo::GetAttrs() {
// kernel_size
kernel_size_ = GetTupleIntAttr(KERNEL_SIZE);
if (kernel_size_.size() != 4) {
MS_LOG(ERROR) << name_ << ": The size of kernel_size must be 4, but got " << kernel_size_.size();
return FAILED;
}
if (kernel_size_[0] != 1 || kernel_size_[1] != 1) {
MS_LOG(ERROR) << name_ << ": The first two elements of kernel_size must be 1, but got (" << kernel_size_[0] << ", "
<< kernel_size_[1] << ")";
return FAILED;
}
// pad_mode
pad_mode_ = GetIntAttr(PAD_MODE);
if (pad_mode_ != POOL_PAD_MODE_VALID && pad_mode_ != POOL_PAD_MODE_SAME) {
MS_LOG(ERROR) << name_ << ": The pad_mode value is invalid: " << pad_mode_;
return FAILED;
}
// stride
stride_ = GetTupleIntAttr(STRIDES);
if (stride_.size() != 4) {
MS_LOG(ERROR) << name_ << ": The size of stride must be 4, but got " << stride_.size();
return FAILED;
}
if (stride_[0] != 1 || stride_[1] != 1) {
MS_LOG(ERROR) << name_ << ": The first two elements of stride must be 1, but got (" << stride_[0] << ", "
<< stride_[1] << ")";
return FAILED;
}
// format
format_ = GetStringAttr(FORMAT);
if (format_ != NCHW) {
MS_LOG(ERROR) << name_ << ": The format only support 'NCHW', but got " << format_;
return FAILED;
}
MS_LOG(INFO) << name_ << ": The kernel size is " << kernel_size_ << ", pad mode is " << pad_mode_ << ", pad list is "
<< ", stride is " << stride_ << ", format is " << format_;
return SUCCESS;
}
Status MaxPoolInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
if (h_strategy > 1) {
if (kernel_size_[2] > stride_[2]) {
MS_LOG(ERROR) << name_ << ": It does not support to split H dimension when kernel_size > stride";
return FAILED;
}
int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy;
if (h_slice_shape % stride_[2] != 0) {
MS_LOG(ERROR) << name_
<< ": It does not support to split H dimension when kernel_size <= stride but slice shape is not "
"divisible by stride ";
return FAILED;
}
}
if (w_strategy > 1) {
if (kernel_size_[3] > stride_[3]) {
MS_LOG(ERROR) << name_ << ": It does not support to split W dimension when kernel_size > stride";
return FAILED;
}
int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy;
if (w_slice_shape % stride_[3] != 0) {
MS_LOG(ERROR) << name_
<< ": It does not support to split W dimension when kernel_size <= stride but slice shape is not "
"divisible by stride ";
return FAILED;
}
}
return SUCCESS;
}
Status MaxPoolInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy);
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy";
return FAILED;
}
std::vector<Dimensions> stra = strategy->GetInputDim();
if (stra.size() != 1) {
MS_LOG(ERROR) << name_ << ": The size of strategy must be 1, but got " << stra.size();
return FAILED;
}
Dimensions input_strategy = stra[0];
if (input_strategy.size() != 4) {
MS_LOG(ERROR) << name_ << ": The size of input strategy must be 4, but got" << input_strategy.size();
return FAILED;
}
if (input_strategy[2] != 1 || input_strategy[3] != 1) {
if (CheckHWStrategy(input_strategy[2], input_strategy[3]) != SUCCESS) {
return FAILED;
}
}
return SUCCESS;
}
Status MaxPoolInfo::InferDevMatrixShape() {
// the strategy is (n, c, h, w)
// the dev matrix is (n, c, h, w)
MS_EXCEPTION_IF_NULL(strategy_);
std::vector<Dimensions> stra = strategy_->GetInputDim();
if (stra.empty()) {
MS_LOG(ERROR) << name_ << ": The strategy is empty";
return FAILED;
}
dev_matrix_shape_ = stra[0];
return SUCCESS;
}
Status MaxPoolInfo::InferTensorMap() {
// input_strategy: (n, c, h, w)
// output_strategy: (n, c, h, w)
// dev_matrix: (n, c, h, w)
TensorMap input_tensor_map = {3, 2, 1, 0};
TensorMap output_tensor_map = {3, 2, 1, 0};
(void)inputs_tensor_map_.emplace_back(std::move(input_tensor_map));
(void)outputs_tensor_map_.emplace_back(std::move(output_tensor_map));
return SUCCESS;
}
Status MaxPoolInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
std::vector<StrategyPtr> MaxPoolInfo::GenerateOpStrategies(int64_t stage_id) {
Shape input0_split(inputs_shape_[0].size(), 1);
Shapes splittable_inputs = {input0_split};
std::vector<StrategyPtr> sp_vector;
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << " : Generate strategies for independent inputs() failed.";
}
return sp_vector;
}
Status MaxPoolInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}
Status MaxPoolInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,72 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_MAXPOOL_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_MAXPOOL_INFO_H_
#include <string>
#include <memory>
#include <unordered_map>
#include <vector>
#include "ir/value.h"
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
namespace mindspore {
namespace parallel {
class MaxPoolInfo : public OperatorInfo {
public:
MaxPoolInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<MaxPoolCost>()) {}
~MaxPoolInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
protected:
Status GetAttrs() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
private:
std::vector<int64_t> kernel_size_; // four integers
int64_t pad_mode_ = 0; // "same": 1; "valid": 2;
std::vector<int64_t> stride_; // four integers
std::string format_;
};
class AvgPoolInfo : public MaxPoolInfo {
public:
AvgPoolInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: MaxPoolInfo(name, inputs_shape, outputs_shape, attrs) {}
~AvgPoolInfo() override = default;
};
constexpr int64_t POOL_PAD_MODE_SAME = 1;
constexpr int64_t POOL_PAD_MODE_VALID = 2;
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_MAXPOOL_INFO_H_

View File

@ -57,5 +57,6 @@
#include "frontend/parallel/ops_info/virtual_output_info.h"
#include "frontend/parallel/ops_info/conv2d_info.h"
#include "frontend/parallel/ops_info/batchnorm_info.h"
#include "frontend/parallel/ops_info/maxpool_info.h"
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_

View File

@ -236,6 +236,7 @@ constexpr char ACTIVATION[] = "Activation";
constexpr char PRELU[] = "PReLU";
constexpr char FLOORDIV[] = "FloorDiv";
constexpr char MAXPOOL[] = "MaxPool";
constexpr char AVGPOOL[] = "AvgPool";
constexpr char MAXPOOLV2[] = "MaxPoolV2";
constexpr char L2_NORMALIZE[] = "L2Normalize";
constexpr char TRANSPOSE[] = "Transpose";

View File

@ -158,7 +158,7 @@ bool IsSplittableOperator(const std::string &op_name) {
// clang-format off
static const std::set<std::string> splittable_op =
{MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU,
FLOORDIV, L2_NORMALIZE, ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK,
FLOORDIV, L2_NORMALIZE, ADD, MAXPOOL, AVGPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK,
REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, STACK,
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT,

View File

@ -0,0 +1,131 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.common.api import _executor
from mindspore.nn import Cell, TrainOneStepCell, Momentum
from mindspore.ops import operations as P
class Net(Cell):
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, pool_kernel_size, pool_strides,
strategy1=None, strategy2=None):
super().__init__()
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
pad_mode=pad_mode, stride=stride).shard(strategy1)
self.conv2d_weight = Parameter(conv2d_weight, "w1")
self.max_pool = P.MaxPool(kernel_size=pool_kernel_size, strides=pool_strides).shard(strategy2)
def construct(self, x, b):
out = self.conv2d(x, self.conv2d_weight)
out = self.max_pool(out)
return out
class Net2(Cell):
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, pool_kernel_size, pool_strides,
strategy1=None, strategy2=None):
super().__init__()
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
pad_mode=pad_mode, stride=stride).shard(strategy1)
self.conv2d_weight = Parameter(conv2d_weight, "w1")
self.avg_pool = P.AvgPool(kernel_size=pool_kernel_size, strides=pool_strides).shard(strategy2)
def construct(self, x, b):
out = self.conv2d(x, self.conv2d_weight)
out = self.avg_pool(out)
return out
_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
def compile_net(net):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_executor.compile(train_net, _x, _b)
context.reset_auto_parallel_context()
def test_maxpool_data_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((8, 1, 1, 1),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_maxpool_model_parallel1():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((2, 1, 2, 2),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_maxpool_model_parallel2():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((2, 1, 2, 2),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=4,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_maxpool_auto_parallel():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=4)
compile_net(net)
def test_avgpool_data_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((8, 1, 1, 1),)
net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_avgpool_model_parallel1():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((2, 1, 2, 2),)
net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_avgpool_model_parallel2():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((2, 1, 2, 2),)
net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=4,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_avgpool_auto_parallel():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=4)
compile_net(net)