forked from mindspore-Ecosystem/mindspore
!1090 fix mean and var shape in layernorm
Merge pull request !1090 from JichenZhao/layernorm_mean_var_shape
This commit is contained in:
commit
298a784878
|
@ -301,7 +301,7 @@ AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr
|
||||||
|
|
||||||
// begin_norm_axis and begin_params_axis should be smaller than the size of input_x and >= -1
|
// begin_norm_axis and begin_params_axis should be smaller than the size of input_x and >= -1
|
||||||
ValuePtr bna_ptr = primitive->GetAttr("begin_norm_axis");
|
ValuePtr bna_ptr = primitive->GetAttr("begin_norm_axis");
|
||||||
(void)CheckAxis(op_name, bna_ptr, -1, SizeToInt(input_rank) - 1);
|
int begin_norm_axis = CheckAxis(op_name, bna_ptr, -1, SizeToInt(input_rank) - 1);
|
||||||
|
|
||||||
ValuePtr bpa_ptr = primitive->GetAttr("begin_params_axis");
|
ValuePtr bpa_ptr = primitive->GetAttr("begin_params_axis");
|
||||||
int begin_params_axis = CheckAxis(op_name, bpa_ptr, -1, SizeToInt(input_rank) - 1);
|
int begin_params_axis = CheckAxis(op_name, bpa_ptr, -1, SizeToInt(input_rank) - 1);
|
||||||
|
@ -341,7 +341,13 @@ AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr
|
||||||
}
|
}
|
||||||
|
|
||||||
auto mean_var_shape_value = input_shape->shape();
|
auto mean_var_shape_value = input_shape->shape();
|
||||||
|
if (begin_norm_axis == -1) {
|
||||||
mean_var_shape_value[input_rank - 1] = 1;
|
mean_var_shape_value[input_rank - 1] = 1;
|
||||||
|
} else {
|
||||||
|
for (size_t i = begin_norm_axis; i < input_rank; ++i) {
|
||||||
|
mean_var_shape_value[i] = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
auto mean = input_x->Broaden();
|
auto mean = input_x->Broaden();
|
||||||
mean->set_shape(std::make_shared<Shape>(mean_var_shape_value));
|
mean->set_shape(std::make_shared<Shape>(mean_var_shape_value));
|
||||||
|
|
Loading…
Reference in New Issue