add check of batchnorm args size

This commit is contained in:
tangxl 2023-02-09 20:39:29 +08:00
parent 9950d262c0
commit c0a3d589fd
2 changed files with 3 additions and 1 deletions

View File

@ -119,6 +119,7 @@ class BatchNormInfer : public abstract::OpInferBase {
BaseShapePtr InferShape(const PrimitivePtr &primitive, BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override { const std::vector<AbstractBasePtr> &input_args) const override {
const auto prim_name = primitive->name(); const auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterThan, 0, prim_name);
auto x_shape_ptr = input_args[kInputIndex0]->BuildShape(); auto x_shape_ptr = input_args[kInputIndex0]->BuildShape();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr)[kShape]; auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr)[kShape];
auto scale_shape_ptr = input_args[kInputIndex1]->BuildShape(); auto scale_shape_ptr = input_args[kInputIndex1]->BuildShape();
@ -175,6 +176,8 @@ class BatchNormInfer : public abstract::OpInferBase {
} }
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override { TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
const auto prim_name = prim->name();
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterThan, 0, prim_name);
const std::set valid_types = {kFloat16, kFloat32}; const std::set valid_types = {kFloat16, kFloat32};
auto x_type = input_args[0]->BuildType(); auto x_type = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim->name()); (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim->name());

View File

@ -14,7 +14,6 @@
* limitations under the License. * limitations under the License.
*/ */
#include <string> #include <string>
#include <algorithm>
#include <memory> #include <memory>
#include <set> #include <set>
#include <vector> #include <vector>