modify the error message of vmaprule of P.Switch, and make GeneralPreprocess support list inputs.

This commit is contained in:
Erpim 2022-10-31 21:06:43 +08:00
parent cc296825d3
commit a7bdf04cb8
2 changed files with 7 additions and 5 deletions

View File

@ -313,12 +313,12 @@ FuncGraphPtr VmapGeneralPreprocess::GenerateFuncGraph(const AbstractBasePtrList
&offset](const AbstractBasePtrList &args_spec_list) -> const AbstractBasePtrList & { &offset](const AbstractBasePtrList &args_spec_list) -> const AbstractBasePtrList & {
if (args_size == 2) { if (args_size == 2) {
auto arg = args_spec_list[1]; auto arg = args_spec_list[1];
if (!arg->isa<abstract::AbstractTuple>()) { if (!arg->isa<abstract::AbstractSequence>()) {
MS_LOG(EXCEPTION) << "The second input to VmapGeneralPreprocess should be AbstractTuple but got: " MS_LOG(EXCEPTION) << "The second input to VmapGeneralPreprocess should be AbstractSequence but got: "
<< arg->ToString() << "."; << arg->ToString() << ".";
} }
auto arg_tuple = arg->cast<abstract::AbstractTuplePtr>(); auto arg_seq = arg->cast<abstract::AbstractSequencePtr>();
const auto &arg_tuple_elements = arg_tuple->elements(); const auto &arg_tuple_elements = arg_seq->elements();
if (arg_tuple_elements.back()->isa<abstract::AbstractTuple>()) { if (arg_tuple_elements.back()->isa<abstract::AbstractTuple>()) {
// Operators with indefinite inputs length, such as `AddN`, whose inputs is wrapped // Operators with indefinite inputs length, such as `AddN`, whose inputs is wrapped
// into a tuple. We need to process the internal elements separately and then re-wrap // into a tuple. We need to process the internal elements separately and then re-wrap

View File

@ -83,7 +83,9 @@ def get_partical_vmap_rule(prim, axis_size):
else: else:
val, dim = val_bdim val, dim = val_bdim
if dim is not None: if dim is not None:
_raise_value_error("The source axis of args in {} must be None, " _raise_value_error("In the scenario where vmap contains control flow, currently only the "
"case of each batch branch with the same processing operations is "
"supported, so that the source axis of args in {} must be None, "
"but got {}.".format(prim_name, dim)) "but got {}.".format(prim_name, dim))
vals = vals + (val,) vals = vals + (val,)