fix concat to attr and pad vmap

This commit is contained in:
fangzehua 2022-11-08 09:43:42 +08:00
parent 72a44d9012
commit 26ef2e292e
3 changed files with 10 additions and 13 deletions

View File

@ -23,7 +23,6 @@ namespace mindspore::opt {
#define RER_ASCEND_DYNAMIC_CONST_TO_ATTR(op_name, ...) RER_CONST_TO_ATTR_LIST(op_name, kAscendDevice, true, __VA_ARGS__)
RER_ASCEND_DYNAMIC_CONST_TO_ATTR(kCastOpName, 1);
RER_ASCEND_DYNAMIC_CONST_TO_ATTR(kConcatOpName, 0);
RER_ASCEND_DYNAMIC_CONST_TO_ATTR(kEmbeddingLookupOpName, 2, 3, 4, 5);
RER_ASCEND_DYNAMIC_CONST_TO_ATTR(kExpandDimsOpName, 1);
RER_ASCEND_DYNAMIC_CONST_TO_ATTR(kGatherDGradV2OpName, 1);
@ -38,7 +37,6 @@ RER_ASCEND_STATIC_CONST_TO_ATTR(kAvgPoolGradVmOpName, 0);
RER_ASCEND_STATIC_CONST_TO_ATTR(kBatchToSpaceOpName, 1);
RER_ASCEND_STATIC_CONST_TO_ATTR(kCastOpName, 1);
RER_ASCEND_STATIC_CONST_TO_ATTR(kCentralizationOpName, 1);
RER_ASCEND_STATIC_CONST_TO_ATTR(kConcatOpName, 0);
RER_ASCEND_STATIC_CONST_TO_ATTR(kConv2DBackpropFilterOpName, 2);
RER_ASCEND_STATIC_CONST_TO_ATTR(kConv2DBackpropInputOpName, 2);
RER_ASCEND_STATIC_CONST_TO_ATTR(kConv2DTransposeOpName, 2);

View File

@ -23,7 +23,6 @@ namespace mindspore::opt {
#define RER_CPU_DYNAMIC_CONST_TO_ATTR(op_name, ...) RER_CONST_TO_ATTR_LIST(op_name, kCPUDevice, true, __VA_ARGS__)
RER_CPU_DYNAMIC_CONST_TO_ATTR(kCastOpName, 1);
RER_CPU_DYNAMIC_CONST_TO_ATTR(kConcatOpName, 0);
RER_CPU_DYNAMIC_CONST_TO_ATTR(kFillOpName, 0);
RER_CPU_DYNAMIC_CONST_TO_ATTR(kReduceAllOpName, 1);
RER_CPU_DYNAMIC_CONST_TO_ATTR(kReduceAnyOpName, 1);

View File

@ -1072,24 +1072,24 @@ def get_pad_v3_vmap_rule(prim, axis_size):
if is_all_none:
return result
if len(params_bdim) < 2:
_raise_value_error("The input params in `{}` must >= 2, "
"but got {}.".format(prim.name, len(params_bdim)))
_raise_value_error("The input params in `PadV3` must >= 2, "
"but got {}.".format(len(params_bdim)))
input_x, input_x_dim = params_bdim[0]
paddings, paddings_dim = params_bdim[1]
values = None
out = None
x = _bdim_at_front(input_x, input_x_dim, axis_size)
if paddings_dim is not None:
_raise_value_error("The source axis of `paddings` in `{}` must be None, "
"but got {}.".format(prim.name, paddings_dim))
_raise_value_error("The source axis of `paddings` in `PadV3` must be None, "
"but got {}.".format(paddings_dim))
if mode == "constant":
if len(params_bdim) != 3:
_raise_value_error("The input params in `{}` of constant mode must be 3, "
"but got {}.".format(prim.name, len(params_bdim)))
_raise_value_error("The input params in `PadV3` of constant mode must be 3, "
"but got {}.".format(len(params_bdim)))
values, values_dim = params_bdim[2]
if values_dim is not None:
_raise_value_error("The source axis of `values_dim` in `{}` must be None, "
"but got {}.".format(prim.name, values_dim))
_raise_value_error("The source axis of `values_dim` in `PadV3` must be None, "
"but got {}.".format(values_dim))
if isinstance(paddings, Tensor):
pad_dim = F.shape(paddings)[0] / pad_pair
else:
@ -1118,8 +1118,8 @@ def get_pad_v3_vmap_rule(prim, axis_size):
real_out_shape = x_shape[:diff_dim + 1] + out_shape[1:]
out = F.reshape(out, real_out_shape)
else:
_raise_value_error("The dim of `input_x` in `{}` must be bigger than {}, "
"but got {}.".format(prim.name, pad_dim, x_ndim))
_raise_value_error("The dim of `input_x` in `PadV3` must be bigger than {}, "
"but got {}.".format(pad_dim, x_ndim))
return out, 0
return vmap_rule