forked from mindspore-Ecosystem/mindspore
add check of batchnorm args size
This commit is contained in:
parent
9950d262c0
commit
c0a3d589fd
|
@ -119,6 +119,7 @@ class BatchNormInfer : public abstract::OpInferBase {
|
||||||
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||||
const auto prim_name = primitive->name();
|
const auto prim_name = primitive->name();
|
||||||
|
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterThan, 0, prim_name);
|
||||||
auto x_shape_ptr = input_args[kInputIndex0]->BuildShape();
|
auto x_shape_ptr = input_args[kInputIndex0]->BuildShape();
|
||||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr)[kShape];
|
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr)[kShape];
|
||||||
auto scale_shape_ptr = input_args[kInputIndex1]->BuildShape();
|
auto scale_shape_ptr = input_args[kInputIndex1]->BuildShape();
|
||||||
|
@ -175,6 +176,8 @@ class BatchNormInfer : public abstract::OpInferBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
|
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||||
|
const auto prim_name = prim->name();
|
||||||
|
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterThan, 0, prim_name);
|
||||||
const std::set valid_types = {kFloat16, kFloat32};
|
const std::set valid_types = {kFloat16, kFloat32};
|
||||||
auto x_type = input_args[0]->BuildType();
|
auto x_type = input_args[0]->BuildType();
|
||||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim->name());
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim->name());
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <algorithm>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
Loading…
Reference in New Issue