modify the error message of vmaprule of P.Switch, and make GeneralPreprocess support list inputs.
This commit is contained in:
parent
cc296825d3
commit
a7bdf04cb8
|
@ -313,12 +313,12 @@ FuncGraphPtr VmapGeneralPreprocess::GenerateFuncGraph(const AbstractBasePtrList
|
||||||
&offset](const AbstractBasePtrList &args_spec_list) -> const AbstractBasePtrList & {
|
&offset](const AbstractBasePtrList &args_spec_list) -> const AbstractBasePtrList & {
|
||||||
if (args_size == 2) {
|
if (args_size == 2) {
|
||||||
auto arg = args_spec_list[1];
|
auto arg = args_spec_list[1];
|
||||||
if (!arg->isa<abstract::AbstractTuple>()) {
|
if (!arg->isa<abstract::AbstractSequence>()) {
|
||||||
MS_LOG(EXCEPTION) << "The second input to VmapGeneralPreprocess should be AbstractTuple but got: "
|
MS_LOG(EXCEPTION) << "The second input to VmapGeneralPreprocess should be AbstractSequence but got: "
|
||||||
<< arg->ToString() << ".";
|
<< arg->ToString() << ".";
|
||||||
}
|
}
|
||||||
auto arg_tuple = arg->cast<abstract::AbstractTuplePtr>();
|
auto arg_seq = arg->cast<abstract::AbstractSequencePtr>();
|
||||||
const auto &arg_tuple_elements = arg_tuple->elements();
|
const auto &arg_tuple_elements = arg_seq->elements();
|
||||||
if (arg_tuple_elements.back()->isa<abstract::AbstractTuple>()) {
|
if (arg_tuple_elements.back()->isa<abstract::AbstractTuple>()) {
|
||||||
// Operators with indefinite inputs length, such as `AddN`, whose inputs is wrapped
|
// Operators with indefinite inputs length, such as `AddN`, whose inputs is wrapped
|
||||||
// into a tuple. We need to process the internal elements separately and then re-wrap
|
// into a tuple. We need to process the internal elements separately and then re-wrap
|
||||||
|
|
|
@ -83,7 +83,9 @@ def get_partical_vmap_rule(prim, axis_size):
|
||||||
else:
|
else:
|
||||||
val, dim = val_bdim
|
val, dim = val_bdim
|
||||||
if dim is not None:
|
if dim is not None:
|
||||||
_raise_value_error("The source axis of args in {} must be None, "
|
_raise_value_error("In the scenario where vmap contains control flow, currently only the "
|
||||||
|
"case of each batch branch with the same processing operations is "
|
||||||
|
"supported, so that the source axis of args in {} must be None, "
|
||||||
"but got {}.".format(prim_name, dim))
|
"but got {}.".format(prim_name, dim))
|
||||||
vals = vals + (val,)
|
vals = vals + (val,)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue