!42059 Support DynamicShape/Rank for BCEWithLogitsLoss, StridedSlice, CumSum

Merge pull request !42059 from zhengzuohe/r1.9
This commit is contained in:
i-robot 2022-09-16 02:30:00 +00:00 committed by Gitee
commit 5c8ea33af0
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 22 additions and 11 deletions

View File

@ -358,6 +358,9 @@ void GPUKernelExecutor::FuseOperators(const KernelGraphPtr &graph) const {
namespace {
void RunOpOptimize(const KernelGraphPtr &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
if (kernel_graph->is_dynamic_shape()) {
return;
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::BCEWithLogitsLossFusion>());

View File

@ -40,7 +40,7 @@ abstract::ShapePtr BCEWithLogitsLossInferShape(const PrimitivePtr &primitive,
auto logits_shape = logits_shape_map[kShape];
auto label_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape());
auto label_shape = label_shape_map[kShape];
if (!ObscureShapeEqual(logits_shape, label_shape)) {
if (!ObscureShapeEqual(logits_shape, label_shape) && !(IsDynamicRank(logits_shape) || IsDynamicRank(label_shape))) {
MS_EXCEPTION(ValueError) << "For '" << op_name << "', the two input 'logits' and 'label' shape are not equal.";
}
auto weight_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape());

View File

@ -60,6 +60,9 @@ abstract::ShapePtr CumSumInferShape(const PrimitivePtr &primitive, const std::ve
if (x_shape_ptr->IsDynamic()) {
return x_shape_ptr->cast<abstract::ShapePtr>();
}
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr)[kShape];
auto rank = SizeToLong(x_shape.size());
(void)CheckAndConvertUtils::CheckInteger("rank of 'x'", rank, kGreaterThan, 0, prim_name);
int64_t axis;
if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>()) {
@ -67,9 +70,13 @@ abstract::ShapePtr CumSumInferShape(const PrimitivePtr &primitive, const std::ve
MS_EXCEPTION_IF_NULL(axis_ptr);
auto axis_value_ptr = axis_ptr->BuildValue();
MS_EXCEPTION_IF_NULL(axis_value_ptr);
auto axis_tensor = axis_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(axis_tensor);
axis = *static_cast<int64_t *>(axis_tensor->data_c());
if (axis_value_ptr->isa<tensor::Tensor>()) {
auto axis_tensor = axis_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(axis_tensor);
axis = *static_cast<int64_t *>(axis_tensor->data_c());
} else {
return std::make_shared<abstract::Shape>(x_shape);
}
} else if (input_args[kInputIndex1]->isa<abstract::AbstractScalar>()) {
auto axis_ptr = input_args[kInputIndex1]->cast<abstract::AbstractScalarPtr>();
MS_EXCEPTION_IF_NULL(axis_ptr);
@ -79,9 +86,6 @@ abstract::ShapePtr CumSumInferShape(const PrimitivePtr &primitive, const std::ve
<< "', the second input type should be tensor or scalar, but got invalid abstract type:"
<< input_args[kInputIndex1]->type_name() << ".";
}
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr)[kShape];
auto rank = SizeToLong(x_shape.size());
(void)CheckAndConvertUtils::CheckInteger("rank of 'x'", rank, kGreaterThan, 0, prim_name);
CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeBoth, {-rank, rank - 1}, prim_name);
return std::make_shared<abstract::Shape>(x_shape);
}

View File

@ -354,8 +354,9 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive,
const size_t x_index = 0;
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[x_index]->BuildShape());
auto x_shape = shape_map[kShape];
bool x_is_dyn =
std::any_of(x_shape.begin(), x_shape.end(), [](int64_t value) { return value == abstract::Shape::SHP_ANY; });
if (IsDynamicRank(x_shape)) {
return std::make_shared<abstract::Shape>(std::vector<int64_t>{UNKNOWN_RANK});
}
ShapeVector begin_v;
ShapeVector end_v;
@ -376,7 +377,7 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive,
<< ", 'strides': " << stride_len << ".";
}
bool slice_dynamic = false;
if (begin_dynamic || end_dynamic || stride_dynamic || x_is_dyn) {
if (begin_dynamic || end_dynamic || stride_dynamic || IsDynamic(x_shape)) {
slice_dynamic = true;
}
if (!slice_dynamic) {

View File

@ -25,7 +25,8 @@ from mindspore import log as logger
from mindspore import context
from mindspore.common.initializer import Zero
from mindspore.ops import signature as sig
from mindspore.ops._utils import get_broadcast_shape, is_shape_unknown
from mindspore.ops._utils import get_broadcast_shape
from mindspore.common._utils import is_shape_unknown, is_dim_unknown
from mindspore.ops.primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
from mindspore._checkparam import Rel
from mindspore._checkparam import Validator as validator
@ -3553,6 +3554,8 @@ class StridedSlice(PrimitiveWithInfer):
def _compute_dynamic_slicing_shape(self, x, begin_v, end_v, strides_v, slice_len):
"""Computes the shape of the slicing for dynamic shape, mask is currently not supported."""
x_shape = x['shape']
if is_dim_unknown(x_shape):
return [-2]
x_rank = len(x_shape)
new_axis_pos = bin(self.new_axis_mask)[-1:1:-1]
shrink_axis_pos = bin(self.shrink_axis_mask)[-1:1:-1]