forked from mindspore-Ecosystem/mindspore
!15766 modify tile op
From: @yangzhenzhang Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsuteng
This commit is contained in:
commit
cbca36d723
|
@ -45,7 +45,6 @@ class SelectInfo : public OperatorInfo {
|
||||||
protected:
|
protected:
|
||||||
Status GetAttrs() override { return SUCCESS; }
|
Status GetAttrs() override { return SUCCESS; }
|
||||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||||
Status InferMirrorOps() override { return SUCCESS; }
|
|
||||||
Status InferForwardCommunication() override { return SUCCESS; }
|
Status InferForwardCommunication() override { return SUCCESS; }
|
||||||
Status InferTensorInfo() override;
|
Status InferTensorInfo() override;
|
||||||
Status InferDevMatrixShape() override;
|
Status InferDevMatrixShape() override;
|
||||||
|
|
|
@ -65,8 +65,19 @@ Status TileInfo::GetAttrs() {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if some dimension of multiples > 1, split the multiple; otherwise, split the dimension of input
|
||||||
Status TileInfo::CheckStrategy(const StrategyPtr &strategy) {
|
Status TileInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
Shapes multiples = {full_multiples_};
|
Shape tmp;
|
||||||
|
for (size_t i = 0; i < full_multiples_.size(); ++i) {
|
||||||
|
if (full_multiples_[i] != 1) {
|
||||||
|
tmp.push_back(full_multiples_[i]);
|
||||||
|
} else {
|
||||||
|
tmp.push_back(inputs_shape_[0][i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Shapes multiples = {tmp};
|
||||||
|
MS_LOG(INFO) << name_ << ": The input shape is " << ShapeToString(inputs_shape_[0]) << ", the multiples is "
|
||||||
|
<< ShapeToString(full_multiples_) << ", so the 'shape' can be split is " << ShapeToString(tmp);
|
||||||
return CheckStrategyValue(strategy, multiples);
|
return CheckStrategyValue(strategy, multiples);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,7 +85,7 @@ Status TileInfo::InferDevMatrixShape() {
|
||||||
MS_EXCEPTION_IF_NULL(strategy_);
|
MS_EXCEPTION_IF_NULL(strategy_);
|
||||||
std::vector<Dimensions> stra = strategy_->GetInputDim();
|
std::vector<Dimensions> stra = strategy_->GetInputDim();
|
||||||
if (stra.empty()) {
|
if (stra.empty()) {
|
||||||
MS_LOG(ERROR) << name_ << "The strategy is empty";
|
MS_LOG(ERROR) << name_ << ": The strategy is empty";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
if (full_multiples_.size() != stra[0].size()) {
|
if (full_multiples_.size() != stra[0].size()) {
|
||||||
|
@ -86,6 +97,9 @@ Status TileInfo::InferDevMatrixShape() {
|
||||||
|
|
||||||
slice_multiples_ = full_multiples_;
|
slice_multiples_ = full_multiples_;
|
||||||
for (size_t i = 0; i < full_multiples_.size(); ++i) {
|
for (size_t i = 0; i < full_multiples_.size(); ++i) {
|
||||||
|
if (full_multiples_[i] == 1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
slice_multiples_[i] = slice_multiples_[i] / dev_matrix_shape_[i];
|
slice_multiples_[i] = slice_multiples_[i] / dev_matrix_shape_[i];
|
||||||
}
|
}
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
|
@ -95,13 +109,18 @@ Status TileInfo::InferTensorMap() {
|
||||||
TensorMap input_tensor_map;
|
TensorMap input_tensor_map;
|
||||||
TensorMap output_tensor_map;
|
TensorMap output_tensor_map;
|
||||||
if (inputs_shape_.empty() || outputs_shape_.empty()) {
|
if (inputs_shape_.empty() || outputs_shape_.empty()) {
|
||||||
MS_LOG(ERROR) << name_ << "The inputs or outputs' shape is empty";
|
MS_LOG(ERROR) << name_ << ": The inputs or outputs' shape is empty";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
// the input tensor cannot be split
|
// if some dimension of multiples > 1, split the multiple; otherwise, split the dimension of input
|
||||||
for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
|
for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
|
||||||
input_tensor_map.push_back(MAP_NONE);
|
input_tensor_map.push_back(inputs_shape_[0].size() - i - 1);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
|
||||||
|
if (full_multiples_[i] != 1) {
|
||||||
|
input_tensor_map[i] = MAP_NONE;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// cannot use dev_matrix_shape_ replace outputs_shape_[0], because it may not be fully split in all devices.
|
// cannot use dev_matrix_shape_ replace outputs_shape_[0], because it may not be fully split in all devices.
|
||||||
|
@ -163,11 +182,11 @@ Status TileInfo::InferTensorInfo() {
|
||||||
void TileInfo::UpdateMultiples(const CNodePtr &cnode) {
|
void TileInfo::UpdateMultiples(const CNodePtr &cnode) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
if (cnode->size() != 3) {
|
if (cnode->size() != 3) {
|
||||||
MS_LOG(EXCEPTION) << "The size of tile cnode's inputs must be 3";
|
MS_LOG(EXCEPTION) << name_ << ": The size of tile cnode's inputs must be 3";
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!IsValueNode<ValueTuple>(cnode->input(2))) {
|
if (!IsValueNode<ValueTuple>(cnode->input(2))) {
|
||||||
MS_LOG(EXCEPTION) << "The input[2] of tile cnode is not ValueTuple.";
|
MS_LOG(EXCEPTION) << name_ << ": The input[2] of tile cnode is not ValueTuple.";
|
||||||
}
|
}
|
||||||
|
|
||||||
auto func_graph = cnode->func_graph();
|
auto func_graph = cnode->func_graph();
|
||||||
|
@ -199,7 +218,14 @@ Status TileInfo::GenerateStrategies(int64_t stage_id) {
|
||||||
Shape multiples_split(full_multiples_.size(), 1);
|
Shape multiples_split(full_multiples_.size(), 1);
|
||||||
Shapes splittable_inputs = {multiples_split};
|
Shapes splittable_inputs = {multiples_split};
|
||||||
|
|
||||||
|
// if some dimension of multiples > 1, split the multiple; otherwise, split the dimension of input
|
||||||
std::vector<StrategyPtr> sp_vector;
|
std::vector<StrategyPtr> sp_vector;
|
||||||
|
Shape tmp_input_shape = full_multiples_;
|
||||||
|
for (size_t i = 0; i < full_multiples_.size(); ++i) {
|
||||||
|
if (full_multiples_[i] == 0) {
|
||||||
|
tmp_input_shape[i] = inputs_shape_[0][i];
|
||||||
|
}
|
||||||
|
}
|
||||||
Shapes tmp_inputs_shape = {full_multiples_};
|
Shapes tmp_inputs_shape = {full_multiples_};
|
||||||
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) {
|
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
|
|
@ -94,6 +94,15 @@ def test_select_model_parallel():
|
||||||
compile_net(net)
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_mirror():
|
||||||
|
context.set_auto_parallel_context(
|
||||||
|
parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
strategy1 = ((1, 2, 2), (1, 2, 2))
|
||||||
|
strategy2 = ((1, 2, 2), (1, 2, 2), (1, 2, 2))
|
||||||
|
net = Net(_w1, _w2, strategy1, strategy2)
|
||||||
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
def test_select_auto_parallel():
|
def test_select_auto_parallel():
|
||||||
context.set_auto_parallel_context(
|
context.set_auto_parallel_context(
|
||||||
parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||||
|
|
|
@ -52,20 +52,38 @@ class Net2(Cell):
|
||||||
out = self.tile(out, (8, 8, 4, 2))
|
out = self.tile(out, (8, 8, 4, 2))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class Net3(Cell):
|
||||||
|
def __init__(self, weight, strategy1=None, strategy2=None, is_parameter=True):
|
||||||
|
super().__init__()
|
||||||
|
self.mul = P.Mul().shard(strategy1)
|
||||||
|
self.tile = P.Tile().shard(strategy2)
|
||||||
|
if is_parameter:
|
||||||
|
self.weight = Parameter(weight, "w1")
|
||||||
|
else:
|
||||||
|
self.weight = weight
|
||||||
|
self.mul2 = P.Mul()
|
||||||
|
|
||||||
|
def construct(self, x, b):
|
||||||
|
out = self.tile(self.weight, (8, 1, 1))
|
||||||
|
out = self.mul(x, out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||||
|
_x1 = Tensor(np.ones([128, 16, 16]), dtype=ms.float32)
|
||||||
_w1 = Tensor(np.ones([16, 16, 16]), dtype=ms.float32)
|
_w1 = Tensor(np.ones([16, 16, 16]), dtype=ms.float32)
|
||||||
_w2 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
_w2 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||||
|
_w3 = Tensor(np.ones([128, 16, 16]), dtype=ms.float32)
|
||||||
_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||||
|
|
||||||
|
|
||||||
def compile_net(net):
|
def compile_net(net, x=_b, b=_b):
|
||||||
context.set_context(save_graphs=False)
|
context.set_context(save_graphs=True)
|
||||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||||
train_net = TrainOneStepCell(net, optimizer)
|
train_net = TrainOneStepCell(net, optimizer)
|
||||||
train_net.set_auto_parallel()
|
train_net.set_auto_parallel()
|
||||||
train_net.set_train()
|
train_net.set_train()
|
||||||
_executor.compile(train_net, _x, _b)
|
_executor.compile(train_net, x, b)
|
||||||
context.reset_auto_parallel_context()
|
context.reset_auto_parallel_context()
|
||||||
|
|
||||||
|
|
||||||
|
@ -101,6 +119,14 @@ def test_tile_tensor_no_full_split():
|
||||||
compile_net(net)
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tile_tensor_no_full_split2():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
strategy1 = ((2, 2, 1), (2, 2, 1))
|
||||||
|
strategy2 = ((2, 2, 1),)
|
||||||
|
net = Net3(_w1, strategy1, strategy2)
|
||||||
|
compile_net(net, _x1, _b)
|
||||||
|
|
||||||
|
|
||||||
def test_tile_output():
|
def test_tile_output():
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
strategy1 = ((2, 2, 2), (2, 2, 2))
|
strategy1 = ((2, 2, 2), (2, 2, 2))
|
||||||
|
@ -108,6 +134,7 @@ def test_tile_output():
|
||||||
net = Net2(_w2, strategy1, strategy2)
|
net = Net2(_w2, strategy1, strategy2)
|
||||||
compile_net(net)
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
def test_tile_output_no_full_split():
|
def test_tile_output_no_full_split():
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
strategy1 = ((2, 2, 2), (2, 2, 2))
|
strategy1 = ((2, 2, 2), (2, 2, 2))
|
||||||
|
@ -123,7 +150,14 @@ def test_tile_no_strategy():
|
||||||
net = Net2(_w2, strategy1, strategy2)
|
net = Net2(_w2, strategy1, strategy2)
|
||||||
compile_net(net)
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
def test_tile_auto_parallel():
|
def test_tile_auto_parallel():
|
||||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||||
net = Net2(_w2)
|
net = Net2(_w2)
|
||||||
compile_net(net)
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tile_auto_parallel_2():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
net = Net3(_w1)
|
||||||
|
compile_net(net, _x1, _b)
|
||||||
|
|
Loading…
Reference in New Issue