forked from mindspore-Ecosystem/mindspore
!42059 Support DynamicShape/Rank for BCEWithLogitsLoss, StridedSlice, CumSum
Merge pull request !42059 from zhengzuohe/r1.9
This commit is contained in:
commit
5c8ea33af0
|
@ -358,6 +358,9 @@ void GPUKernelExecutor::FuseOperators(const KernelGraphPtr &graph) const {
|
|||
namespace {
|
||||
void RunOpOptimize(const KernelGraphPtr &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
if (kernel_graph->is_dynamic_shape()) {
|
||||
return;
|
||||
}
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::BCEWithLogitsLossFusion>());
|
||||
|
|
|
@ -40,7 +40,7 @@ abstract::ShapePtr BCEWithLogitsLossInferShape(const PrimitivePtr &primitive,
|
|||
auto logits_shape = logits_shape_map[kShape];
|
||||
auto label_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape());
|
||||
auto label_shape = label_shape_map[kShape];
|
||||
if (!ObscureShapeEqual(logits_shape, label_shape)) {
|
||||
if (!ObscureShapeEqual(logits_shape, label_shape) && !(IsDynamicRank(logits_shape) || IsDynamicRank(label_shape))) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name << "', the two input 'logits' and 'label' shape are not equal.";
|
||||
}
|
||||
auto weight_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape());
|
||||
|
|
|
@ -60,6 +60,9 @@ abstract::ShapePtr CumSumInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
if (x_shape_ptr->IsDynamic()) {
|
||||
return x_shape_ptr->cast<abstract::ShapePtr>();
|
||||
}
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr)[kShape];
|
||||
auto rank = SizeToLong(x_shape.size());
|
||||
(void)CheckAndConvertUtils::CheckInteger("rank of 'x'", rank, kGreaterThan, 0, prim_name);
|
||||
|
||||
int64_t axis;
|
||||
if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>()) {
|
||||
|
@ -67,9 +70,13 @@ abstract::ShapePtr CumSumInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
MS_EXCEPTION_IF_NULL(axis_ptr);
|
||||
auto axis_value_ptr = axis_ptr->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(axis_value_ptr);
|
||||
if (axis_value_ptr->isa<tensor::Tensor>()) {
|
||||
auto axis_tensor = axis_value_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(axis_tensor);
|
||||
axis = *static_cast<int64_t *>(axis_tensor->data_c());
|
||||
} else {
|
||||
return std::make_shared<abstract::Shape>(x_shape);
|
||||
}
|
||||
} else if (input_args[kInputIndex1]->isa<abstract::AbstractScalar>()) {
|
||||
auto axis_ptr = input_args[kInputIndex1]->cast<abstract::AbstractScalarPtr>();
|
||||
MS_EXCEPTION_IF_NULL(axis_ptr);
|
||||
|
@ -79,9 +86,6 @@ abstract::ShapePtr CumSumInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
<< "', the second input type should be tensor or scalar, but got invalid abstract type:"
|
||||
<< input_args[kInputIndex1]->type_name() << ".";
|
||||
}
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr)[kShape];
|
||||
auto rank = SizeToLong(x_shape.size());
|
||||
(void)CheckAndConvertUtils::CheckInteger("rank of 'x'", rank, kGreaterThan, 0, prim_name);
|
||||
CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeBoth, {-rank, rank - 1}, prim_name);
|
||||
return std::make_shared<abstract::Shape>(x_shape);
|
||||
}
|
||||
|
|
|
@ -354,8 +354,9 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive,
|
|||
const size_t x_index = 0;
|
||||
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[x_index]->BuildShape());
|
||||
auto x_shape = shape_map[kShape];
|
||||
bool x_is_dyn =
|
||||
std::any_of(x_shape.begin(), x_shape.end(), [](int64_t value) { return value == abstract::Shape::SHP_ANY; });
|
||||
if (IsDynamicRank(x_shape)) {
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{UNKNOWN_RANK});
|
||||
}
|
||||
|
||||
ShapeVector begin_v;
|
||||
ShapeVector end_v;
|
||||
|
@ -376,7 +377,7 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive,
|
|||
<< ", 'strides': " << stride_len << ".";
|
||||
}
|
||||
bool slice_dynamic = false;
|
||||
if (begin_dynamic || end_dynamic || stride_dynamic || x_is_dyn) {
|
||||
if (begin_dynamic || end_dynamic || stride_dynamic || IsDynamic(x_shape)) {
|
||||
slice_dynamic = true;
|
||||
}
|
||||
if (!slice_dynamic) {
|
||||
|
|
|
@ -25,7 +25,8 @@ from mindspore import log as logger
|
|||
from mindspore import context
|
||||
from mindspore.common.initializer import Zero
|
||||
from mindspore.ops import signature as sig
|
||||
from mindspore.ops._utils import get_broadcast_shape, is_shape_unknown
|
||||
from mindspore.ops._utils import get_broadcast_shape
|
||||
from mindspore.common._utils import is_shape_unknown, is_dim_unknown
|
||||
from mindspore.ops.primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore._checkparam import Validator as validator
|
||||
|
@ -3553,6 +3554,8 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
def _compute_dynamic_slicing_shape(self, x, begin_v, end_v, strides_v, slice_len):
|
||||
"""Computes the shape of the slicing for dynamic shape, mask is currently not supported."""
|
||||
x_shape = x['shape']
|
||||
if is_dim_unknown(x_shape):
|
||||
return [-2]
|
||||
x_rank = len(x_shape)
|
||||
new_axis_pos = bin(self.new_axis_mask)[-1:1:-1]
|
||||
shrink_axis_pos = bin(self.shrink_axis_mask)[-1:1:-1]
|
||||
|
|
Loading…
Reference in New Issue