fix layernorm bug

This commit is contained in:
yangzhenzhang 2020-04-26 16:37:57 +08:00
parent 348b0ef53c
commit 4750861054
2 changed files with 20 additions and 10 deletions

View File

@ -69,7 +69,7 @@ Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
// check input strategy // check input strategy
for (size_t i = begin_norm_axis_; i < input_strategy.size(); ++i) { for (size_t i = begin_norm_axis_; i < input_strategy.size(); ++i) {
if (input_strategy[begin_norm_axis_] != NO_SPLIT_STRATEGY) { if (input_strategy[i] != NO_SPLIT_STRATEGY) {
MS_LOG(ERROR) << name_ << ": Invalid input strategy " << ShapeToString(input_strategy); MS_LOG(ERROR) << name_ << ": Invalid input strategy " << ShapeToString(input_strategy);
return FAILED; return FAILED;
} }

View File

@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================
import numpy as np import numpy as np
import pytest
import mindspore as ms import mindspore as ms
from mindspore import context, Tensor, Parameter from mindspore import context, Tensor, Parameter
from mindspore.nn import Cell, TrainOneStepCell, Momentum from mindspore.nn import Cell, TrainOneStepCell, Momentum
@ -24,7 +25,7 @@ from mindspore.common.initializer import initializer
class Net(Cell): class Net(Cell):
def __init__(self, mul_weight, strategy1=None, strategy2=None, strategy3=None): def __init__(self, mul_weight, strategy1=None, strategy2=None, strategy3=None):
super().__init__() super().__init__()
self.begin_norm_axis = -1 self.begin_norm_axis = 2
self.begin_params_axis = 1 self.begin_params_axis = 1
self.mul = P.Mul().set_strategy(strategy1) self.mul = P.Mul().set_strategy(strategy1)
self.layer_norm = P.LayerNorm(self.begin_norm_axis, self.begin_params_axis).set_strategy(strategy2) self.layer_norm = P.LayerNorm(self.begin_norm_axis, self.begin_params_axis).set_strategy(strategy2)
@ -64,18 +65,18 @@ def test_layer_norm_data_parallel():
def test_layer_norm_model_parallel(): def test_layer_norm_model_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((1, 1, 16, 1), (1, 1, 16, 1)) strategy1 = ((1, 16, 1, 1), (1, 16, 1, 1))
strategy2 = ((1, 1, 16, 1), (1, 16, 1), (1, 16, 1)) strategy2 = ((1, 16, 1, 1), (16, 1, 1), (16, 1, 1))
strategy3 = ((1, 1, 16, 1), (1, 1, 16, 1)) strategy3 = ((1, 16, 1, 1), (1, 16, 1, 1))
net = Net(_w, strategy1, strategy2, strategy3) net = Net(_w, strategy1, strategy2, strategy3)
compile(net) compile(net)
def test_layer_norm_hybrid_parallel(): def test_layer_norm_hybrid_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1)) strategy1 = ((2, 8, 1, 1), (2, 8, 1, 1))
strategy2 = ((2, 2, 4, 1), (2, 4, 1), (2, 4, 1)) strategy2 = ((2, 8, 1, 1), (8, 1, 1), (8, 1, 1))
strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1)) strategy3 = ((2, 8, 1, 1), (2, 8, 1, 1))
net = Net(_w, strategy1, strategy2, strategy3) net = Net(_w, strategy1, strategy2, strategy3)
compile(net) compile(net)
@ -89,8 +90,17 @@ def test_layer_norm_auto_parallel():
def test_layer_norm_repeat_calc(): def test_layer_norm_repeat_calc():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1)) strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1))
strategy2 = ((1, 2, 2, 1), (2, 2, 1), (2, 2, 1)) strategy2 = ((2, 2, 1, 1), (2, 1, 1), (2, 1, 1))
strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1)) strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1))
net = Net(_w, strategy1, strategy2, strategy3) net = Net(_w, strategy1, strategy2, strategy3)
compile(net) compile(net)
def test_layer_norm_wrong_strategy():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1))
strategy2 = ((1, 2, 1, 2), (2, 1, 2), (2, 1, 2))
strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1))
net = Net(_w, strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile(net)