fix concat to attr and pad vmap
This commit is contained in:
parent
72a44d9012
commit
26ef2e292e
|
@ -23,7 +23,6 @@ namespace mindspore::opt {
|
|||
#define RER_ASCEND_DYNAMIC_CONST_TO_ATTR(op_name, ...) RER_CONST_TO_ATTR_LIST(op_name, kAscendDevice, true, __VA_ARGS__)
|
||||
|
||||
RER_ASCEND_DYNAMIC_CONST_TO_ATTR(kCastOpName, 1);
|
||||
RER_ASCEND_DYNAMIC_CONST_TO_ATTR(kConcatOpName, 0);
|
||||
RER_ASCEND_DYNAMIC_CONST_TO_ATTR(kEmbeddingLookupOpName, 2, 3, 4, 5);
|
||||
RER_ASCEND_DYNAMIC_CONST_TO_ATTR(kExpandDimsOpName, 1);
|
||||
RER_ASCEND_DYNAMIC_CONST_TO_ATTR(kGatherDGradV2OpName, 1);
|
||||
|
@ -38,7 +37,6 @@ RER_ASCEND_STATIC_CONST_TO_ATTR(kAvgPoolGradVmOpName, 0);
|
|||
RER_ASCEND_STATIC_CONST_TO_ATTR(kBatchToSpaceOpName, 1);
|
||||
RER_ASCEND_STATIC_CONST_TO_ATTR(kCastOpName, 1);
|
||||
RER_ASCEND_STATIC_CONST_TO_ATTR(kCentralizationOpName, 1);
|
||||
RER_ASCEND_STATIC_CONST_TO_ATTR(kConcatOpName, 0);
|
||||
RER_ASCEND_STATIC_CONST_TO_ATTR(kConv2DBackpropFilterOpName, 2);
|
||||
RER_ASCEND_STATIC_CONST_TO_ATTR(kConv2DBackpropInputOpName, 2);
|
||||
RER_ASCEND_STATIC_CONST_TO_ATTR(kConv2DTransposeOpName, 2);
|
||||
|
|
|
@ -23,7 +23,6 @@ namespace mindspore::opt {
|
|||
#define RER_CPU_DYNAMIC_CONST_TO_ATTR(op_name, ...) RER_CONST_TO_ATTR_LIST(op_name, kCPUDevice, true, __VA_ARGS__)
|
||||
|
||||
RER_CPU_DYNAMIC_CONST_TO_ATTR(kCastOpName, 1);
|
||||
RER_CPU_DYNAMIC_CONST_TO_ATTR(kConcatOpName, 0);
|
||||
RER_CPU_DYNAMIC_CONST_TO_ATTR(kFillOpName, 0);
|
||||
RER_CPU_DYNAMIC_CONST_TO_ATTR(kReduceAllOpName, 1);
|
||||
RER_CPU_DYNAMIC_CONST_TO_ATTR(kReduceAnyOpName, 1);
|
||||
|
|
|
@ -1072,24 +1072,24 @@ def get_pad_v3_vmap_rule(prim, axis_size):
|
|||
if is_all_none:
|
||||
return result
|
||||
if len(params_bdim) < 2:
|
||||
_raise_value_error("The input params in `{}` must >= 2, "
|
||||
"but got {}.".format(prim.name, len(params_bdim)))
|
||||
_raise_value_error("The input params in `PadV3` must >= 2, "
|
||||
"but got {}.".format(len(params_bdim)))
|
||||
input_x, input_x_dim = params_bdim[0]
|
||||
paddings, paddings_dim = params_bdim[1]
|
||||
values = None
|
||||
out = None
|
||||
x = _bdim_at_front(input_x, input_x_dim, axis_size)
|
||||
if paddings_dim is not None:
|
||||
_raise_value_error("The source axis of `paddings` in `{}` must be None, "
|
||||
"but got {}.".format(prim.name, paddings_dim))
|
||||
_raise_value_error("The source axis of `paddings` in `PadV3` must be None, "
|
||||
"but got {}.".format(paddings_dim))
|
||||
if mode == "constant":
|
||||
if len(params_bdim) != 3:
|
||||
_raise_value_error("The input params in `{}` of constant mode must be 3, "
|
||||
"but got {}.".format(prim.name, len(params_bdim)))
|
||||
_raise_value_error("The input params in `PadV3` of constant mode must be 3, "
|
||||
"but got {}.".format(len(params_bdim)))
|
||||
values, values_dim = params_bdim[2]
|
||||
if values_dim is not None:
|
||||
_raise_value_error("The source axis of `values_dim` in `{}` must be None, "
|
||||
"but got {}.".format(prim.name, values_dim))
|
||||
_raise_value_error("The source axis of `values_dim` in `PadV3` must be None, "
|
||||
"but got {}.".format(values_dim))
|
||||
if isinstance(paddings, Tensor):
|
||||
pad_dim = F.shape(paddings)[0] / pad_pair
|
||||
else:
|
||||
|
@ -1118,8 +1118,8 @@ def get_pad_v3_vmap_rule(prim, axis_size):
|
|||
real_out_shape = x_shape[:diff_dim + 1] + out_shape[1:]
|
||||
out = F.reshape(out, real_out_shape)
|
||||
else:
|
||||
_raise_value_error("The dim of `input_x` in `{}` must be bigger than {}, "
|
||||
"but got {}.".format(prim.name, pad_dim, x_ndim))
|
||||
_raise_value_error("The dim of `input_x` in `PadV3` must be bigger than {}, "
|
||||
"but got {}.".format(pad_dim, x_ndim))
|
||||
return out, 0
|
||||
|
||||
return vmap_rule
|
||||
|
|
Loading…
Reference in New Issue