!40506 Add filtering of Shard parameters to the GradReducer

Merge pull request !40506 from liuluobin/master_fix
This commit is contained in:
i-robot 2022-08-30 02:25:27 +00:00 committed by Gitee
commit 59dbcb2ef8
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
12 changed files with 20 additions and 100 deletions

View File

@ -1269,7 +1269,7 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(cnode->abstract());
new_cnode->set_scope(cnode->scope());
new_cnode->set_parallel(cnode->is_parallel());
new_cnode->set_attrs(cnode->attrs());
// record map relations between anf from ME and new anf node used in backend
graph->FrontBackendMapAdd(node, new_cnode);
}
@ -2770,33 +2770,7 @@ std::vector<uint32_t> GenerateBucketSizeList(const KernelGraphPtr &graph, const
if (grads_count == 0) {
MS_LOG(EXCEPTION) << "Bprop graph has no grad";
}
uint32_t remove_number = 0;
auto parallel_context = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel_context);
auto parallel_mode = parallel_context->parallel_mode();
if (parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel) {
auto ret = graph->get_return();
MS_EXCEPTION_IF_NULL(ret);
auto current_node = ret->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(current_node);
while (IsPrimitiveCNode(current_node->input(1), prim::kPrimMakeTuple)) {
current_node = current_node->input(1)->cast<CNodePtr>();
}
auto inputs = current_node->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
auto node = inputs[i];
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (cnode->is_parallel()) {
remove_number += 1;
}
}
}
MS_LOG(INFO) << "auto parallel remove_number " << remove_number;
return {grads_count - remove_number};
return {grads_count};
}
std::vector<uint32_t> bucket_size_list;

View File

@ -1113,12 +1113,6 @@ void MindRTBackend::RunGraphBySingleOp(const GraphCompilerInfo &graph_compiler_i
graph_output_info.graph_output_tensors.clear();
graph_compiler_->RecoverGraphOutput(kernel, op_outputs, cnode_ref_count, &op_output_map, &graph_output_info);
// Save grad node to Bucket
if (graph->has_flag(kFlagIsPynativeBpropGraph) && (!common::AnfAlgo::IsControlOpExecInBackend(kernel)) &&
!kernel->is_parallel() && pynative::GraphAdapter::IsAutoParallel()) {
graph_compiler_->AddGradAddrToBucket(graph->graph_id(), graph_output_info.graph_output_tensors);
}
}
WaitTaskFinish();
// Clear bucket resources every step

View File

