forked from mindspore-Ecosystem/mindspore
fix layernorm bug
This commit is contained in:
parent
348b0ef53c
commit
4750861054
|
@ -69,7 +69,7 @@ Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
}
|
}
|
||||||
// check input strategy
|
// check input strategy
|
||||||
for (size_t i = begin_norm_axis_; i < input_strategy.size(); ++i) {
|
for (size_t i = begin_norm_axis_; i < input_strategy.size(); ++i) {
|
||||||
if (input_strategy[begin_norm_axis_] != NO_SPLIT_STRATEGY) {
|
if (input_strategy[i] != NO_SPLIT_STRATEGY) {
|
||||||
MS_LOG(ERROR) << name_ << ": Invalid input strategy " << ShapeToString(input_strategy);
|
MS_LOG(ERROR) << name_ << ": Invalid input strategy " << ShapeToString(input_strategy);
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,8 +11,9 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
from mindspore import context, Tensor, Parameter
|
from mindspore import context, Tensor, Parameter
|
||||||
from mindspore.nn import Cell, TrainOneStepCell, Momentum
|
from mindspore.nn import Cell, TrainOneStepCell, Momentum
|
||||||
|
@ -24,7 +25,7 @@ from mindspore.common.initializer import initializer
|
||||||
class Net(Cell):
|
class Net(Cell):
|
||||||
def __init__(self, mul_weight, strategy1=None, strategy2=None, strategy3=None):
|
def __init__(self, mul_weight, strategy1=None, strategy2=None, strategy3=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.begin_norm_axis = -1
|
self.begin_norm_axis = 2
|
||||||
self.begin_params_axis = 1
|
self.begin_params_axis = 1
|
||||||
self.mul = P.Mul().set_strategy(strategy1)
|
self.mul = P.Mul().set_strategy(strategy1)
|
||||||
self.layer_norm = P.LayerNorm(self.begin_norm_axis, self.begin_params_axis).set_strategy(strategy2)
|
self.layer_norm = P.LayerNorm(self.begin_norm_axis, self.begin_params_axis).set_strategy(strategy2)
|
||||||
|
@ -64,18 +65,18 @@ def test_layer_norm_data_parallel():
|
||||||
|
|
||||||
def test_layer_norm_model_parallel():
|
def test_layer_norm_model_parallel():
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||||
strategy1 = ((1, 1, 16, 1), (1, 1, 16, 1))
|
strategy1 = ((1, 16, 1, 1), (1, 16, 1, 1))
|
||||||
strategy2 = ((1, 1, 16, 1), (1, 16, 1), (1, 16, 1))
|
strategy2 = ((1, 16, 1, 1), (16, 1, 1), (16, 1, 1))
|
||||||
strategy3 = ((1, 1, 16, 1), (1, 1, 16, 1))
|
strategy3 = ((1, 16, 1, 1), (1, 16, 1, 1))
|
||||||
net = Net(_w, strategy1, strategy2, strategy3)
|
net = Net(_w, strategy1, strategy2, strategy3)
|
||||||
compile(net)
|
compile(net)
|
||||||
|
|
||||||
|
|
||||||
def test_layer_norm_hybrid_parallel():
|
def test_layer_norm_hybrid_parallel():
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||||
strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1))
|
strategy1 = ((2, 8, 1, 1), (2, 8, 1, 1))
|
||||||
strategy2 = ((2, 2, 4, 1), (2, 4, 1), (2, 4, 1))
|
strategy2 = ((2, 8, 1, 1), (8, 1, 1), (8, 1, 1))
|
||||||
strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1))
|
strategy3 = ((2, 8, 1, 1), (2, 8, 1, 1))
|
||||||
net = Net(_w, strategy1, strategy2, strategy3)
|
net = Net(_w, strategy1, strategy2, strategy3)
|
||||||
compile(net)
|
compile(net)
|
||||||
|
|
||||||
|
@ -89,8 +90,17 @@ def test_layer_norm_auto_parallel():
|
||||||
def test_layer_norm_repeat_calc():
|
def test_layer_norm_repeat_calc():
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||||
strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1))
|
strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1))
|
||||||
strategy2 = ((1, 2, 2, 1), (2, 2, 1), (2, 2, 1))
|
strategy2 = ((2, 2, 1, 1), (2, 1, 1), (2, 1, 1))
|
||||||
strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1))
|
strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1))
|
||||||
net = Net(_w, strategy1, strategy2, strategy3)
|
net = Net(_w, strategy1, strategy2, strategy3)
|
||||||
compile(net)
|
compile(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer_norm_wrong_strategy():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||||
|
strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1))
|
||||||
|
strategy2 = ((1, 2, 1, 2), (2, 1, 2), (2, 1, 2))
|
||||||
|
strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1))
|
||||||
|
net = Net(_w, strategy1, strategy2, strategy3)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
compile(net)
|
||||||
|
|
Loading…
Reference in New Issue