add adaptive maxpool3d grad gpu vmap rule

This commit is contained in:
dabaiji 2022-11-15 17:33:14 +08:00
parent 5a5491eb04
commit 0ad330a069
4 changed files with 8 additions and 9 deletions

View File

@ -32,10 +32,8 @@ class MIND_API CumProd : public BaseOperator {
void Init(const bool exclusive, const bool reverse);
void SetExclusive(const bool exclusive);
void SetReverse(const bool reverse);
void SetAxis(const int64_t axis);
bool GetExclusive() const;
bool GetReverse() const;
int64_t GetAxis() const;
};
abstract::AbstractBasePtr CumProdInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);

View File

@ -38,8 +38,6 @@ class MIND_API CumSum : public BaseOperator {
void set_exclusive(const bool exclusive);
/// \brief Set reverse.
void set_reverse(const bool reverse);
/// \brief Set axis.
void set_axis(const int64_t axis);
/// \brief Get exclusive.
///
/// \return exclusive.
@ -49,8 +47,6 @@ class MIND_API CumSum : public BaseOperator {
/// \return reverse.
bool get_reverse() const;
///
/// \return axis.
int64_t get_axis() const;
};
abstract::AbstractBasePtr CumSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);

View File

@ -226,11 +226,15 @@ def get_cdist_grad_vmap_rule(prim, axis_size):
return vmap_rule
@vmap_rules_getters.register(G.AdaptiveMaxPool3DGrad)
@vmap_rules_getters.register(G.AdaptiveMaxPool2DGrad)
def get_adaptive_avgpool2d_vmap_rule(prim, axis_size):
"""VmapRule for `AdaptiveMaxPool2DGrad` operation."""
"""VmapRule for `AdaptiveMaxPool2DGrad` and `AdaptiveMaxPool3DGrad` operation."""
chw_reverse_index = -3
hw_reverse_index = -2
if prim.name == "AdaptiveMaxPool2DGrad":
hw_reverse_index = -2
else:
hw_reverse_index = -3
def vmap_rule(ygrad_bdim, x_bdim, max_index_bdim):
is_all_none, result = vmap_general_preprocess(prim, ygrad_bdim, x_bdim, max_index_bdim)
@ -353,7 +357,7 @@ def get_batchnorm_grad_vmap_rule(prim, axis_size):
if is_all_none:
return result
if data_format == "NHWC":
#BatchNormGrad with NHWC format is a GPU backend operation and not supported for now.
# BatchNormGrad with NHWC format is a GPU backend operation and not supported for now.
return batchnorm_grad_nhwc_vmap(grad_bdim, x_bdim, scale_bdim, rsv_1_bdim, rsv_2_bdim, rsv_3_bdim)
grad, grad_dim = grad_bdim
input_x, input_x_dim = x_bdim

View File

@ -3,6 +3,7 @@ protobuf >= 3.13.0
asttokens >= 2.0.4
pillow >= 6.2.0
scipy >= 1.5.4
decorator >= 4.4.0
matplotlib >= 3.1.3 # for ut test
opencv-python >= 4.1.2.30 # for ut test
sklearn >= 0.0 # for st test