forked from mindspore-Ecosystem/mindspore
add adaptive maxpool3d grad gpu vmap rule
This commit is contained in:
parent
5a5491eb04
commit
0ad330a069
|
@ -32,10 +32,8 @@ class MIND_API CumProd : public BaseOperator {
|
|||
void Init(const bool exclusive, const bool reverse);
|
||||
void SetExclusive(const bool exclusive);
|
||||
void SetReverse(const bool reverse);
|
||||
void SetAxis(const int64_t axis);
|
||||
bool GetExclusive() const;
|
||||
bool GetReverse() const;
|
||||
int64_t GetAxis() const;
|
||||
};
|
||||
abstract::AbstractBasePtr CumProdInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
|
|
|
@ -38,8 +38,6 @@ class MIND_API CumSum : public BaseOperator {
|
|||
void set_exclusive(const bool exclusive);
|
||||
/// \brief Set reverse.
|
||||
void set_reverse(const bool reverse);
|
||||
/// \brief Set axis.
|
||||
void set_axis(const int64_t axis);
|
||||
/// \brief Get exclusive.
|
||||
///
|
||||
/// \return exclusive.
|
||||
|
@ -49,8 +47,6 @@ class MIND_API CumSum : public BaseOperator {
|
|||
/// \return reverse.
|
||||
bool get_reverse() const;
|
||||
///
|
||||
/// \return axis.
|
||||
int64_t get_axis() const;
|
||||
};
|
||||
abstract::AbstractBasePtr CumSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
|
|
|
@ -226,11 +226,15 @@ def get_cdist_grad_vmap_rule(prim, axis_size):
|
|||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(G.AdaptiveMaxPool3DGrad)
|
||||
@vmap_rules_getters.register(G.AdaptiveMaxPool2DGrad)
|
||||
def get_adaptive_avgpool2d_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `AdaptiveMaxPool2DGrad` operation."""
|
||||
"""VmapRule for `AdaptiveMaxPool2DGrad` and `AdaptiveMaxPool3DGrad` operation."""
|
||||
chw_reverse_index = -3
|
||||
hw_reverse_index = -2
|
||||
if prim.name == "AdaptiveMaxPool2DGrad":
|
||||
hw_reverse_index = -2
|
||||
else:
|
||||
hw_reverse_index = -3
|
||||
|
||||
def vmap_rule(ygrad_bdim, x_bdim, max_index_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, ygrad_bdim, x_bdim, max_index_bdim)
|
||||
|
@ -353,7 +357,7 @@ def get_batchnorm_grad_vmap_rule(prim, axis_size):
|
|||
if is_all_none:
|
||||
return result
|
||||
if data_format == "NHWC":
|
||||
#BatchNormGrad with NHWC format is a GPU backend operation and not supported for now.
|
||||
# BatchNormGrad with NHWC format is a GPU backend operation and not supported for now.
|
||||
return batchnorm_grad_nhwc_vmap(grad_bdim, x_bdim, scale_bdim, rsv_1_bdim, rsv_2_bdim, rsv_3_bdim)
|
||||
grad, grad_dim = grad_bdim
|
||||
input_x, input_x_dim = x_bdim
|
||||
|
|
|
@ -3,6 +3,7 @@ protobuf >= 3.13.0
|
|||
asttokens >= 2.0.4
|
||||
pillow >= 6.2.0
|
||||
scipy >= 1.5.4
|
||||
decorator >= 4.4.0
|
||||
matplotlib >= 3.1.3 # for ut test
|
||||
opencv-python >= 4.1.2.30 # for ut test
|
||||
sklearn >= 0.0 # for st test
|
||||
|
|
Loading…
Reference in New Issue