From c4537b63436f237e5d14ceb28cc9b56ff7ce1ba5 Mon Sep 17 00:00:00 2001 From: l00591931 Date: Thu, 22 Apr 2021 09:33:32 +0800 Subject: [PATCH] fix shape operator --- mindspore/core/ops/shape.cc | 2 ++ mindspore/core/ops/softmax.cc | 2 ++ 2 files changed, 4 insertions(+) diff --git a/mindspore/core/ops/shape.cc b/mindspore/core/ops/shape.cc index dd61504a2db..e9be46d02fe 100644 --- a/mindspore/core/ops/shape.cc +++ b/mindspore/core/ops/shape.cc @@ -36,6 +36,8 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); auto in_shape = shape_map[kShape]; // infer type + std::set valid_params_types = {kTensorType}; + CheckAndConvertUtils::CheckSubClass("shape type", input_args[0]->BuildType(), valid_params_types, op_name); AbstractBasePtrList abs_list; std::transform(in_shape.begin(), in_shape.end(), std::back_inserter(abs_list), [](int64_t item) -> std::shared_ptr { diff --git a/mindspore/core/ops/softmax.cc b/mindspore/core/ops/softmax.cc index 5026a4712f3..2be3f1fb6fb 100644 --- a/mindspore/core/ops/softmax.cc +++ b/mindspore/core/ops/softmax.cc @@ -44,6 +44,7 @@ void Softmax::Init(const int64_t axis) { this->set_axis(axis_vec); } +namespace { abstract::ShapePtr SoftMaxInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); @@ -74,6 +75,7 @@ TypePtr SoftMaxInferType(const PrimitivePtr &prim, const std::vector valid_types = {kFloat16, kFloat32, kFloat64}; return CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, prim->name()); } +} // namespace AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) {