forked from mindspore-Ecosystem/mindspore
!40506 Add filtering of Shard parameters to the GradReducer
Merge pull request !40506 from liuluobin/master_fix
This commit is contained in:
commit
59dbcb2ef8
|
@ -1269,7 +1269,7 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
|
|||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(cnode->abstract());
|
||||
new_cnode->set_scope(cnode->scope());
|
||||
new_cnode->set_parallel(cnode->is_parallel());
|
||||
new_cnode->set_attrs(cnode->attrs());
|
||||
// record map relations between anf from ME and new anf node used in backend
|
||||
graph->FrontBackendMapAdd(node, new_cnode);
|
||||
}
|
||||
|
@ -2770,33 +2770,7 @@ std::vector<uint32_t> GenerateBucketSizeList(const KernelGraphPtr &graph, const
|
|||
if (grads_count == 0) {
|
||||
MS_LOG(EXCEPTION) << "Bprop graph has no grad";
|
||||
}
|
||||
uint32_t remove_number = 0;
|
||||
auto parallel_context = parallel::ParallelContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(parallel_context);
|
||||
auto parallel_mode = parallel_context->parallel_mode();
|
||||
if (parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel) {
|
||||
auto ret = graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
auto current_node = ret->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(current_node);
|
||||
while (IsPrimitiveCNode(current_node->input(1), prim::kPrimMakeTuple)) {
|
||||
current_node = current_node->input(1)->cast<CNodePtr>();
|
||||
}
|
||||
auto inputs = current_node->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
auto node = inputs[i];
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode->is_parallel()) {
|
||||
remove_number += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "auto parallel remove_number " << remove_number;
|
||||
return {grads_count - remove_number};
|
||||
return {grads_count};
|
||||
}
|
||||
|
||||
std::vector<uint32_t> bucket_size_list;
|
||||
|
|
|
@ -1113,12 +1113,6 @@ void MindRTBackend::RunGraphBySingleOp(const GraphCompilerInfo &graph_compiler_i
|
|||
|
||||
graph_output_info.graph_output_tensors.clear();
|
||||
graph_compiler_->RecoverGraphOutput(kernel, op_outputs, cnode_ref_count, &op_output_map, &graph_output_info);
|
||||
|
||||
// Save grad node to Bucket
|
||||
if (graph->has_flag(kFlagIsPynativeBpropGraph) && (!common::AnfAlgo::IsControlOpExecInBackend(kernel)) &&
|
||||
!kernel->is_parallel() && pynative::GraphAdapter::IsAutoParallel()) {
|
||||
graph_compiler_->AddGradAddrToBucket(graph->graph_id(), graph_output_info.graph_output_tensors);
|
||||
}
|
||||
}
|
||||
WaitTaskFinish();
|
||||
// Clear bucket resources every step
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
#include "ir/cell.h"
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
#include "pipeline/jit/parse/parse_dynamic.h"
|
||||
#include "include/common/utils/parallel_context.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "pipeline/jit/debug/trace.h"
|
||||
#include "frontend/optimizer/ad/prim_bprop_optimizer.h"
|
||||
|
@ -585,10 +584,6 @@ void GradExecutor::GradNetInner(const py::object *ret, const prim::GradOperation
|
|||
compile::SetMindRTEnable();
|
||||
resource->SetBackendAsync([]() { return compile::CreateBackend(); });
|
||||
MS_LOG(DEBUG) << "Start task emit action";
|
||||
auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
|
||||
if (parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel) {
|
||||
ms_function()->MarkMsFunctionNodes(resource);
|
||||
}
|
||||
TaskEmitAction(resource);
|
||||
MS_LOG(DEBUG) << "Start execute action";
|
||||
ExecuteAction(resource);
|
||||
|
|
|
@ -104,38 +104,6 @@ void MsFunction::ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, co
|
|||
top_cell->set_op_info_with_ms_func_forward_tensors(op_run_info->op_info, total_output_tensors);
|
||||
}
|
||||
|
||||
void MsFunction::MarkMsFunctionNodes(const pipeline::ResourcePtr &resource) const {
|
||||
auto func_graph = resource->func_graph();
|
||||
std::vector<size_t> in_ms_function;
|
||||
const auto ¶meters = func_graph->parameters();
|
||||
for (const auto ¶meter : parameters) {
|
||||
auto param = parameter->cast<ParameterPtr>();
|
||||
if (!param->has_default()) {
|
||||
continue;
|
||||
}
|
||||
auto iter = std::find(ms_function_params_.begin(), ms_function_params_.end(), param->name());
|
||||
if (iter != ms_function_params_.end()) {
|
||||
(void)in_ms_function.emplace_back(1);
|
||||
} else {
|
||||
(void)in_ms_function.emplace_back(0);
|
||||
}
|
||||
}
|
||||
|
||||
auto ret = func_graph->get_return();
|
||||
auto ret_cnode = ret->cast<CNodePtr>();
|
||||
auto grads = ret_cnode->input(1)->cast<CNodePtr>();
|
||||
for (size_t i = 1; i < grads->inputs().size(); i++) {
|
||||
if (in_ms_function[i - 1] != 0) {
|
||||
auto node = grads->input(i);
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
cnode->set_parallel(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MsFunction::UpdateMsFunctionForwardTensors(const FrontendOpRunInfoPtr &op_run_info,
|
||||
const ValuePtr &new_forward_value) const {
|
||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
|
|
|
@ -36,7 +36,6 @@ class MsFunction {
|
|||
MsFunction() = default;
|
||||
~MsFunction() = default;
|
||||
void set_graph_phase(const std::string &graph_phase) { graph_phase_ = graph_phase; }
|
||||
void MarkMsFunctionNodes(const pipeline::ResourcePtr &resource) const;
|
||||
py::object GradMsFunction(const py::object &out, const py::args &args);
|
||||
|
||||
private:
|
||||
|
|
|
@ -665,16 +665,6 @@ class MS_CORE_API CNode final : public AnfNode, public EffectInfoHolder {
|
|||
/// \param debug_infos A node debug info of an anf node.
|
||||
void AddFusedDebugInfoList(const std::vector<NodeDebugInfoPtr> &debug_infos);
|
||||
|
||||
/// \brief Check whether this node is in ms_function or not in PyNative Mode.
|
||||
///
|
||||
/// \return True if in ms_function, otherwise false.
|
||||
bool is_parallel() const { return flags_[kIsParallel]; }
|
||||
|
||||
/// \brief Set is_parallel_ for CNode.
|
||||
///
|
||||
/// \param[in] parallel Boolean.
|
||||
void set_parallel(bool parallel) { flags_[kIsParallel] = parallel; }
|
||||
|
||||
/// \brief Check whether contains a input or indirect input, which is Depend CNode with isolated side-effect node.
|
||||
///
|
||||
/// \return True if contains, otherwise false.
|
||||
|
@ -692,8 +682,7 @@ class MS_CORE_API CNode final : public AnfNode, public EffectInfoHolder {
|
|||
static constexpr size_t kInForwardFlag = 1;
|
||||
static constexpr size_t kEffectHandled = 2;
|
||||
static constexpr size_t kIsLoad = 3;
|
||||
static constexpr size_t kIsParallel = 4;
|
||||
static constexpr size_t kNumFlags = 5;
|
||||
static constexpr size_t kNumFlags = 4;
|
||||
static constexpr auto kFuncGraphVarKey = "fg_var";
|
||||
static constexpr auto kOutputValueKey = "out_value";
|
||||
|
||||
|
|
|
@ -115,7 +115,7 @@ KernelGraphPtr KernelGraphUtils::ConstructKernelGraphFromNodeList(const AnfNodeP
|
|||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(cnode->abstract());
|
||||
new_cnode->set_scope(cnode->scope());
|
||||
new_cnode->set_parallel(cnode->is_parallel());
|
||||
new_cnode->set_attrs(cnode->attrs());
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
|
||||
new_cnode->set_fullname_with_scope(cnode->input(kFirstDataInputIndex)->fullname_with_scope());
|
||||
}
|
||||
|
|
|
@ -217,6 +217,7 @@ class Parameter(Tensor_):
|
|||
self._cast_type = None
|
||||
self._unique = False
|
||||
self.is_in_parallel = _is_in_parallel_mode()
|
||||
self.is_in_shard = False
|
||||
self._pipeline_stage_list = []
|
||||
if isinstance(default_input, (Tensor_, Tensor)):
|
||||
Tensor_.__init__(self, default_input.dtype, default_input.shape)
|
||||
|
|
|
@ -25,6 +25,7 @@ from mindspore.context import ParallelMode
|
|||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore import ops, nn
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.api import is_pynative_parallel
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops import composite as C
|
||||
|
@ -357,7 +358,8 @@ class TrainOneStepCell(Cell):
|
|||
self.reducer_flag = False
|
||||
self.grad_reducer = F.identity
|
||||
self.parallel_mode = _get_parallel_mode()
|
||||
self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL)
|
||||
self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL) or \
|
||||
is_pynative_parallel()
|
||||
if self.reducer_flag:
|
||||
self.mean = _get_gradients_mean()
|
||||
self.degree = _get_device_num()
|
||||
|
|
|
@ -26,7 +26,6 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
|||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.api import is_pynative_parallel
|
||||
|
||||
|
||||
reduce_opt = C.MultitypeFuncGraph("reduce_opt")
|
||||
|
@ -126,8 +125,6 @@ def _tensors_allreduce_post(degree, mean, allreduce_filter, grad):
|
|||
Args:
|
||||
degree (int): The mean coefficient.
|
||||
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
|
||||
allgather (Primitive): The communication operator for sparse gradients.
|
||||
allreduce (Primitive): The communication operator for gradients.
|
||||
allreduce_filter (bool): When it is true, allreduce would apply.
|
||||
grad (Tensor): The gradient tensor before operation.
|
||||
|
||||
|
@ -137,7 +134,7 @@ def _tensors_allreduce_post(degree, mean, allreduce_filter, grad):
|
|||
if allreduce_filter:
|
||||
if mean:
|
||||
grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad)))
|
||||
return grad
|
||||
return grad
|
||||
return grad
|
||||
|
||||
|
||||
|
@ -395,7 +392,7 @@ class DistributedGradReducer(Cell):
|
|||
self.degree = degree
|
||||
self.degree = Tensor(1.0 / self.degree, mstype.float32)
|
||||
self.mean = mean
|
||||
self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters)
|
||||
self.allreduce_filter = tuple((x.layerwise_parallel is False) and (x.is_in_shard is False) for x in parameters)
|
||||
is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer")
|
||||
split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices()
|
||||
if is_parallel_optimizer and split_indices:
|
||||
|
@ -413,7 +410,6 @@ class DistributedGradReducer(Cell):
|
|||
self.ps_parameters = tuple(ps_filter(x) for x in parameters)
|
||||
self.enable_parameter_server = any(self.ps_parameters)
|
||||
self.mode = context.get_context("mode")
|
||||
self.is_pynative_parallel = is_pynative_parallel()
|
||||
self.enable_tuple_broaden = True
|
||||
|
||||
@ms_function
|
||||
|
@ -431,9 +427,8 @@ class DistributedGradReducer(Cell):
|
|||
"""
|
||||
datatypes = self.map_(F.partial(_get_datatype), grads)
|
||||
grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads)
|
||||
if self.is_pynative_parallel:
|
||||
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean), self.allreduce_filter, grads)
|
||||
elif self.split_fusion:
|
||||
|
||||
if self.split_fusion:
|
||||
if self.enable_parameter_server:
|
||||
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
|
||||
self.op_list, self.allreduce_filter, grads, self.ps_parameters)
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
from functools import partial
|
||||
from types import FunctionType
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore import log as logger
|
||||
|
@ -906,6 +907,10 @@ class Shard(Shard_):
|
|||
return self.shard_fn
|
||||
shard_ = Shard()
|
||||
|
||||
if isinstance(fn, nn.Cell):
|
||||
for param in fn.trainable_params():
|
||||
param.is_in_shard = True
|
||||
|
||||
def shard_fn(*args):
|
||||
args = (fn,) + args
|
||||
@ms_function(hash_args=fn)
|
||||
|
|
|
@ -202,10 +202,8 @@ class ResNet(nn.Cell):
|
|||
in_channel=in_channels[3],
|
||||
out_channel=out_channels[3],
|
||||
stride=strides[3])
|
||||
self.layer4 = F.shard(self.layer4, in_strategy=((4, 2, 1, 1),), out_strategy=(None,),
|
||||
parameter_plan={
|
||||
'self.layer4.0.conv2.weight': (2, 2, 1, 1),
|
||||
})
|
||||
F.shard(self.layer4, in_strategy=((4, 2, 1, 1),), out_strategy=(None,),
|
||||
parameter_plan={'self.layer4.0.conv2.weight': (2, 2, 1, 1)})
|
||||
|
||||
self.mean = P.ReduceMean(keep_dims=True)
|
||||
self.end_point = nn.Dense(2048, num_classes, has_bias=True,
|
||||
|
@ -387,6 +385,6 @@ def test_train_feed(num_classes=65536):
|
|||
model = Model(net, loss_fn=loss, optimizer=opt)
|
||||
model.train(3, dataset, dataset_sink_mode=False, callbacks=parallel_callback)
|
||||
loss_value = np.array(parallel_callback.loss_list)
|
||||
expect_out = [11.374571, 11.230516, 10.755886]
|
||||
expect_out = [11.374571, 11.028273, 10.5469265]
|
||||
print(loss_value)
|
||||
assert np.allclose(loss_value, expect_out, 0.0001, 0.0001)
|
||||
|
|
Loading…
Reference in New Issue