forked from mindspore-Ecosystem/mindspore
!3966 [AutoParallel]Add dropout distributed op
Merge pull request !3966 from lichen/add_dropout_distributed_op
This commit is contained in:
commit
617b98f104
|
@ -135,6 +135,7 @@ REGISTER(GatherV2PInfo);
|
|||
REGISTER(EmbeddingLookupInfo);
|
||||
REGISTER(TileInfo);
|
||||
REGISTER(StridedSliceInfo);
|
||||
REGISTER(DropoutInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -20,6 +20,8 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
|
||||
#include "ir/value.h"
|
||||
#include "frontend/parallel/auto_parallel/costmodel.h"
|
||||
|
@ -54,6 +56,29 @@ Status Activation::CheckStrategy(const StrategyPtr &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status DropoutInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
||||
if (is_auto_parallel_) {
|
||||
MS_LOG(DEBUG) << name_ << " : Invalid strategy.";
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
||||
}
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// dropout don't support repeated calculation
|
||||
CheckGlobalDeviceManager();
|
||||
auto input_strategy = strategy->GetInputDim().at(0);
|
||||
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
|
||||
auto product_p = std::accumulate(input_strategy.begin(), input_strategy.end(), 1, std::multiplies<int>());
|
||||
if (IntToSize(product_p) != dev_num) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ActivationInfo::GetAttrs() {
|
||||
if (attrs_.size() < ACTIVATION_ATTR_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << " : The size of attrs small than 1.";
|
||||
|
@ -120,6 +145,27 @@ Status Activation::GenerateStrategies(int32_t stage_id) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status DropoutInfo::GenerateStrategies(int32_t stage_id) {
|
||||
is_auto_parallel_ = true;
|
||||
Shape input0_split(inputs_shape_[0].size(), 1);
|
||||
Shapes splittable_inputs = {input0_split};
|
||||
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed.";
|
||||
return FAILED;
|
||||
}
|
||||
size_t success = 0;
|
||||
for (auto &sp : sp_vector) {
|
||||
if (SetCostUnderStrategy(sp) == SUCCESS) {
|
||||
success++;
|
||||
MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy";
|
||||
PrintStrategy(sp);
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Softmax::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
||||
if (is_auto_parallel_) {
|
||||
|
@ -334,6 +380,32 @@ Status ActivationBase::InferTensorInfo() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status DropoutInfo::InferTensorInfo() {
|
||||
// infer tensor shape
|
||||
Shape input_shape = inputs_shape_.at(0);
|
||||
|
||||
// infer slice shape
|
||||
Shapes inputs_slice_shape, outputs_slice_shape;
|
||||
Strategys inputs_strategy = strategy_->GetInputDim();
|
||||
// dropout has two outputs
|
||||
Strategys outputs_strategy = {inputs_strategy.at(0), inputs_strategy.at(0)};
|
||||
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
Shape input_slice_shape = inputs_slice_shape.at(0);
|
||||
TensorLayout input_tensor_layout;
|
||||
if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape);
|
||||
inputs_tensor_info_.push_back(input_tensor_info);
|
||||
// the two outputs of dropout all have the same tensor_info as input
|
||||
outputs_tensor_info_.push_back(input_tensor_info);
|
||||
outputs_tensor_info_.push_back(input_tensor_info);
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ActivationBase::Init(const StrategyPtr &strategy) {
|
||||
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Init failed.";
|
||||
|
|
|
@ -219,6 +219,20 @@ class SigmoidInfo : public ActivationOther {
|
|||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
~SigmoidInfo() override = default;
|
||||
};
|
||||
|
||||
class DropoutInfo : public ActivationOther {
|
||||
public:
|
||||
DropoutInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
~DropoutInfo() override = default;
|
||||
Status GenerateStrategies(int32_t stage_id) override;
|
||||
|
||||
protected:
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status GetAttrs() override { return SUCCESS; }
|
||||
Status InferTensorInfo() override;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_
|
||||
|
|
|
@ -238,6 +238,7 @@ constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD";
|
|||
constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative";
|
||||
constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D";
|
||||
constexpr char ADD[] = "Add";
|
||||
constexpr char DROPOUT[] = "Dropout";
|
||||
constexpr char KStridedSlice[] = "StridedSlice";
|
||||
|
||||
// Parallel don't care
|
||||
|
|
|
@ -256,7 +256,7 @@ bool IsSplittableOperator(const std::string &op_name) {
|
|||
REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
|
||||
MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP,
|
||||
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT,
|
||||
STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE,
|
||||
STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT,
|
||||
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS};
|
||||
// clang-format on
|
||||
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
predict = self.network(x, y)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
return C.grad_all(self.network)(x, y)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1=None, strategy2=None):
|
||||
super().__init__()
|
||||
self.dropout = P.Dropout(keep_prob=0.6).set_strategy(strategy1)
|
||||
self.matmul = P.MatMul().set_strategy(strategy2)
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.matmul(x, y)
|
||||
out, _ = self.dropout(out)
|
||||
return out
|
||||
|
||||
|
||||
def test_dropout_semi_auto():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_dropout_semi_auto2():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((8, 1),)
|
||||
strategy2 = ((4, 2), (2, 1))
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_dropout_semi_auto3():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((2, 4),)
|
||||
strategy2 = ((4, 2), (2, 1))
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_dropout_auto():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
Loading…
Reference in New Issue