forked from mindspore-Ecosystem/mindspore
nllloss bugfix for null input
This commit is contained in:
parent
c29a919b72
commit
c70beba908
|
@ -32,11 +32,18 @@ class NLLLossInfer : public abstract::OpInferBase {
|
|||
public:
|
||||
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const auto prim_name = primitive->name();
|
||||
const int64_t input_num = 3;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num,
|
||||
prim_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
|
||||
auto logits_shape_ptr = input_args[kInputIndex0]->BuildShape();
|
||||
auto logits_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(logits_shape_ptr)[kShape];
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]);
|
||||
auto target_shape_ptr = input_args[kInputIndex1]->BuildShape();
|
||||
auto target_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(target_shape_ptr)[kShape];
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]);
|
||||
auto weight_shape_ptr = input_args[kInputIndex2]->BuildShape();
|
||||
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(weight_shape_ptr)[kShape];
|
||||
|
||||
|
@ -83,6 +90,10 @@ class NLLLossInfer : public abstract::OpInferBase {
|
|||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
const std::set valid_types = {kFloat16, kFloat32};
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
const int64_t input_num = 3;
|
||||
CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
auto logits_data_type = input_args[kIndex0]->BuildType();
|
||||
auto target_type = input_args[kIndex1]->BuildType();
|
||||
auto weight_data_type = input_args[kIndex2]->BuildType();
|
||||
|
|
Loading…
Reference in New Issue