fix bug of GatherDGradV2-vmap when dim is negative

This commit is contained in:
xuanyue 2022-09-19 10:17:56 +08:00
parent b4402302df
commit 57ca6777bf
2 changed files with 47 additions and 47 deletions

View File

@ -34,9 +34,8 @@ class CropAndResizeInfer : public abstract::OpInferBase {
const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
MS_EXCEPTION_IF_CHECK_FAIL(
input_args.size() == kCropAndResizeInputSize,
"For primitive[" + prim_name + "], [input number] must be 4 but got " + std::to_string(input_args.size()));
(void)CheckAndConvertUtils::CheckInteger("[input number]", static_cast<int64_t>(input_args.size()), kEqual,
kCropAndResizeInputSize, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
@ -58,13 +57,12 @@ class CropAndResizeInfer : public abstract::OpInferBase {
std::vector<int64_t>{UNKNOWN_DIM, UNKNOWN_DIM, UNKNOWN_DIM, UNKNOWN_DIM});
}
size_t batch_rank = 0;
int64_t batch_rank = 0;
if (primitive->HasAttr(kBatchRank)) {
batch_rank = GetValue<int64_t>(primitive->GetAttr(kBatchRank));
}
MS_EXCEPTION_IF_CHECK_FAIL(x_shape.size() == kShapeRank4 + batch_rank,
"For primitive[" + prim_name + "], the [x shape-length] should be 4, bug got " +
std::to_string(static_cast<int>(x_shape.size()) - static_cast<int>(batch_rank)) + ".");
int64_t x_dims = static_cast<int64_t>(x_shape.size()) - batch_rank;
(void)CheckAndConvertUtils::CheckInteger("[x shape-length]", x_dims, kEqual, kShapeRank4, prim_name);
int64_t out_channel = x_shape.back();
std::vector<int64_t> batch_shape(x_shape.begin(), x_shape.begin() + static_cast<int>(batch_rank));
@ -78,8 +76,7 @@ class CropAndResizeInfer : public abstract::OpInferBase {
crop_size = CheckAndConvertUtils::CheckTensorIntValue("crop_size", value_ptr, prim_name);
} else if (IsIdentidityOrSubclass(crop_size_type, kTuple)) {
auto value_tuple = value_ptr->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_CHECK_FAIL(value_tuple != nullptr,
"For primitive[" + prim_name + "], the [crop_size] must a Tuple.");
MS_EXCEPTION_IF_NULL(value_tuple);
auto &elements = value_tuple->value();
for (const auto &element : elements) {
if (element->isa<Int64Imm>()) {
@ -87,18 +84,20 @@ class CropAndResizeInfer : public abstract::OpInferBase {
} else {
auto type = element->type();
std::string real_type_str = type == nullptr ? "Unknown." : type->ToString() + ".";
MS_LOG(EXCEPTION) << ("For primitive[" + prim_name +
"], the [crop_size] must be a tuple with two Int elements, but got " + real_type_str);
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name
<< "], the [crop_size] must be a tuple with two Int elements, but got "
<< real_type_str;
}
}
} else {
MS_LOG(EXCEPTION) << ("For primitive[" + prim_name +
"], the [crop_size] is must be a Tensor or a Tuple with two Int elements, but got " +
crop_size_type->ToString());
MS_EXCEPTION(TypeError) << "For primitive[" + prim_name
<< "], the [crop_size] is must be a Tensor or a Tuple with two Int elements, but got "
<< crop_size_type->ToString();
}
CheckAndConvertUtils::Check("crop_size length", crop_size.size(), kEqual, kShapeRank2, prim_name);
CheckAndConvertUtils::Check("crop height", crop_size[0], kGreaterThan, 0, prim_name);
CheckAndConvertUtils::Check("crop weight", crop_size.back(), kGreaterThan, 0, prim_name);
(void)CheckAndConvertUtils::CheckInteger("[crop_size length]", static_cast<int64_t>(crop_size.size()), kEqual,
kShapeRank2, prim_name);
(void)CheckAndConvertUtils::CheckInteger("[crop height]", crop_size[0], kGreaterThan, 0, prim_name);
(void)CheckAndConvertUtils::CheckInteger("[crop weight]", crop_size.back(), kGreaterThan, 0, prim_name);
ShapeVector out_shape = {num_boxes, crop_size[0], crop_size.back(), out_channel};
return std::make_shared<abstract::Shape>(out_shape);
}
@ -106,9 +105,8 @@ class CropAndResizeInfer : public abstract::OpInferBase {
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
MS_EXCEPTION_IF_CHECK_FAIL(input_args.size() == kCropAndResizeInputSize,
"For primitive[" + prim_name + "], the [x shape-length] should be 4, bug got " +
std::to_string(input_args.size()) + ".");
(void)CheckAndConvertUtils::CheckInteger("[input number]", static_cast<int64_t>(input_args.size()), kEqual,
kCropAndResizeInputSize, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
@ -125,37 +123,33 @@ class CropAndResizeInfer : public abstract::OpInferBase {
protected:
int64_t ParseNumBoxes(const ShapeVector &box_shape, const ShapeVector &box_index_shape, const std::string &prim_name,
const std::vector<int64_t> &batch_shape) const {
size_t batch_rank = batch_shape.size();
MS_EXCEPTION_IF_CHECK_FAIL(box_shape.size() == kShapeRank2 + batch_rank,
"For primitive[" + prim_name + "], the [boxes shape-length] should be 2, bug got " +
std::to_string(static_cast<int>(box_shape.size()) - static_cast<int>(batch_rank)) +
".");
auto batch_rank = static_cast<int64_t>(batch_shape.size());
int64_t box_dims = static_cast<int64_t>(box_shape.size()) - batch_rank;
(void)CheckAndConvertUtils::CheckInteger("[boxes shape-length]", box_dims, kEqual, kShapeRank2, prim_name);
MS_EXCEPTION_IF_CHECK_FAIL(
batch_shape == std::vector<int64_t>(box_shape.begin(), box_shape.begin() + batch_rank),
"For primitive[" + prim_name + "], the [batch_shape] of boxes is not equal to that of input.");
MS_EXCEPTION_IF_CHECK_FAIL(box_shape.back() == kLimitValue4, "For primitive[" + prim_name +
"], the [boxes second-dim] must be 4, but got " +
std::to_string(box_shape.back()) + ".");
(void)CheckAndConvertUtils::CheckInteger("[boxes second-dim]", box_shape.back(), kEqual, kLimitValue4, prim_name);
MS_EXCEPTION_IF_CHECK_FAIL(
box_index_shape.size() == 1 + batch_rank,
"For primitive[" + prim_name + "], the [box_index shape-length] should be 1, bug got " +
std::to_string(static_cast<int>(box_index_shape.size()) - static_cast<int>(batch_rank)) + ".");
int64_t box_index_dims = static_cast<int64_t>(box_index_shape.size()) - batch_rank;
(void)CheckAndConvertUtils::CheckInteger("[box_index shape-length]", box_index_dims, kEqual, 1, prim_name);
MS_EXCEPTION_IF_CHECK_FAIL(
batch_shape == std::vector<int64_t>(box_index_shape.begin(), box_index_shape.begin() + batch_rank),
"For primitive[" + prim_name + "], the [batch_shape] of box_index is not equal to that of input.");
MS_EXCEPTION_IF_CHECK_FAIL(
box_shape[batch_rank] == box_index_shape[batch_rank],
"For primitive[" + prim_name + "], the [boxes first-dim] must be equal to [box_index first-dim], but got " +
std::to_string(box_shape[batch_rank]) + " vs " + std::to_string(box_index_shape[batch_rank]) + ".");
if (box_shape[batch_rank] != box_index_shape[batch_rank]) {
MS_EXCEPTION(ValueError) << "For primitive[" + prim_name +
"], the [boxes first-dim] must be equal to [box_index first-dim], but got " +
std::to_string(box_shape[batch_rank]) + " vs " +
std::to_string(box_index_shape[batch_rank]) + ".";
}
return box_shape[batch_rank];
}
private:
const int64_t kLimitValue4 = 4;
const size_t kCropAndResizeInputSize = 4;
const size_t kShapeRank2 = 2;
const size_t kShapeRank4 = 4;
const int64_t kCropAndResizeInputSize = 4;
const int64_t kShapeRank2 = 2;
const int64_t kShapeRank4 = 4;
};
void CropAndResize::Init(ResizeMethod method, float extrapolation_value) {

View File

@ -1303,6 +1303,19 @@ def get_gatherd_grad_v2_vmap_rule(prim, axis_size):
if hasattr(prim, 'dim'):
dim = prim.dim
@constexpr
def _update_attr(x_rank, batch_dim):
pdim = dim
if pdim < 0:
pdim += x_rank
if pdim < 0 or pdim >= x_rank:
_raise_value_error(
"The `dim` in `GatherDGradV2` must be in range [{}, {}], but got {}.".format(-x_rank, x_rank - 1, dim))
if pdim >= batch_dim:
_vmap_update_prim_attr(prim, 'dim', pdim + 1)
elif dim < 0:
_vmap_update_prim_attr(prim, 'dim', pdim)
def vmap_rule(x_bdim, index_bdim, grad_bdim):
is_all_none, result = vmap_general_preprocess(prim, x_bdim, index_bdim, grad_bdim)
if is_all_none:
@ -1326,14 +1339,7 @@ def get_gatherd_grad_v2_vmap_rule(prim, axis_size):
# Adjust dim-attr if needed
x_rank = F.rank(x) - 1
pdim = dim
if pdim < 0:
pdim += x_rank
if pdim < 0 or pdim >= x_rank:
_raise_value_error(
"The `dim` in `GatherDGradV2` must be in range [{}, {}], but got {}.".format(-x_rank, x_rank - 1, dim))
if pdim >= batch_dim:
_vmap_update_prim_attr(prim, 'dim', pdim + 1)
_update_attr(x_rank, batch_dim)
out = prim(x, index, grad)
return (out, batch_dim)