forked from mindspore-Ecosystem/mindspore
!16653 parallel dropout extend
From: @yao_yf Reviewed-by: @yangzhenzhang,@stsuteng Signed-off-by: @stsuteng
This commit is contained in:
commit
1381a58cfd
|
@ -34,23 +34,6 @@ Status Activation::SetCostUnderStrategy(const StrategyPtr &strategy) { return Se
|
|||
|
||||
Status Activation::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }
|
||||
|
||||
Status DropoutInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// dropout don't support repeated calculation
|
||||
auto input_strategy = strategy->GetInputDim().at(0);
|
||||
auto product_p = std::accumulate(input_strategy.begin(), input_strategy.end(), 1, std::multiplies<int64_t>());
|
||||
if (product_p != stage_device_size_) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ActivationInfo::GetAttrs() {
|
||||
if (attrs_.size() < ACTIVATION_ATTR_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << " : The size of attrs small than 1.";
|
||||
|
@ -116,15 +99,6 @@ Status Activation::GenerateStrategies(int64_t stage_id) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
bool DropoutInfo::IsRepeatedStrategy(const StrategyPtr &sp) {
|
||||
auto input_strategy = sp->GetInputDim().at(0);
|
||||
auto product_p = std::accumulate(input_strategy.begin(), input_strategy.end(), 1, std::multiplies<int64_t>());
|
||||
if (product_p != stage_device_size_) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Status DropoutInfo::GenerateStrategies(int64_t stage_id) {
|
||||
Shape input0_split(inputs_shape_[0].size(), 1);
|
||||
Shapes splittable_inputs = {input0_split};
|
||||
|
@ -136,9 +110,6 @@ Status DropoutInfo::GenerateStrategies(int64_t stage_id) {
|
|||
}
|
||||
size_t success = 0;
|
||||
for (auto &sp : sp_vector) {
|
||||
if (IsRepeatedStrategy(sp)) {
|
||||
continue;
|
||||
}
|
||||
if (SetCostUnderStrategy(sp) == SUCCESS) {
|
||||
success++;
|
||||
MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy";
|
||||
|
@ -333,6 +304,30 @@ Status ActivationBase::InferTensorInfo() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status DropoutInfo::GetAttrs() {
|
||||
auto iter0 = attrs_.find(SEED0);
|
||||
if (iter0 != attrs_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(iter0->second);
|
||||
if (iter0->second->isa<Int64Imm>()) {
|
||||
seed0_ = iter0->second->cast<Int64ImmPtr>()->value();
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << " : The value of seed0 is not int64_t.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
auto iter1 = attrs_.find(SEED1);
|
||||
if (iter1 != attrs_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(iter1->second);
|
||||
if (iter1->second->isa<Int64Imm>()) {
|
||||
seed1_ = iter1->second->cast<Int64ImmPtr>()->value();
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << " : The value of seed1 is not int64_t.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status DropoutInfo::InferTensorInfo() {
|
||||
// infer tensor shape
|
||||
Shape input_shape = inputs_shape_.at(0);
|
||||
|
@ -359,6 +354,36 @@ Status DropoutInfo::InferTensorInfo() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status DropoutInfo::InferReplaceOps(const StrategyPtr &) {
|
||||
if ((seed0_ != 0) || (seed1_ != 0) || (repeated_calc_num_ == 1)) {
|
||||
return SUCCESS;
|
||||
}
|
||||
int64_t seed = get_seed();
|
||||
ValuePtr new_seed0 = MakeValue(seed);
|
||||
ValuePtr new_seed1 = MakeValue(seed);
|
||||
Attr attr_seed0 = std::make_pair(SEED0, new_seed0);
|
||||
Attr attr_seed1 = std::make_pair(SEED1, new_seed1);
|
||||
Attr attr_keep_probs = std::make_pair(KEEP_PROB, attrs_[KEEP_PROB]);
|
||||
OperatorAttrs attrs = {attr_keep_probs, attr_seed0, attr_seed1};
|
||||
OperatorParams params;
|
||||
OperatorArgs args = std::make_pair(attrs, params);
|
||||
replace_op_ = {std::make_pair(DROPOUT, args)};
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status DropoutInfo::Init(const StrategyPtr &strategy) {
|
||||
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Init failed";
|
||||
return FAILED;
|
||||
}
|
||||
if (InferReplaceOps(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Infer replace Ops failed";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << name_ << " : Init success";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ActivationBase::Init(const StrategyPtr &strategy) {
|
||||
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Init failed.";
|
||||
|
|
|
@ -266,14 +266,20 @@ class DropoutInfo : public ActivationOther {
|
|||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutCost>()) {}
|
||||
~DropoutInfo() override = default;
|
||||
Status GenerateStrategies(int64_t stage_id) override;
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
|
||||
protected:
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status GetAttrs() override { return SUCCESS; }
|
||||
Status GetAttrs() override;
|
||||
Status InferTensorInfo() override;
|
||||
Status InferReplaceOps(const StrategyPtr &strategy);
|
||||
|
||||
private:
|
||||
bool IsRepeatedStrategy(const StrategyPtr &sp);
|
||||
int64_t seed0_ = 0;
|
||||
int64_t seed1_ = 0;
|
||||
int64_t get_seed() {
|
||||
static int64_t SEED_NUM;
|
||||
return ++SEED_NUM;
|
||||
}
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,14 +17,14 @@
|
|||
#include "frontend/parallel/parallel_stub/executor_manager_stub.h"
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
std::shared_ptr<Executor> ExecutorManager::GetExecutor(const std::string &device_name, int device_id) {
|
||||
std::string device_key = device_name + "_" + std::to_string(device_id);
|
||||
auto iter = executors_.find(device_key);
|
||||
std::shared_ptr<Executor> ExecutorManager::GetExecutor(const std::string &dev_name, int dev_id) {
|
||||
std::string dev_key = dev_name + "_" + std::to_string(dev_id);
|
||||
auto iter = executors_.find(dev_key);
|
||||
if (iter != executors_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
auto executor = std::make_shared<Executor>(device_name, device_id);
|
||||
executors_[device_key] = executor;
|
||||
auto executor = std::make_shared<Executor>(dev_name, dev_id);
|
||||
executors_[dev_key] = executor;
|
||||
return executor;
|
||||
}
|
||||
} // namespace parallel
|
||||
|
|
|
@ -134,8 +134,8 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
|
|||
for (auto &dims : node_stra.second->GetInputDim()) {
|
||||
straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy();
|
||||
MS_EXCEPTION_IF_NULL(parallel_strategy);
|
||||
for (auto dim : dims) {
|
||||
parallel_strategy->add_dim(LongToUlong(dim));
|
||||
for (auto stra_dim : dims) {
|
||||
parallel_strategy->add_dim(LongToUlong(stra_dim));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -147,13 +147,13 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
|
|||
straspb::ParallelLayouts *parallel_layouts = parallel_layout_item->mutable_parallel_layouts();
|
||||
straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix();
|
||||
MS_EXCEPTION_IF_NULL(dev_matrix);
|
||||
for (auto dim : tensor_layout->device_arrangement().array()) {
|
||||
dev_matrix->add_dim(LongToUlong(dim));
|
||||
for (auto dev_dim : tensor_layout->device_arrangement().array()) {
|
||||
dev_matrix->add_dim(LongToUlong(dev_dim));
|
||||
}
|
||||
straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map();
|
||||
MS_EXCEPTION_IF_NULL(tensor_map);
|
||||
for (auto dim : tensor_layout->tensor_map().array()) {
|
||||
tensor_map->add_dim(dim);
|
||||
for (auto map_dim : tensor_layout->tensor_map().array()) {
|
||||
tensor_map->add_dim(map_dim);
|
||||
}
|
||||
straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape();
|
||||
straspb::IndicesOffset *indices_offset = parallel_layouts->add_indices_offset();
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.nn import Cell, TrainOneStepCell, Momentum
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, mul_weight, strategy1=None, strategy2=None):
|
||||
super().__init__()
|
||||
self.mul = P.Mul().shard(strategy1)
|
||||
self.dropout1 = P.Dropout(keep_prob=0.5).shard(strategy2)
|
||||
self.relu = P.ReLU().shard(strategy2)
|
||||
self.dropout2 = P.Dropout(keep_prob=0.5).shard(strategy2)
|
||||
self.relu2 = P.ReLU().shard(strategy2)
|
||||
self.mul_weight = Parameter(mul_weight, "w1")
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.mul(x, self.mul_weight)
|
||||
out, _ = self.dropout1(out)
|
||||
out = self.relu(out)
|
||||
out, _ = self.dropout2(out)
|
||||
out = self.relu2(out)
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([128, 64]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([128, 64]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([128, 64]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile_net(net):
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
train_net.set_train()
|
||||
_executor.compile(train_net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_dropout_data_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((16, 1), (16, 1))
|
||||
strategy2 = ((16, 1),)
|
||||
net = Net(_w1, strategy1, strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_dropout_model_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((1, 16), (1, 16))
|
||||
strategy2 = ((1, 16),)
|
||||
net = Net(_w1, strategy1, strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_dropout_mixed_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((4, 4), (4, 4))
|
||||
strategy2 = ((4, 4),)
|
||||
net = Net(_w1, strategy1, strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_dropout_auto_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0)
|
||||
net = Net(_w1)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_dropout_repeat_calc():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((4, 4), (4, 4))
|
||||
strategy2 = ((2, 4),)
|
||||
net = Net(_w1, strategy1, strategy2)
|
||||
compile_net(net)
|
Loading…
Reference in New Issue