@ -22,7 +22,6 @@
#include "ir/cell.h"
#include "include/common/debug/anf_ir_dump.h"
#include "pipeline/jit/parse/parse_dynamic.h"
#include "include/common/utils/parallel_context.h"
#include "pipeline/jit/parse/data_converter.h"
#include "pipeline/jit/debug/trace.h"
#include "frontend/optimizer/ad/prim_bprop_optimizer.h"
@ -585,10 +584,6 @@ void GradExecutor::GradNetInner(const py::object *ret, const prim::GradOperation
compile::SetMindRTEnable();
resource->SetBackendAsync([]() { return compile::CreateBackend(); });
MS_LOG(DEBUG) << "Start task emit action";
auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
if (parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel) {
ms_function()->MarkMsFunctionNodes(resource);
}
TaskEmitAction(resource);
MS_LOG(DEBUG) << "Start execute action";
ExecuteAction(resource);

View File

@ -104,38 +104,6 @@ void MsFunction::ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, co
top_cell->set_op_info_with_ms_func_forward_tensors(op_run_info->op_info, total_output_tensors);
}
void MsFunction::MarkMsFunctionNodes(const pipeline::ResourcePtr &resource) const {
auto func_graph = resource->func_graph();
std::vector<size_t> in_ms_function;
const auto &parameters = func_graph->parameters();
for (const auto &parameter : parameters) {
auto param = parameter->cast<ParameterPtr>();
if (!param->has_default()) {
continue;
}
auto iter = std::find(ms_function_params_.begin(), ms_function_params_.end(), param->name());
if (iter != ms_function_params_.end()) {
(void)in_ms_function.emplace_back(1);
} else {
(void)in_ms_function.emplace_back(0);
}
}
auto ret = func_graph->get_return();
auto ret_cnode = ret->cast<CNodePtr>();
auto grads = ret_cnode->input(1)->cast<CNodePtr>();
for (size_t i = 1; i < grads->inputs().size(); i++) {
if (in_ms_function[i - 1] != 0) {
auto node = grads->input(i);
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
cnode->set_parallel(true);
}
}
}
void MsFunction::UpdateMsFunctionForwardTensors(const FrontendOpRunInfoPtr &op_run_info,
const ValuePtr &new_forward_value) const {
MS_EXCEPTION_IF_NULL(op_run_info);

View File

@ -36,7 +36,6 @@ class MsFunction {
MsFunction() = default;
~MsFunction() = default;
void set_graph_phase(const std::string &graph_phase) { graph_phase_ = graph_phase; }
void MarkMsFunctionNodes(const pipeline::ResourcePtr &resource) const;
py::object GradMsFunction(const py::object &out, const py::args &args);
private:

View File

@ -665,16 +665,6 @@ class MS_CORE_API CNode final : public AnfNode, public EffectInfoHolder {
/// \param debug_infos A node debug info of an anf node.
void AddFusedDebugInfoList(const std::vector<NodeDebugInfoPtr> &debug_infos);
/// \brief Check whether this node is in ms_function or not in PyNative Mode.
///
/// \return True if in ms_function, otherwise false.
bool is_parallel() const { return flags_[kIsParallel]; }
/// \brief Set is_parallel_ for CNode.
///
/// \param[in] parallel Boolean.
void set_parallel(bool parallel) { flags_[kIsParallel] = parallel; }
/// \brief Check whether contains a input or indirect input, which is Depend CNode with isolated side-effect node.
///
/// \return True if contains, otherwise false.
@ -692,8 +682,7 @@ class MS_CORE_API CNode final : public AnfNode, public EffectInfoHolder {
static constexpr size_t kInForwardFlag = 1;
static constexpr size_t kEffectHandled = 2;
static constexpr size_t kIsLoad = 3;
static constexpr size_t kIsParallel = 4;
static constexpr size_t kNumFlags = 5;
static constexpr size_t kNumFlags = 4;
static constexpr auto kFuncGraphVarKey = "fg_var";
static constexpr auto kOutputValueKey = "out_value";

View File

@ -115,7 +115,7 @@ KernelGraphPtr KernelGraphUtils::ConstructKernelGraphFromNodeList(const AnfNodeP
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(cnode->abstract());
new_cnode->set_scope(cnode->scope());
new_cnode->set_parallel(cnode->is_parallel());
new_cnode->set_attrs(cnode->attrs());
if (IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
new_cnode->set_fullname_with_scope(cnode->input(kFirstDataInputIndex)->fullname_with_scope());
}

View File

@ -217,6 +217,7 @@ class Parameter(Tensor_):
self._cast_type = None
self._unique = False
self.is_in_parallel = _is_in_parallel_mode()
self.is_in_shard = False
self._pipeline_stage_list = []
if isinstance(default_input, (Tensor_, Tensor)):
Tensor_.__init__(self, default_input.dtype, default_input.shape)

View File

@ -25,6 +25,7 @@ from mindspore.context import ParallelMode
from mindspore._checkparam import Validator as validator
from mindspore import ops, nn
from mindspore.common import dtype as mstype
from mindspore.common.api import is_pynative_parallel
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.ops.primitive import constexpr
from mindspore.ops import composite as C
@ -357,7 +358,8 @@ class TrainOneStepCell(Cell):
self.reducer_flag = False
self.grad_reducer = F.identity
self.parallel_mode = _get_parallel_mode()
self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL)
self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL) or \
is_pynative_parallel()
if self.reducer_flag:
self.mean = _get_gradients_mean()
self.degree = _get_device_num()

View File

@ -26,7 +26,6 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.common.api import ms_function
from mindspore.common.api import is_pynative_parallel
reduce_opt = C.MultitypeFuncGraph("reduce_opt")
@ -126,8 +125,6 @@ def _tensors_allreduce_post(degree, mean, allreduce_filter, grad):
Args:
degree (int): The mean coefficient.
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
allgather (Primitive): The communication operator for sparse gradients.
allreduce (Primitive): The communication operator for gradients.
allreduce_filter (bool): When it is true, allreduce would apply.
grad (Tensor): The gradient tensor before operation.
@ -137,7 +134,7 @@ def _tensors_allreduce_post(degree, mean, allreduce_filter, grad):
if allreduce_filter:
if mean:
grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad)))
return grad
return grad
return grad
@ -395,7 +392,7 @@ class DistributedGradReducer(Cell):
self.degree = degree
self.degree = Tensor(1.0 / self.degree, mstype.float32)
self.mean = mean
self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters)
self.allreduce_filter = tuple((x.layerwise_parallel is False) and (x.is_in_shard is False) for x in parameters)
is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer")
split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices()
if is_parallel_optimizer and split_indices:
@ -413,7 +410,6 @@ class DistributedGradReducer(Cell):
self.ps_parameters = tuple(ps_filter(x) for x in parameters)
self.enable_parameter_server = any(self.ps_parameters)
self.mode = context.get_context("mode")
self.is_pynative_parallel = is_pynative_parallel()
self.enable_tuple_broaden = True
@ms_function
@ -431,9 +427,8 @@ class DistributedGradReducer(Cell):
"""
datatypes = self.map_(F.partial(_get_datatype), grads)
grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads)
if self.is_pynative_parallel:
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean), self.allreduce_filter, grads)
elif self.split_fusion:
if self.split_fusion:
if self.enable_parameter_server:
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
self.op_list, self.allreduce_filter, grads, self.ps_parameters)

View File

@ -19,6 +19,7 @@
from functools import partial
from types import FunctionType
import mindspore as ms
import mindspore.nn as nn
from mindspore import context
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore import log as logger
@ -906,6 +907,10 @@ class Shard(Shard_):
return self.shard_fn
shard_ = Shard()
if isinstance(fn, nn.Cell):
for param in fn.trainable_params():
param.is_in_shard = True
def shard_fn(*args):
args = (fn,) + args
@ms_function(hash_args=fn)

View File

@ -202,10 +202,8 @@ class ResNet(nn.Cell):
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=strides[3])
self.layer4 = F.shard(self.layer4, in_strategy=((4, 2, 1, 1),), out_strategy=(None,),
parameter_plan={
'self.layer4.0.conv2.weight': (2, 2, 1, 1),
})
F.shard(self.layer4, in_strategy=((4, 2, 1, 1),), out_strategy=(None,),
parameter_plan={'self.layer4.0.conv2.weight': (2, 2, 1, 1)})
self.mean = P.ReduceMean(keep_dims=True)
self.end_point = nn.Dense(2048, num_classes, has_bias=True,
@ -387,6 +385,6 @@ def test_train_feed(num_classes=65536):
model = Model(net, loss_fn=loss, optimizer=opt)
model.train(3, dataset, dataset_sink_mode=False, callbacks=parallel_callback)
loss_value = np.array(parallel_callback.loss_list)
expect_out = [11.374571, 11.230516, 10.755886]
expect_out = [11.374571, 11.028273, 10.5469265]
print(loss_value)
assert np.allclose(loss_value, expect_out, 0.0001, 0.0001)