add parallel op for conv2d backprop input

This commit is contained in:
yangzhenzhang 2021-06-22 11:45:35 +08:00
parent eb30e1e691
commit 69acf757d0
10 changed files with 345 additions and 23 deletions

View File

@ -197,6 +197,8 @@ REGISTER(TopKInfo);
REGISTER(ScatterUpdateInfo);
REGISTER(VirtualOutputInfo);
REGISTER(Conv2DInfo);
REGISTER(Conv2DBackpropInputInfo);
REGISTER(Conv2DTransposeInfo);
REGISTER(BatchNormInfo);
REGISTER(MaxPoolInfo);
REGISTER(AvgPoolInfo);

View File

@ -28,7 +28,7 @@
namespace mindspore {
namespace parallel {
Status Conv2DInfo::GetAttrs() {
Status Conv2DInfo::GetAttrsBase() {
// out_channel
out_channel_ = GetIntAttr(OUT_CHANNEL);
if (out_channel_ <= 0) {
@ -121,6 +121,8 @@ Status Conv2DInfo::GetAttrs() {
return SUCCESS;
}
Status Conv2DInfo::GetAttrs() { return GetAttrsBase(); }
Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
if (pad_mode_ == 0) { // 'pad' mode
MS_LOG(ERROR) << name_ << ": The 'pad' mode do not support to split H or W";
@ -175,7 +177,7 @@ Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
return SUCCESS;
}
Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) {
Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy);
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy";
@ -197,24 +199,12 @@ Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED;
}
if (input_strategy[1] != weight_strategy[1]) {
MS_LOG(ERROR) << name_ << ": The shard num of c-in for input strategy is " << input_strategy[1]
<< ", but the shard num of c-in for weight strategy is " << weight_strategy[1];
return FAILED;
}
if (weight_strategy[2] != 1 || weight_strategy[3] != 1) {
MS_LOG(ERROR) << name_ << ": The kernel size can not be split, but the strategy for kernel size is ("
<< weight_strategy[2] << ", " << weight_strategy[3] << ")";
return FAILED;
}
if (input_strategy[2] != 1 || input_strategy[3] != 1) {
if (CheckHWStrategy(input_strategy[2], input_strategy[3]) != SUCCESS) {
return FAILED;
}
}
if (weight_strategy[0] > 1) {
out_channel_shard_ = true;
new_out_channel_ = out_channel_ / weight_strategy[1];
@ -225,6 +215,29 @@ Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) {
return SUCCESS;
}
Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyBase(strategy) != SUCCESS) {
return FAILED;
}
std::vector<Dimensions> stra = strategy->GetInputDim();
Dimensions input_strategy = stra[0];
Dimensions weight_strategy = stra[1];
if (input_strategy[1] != weight_strategy[1]) {
MS_LOG(ERROR) << name_ << ": The shard num of c-in for input strategy is " << input_strategy[1]
<< ", but the shard num of c-in for weight strategy is " << weight_strategy[1];
return FAILED;
}
if (input_strategy[2] != 1 || input_strategy[3] != 1) {
if (CheckHWStrategy(input_strategy[2], input_strategy[3]) != SUCCESS) {
return FAILED;
}
}
return SUCCESS;
}
Status Conv2DInfo::InferDevMatrixShape() {
// the strategy is ((n, i, h, w), (o, i, 1, 1))
// the dev matrix is (n, i, h, w, o)
@ -254,7 +267,8 @@ Status Conv2DInfo::InferTensorMap() {
return SUCCESS;
}
// if in channel is split, it need to insert all reduce
// Conv2d: dev_matrix is (n, i, h, w, o), if in channel is split, it need to insert all reduce
// Conv2DBackpropInputInfo: dev_matrix is (n, o, h, w, i), if out channel is split, it need to insert all reduce
Status Conv2DInfo::InferForwardCommunication() {
forward_op_.clear();
size_t relevant_dim_index = IN_CHANNEL_INDEX;
@ -329,5 +343,190 @@ Status Conv2DInfo::InitForCostModel(const StrategyPtr &strategy) {
MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
Status Conv2DBackpropInputInfo::GetOutShape() {
if (input_value_.size() != 3) {
MS_LOG(ERROR) << name_ << ": The size of input value must be 3, but got " << input_value_.size();
return FAILED;
}
if (input_value_[2] == nullptr) {
MS_LOG(ERROR) << name_ << ": The input_value_[2] is nullptr";
return FAILED;
}
std::vector<ValuePtr> elements;
auto value_tuple = input_value_[2]->cast<ValueTuplePtr>();
if (value_tuple == nullptr) {
MS_LOG(ERROR) << name_ << ": Input_value_[2] must be ValueTuplePtr.";
return FAILED;
}
elements = value_tuple->value();
if (elements.size() != 4) {
MS_LOG(ERROR) << name_ << ": Elements size must be 4, but got " << elements.size();
return FAILED;
}
for (auto &element : elements) {
MS_EXCEPTION_IF_NULL(element);
if (element->isa<Int64Imm>()) {
int64_t axis = element->cast<Int64ImmPtr>()->value();
out_shape_.push_back(axis);
} else {
MS_LOG(ERROR) << name_ << ": The value of shape must be int";
return FAILED;
}
}
return SUCCESS;
}
Status Conv2DBackpropInputInfo::GetAttrs() {
if (GetAttrsBase() != SUCCESS) {
return FAILED;
}
return GetOutShape();
}
Status Conv2DBackpropInputInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyBase(strategy) != SUCCESS) {
return FAILED;
}
std::vector<Dimensions> stra = strategy->GetInputDim();
Dimensions input_strategy = stra[0];
Dimensions weight_strategy = stra[1];
if (input_strategy[1] != weight_strategy[0]) {
MS_LOG(ERROR) << name_ << ": The shard num of c-out for input strategy is " << input_strategy[1]
<< ", but the shard num of c-out for weight strategy is " << weight_strategy[0];
return FAILED;
}
if (input_strategy[2] != 1 || input_strategy[3] != 1) {
if (CheckHWStrategy(input_strategy[2], input_strategy[3]) != SUCCESS) {
return FAILED;
}
}
return SUCCESS;
}
Status Conv2DBackpropInputInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) { return SUCCESS; }
Status Conv2DBackpropInputInfo::InferDevMatrixShape() {
// the strategy is ((n, o, h, w), (o, i, 1, 1))
// the dev matrix is (n, o, h, w, i)
MS_EXCEPTION_IF_NULL(strategy_);
std::vector<Dimensions> stra = strategy_->GetInputDim();
if (stra.size() != 2) {
MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << stra.size();
return FAILED;
}
dev_matrix_shape_ = stra[0];
dev_matrix_shape_.push_back(stra[1][1]);
Shape out_strategy = stra[0];
out_strategy[1] = stra[1][1];
out_slice_shape_ = out_shape_;
if (out_shape_.size() != out_strategy.size()) {
MS_LOG(ERROR) << name_ << ": The size of out shape is " << out_shape_.size()
<< ", but the size of output strategy is " << out_strategy.size();
return FAILED;
}
for (size_t i = 0; i < out_slice_shape_.size(); ++i) {
if (out_slice_shape_[i] % out_strategy[i] != 0) {
MS_LOG(ERROR) << name_ << ": The output can not be split by strategy. The shape of output is " << out_slice_shape_
<< ", but the strategy of output is " << out_strategy;
return FAILED;
}
out_slice_shape_[i] = out_slice_shape_[i] / out_strategy[i];
}
MS_LOG(INFO) << name_ << ": The output slice shape is " << out_slice_shape_;
return SUCCESS;
}
Status Conv2DBackpropInputInfo::InferTensorMap() {
// input_strategy: ((n, o, h, w), (o, i, 1, 1))
// output_strategy: ((n, i, h, w),)
// dev_matrix: (n, o, h, w, i)
TensorMap input_tensor_map = {4, 3, 2, 1};
TensorMap weight_tensor_map = {3, 0, -1, -1};
TensorMap output_tensor_map = {4, 0, 2, 1};
(void)inputs_tensor_map_.emplace_back(std::move(input_tensor_map));
(void)inputs_tensor_map_.emplace_back(std::move(weight_tensor_map));
(void)outputs_tensor_map_.emplace_back(std::move(output_tensor_map));
return SUCCESS;
}
Status Conv2DBackpropInputInfo::InferMirrorOps() {
mirror_ops_.clear();
if (inputs_shape_.empty()) {
MS_LOG(INFO) << name_ << ": The inputs size is empty";
return SUCCESS;
}
if (inputs_tensor_map_.size() != inputs_shape_.size()) {
MS_LOG(ERROR) << name_ << ": The size of inputs tensor map is not equal to the size of inputs shape";
return FAILED;
}
bool group_is_empty = true;
for (size_t i = 0; i < inputs_tensor_map_.size(); ++i) {
std::vector<Group> group;
if (CreateGroupByTensorMap(inputs_tensor_map_[i], &group) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Create group failed, the input index is " << i;
mirror_ops_.clear();
return FAILED;
}
OperatorVector mirror_op;
if (group.empty()) {
MS_LOG(INFO) << name_ << ": The mirror group is empty, the input index is " << i;
mirror_ops_.push_back(mirror_op);
continue;
}
group_is_empty = false;
mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
mirror_ops_.push_back(mirror_op);
}
if (group_is_empty) {
mirror_ops_.clear();
MS_LOG(INFO) << name_ << ": No need to insert mirror ops";
return SUCCESS;
}
OperatorVector tmp_mirror_op; // tmp mirror op for 'out_shape'
mirror_ops_.push_back(tmp_mirror_op);
return SUCCESS;
}
void Conv2DBackpropInputInfo::UpdateOutShape(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() != 4) {
MS_LOG(EXCEPTION) << name_ << ": The size of cnode's inputs must be 4, but got " << cnode->size();
}
if (!IsValueNode<ValueTuple>(cnode->input(3))) {
MS_LOG(EXCEPTION) << name_ << ": The cnode's input[3] is not value node";
}
auto func_graph = cnode->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
ValuePtr out_shape = MakeValue(out_slice_shape_);
AnfNodePtr val = NewValueNode(out_shape);
(void)manager->Replace(cnode->input(3), val);
MS_LOG(INFO) << name_ << ": Update the output shape " << out_slice_shape_;
}
} // namespace parallel
} // namespace mindspore

