!11511 Fix GPU BiasAdd NHWC data format problem

From: @TFbunny
Reviewed-by: @tom__chen,@robingrosman
Signed-off-by: @robingrosman
This commit is contained in:
mindspore-ci-bot 2021-01-22 22:19:18 +08:00 committed by Gitee
commit fe3473c0cc
1 changed files with 1 additions and 1 deletions

View File

@ -530,7 +530,7 @@ AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &
ShapeVector x_min_shape = x->shape()->min_shape();
ShapeVector x_max_shape = x->shape()->max_shape();
std::set<std::string> available_data_format{"NCHW", "NHWC"};
auto data_format_ptr = primitive->GetAttr("data_format");
auto data_format_ptr = primitive->GetAttr("format");
std::string data_format = "NCHW";
if ((data_format_ptr != nullptr) && data_format_ptr->isa<StringImm>()) {
data_format = data_format_ptr->cast<StringImmPtr>()->value();