forked from mindspore-Ecosystem/mindspore
!15766 modify tile op
From: @yangzhenzhang Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsuteng
This commit is contained in:
commit
cbca36d723
|
@ -45,7 +45,6 @@ class SelectInfo : public OperatorInfo {
|
|||
protected:
|
||||
Status GetAttrs() override { return SUCCESS; }
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferMirrorOps() override { return SUCCESS; }
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferTensorInfo() override;
|
||||
Status InferDevMatrixShape() override;
|
||||
|
|
|
@ -65,8 +65,19 @@ Status TileInfo::GetAttrs() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
// if some dimension of multiples > 1, split the multiple; otherwise, split the dimension of input
|
||||
Status TileInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
Shapes multiples = {full_multiples_};
|
||||
Shape tmp;
|
||||
for (size_t i = 0; i < full_multiples_.size(); ++i) {
|
||||
if (full_multiples_[i] != 1) {
|
||||
tmp.push_back(full_multiples_[i]);
|
||||
} else {
|
||||
tmp.push_back(inputs_shape_[0][i]);
|
||||
}
|
||||
}
|
||||
Shapes multiples = {tmp};
|
||||
MS_LOG(INFO) << name_ << ": The input shape is " << ShapeToString(inputs_shape_[0]) << ", the multiples is "
|
||||
<< ShapeToString(full_multiples_) << ", so the 'shape' can be split is " << ShapeToString(tmp);
|
||||
return CheckStrategyValue(strategy, multiples);
|
||||
}
|
||||
|
||||
|
@ -74,7 +85,7 @@ Status TileInfo::InferDevMatrixShape() {
|
|||
MS_EXCEPTION_IF_NULL(strategy_);
|
||||
std::vector<Dimensions> stra = strategy_->GetInputDim();
|
||||
if (stra.empty()) {
|
||||
MS_LOG(ERROR) << name_ << "The strategy is empty";
|
||||
MS_LOG(ERROR) << name_ << ": The strategy is empty";
|
||||
return FAILED;
|
||||
}
|
||||
if (full_multiples_.size() != stra[0].size()) {
|
||||
|
@ -86,6 +97,9 @@ Status TileInfo::InferDevMatrixShape() {
|
|||
|
||||
slice_multiples_ = full_multiples_;
|
||||
for (size_t i = 0; i < full_multiples_.size(); ++i) {
|
||||
if (full_multiples_[i] == 1) {
|
||||
continue;
|
||||
}
|
||||
slice_multiples_[i] = slice_multiples_[i] / dev_matrix_shape_[i];
|
||||
}
|
||||
return SUCCESS;
|
||||
|
@ -95,13 +109,18 @@ Status TileInfo::InferTensorMap() {
|
|||
TensorMap input_tensor_map;
|
||||
TensorMap output_tensor_map;
|
||||
if (inputs_shape_.empty() || outputs_shape_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << "The inputs or outputs' shape is empty";
|
||||
MS_LOG(ERROR) << name_ << ": The inputs or outputs' shape is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// the input tensor cannot be split
|
||||
// if some dimension of multiples > 1, split the multiple; otherwise, split the dimension of input
|
||||
for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
|
||||
input_tensor_map.push_back(MAP_NONE);
|
||||
input_tensor_map.push_back(inputs_shape_[0].size() - i - 1);
|
||||
}
|
||||
for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
|
||||
if (full_multiples_[i] != 1) {
|
||||
input_tensor_map[i] = MAP_NONE;
|
||||
}
|
||||
}
|
||||
|
||||
// cannot use dev_matrix_shape_ replace outputs_shape_[0], because it may not be fully split in all devices.
|
||||
|
@ -163,11 +182,11 @@ Status TileInfo::InferTensorInfo() {
|
|||
void TileInfo::UpdateMultiples(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->size() != 3) {
|
||||
MS_LOG(EXCEPTION) << "The size of tile cnode's inputs must be 3";
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of tile cnode's inputs must be 3";
|
||||
}
|
||||
|
||||
if (!IsValueNode<ValueTuple>(cnode->input(2))) {
|
||||
MS_LOG(EXCEPTION) << "The input[2] of tile cnode is not ValueTuple.";
|
||||
MS_LOG(EXCEPTION) << name_ << ": The input[2] of tile cnode is not ValueTuple.";
|
||||
}
|
||||
|
||||
auto func_graph = cnode->func_graph();
|
||||
|
@ -199,7 +218,14 @@ Status TileInfo::GenerateStrategies(int64_t stage_id) {
|
|||
Shape multiples_split(full_multiples_.size(), 1);
|
||||
Shapes splittable_inputs = {multiples_split};
|
||||
|
||||
// if some dimension of multiples > 1, split the multiple; otherwise, split the dimension of input
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
Shape tmp_input_shape = full_multiples_;
|
||||
for (size_t i = 0; i < full_multiples_.size(); ++i) {
|
||||
if (full_multiples_[i] == 0) {
|
||||
tmp_input_shape[i] = inputs_shape_[0][i];
|
||||
}
|
||||
}
|
||||
Shapes tmp_inputs_shape = {full_multiples_};
|
||||
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) {
|
||||
return FAILED;
|
||||
|
|
|
@ -94,6 +94,15 @@ def test_select_model_parallel():
|
|||
compile_net(net)
|
||||
|
||||
|
||||
def test_select_mirror():
|
||||
context.set_auto_parallel_context(
|
||||
parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 2, 2), (1, 2, 2))
|
||||
strategy2 = ((1, 2, 2), (1, 2, 2), (1, 2, 2))
|
||||
net = Net(_w1, _w2, strategy1, strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_select_auto_parallel():
|
||||
context.set_auto_parallel_context(
|
||||
parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
|
|
|
@ -52,20 +52,38 @@ class Net2(Cell):
|
|||
out = self.tile(out, (8, 8, 4, 2))
|
||||
return out
|
||||
|
||||
class Net3(Cell):
|
||||
def __init__(self, weight, strategy1=None, strategy2=None, is_parameter=True):
|
||||
super().__init__()
|
||||
self.mul = P.Mul().shard(strategy1)
|
||||
self.tile = P.Tile().shard(strategy2)
|
||||
if is_parameter:
|
||||
self.weight = Parameter(weight, "w1")
|
||||
else:
|
||||
self.weight = weight
|
||||
self.mul2 = P.Mul()
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.tile(self.weight, (8, 1, 1))
|
||||
out = self.mul(x, out)
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
_x1 = Tensor(np.ones([128, 16, 16]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([16, 16, 16]), dtype=ms.float32)
|
||||
_w2 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
_w3 = Tensor(np.ones([128, 16, 16]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile_net(net):
|
||||
context.set_context(save_graphs=False)
|
||||
def compile_net(net, x=_b, b=_b):
|
||||
context.set_context(save_graphs=True)
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
train_net.set_train()
|
||||
_executor.compile(train_net, _x, _b)
|
||||
_executor.compile(train_net, x, b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
|
@ -101,6 +119,14 @@ def test_tile_tensor_no_full_split():
|
|||
compile_net(net)
|
||||
|
||||
|
||||
def test_tile_tensor_no_full_split2():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 1), (2, 2, 1))
|
||||
strategy2 = ((2, 2, 1),)
|
||||
net = Net3(_w1, strategy1, strategy2)
|
||||
compile_net(net, _x1, _b)
|
||||
|
||||
|
||||
def test_tile_output():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 2), (2, 2, 2))
|
||||
|
@ -108,6 +134,7 @@ def test_tile_output():
|
|||
net = Net2(_w2, strategy1, strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_tile_output_no_full_split():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 2), (2, 2, 2))
|
||||
|
@ -123,7 +150,14 @@ def test_tile_no_strategy():
|
|||
net = Net2(_w2, strategy1, strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_tile_auto_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net2(_w2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_tile_auto_parallel_2():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net3(_w1)
|
||||
compile_net(net, _x1, _b)
|
||||
|
|
Loading…
Reference in New Issue