!16653 parallel dropout extend

From: @yao_yf
Reviewed-by: @yangzhenzhang,@stsuteng
Signed-off-by: @stsuteng
This commit is contained in:
mindspore-ci-bot 2021-05-21 15:54:01 +08:00 committed by Gitee
commit 1381a58cfd
5 changed files with 166 additions and 43 deletions

View File

@ -34,23 +34,6 @@ Status Activation::SetCostUnderStrategy(const StrategyPtr &strategy) { return Se
Status Activation::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); } Status Activation::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }
Status DropoutInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
return FAILED;
}
// dropout don't support repeated calculation
auto input_strategy = strategy->GetInputDim().at(0);
auto product_p = std::accumulate(input_strategy.begin(), input_strategy.end(), 1, std::multiplies<int64_t>());
if (product_p != stage_device_size_) {
MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc.";
return FAILED;
}
return SUCCESS;
}
Status ActivationInfo::GetAttrs() { Status ActivationInfo::GetAttrs() {
if (attrs_.size() < ACTIVATION_ATTR_SIZE) { if (attrs_.size() < ACTIVATION_ATTR_SIZE) {
MS_LOG(ERROR) << name_ << " : The size of attrs small than 1."; MS_LOG(ERROR) << name_ << " : The size of attrs small than 1.";
@ -116,15 +99,6 @@ Status Activation::GenerateStrategies(int64_t stage_id) {
return SUCCESS; return SUCCESS;
} }
bool DropoutInfo::IsRepeatedStrategy(const StrategyPtr &sp) {
auto input_strategy = sp->GetInputDim().at(0);
auto product_p = std::accumulate(input_strategy.begin(), input_strategy.end(), 1, std::multiplies<int64_t>());
if (product_p != stage_device_size_) {
return true;
}
return false;
}
Status DropoutInfo::GenerateStrategies(int64_t stage_id) { Status DropoutInfo::GenerateStrategies(int64_t stage_id) {
Shape input0_split(inputs_shape_[0].size(), 1); Shape input0_split(inputs_shape_[0].size(), 1);
Shapes splittable_inputs = {input0_split}; Shapes splittable_inputs = {input0_split};
@ -136,9 +110,6 @@ Status DropoutInfo::GenerateStrategies(int64_t stage_id) {
} }
size_t success = 0; size_t success = 0;
for (auto &sp : sp_vector) { for (auto &sp : sp_vector) {
if (IsRepeatedStrategy(sp)) {
continue;
}
if (SetCostUnderStrategy(sp) == SUCCESS) { if (SetCostUnderStrategy(sp) == SUCCESS) {
success++; success++;
MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy";
@ -333,6 +304,30 @@ Status ActivationBase::InferTensorInfo() {
return SUCCESS; return SUCCESS;
} }
Status DropoutInfo::GetAttrs() {
auto iter0 = attrs_.find(SEED0);
if (iter0 != attrs_.end()) {
MS_EXCEPTION_IF_NULL(iter0->second);
if (iter0->second->isa<Int64Imm>()) {
seed0_ = iter0->second->cast<Int64ImmPtr>()->value();
} else {
MS_LOG(ERROR) << name_ << " : The value of seed0 is not int64_t.";
return FAILED;
}
}
auto iter1 = attrs_.find(SEED1);
if (iter1 != attrs_.end()) {
MS_EXCEPTION_IF_NULL(iter1->second);
if (iter1->second->isa<Int64Imm>()) {
seed1_ = iter1->second->cast<Int64ImmPtr>()->value();
} else {
MS_LOG(ERROR) << name_ << " : The value of seed1 is not int64_t.";
return FAILED;
}
}
return SUCCESS;
}
Status DropoutInfo::InferTensorInfo() { Status DropoutInfo::InferTensorInfo() {
// infer tensor shape // infer tensor shape
Shape input_shape = inputs_shape_.at(0); Shape input_shape = inputs_shape_.at(0);
@ -359,6 +354,36 @@ Status DropoutInfo::InferTensorInfo() {
return SUCCESS; return SUCCESS;
} }
Status DropoutInfo::InferReplaceOps(const StrategyPtr &) {
if ((seed0_ != 0) || (seed1_ != 0) || (repeated_calc_num_ == 1)) {
return SUCCESS;
}
int64_t seed = get_seed();
ValuePtr new_seed0 = MakeValue(seed);
ValuePtr new_seed1 = MakeValue(seed);
Attr attr_seed0 = std::make_pair(SEED0, new_seed0);
Attr attr_seed1 = std::make_pair(SEED1, new_seed1);
Attr attr_keep_probs = std::make_pair(KEEP_PROB, attrs_[KEEP_PROB]);
OperatorAttrs attrs = {attr_keep_probs, attr_seed0, attr_seed1};
OperatorParams params;
OperatorArgs args = std::make_pair(attrs, params);
replace_op_ = {std::make_pair(DROPOUT, args)};
return SUCCESS;
}
Status DropoutInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init failed";
return FAILED;
}
if (InferReplaceOps(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Infer replace Ops failed";
return FAILED;
}
MS_LOG(INFO) << name_ << " : Init success";
return SUCCESS;
}
Status ActivationBase::Init(const StrategyPtr &strategy) { Status ActivationBase::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init failed."; MS_LOG(ERROR) << name_ << " : Init failed.";

View File

@ -266,14 +266,20 @@ class DropoutInfo : public ActivationOther {
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutCost>()) {} : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutCost>()) {}
~DropoutInfo() override = default; ~DropoutInfo() override = default;
Status GenerateStrategies(int64_t stage_id) override; Status GenerateStrategies(int64_t stage_id) override;
Status Init(const StrategyPtr &strategy) override;
protected: protected:
Status CheckStrategy(const StrategyPtr &strategy) override; Status GetAttrs() override;
Status GetAttrs() override { return SUCCESS; }
Status InferTensorInfo() override; Status InferTensorInfo() override;
Status InferReplaceOps(const StrategyPtr &strategy);
private: private:
bool IsRepeatedStrategy(const StrategyPtr &sp); int64_t seed0_ = 0;
int64_t seed1_ = 0;
int64_t get_seed() {
static int64_t SEED_NUM;
return ++SEED_NUM;
}
}; };
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

View File

@ -17,14 +17,14 @@
#include "frontend/parallel/parallel_stub/executor_manager_stub.h" #include "frontend/parallel/parallel_stub/executor_manager_stub.h"
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
std::shared_ptr<Executor> ExecutorManager::GetExecutor(const std::string &device_name, int device_id) { std::shared_ptr<Executor> ExecutorManager::GetExecutor(const std::string &dev_name, int dev_id) {
std::string device_key = device_name + "_" + std::to_string(device_id); std::string dev_key = dev_name + "_" + std::to_string(dev_id);
auto iter = executors_.find(device_key); auto iter = executors_.find(dev_key);
if (iter != executors_.end()) { if (iter != executors_.end()) {
return iter->second; return iter->second;
} }
auto executor = std::make_shared<Executor>(device_name, device_id); auto executor = std::make_shared<Executor>(dev_name, dev_id);
executors_[device_key] = executor; executors_[dev_key] = executor;
return executor; return executor;
} }
} // namespace parallel } // namespace parallel

View File

@ -134,8 +134,8 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
for (auto &dims : node_stra.second->GetInputDim()) { for (auto &dims : node_stra.second->GetInputDim()) {
straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy(); straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy();
MS_EXCEPTION_IF_NULL(parallel_strategy); MS_EXCEPTION_IF_NULL(parallel_strategy);
for (auto dim : dims) { for (auto stra_dim : dims) {
parallel_strategy->add_dim(LongToUlong(dim)); parallel_strategy->add_dim(LongToUlong(stra_dim));
} }
} }
} }
@ -147,13 +147,13 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
straspb::ParallelLayouts *parallel_layouts = parallel_layout_item->mutable_parallel_layouts(); straspb::ParallelLayouts *parallel_layouts = parallel_layout_item->mutable_parallel_layouts();
straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix(); straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix();
MS_EXCEPTION_IF_NULL(dev_matrix); MS_EXCEPTION_IF_NULL(dev_matrix);
for (auto dim : tensor_layout->device_arrangement().array()) { for (auto dev_dim : tensor_layout->device_arrangement().array()) {
dev_matrix->add_dim(LongToUlong(dim)); dev_matrix->add_dim(LongToUlong(dev_dim));
} }
straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map(); straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map();
MS_EXCEPTION_IF_NULL(tensor_map); MS_EXCEPTION_IF_NULL(tensor_map);
for (auto dim : tensor_layout->tensor_map().array()) { for (auto map_dim : tensor_layout->tensor_map().array()) {
tensor_map->add_dim(dim); tensor_map->add_dim(map_dim);
} }
straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape(); straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape();
straspb::IndicesOffset *indices_offset = parallel_layouts->add_indices_offset(); straspb::IndicesOffset *indices_offset = parallel_layouts->add_indices_offset();

View File

@ -0,0 +1,92 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.common.api import _executor
from mindspore.nn import Cell, TrainOneStepCell, Momentum
from mindspore.ops import operations as P
class Net(Cell):
def __init__(self, mul_weight, strategy1=None, strategy2=None):
super().__init__()
self.mul = P.Mul().shard(strategy1)
self.dropout1 = P.Dropout(keep_prob=0.5).shard(strategy2)
self.relu = P.ReLU().shard(strategy2)
self.dropout2 = P.Dropout(keep_prob=0.5).shard(strategy2)
self.relu2 = P.ReLU().shard(strategy2)
self.mul_weight = Parameter(mul_weight, "w1")
def construct(self, x, b):
out = self.mul(x, self.mul_weight)
out, _ = self.dropout1(out)
out = self.relu(out)
out, _ = self.dropout2(out)
out = self.relu2(out)
return out
_x = Tensor(np.ones([128, 64]), dtype=ms.float32)
_w1 = Tensor(np.ones([128, 64]), dtype=ms.float32)
_b = Tensor(np.ones([128, 64]), dtype=ms.float32)
def compile_net(net):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_executor.compile(train_net, _x, _b)
context.reset_auto_parallel_context()
def test_dropout_data_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((16, 1), (16, 1))
strategy2 = ((16, 1),)
net = Net(_w1, strategy1, strategy2)
compile_net(net)
def test_dropout_model_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((1, 16), (1, 16))
strategy2 = ((1, 16),)
net = Net(_w1, strategy1, strategy2)
compile_net(net)
def test_dropout_mixed_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((4, 4), (4, 4))
strategy2 = ((4, 4),)
net = Net(_w1, strategy1, strategy2)
compile_net(net)
def test_dropout_auto_parallel():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0)
net = Net(_w1)
compile_net(net)
def test_dropout_repeat_calc():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((4, 4), (4, 4))
strategy2 = ((2, 4),)
net = Net(_w1, strategy1, strategy2)
compile_net(net)