forked from mindspore-Ecosystem/mindspore
fix bug of GatherDGradV2-vmap when dim is negative
This commit is contained in:
parent
b4402302df
commit
57ca6777bf
|
@ -34,9 +34,8 @@ class CropAndResizeInfer : public abstract::OpInferBase {
|
|||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(
|
||||
input_args.size() == kCropAndResizeInputSize,
|
||||
"For primitive[" + prim_name + "], [input number] must be 4 but got " + std::to_string(input_args.size()));
|
||||
(void)CheckAndConvertUtils::CheckInteger("[input number]", static_cast<int64_t>(input_args.size()), kEqual,
|
||||
kCropAndResizeInputSize, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -58,13 +57,12 @@ class CropAndResizeInfer : public abstract::OpInferBase {
|
|||
std::vector<int64_t>{UNKNOWN_DIM, UNKNOWN_DIM, UNKNOWN_DIM, UNKNOWN_DIM});
|
||||
}
|
||||
|
||||
size_t batch_rank = 0;
|
||||
int64_t batch_rank = 0;
|
||||
if (primitive->HasAttr(kBatchRank)) {
|
||||
batch_rank = GetValue<int64_t>(primitive->GetAttr(kBatchRank));
|
||||
}
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(x_shape.size() == kShapeRank4 + batch_rank,
|
||||
"For primitive[" + prim_name + "], the [x shape-length] should be 4, bug got " +
|
||||
std::to_string(static_cast<int>(x_shape.size()) - static_cast<int>(batch_rank)) + ".");
|
||||
int64_t x_dims = static_cast<int64_t>(x_shape.size()) - batch_rank;
|
||||
(void)CheckAndConvertUtils::CheckInteger("[x shape-length]", x_dims, kEqual, kShapeRank4, prim_name);
|
||||
int64_t out_channel = x_shape.back();
|
||||
|
||||
std::vector<int64_t> batch_shape(x_shape.begin(), x_shape.begin() + static_cast<int>(batch_rank));
|
||||
|
@ -78,8 +76,7 @@ class CropAndResizeInfer : public abstract::OpInferBase {
|
|||
crop_size = CheckAndConvertUtils::CheckTensorIntValue("crop_size", value_ptr, prim_name);
|
||||
} else if (IsIdentidityOrSubclass(crop_size_type, kTuple)) {
|
||||
auto value_tuple = value_ptr->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(value_tuple != nullptr,
|
||||
"For primitive[" + prim_name + "], the [crop_size] must a Tuple.");
|
||||
MS_EXCEPTION_IF_NULL(value_tuple);
|
||||
auto &elements = value_tuple->value();
|
||||
for (const auto &element : elements) {
|
||||
if (element->isa<Int64Imm>()) {
|
||||
|
@ -87,18 +84,20 @@ class CropAndResizeInfer : public abstract::OpInferBase {
|
|||
} else {
|
||||
auto type = element->type();
|
||||
std::string real_type_str = type == nullptr ? "Unknown." : type->ToString() + ".";
|
||||
MS_LOG(EXCEPTION) << ("For primitive[" + prim_name +
|
||||
"], the [crop_size] must be a tuple with two Int elements, but got " + real_type_str);
|
||||
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name
|
||||
<< "], the [crop_size] must be a tuple with two Int elements, but got "
|
||||
<< real_type_str;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << ("For primitive[" + prim_name +
|
||||
"], the [crop_size] is must be a Tensor or a Tuple with two Int elements, but got " +
|
||||
crop_size_type->ToString());
|
||||
MS_EXCEPTION(TypeError) << "For primitive[" + prim_name
|
||||
<< "], the [crop_size] is must be a Tensor or a Tuple with two Int elements, but got "
|
||||
<< crop_size_type->ToString();
|
||||
}
|
||||
CheckAndConvertUtils::Check("crop_size length", crop_size.size(), kEqual, kShapeRank2, prim_name);
|
||||
CheckAndConvertUtils::Check("crop height", crop_size[0], kGreaterThan, 0, prim_name);
|
||||
CheckAndConvertUtils::Check("crop weight", crop_size.back(), kGreaterThan, 0, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("[crop_size length]", static_cast<int64_t>(crop_size.size()), kEqual,
|
||||
kShapeRank2, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("[crop height]", crop_size[0], kGreaterThan, 0, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("[crop weight]", crop_size.back(), kGreaterThan, 0, prim_name);
|
||||
ShapeVector out_shape = {num_boxes, crop_size[0], crop_size.back(), out_channel};
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
@ -106,9 +105,8 @@ class CropAndResizeInfer : public abstract::OpInferBase {
|
|||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(input_args.size() == kCropAndResizeInputSize,
|
||||
"For primitive[" + prim_name + "], the [x shape-length] should be 4, bug got " +
|
||||
std::to_string(input_args.size()) + ".");
|
||||
(void)CheckAndConvertUtils::CheckInteger("[input number]", static_cast<int64_t>(input_args.size()), kEqual,
|
||||
kCropAndResizeInputSize, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -125,37 +123,33 @@ class CropAndResizeInfer : public abstract::OpInferBase {
|
|||
protected:
|
||||
int64_t ParseNumBoxes(const ShapeVector &box_shape, const ShapeVector &box_index_shape, const std::string &prim_name,
|
||||
const std::vector<int64_t> &batch_shape) const {
|
||||
size_t batch_rank = batch_shape.size();
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(box_shape.size() == kShapeRank2 + batch_rank,
|
||||
"For primitive[" + prim_name + "], the [boxes shape-length] should be 2, bug got " +
|
||||
std::to_string(static_cast<int>(box_shape.size()) - static_cast<int>(batch_rank)) +
|
||||
".");
|
||||
auto batch_rank = static_cast<int64_t>(batch_shape.size());
|
||||
int64_t box_dims = static_cast<int64_t>(box_shape.size()) - batch_rank;
|
||||
(void)CheckAndConvertUtils::CheckInteger("[boxes shape-length]", box_dims, kEqual, kShapeRank2, prim_name);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(
|
||||
batch_shape == std::vector<int64_t>(box_shape.begin(), box_shape.begin() + batch_rank),
|
||||
"For primitive[" + prim_name + "], the [batch_shape] of boxes is not equal to that of input.");
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(box_shape.back() == kLimitValue4, "For primitive[" + prim_name +
|
||||
"], the [boxes second-dim] must be 4, but got " +
|
||||
std::to_string(box_shape.back()) + ".");
|
||||
(void)CheckAndConvertUtils::CheckInteger("[boxes second-dim]", box_shape.back(), kEqual, kLimitValue4, prim_name);
|
||||
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(
|
||||
box_index_shape.size() == 1 + batch_rank,
|
||||
"For primitive[" + prim_name + "], the [box_index shape-length] should be 1, bug got " +
|
||||
std::to_string(static_cast<int>(box_index_shape.size()) - static_cast<int>(batch_rank)) + ".");
|
||||
int64_t box_index_dims = static_cast<int64_t>(box_index_shape.size()) - batch_rank;
|
||||
(void)CheckAndConvertUtils::CheckInteger("[box_index shape-length]", box_index_dims, kEqual, 1, prim_name);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(
|
||||
batch_shape == std::vector<int64_t>(box_index_shape.begin(), box_index_shape.begin() + batch_rank),
|
||||
"For primitive[" + prim_name + "], the [batch_shape] of box_index is not equal to that of input.");
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(
|
||||
box_shape[batch_rank] == box_index_shape[batch_rank],
|
||||
"For primitive[" + prim_name + "], the [boxes first-dim] must be equal to [box_index first-dim], but got " +
|
||||
std::to_string(box_shape[batch_rank]) + " vs " + std::to_string(box_index_shape[batch_rank]) + ".");
|
||||
if (box_shape[batch_rank] != box_index_shape[batch_rank]) {
|
||||
MS_EXCEPTION(ValueError) << "For primitive[" + prim_name +
|
||||
"], the [boxes first-dim] must be equal to [box_index first-dim], but got " +
|
||||
std::to_string(box_shape[batch_rank]) + " vs " +
|
||||
std::to_string(box_index_shape[batch_rank]) + ".";
|
||||
}
|
||||
return box_shape[batch_rank];
|
||||
}
|
||||
|
||||
private:
|
||||
const int64_t kLimitValue4 = 4;
|
||||
const size_t kCropAndResizeInputSize = 4;
|
||||
const size_t kShapeRank2 = 2;
|
||||
const size_t kShapeRank4 = 4;
|
||||
const int64_t kCropAndResizeInputSize = 4;
|
||||
const int64_t kShapeRank2 = 2;
|
||||
const int64_t kShapeRank4 = 4;
|
||||
};
|
||||
|
||||
void CropAndResize::Init(ResizeMethod method, float extrapolation_value) {
|
||||
|
|
|
@ -1303,6 +1303,19 @@ def get_gatherd_grad_v2_vmap_rule(prim, axis_size):
|
|||
if hasattr(prim, 'dim'):
|
||||
dim = prim.dim
|
||||
|
||||
@constexpr
|
||||
def _update_attr(x_rank, batch_dim):
|
||||
pdim = dim
|
||||
if pdim < 0:
|
||||
pdim += x_rank
|
||||
if pdim < 0 or pdim >= x_rank:
|
||||
_raise_value_error(
|
||||
"The `dim` in `GatherDGradV2` must be in range [{}, {}], but got {}.".format(-x_rank, x_rank - 1, dim))
|
||||
if pdim >= batch_dim:
|
||||
_vmap_update_prim_attr(prim, 'dim', pdim + 1)
|
||||
elif dim < 0:
|
||||
_vmap_update_prim_attr(prim, 'dim', pdim)
|
||||
|
||||
def vmap_rule(x_bdim, index_bdim, grad_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, x_bdim, index_bdim, grad_bdim)
|
||||
if is_all_none:
|
||||
|
@ -1326,14 +1339,7 @@ def get_gatherd_grad_v2_vmap_rule(prim, axis_size):
|
|||
|
||||
# Adjust dim-attr if needed
|
||||
x_rank = F.rank(x) - 1
|
||||
pdim = dim
|
||||
if pdim < 0:
|
||||
pdim += x_rank
|
||||
if pdim < 0 or pdim >= x_rank:
|
||||
_raise_value_error(
|
||||
"The `dim` in `GatherDGradV2` must be in range [{}, {}], but got {}.".format(-x_rank, x_rank - 1, dim))
|
||||
if pdim >= batch_dim:
|
||||
_vmap_update_prim_attr(prim, 'dim', pdim + 1)
|
||||
_update_attr(x_rank, batch_dim)
|
||||
|
||||
out = prim(x, index, grad)
|
||||
return (out, batch_dim)
|
||||
|
|
Loading…
Reference in New Issue