!41752 add vmap model ensembling parallel training code

Merge pull request !41752 from Erpim/master
This commit is contained in:
i-robot 2022-09-16 01:12:41 +00:00 committed by Gitee
commit 5c9bd1f6a5
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
13 changed files with 684 additions and 159 deletions

View File

@ -941,14 +941,14 @@ VmapOperation::VmapOperation(const std::string &name) : MetaFuncGraph(name) {
SignatureEnumDType::kDTypeEmptyDefaultValue}});
}
FuncGraphPtr VmapOperation::GetVmap(const AnfNodePtr &vmap, const std::vector<AnfNodePtr> &forward_graph_params) const {
FuncGraphPtr VmapOperation::GetVmap(const AnfNodePtr &vmap, int param_number) const {
FuncGraphPtr vmap_child = std::make_shared<FuncGraph>();
vmap_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
vmap_child->set_flag(FUNC_GRAPH_FLAG_K_GRAPH, true);
std::vector<AnfNodePtr> inputs;
inputs.push_back(vmap);
for (size_t i = 0; i < forward_graph_params.size(); ++i) {
for (int i = 0; i < param_number; ++i) {
inputs.push_back(vmap_child->add_parameter());
}
auto vmap_app = vmap_child->NewCNodeInOrder(inputs);
@ -976,7 +976,7 @@ bool IsAxesAllNone(const ValuePtr &axes) {
return false;
}
ValuePtr CheckAxes(const AbstractBasePtr &axes_abs, const bool &is_in_axes = false, int nparam = 0) {
ValuePtr CheckAxes(const AbstractBasePtr &axes_abs, bool is_in_axes = false, int nparam = 0, size_t cell_size = 0) {
ValuePtr axes_value = nullptr;
auto axes_name = is_in_axes ? "in_axes" : "out_axes";
@ -994,22 +994,81 @@ ValuePtr CheckAxes(const AbstractBasePtr &axes_abs, const bool &is_in_axes = fal
}
}
bool elem_all_none = IsAxesAllNone(axes_value);
if (elem_all_none) {
MS_LOG(EXCEPTION) << "The '" << axes_name << "' of 'vmap' cannot be all None, but got " << axes_value->ToString()
<< ".";
if (elem_all_none && cell_size == 0) {
MS_LOG(EXCEPTION) << "The '" << axes_name
<< "' of 'vmap' cannot be all None while 'fn' is not a 'CellList', but got "
<< axes_value->ToString() << ".";
}
} else {
axes_value = axes_abs->BuildValue();
MS_EXCEPTION_IF_NULL(axes_value);
if (axes_value->isa<None>()) {
MS_LOG(EXCEPTION) << "The '" << axes_name << "' of 'vmap' cannot be a single None.";
} else if (!axes_value->isa<Int64Imm>()) {
if (axes_value->isa<None>() && cell_size == 0) {
MS_LOG(EXCEPTION) << "The '" << axes_name
<< "' of 'vmap' cannot be a single None while 'fn' is not a 'CellList'.";
} else if (!axes_value->isa<None>() && !axes_value->isa<Int64Imm>()) {
MS_LOG(EXCEPTION) << "The axis in vmap`s '" << axes_name << "' can only be of type Int or None, but got "
<< axes_abs->ToString() << ".";
}
}
return axes_value;
}
DebugInfoPtr CheckVmapFunc(const AbstractBasePtr &fn_arg, int *nparam, size_t *cell_size) {
DebugInfoPtr origin_graph_info = nullptr;
// In the model ensembling parallel training scenario, fn is a CellList.
AbstractTuplePtr cell_list = dyn_cast<AbstractTuple>(fn_arg);
if (cell_list != nullptr) {
*cell_size = cell_list->size();
if (*cell_size <= 1) {
MS_LOG(EXCEPTION) << "In the model ensembling parallel training scenario ('VmapOperation' arg0 is a 'CellList'),"
<< " the size of 'CellList' must be greater than 1, but got " << *cell_size << ".";
}
const AbstractBasePtrList &cell_list_fns = cell_list->elements();
for (auto fn_abs : cell_list_fns) {
MS_EXCEPTION_IF_NULL(fn_abs);
AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(fn_abs);
if (fn == nullptr) {
MS_LOG(EXCEPTION) << "'VmapOperation' arg0 is a 'CellList', whose elements must be 'Cell', but got "
<< fn_abs->ToString() << ".";
}
auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
if (real_fn == nullptr) {
MS_LOG(EXCEPTION) << "'VmapOperation' arg0 is a 'CellList', whose element " << fn->ToString()
<< " cast to 'FuncGraphAbstractClosure' failed.";
}
FuncGraphPtr orig_graph = real_fn->func_graph();
MS_EXCEPTION_IF_NULL(orig_graph);
orig_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
int fn_nparam = SizeToInt(orig_graph->parameters().size());
if (*nparam == -1) {
origin_graph_info = orig_graph->debug_info();
*nparam = fn_nparam;
} else if (*nparam != fn_nparam) {
MS_LOG(EXCEPTION) << "'VmapOperation' arg0 is a CellList, whose elements's inputs should be consistent.";
}
}
} else {
AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(fn_arg);
if (fn == nullptr) {
MS_LOG(EXCEPTION) << "'VmapOperation' arg0 must be a 'Function' or 'Cell', but got " << fn_arg->ToString() << ".";
}
auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
if (real_fn == nullptr) {
MS_LOG(EXCEPTION) << "'VmapOperation' arg0 " << fn->ToString() << " cast to 'FuncGraphAbstractClosure' failed.";
}
FuncGraphPtr orig_graph = real_fn->func_graph();
MS_EXCEPTION_IF_NULL(orig_graph);
orig_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
*nparam = SizeToInt(orig_graph->parameters().size());
origin_graph_info = orig_graph->debug_info();
}
return origin_graph_info;
}
} // namespace
FuncGraphPtr VmapOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
@ -1025,25 +1084,15 @@ FuncGraphPtr VmapOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
auto in_axes_arg = args_spec_list[1];
auto out_axes_arg = args_spec_list[2];
AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(fn_arg);
if (fn == nullptr) {
MS_LOG(EXCEPTION) << "'VmapOperation' arg0 must be a 'Function' or 'Cell', but got " << fn_arg->ToString() << ".";
}
int nparam = -1;
size_t cell_size = 0;
DebugInfoPtr origin_graph_info = CheckVmapFunc(fn_arg, &nparam, &cell_size);
auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
if (real_fn == nullptr) {
MS_LOG(EXCEPTION) << "'VmapOperation' arg0 " << fn->ToString() << " cast to 'FuncGraphAbstractClosure' failed.";
}
FuncGraphPtr orig_graph = real_fn->func_graph();
MS_EXCEPTION_IF_NULL(orig_graph);
orig_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
FuncGraphPtr vmap_fg = nullptr;
{
TraceGuard guard(std::make_shared<TraceVmapOperation>(orig_graph->debug_info()));
TraceGuard guard(std::make_shared<TraceVmapOperation>(origin_graph_info));
vmap_fg = std::make_shared<FuncGraph>();
}
int nparam = SizeToInt(orig_graph->parameters().size());
std::ostringstream ss;
ss << "vmap{" << nparam << "}";
@ -1056,12 +1105,13 @@ FuncGraphPtr VmapOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
(void)vmap_fg->add_parameter();
// Validity verification of in_axes and out_axes
ValuePtr in_axes = CheckAxes(in_axes_arg, true, nparam);
ValuePtr in_axes = CheckAxes(in_axes_arg, true, nparam, cell_size);
ValuePtr out_axes = CheckAxes(out_axes_arg);
PrimitivePtr kprim_vmap = std::make_shared<Primitive>(prim::kVmap, kSideEffectPropagate);
kprim_vmap->set_attr("in_axes", in_axes);
kprim_vmap->set_attr("out_axes", out_axes);
kprim_vmap->set_attr("cell_size", MakeValue(cell_size));
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(kprim_vmap));
@ -1070,8 +1120,8 @@ FuncGraphPtr VmapOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
FuncGraphPtr vmap_child = nullptr;
{
TraceGuard guard(std::make_shared<TraceVmapOperation>(orig_graph->debug_info()));
vmap_child = GetVmap(vmap, orig_graph->parameters());
TraceGuard guard(std::make_shared<TraceVmapOperation>(origin_graph_info));
vmap_child = GetVmap(vmap, nparam);
}
vmap_fg->set_output(NewValueNode(vmap_child));

View File

@ -294,7 +294,7 @@ class VmapOperation : public MetaFuncGraph {
~VmapOperation() override = default;
MS_DECLARE_PARENT(VmapOperation, MetaFuncGraph)
FuncGraphPtr GetVmap(const AnfNodePtr &vmap, const std::vector<AnfNodePtr> &forward_graph_params) const;
FuncGraphPtr GetVmap(const AnfNodePtr &vmap, int param_number) const;
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
};

View File

@ -945,18 +945,29 @@ AbstractBasePtr InferImplVmap(const AnalysisEnginePtr &, const PrimitivePtr &pri
auto fn_arg = args_spec_list[0];
MS_LOG(DEBUG) << "Evaluate Vmap: " << fn_arg->ToString() << ".";
AbstractFunctionPtr x = dyn_cast<AbstractFunction>(fn_arg);
MS_EXCEPTION_IF_NULL(x);
AbstractFuncAtomPtrList vmap_v;
ValuePtr in_axes = primitive->GetAttr("in_axes");
ValuePtr out_axes = primitive->GetAttr("out_axes");
AbstractFuncAtomPtrList vmap_v;
auto build_vmap_v = [&vmap_v, &in_axes, &out_axes](const AbstractFuncAtomPtr &func) {
auto vmap_closure = std::make_shared<VmapTransformedAbstractClosure>(func, in_axes, out_axes);
vmap_v.push_back(vmap_closure);
auto traverse_fn = [&vmap_v, &in_axes, &out_axes](const AbstractBasePtr &fn_arg) {
AbstractFunctionPtr x = dyn_cast<AbstractFunction>(fn_arg);
MS_EXCEPTION_IF_NULL(x);
auto build_vmap_v = [&vmap_v, &in_axes, &out_axes](const AbstractFuncAtomPtr &func) {
auto vmap_closure = std::make_shared<VmapTransformedAbstractClosure>(func, in_axes, out_axes);
vmap_v.push_back(vmap_closure);
};
x->Visit(build_vmap_v);
};
x->Visit(build_vmap_v);
AbstractTuplePtr cell_list = dyn_cast<AbstractTuple>(fn_arg);
if (cell_list != nullptr) {
const auto &cell_list_fns = cell_list->elements();
for (const auto &fn : cell_list_fns) {
traverse_fn(fn);
}
} else {
traverse_fn(fn_arg);
}
return AbstractFunction::MakeAbstractFunction(vmap_v);
}

View File

@ -35,7 +35,7 @@ class ExpandMetaFgPrim {
public:
ExpandMetaFgPrim() = default;
virtual ~ExpandMetaFgPrim() = default;
bool CheckIfEmbedMetaFgPrim(const CNodePtr &node) const;
virtual bool CheckIfEmbedMetaFgPrim(const CNodePtr &node) const;
const std::vector<CNodePtr> &prim_nodes() const { return prim_nodes_; }
virtual bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) = 0;
void GetMetaFgPrim(const std::vector<AnfNodePtr> &all_nodes);

View File

@ -18,7 +18,11 @@
#include <string>
#include <vector>
#include <set>
#include <regex>
#include "utils/hash_map.h"
#include "ir/func_graph_cloner.h"
#include "base/complex_storage.h"
#include "frontend/optimizer/irpass/gradient_eliminate.h"
#include "frontend/parallel/step_parallel_utils.h"
#include "pipeline/pynative/pynative_execute.h"
@ -33,43 +37,123 @@ namespace internal {
const mindspore::HashSet<std::string> throughtout_op{prim::kPrimMakeTuple->name(), prim::kPrimMakeList->name(),
prim::kPrimDepend->name(), prim::kPrimReturn->name(),
prim::kPrimUpdateState->name(), prim::kPrimStopGradient->name()};
CNodePtr BuildBindInAxisTupleInput(const AnfNodePtr &input, const ValuePtr &in_axis, const FuncGraphPtr &fg) {
auto input_abs_elements = dyn_cast<abstract::AbstractTuple>(input->abstract());
CNodePtr BuildBindInAxisSeqInput(const AnfNodePtr &input, const ValuePtr &in_axis, const FuncGraphPtr &fg) {
auto input_abs = input->abstract();
MS_EXCEPTION_IF_NULL(input_abs);
auto input_abs_elements = dyn_cast<abstract::AbstractSequence>(input_abs);
MS_EXCEPTION_IF_NULL(input_abs_elements);
ValueSequencePtr in_axis_value_sequence = nullptr;
if (in_axis->isa<ValueSequence>()) {
in_axis_value_sequence = dyn_cast<ValueSequence>(in_axis);
if (input_abs_elements->size() != in_axis_value_sequence->size()) {
MS_LOG(EXCEPTION) << "The length of input and in_axis should be the same but got input length: "
<< input_abs_elements->size() << ", in_axis length: " << in_axis_value_sequence->size() << ".";
MS_EXCEPTION(ValueError) << "The length of input and in_axis should be the same but got input length: "
<< input_abs_elements->size() << ", in_axis length: " << in_axis_value_sequence->size()
<< ".";
}
}
std::vector<AnfNodePtr> ret_inputs;
(void)ret_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
if (input_abs->isa<abstract::AbstractList>()) {
(void)ret_inputs.emplace_back(NewValueNode(prim::kPrimMakeList));
} else {
(void)ret_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
}
for (unsigned int i = 0; i < input_abs_elements->size(); ++i) {
std::vector<AnfNodePtr> tuple_getitem_cnode_inputs;
(void)tuple_getitem_cnode_inputs.emplace_back(NewValueNode(prim::kPrimTupleGetItem));
(void)tuple_getitem_cnode_inputs.emplace_back(input);
(void)tuple_getitem_cnode_inputs.emplace_back(NewValueNode(static_cast<int64_t>(i)));
auto tuple_getitem_cnode = fg->NewCNode(tuple_getitem_cnode_inputs);
std::vector<AnfNodePtr> seq_getitem_cnode_inputs;
if (input_abs->isa<abstract::AbstractList>()) {
(void)seq_getitem_cnode_inputs.emplace_back(NewValueNode(prim::kPrimListGetItem));
} else {
(void)seq_getitem_cnode_inputs.emplace_back(NewValueNode(prim::kPrimTupleGetItem));
}
(void)seq_getitem_cnode_inputs.emplace_back(input);
(void)seq_getitem_cnode_inputs.emplace_back(NewValueNode(static_cast<int64_t>(i)));
auto seq_getitem_cnode = fg->NewCNode(seq_getitem_cnode_inputs);
MS_EXCEPTION_IF_NULL(seq_getitem_cnode);
auto input_abs_element = (*input_abs_elements)[i];
auto in_axis_value = in_axis_value_sequence == nullptr ? in_axis : (*in_axis_value_sequence)[i];
CNodePtr cur_make_tuple = nullptr;
if (input_abs_element->isa<abstract::AbstractTuple>()) {
tuple_getitem_cnode->set_abstract(input_abs_element);
cur_make_tuple = BuildBindInAxisTupleInput(tuple_getitem_cnode, in_axis_value, fg);
CNodePtr cur_make_seq = nullptr;
if (input_abs_element->isa<abstract::AbstractSequence>()) {
seq_getitem_cnode->set_abstract(input_abs_element);
cur_make_seq = BuildBindInAxisSeqInput(seq_getitem_cnode, in_axis_value, fg);
} else {
std::vector<AnfNodePtr> cur_make_tuple_inputs;
(void)cur_make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
(void)cur_make_tuple_inputs.emplace_back(tuple_getitem_cnode);
(void)cur_make_tuple_inputs.emplace_back(seq_getitem_cnode);
(void)cur_make_tuple_inputs.emplace_back(NewValueNode(in_axis_value));
cur_make_tuple = fg->NewCNode(cur_make_tuple_inputs);
cur_make_seq = fg->NewCNode(cur_make_tuple_inputs);
}
(void)ret_inputs.emplace_back(cur_make_tuple);
(void)ret_inputs.emplace_back(cur_make_seq);
}
return fg->NewCNode(ret_inputs);
}
AnfNodePtr BindInAxis(const CNodePtr &vmap_app, const ValuePtr &in_axes, size_t *u_monad_offset) {
AnfNodePtr UpdateParam(const FuncGraphPtr &vmap_fg, const AnfNodePtr &u_monad_node,
ParamMappingVector *param_mapping_table) {
MS_EXCEPTION_IF_NULL(u_monad_node);
auto update_state_prim = NewValueNode(prim::kPrimUpdateState);
auto load_prim = NewValueNode(prim::kPrimLoad);
std::vector<AnfNodePtr> attach_tuple{NewValueNode(prim::kPrimMakeTuple)};
for (auto param_pair : *param_mapping_table) {
auto ref = param_pair.first;
auto each_cell_params = param_pair.second;
std::vector<AnfNodePtr> param_tuple{NewValueNode(prim::kPrimMakeTuple)};
for (auto param : each_cell_params) {
auto load_cnode = vmap_fg->NewCNode({load_prim, param, u_monad_node});
MS_EXCEPTION_IF_NULL(load_cnode);
auto param_abs = dyn_cast<abstract::AbstractRefTensor>(param->abstract());
MS_EXCEPTION_IF_NULL(param_abs);
load_cnode->set_abstract(param_abs->CloneAsTensor());
param_tuple.push_back(load_cnode);
}
auto param_tuple_cnode = vmap_fg->NewCNode(param_tuple);
auto update_state_after_load = vmap_fg->NewCNode({update_state_prim, u_monad_node, param_tuple_cnode});
const py::function stack_fn = python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_stack");
auto stack_fg = parse::ParsePythonCode(stack_fn);
MS_EXCEPTION_IF_NULL(stack_fg);
auto stack_cnode = vmap_fg->NewCNode({NewValueNode(stack_fg), param_tuple_cnode});
auto assign_prim = NewValueNode(prim::kPrimAssign);
auto assign_cnode = vmap_fg->NewCNode({assign_prim, ref, stack_cnode, update_state_after_load});
attach_tuple.push_back(assign_cnode);
}
auto attach_cnode = vmap_fg->NewCNode(attach_tuple);
auto update_state_node = vmap_fg->NewCNode({update_state_prim, u_monad_node, attach_cnode});
return update_state_node;
}
void GetMonadOffset(const std::vector<AnfNodePtr> &inputs, size_t *u_monad_offset, size_t *io_monad_offset) {
// Check the last two (if exists) is monad input.
if (*u_monad_offset != 0 || *io_monad_offset != 0) {
MS_EXCEPTION(ValueError) << "The initial value of u_monad_offset and io_monad_offset should be 0, but we got "
<< "u_monad_offset: " << *u_monad_offset << " and io_monad_offset: " << *io_monad_offset
<< ".";
}
auto inputs_size = inputs.size();
constexpr size_t max_monad_input_num = 2;
if (HasAbstractMonad(inputs[inputs_size - 1])) {
if (HasAbstractUMonad(inputs[inputs_size - 1])) {
*u_monad_offset = 1;
} else if (inputs_size >= max_monad_input_num && HasAbstractUMonad(inputs[inputs_size - max_monad_input_num])) {
++(*io_monad_offset);
*u_monad_offset = *io_monad_offset + 1;
} else {
++(*io_monad_offset);
}
}
}
void BindUMonad(const AnfNodePtr &u_monad_node, const FuncGraphPtr &vmap_fg, std::vector<AnfNodePtr> *outputs,
ParamMappingVector *param_mapping_table) {
MS_EXCEPTION_IF_NULL(u_monad_node);
if (param_mapping_table == nullptr || param_mapping_table->empty()) {
(void)outputs->emplace_back(u_monad_node);
} else {
auto update_state_node = UpdateParam(vmap_fg, u_monad_node, param_mapping_table);
(void)outputs->emplace_back(update_state_node);
}
}
AnfNodePtr BindInAxis(const CNodePtr &vmap_app, const ValuePtr &in_axes, size_t *u_monad_offset,
size_t *io_monad_offset, ParamMappingVector *param_mapping_table) {
FuncGraphPtr vmap_fg = vmap_app->func_graph();
bool is_in_axes_value_sequence = in_axes->isa<ValueSequence>();
ValueSequencePtr in_axes_to_value_sequence = dyn_cast<ValueSequence>(in_axes);
@ -77,30 +161,17 @@ AnfNodePtr BindInAxis(const CNodePtr &vmap_app, const ValuePtr &in_axes, size_t
auto inputs = vmap_app->inputs();
auto inputs_size = inputs.size();
if (inputs_size == 0) {
MS_LOG(EXCEPTION) << "The inputs number of CNode: " << vmap_app->DebugString()
<< " should be positive but got : " << inputs_size << ".";
MS_EXCEPTION(ValueError) << "The inputs number of CNode: " << vmap_app->DebugString()
<< " should be positive but got : " << inputs_size << ".";
}
// Check the last two (if exists) is monad input.
size_t io_monad_offset = 0;
constexpr size_t max_monad_input_num = 2;
if (HasAbstractMonad(inputs[inputs_size - 1])) {
if (HasAbstractUMonad(inputs[inputs_size - 1])) {
*u_monad_offset = 1;
} else if (inputs_size >= max_monad_input_num && HasAbstractUMonad(inputs[inputs_size - max_monad_input_num])) {
io_monad_offset++;
*u_monad_offset = io_monad_offset + 1;
} else {
io_monad_offset++;
}
}
size_t abstract_monad_count = *u_monad_offset > io_monad_offset ? *u_monad_offset : io_monad_offset;
auto real_params_size = inputs_size - abstract_monad_count;
GetMonadOffset(inputs, u_monad_offset, io_monad_offset);
size_t abstract_monad_count = *u_monad_offset > *io_monad_offset ? *u_monad_offset : *io_monad_offset;
size_t real_params_size = inputs_size > abstract_monad_count ? inputs_size - abstract_monad_count : 0;
if (is_in_axes_value_sequence && real_params_size - 1 != in_axes_to_value_sequence->size()) {
MS_LOG(EXCEPTION) << "The length of vmap_app inputs (except primitive input and monad input) is: "
<< real_params_size - 1 << " and the length of in_axis is: " << in_axes_to_value_sequence->size()
<< ". These two numbers should be equal.";
MS_EXCEPTION(ValueError) << "The length of vmap_app inputs (except primitive input and monad input) is: "
<< (real_params_size - 1)
<< " and the length of in_axis is: " << in_axes_to_value_sequence->size()
<< ". These two numbers should be equal.";
}
std::vector<AnfNodePtr> outputs;
@ -109,21 +180,24 @@ AnfNodePtr BindInAxis(const CNodePtr &vmap_app, const ValuePtr &in_axes, size_t
auto input = inputs[i];
auto in_axis = is_in_axes_value_sequence ? (*in_axes_to_value_sequence)[i - 1] : in_axes;
auto input_abs = input->abstract();
CNodePtr cur_make_tuple_cnode = nullptr;
if (input_abs->isa<abstract::AbstractTuple>()) {
cur_make_tuple_cnode = BuildBindInAxisTupleInput(input, in_axis, vmap_fg);
CNodePtr cur_make_seq_cnode = nullptr;
if (input_abs->isa<abstract::AbstractSequence>()) {
cur_make_seq_cnode = BuildBindInAxisSeqInput(input, in_axis, vmap_fg);
} else {
cur_make_tuple_cnode = vmap_fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), input, NewValueNode(in_axis)});
cur_make_seq_cnode = vmap_fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), input, NewValueNode(in_axis)});
}
(void)outputs.emplace_back(cur_make_tuple_cnode);
(void)outputs.emplace_back(cur_make_seq_cnode);
}
if (abstract_monad_count == 1) {
(void)outputs.emplace_back(inputs.back());
} else if (IntToSize(abstract_monad_count) == max_monad_input_num) {
(void)outputs.emplace_back(inputs[inputs_size - max_monad_input_num]);
if (*u_monad_offset > 0 && inputs_size > *u_monad_offset) {
AnfNodePtr u_monad_node = inputs[inputs_size - *u_monad_offset];
BindUMonad(u_monad_node, vmap_fg, &outputs, param_mapping_table);
}
if (*io_monad_offset > 0 && inputs_size > 0) {
(void)outputs.emplace_back(inputs.back());
}
return vmap_fg->NewCNode(outputs);
}
@ -165,8 +239,8 @@ void GetSubAxisSize(const AbstractBasePtr &sub_abs, ValuePtr *const sub_in_axes,
if (*axis_size == kInvalidAxisSize) {
*axis_size = sub_axis_size;
} else if (*axis_size != sub_axis_size) {
MS_LOG(EXCEPTION) << "The 'axis_size' of each argument in the scope of 'vmap' should be equal, but got "
<< *axis_size << " and " << sub_axis_size << ".";
MS_EXCEPTION(ValueError) << "The 'axis_size' of each argument in the scope of 'vmap' should be equal, but got "
<< *axis_size << " and " << sub_axis_size << ".";
}
}
@ -196,11 +270,29 @@ int GetAxisSizeByAbs(const AbstractBasePtr &abs, ValuePtr *const in_axes) {
auto in_axes_int = dyn_cast<Int64Imm>(*in_axes);
if (in_axes_int != nullptr) {
int64_t axis = in_axes_int->value();
ShapeVector orig_shape = dyn_cast<abstract::Shape>(abs->BuildShape())->shape();
if (!abs->isa<abstract::AbstractTensor>()) {
// If got a AbstractScalar with a value 0 of type int32, it means that the input is not used later.
auto abs_value = abs->BuildValue();
MS_EXCEPTION_IF_NULL(abs_value);
auto abs_int32_t = dyn_cast<Int32Imm>(abs_value);
MS_EXCEPTION_IF_NULL(abs_int32_t);
if (abs_int32_t->value() == 0) {
MS_LOG(WARNING) << "There is a argument not used in the scope of vmap. Please check whether the inputs"
<< " meet expectations.";
return axis_size;
}
MS_EXCEPTION(ValueError) << "The abs should be AbstractTensor when axis is " << axis << ", but got a "
<< abs->ToString() << ".";
}
auto shape = abs->BuildShape();
MS_EXCEPTION_IF_NULL(shape);
auto shape_ptr = dyn_cast<abstract::Shape>(shape);
MS_EXCEPTION_IF_NULL(shape_ptr);
ShapeVector orig_shape = shape_ptr->shape();
int64_t shape_len = SizeToLong(orig_shape.size());
if (axis < -shape_len || axis >= shape_len) {
MS_LOG(EXCEPTION) << "ValueError: axis " << axis << " is out of bounds for array of dimension [" << -shape_len
<< "," << shape_len << ").";
MS_EXCEPTION(ValueError) << "ValueError: axis " << axis << " is out of bounds for array of dimension ["
<< -shape_len << "," << shape_len << ").";
}
axis = axis < 0 ? shape_len + axis : axis;
*in_axes = std::make_shared<Int64Imm>(axis);
@ -212,11 +304,10 @@ int GetAxisSizeByAbs(const AbstractBasePtr &abs, ValuePtr *const in_axes) {
// get the axis size of currently vmap scope, at the same time, the negative indexes in in_axes are converted to
// corresponding positive indexes.
int GetAxisSize(const CNodePtr &cnode, ValuePtr *const in_axes) {
int GetAxisSize(const CNodePtr &cnode, size_t cell_size, size_t parameters_size, ValuePtr *const in_axes) {
MS_EXCEPTION_IF_NULL(cnode);
// `axis_size` is unique within the scope of vmap, so we just need to get one of them.
int axis_size = kInvalidAxisSize;
size_t parameters_size = cnode->size() - 1;
auto in_axes_seq = GetInAxesSeq(*in_axes, parameters_size);
std::vector<ValuePtr> corrected_in_axes;
for (size_t i = 0; i < parameters_size; ++i) {
@ -228,48 +319,143 @@ int GetAxisSize(const CNodePtr &cnode, ValuePtr *const in_axes) {
GetSubAxisSize(sub_abs, &sub_in_axes, &axis_size, &corrected_in_axes);
}
*in_axes = std::make_shared<ValueSequence>(corrected_in_axes);
if (cell_size > 0) {
if (axis_size == kInvalidAxisSize) {
axis_size = SizeToLong(cell_size);
} else if (SizeToLong(cell_size) != axis_size) {
MS_EXCEPTION(ValueError) << "If you want to execute the model ensembling parallel training, please make sure "
<< "the 'axis_size' in the scope of vmap consistent with the cell size of the input "
<< "'CellList', otherwise, please do not enter 'CellList' as the first argument, "
<< "but we get axis_size: " << axis_size << " and the cell size: " << cell_size << ".";
}
} else if (axis_size == kInvalidAxisSize) {
MS_LOG(EXCEPTION) << "Failed to get 'axis_size' within the scope of vmap.";
}
return axis_size;
}
AnfNodePtr MatchOutAxis(const AnfNodePtr &expanded_vmap_node, int parameters_size, size_t u_monad_offset, int axis_size,
const ValuePtr &out_axes) {
CNodePtr AttachToOutput(const FuncGraphPtr &func_graph, const CNodePtr &output, const AnfNodePtr &node) {
TraceGuard guard(std::make_shared<TraceCopy>(output->debug_info()));
auto depend = NewValueNode(prim::kPrimDepend);
auto depend_cnode = func_graph->NewCNode({depend, output, node});
MS_EXCEPTION_IF_NULL(depend_cnode);
depend_cnode->set_abstract(output->abstract());
return depend_cnode;
}
AnfNodePtr FeedBackParam(const FuncGraphPtr &vmap_post_fg, const AnfNodePtr &u_monad_node,
const AnfNodePtr &io_monad_node, const CNodePtr &output,
ParamMappingVector *param_mapping_table) {
auto update_state_prim = NewValueNode(prim::kPrimUpdateState);
std::vector<AnfNodePtr> out_attach_tuple{NewValueNode(prim::kPrimMakeTuple)};
for (auto param_pair : *param_mapping_table) {
auto ref = param_pair.first;
auto load_prim = NewValueNode(prim::kPrimLoad);
auto load_cnode = vmap_post_fg->NewCNode({load_prim, ref, u_monad_node});
MS_EXCEPTION_IF_NULL(load_cnode);
auto ref_abs = dyn_cast<abstract::AbstractRefTensor>(ref->abstract());
MS_EXCEPTION_IF_NULL(ref_abs);
load_cnode->set_abstract(ref_abs->CloneAsTensor());
auto update_state_after_load = vmap_post_fg->NewCNode({update_state_prim, u_monad_node, load_cnode});
MS_EXCEPTION_IF_NULL(update_state_after_load);
update_state_after_load->set_abstract(u_monad_node->abstract());
PrimitivePtr kprim_unstack = std::make_shared<Primitive>(prim::kUnstack);
kprim_unstack->set_attr(kAttrAxis, MakeValue(SizeToLong(0)));
auto unstack_prim = NewValueNode(kprim_unstack);
auto unstack_cnode = vmap_post_fg->NewCNode({unstack_prim, load_cnode});
auto each_cell_params = param_pair.second;
std::vector<AnfNodePtr> attach_tuple{NewValueNode(prim::kPrimMakeTuple)};
int64_t cell_index = 0;
for (auto param : each_cell_params) {
auto tuple_getitem_prim = NewValueNode(prim::kPrimTupleGetItem);
auto tuple_getitem_cnode = vmap_post_fg->NewCNode({tuple_getitem_prim, unstack_cnode, NewValueNode(cell_index)});
cell_index++;
auto assign_prim = NewValueNode(prim::kPrimAssign);
auto assign_cnode = vmap_post_fg->NewCNode({assign_prim, param, tuple_getitem_cnode, update_state_after_load});
attach_tuple.push_back(assign_cnode);
}
auto attach_cnode = vmap_post_fg->NewCNode(attach_tuple);
auto update_state_after_assign = vmap_post_fg->NewCNode({update_state_prim, update_state_after_load, attach_cnode});
MS_EXCEPTION_IF_NULL(update_state_after_assign);
update_state_after_assign->set_abstract(update_state_after_load->abstract());
out_attach_tuple.push_back(update_state_after_assign);
}
auto out_attach_cnode = vmap_post_fg->NewCNode(out_attach_tuple);
auto attach_output = AttachToOutput(vmap_post_fg, output, out_attach_cnode);
if (io_monad_node) {
attach_output = AttachToOutput(vmap_post_fg, attach_output, io_monad_node);
}
vmap_post_fg->set_output(attach_output);
return NewValueNode(vmap_post_fg);
}
AnfNodePtr PostProcessVmap(const AnfNodePtr &expanded_vmap_node, const std::vector<size_t> &orig_fg_param_info,
const ValuePtr &out_axes, int axis_size, ParamMappingVector *param_mapping_table) {
FuncGraphPtr vmap_post_fg = std::make_shared<FuncGraph>();
std::vector<AnfNodePtr> exec_node;
exec_node.push_back(expanded_vmap_node);
AnfNodePtr u_monad_node = nullptr;
int offset = SizeToInt(u_monad_offset);
int u_monad_index = parameters_size > offset ? parameters_size - offset : parameters_size;
for (int i = 0; i < parameters_size; ++i) {
AnfNodePtr io_monad_node = nullptr;
size_t parameters_size = orig_fg_param_info[kParamSizeIndex];
size_t u_monad_offset = orig_fg_param_info[kUMonadOffsetIndex];
size_t io_monad_offset = orig_fg_param_info[kIOMonadOffsetIndex];
size_t u_monad_index = parameters_size > u_monad_offset ? parameters_size - u_monad_offset : parameters_size;
size_t io_monad_index = parameters_size > io_monad_offset ? parameters_size - io_monad_offset : parameters_size;
for (size_t i = 0; i < parameters_size; ++i) {
if (i == u_monad_index) {
u_monad_node = vmap_post_fg->add_parameter();
exec_node.push_back(u_monad_node);
continue;
} else if (i == io_monad_index) {
io_monad_node = vmap_post_fg->add_parameter();
exec_node.push_back(io_monad_node);
continue;
}
exec_node.push_back(vmap_post_fg->add_parameter());
}
auto vmap_outputs = vmap_post_fg->NewCNode(exec_node);
if (u_monad_node != nullptr) {
auto update_state_prim = NewValueNode(prim::kPrimUpdateState);
auto update_state_prim = NewValueNode(prim::kPrimUpdateState);
if (u_monad_node) {
auto update_state_cnode = vmap_post_fg->NewCNode({update_state_prim, u_monad_node, vmap_outputs});
MS_EXCEPTION_IF_NULL(update_state_cnode);
update_state_cnode->set_abstract(u_monad_node->abstract());
u_monad_node = update_state_cnode;
}
if (io_monad_node) {
auto update_state_cnode = vmap_post_fg->NewCNode({update_state_prim, io_monad_node, vmap_outputs});
MS_EXCEPTION_IF_NULL(update_state_cnode);
update_state_cnode->set_abstract(io_monad_node->abstract());
io_monad_node = update_state_cnode;
}
// MatchOutAxis: Convert the outputs according to the out_axes to the specified physical perspective.
auto match_out_axis_app =
vmap_post_fg->NewCNode({NewValueNode(std::make_shared<prim::VmapMatchOutAxis>("VmapMatchOutAxis")), vmap_outputs,
NewValueNode(out_axes), NewValueNode(static_cast<int64_t>(axis_size))});
if (u_monad_node != nullptr) {
auto depend_prim = NewValueNode(prim::kPrimDepend);
auto state_depend = vmap_post_fg->NewCNode({depend_prim, match_out_axis_app, u_monad_node});
state_depend->set_abstract(match_out_axis_app->abstract());
vmap_post_fg->set_output(state_depend);
if (param_mapping_table == nullptr || param_mapping_table->empty()) {
auto output = match_out_axis_app;
if (u_monad_node) {
output = AttachToOutput(vmap_post_fg, output, u_monad_node);
}
if (io_monad_node) {
output = AttachToOutput(vmap_post_fg, output, io_monad_node);
}
vmap_post_fg->set_output(output);
return NewValueNode(vmap_post_fg);
}
vmap_post_fg->set_output(match_out_axis_app);
return NewValueNode(vmap_post_fg);
// Feed parameters back to each cell in the model ensembling parallel training case.
return FeedBackParam(vmap_post_fg, u_monad_node, io_monad_node, match_out_axis_app, param_mapping_table);
}
AnfNodePtr GetVmapRule(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resource, int axis_size) {
@ -392,13 +578,19 @@ AnfNodePtr CopyNodeToVmap(const AnfNodePtr &node, const FuncGraphPtr &func_graph
return node;
}
void BindNoneAxis(const AnfNodePtr &node, const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng) {
void BindFvAxis(const AnfNodePtr &node, const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
const AnfNodePtr &stacked_param_node = nullptr) {
MS_EXCEPTION_IF_NULL(node);
auto &node_user_map = mng->node_users();
auto user = node_user_map.find(node);
if (user != node_user_map.end() && !user->second.empty()) {
auto make_tuple = NewValueNode(prim::kPrimMakeTuple);
auto replace_node = func_graph->NewCNode({make_tuple, node, NewValueNode(kNone)});
CNodePtr replace_node = nullptr;
if (stacked_param_node == nullptr) {
replace_node = func_graph->NewCNode({make_tuple, node, NewValueNode(kNone)});
} else {
replace_node = func_graph->NewCNode({make_tuple, stacked_param_node, NewValueNode(SizeToLong(0))});
}
auto user_set = user->second;
for (auto pair : user_set) {
if (pair.first->func_graph() == func_graph) {
@ -409,13 +601,36 @@ void BindNoneAxis(const AnfNodePtr &node, const FuncGraphPtr &func_graph, const
}
}
void BindParamAxis(const AnfNodePtr &node, const FuncGraphPtr &vmap_fg, const FuncGraphManagerPtr &manager,
mindspore::HashMap<std::string, ParameterPtr> *stacked_params) {
if (stacked_params == nullptr || stacked_params->empty()) {
BindFvAxis(node, vmap_fg, manager);
return;
}
std::string param_name = dyn_cast<Parameter>(node)->name();
std::regex e("^.*?\\d+\\.(.+)$");
param_name = std::regex_replace(param_name, e, "vmap.$1");
auto iter = stacked_params->find(param_name);
if (iter != stacked_params->end()) {
ParameterPtr stacked_param_node = iter->second;
MS_EXCEPTION_IF_NULL(stacked_param_node);
BindFvAxis(node, vmap_fg, manager, stacked_param_node);
} else {
BindFvAxis(node, vmap_fg, manager);
}
}
void ExpandVmapValueNode(const FuncGraphPtr &vmap_fg, const pipeline::ResourceBasePtr &resource,
mindspore::HashSet<FuncGraphPtr> *visited_graph, mindspore::HashSet<AnfNodePtr> *visited_node,
int axis_size) {
VisitedHashSetPair *visited_pair, int axis_size,
mindspore::HashMap<std::string, ParameterPtr> *stacked_params) {
// Map ValueNode.
auto manager = resource->manager();
MS_EXCEPTION_IF_NULL(manager);
auto value_nodes = vmap_fg->value_nodes();
auto visited_graph = &visited_pair->first;
auto visited_node = &visited_pair->second;
for (const auto &value_pair : value_nodes) {
auto node = value_pair.first;
// ValueNode may have been transformed when other graphs are expanded.
@ -432,7 +647,7 @@ void ExpandVmapValueNode(const FuncGraphPtr &vmap_fg, const pipeline::ResourceBa
continue;
}
(void)visited_graph->insert(sub_func_graph);
auto transformed_fg = ExpandVmapFunctor(sub_func_graph, resource, visited_graph, visited_node, axis_size);
auto transformed_fg = ExpandVmapFunctor(sub_func_graph, resource, axis_size, visited_pair, stacked_params);
auto replace_node = NewValueNode(transformed_fg);
(void)visited_node->insert(replace_node);
(void)manager->Replace(node, replace_node);
@ -461,28 +676,34 @@ void ExpandVmapValueNode(const FuncGraphPtr &vmap_fg, const pipeline::ResourceBa
}
void ExpandVmapFreeVariable(const FuncGraphPtr &vmap_fg, const FuncGraphManagerPtr &manager,
const mindspore::HashSet<AnfNodePtr> &visited_node) {
const mindspore::HashSet<AnfNodePtr> &visited_node,
mindspore::HashMap<std::string, ParameterPtr> *stacked_params) {
// Map free variable.
auto free_variables_nodes = vmap_fg->free_variables_nodes();
for (auto &node : free_variables_nodes) {
auto free_variables_nodes = vmap_fg->free_variables();
for (auto &pair : free_variables_nodes) {
auto node = pair.first;
if (visited_node.count(node) > 0 || node->isa<CNode>()) {
MS_LOG(DEBUG) << node->DebugString() << " has been transformed.";
} else if (node->isa<Parameter>() || IsValueNode<Scalar>(node) || IsValueNode<tensor::Tensor>(node) ||
IsValueNode<None>(node) || IsValueNode<ValueTuple>(node) || IsValueNode<Type>(node)) {
BindNoneAxis(node, vmap_fg, manager);
continue;
}
if (IsValueNode<Scalar>(node) || IsValueNode<tensor::Tensor>(node) || IsValueNode<None>(node) ||
IsValueNode<ValueTuple>(node) || IsValueNode<Type>(node)) {
BindFvAxis(node, vmap_fg, manager);
} else if (node->isa<Parameter>()) {
BindParamAxis(node, vmap_fg, manager, stacked_params);
} else {
MS_LOG(EXCEPTION) << "vmap do not support transform " << node->DebugString() << " right now.";
}
}
}
FuncGraphPtr ExpandVmapFunctor(const FuncGraphPtr &vmap_fg, const pipeline::ResourceBasePtr &resource,
mindspore::HashSet<FuncGraphPtr> *visited_graph,
mindspore::HashSet<AnfNodePtr> *visited_node, int axis_size) {
FuncGraphPtr ExpandVmapFunctor(const FuncGraphPtr &vmap_fg, const pipeline::ResourceBasePtr &resource, int axis_size,
VisitedHashSetPair *visited_pair,
mindspore::HashMap<std::string, ParameterPtr> *stacked_params) {
MS_EXCEPTION_IF_NULL(vmap_fg);
auto manager = resource->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(vmap_fg);
auto visited_node = &visited_pair->second;
// The parameters of the current graph will be transformed in the upper graph, and recorded in
// `visited_node` to avoid being repeatedly transformed refer as a free variable in other graph.
@ -492,14 +713,15 @@ FuncGraphPtr ExpandVmapFunctor(const FuncGraphPtr &vmap_fg, const pipeline::Reso
(void)visited_node->insert(node);
}
ExpandVmapValueNode(vmap_fg, resource, visited_graph, visited_node, axis_size);
ExpandVmapFreeVariable(vmap_fg, manager, *visited_node);
ExpandVmapValueNode(vmap_fg, resource, visited_pair, axis_size, stacked_params);
ExpandVmapFreeVariable(vmap_fg, manager, *visited_node, stacked_params);
return vmap_fg;
}
// Entry to perform Vmap transformation.
AnfNodePtr ExpandVmap(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource, int axis_size) {
AnfNodePtr ExpandVmap(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource, int axis_size,
mindspore::HashMap<std::string, ParameterPtr> *stacked_params) {
MS_EXCEPTION_IF_NULL(vnode);
if (IsValueNode<FuncGraph>(vnode)) {
ScopeGuard scope_guard(vnode->scope());
@ -514,7 +736,8 @@ AnfNodePtr ExpandVmap(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr
(void)visited_graph.insert(func_graph);
(void)visited_node.insert(vnode);
auto tf_fg = ExpandVmapFunctor(func_graph, resource, &visited_graph, &visited_node, axis_size);
VisitedHashSetPair visited_pair(visited_graph, visited_node);
auto tf_fg = ExpandVmapFunctor(func_graph, resource, axis_size, &visited_pair, stacked_params);
visited_node.clear();
return NewValueNode(tf_fg);
@ -522,8 +745,150 @@ AnfNodePtr ExpandVmap(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr
MS_LOG(EXCEPTION) << "Currently, the first argument in F.vmap only supports Cell, Python defined "
"function or @ms_function decorated function.";
}
std::string GetShapeString(const ShapeVector &tensor_shape) {
std::ostringstream oss;
oss << " Shape:";
for (auto &dim : tensor_shape) {
oss << " " << dim;
}
return oss.str();
}
void GenerateStackedParams(const FuncGraphPtr vmap_fg, size_t cell_size,
const std::vector<std::vector<AnfNodePtr>> &param_table,
mindspore::HashMap<std::string, ParameterPtr> *stacked_params,
ParamMappingVector *param_mapping_table) {
MS_EXCEPTION_IF_NULL(vmap_fg);
FuncGraphPtr top_fg = vmap_fg;
while (top_fg->parent() != nullptr) {
top_fg = top_fg->parent();
}
ShapeVector tensor_shape;
TypeId tensor_type = kNumberTypeFloat32;
std::string param_name = "";
for (size_t i = 0; i < param_table[0].size(); ++i) {
std::vector<AnfNodePtr> param;
for (size_t j = 0; j < param_table.size(); ++j) {
auto param_node = dyn_cast<Parameter>(param_table[j][i]);
(void)param.emplace_back(param_node);
MS_EXCEPTION_IF_NULL(param_node);
auto default_param = param_node->default_param();
MS_EXCEPTION_IF_NULL(default_param);
auto param_tensor = dyn_cast<tensor::Tensor>(default_param);
MS_EXCEPTION_IF_NULL(param_tensor);
if (j == 0) {
tensor_shape = param_tensor->shape();
tensor_type = param_tensor->data_type();
param_name = param_node->name();
} else {
if (tensor_type != param_tensor->data_type()) {
MS_LOG(EXCEPTION) << "The corresponding parameter's type in each cell should be consistent, but get "
<< TypeIdToType(tensor_type)->ToString() << " and "
<< TypeIdToType(param_tensor->data_type())->ToString() << " for the parameter "
<< param_name << ".";
}
if (tensor_shape != param_tensor->shape()) {
MS_LOG(EXCEPTION) << "The corresponding parameter's shape in each cell should be consistent, but get "
<< GetShapeString(tensor_shape) << " and " << GetShapeString(param_tensor->shape())
<< " for the parameter " << param_name << ".";
}
}
}
std::regex e("^.*?0\\.(.+)$");
param_name = std::regex_replace(param_name, e, "vmap.$1");
ParameterPtr param_node = nullptr;
ShapeVector stacked_shape(tensor_shape);
(void)stacked_shape.insert(stacked_shape.begin(), cell_size);
tensor::TensorPtr stacked_param_tensor = std::make_shared<tensor::Tensor>(tensor_type, stacked_shape);
MS_EXCEPTION_IF_NULL(stacked_param_tensor);
ParamInfoPtr param_info = std::make_shared<ParamInfo>();
param_info->set_name(param_name);
stacked_param_tensor->set_param_info(param_info);
param_node = top_fg->AddFvParameter(param_name, stacked_param_tensor);
MS_LOG(DEBUG) << "Add new parameter " << param_node->ToString() << "to the top graph " << top_fg->ToString() << ".";
(*stacked_params)[param_name] = param_node;
std::pair<ParameterPtr, std::vector<AnfNodePtr>> param_mapping(param_node, param);
(void)param_mapping_table->emplace_back(param_mapping);
}
}
void GetCellParams(const FuncGraphPtr &vmap_fg, std::vector<AnfNodePtr> *param_nodes) {
std::set<AnfNodePtr> memo;
auto scan_fn = [&memo, param_nodes](const FuncGraphPtr &vmap_fg) {
auto fv_nodes = vmap_fg->free_variables();
for (auto &pair : fv_nodes) {
auto node = pair.first;
if (node->isa<Parameter>() && node->cast<ParameterPtr>()->has_default() && memo.emplace(node).second) {
(void)param_nodes->emplace_back(node);
}
}
};
scan_fn(vmap_fg);
auto used_fgs = vmap_fg->func_graphs_used_total();
for (auto &fg : used_fgs) {
scan_fn(fg);
}
}
AnfNodePtr TraverseVmapNode(CNodePtr vmap_node, size_t cell_size,
mindspore::HashMap<std::string, ParameterPtr> *stacked_params,
ParamMappingVector *param_mapping_table) {
AnfNodePtr vmap_fn_node = nullptr;
auto cell_list_node = vmap_node->input(1);
CNodePtr cnode = cell_list_node->cast<CNodePtr>();
auto inputs_size = cnode->size();
if (inputs_size != (cell_size + 1)) {
MS_EXCEPTION(ValueError) << "The size of CellList Node should be equal to" << (cell_size + 1) << ", but get"
<< inputs_size << ".";
}
std::vector<std::vector<AnfNodePtr>> param_table(cell_size, std::vector<AnfNodePtr>());
FuncGraphPtr vmap_fg = nullptr;
size_t param_size = 0;
for (size_t i = 1; i < inputs_size; i++) {
vmap_fn_node = cnode->input(i);
vmap_fg = GetValueNode<FuncGraphPtr>(vmap_fn_node);
MS_EXCEPTION_IF_NULL(vmap_fg);
GetCellParams(vmap_fg, &param_table[i - 1]);
if (param_size == 0) {
param_size = param_table[i - 1].size();
} else if (param_size != param_table[i - 1].size()) {
MS_EXCEPTION(ValueError) << "Parameter size of each cell should be consistent, but get " << param_size << " and "
<< param_table[i - 1].size() << ".";
}
}
GenerateStackedParams(vmap_fg, cell_size, param_table, stacked_params, param_mapping_table);
return vmap_fn_node;
}
} // namespace internal
bool ExpandVmapPrim::CheckIfEmbedMetaFgPrim(const CNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node);
AnfNodePtr value_node = node->input(1);
if (IsPrimitiveCNode(value_node, prim::kPrimMakeTuple)) {
CNodePtr cnode = value_node->cast<CNodePtr>();
value_node = cnode->input(1);
}
if (IsValueNode<Primitive>(value_node)) {
return false;
}
auto func_graph = GetValueNode<FuncGraphPtr>(value_node);
if (func_graph == nullptr) {
MS_LOG(EXCEPTION) << "Unexpected meta function graph node:" << node->DebugString();
}
auto func_graph_manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(func_graph_manager);
return func_graph_manager->func_graph_meta_fg_prim_total(func_graph);
}
bool ExpandVmapPrim::operator()(const FuncGraphPtr &, const OptimizerPtr &optimizer) {
// Expand vmap nodes that don't have embed j or vmap nodes.
bool change = false;
@ -536,32 +901,41 @@ bool ExpandVmapPrim::operator()(const FuncGraphPtr &, const OptimizerPtr &optimi
MS_EXCEPTION_IF_NULL(in_axes);
ValuePtr out_axes = VmapPrim->GetAttr("out_axes");
MS_EXCEPTION_IF_NULL(out_axes);
ValuePtr cell_size_value = VmapPrim->GetAttr("cell_size");
MS_EXCEPTION_IF_NULL(cell_size_value);
auto cell_size = cell_size_value->isa<UInt64Imm>() ? dyn_cast<UInt64Imm>(cell_size_value)->value() : 0;
auto vmap_fn_node = vmap_node->input(1);
auto vmap_fg = GetValueNode<FuncGraphPtr>(vmap_fn_node);
auto &fn_users = manager->node_users()[vmap_fn_node];
size_t fn_users_size = fn_users.size();
AnfNodePtr vmap_fn_node = nullptr;
mindspore::HashMap<std::string, ParameterPtr> stacked_params;
// Record the stacked parameters, and the corresponding origin parameters from each cell, preserved
// for future feedback.
ParamMappingVector param_mapping_table;
if (cell_size > 0) {
// This branch handles the model ensembling parallel training case. Get one function node in the 'CellList'
// as the vmap function, meanwhile preprocess the cells parameters to get the stacked parameters and
// the parameters mapping table.
vmap_fn_node = internal::TraverseVmapNode(vmap_node, cell_size, &stacked_params, &param_mapping_table);
} else {
vmap_fn_node = vmap_node->input(1);
}
MS_EXCEPTION_IF_NULL(vmap_fn_node);
FuncGraphPtr vmap_fg = GetValueNode<FuncGraphPtr>(vmap_fn_node);
auto users = manager->node_users()[vmap_node];
if (users.size() < 1) {
MS_LOG(EXCEPTION) << "vmap_node could used by at least one CNode, but got users.size() = " << users.size() << ".";
MS_EXCEPTION(ValueError) << "vmap_node could used by at least one CNode, but got users.size() = " << users.size()
<< ".";
}
size_t user_nb = 0;
size_t user_size = users.size();
for (auto &user : users) {
user_nb++;
for (auto &user : users) {
// When `vmap_node` has more than one user or `fn` has more than one user, the original function graph
// cannot be modified directly.
if ((user_size > 1 && user_nb != user_size) || fn_users_size > 1) {
MS_LOG(DEBUG) << "Funcgraph: " << vmap_fg->ToString() << " is also used outside the scope of vmap.";
auto vmap_fg_copy = BasicClone(vmap_fg, true);
auto manager_ptr = optimizer->resource()->manager();
manager_ptr->AddFuncGraph(vmap_fg_copy);
vmap_fn_node = NewValueNode(vmap_fg_copy);
} else {
vmap_fn_node = NewValueNode(vmap_fg);
}
MS_LOG(DEBUG) << "Funcgraph: " << vmap_fg->ToString() << " is also used outside the scope of vmap.";
auto vmap_fg_copy = BasicClone(vmap_fg, true);
manager->AddFuncGraph(vmap_fg_copy);
vmap_fn_node = NewValueNode(vmap_fg_copy);
if (parallel::IsPynativeParallel()) {
auto func_graph = GetValueNode<FuncGraphPtr>(vmap_fn_node);
@ -571,26 +945,34 @@ bool ExpandVmapPrim::operator()(const FuncGraphPtr &, const OptimizerPtr &optimi
// get axis size, simultaneous correction the negative in_axes.
auto vmap_app = user.first->cast<CNodePtr>();
int user_index = user.second;
int parameters_size = SizeToInt(vmap_app->size() - 1);
int axis_size = internal::GetAxisSize(vmap_app, &in_axes);
if (axis_size == kInvalidAxisSize) {
MS_LOG(EXCEPTION) << "Failed to get 'axis_size' within the scope of vmap.";
if (vmap_app->size() < 1) {
MS_LOG(EXCEPTION) << "Something went wrong, CNode vmap_app's arguments is less than 1, CNode: "
<< vmap_app->DebugString();
}
MS_LOG(DEBUG) << "The axis size corresponding to the current level vmap scope is " << axis_size << ".";
size_t parameters_size = vmap_app->size() - 1;
std::vector<size_t> orig_fg_param_info;
(void)orig_fg_param_info.emplace_back(parameters_size);
int axis_size = internal::GetAxisSize(vmap_app, cell_size, parameters_size, &in_axes);
// Step1: Bind the inputs with the corresponding in_axes.
size_t u_monad_offset = 0;
auto bind_axes_node = internal::BindInAxis(vmap_app, in_axes, &u_monad_offset);
size_t io_monad_offset = 0;
auto bind_axes_node =
internal::BindInAxis(vmap_app, in_axes, &u_monad_offset, &io_monad_offset, &param_mapping_table);
MS_EXCEPTION_IF_NULL(bind_axes_node);
(void)manager->Replace(vmap_app, bind_axes_node);
(void)orig_fg_param_info.emplace_back(u_monad_offset);
(void)orig_fg_param_info.emplace_back(io_monad_offset);
// Step2: Bind the variables with the corresponding axis, and overload the original
// operation with the VmapRule operation meanwhile transfer the axis information.
auto expanded_vmap = internal::ExpandVmap(vmap_fn_node->cast<ValueNodePtr>(), optimizer->resource(), axis_size);
auto expanded_vmap =
internal::ExpandVmap(vmap_fn_node->cast<ValueNodePtr>(), optimizer->resource(), axis_size, &stacked_params);
MS_EXCEPTION_IF_NULL(expanded_vmap);
// Step3: Convert the outputs according to the out_axes to the specified physical perspective.
auto match_out_axis = internal::MatchOutAxis(expanded_vmap, parameters_size, u_monad_offset, axis_size, out_axes);
// Step3: Post processing of converted vmap function graph, including: MatchOutAxis and Parameter feedback.
auto match_out_axis =
internal::PostProcessVmap(expanded_vmap, orig_fg_param_info, out_axes, axis_size, &param_mapping_table);
MS_EXCEPTION_IF_NULL(match_out_axis);
manager->SetEdge(bind_axes_node, user_index, match_out_axis);
}

View File

@ -18,6 +18,9 @@
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_VMAP_ELIMINATE_H_
#include <memory>
#include <utility>
#include <string>
#include <vector>
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/irpass.h"
@ -31,20 +34,28 @@
namespace mindspore {
namespace opt {
namespace irpass {
using ParamMappingVector = std::vector<std::pair<ParameterPtr, std::vector<AnfNodePtr>>>;
// {prim::kPrimVmap, C}
class ExpandVmapPrim : public ExpandMetaFgPrim {
public:
ExpandVmapPrim() { prim_ = prim::kPrimVmap; }
virtual ~ExpandVmapPrim() = default;
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) override;
bool CheckIfEmbedMetaFgPrim(const CNodePtr &node) const override;
};
using ExpandVmapPrimPtr = std::shared_ptr<ExpandVmapPrim>;
namespace internal {
constexpr int64_t kParamSizeIndex = 0;
constexpr int64_t kUMonadOffsetIndex = 1;
constexpr int64_t kIOMonadOffsetIndex = 2;
using VisitedHashSetPair = std::pair<mindspore::HashSet<FuncGraphPtr>, mindspore::HashSet<AnfNodePtr>>;
constexpr char kVmapFunctionModelName[] = "mindspore.ops._vmap";
int GetAxisSizeByAbs(const AbstractBasePtr &abs, ValuePtr *const in_axes);
FuncGraphPtr ExpandVmapFunctor(const FuncGraphPtr &vmap_fg, const pipeline::ResourceBasePtr &resource,
mindspore::HashSet<FuncGraphPtr> *visited_graph,
mindspore::HashSet<AnfNodePtr> *visited_node, int axis_size);
FuncGraphPtr ExpandVmapFunctor(const FuncGraphPtr &vmap_fg, const pipeline::ResourceBasePtr &resource, int axis_size,
VisitedHashSetPair *visited_pair,
mindspore::HashMap<std::string, ParameterPtr> *stacked_params = nullptr);
} // namespace internal
} // namespace irpass
} // namespace opt

View File

@ -742,6 +742,17 @@ class SideEffectFinder {
return TraceTupleGetItemEffectInfo(cnode, &tuple_indexes);
}
if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) {
// Trace make_tuple.
const auto &inputs = cnode->inputs();
EffectInfo info{EffectInfo::kDetected, false, false, false};
for (size_t i = 1; i < inputs.size(); ++i) {
auto input_info = TraceEffectInfo(inputs[i]);
info.Merge(input_info);
}
return info;
}
// For high-order pritimive such as Partial,
// we trace effect info from its argument.
int index_prim = GetSideEffectPropagate(prim);

View File

@ -19,7 +19,7 @@ from __future__ import absolute_import
from . import vmap_base, vmap_array_ops, vmap_grad_nn_ops, vmap_debug_ops, vmap_math_ops, vmap_nn_ops,\
vmap_image_ops, vmap_other_ops, vmap_sparse_ops, vmap_random_ops, vmap_convolution_ops, vmap_grad_math_ops
from .vmap_base import get_vmap_rule, vmap_monad_rule, _broadcast_by_axis, vmap_bind_all_none,\
vmap_unstack, vmap_general_output_process
vmap_unstack, vmap_stack, vmap_general_output_process
__all__ = ['get_vmap_rule', 'vmap_monad_rule', '_broadcast_by_axis', 'vmap_bind_all_none',
'vmap_unstack', 'vmap_general_output_process']
'vmap_unstack', 'vmap_stack', 'vmap_general_output_process']

View File

@ -190,6 +190,10 @@ def vmap_unstack(dim, val):
return P.Unstack(dim)(val)
def vmap_stack(val):
return P.Stack()(val)
def vmap_general_output_process(output):
""" Match output to axis 0"""
vals_out_tuple = ()

View File

@ -14,6 +14,7 @@
# ============================================================================
"""sparse_ops vmap impl."""
from __future__ import absolute_import
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error
from mindspore.ops.primitive import Primitive

View File

@ -601,8 +601,13 @@ class _Vmap(VmapOperation_):
VmapOperation_.__init__(self, 'vmap')
self.vmap_fn = None
self.fn = None
self.in_axes = None
self.out_axes = None
def __call__(self, fn, in_axes=0, out_axes=0):
if self.vmap_fn is not None and self.fn == fn and self.in_axes == in_axes and self.out_axes == out_axes:
return self.vmap_fn
vmap_ = self
@ms_function
@ -611,6 +616,8 @@ class _Vmap(VmapOperation_):
self.vmap_fn = after_vmap
self.fn = fn
self.in_axes = in_axes
self.out_axes = out_axes
return self.vmap_fn

View File

@ -389,3 +389,50 @@ def test_vmap_with_tuple_input():
assert isinstance(res[1], Tensor)
assert np.allclose(res[0].asnumpy(), np.array([[2, 2, 2], [2, 2, 2]]))
assert np.allclose(res[1].asnumpy(), np.array([[4, 4, 4], [4, 4, 4]]))
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_vmap_with_celllist_input():
"""
Feature: vmap
Description: When vmap use celllist inputs in graph, it is executing the model ensembling parallel scenario.
Expectation: success
"""
class AssignNet(nn.Cell):
def __init__(self):
super(AssignNet, self).__init__()
self.assign = P.Assign()
self.ref_a = Parameter(Tensor([0, 1, 2], mstype.float32), name='ref_a')
self.ref_b = Parameter(Tensor([0, 1, 2], mstype.float32), name='ref_b')
def construct(self, replace_tensor):
out = self.assign(self.ref_a, replace_tensor)
out = self.ref_b + out
return out
m1 = AssignNet()
m2 = AssignNet()
m3 = AssignNet()
mm = nn.CellList([m1, m2, m3])
replace_tensor = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.float32)
output = F.vmap(mm, 0)(replace_tensor)
expect_res1 = Tensor([[1, 3, 5], [4, 6, 8], [7, 9, 11]], mstype.float32)
expect_res2 = Tensor([1, 2, 3], mstype.float32)
expect_res3 = Tensor([4, 5, 6], mstype.float32)
expect_res4 = Tensor([7, 8, 9], mstype.float32)
expect_res5 = Tensor([0, 1, 2], mstype.float32)
assert np.allclose(output.asnumpy(), expect_res1.asnumpy())
assert np.allclose(m1.ref_a.asnumpy(), expect_res2.asnumpy())
assert np.allclose(m2.ref_a.asnumpy(), expect_res3.asnumpy())
assert np.allclose(m3.ref_a.asnumpy(), expect_res4.asnumpy())
assert np.allclose(m1.ref_b.asnumpy(), expect_res5.asnumpy())
assert np.allclose(m2.ref_b.asnumpy(), expect_res5.asnumpy())
assert np.allclose(m3.ref_b.asnumpy(), expect_res5.asnumpy())

View File

@ -67,7 +67,7 @@ def test_none_in_axes():
z_hat = 1
with pytest.raises(RuntimeError) as ex:
vmap(ThreeInputsTwoOutputsNet(), in_axes=None, out_axes=0)(x_hat, y_hat, z_hat)
assert "The 'in_axes' of 'vmap' cannot be a single None." in str(ex.value)
assert "The 'in_axes' of 'vmap' cannot be a single None while 'fn' is not a 'CellList'." in str(ex.value)
def test_none_out_axes():
@ -83,7 +83,8 @@ def test_none_out_axes():
with pytest.raises(RuntimeError) as ex:
vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 1, None),
out_axes=(None, None, None, (None, None)))(x_hat, y_hat, z_hat)
assert "The 'out_axes' of 'vmap' cannot be all None, but got (None, None, None, (None, None))." in str(ex.value)
assert "The 'out_axes' of 'vmap' cannot be all None while 'fn' is not a 'CellList', " \
"but got (None, None, None, (None, None))." in str(ex.value)
def test_mismatch_out_axes():