forked from mindspore-Ecosystem/mindspore
add check of batchnorm args size
This commit is contained in:
parent
9950d262c0
commit
c0a3d589fd
|
@ -119,6 +119,7 @@ class BatchNormInfer : public abstract::OpInferBase {
|
|||
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
const auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterThan, 0, prim_name);
|
||||
auto x_shape_ptr = input_args[kInputIndex0]->BuildShape();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr)[kShape];
|
||||
auto scale_shape_ptr = input_args[kInputIndex1]->BuildShape();
|
||||
|
@ -175,6 +176,8 @@ class BatchNormInfer : public abstract::OpInferBase {
|
|||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
const auto prim_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterThan, 0, prim_name);
|
||||
const std::set valid_types = {kFloat16, kFloat32};
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim->name());
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
|
Loading…
Reference in New Issue