From 45ab1edb6767f599484454ae759963ea5cd11c6e Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Wed, 14 Sep 2022 14:42:08 +0800 Subject: [PATCH] support dynamic shape ops --- ...softmax_cross_entropy_with_logits_cpu_kernel.cc | 7 ++++--- mindspore/core/ops/bce_with_logits_loss.cc | 4 ++++ .../core/ops/sigmoid_cross_entropy_with_logits.cc | 6 ++++++ .../core/ops/softmax_cross_entropy_with_logits.cc | 14 +++++++------- 4 files changed, 21 insertions(+), 10 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc index 7966059e64c..ce8d35a5182 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc @@ -55,7 +55,9 @@ int SoftmaxCrossEntropyWithLogitsCpuKernelMod::Resize(const BaseOperatorPtr &bas if (batch_size_ == 0 || class_num_ == 0) { MS_LOG(EXCEPTION) << "Invalid batch size or class num input!"; } - auto mem_desc = CreateDesc(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc); + auto dnnl_dtype = + (inputs.at(0)->GetDtype() == kNumberTypeFloat32) ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::f16; + auto mem_desc = CreateDesc(mem_dims, dnnl_dtype, dnnl::memory::format_tag::nc); auto desc = CreateDesc(dnnl::prop_kind::forward_training, mem_desc, 1); auto prim_desc = CreateDesc(desc, engine_); @@ -64,8 +66,7 @@ int SoftmaxCrossEntropyWithLogitsCpuKernelMod::Resize(const BaseOperatorPtr &bas AddArgument(DNNL_ARG_SRC, mem_desc); AddArgument(DNNL_ARG_DST, mem_desc); - workspace_size_list_.clear(); - size_t type_size = sizeof(float); + size_t type_size = (inputs.at(0)->GetDtype() == kNumberTypeFloat32) ? sizeof(float) : sizeof(float16); size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); (void)workspace_size_list_.emplace_back(tensor_size); return KRET_OK; diff --git a/mindspore/core/ops/bce_with_logits_loss.cc b/mindspore/core/ops/bce_with_logits_loss.cc index af259e88b11..fdaab2dc030 100644 --- a/mindspore/core/ops/bce_with_logits_loss.cc +++ b/mindspore/core/ops/bce_with_logits_loss.cc @@ -40,6 +40,10 @@ abstract::ShapePtr BCEWithLogitsLossInferShape(const PrimitivePtr &primitive, auto logits_shape = logits_shape_map[kShape]; auto label_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape()); auto label_shape = label_shape_map[kShape]; + if (IsDynamicRank(logits_shape) || IsDynamicRank(label_shape)) { + auto ds_shape = std::vector{UNKNOWN_RANK}; + return std::make_shared(ds_shape); + } if (!ObscureShapeEqual(logits_shape, label_shape) && !(IsDynamicRank(logits_shape) || IsDynamicRank(label_shape))) { MS_EXCEPTION(ValueError) << "For '" << op_name << "', the two input 'logits' and 'label' shape are not equal."; } diff --git a/mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc b/mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc index f7f452fcb4c..46692b0e505 100644 --- a/mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc +++ b/mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc @@ -39,6 +39,12 @@ class SigmoidCrossEntropyWithLogitsInfer : public abstract::OpInferBase { auto label_shape = input_args[1]->BuildShape(); auto logits_shape_ptr = logits_shape->cast(); auto label_shape_ptr = label_shape->cast(); + auto logits_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(logits_shape)[kShape]; + auto label_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(label_shape)[kShape]; + if (IsDynamicRank(logits_map) || IsDynamicRank(label_map)) { + auto ds_shape_ptr = std::make_shared(std::vector{UNKNOWN_RANK}); + return std::make_shared(std::vector{ds_shape_ptr, ds_shape_ptr}); + } // logits and label must have the same shape when is not dynamic if (!logits_shape_ptr->IsDynamic() && !label_shape_ptr->IsDynamic()) { if (*logits_shape != *label_shape) { diff --git a/mindspore/core/ops/softmax_cross_entropy_with_logits.cc b/mindspore/core/ops/softmax_cross_entropy_with_logits.cc index 811d1d207d1..791b21b15a5 100644 --- a/mindspore/core/ops/softmax_cross_entropy_with_logits.cc +++ b/mindspore/core/ops/softmax_cross_entropy_with_logits.cc @@ -37,14 +37,14 @@ class SoftmaxCrossEntropyWithLogitsInfer : public abstract::OpInferBase { auto logits_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(logits_shape)[kShape]; auto label_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(label_shape)[kShape]; const int64_t input_rank = 2; - if (!IsDynamicRank(logits_map)) { - (void)CheckAndConvertUtils::CheckInteger("dimension of logits", SizeToLong(logits_map.size()), kEqual, input_rank, - prim_name); - } - if (!IsDynamicRank(label_map)) { - (void)CheckAndConvertUtils::CheckInteger("dimension of labels", SizeToLong(label_map.size()), kEqual, input_rank, - prim_name); + if (IsDynamicRank(logits_map) || IsDynamicRank(label_map)) { + auto ds_shape_ptr = std::make_shared(std::vector{UNKNOWN_RANK}); + return std::make_shared(std::vector{ds_shape_ptr, ds_shape_ptr}); } + (void)CheckAndConvertUtils::CheckInteger("dimension of logits", SizeToLong(logits_map.size()), kEqual, input_rank, + prim_name); + (void)CheckAndConvertUtils::CheckInteger("dimension of labels", SizeToLong(label_map.size()), kEqual, input_rank, + prim_name); auto logits_shape_ptr = logits_shape->cast(); auto label_shape_ptr = label_shape->cast(); // logits and label must have the same shape when is not dynamic