!41947 fix dtype bug
Merge pull request !41947 from cjh9368/op_develop2
This commit is contained in:
commit
e100504d6a
|
@ -55,7 +55,9 @@ int SoftmaxCrossEntropyWithLogitsCpuKernelMod::Resize(const BaseOperatorPtr &bas
|
|||
if (batch_size_ == 0 || class_num_ == 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid batch size or class num input!";
|
||||
}
|
||||
auto mem_desc = CreateDesc<dnnl::memory::desc>(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc);
|
||||
auto dnnl_dtype =
|
||||
(inputs.at(0)->GetDtype() == kNumberTypeFloat32) ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::f16;
|
||||
auto mem_desc = CreateDesc<dnnl::memory::desc>(mem_dims, dnnl_dtype, dnnl::memory::format_tag::nc);
|
||||
|
||||
auto desc = CreateDesc<dnnl::softmax_forward::desc>(dnnl::prop_kind::forward_training, mem_desc, 1);
|
||||
auto prim_desc = CreateDesc<dnnl::softmax_forward::primitive_desc>(desc, engine_);
|
||||
|
@ -64,8 +66,7 @@ int SoftmaxCrossEntropyWithLogitsCpuKernelMod::Resize(const BaseOperatorPtr &bas
|
|||
AddArgument(DNNL_ARG_SRC, mem_desc);
|
||||
AddArgument(DNNL_ARG_DST, mem_desc);
|
||||
|
||||
workspace_size_list_.clear();
|
||||
size_t type_size = sizeof(float);
|
||||
size_t type_size = (inputs.at(0)->GetDtype() == kNumberTypeFloat32) ? sizeof(float) : sizeof(float16);
|
||||
size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
|
||||
(void)workspace_size_list_.emplace_back(tensor_size);
|
||||
return KRET_OK;
|
||||
|
|
|
@ -40,6 +40,10 @@ abstract::ShapePtr BCEWithLogitsLossInferShape(const PrimitivePtr &primitive,
|
|||
auto logits_shape = logits_shape_map[kShape];
|
||||
auto label_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape());
|
||||
auto label_shape = label_shape_map[kShape];
|
||||
if (IsDynamicRank(logits_shape) || IsDynamicRank(label_shape)) {
|
||||
auto ds_shape = std::vector<int64_t>{UNKNOWN_RANK};
|
||||
return std::make_shared<abstract::Shape>(ds_shape);
|
||||
}
|
||||
if (!ObscureShapeEqual(logits_shape, label_shape) && !(IsDynamicRank(logits_shape) || IsDynamicRank(label_shape))) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name << "', the two input 'logits' and 'label' shape are not equal.";
|
||||
}
|
||||
|
|
|
@ -39,6 +39,12 @@ class SigmoidCrossEntropyWithLogitsInfer : public abstract::OpInferBase {
|
|||
auto label_shape = input_args[1]->BuildShape();
|
||||
auto logits_shape_ptr = logits_shape->cast<abstract::ShapePtr>();
|
||||
auto label_shape_ptr = label_shape->cast<abstract::ShapePtr>();
|
||||
auto logits_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(logits_shape)[kShape];
|
||||
auto label_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(label_shape)[kShape];
|
||||
if (IsDynamicRank(logits_map) || IsDynamicRank(label_map)) {
|
||||
auto ds_shape_ptr = std::make_shared<abstract::Shape>(std::vector<int64_t>{UNKNOWN_RANK});
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{ds_shape_ptr, ds_shape_ptr});
|
||||
}
|
||||
// logits and label must have the same shape when is not dynamic
|
||||
if (!logits_shape_ptr->IsDynamic() && !label_shape_ptr->IsDynamic()) {
|
||||
if (*logits_shape != *label_shape) {
|
||||
|
|
|
@ -37,14 +37,14 @@ class SoftmaxCrossEntropyWithLogitsInfer : public abstract::OpInferBase {
|
|||
auto logits_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(logits_shape)[kShape];
|
||||
auto label_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(label_shape)[kShape];
|
||||
const int64_t input_rank = 2;
|
||||
if (!IsDynamicRank(logits_map)) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("dimension of logits", SizeToLong(logits_map.size()), kEqual, input_rank,
|
||||
prim_name);
|
||||
}
|
||||
if (!IsDynamicRank(label_map)) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("dimension of labels", SizeToLong(label_map.size()), kEqual, input_rank,
|
||||
prim_name);
|
||||
if (IsDynamicRank(logits_map) || IsDynamicRank(label_map)) {
|
||||
auto ds_shape_ptr = std::make_shared<abstract::Shape>(std::vector<int64_t>{UNKNOWN_RANK});
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{ds_shape_ptr, ds_shape_ptr});
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckInteger("dimension of logits", SizeToLong(logits_map.size()), kEqual, input_rank,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("dimension of labels", SizeToLong(label_map.size()), kEqual, input_rank,
|
||||
prim_name);
|
||||
auto logits_shape_ptr = logits_shape->cast<abstract::ShapePtr>();
|
||||
auto label_shape_ptr = label_shape->cast<abstract::ShapePtr>();
|
||||
// logits and label must have the same shape when is not dynamic
|
||||
|
|
Loading…
Reference in New Issue