forked from mindspore-Ecosystem/mindspore
!35448 add deformable offsets grad op for Ascend
Merge pull request !35448 from guoqi/deformable_offsets_grad
This commit is contained in:
commit
d6bbef46ee
|
@ -52,4 +52,7 @@ mindspore.ops.deformable_conv2d
|
|||
|
||||
.. note::
|
||||
- 这是一个实验性质的接口,将来有可能被修改或删除。
|
||||
- 在Ascend平台上,只支持同时满足 :math:`C_{in}` 能被8整除, `deformable_groups` 为1且 `offsets` 的数据是浮点数类型(即需要包含小数部分)的场景。例如, `x` 的shape为 :math:`(N, 2, H_{in}, W_{in})` 、 `deformable_groups` 为2、使用"numpy.ones()"函数去赋值 `offsets` 等场景均不支持。
|
||||
- 在Ascend平台上,目前需满足如下条件:
|
||||
- :math:`C_{in}` 能被8整除。
|
||||
- `deformable_groups` 为1且 `offsets` 的数据是浮点数类型(即需要包含小数部分)。
|
||||
- `kernel_size` 需大于1。
|
||||
|
|
|
@ -59,6 +59,7 @@
|
|||
#include "plugin/device/ascend/optimizer/ir_fusion/refresh_parameter_format.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fusion/transpose_transdata_fusion.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fusion/deformable_offsets_fusion.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fusion/deformable_offsets_grad_fusion.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/transdata_split.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/topk_split.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/conv2d_backprop_filter_mul_fission.h"
|
||||
|
@ -225,6 +226,7 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
|
|||
ir_fusion_pm->AddPass(std::make_shared<DiagFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DiagPartFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DeformableOffsetsFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DeformableOffsetsGradFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<MaxPool3DGradGradFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AdaptiveMaxPool2DFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AvgPool3DFusion>());
|
||||
|
@ -422,6 +424,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
|
|||
ir_fusion_pm->AddPass(std::make_shared<DiagFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DiagPartFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DeformableOffsetsFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DeformableOffsetsGradFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AdamWeightDecayFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<ScaleGradFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambFission>());
|
||||
|
|
|
@ -0,0 +1,130 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/ascend/optimizer/ir_fusion/deformable_offsets_grad_fusion.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kAxisH = 2;
|
||||
constexpr size_t kAxisW = 3;
|
||||
constexpr size_t kAxisC = 1;
|
||||
constexpr size_t kDeformableOffsetsGradInputNum = 4;
|
||||
constexpr size_t kChannel = 3;
|
||||
} // namespace
|
||||
|
||||
ValueNodePtr DeformableOffsetsGradFusion::CreateHelperNode(
|
||||
const FuncGraphPtr &func_graph, const AnfNodePtr &node, const std::vector<size_t> &offset_shape,
|
||||
const std::vector<int64_t> &kernel_sizes, const std::vector<int64_t> &strides, const std::vector<int64_t> &pads,
|
||||
const std::vector<int64_t> &dilations, const size_t axis_h, const size_t axis_w, const size_t axis_c) const {
|
||||
size_t h_out = offset_shape[axis_h];
|
||||
size_t w_out = offset_shape[axis_w];
|
||||
int64_t kernel_size_h = kernel_sizes[0];
|
||||
int64_t kernel_size_w = kernel_sizes[1];
|
||||
int64_t stride_h = strides[axis_h];
|
||||
int64_t stride_w = strides[axis_w];
|
||||
int64_t dilation_h = dilations[axis_h];
|
||||
int64_t dilation_w = dilations[axis_w];
|
||||
size_t group = offset_shape[axis_c] / (kChannel * kernel_size_h * kernel_size_w);
|
||||
int64_t pad_top = pads[0];
|
||||
int64_t pad_left = pads[axis_w];
|
||||
int64_t h_index;
|
||||
int64_t w_index;
|
||||
std::vector<size_t> out_shape = {1, offset_shape[1], offset_shape[2], offset_shape[3]};
|
||||
std::vector<int64_t> assist_shape;
|
||||
std::transform(out_shape.begin(), out_shape.end(), std::back_inserter(assist_shape), SizeToLong);
|
||||
tensor::TensorPtr helper_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, assist_shape);
|
||||
TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat32);
|
||||
tensor::DeviceInfo device_info{kOpFormat_NHWC, tensor_type, kOpFormat_NHWC};
|
||||
helper_tensor->set_device_info(device_info);
|
||||
auto tensor_data = reinterpret_cast<float *>(helper_tensor->data_c());
|
||||
for (size_t h = 0; h < h_out; ++h) {
|
||||
for (size_t w = 0; w < w_out; ++w) {
|
||||
for (size_t g = 0; g < group; ++g) {
|
||||
for (int64_t k_h = 0; k_h < kernel_size_h; ++k_h) {
|
||||
for (int64_t k_w = 0; k_w < kernel_size_w; ++k_w) {
|
||||
w_index = static_cast<int64_t>(h * w_out * kChannel * group * kernel_size_h * kernel_size_w +
|
||||
w * kChannel * group * kernel_size_h * kernel_size_w +
|
||||
0 * group * kernel_size_h * kernel_size_w +
|
||||
g * kernel_size_h * kernel_size_w + k_h * kernel_size_w + k_w);
|
||||
h_index = static_cast<int64_t>(h * w_out * kChannel * group * kernel_size_h * kernel_size_w +
|
||||
w * kChannel * group * kernel_size_h * kernel_size_w +
|
||||
1 * group * kernel_size_h * kernel_size_w +
|
||||
g * kernel_size_h * kernel_size_w + k_h * kernel_size_w + k_w);
|
||||
float w_val = static_cast<float>(w * stride_w - pad_left + k_w * dilation_w);
|
||||
float h_val = static_cast<float>(h * stride_h - pad_top + k_h * dilation_h);
|
||||
tensor_data[w_index] = w_val;
|
||||
tensor_data[h_index] = h_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
AbstractBasePtr x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat, assist_shape);
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto assist_value_node = kernel_graph->NewValueNode(x_abstract, helper_tensor);
|
||||
kernel_graph->AddValueNodeToGraph(assist_value_node);
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {out_shape}, assist_value_node.get());
|
||||
return assist_value_node;
|
||||
}
|
||||
|
||||
const BaseRef DeformableOffsetsGradFusion::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({prim::kPrimDeformableOffsetsGrad, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr DeformableOffsetsGradFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
auto deformable_offsets_grad_cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(deformable_offsets_grad_cnode);
|
||||
size_t origin_input_size = deformable_offsets_grad_cnode->inputs().size();
|
||||
if (origin_input_size <= kDeformableOffsetsGradInputNum) {
|
||||
MS_LOG(INFO) << "The node " << deformable_offsets_grad_cnode->DebugString() << " is not equal to "
|
||||
<< kDeformableOffsetsGradInputNum << " inputs";
|
||||
}
|
||||
auto pads = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(deformable_offsets_grad_cnode, kAttrPads);
|
||||
auto stride = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(deformable_offsets_grad_cnode, kAttrStrides);
|
||||
auto dialation = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(deformable_offsets_grad_cnode, kAttrDilations);
|
||||
auto kernel_size = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(deformable_offsets_grad_cnode, kAttrKsize);
|
||||
auto offset_shape = common::AnfAlgo::GetOutputInferShape(deformable_offsets_grad_cnode->inputs()[kIndex3], 0);
|
||||
std::vector<AnfNodePtr> new_inputs{
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimDeformableOffsetsGrad->name()))};
|
||||
auto assist_const = CreateHelperNode(func_graph, deformable_offsets_grad_cnode, offset_shape, kernel_size, stride,
|
||||
pads, dialation, kAxisH, kAxisW, kAxisC);
|
||||
(void)new_inputs.insert(new_inputs.end(), deformable_offsets_grad_cnode->inputs().begin() + 1,
|
||||
deformable_offsets_grad_cnode->inputs().end());
|
||||
new_inputs.push_back(assist_const);
|
||||
auto new_cnode = NewCNode(new_inputs, func_graph);
|
||||
new_cnode->set_abstract(deformable_offsets_grad_cnode->abstract());
|
||||
new_cnode->set_scope(deformable_offsets_grad_cnode->scope());
|
||||
common::AnfAlgo::CopyNodeAttrs(deformable_offsets_grad_cnode, new_cnode);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrDataFormat, MakeValue("NHWC"), new_cnode);
|
||||
if (kernel_graph != nullptr) {
|
||||
kernel_graph->AddValueNodeToGraph(assist_const);
|
||||
MS_LOG(INFO) << "Add assist tensor for DeformableOffsets op success.";
|
||||
}
|
||||
return new_cnode;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_DEFORMABLE_OFFSETS_GRAD_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_DEFORMABLE_OFFSETS_GRAD_FUSION_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class DeformableOffsetsGradFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit DeformableOffsetsGradFusion(bool multigraph = true)
|
||||
: PatternProcessPass("deformable_offsets_fission", multigraph) {}
|
||||
~DeformableOffsetsGradFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
ValueNodePtr CreateHelperNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const std::vector<size_t> &offset_shape, const std::vector<int64_t> &kernel_sizes,
|
||||
const std::vector<int64_t> &strides, const std::vector<int64_t> &pads,
|
||||
const std::vector<int64_t> &dilations, const size_t axis_h, const size_t axis_w,
|
||||
const size_t axis_c) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_DEFORMABLE_OFFSETS_GRAD_FUSION_H_
|
|
@ -26,7 +26,6 @@ from ..operations import _inner_ops as inner
|
|||
from ..operations import _rl_inner_ops as rl_ops
|
||||
from ... import context
|
||||
from .._utils.utils import range_op, get_1d_shape
|
||||
from ..operations.nn_ops import DeformableOffsets
|
||||
|
||||
|
||||
@bprop_getters.register(P.BiasAdd)
|
||||
|
@ -1281,7 +1280,7 @@ def get_bprop_basic_lstm_cell(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(DeformableOffsets)
|
||||
@bprop_getters.register(nps.DeformableOffsets)
|
||||
def get_bprop_deformable_offsets(self):
|
||||
"""Grad definition for `DeformableOffsets` operation."""
|
||||
grad = G.DeformableOffsetsGrad(self.strides, self.pads, self.ksize, self.dilations, self.data_format,
|
||||
|
|
|
@ -19,9 +19,9 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
|||
deformable_offsets_grad_op_info = TBERegOp("DeformableOffsetsGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("deformable_offsets.so") \
|
||||
.binfile_name("deformable_offsets_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("deformable_offsets") \
|
||||
.kernel_name("deformable_offsets_grad") \
|
||||
.partial_flag(True) \
|
||||
.need_check_supported(True) \
|
||||
.attr("strides", "required", "listInt", "all") \
|
||||
|
|
|
@ -21,7 +21,7 @@ bprop.13:y*
|
|||
bprop.13:keep_prob*
|
||||
bprop.13:out*
|
||||
bprop.13:dout2
|
||||
bprop.13:[CNode]17:6:@6f2fbfa8ce1391cda7ed6ad0bd4a9ed790ee77536ef6167337a9cb401c782fe7Pb.
|
||||
bprop.13:[CNode]17:6:@b30f014d58a844d179a6688149acfc5f85f53edc69e56e997a1c1c38e981c97cPbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:4!S-Prim-hyper_map[zeros_like_leaf]b.
|
||||
S-Prim-DropoutDoMask:2S-Prim-DropoutDoMaskb&
|
||||
S-Prim-MakeTuple:7S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:4!S-Prim-hyper_map[zeros_like_leaf]h
|
||||
S-Prim-MakeTuple:7S-Prim-MakeTupleh
|
|
@ -11,6 +11,6 @@
|
|||
bprop.3:keep_prob*
|
||||
bprop.3:out*
|
||||
bprop.3:dout2
|
||||
bprop.3:[CNode]6:4:@6f2fbfa8ce1391cda7ed6ad0bd4a9ed790ee77536ef6167337a9cb401c782fe7Pb&
|
||||
bprop.3:[CNode]6:4:@b30f014d58a844d179a6688149acfc5f85f53edc69e56e997a1c1c38e981c97cPb&
|
||||
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -19,6 +19,6 @@
|
|||
bprop.7:off_value*
|
||||
bprop.7:out*
|
||||
bprop.7:dout2
|
||||
bprop.7:[CNode]12:6:@6f2fbfa8ce1391cda7ed6ad0bd4a9ed790ee77536ef6167337a9cb401c782fe7PbH
|
||||
bprop.7:[CNode]12:6:@b30f014d58a844d179a6688149acfc5f85f53edc69e56e997a1c1c38e981c97cPbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:7S-Prim-MakeTupleh
|
Binary file not shown.
|
@ -8,9 +8,9 @@ m
|
|||
bprop.1:x*
|
||||
bprop.1:out*
|
||||
bprop.1:dout2
|
||||
bprop.1:[CNode]2:3:@6f2fbfa8ce1391cda7ed6ad0bd4a9ed790ee77536ef6167337a9cb401c782fe7Pb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebr
|
||||
bprop.1:[CNode]2:3:@b30f014d58a844d179a6688149acfc5f85f53edc69e56e997a1c1c38e981c97cPbr
|
||||
S-Prim-ReluGrad:2S-Prim-ReluGrad
|
||||
output_names€ŠZoutput€+
|
||||
input_names€ŠZ
|
||||
y_backprop€ŠZx€h
|
||||
y_backprop€ŠZx€b&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
|
@ -15,10 +15,10 @@ bprop.18:ybprop.18:[CNode]19:3bprop.18:[CNode]19:3"(REF::S-Prim-hyper_map[ze
|
|||
bprop.18:y*
|
||||
bprop.18:out*
|
||||
bprop.18:dout2
|
||||
bprop.18:[CNode]20:5:@6f2fbfa8ce1391cda7ed6ad0bd4a9ed790ee77536ef6167337a9cb401c782fe7Pbr
|
||||
bprop.18:[CNode]20:5:@b30f014d58a844d179a6688149acfc5f85f53edc69e56e997a1c1c38e981c97cPbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:4!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:6S-Prim-MakeTuplebr
|
||||
S-Prim-ReluGrad:2S-Prim-ReluGrad
|
||||
output_names€ŠZoutput€+
|
||||
input_names€ŠZ
|
||||
y_backprop€ŠZx€b&
|
||||
S-Prim-MakeTuple:6S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:4!S-Prim-hyper_map[zeros_like_leaf]h
|
||||
y_backprop€ŠZx€h
|
|
@ -380,10 +380,12 @@ def deformable_conv2d(x, weight, offsets, kernel_size, strides, padding, bias=No
|
|||
|
||||
.. note::
|
||||
- This is an experimental interface that is subject to change or deletion.
|
||||
- For Ascend platform, only supports cases when :math:`C_{in}` can be divisible by 8, `deformable_groups` is 1
|
||||
and `offsets` value is float which needs to contain a decimal part. For example, these scenarios where the
|
||||
shape of `x` is :math:`(N, 2, H_{in}, W_{in})`, `deformable_groups` is 2 or `offsets` is assigned with
|
||||
"numpy.ones()" function are not supported.
|
||||
- For Ascend platform, the following cases are not supported:
|
||||
- :math:`C_{in}` cannot be divisible by 8, e.g. `x` is :math:`(N, 2, H_{in}, W_{in})`
|
||||
- `deformable_groups` is 1, e.g. `deformable_groups` is 2
|
||||
- `offsets` value is float which does not contain a decimal part, e.g. `offsets` is assigned with
|
||||
"numpy.ones()"
|
||||
- `kernel_size` is less than 2, e.g. `kernel_size` is (1, 1)
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.ones((4, 3, 10, 10)), mstype.float32)
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.composite import GradOperation
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import nn_ops as NN
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, out_channel, kernel_size, pad, stride, dilation):
|
||||
super(Net, self).__init__()
|
||||
self.net = NN.DeformableOffsets(ksize=(kernel_size, kernel_size),
|
||||
pads=(pad, pad, pad, pad),
|
||||
strides=(stride, stride, stride, stride),
|
||||
dilations=(dilation, dilation, dilation, dilation),
|
||||
deformable_groups=1,
|
||||
modulated=True,
|
||||
data_format="NCHW")
|
||||
self.conv = P.Conv2D(out_channel,
|
||||
kernel_size,
|
||||
mode=1,
|
||||
pad_mode="pad",
|
||||
pad=pad,
|
||||
stride=kernel_size,
|
||||
dilation=1,
|
||||
group=1,
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x, w, offset):
|
||||
x = self.net(x, offset)
|
||||
return self.conv(x, w)
|
||||
|
||||
|
||||
class Grad(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(Grad, self).__init__()
|
||||
self.grad = GradOperation(get_all=True, sens_param=True)
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, w, offset, output_grad):
|
||||
return self.grad(self.network)(x, w, offset, output_grad)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_deformable_conv2d_grad():
|
||||
""""
|
||||
Feature: deformable_conv2d_grad function
|
||||
Description: Test case for simplest deformable_conv2d_grad
|
||||
Expectation: The results are as expected
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
kernel_size = 2
|
||||
stride = 1
|
||||
pad = 0
|
||||
dilation = 1
|
||||
# x shape [1, 8, 2, 2]
|
||||
x = Tensor(np.ones([1, 8, 2, 2]).astype(np.float32) * 0.1)
|
||||
# weight shape [1, 8, 2, 2]
|
||||
weight = Tensor(np.ones([1, 8, 2, 2]).astype(np.float32) * 0.1)
|
||||
# offsets shape [1, 12, 1, 1]
|
||||
offsets = Tensor(np.ones([1, 12, 1, 1]).astype(np.float32) * 0.1)
|
||||
# out_channel, kernel_size, pad, stride, dilation
|
||||
dfm_conv2d_net = Net(1, kernel_size, pad, stride, dilation)
|
||||
out = dfm_conv2d_net(x, weight, offsets)
|
||||
grad_net = Grad(dfm_conv2d_net)
|
||||
grad_output = grad_net(x, weight, offsets, out)
|
||||
expected_out = np.array([[[[0.02888089]]]]).astype(np.float32)
|
||||
expect_grad_x = np.array([[[[0.00023391, 0.0002599],
|
||||
[0.0002599, 0.00028877]],
|
||||
[[0.00023391, 0.0002599],
|
||||
[0.0002599, 0.00028877]],
|
||||
[[0.00023391, 0.0002599],
|
||||
[0.0002599, 0.00028877]],
|
||||
[[0.00023391, 0.0002599],
|
||||
[0.0002599, 0.00028877]],
|
||||
[[0.00023391, 0.0002599],
|
||||
[0.0002599, 0.00028877]],
|
||||
[[0.00023391, 0.0002599],
|
||||
[0.0002599, 0.00028877]],
|
||||
[[0.00023391, 0.0002599],
|
||||
[0.0002599, 0.00028877]],
|
||||
[[0.00023391, 0.0002599],
|
||||
[0.0002599, 0.00028877]]]]).astype(np.float32)
|
||||
expect_grad_offset = np.array([[[[0.00028891, 0.00026004],
|
||||
[0.00026004, 0.00023404]],
|
||||
[[0.00028891, 0.00026004],
|
||||
[0.00026004, 0.00023404]],
|
||||
[[0.00028891, 0.00026004],
|
||||
[0.00026004, 0.00023404]],
|
||||
[[0.00028891, 0.00026004],
|
||||
[0.00026004, 0.00023404]],
|
||||
[[0.00028891, 0.00026004],
|
||||
[0.00026004, 0.00023404]],
|
||||
[[0.00028891, 0.00026004],
|
||||
[0.00026004, 0.00023404]],
|
||||
[[0.00028891, 0.00026004],
|
||||
[0.00026004, 0.00023404]],
|
||||
[[0.00028891, 0.00026004],
|
||||
[0.00026004, 0.00023404]]]]).astype(np.float32)
|
||||
assert np.allclose(out.asnumpy(), expected_out, 0.0001, 0.0001)
|
||||
assert np.allclose(grad_output[0].asnumpy(), expect_grad_x, 0.0001, 0.0001)
|
||||
assert np.allclose(grad_output[1].asnumpy(), expect_grad_offset, 0.0001, 0.0001)
|
Loading…
Reference in New Issue