!41947 fix dtype bug

Merge pull request !41947 from cjh9368/op_develop2
This commit is contained in:
i-robot 2022-09-17 02:00:27 +00:00 committed by Gitee
commit e100504d6a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 21 additions and 10 deletions

View File

@ -55,7 +55,9 @@ int SoftmaxCrossEntropyWithLogitsCpuKernelMod::Resize(const BaseOperatorPtr &bas
if (batch_size_ == 0 || class_num_ == 0) {
MS_LOG(EXCEPTION) << "Invalid batch size or class num input!";
}
auto mem_desc = CreateDesc<dnnl::memory::desc>(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc);
auto dnnl_dtype =
(inputs.at(0)->GetDtype() == kNumberTypeFloat32) ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::f16;
auto mem_desc = CreateDesc<dnnl::memory::desc>(mem_dims, dnnl_dtype, dnnl::memory::format_tag::nc);
auto desc = CreateDesc<dnnl::softmax_forward::desc>(dnnl::prop_kind::forward_training, mem_desc, 1);
auto prim_desc = CreateDesc<dnnl::softmax_forward::primitive_desc>(desc, engine_);
@ -64,8 +66,7 @@ int SoftmaxCrossEntropyWithLogitsCpuKernelMod::Resize(const BaseOperatorPtr &bas
AddArgument(DNNL_ARG_SRC, mem_desc);
AddArgument(DNNL_ARG_DST, mem_desc);
workspace_size_list_.clear();
size_t type_size = sizeof(float);
size_t type_size = (inputs.at(0)->GetDtype() == kNumberTypeFloat32) ? sizeof(float) : sizeof(float16);
size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
(void)workspace_size_list_.emplace_back(tensor_size);
return KRET_OK;

View File

@ -40,6 +40,10 @@ abstract::ShapePtr BCEWithLogitsLossInferShape(const PrimitivePtr &primitive,
auto logits_shape = logits_shape_map[kShape];
auto label_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape());
auto label_shape = label_shape_map[kShape];
if (IsDynamicRank(logits_shape) || IsDynamicRank(label_shape)) {
auto ds_shape = std::vector<int64_t>{UNKNOWN_RANK};
return std::make_shared<abstract::Shape>(ds_shape);
}
if (!ObscureShapeEqual(logits_shape, label_shape) && !(IsDynamicRank(logits_shape) || IsDynamicRank(label_shape))) {
MS_EXCEPTION(ValueError) << "For '" << op_name << "', the two input 'logits' and 'label' shape are not equal.";
}

View File

@ -39,6 +39,12 @@ class SigmoidCrossEntropyWithLogitsInfer : public abstract::OpInferBase {
auto label_shape = input_args[1]->BuildShape();
auto logits_shape_ptr = logits_shape->cast<abstract::ShapePtr>();
auto label_shape_ptr = label_shape->cast<abstract::ShapePtr>();
auto logits_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(logits_shape)[kShape];
auto label_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(label_shape)[kShape];
if (IsDynamicRank(logits_map) || IsDynamicRank(label_map)) {
auto ds_shape_ptr = std::make_shared<abstract::Shape>(std::vector<int64_t>{UNKNOWN_RANK});
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{ds_shape_ptr, ds_shape_ptr});
}
// logits and label must have the same shape when is not dynamic
if (!logits_shape_ptr->IsDynamic() && !label_shape_ptr->IsDynamic()) {
if (*logits_shape != *label_shape) {

View File

@ -37,14 +37,14 @@ class SoftmaxCrossEntropyWithLogitsInfer : public abstract::OpInferBase {
auto logits_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(logits_shape)[kShape];
auto label_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(label_shape)[kShape];
const int64_t input_rank = 2;
if (!IsDynamicRank(logits_map)) {
(void)CheckAndConvertUtils::CheckInteger("dimension of logits", SizeToLong(logits_map.size()), kEqual, input_rank,
prim_name);
}
if (!IsDynamicRank(label_map)) {
(void)CheckAndConvertUtils::CheckInteger("dimension of labels", SizeToLong(label_map.size()), kEqual, input_rank,
prim_name);
if (IsDynamicRank(logits_map) || IsDynamicRank(label_map)) {
auto ds_shape_ptr = std::make_shared<abstract::Shape>(std::vector<int64_t>{UNKNOWN_RANK});
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{ds_shape_ptr, ds_shape_ptr});
}
(void)CheckAndConvertUtils::CheckInteger("dimension of logits", SizeToLong(logits_map.size()), kEqual, input_rank,
prim_name);
(void)CheckAndConvertUtils::CheckInteger("dimension of labels", SizeToLong(label_map.size()), kEqual, input_rank,
prim_name);
auto logits_shape_ptr = logits_shape->cast<abstract::ShapePtr>();
auto label_shape_ptr = label_shape->cast<abstract::ShapePtr>();
// logits and label must have the same shape when is not dynamic