forked from mindspore-Ecosystem/mindspore
!11511 Fix GPU BiasAdd NHWC data format problem
From: @TFbunny Reviewed-by: @tom__chen,@robingrosman Signed-off-by: @robingrosman
This commit is contained in:
commit
fe3473c0cc
|
@ -530,7 +530,7 @@ AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &
|
||||||
ShapeVector x_min_shape = x->shape()->min_shape();
|
ShapeVector x_min_shape = x->shape()->min_shape();
|
||||||
ShapeVector x_max_shape = x->shape()->max_shape();
|
ShapeVector x_max_shape = x->shape()->max_shape();
|
||||||
std::set<std::string> available_data_format{"NCHW", "NHWC"};
|
std::set<std::string> available_data_format{"NCHW", "NHWC"};
|
||||||
auto data_format_ptr = primitive->GetAttr("data_format");
|
auto data_format_ptr = primitive->GetAttr("format");
|
||||||
std::string data_format = "NCHW";
|
std::string data_format = "NCHW";
|
||||||
if ((data_format_ptr != nullptr) && data_format_ptr->isa<StringImm>()) {
|
if ((data_format_ptr != nullptr) && data_format_ptr->isa<StringImm>()) {
|
||||||
data_format = data_format_ptr->cast<StringImmPtr>()->value();
|
data_format = data_format_ptr->cast<StringImmPtr>()->value();
|
||||||
|
|
Loading…
Reference in New Issue