|
|
|
@ -18,7 +18,11 @@
|
|
|
|
|
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <set>
|
|
|
|
|
#include <regex>
|
|
|
|
|
#include "utils/hash_map.h"
|
|
|
|
|
#include "ir/func_graph_cloner.h"
|
|
|
|
|
#include "base/complex_storage.h"
|
|
|
|
|
#include "frontend/optimizer/irpass/gradient_eliminate.h"
|
|
|
|
|
#include "frontend/parallel/step_parallel_utils.h"
|
|
|
|
|
#include "pipeline/pynative/pynative_execute.h"
|
|
|
|
@ -33,43 +37,123 @@ namespace internal {
|
|
|
|
|
const mindspore::HashSet<std::string> throughtout_op{prim::kPrimMakeTuple->name(), prim::kPrimMakeList->name(),
|
|
|
|
|
prim::kPrimDepend->name(), prim::kPrimReturn->name(),
|
|
|
|
|
prim::kPrimUpdateState->name(), prim::kPrimStopGradient->name()};
|
|
|
|
|
CNodePtr BuildBindInAxisTupleInput(const AnfNodePtr &input, const ValuePtr &in_axis, const FuncGraphPtr &fg) {
|
|
|
|
|
auto input_abs_elements = dyn_cast<abstract::AbstractTuple>(input->abstract());
|
|
|
|
|
CNodePtr BuildBindInAxisSeqInput(const AnfNodePtr &input, const ValuePtr &in_axis, const FuncGraphPtr &fg) {
|
|
|
|
|
auto input_abs = input->abstract();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_abs);
|
|
|
|
|
auto input_abs_elements = dyn_cast<abstract::AbstractSequence>(input_abs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_abs_elements);
|
|
|
|
|
ValueSequencePtr in_axis_value_sequence = nullptr;
|
|
|
|
|
if (in_axis->isa<ValueSequence>()) {
|
|
|
|
|
in_axis_value_sequence = dyn_cast<ValueSequence>(in_axis);
|
|
|
|
|
if (input_abs_elements->size() != in_axis_value_sequence->size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The length of input and in_axis should be the same but got input length: "
|
|
|
|
|
<< input_abs_elements->size() << ", in_axis length: " << in_axis_value_sequence->size() << ".";
|
|
|
|
|
MS_EXCEPTION(ValueError) << "The length of input and in_axis should be the same but got input length: "
|
|
|
|
|
<< input_abs_elements->size() << ", in_axis length: " << in_axis_value_sequence->size()
|
|
|
|
|
<< ".";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::vector<AnfNodePtr> ret_inputs;
|
|
|
|
|
(void)ret_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
|
|
|
|
if (input_abs->isa<abstract::AbstractList>()) {
|
|
|
|
|
(void)ret_inputs.emplace_back(NewValueNode(prim::kPrimMakeList));
|
|
|
|
|
} else {
|
|
|
|
|
(void)ret_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (unsigned int i = 0; i < input_abs_elements->size(); ++i) {
|
|
|
|
|
std::vector<AnfNodePtr> tuple_getitem_cnode_inputs;
|
|
|
|
|
(void)tuple_getitem_cnode_inputs.emplace_back(NewValueNode(prim::kPrimTupleGetItem));
|
|
|
|
|
(void)tuple_getitem_cnode_inputs.emplace_back(input);
|
|
|
|
|
(void)tuple_getitem_cnode_inputs.emplace_back(NewValueNode(static_cast<int64_t>(i)));
|
|
|
|
|
auto tuple_getitem_cnode = fg->NewCNode(tuple_getitem_cnode_inputs);
|
|
|
|
|
std::vector<AnfNodePtr> seq_getitem_cnode_inputs;
|
|
|
|
|
if (input_abs->isa<abstract::AbstractList>()) {
|
|
|
|
|
(void)seq_getitem_cnode_inputs.emplace_back(NewValueNode(prim::kPrimListGetItem));
|
|
|
|
|
} else {
|
|
|
|
|
(void)seq_getitem_cnode_inputs.emplace_back(NewValueNode(prim::kPrimTupleGetItem));
|
|
|
|
|
}
|
|
|
|
|
(void)seq_getitem_cnode_inputs.emplace_back(input);
|
|
|
|
|
(void)seq_getitem_cnode_inputs.emplace_back(NewValueNode(static_cast<int64_t>(i)));
|
|
|
|
|
auto seq_getitem_cnode = fg->NewCNode(seq_getitem_cnode_inputs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(seq_getitem_cnode);
|
|
|
|
|
auto input_abs_element = (*input_abs_elements)[i];
|
|
|
|
|
auto in_axis_value = in_axis_value_sequence == nullptr ? in_axis : (*in_axis_value_sequence)[i];
|
|
|
|
|
CNodePtr cur_make_tuple = nullptr;
|
|
|
|
|
if (input_abs_element->isa<abstract::AbstractTuple>()) {
|
|
|
|
|
tuple_getitem_cnode->set_abstract(input_abs_element);
|
|
|
|
|
cur_make_tuple = BuildBindInAxisTupleInput(tuple_getitem_cnode, in_axis_value, fg);
|
|
|
|
|
CNodePtr cur_make_seq = nullptr;
|
|
|
|
|
if (input_abs_element->isa<abstract::AbstractSequence>()) {
|
|
|
|
|
seq_getitem_cnode->set_abstract(input_abs_element);
|
|
|
|
|
cur_make_seq = BuildBindInAxisSeqInput(seq_getitem_cnode, in_axis_value, fg);
|
|
|
|
|
} else {
|
|
|
|
|
std::vector<AnfNodePtr> cur_make_tuple_inputs;
|
|
|
|
|
(void)cur_make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
|
|
|
|
(void)cur_make_tuple_inputs.emplace_back(tuple_getitem_cnode);
|
|
|
|
|
(void)cur_make_tuple_inputs.emplace_back(seq_getitem_cnode);
|
|
|
|
|
(void)cur_make_tuple_inputs.emplace_back(NewValueNode(in_axis_value));
|
|
|
|
|
cur_make_tuple = fg->NewCNode(cur_make_tuple_inputs);
|
|
|
|
|
cur_make_seq = fg->NewCNode(cur_make_tuple_inputs);
|
|
|
|
|
}
|
|
|
|
|
(void)ret_inputs.emplace_back(cur_make_tuple);
|
|
|
|
|
(void)ret_inputs.emplace_back(cur_make_seq);
|
|
|
|
|
}
|
|
|
|
|
return fg->NewCNode(ret_inputs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr BindInAxis(const CNodePtr &vmap_app, const ValuePtr &in_axes, size_t *u_monad_offset) {
|
|
|
|
|
AnfNodePtr UpdateParam(const FuncGraphPtr &vmap_fg, const AnfNodePtr &u_monad_node,
|
|
|
|
|
ParamMappingVector *param_mapping_table) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(u_monad_node);
|
|
|
|
|
auto update_state_prim = NewValueNode(prim::kPrimUpdateState);
|
|
|
|
|
auto load_prim = NewValueNode(prim::kPrimLoad);
|
|
|
|
|
std::vector<AnfNodePtr> attach_tuple{NewValueNode(prim::kPrimMakeTuple)};
|
|
|
|
|
for (auto param_pair : *param_mapping_table) {
|
|
|
|
|
auto ref = param_pair.first;
|
|
|
|
|
auto each_cell_params = param_pair.second;
|
|
|
|
|
std::vector<AnfNodePtr> param_tuple{NewValueNode(prim::kPrimMakeTuple)};
|
|
|
|
|
for (auto param : each_cell_params) {
|
|
|
|
|
auto load_cnode = vmap_fg->NewCNode({load_prim, param, u_monad_node});
|
|
|
|
|
MS_EXCEPTION_IF_NULL(load_cnode);
|
|
|
|
|
auto param_abs = dyn_cast<abstract::AbstractRefTensor>(param->abstract());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(param_abs);
|
|
|
|
|
load_cnode->set_abstract(param_abs->CloneAsTensor());
|
|
|
|
|
param_tuple.push_back(load_cnode);
|
|
|
|
|
}
|
|
|
|
|
auto param_tuple_cnode = vmap_fg->NewCNode(param_tuple);
|
|
|
|
|
auto update_state_after_load = vmap_fg->NewCNode({update_state_prim, u_monad_node, param_tuple_cnode});
|
|
|
|
|
const py::function stack_fn = python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_stack");
|
|
|
|
|
auto stack_fg = parse::ParsePythonCode(stack_fn);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(stack_fg);
|
|
|
|
|
auto stack_cnode = vmap_fg->NewCNode({NewValueNode(stack_fg), param_tuple_cnode});
|
|
|
|
|
auto assign_prim = NewValueNode(prim::kPrimAssign);
|
|
|
|
|
auto assign_cnode = vmap_fg->NewCNode({assign_prim, ref, stack_cnode, update_state_after_load});
|
|
|
|
|
attach_tuple.push_back(assign_cnode);
|
|
|
|
|
}
|
|
|
|
|
auto attach_cnode = vmap_fg->NewCNode(attach_tuple);
|
|
|
|
|
auto update_state_node = vmap_fg->NewCNode({update_state_prim, u_monad_node, attach_cnode});
|
|
|
|
|
return update_state_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GetMonadOffset(const std::vector<AnfNodePtr> &inputs, size_t *u_monad_offset, size_t *io_monad_offset) {
|
|
|
|
|
// Check the last two (if exists) is monad input.
|
|
|
|
|
if (*u_monad_offset != 0 || *io_monad_offset != 0) {
|
|
|
|
|
MS_EXCEPTION(ValueError) << "The initial value of u_monad_offset and io_monad_offset should be 0, but we got "
|
|
|
|
|
<< "u_monad_offset: " << *u_monad_offset << " and io_monad_offset: " << *io_monad_offset
|
|
|
|
|
<< ".";
|
|
|
|
|
}
|
|
|
|
|
auto inputs_size = inputs.size();
|
|
|
|
|
constexpr size_t max_monad_input_num = 2;
|
|
|
|
|
if (HasAbstractMonad(inputs[inputs_size - 1])) {
|
|
|
|
|
if (HasAbstractUMonad(inputs[inputs_size - 1])) {
|
|
|
|
|
*u_monad_offset = 1;
|
|
|
|
|
} else if (inputs_size >= max_monad_input_num && HasAbstractUMonad(inputs[inputs_size - max_monad_input_num])) {
|
|
|
|
|
++(*io_monad_offset);
|
|
|
|
|
*u_monad_offset = *io_monad_offset + 1;
|
|
|
|
|
} else {
|
|
|
|
|
++(*io_monad_offset);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BindUMonad(const AnfNodePtr &u_monad_node, const FuncGraphPtr &vmap_fg, std::vector<AnfNodePtr> *outputs,
|
|
|
|
|
ParamMappingVector *param_mapping_table) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(u_monad_node);
|
|
|
|
|
if (param_mapping_table == nullptr || param_mapping_table->empty()) {
|
|
|
|
|
(void)outputs->emplace_back(u_monad_node);
|
|
|
|
|
} else {
|
|
|
|
|
auto update_state_node = UpdateParam(vmap_fg, u_monad_node, param_mapping_table);
|
|
|
|
|
(void)outputs->emplace_back(update_state_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr BindInAxis(const CNodePtr &vmap_app, const ValuePtr &in_axes, size_t *u_monad_offset,
|
|
|
|
|
size_t *io_monad_offset, ParamMappingVector *param_mapping_table) {
|
|
|
|
|
FuncGraphPtr vmap_fg = vmap_app->func_graph();
|
|
|
|
|
bool is_in_axes_value_sequence = in_axes->isa<ValueSequence>();
|
|
|
|
|
ValueSequencePtr in_axes_to_value_sequence = dyn_cast<ValueSequence>(in_axes);
|
|
|
|
@ -77,30 +161,17 @@ AnfNodePtr BindInAxis(const CNodePtr &vmap_app, const ValuePtr &in_axes, size_t
|
|
|
|
|
auto inputs = vmap_app->inputs();
|
|
|
|
|
auto inputs_size = inputs.size();
|
|
|
|
|
if (inputs_size == 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The inputs number of CNode: " << vmap_app->DebugString()
|
|
|
|
|
<< " should be positive but got : " << inputs_size << ".";
|
|
|
|
|
MS_EXCEPTION(ValueError) << "The inputs number of CNode: " << vmap_app->DebugString()
|
|
|
|
|
<< " should be positive but got : " << inputs_size << ".";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Check the last two (if exists) is monad input.
|
|
|
|
|
size_t io_monad_offset = 0;
|
|
|
|
|
constexpr size_t max_monad_input_num = 2;
|
|
|
|
|
if (HasAbstractMonad(inputs[inputs_size - 1])) {
|
|
|
|
|
if (HasAbstractUMonad(inputs[inputs_size - 1])) {
|
|
|
|
|
*u_monad_offset = 1;
|
|
|
|
|
} else if (inputs_size >= max_monad_input_num && HasAbstractUMonad(inputs[inputs_size - max_monad_input_num])) {
|
|
|
|
|
io_monad_offset++;
|
|
|
|
|
*u_monad_offset = io_monad_offset + 1;
|
|
|
|
|
} else {
|
|
|
|
|
io_monad_offset++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t abstract_monad_count = *u_monad_offset > io_monad_offset ? *u_monad_offset : io_monad_offset;
|
|
|
|
|
auto real_params_size = inputs_size - abstract_monad_count;
|
|
|
|
|
GetMonadOffset(inputs, u_monad_offset, io_monad_offset);
|
|
|
|
|
size_t abstract_monad_count = *u_monad_offset > *io_monad_offset ? *u_monad_offset : *io_monad_offset;
|
|
|
|
|
size_t real_params_size = inputs_size > abstract_monad_count ? inputs_size - abstract_monad_count : 0;
|
|
|
|
|
if (is_in_axes_value_sequence && real_params_size - 1 != in_axes_to_value_sequence->size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The length of vmap_app inputs (except primitive input and monad input) is: "
|
|
|
|
|
<< real_params_size - 1 << " and the length of in_axis is: " << in_axes_to_value_sequence->size()
|
|
|
|
|
<< ". These two numbers should be equal.";
|
|
|
|
|
MS_EXCEPTION(ValueError) << "The length of vmap_app inputs (except primitive input and monad input) is: "
|
|
|
|
|
<< (real_params_size - 1)
|
|
|
|
|
<< " and the length of in_axis is: " << in_axes_to_value_sequence->size()
|
|
|
|
|
<< ". These two numbers should be equal.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> outputs;
|
|
|
|
@ -109,21 +180,24 @@ AnfNodePtr BindInAxis(const CNodePtr &vmap_app, const ValuePtr &in_axes, size_t
|
|
|
|
|
auto input = inputs[i];
|
|
|
|
|
auto in_axis = is_in_axes_value_sequence ? (*in_axes_to_value_sequence)[i - 1] : in_axes;
|
|
|
|
|
auto input_abs = input->abstract();
|
|
|
|
|
CNodePtr cur_make_tuple_cnode = nullptr;
|
|
|
|
|
if (input_abs->isa<abstract::AbstractTuple>()) {
|
|
|
|
|
cur_make_tuple_cnode = BuildBindInAxisTupleInput(input, in_axis, vmap_fg);
|
|
|
|
|
CNodePtr cur_make_seq_cnode = nullptr;
|
|
|
|
|
if (input_abs->isa<abstract::AbstractSequence>()) {
|
|
|
|
|
cur_make_seq_cnode = BuildBindInAxisSeqInput(input, in_axis, vmap_fg);
|
|
|
|
|
} else {
|
|
|
|
|
cur_make_tuple_cnode = vmap_fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), input, NewValueNode(in_axis)});
|
|
|
|
|
cur_make_seq_cnode = vmap_fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), input, NewValueNode(in_axis)});
|
|
|
|
|
}
|
|
|
|
|
(void)outputs.emplace_back(cur_make_tuple_cnode);
|
|
|
|
|
(void)outputs.emplace_back(cur_make_seq_cnode);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (abstract_monad_count == 1) {
|
|
|
|
|
(void)outputs.emplace_back(inputs.back());
|
|
|
|
|
} else if (IntToSize(abstract_monad_count) == max_monad_input_num) {
|
|
|
|
|
(void)outputs.emplace_back(inputs[inputs_size - max_monad_input_num]);
|
|
|
|
|
if (*u_monad_offset > 0 && inputs_size > *u_monad_offset) {
|
|
|
|
|
AnfNodePtr u_monad_node = inputs[inputs_size - *u_monad_offset];
|
|
|
|
|
BindUMonad(u_monad_node, vmap_fg, &outputs, param_mapping_table);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (*io_monad_offset > 0 && inputs_size > 0) {
|
|
|
|
|
(void)outputs.emplace_back(inputs.back());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return vmap_fg->NewCNode(outputs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -165,8 +239,8 @@ void GetSubAxisSize(const AbstractBasePtr &sub_abs, ValuePtr *const sub_in_axes,
|
|
|
|
|
if (*axis_size == kInvalidAxisSize) {
|
|
|
|
|
*axis_size = sub_axis_size;
|
|
|
|
|
} else if (*axis_size != sub_axis_size) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The 'axis_size' of each argument in the scope of 'vmap' should be equal, but got "
|
|
|
|
|
<< *axis_size << " and " << sub_axis_size << ".";
|
|
|
|
|
MS_EXCEPTION(ValueError) << "The 'axis_size' of each argument in the scope of 'vmap' should be equal, but got "
|
|
|
|
|
<< *axis_size << " and " << sub_axis_size << ".";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -196,11 +270,29 @@ int GetAxisSizeByAbs(const AbstractBasePtr &abs, ValuePtr *const in_axes) {
|
|
|
|
|
auto in_axes_int = dyn_cast<Int64Imm>(*in_axes);
|
|
|
|
|
if (in_axes_int != nullptr) {
|
|
|
|
|
int64_t axis = in_axes_int->value();
|
|
|
|
|
ShapeVector orig_shape = dyn_cast<abstract::Shape>(abs->BuildShape())->shape();
|
|
|
|
|
if (!abs->isa<abstract::AbstractTensor>()) {
|
|
|
|
|
// If got a AbstractScalar with a value 0 of type int32, it means that the input is not used later.
|
|
|
|
|
auto abs_value = abs->BuildValue();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(abs_value);
|
|
|
|
|
auto abs_int32_t = dyn_cast<Int32Imm>(abs_value);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(abs_int32_t);
|
|
|
|
|
if (abs_int32_t->value() == 0) {
|
|
|
|
|
MS_LOG(WARNING) << "There is a argument not used in the scope of vmap. Please check whether the inputs"
|
|
|
|
|
<< " meet expectations.";
|
|
|
|
|
return axis_size;
|
|
|
|
|
}
|
|
|
|
|
MS_EXCEPTION(ValueError) << "The abs should be AbstractTensor when axis is " << axis << ", but got a "
|
|
|
|
|
<< abs->ToString() << ".";
|
|
|
|
|
}
|
|
|
|
|
auto shape = abs->BuildShape();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(shape);
|
|
|
|
|
auto shape_ptr = dyn_cast<abstract::Shape>(shape);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(shape_ptr);
|
|
|
|
|
ShapeVector orig_shape = shape_ptr->shape();
|
|
|
|
|
int64_t shape_len = SizeToLong(orig_shape.size());
|
|
|
|
|
if (axis < -shape_len || axis >= shape_len) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "ValueError: axis " << axis << " is out of bounds for array of dimension [" << -shape_len
|
|
|
|
|
<< "," << shape_len << ").";
|
|
|
|
|
MS_EXCEPTION(ValueError) << "ValueError: axis " << axis << " is out of bounds for array of dimension ["
|
|
|
|
|
<< -shape_len << "," << shape_len << ").";
|
|
|
|
|
}
|
|
|
|
|
axis = axis < 0 ? shape_len + axis : axis;
|
|
|
|
|
*in_axes = std::make_shared<Int64Imm>(axis);
|
|
|
|
@ -212,11 +304,10 @@ int GetAxisSizeByAbs(const AbstractBasePtr &abs, ValuePtr *const in_axes) {
|
|
|
|
|
|
|
|
|
|
// get the axis size of currently vmap scope, at the same time, the negative indexes in in_axes are converted to
|
|
|
|
|
// corresponding positive indexes.
|
|
|
|
|
int GetAxisSize(const CNodePtr &cnode, ValuePtr *const in_axes) {
|
|
|
|
|
int GetAxisSize(const CNodePtr &cnode, size_t cell_size, size_t parameters_size, ValuePtr *const in_axes) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
// `axis_size` is unique within the scope of vmap, so we just need to get one of them.
|
|
|
|
|
int axis_size = kInvalidAxisSize;
|
|
|
|
|
size_t parameters_size = cnode->size() - 1;
|
|
|
|
|
auto in_axes_seq = GetInAxesSeq(*in_axes, parameters_size);
|
|
|
|
|
std::vector<ValuePtr> corrected_in_axes;
|
|
|
|
|
for (size_t i = 0; i < parameters_size; ++i) {
|
|
|
|
@ -228,48 +319,143 @@ int GetAxisSize(const CNodePtr &cnode, ValuePtr *const in_axes) {
|
|
|
|
|
GetSubAxisSize(sub_abs, &sub_in_axes, &axis_size, &corrected_in_axes);
|
|
|
|
|
}
|
|
|
|
|
*in_axes = std::make_shared<ValueSequence>(corrected_in_axes);
|
|
|
|
|
|
|
|
|
|
if (cell_size > 0) {
|
|
|
|
|
if (axis_size == kInvalidAxisSize) {
|
|
|
|
|
axis_size = SizeToLong(cell_size);
|
|
|
|
|
} else if (SizeToLong(cell_size) != axis_size) {
|
|
|
|
|
MS_EXCEPTION(ValueError) << "If you want to execute the model ensembling parallel training, please make sure "
|
|
|
|
|
<< "the 'axis_size' in the scope of vmap consistent with the cell size of the input "
|
|
|
|
|
<< "'CellList', otherwise, please do not enter 'CellList' as the first argument, "
|
|
|
|
|
<< "but we get axis_size: " << axis_size << " and the cell size: " << cell_size << ".";
|
|
|
|
|
}
|
|
|
|
|
} else if (axis_size == kInvalidAxisSize) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failed to get 'axis_size' within the scope of vmap.";
|
|
|
|
|
}
|
|
|
|
|
return axis_size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr MatchOutAxis(const AnfNodePtr &expanded_vmap_node, int parameters_size, size_t u_monad_offset, int axis_size,
|
|
|
|
|
const ValuePtr &out_axes) {
|
|
|
|
|
CNodePtr AttachToOutput(const FuncGraphPtr &func_graph, const CNodePtr &output, const AnfNodePtr &node) {
|
|
|
|
|
TraceGuard guard(std::make_shared<TraceCopy>(output->debug_info()));
|
|
|
|
|
auto depend = NewValueNode(prim::kPrimDepend);
|
|
|
|
|
auto depend_cnode = func_graph->NewCNode({depend, output, node});
|
|
|
|
|
MS_EXCEPTION_IF_NULL(depend_cnode);
|
|
|
|
|
depend_cnode->set_abstract(output->abstract());
|
|
|
|
|
return depend_cnode;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr FeedBackParam(const FuncGraphPtr &vmap_post_fg, const AnfNodePtr &u_monad_node,
|
|
|
|
|
const AnfNodePtr &io_monad_node, const CNodePtr &output,
|
|
|
|
|
ParamMappingVector *param_mapping_table) {
|
|
|
|
|
auto update_state_prim = NewValueNode(prim::kPrimUpdateState);
|
|
|
|
|
std::vector<AnfNodePtr> out_attach_tuple{NewValueNode(prim::kPrimMakeTuple)};
|
|
|
|
|
|
|
|
|
|
for (auto param_pair : *param_mapping_table) {
|
|
|
|
|
auto ref = param_pair.first;
|
|
|
|
|
auto load_prim = NewValueNode(prim::kPrimLoad);
|
|
|
|
|
auto load_cnode = vmap_post_fg->NewCNode({load_prim, ref, u_monad_node});
|
|
|
|
|
MS_EXCEPTION_IF_NULL(load_cnode);
|
|
|
|
|
auto ref_abs = dyn_cast<abstract::AbstractRefTensor>(ref->abstract());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ref_abs);
|
|
|
|
|
load_cnode->set_abstract(ref_abs->CloneAsTensor());
|
|
|
|
|
|
|
|
|
|
auto update_state_after_load = vmap_post_fg->NewCNode({update_state_prim, u_monad_node, load_cnode});
|
|
|
|
|
MS_EXCEPTION_IF_NULL(update_state_after_load);
|
|
|
|
|
update_state_after_load->set_abstract(u_monad_node->abstract());
|
|
|
|
|
|
|
|
|
|
PrimitivePtr kprim_unstack = std::make_shared<Primitive>(prim::kUnstack);
|
|
|
|
|
kprim_unstack->set_attr(kAttrAxis, MakeValue(SizeToLong(0)));
|
|
|
|
|
auto unstack_prim = NewValueNode(kprim_unstack);
|
|
|
|
|
auto unstack_cnode = vmap_post_fg->NewCNode({unstack_prim, load_cnode});
|
|
|
|
|
|
|
|
|
|
auto each_cell_params = param_pair.second;
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> attach_tuple{NewValueNode(prim::kPrimMakeTuple)};
|
|
|
|
|
|
|
|
|
|
int64_t cell_index = 0;
|
|
|
|
|
for (auto param : each_cell_params) {
|
|
|
|
|
auto tuple_getitem_prim = NewValueNode(prim::kPrimTupleGetItem);
|
|
|
|
|
auto tuple_getitem_cnode = vmap_post_fg->NewCNode({tuple_getitem_prim, unstack_cnode, NewValueNode(cell_index)});
|
|
|
|
|
cell_index++;
|
|
|
|
|
|
|
|
|
|
auto assign_prim = NewValueNode(prim::kPrimAssign);
|
|
|
|
|
auto assign_cnode = vmap_post_fg->NewCNode({assign_prim, param, tuple_getitem_cnode, update_state_after_load});
|
|
|
|
|
attach_tuple.push_back(assign_cnode);
|
|
|
|
|
}
|
|
|
|
|
auto attach_cnode = vmap_post_fg->NewCNode(attach_tuple);
|
|
|
|
|
auto update_state_after_assign = vmap_post_fg->NewCNode({update_state_prim, update_state_after_load, attach_cnode});
|
|
|
|
|
MS_EXCEPTION_IF_NULL(update_state_after_assign);
|
|
|
|
|
update_state_after_assign->set_abstract(update_state_after_load->abstract());
|
|
|
|
|
out_attach_tuple.push_back(update_state_after_assign);
|
|
|
|
|
}
|
|
|
|
|
auto out_attach_cnode = vmap_post_fg->NewCNode(out_attach_tuple);
|
|
|
|
|
auto attach_output = AttachToOutput(vmap_post_fg, output, out_attach_cnode);
|
|
|
|
|
if (io_monad_node) {
|
|
|
|
|
attach_output = AttachToOutput(vmap_post_fg, attach_output, io_monad_node);
|
|
|
|
|
}
|
|
|
|
|
vmap_post_fg->set_output(attach_output);
|
|
|
|
|
return NewValueNode(vmap_post_fg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr PostProcessVmap(const AnfNodePtr &expanded_vmap_node, const std::vector<size_t> &orig_fg_param_info,
|
|
|
|
|
const ValuePtr &out_axes, int axis_size, ParamMappingVector *param_mapping_table) {
|
|
|
|
|
FuncGraphPtr vmap_post_fg = std::make_shared<FuncGraph>();
|
|
|
|
|
std::vector<AnfNodePtr> exec_node;
|
|
|
|
|
exec_node.push_back(expanded_vmap_node);
|
|
|
|
|
AnfNodePtr u_monad_node = nullptr;
|
|
|
|
|
int offset = SizeToInt(u_monad_offset);
|
|
|
|
|
int u_monad_index = parameters_size > offset ? parameters_size - offset : parameters_size;
|
|
|
|
|
for (int i = 0; i < parameters_size; ++i) {
|
|
|
|
|
AnfNodePtr io_monad_node = nullptr;
|
|
|
|
|
size_t parameters_size = orig_fg_param_info[kParamSizeIndex];
|
|
|
|
|
size_t u_monad_offset = orig_fg_param_info[kUMonadOffsetIndex];
|
|
|
|
|
size_t io_monad_offset = orig_fg_param_info[kIOMonadOffsetIndex];
|
|
|
|
|
size_t u_monad_index = parameters_size > u_monad_offset ? parameters_size - u_monad_offset : parameters_size;
|
|
|
|
|
size_t io_monad_index = parameters_size > io_monad_offset ? parameters_size - io_monad_offset : parameters_size;
|
|
|
|
|
for (size_t i = 0; i < parameters_size; ++i) {
|
|
|
|
|
if (i == u_monad_index) {
|
|
|
|
|
u_monad_node = vmap_post_fg->add_parameter();
|
|
|
|
|
exec_node.push_back(u_monad_node);
|
|
|
|
|
continue;
|
|
|
|
|
} else if (i == io_monad_index) {
|
|
|
|
|
io_monad_node = vmap_post_fg->add_parameter();
|
|
|
|
|
exec_node.push_back(io_monad_node);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
exec_node.push_back(vmap_post_fg->add_parameter());
|
|
|
|
|
}
|
|
|
|
|
auto vmap_outputs = vmap_post_fg->NewCNode(exec_node);
|
|
|
|
|
|
|
|
|
|
if (u_monad_node != nullptr) {
|
|
|
|
|
auto update_state_prim = NewValueNode(prim::kPrimUpdateState);
|
|
|
|
|
auto update_state_prim = NewValueNode(prim::kPrimUpdateState);
|
|
|
|
|
if (u_monad_node) {
|
|
|
|
|
auto update_state_cnode = vmap_post_fg->NewCNode({update_state_prim, u_monad_node, vmap_outputs});
|
|
|
|
|
MS_EXCEPTION_IF_NULL(update_state_cnode);
|
|
|
|
|
update_state_cnode->set_abstract(u_monad_node->abstract());
|
|
|
|
|
u_monad_node = update_state_cnode;
|
|
|
|
|
}
|
|
|
|
|
if (io_monad_node) {
|
|
|
|
|
auto update_state_cnode = vmap_post_fg->NewCNode({update_state_prim, io_monad_node, vmap_outputs});
|
|
|
|
|
MS_EXCEPTION_IF_NULL(update_state_cnode);
|
|
|
|
|
update_state_cnode->set_abstract(io_monad_node->abstract());
|
|
|
|
|
io_monad_node = update_state_cnode;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// MatchOutAxis: Convert the outputs according to the out_axes to the specified physical perspective.
|
|
|
|
|
auto match_out_axis_app =
|
|
|
|
|
vmap_post_fg->NewCNode({NewValueNode(std::make_shared<prim::VmapMatchOutAxis>("VmapMatchOutAxis")), vmap_outputs,
|
|
|
|
|
NewValueNode(out_axes), NewValueNode(static_cast<int64_t>(axis_size))});
|
|
|
|
|
|
|
|
|
|
if (u_monad_node != nullptr) {
|
|
|
|
|
auto depend_prim = NewValueNode(prim::kPrimDepend);
|
|
|
|
|
auto state_depend = vmap_post_fg->NewCNode({depend_prim, match_out_axis_app, u_monad_node});
|
|
|
|
|
state_depend->set_abstract(match_out_axis_app->abstract());
|
|
|
|
|
vmap_post_fg->set_output(state_depend);
|
|
|
|
|
if (param_mapping_table == nullptr || param_mapping_table->empty()) {
|
|
|
|
|
auto output = match_out_axis_app;
|
|
|
|
|
if (u_monad_node) {
|
|
|
|
|
output = AttachToOutput(vmap_post_fg, output, u_monad_node);
|
|
|
|
|
}
|
|
|
|
|
if (io_monad_node) {
|
|
|
|
|
output = AttachToOutput(vmap_post_fg, output, io_monad_node);
|
|
|
|
|
}
|
|
|
|
|
vmap_post_fg->set_output(output);
|
|
|
|
|
return NewValueNode(vmap_post_fg);
|
|
|
|
|
}
|
|
|
|
|
vmap_post_fg->set_output(match_out_axis_app);
|
|
|
|
|
|
|
|
|
|
return NewValueNode(vmap_post_fg);
|
|
|
|
|
// Feed parameters back to each cell in the model ensembling parallel training case.
|
|
|
|
|
return FeedBackParam(vmap_post_fg, u_monad_node, io_monad_node, match_out_axis_app, param_mapping_table);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr GetVmapRule(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resource, int axis_size) {
|
|
|
|
@ -392,13 +578,19 @@ AnfNodePtr CopyNodeToVmap(const AnfNodePtr &node, const FuncGraphPtr &func_graph
|
|
|
|
|
return node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BindNoneAxis(const AnfNodePtr &node, const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng) {
|
|
|
|
|
void BindFvAxis(const AnfNodePtr &node, const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
|
|
|
|
|
const AnfNodePtr &stacked_param_node = nullptr) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto &node_user_map = mng->node_users();
|
|
|
|
|
auto user = node_user_map.find(node);
|
|
|
|
|
if (user != node_user_map.end() && !user->second.empty()) {
|
|
|
|
|
auto make_tuple = NewValueNode(prim::kPrimMakeTuple);
|
|
|
|
|
auto replace_node = func_graph->NewCNode({make_tuple, node, NewValueNode(kNone)});
|
|
|
|
|
CNodePtr replace_node = nullptr;
|
|
|
|
|
if (stacked_param_node == nullptr) {
|
|
|
|
|
replace_node = func_graph->NewCNode({make_tuple, node, NewValueNode(kNone)});
|
|
|
|
|
} else {
|
|
|
|
|
replace_node = func_graph->NewCNode({make_tuple, stacked_param_node, NewValueNode(SizeToLong(0))});
|
|
|
|
|
}
|
|
|
|
|
auto user_set = user->second;
|
|
|
|
|
for (auto pair : user_set) {
|
|
|
|
|
if (pair.first->func_graph() == func_graph) {
|
|
|
|
@ -409,13 +601,36 @@ void BindNoneAxis(const AnfNodePtr &node, const FuncGraphPtr &func_graph, const
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BindParamAxis(const AnfNodePtr &node, const FuncGraphPtr &vmap_fg, const FuncGraphManagerPtr &manager,
|
|
|
|
|
mindspore::HashMap<std::string, ParameterPtr> *stacked_params) {
|
|
|
|
|
if (stacked_params == nullptr || stacked_params->empty()) {
|
|
|
|
|
BindFvAxis(node, vmap_fg, manager);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
std::string param_name = dyn_cast<Parameter>(node)->name();
|
|
|
|
|
std::regex e("^.*?\\d+\\.(.+)$");
|
|
|
|
|
param_name = std::regex_replace(param_name, e, "vmap.$1");
|
|
|
|
|
auto iter = stacked_params->find(param_name);
|
|
|
|
|
if (iter != stacked_params->end()) {
|
|
|
|
|
ParameterPtr stacked_param_node = iter->second;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(stacked_param_node);
|
|
|
|
|
BindFvAxis(node, vmap_fg, manager, stacked_param_node);
|
|
|
|
|
} else {
|
|
|
|
|
BindFvAxis(node, vmap_fg, manager);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ExpandVmapValueNode(const FuncGraphPtr &vmap_fg, const pipeline::ResourceBasePtr &resource,
|
|
|
|
|
mindspore::HashSet<FuncGraphPtr> *visited_graph, mindspore::HashSet<AnfNodePtr> *visited_node,
|
|
|
|
|
int axis_size) {
|
|
|
|
|
VisitedHashSetPair *visited_pair, int axis_size,
|
|
|
|
|
mindspore::HashMap<std::string, ParameterPtr> *stacked_params) {
|
|
|
|
|
// Map ValueNode.
|
|
|
|
|
auto manager = resource->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
auto value_nodes = vmap_fg->value_nodes();
|
|
|
|
|
|
|
|
|
|
auto visited_graph = &visited_pair->first;
|
|
|
|
|
auto visited_node = &visited_pair->second;
|
|
|
|
|
|
|
|
|
|
for (const auto &value_pair : value_nodes) {
|
|
|
|
|
auto node = value_pair.first;
|
|
|
|
|
// ValueNode may have been transformed when other graphs are expanded.
|
|
|
|
@ -432,7 +647,7 @@ void ExpandVmapValueNode(const FuncGraphPtr &vmap_fg, const pipeline::ResourceBa
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
(void)visited_graph->insert(sub_func_graph);
|
|
|
|
|
auto transformed_fg = ExpandVmapFunctor(sub_func_graph, resource, visited_graph, visited_node, axis_size);
|
|
|
|
|
auto transformed_fg = ExpandVmapFunctor(sub_func_graph, resource, axis_size, visited_pair, stacked_params);
|
|
|
|
|
auto replace_node = NewValueNode(transformed_fg);
|
|
|
|
|
(void)visited_node->insert(replace_node);
|
|
|
|
|
(void)manager->Replace(node, replace_node);
|
|
|
|
@ -461,28 +676,34 @@ void ExpandVmapValueNode(const FuncGraphPtr &vmap_fg, const pipeline::ResourceBa
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ExpandVmapFreeVariable(const FuncGraphPtr &vmap_fg, const FuncGraphManagerPtr &manager,
|
|
|
|
|
const mindspore::HashSet<AnfNodePtr> &visited_node) {
|
|
|
|
|
const mindspore::HashSet<AnfNodePtr> &visited_node,
|
|
|
|
|
mindspore::HashMap<std::string, ParameterPtr> *stacked_params) {
|
|
|
|
|
// Map free variable.
|
|
|
|
|
auto free_variables_nodes = vmap_fg->free_variables_nodes();
|
|
|
|
|
for (auto &node : free_variables_nodes) {
|
|
|
|
|
auto free_variables_nodes = vmap_fg->free_variables();
|
|
|
|
|
for (auto &pair : free_variables_nodes) {
|
|
|
|
|
auto node = pair.first;
|
|
|
|
|
if (visited_node.count(node) > 0 || node->isa<CNode>()) {
|
|
|
|
|
MS_LOG(DEBUG) << node->DebugString() << " has been transformed.";
|
|
|
|
|
} else if (node->isa<Parameter>() || IsValueNode<Scalar>(node) || IsValueNode<tensor::Tensor>(node) ||
|
|
|
|
|
IsValueNode<None>(node) || IsValueNode<ValueTuple>(node) || IsValueNode<Type>(node)) {
|
|
|
|
|
BindNoneAxis(node, vmap_fg, manager);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (IsValueNode<Scalar>(node) || IsValueNode<tensor::Tensor>(node) || IsValueNode<None>(node) ||
|
|
|
|
|
IsValueNode<ValueTuple>(node) || IsValueNode<Type>(node)) {
|
|
|
|
|
BindFvAxis(node, vmap_fg, manager);
|
|
|
|
|
} else if (node->isa<Parameter>()) {
|
|
|
|
|
BindParamAxis(node, vmap_fg, manager, stacked_params);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "vmap do not support transform " << node->DebugString() << " right now.";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr ExpandVmapFunctor(const FuncGraphPtr &vmap_fg, const pipeline::ResourceBasePtr &resource,
|
|
|
|
|
mindspore::HashSet<FuncGraphPtr> *visited_graph,
|
|
|
|
|
mindspore::HashSet<AnfNodePtr> *visited_node, int axis_size) {
|
|
|
|
|
FuncGraphPtr ExpandVmapFunctor(const FuncGraphPtr &vmap_fg, const pipeline::ResourceBasePtr &resource, int axis_size,
|
|
|
|
|
VisitedHashSetPair *visited_pair,
|
|
|
|
|
mindspore::HashMap<std::string, ParameterPtr> *stacked_params) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(vmap_fg);
|
|
|
|
|
auto manager = resource->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
manager->AddFuncGraph(vmap_fg);
|
|
|
|
|
auto visited_node = &visited_pair->second;
|
|
|
|
|
|
|
|
|
|
// The parameters of the current graph will be transformed in the upper graph, and recorded in
|
|
|
|
|
// `visited_node` to avoid being repeatedly transformed refer as a free variable in other graph.
|
|
|
|
@ -492,14 +713,15 @@ FuncGraphPtr ExpandVmapFunctor(const FuncGraphPtr &vmap_fg, const pipeline::Reso
|
|
|
|
|
(void)visited_node->insert(node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ExpandVmapValueNode(vmap_fg, resource, visited_graph, visited_node, axis_size);
|
|
|
|
|
ExpandVmapFreeVariable(vmap_fg, manager, *visited_node);
|
|
|
|
|
ExpandVmapValueNode(vmap_fg, resource, visited_pair, axis_size, stacked_params);
|
|
|
|
|
ExpandVmapFreeVariable(vmap_fg, manager, *visited_node, stacked_params);
|
|
|
|
|
|
|
|
|
|
return vmap_fg;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Entry to perform Vmap transformation.
|
|
|
|
|
AnfNodePtr ExpandVmap(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource, int axis_size) {
|
|
|
|
|
AnfNodePtr ExpandVmap(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource, int axis_size,
|
|
|
|
|
mindspore::HashMap<std::string, ParameterPtr> *stacked_params) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(vnode);
|
|
|
|
|
if (IsValueNode<FuncGraph>(vnode)) {
|
|
|
|
|
ScopeGuard scope_guard(vnode->scope());
|
|
|
|
@ -514,7 +736,8 @@ AnfNodePtr ExpandVmap(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr
|
|
|
|
|
(void)visited_graph.insert(func_graph);
|
|
|
|
|
(void)visited_node.insert(vnode);
|
|
|
|
|
|
|
|
|
|
auto tf_fg = ExpandVmapFunctor(func_graph, resource, &visited_graph, &visited_node, axis_size);
|
|
|
|
|
VisitedHashSetPair visited_pair(visited_graph, visited_node);
|
|
|
|
|
auto tf_fg = ExpandVmapFunctor(func_graph, resource, axis_size, &visited_pair, stacked_params);
|
|
|
|
|
visited_node.clear();
|
|
|
|
|
|
|
|
|
|
return NewValueNode(tf_fg);
|
|
|
|
@ -522,8 +745,150 @@ AnfNodePtr ExpandVmap(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr
|
|
|
|
|
MS_LOG(EXCEPTION) << "Currently, the first argument in F.vmap only supports Cell, Python defined "
|
|
|
|
|
"function or @ms_function decorated function.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string GetShapeString(const ShapeVector &tensor_shape) {
|
|
|
|
|
std::ostringstream oss;
|
|
|
|
|
oss << " Shape:";
|
|
|
|
|
for (auto &dim : tensor_shape) {
|
|
|
|
|
oss << " " << dim;
|
|
|
|
|
}
|
|
|
|
|
return oss.str();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GenerateStackedParams(const FuncGraphPtr vmap_fg, size_t cell_size,
|
|
|
|
|
const std::vector<std::vector<AnfNodePtr>> ¶m_table,
|
|
|
|
|
mindspore::HashMap<std::string, ParameterPtr> *stacked_params,
|
|
|
|
|
ParamMappingVector *param_mapping_table) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(vmap_fg);
|
|
|
|
|
FuncGraphPtr top_fg = vmap_fg;
|
|
|
|
|
while (top_fg->parent() != nullptr) {
|
|
|
|
|
top_fg = top_fg->parent();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ShapeVector tensor_shape;
|
|
|
|
|
TypeId tensor_type = kNumberTypeFloat32;
|
|
|
|
|
std::string param_name = "";
|
|
|
|
|
for (size_t i = 0; i < param_table[0].size(); ++i) {
|
|
|
|
|
std::vector<AnfNodePtr> param;
|
|
|
|
|
for (size_t j = 0; j < param_table.size(); ++j) {
|
|
|
|
|
auto param_node = dyn_cast<Parameter>(param_table[j][i]);
|
|
|
|
|
(void)param.emplace_back(param_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(param_node);
|
|
|
|
|
auto default_param = param_node->default_param();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(default_param);
|
|
|
|
|
auto param_tensor = dyn_cast<tensor::Tensor>(default_param);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(param_tensor);
|
|
|
|
|
if (j == 0) {
|
|
|
|
|
tensor_shape = param_tensor->shape();
|
|
|
|
|
tensor_type = param_tensor->data_type();
|
|
|
|
|
param_name = param_node->name();
|
|
|
|
|
} else {
|
|
|
|
|
if (tensor_type != param_tensor->data_type()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The corresponding parameter's type in each cell should be consistent, but get "
|
|
|
|
|
<< TypeIdToType(tensor_type)->ToString() << " and "
|
|
|
|
|
<< TypeIdToType(param_tensor->data_type())->ToString() << " for the parameter "
|
|
|
|
|
<< param_name << ".";
|
|
|
|
|
}
|
|
|
|
|
if (tensor_shape != param_tensor->shape()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The corresponding parameter's shape in each cell should be consistent, but get "
|
|
|
|
|
<< GetShapeString(tensor_shape) << " and " << GetShapeString(param_tensor->shape())
|
|
|
|
|
<< " for the parameter " << param_name << ".";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::regex e("^.*?0\\.(.+)$");
|
|
|
|
|
param_name = std::regex_replace(param_name, e, "vmap.$1");
|
|
|
|
|
ParameterPtr param_node = nullptr;
|
|
|
|
|
|
|
|
|
|
ShapeVector stacked_shape(tensor_shape);
|
|
|
|
|
(void)stacked_shape.insert(stacked_shape.begin(), cell_size);
|
|
|
|
|
tensor::TensorPtr stacked_param_tensor = std::make_shared<tensor::Tensor>(tensor_type, stacked_shape);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(stacked_param_tensor);
|
|
|
|
|
|
|
|
|
|
ParamInfoPtr param_info = std::make_shared<ParamInfo>();
|
|
|
|
|
param_info->set_name(param_name);
|
|
|
|
|
stacked_param_tensor->set_param_info(param_info);
|
|
|
|
|
|
|
|
|
|
param_node = top_fg->AddFvParameter(param_name, stacked_param_tensor);
|
|
|
|
|
MS_LOG(DEBUG) << "Add new parameter " << param_node->ToString() << "to the top graph " << top_fg->ToString() << ".";
|
|
|
|
|
|
|
|
|
|
(*stacked_params)[param_name] = param_node;
|
|
|
|
|
std::pair<ParameterPtr, std::vector<AnfNodePtr>> param_mapping(param_node, param);
|
|
|
|
|
(void)param_mapping_table->emplace_back(param_mapping);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GetCellParams(const FuncGraphPtr &vmap_fg, std::vector<AnfNodePtr> *param_nodes) {
|
|
|
|
|
std::set<AnfNodePtr> memo;
|
|
|
|
|
auto scan_fn = [&memo, param_nodes](const FuncGraphPtr &vmap_fg) {
|
|
|
|
|
auto fv_nodes = vmap_fg->free_variables();
|
|
|
|
|
for (auto &pair : fv_nodes) {
|
|
|
|
|
auto node = pair.first;
|
|
|
|
|
if (node->isa<Parameter>() && node->cast<ParameterPtr>()->has_default() && memo.emplace(node).second) {
|
|
|
|
|
(void)param_nodes->emplace_back(node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
scan_fn(vmap_fg);
|
|
|
|
|
auto used_fgs = vmap_fg->func_graphs_used_total();
|
|
|
|
|
for (auto &fg : used_fgs) {
|
|
|
|
|
scan_fn(fg);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr TraverseVmapNode(CNodePtr vmap_node, size_t cell_size,
|
|
|
|
|
mindspore::HashMap<std::string, ParameterPtr> *stacked_params,
|
|
|
|
|
ParamMappingVector *param_mapping_table) {
|
|
|
|
|
AnfNodePtr vmap_fn_node = nullptr;
|
|
|
|
|
auto cell_list_node = vmap_node->input(1);
|
|
|
|
|
CNodePtr cnode = cell_list_node->cast<CNodePtr>();
|
|
|
|
|
auto inputs_size = cnode->size();
|
|
|
|
|
if (inputs_size != (cell_size + 1)) {
|
|
|
|
|
MS_EXCEPTION(ValueError) << "The size of CellList Node should be equal to" << (cell_size + 1) << ", but get"
|
|
|
|
|
<< inputs_size << ".";
|
|
|
|
|
}
|
|
|
|
|
std::vector<std::vector<AnfNodePtr>> param_table(cell_size, std::vector<AnfNodePtr>());
|
|
|
|
|
FuncGraphPtr vmap_fg = nullptr;
|
|
|
|
|
size_t param_size = 0;
|
|
|
|
|
for (size_t i = 1; i < inputs_size; i++) {
|
|
|
|
|
vmap_fn_node = cnode->input(i);
|
|
|
|
|
vmap_fg = GetValueNode<FuncGraphPtr>(vmap_fn_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(vmap_fg);
|
|
|
|
|
GetCellParams(vmap_fg, ¶m_table[i - 1]);
|
|
|
|
|
if (param_size == 0) {
|
|
|
|
|
param_size = param_table[i - 1].size();
|
|
|
|
|
} else if (param_size != param_table[i - 1].size()) {
|
|
|
|
|
MS_EXCEPTION(ValueError) << "Parameter size of each cell should be consistent, but get " << param_size << " and "
|
|
|
|
|
<< param_table[i - 1].size() << ".";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
GenerateStackedParams(vmap_fg, cell_size, param_table, stacked_params, param_mapping_table);
|
|
|
|
|
return vmap_fn_node;
|
|
|
|
|
}
|
|
|
|
|
} // namespace internal
|
|
|
|
|
|
|
|
|
|
bool ExpandVmapPrim::CheckIfEmbedMetaFgPrim(const CNodePtr &node) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
AnfNodePtr value_node = node->input(1);
|
|
|
|
|
if (IsPrimitiveCNode(value_node, prim::kPrimMakeTuple)) {
|
|
|
|
|
CNodePtr cnode = value_node->cast<CNodePtr>();
|
|
|
|
|
value_node = cnode->input(1);
|
|
|
|
|
}
|
|
|
|
|
if (IsValueNode<Primitive>(value_node)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto func_graph = GetValueNode<FuncGraphPtr>(value_node);
|
|
|
|
|
if (func_graph == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unexpected meta function graph node:" << node->DebugString();
|
|
|
|
|
}
|
|
|
|
|
auto func_graph_manager = func_graph->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph_manager);
|
|
|
|
|
return func_graph_manager->func_graph_meta_fg_prim_total(func_graph);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ExpandVmapPrim::operator()(const FuncGraphPtr &, const OptimizerPtr &optimizer) {
|
|
|
|
|
// Expand vmap nodes that don't have embed j or vmap nodes.
|
|
|
|
|
bool change = false;
|
|
|
|
@ -536,32 +901,41 @@ bool ExpandVmapPrim::operator()(const FuncGraphPtr &, const OptimizerPtr &optimi
|
|
|
|
|
MS_EXCEPTION_IF_NULL(in_axes);
|
|
|
|
|
ValuePtr out_axes = VmapPrim->GetAttr("out_axes");
|
|
|
|
|
MS_EXCEPTION_IF_NULL(out_axes);
|
|
|
|
|
ValuePtr cell_size_value = VmapPrim->GetAttr("cell_size");
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cell_size_value);
|
|
|
|
|
auto cell_size = cell_size_value->isa<UInt64Imm>() ? dyn_cast<UInt64Imm>(cell_size_value)->value() : 0;
|
|
|
|
|
|
|
|
|
|
auto vmap_fn_node = vmap_node->input(1);
|
|
|
|
|
auto vmap_fg = GetValueNode<FuncGraphPtr>(vmap_fn_node);
|
|
|
|
|
auto &fn_users = manager->node_users()[vmap_fn_node];
|
|
|
|
|
size_t fn_users_size = fn_users.size();
|
|
|
|
|
AnfNodePtr vmap_fn_node = nullptr;
|
|
|
|
|
|
|
|
|
|
mindspore::HashMap<std::string, ParameterPtr> stacked_params;
|
|
|
|
|
// Record the stacked parameters, and the corresponding origin parameters from each cell, preserved
|
|
|
|
|
// for future feedback.
|
|
|
|
|
ParamMappingVector param_mapping_table;
|
|
|
|
|
|
|
|
|
|
if (cell_size > 0) {
|
|
|
|
|
// This branch handles the model ensembling parallel training case. Get one function node in the 'CellList'
|
|
|
|
|
// as the vmap function, meanwhile preprocess the cells parameters to get the stacked parameters and
|
|
|
|
|
// the parameters mapping table.
|
|
|
|
|
vmap_fn_node = internal::TraverseVmapNode(vmap_node, cell_size, &stacked_params, ¶m_mapping_table);
|
|
|
|
|
} else {
|
|
|
|
|
vmap_fn_node = vmap_node->input(1);
|
|
|
|
|
}
|
|
|
|
|
MS_EXCEPTION_IF_NULL(vmap_fn_node);
|
|
|
|
|
FuncGraphPtr vmap_fg = GetValueNode<FuncGraphPtr>(vmap_fn_node);
|
|
|
|
|
|
|
|
|
|
auto users = manager->node_users()[vmap_node];
|
|
|
|
|
if (users.size() < 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "vmap_node could used by at least one CNode, but got users.size() = " << users.size() << ".";
|
|
|
|
|
MS_EXCEPTION(ValueError) << "vmap_node could used by at least one CNode, but got users.size() = " << users.size()
|
|
|
|
|
<< ".";
|
|
|
|
|
}
|
|
|
|
|
size_t user_nb = 0;
|
|
|
|
|
size_t user_size = users.size();
|
|
|
|
|
for (auto &user : users) {
|
|
|
|
|
user_nb++;
|
|
|
|
|
|
|
|
|
|
for (auto &user : users) {
|
|
|
|
|
// When `vmap_node` has more than one user or `fn` has more than one user, the original function graph
|
|
|
|
|
// cannot be modified directly.
|
|
|
|
|
if ((user_size > 1 && user_nb != user_size) || fn_users_size > 1) {
|
|
|
|
|
MS_LOG(DEBUG) << "Funcgraph: " << vmap_fg->ToString() << " is also used outside the scope of vmap.";
|
|
|
|
|
auto vmap_fg_copy = BasicClone(vmap_fg, true);
|
|
|
|
|
auto manager_ptr = optimizer->resource()->manager();
|
|
|
|
|
manager_ptr->AddFuncGraph(vmap_fg_copy);
|
|
|
|
|
vmap_fn_node = NewValueNode(vmap_fg_copy);
|
|
|
|
|
} else {
|
|
|
|
|
vmap_fn_node = NewValueNode(vmap_fg);
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "Funcgraph: " << vmap_fg->ToString() << " is also used outside the scope of vmap.";
|
|
|
|
|
auto vmap_fg_copy = BasicClone(vmap_fg, true);
|
|
|
|
|
manager->AddFuncGraph(vmap_fg_copy);
|
|
|
|
|
vmap_fn_node = NewValueNode(vmap_fg_copy);
|
|
|
|
|
|
|
|
|
|
if (parallel::IsPynativeParallel()) {
|
|
|
|
|
auto func_graph = GetValueNode<FuncGraphPtr>(vmap_fn_node);
|
|
|
|
@ -571,26 +945,34 @@ bool ExpandVmapPrim::operator()(const FuncGraphPtr &, const OptimizerPtr &optimi
|
|
|
|
|
// get axis size, simultaneous correction the negative in_axes.
|
|
|
|
|
auto vmap_app = user.first->cast<CNodePtr>();
|
|
|
|
|
int user_index = user.second;
|
|
|
|
|
int parameters_size = SizeToInt(vmap_app->size() - 1);
|
|
|
|
|
int axis_size = internal::GetAxisSize(vmap_app, &in_axes);
|
|
|
|
|
if (axis_size == kInvalidAxisSize) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failed to get 'axis_size' within the scope of vmap.";
|
|
|
|
|
if (vmap_app->size() < 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Something went wrong, CNode vmap_app's arguments is less than 1, CNode: "
|
|
|
|
|
<< vmap_app->DebugString();
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "The axis size corresponding to the current level vmap scope is " << axis_size << ".";
|
|
|
|
|
size_t parameters_size = vmap_app->size() - 1;
|
|
|
|
|
std::vector<size_t> orig_fg_param_info;
|
|
|
|
|
(void)orig_fg_param_info.emplace_back(parameters_size);
|
|
|
|
|
int axis_size = internal::GetAxisSize(vmap_app, cell_size, parameters_size, &in_axes);
|
|
|
|
|
|
|
|
|
|
// Step1: Bind the inputs with the corresponding in_axes.
|
|
|
|
|
size_t u_monad_offset = 0;
|
|
|
|
|
auto bind_axes_node = internal::BindInAxis(vmap_app, in_axes, &u_monad_offset);
|
|
|
|
|
size_t io_monad_offset = 0;
|
|
|
|
|
auto bind_axes_node =
|
|
|
|
|
internal::BindInAxis(vmap_app, in_axes, &u_monad_offset, &io_monad_offset, ¶m_mapping_table);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(bind_axes_node);
|
|
|
|
|
(void)manager->Replace(vmap_app, bind_axes_node);
|
|
|
|
|
(void)orig_fg_param_info.emplace_back(u_monad_offset);
|
|
|
|
|
(void)orig_fg_param_info.emplace_back(io_monad_offset);
|
|
|
|
|
|
|
|
|
|
// Step2: Bind the variables with the corresponding axis, and overload the original
|
|
|
|
|
// operation with the VmapRule operation meanwhile transfer the axis information.
|
|
|
|
|
auto expanded_vmap = internal::ExpandVmap(vmap_fn_node->cast<ValueNodePtr>(), optimizer->resource(), axis_size);
|
|
|
|
|
auto expanded_vmap =
|
|
|
|
|
internal::ExpandVmap(vmap_fn_node->cast<ValueNodePtr>(), optimizer->resource(), axis_size, &stacked_params);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(expanded_vmap);
|
|
|
|
|
|
|
|
|
|
// Step3: Convert the outputs according to the out_axes to the specified physical perspective.
|
|
|
|
|
auto match_out_axis = internal::MatchOutAxis(expanded_vmap, parameters_size, u_monad_offset, axis_size, out_axes);
|
|
|
|
|
// Step3: Post processing of converted vmap function graph, including: MatchOutAxis and Parameter feedback.
|
|
|
|
|
auto match_out_axis =
|
|
|
|
|
internal::PostProcessVmap(expanded_vmap, orig_fg_param_info, out_axes, axis_size, ¶m_mapping_table);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(match_out_axis);
|
|
|
|
|
manager->SetEdge(bind_axes_node, user_index, match_out_axis);
|
|
|
|
|
}
|
|
|
|
|