View File

@ -43,15 +43,15 @@ class Conv2DInfo : public OperatorInfo {
void ReComputeBatchSplitFlagList() override;
protected:
Status GetAttrsBase();
Status GetAttrs() override;
Status CheckStrategyBase(const StrategyPtr &strategy);
Status CheckStrategy(const StrategyPtr &strategy) override;
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
Status InferForwardCommunication() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
private:
int64_t out_channel_ = 1;
std::vector<int64_t> kernel_size_; // two integers
int64_t mode_ = 1;
@ -63,9 +63,43 @@ class Conv2DInfo : public OperatorInfo {
std::string format_;
bool out_channel_shard_ = false;
int64_t new_out_channel_ = 1;
private:
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
};
class Conv2DBackpropInputInfo : public Conv2DInfo {
public:
Conv2DBackpropInputInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: Conv2DInfo(name, inputs_shape, outputs_shape, attrs) {}
~Conv2DBackpropInputInfo() override = default;
void UpdateOutShape(const CNodePtr &cnode);
protected:
Status GetAttrs() override;
Status GetOutShape();
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status InferMirrorOps() override; // can not use OperatorInfo::InferMirrorOps(), since the 'out_shape' is not tensor
private:
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
Shape out_shape_;
Shape out_slice_shape_;
};
class Conv2DTransposeInfo : public Conv2DBackpropInputInfo {
public:
Conv2DTransposeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: Conv2DBackpropInputInfo(name, inputs_shape, outputs_shape, attrs) {}
~Conv2DTransposeInfo() override = default;
};
constexpr size_t IN_CHANNEL_INDEX = 1;
using Conv2DBackpropInputInfoPtr = std::shared_ptr<Conv2DBackpropInputInfo>;
} // namespace parallel
} // namespace mindspore

