forked from mindspore-Ecosystem/mindspore
!18730 add parallel operator for Conv2dTranspose and Conv2DBackPropInput
Merge pull request !18730 from yangzhenzhang/add-parallel-op-for-conv2d-backprop-input
This commit is contained in:
commit
2d8e44f3d0
|
@ -197,6 +197,8 @@ REGISTER(TopKInfo);
|
|||
REGISTER(ScatterUpdateInfo);
|
||||
REGISTER(VirtualOutputInfo);
|
||||
REGISTER(Conv2DInfo);
|
||||
REGISTER(Conv2DBackpropInputInfo);
|
||||
REGISTER(Conv2DTransposeInfo);
|
||||
REGISTER(BatchNormInfo);
|
||||
REGISTER(MaxPoolInfo);
|
||||
REGISTER(AvgPoolInfo);
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status Conv2DInfo::GetAttrs() {
|
||||
Status Conv2DInfo::GetAttrsBase() {
|
||||
// out_channel
|
||||
out_channel_ = GetIntAttr(OUT_CHANNEL);
|
||||
if (out_channel_ <= 0) {
|
||||
|
@ -121,6 +121,8 @@ Status Conv2DInfo::GetAttrs() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv2DInfo::GetAttrs() { return GetAttrsBase(); }
|
||||
|
||||
Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
|
||||
if (pad_mode_ == 0) { // 'pad' mode
|
||||
MS_LOG(ERROR) << name_ << ": The 'pad' mode do not support to split H or W";
|
||||
|
@ -175,7 +177,7 @@ Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) {
|
||||
MS_EXCEPTION_IF_NULL(strategy);
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy";
|
||||
|
@ -197,24 +199,12 @@ Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
if (input_strategy[1] != weight_strategy[1]) {
|
||||
MS_LOG(ERROR) << name_ << ": The shard num of c-in for input strategy is " << input_strategy[1]
|
||||
<< ", but the shard num of c-in for weight strategy is " << weight_strategy[1];
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (weight_strategy[2] != 1 || weight_strategy[3] != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The kernel size can not be split, but the strategy for kernel size is ("
|
||||
<< weight_strategy[2] << ", " << weight_strategy[3] << ")";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (input_strategy[2] != 1 || input_strategy[3] != 1) {
|
||||
if (CheckHWStrategy(input_strategy[2], input_strategy[3]) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
if (weight_strategy[0] > 1) {
|
||||
out_channel_shard_ = true;
|
||||
new_out_channel_ = out_channel_ / weight_strategy[1];
|
||||
|
@ -225,6 +215,29 @@ Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyBase(strategy) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
std::vector<Dimensions> stra = strategy->GetInputDim();
|
||||
Dimensions input_strategy = stra[0];
|
||||
Dimensions weight_strategy = stra[1];
|
||||
if (input_strategy[1] != weight_strategy[1]) {
|
||||
MS_LOG(ERROR) << name_ << ": The shard num of c-in for input strategy is " << input_strategy[1]
|
||||
<< ", but the shard num of c-in for weight strategy is " << weight_strategy[1];
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (input_strategy[2] != 1 || input_strategy[3] != 1) {
|
||||
if (CheckHWStrategy(input_strategy[2], input_strategy[3]) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv2DInfo::InferDevMatrixShape() {
|
||||
// the strategy is ((n, i, h, w), (o, i, 1, 1))
|
||||
// the dev matrix is (n, i, h, w, o)
|
||||
|
@ -254,7 +267,8 @@ Status Conv2DInfo::InferTensorMap() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
// if in channel is split, it need to insert all reduce
|
||||
// Conv2d: dev_matrix is (n, i, h, w, o), if in channel is split, it need to insert all reduce
|
||||
// Conv2DBackpropInputInfo: dev_matrix is (n, o, h, w, i), if out channel is split, it need to insert all reduce
|
||||
Status Conv2DInfo::InferForwardCommunication() {
|
||||
forward_op_.clear();
|
||||
size_t relevant_dim_index = IN_CHANNEL_INDEX;
|
||||
|
@ -329,5 +343,190 @@ Status Conv2DInfo::InitForCostModel(const StrategyPtr &strategy) {
|
|||
MS_LOG(INFO) << name_ << ": Init for cost model success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv2DBackpropInputInfo::GetOutShape() {
|
||||
if (input_value_.size() != 3) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of input value must be 3, but got " << input_value_.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (input_value_[2] == nullptr) {
|
||||
MS_LOG(ERROR) << name_ << ": The input_value_[2] is nullptr";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
std::vector<ValuePtr> elements;
|
||||
auto value_tuple = input_value_[2]->cast<ValueTuplePtr>();
|
||||
if (value_tuple == nullptr) {
|
||||
MS_LOG(ERROR) << name_ << ": Input_value_[2] must be ValueTuplePtr.";
|
||||
return FAILED;
|
||||
}
|
||||
elements = value_tuple->value();
|
||||
if (elements.size() != 4) {
|
||||
MS_LOG(ERROR) << name_ << ": Elements size must be 4, but got " << elements.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
for (auto &element : elements) {
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
if (element->isa<Int64Imm>()) {
|
||||
int64_t axis = element->cast<Int64ImmPtr>()->value();
|
||||
out_shape_.push_back(axis);
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << ": The value of shape must be int";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv2DBackpropInputInfo::GetAttrs() {
|
||||
if (GetAttrsBase() != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return GetOutShape();
|
||||
}
|
||||
|
||||
Status Conv2DBackpropInputInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyBase(strategy) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
std::vector<Dimensions> stra = strategy->GetInputDim();
|
||||
Dimensions input_strategy = stra[0];
|
||||
Dimensions weight_strategy = stra[1];
|
||||
if (input_strategy[1] != weight_strategy[0]) {
|
||||
MS_LOG(ERROR) << name_ << ": The shard num of c-out for input strategy is " << input_strategy[1]
|
||||
<< ", but the shard num of c-out for weight strategy is " << weight_strategy[0];
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (input_strategy[2] != 1 || input_strategy[3] != 1) {
|
||||
if (CheckHWStrategy(input_strategy[2], input_strategy[3]) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv2DBackpropInputInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) { return SUCCESS; }
|
||||
|
||||
Status Conv2DBackpropInputInfo::InferDevMatrixShape() {
|
||||
// the strategy is ((n, o, h, w), (o, i, 1, 1))
|
||||
// the dev matrix is (n, o, h, w, i)
|
||||
MS_EXCEPTION_IF_NULL(strategy_);
|
||||
std::vector<Dimensions> stra = strategy_->GetInputDim();
|
||||
if (stra.size() != 2) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << stra.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
dev_matrix_shape_ = stra[0];
|
||||
dev_matrix_shape_.push_back(stra[1][1]);
|
||||
|
||||
Shape out_strategy = stra[0];
|
||||
out_strategy[1] = stra[1][1];
|
||||
|
||||
out_slice_shape_ = out_shape_;
|
||||
if (out_shape_.size() != out_strategy.size()) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of out shape is " << out_shape_.size()
|
||||
<< ", but the size of output strategy is " << out_strategy.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < out_slice_shape_.size(); ++i) {
|
||||
if (out_slice_shape_[i] % out_strategy[i] != 0) {
|
||||
MS_LOG(ERROR) << name_ << ": The output can not be split by strategy. The shape of output is " << out_slice_shape_
|
||||
<< ", but the strategy of output is " << out_strategy;
|
||||
return FAILED;
|
||||
}
|
||||
out_slice_shape_[i] = out_slice_shape_[i] / out_strategy[i];
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << name_ << ": The output slice shape is " << out_slice_shape_;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv2DBackpropInputInfo::InferTensorMap() {
|
||||
// input_strategy: ((n, o, h, w), (o, i, 1, 1))
|
||||
// output_strategy: ((n, i, h, w),)
|
||||
// dev_matrix: (n, o, h, w, i)
|
||||
TensorMap input_tensor_map = {4, 3, 2, 1};
|
||||
TensorMap weight_tensor_map = {3, 0, -1, -1};
|
||||
TensorMap output_tensor_map = {4, 0, 2, 1};
|
||||
|
||||
(void)inputs_tensor_map_.emplace_back(std::move(input_tensor_map));
|
||||
(void)inputs_tensor_map_.emplace_back(std::move(weight_tensor_map));
|
||||
(void)outputs_tensor_map_.emplace_back(std::move(output_tensor_map));
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv2DBackpropInputInfo::InferMirrorOps() {
|
||||
mirror_ops_.clear();
|
||||
if (inputs_shape_.empty()) {
|
||||
MS_LOG(INFO) << name_ << ": The inputs size is empty";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
if (inputs_tensor_map_.size() != inputs_shape_.size()) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of inputs tensor map is not equal to the size of inputs shape";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
bool group_is_empty = true;
|
||||
for (size_t i = 0; i < inputs_tensor_map_.size(); ++i) {
|
||||
std::vector<Group> group;
|
||||
if (CreateGroupByTensorMap(inputs_tensor_map_[i], &group) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Create group failed, the input index is " << i;
|
||||
mirror_ops_.clear();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
OperatorVector mirror_op;
|
||||
if (group.empty()) {
|
||||
MS_LOG(INFO) << name_ << ": The mirror group is empty, the input index is " << i;
|
||||
mirror_ops_.push_back(mirror_op);
|
||||
continue;
|
||||
}
|
||||
|
||||
group_is_empty = false;
|
||||
mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
|
||||
mirror_ops_.push_back(mirror_op);
|
||||
}
|
||||
|
||||
if (group_is_empty) {
|
||||
mirror_ops_.clear();
|
||||
MS_LOG(INFO) << name_ << ": No need to insert mirror ops";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
OperatorVector tmp_mirror_op; // tmp mirror op for 'out_shape'
|
||||
mirror_ops_.push_back(tmp_mirror_op);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void Conv2DBackpropInputInfo::UpdateOutShape(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->size() != 4) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of cnode's inputs must be 4, but got " << cnode->size();
|
||||
}
|
||||
|
||||
if (!IsValueNode<ValueTuple>(cnode->input(3))) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The cnode's input[3] is not value node";
|
||||
}
|
||||
|
||||
auto func_graph = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
||||
ValuePtr out_shape = MakeValue(out_slice_shape_);
|
||||
AnfNodePtr val = NewValueNode(out_shape);
|
||||
(void)manager->Replace(cnode->input(3), val);
|
||||
MS_LOG(INFO) << name_ << ": Update the output shape " << out_slice_shape_;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -43,15 +43,15 @@ class Conv2DInfo : public OperatorInfo {
|
|||
void ReComputeBatchSplitFlagList() override;
|
||||
|
||||
protected:
|
||||
Status GetAttrsBase();
|
||||
Status GetAttrs() override;
|
||||
Status CheckStrategyBase(const StrategyPtr &strategy);
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
|
||||
Status InferForwardCommunication() override;
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||
|
||||
private:
|
||||
int64_t out_channel_ = 1;
|
||||
std::vector<int64_t> kernel_size_; // two integers
|
||||
int64_t mode_ = 1;
|
||||
|
@ -63,9 +63,43 @@ class Conv2DInfo : public OperatorInfo {
|
|||
std::string format_;
|
||||
bool out_channel_shard_ = false;
|
||||
int64_t new_out_channel_ = 1;
|
||||
|
||||
private:
|
||||
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
|
||||
};
|
||||
|
||||
class Conv2DBackpropInputInfo : public Conv2DInfo {
|
||||
public:
|
||||
Conv2DBackpropInputInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: Conv2DInfo(name, inputs_shape, outputs_shape, attrs) {}
|
||||
~Conv2DBackpropInputInfo() override = default;
|
||||
void UpdateOutShape(const CNodePtr &cnode);
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override;
|
||||
Status GetOutShape();
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Status InferMirrorOps() override; // can not use OperatorInfo::InferMirrorOps(), since the 'out_shape' is not tensor
|
||||
|
||||
private:
|
||||
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
|
||||
Shape out_shape_;
|
||||
Shape out_slice_shape_;
|
||||
};
|
||||
|
||||
class Conv2DTransposeInfo : public Conv2DBackpropInputInfo {
|
||||
public:
|
||||
Conv2DTransposeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: Conv2DBackpropInputInfo(name, inputs_shape, outputs_shape, attrs) {}
|
||||
~Conv2DTransposeInfo() override = default;
|
||||
};
|
||||
|
||||
constexpr size_t IN_CHANNEL_INDEX = 1;
|
||||
using Conv2DBackpropInputInfoPtr = std::shared_ptr<Conv2DBackpropInputInfo>;
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -269,6 +269,8 @@ constexpr char REDUCE_MEAN[] = "ReduceMean";
|
|||
constexpr char ARGMAXWITHVALUE[] = "ArgMaxWithValue";
|
||||
constexpr char ARGMINWITHVALUE[] = "ArgMinWithValue";
|
||||
constexpr char CONV2D[] = "Conv2D";
|
||||
constexpr char CONV2D_BACK_PROP_INPUT[] = "Conv2DBackpropInput";
|
||||
constexpr char CONV2D_TRANSPOSE[] = "Conv2DTranspose";
|
||||
constexpr char FUSE_BATCH_NORM[] = "FusedBatchNorm";
|
||||
constexpr char FUSE_BATCH_NORM_EX[] = "FusedBatchNormEx";
|
||||
constexpr char BATCH_NORM[] = "BatchNorm";
|
||||
|
|
|
@ -169,7 +169,7 @@ bool IsSplittableOperator(const std::string &op_name) {
|
|||
BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2,
|
||||
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
|
||||
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT,
|
||||
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT};
|
||||
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE};
|
||||
// clang-format on
|
||||
|
||||
auto iter = splittable_op.find(op_name);
|
||||
|
|
|
@ -2705,9 +2705,25 @@ void HandleTileNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &
|
|||
tile->UpdateMultiples(cnode);
|
||||
}
|
||||
|
||||
void HandleConv2dTransposeNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->size() != 4 || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
return;
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (prim->name() != CONV2D_BACK_PROP_INPUT && prim->name() != CONV2D_TRANSPOSE) {
|
||||
return;
|
||||
}
|
||||
|
||||
Conv2DBackpropInputInfoPtr op_ptr = std::dynamic_pointer_cast<Conv2DBackpropInputInfo>(distribute_operator);
|
||||
MS_EXCEPTION_IF_NULL(op_ptr);
|
||||
op_ptr->UpdateOutShape(cnode);
|
||||
}
|
||||
|
||||
void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
|
||||
HandleDropoutNode(distribute_operator, cnode);
|
||||
HandleTileNode(distribute_operator, cnode);
|
||||
HandleConv2dTransposeNode(distribute_operator, cnode);
|
||||
}
|
||||
|
||||
std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) {
|
||||
|
|
|
@ -44,6 +44,8 @@ def get_bprop_bias_add(self):
|
|||
@bprop_getters.register(P.Conv2D)
|
||||
def get_bprop_conv2d(self):
|
||||
"""Grad definition for `Conv2D` operation."""
|
||||
self.out_channel = self.get_attr_dict()["out_channel"]
|
||||
self.pad_list = self.get_attr_dict()["pad_list"]
|
||||
input_grad = P.Conv2DBackpropInput(
|
||||
self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
|
||||
dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
|
||||
|
@ -1055,12 +1057,13 @@ def get_bprop_roi_align(self):
|
|||
def get_bprop_conv2d_backprop_input(self):
|
||||
"""Grad definition for `Conv2DBackpropInput` operation."""
|
||||
pad_list = self.get_attr_dict()['pad_list']
|
||||
out_channel = self.get_attr_dict()['out_channel']
|
||||
filter_grad = G.Conv2DBackpropFilter(
|
||||
self.out_channel, self.kernel_size, self.pad_mode, self.pad, pad_list, mode=self.mode,
|
||||
out_channel, self.kernel_size, self.pad_mode, self.pad, pad_list, mode=self.mode,
|
||||
dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
|
||||
)
|
||||
input_grad = P.Conv2D(
|
||||
self.out_channel, self.kernel_size, pad_mode=self.pad_mode.lower(), pad=self.pad,
|
||||
out_channel, self.kernel_size, pad_mode=self.pad_mode.lower(), pad=self.pad,
|
||||
dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
|
||||
)
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
|
||||
0.1.0 MindSpore*1.1.0:î
|
||||
—
|
||||
bprop.10:doutbprop.10:[CNode]12:2bprop.10:[CNode]11:1"S-Prim-MakeTuple:HGradients/Default/network-NetIdentity/gradIdentity/S-Prim-MakeTuple-op20bprop.10*
|
||||
bprop.10:doutbprop.10:[CNode]12:2bprop.10:[CNode]11:1"S-Prim-MakeTuple:HGradients/Default/network-NetIdentity/gradIdentity/S-Prim-MakeTuple-op15bprop.10*
|
||||
|
||||
bprop.10:x*
|
||||
bprop.10:out*
|
||||
bprop.10:dout2
|
||||
bprop.10:[CNode]12:2:€14cac93a068aa39edcd5220275a7f3df23c79f939b5f52bbe3321d22bc4706d92366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb365c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8ece4dfce2542fdfdaa21de56a7b25766339f88051618a03a241de7dfdcb0347a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260
|
||||
bprop.10:[CNode]12:2:€14cac93a068aa39edcd5220275a7f3df23c79f939b5f52bbe3321d22bc4706d92366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb3522819d4919298eadafe049d3d0f3f1998cec40b35bed9c51c9d28b44ea7726065c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260
|
|
@ -8,4 +8,4 @@
|
|||
bprop.2:x*
|
||||
bprop.2:out*
|
||||
bprop.2:dout2
|
||||
bprop.2:[CNode]4:3:€14cac93a068aa39edcd5220275a7f3df23c79f939b5f52bbe3321d22bc4706d92366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb365c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4ec8ece4dfce2542fdfdaa21de56a7b25766339f88051618a03a241de7dfdcb0347a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260
|
||||
bprop.2:[CNode]4:3:€14cac93a068aa39edcd5220275a7f3df23c79f939b5f52bbe3321d22bc4706d92366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb3522819d4919298eadafe049d3d0f3f1998cec40b35bed9c51c9d28b44ea7726065c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.nn import Cell, TrainOneStepCell, Momentum
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
|
||||
strategy1=None, strategy2=None):
|
||||
super().__init__()
|
||||
self.conv2d_transpose = P.Conv2DTranspose(out_channel=out_channel, kernel_size=kernel_size,
|
||||
pad_mode=pad_mode, stride=stride).shard(strategy1)
|
||||
self.neg = P.Neg().shard(strategy2)
|
||||
self.weight = Parameter(conv2d_weight, "w1")
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.conv2d_transpose(x, self.weight, (32, 16, 8, 8))
|
||||
out = self.neg(out)
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([32, 8, 8, 8]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile_net(net):
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
train_net.set_train()
|
||||
_executor.compile(train_net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_conv2d_transpose_data_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
|
||||
strategy2 = ((8, 1, 1, 1),)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_transpose_model_parallel1():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||
strategy2 = ((8, 1, 1, 1),)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
Loading…
Reference in New Issue