modify the error message of vmaprule of P.Switch, and make GeneralPreprocess support list inputs.
This commit is contained in:
parent
cc296825d3
commit
a7bdf04cb8
|
@ -313,12 +313,12 @@ FuncGraphPtr VmapGeneralPreprocess::GenerateFuncGraph(const AbstractBasePtrList
|
|||
&offset](const AbstractBasePtrList &args_spec_list) -> const AbstractBasePtrList & {
|
||||
if (args_size == 2) {
|
||||
auto arg = args_spec_list[1];
|
||||
if (!arg->isa<abstract::AbstractTuple>()) {
|
||||
MS_LOG(EXCEPTION) << "The second input to VmapGeneralPreprocess should be AbstractTuple but got: "
|
||||
if (!arg->isa<abstract::AbstractSequence>()) {
|
||||
MS_LOG(EXCEPTION) << "The second input to VmapGeneralPreprocess should be AbstractSequence but got: "
|
||||
<< arg->ToString() << ".";
|
||||
}
|
||||
auto arg_tuple = arg->cast<abstract::AbstractTuplePtr>();
|
||||
const auto &arg_tuple_elements = arg_tuple->elements();
|
||||
auto arg_seq = arg->cast<abstract::AbstractSequencePtr>();
|
||||
const auto &arg_tuple_elements = arg_seq->elements();
|
||||
if (arg_tuple_elements.back()->isa<abstract::AbstractTuple>()) {
|
||||
// Operators with indefinite inputs length, such as `AddN`, whose inputs is wrapped
|
||||
// into a tuple. We need to process the internal elements separately and then re-wrap
|
||||
|
|
|
@ -83,7 +83,9 @@ def get_partical_vmap_rule(prim, axis_size):
|
|||
else:
|
||||
val, dim = val_bdim
|
||||
if dim is not None:
|
||||
_raise_value_error("The source axis of args in {} must be None, "
|
||||
_raise_value_error("In the scenario where vmap contains control flow, currently only the "
|
||||
"case of each batch branch with the same processing operations is "
|
||||
"supported, so that the source axis of args in {} must be None, "
|
||||
"but got {}.".format(prim_name, dim))
|
||||
vals = vals + (val,)
|
||||
|
||||
|
|
Loading…
Reference in New Issue