View File

@ -269,6 +269,8 @@ constexpr char REDUCE_MEAN[] = "ReduceMean";
constexpr char ARGMAXWITHVALUE[] = "ArgMaxWithValue";
constexpr char ARGMINWITHVALUE[] = "ArgMinWithValue";
constexpr char CONV2D[] = "Conv2D";
constexpr char CONV2D_BACK_PROP_INPUT[] = "Conv2DBackpropInput";
constexpr char CONV2D_TRANSPOSE[] = "Conv2DTranspose";
constexpr char FUSE_BATCH_NORM[] = "FusedBatchNorm";
constexpr char FUSE_BATCH_NORM_EX[] = "FusedBatchNormEx";
constexpr char BATCH_NORM[] = "BatchNorm";

View File

@ -169,7 +169,7 @@ bool IsSplittableOperator(const std::string &op_name) {
BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2,
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT,
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT};
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE};
// clang-format on
auto iter = splittable_op.find(op_name);

View File

@ -2705,9 +2705,25 @@ void HandleTileNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &
tile->UpdateMultiples(cnode);
}
void HandleConv2dTransposeNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() != 4 || !IsValueNode<Primitive>(cnode->input(0))) {
return;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim->name() != CONV2D_BACK_PROP_INPUT && prim->name() != CONV2D_TRANSPOSE) {
return;
}
Conv2DBackpropInputInfoPtr op_ptr = std::dynamic_pointer_cast<Conv2DBackpropInputInfo>(distribute_operator);
MS_EXCEPTION_IF_NULL(op_ptr);
op_ptr->UpdateOutShape(cnode);
}
void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
HandleDropoutNode(distribute_operator, cnode);
HandleTileNode(distribute_operator, cnode);
HandleConv2dTransposeNode(distribute_operator, cnode);
}
std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) {

View File

@ -44,6 +44,8 @@ def get_bprop_bias_add(self):
@bprop_getters.register(P.Conv2D)
def get_bprop_conv2d(self):
"""Grad definition for `Conv2D` operation."""
self.out_channel = self.get_attr_dict()["out_channel"]
self.pad_list = self.get_attr_dict()["pad_list"]
input_grad = P.Conv2DBackpropInput(
self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
@ -1055,12 +1057,13 @@ def get_bprop_roi_align(self):
def get_bprop_conv2d_backprop_input(self):
"""Grad definition for `Conv2DBackpropInput` operation."""
pad_list = self.get_attr_dict()['pad_list']
out_channel = self.get_attr_dict()['out_channel']
filter_grad = G.Conv2DBackpropFilter(
self.out_channel, self.kernel_size, self.pad_mode, self.pad, pad_list, mode=self.mode,
out_channel, self.kernel_size, self.pad_mode, self.pad, pad_list, mode=self.mode,
dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
)
input_grad = P.Conv2D(
self.out_channel, self.kernel_size, pad_mode=self.pad_mode.lower(), pad=self.pad,
out_channel, self.kernel_size, pad_mode=self.pad_mode.lower(), pad=self.pad,
dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
)

View File

@ -1,9 +1,9 @@
0.1.0 MindSpore*1.1.0:î

bprop.10:doutbprop.10:[CNode]12:2bprop.10:[CNode]11:1"S-Prim-MakeTuple:HGradients/Default/network-NetIdentity/gradIdentity/S-Prim-MakeTuple-op20bprop.10*
bprop.10:doutbprop.10:[CNode]12:2bprop.10:[CNode]11:1"S-Prim-MakeTuple:HGradients/Default/network-NetIdentity/gradIdentity/S-Prim-MakeTuple-op15bprop.10*
bprop.10:x*
bprop.10:out*
bprop.10:dout2
bprop.10:[CNode]12:2:€14cac93a068aa39edcd5220275a7f3df23c79f939b5f52bbe3321d22bc4706d92366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb365c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8ece4dfce2542fdfdaa21de56a7b25766339f88051618a03a241de7dfdcb0347a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260
bprop.10:[CNode]12:2:€14cac93a068aa39edcd5220275a7f3df23c79f939b5f52bbe3321d22bc4706d92366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb3522819d4919298eadafe049d3d0f3f1998cec40b35bed9c51c9d28b44ea7726065c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260

View File

@ -8,4 +8,4 @@
bprop.2:x*
bprop.2:out*
bprop.2:dout2
bprop.2:[CNode]4:3:€14cac93a068aa39edcd5220275a7f3df23c79f939b5f52bbe3321d22bc4706d92366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb365c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8ece4dfce2542fdfdaa21de56a7b25766339f88051618a03a241de7dfdcb0347a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260
bprop.2:[CNode]4:3:€14cac93a068aa39edcd5220275a7f3df23c79f939b5f52bbe3321d22bc4706d92366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb3522819d4919298eadafe049d3d0f3f1998cec40b35bed9c51c9d28b44ea7726065c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260

View File

@ -0,0 +1,66 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.common.api import _executor
from mindspore.nn import Cell, TrainOneStepCell, Momentum
from mindspore.ops import operations as P
class Net(Cell):
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
strategy1=None, strategy2=None):
super().__init__()
self.conv2d_transpose = P.Conv2DTranspose(out_channel=out_channel, kernel_size=kernel_size,
pad_mode=pad_mode, stride=stride).shard(strategy1)
self.neg = P.Neg().shard(strategy2)
self.weight = Parameter(conv2d_weight, "w1")
def construct(self, x, b):
out = self.conv2d_transpose(x, self.weight, (32, 16, 8, 8))
out = self.neg(out)
return out
_x = Tensor(np.ones([32, 8, 8, 8]), dtype=ms.float32)
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
def compile_net(net):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_executor.compile(train_net, _x, _b)
context.reset_auto_parallel_context()
def test_conv2d_transpose_data_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((8, 1, 1, 1),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_conv2d_transpose_model_parallel1():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((8, 1, 1, 1),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
compile_net(net)