forked from mindspore-Ecosystem/mindspore
!16346 support AvgPool3D operater for Ascend.
From: @liu_xiao_93 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
99fbbe0f89
|
@ -43,7 +43,8 @@ void DynamicShapeKernel::Execute() {
|
|||
auto output_addr = AnfAlgo::GetOutputAddr(cnode, 0);
|
||||
MS_EXCEPTION_IF_NULL(output_addr);
|
||||
output_addr->SyncHostToDevice(output_shape, LongToSize(output_tensor_for_sync->data().nbytes()),
|
||||
output_tensor_for_sync->data_type(), output_tensor_for_sync->data_c());
|
||||
output_tensor_for_sync->data_type(), output_tensor_for_sync->data_c(),
|
||||
output_tensor_for_sync->device_info().host_format_);
|
||||
MS_LOG(INFO) << "Execute DynamicShapeKernel End";
|
||||
}
|
||||
|
||||
|
|
|
@ -57,6 +57,8 @@
|
|||
#include "backend/optimizer/ascend/ir_fission/lin_space_fission.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/space_to_depth_split.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/max_pool3d_grad_grad_fission.h"
|
||||
#include "backend/optimizer/ascend/ir_fusion/avgpool_3d_fusion.h"
|
||||
#include "backend/optimizer/ascend/ir_fusion/avgpool_3d_grad_fusion.h"
|
||||
#include "backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h"
|
||||
#include "backend/optimizer/ascend/ir_fusion/mul_add_fusion.h"
|
||||
#include "backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h"
|
||||
|
@ -176,6 +178,8 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
|
|||
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LinSpaceFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<MaxPool3DGradGradFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AvgPool3DFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AvgPool3DGradFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<MomentumLossscaleFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<MulAddFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>());
|
||||
|
@ -318,6 +322,8 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
|
|||
ir_fusion_pm->AddPass(std::make_shared<LinSpaceFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<SpaceToDepthSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<MaxPool3DGradGradFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AvgPool3DFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AvgPool3DGradFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>());
|
||||
|
|
|
@ -0,0 +1,288 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/ir_fusion/avgpool_3d_fusion.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t k5DInferDims = 5;
|
||||
constexpr size_t kC0 = 16;
|
||||
|
||||
int64_t GetInterSection(int64_t start_1, int64_t end_1, int64_t start_2, int64_t end_2) {
|
||||
if (end_1 <= start_2) {
|
||||
return 0;
|
||||
}
|
||||
if (start_1 >= end_2) {
|
||||
return 0;
|
||||
}
|
||||
if (start_1 < start_2) {
|
||||
start_1 = start_2;
|
||||
}
|
||||
if (end_1 > end_2) {
|
||||
end_1 = end_2;
|
||||
}
|
||||
return end_1 - start_1;
|
||||
}
|
||||
|
||||
bool GetKernelSize(const AnfNodePtr &node, int64_t *kd, int64_t *kh, int64_t *kw) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (AnfAlgo::HasNodeAttr("kernel_size", node->cast<CNodePtr>())) {
|
||||
auto kernel_size = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "kernel_size");
|
||||
if (kernel_size.size() == 1) {
|
||||
*kd = kernel_size[0];
|
||||
*kh = kernel_size[0];
|
||||
*kw = kernel_size[0];
|
||||
} else if (kernel_size.size() == 3) {
|
||||
*kd = kernel_size[0];
|
||||
*kh = kernel_size[1];
|
||||
*kw = kernel_size[2];
|
||||
} else if (kernel_size.size() == 5) {
|
||||
// NCDHW
|
||||
*kd = kernel_size[2];
|
||||
*kh = kernel_size[3];
|
||||
*kw = kernel_size[4];
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unknown kernel size " << kernel_size.size();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GetStrideSize(const AnfNodePtr &node, int64_t *sd, int64_t *sh, int64_t *sw) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (AnfAlgo::HasNodeAttr("strides", node->cast<CNodePtr>())) {
|
||||
auto kernel_size = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "strides");
|
||||
if (kernel_size.size() == 1) {
|
||||
*sd = kernel_size[0];
|
||||
*sh = kernel_size[0];
|
||||
*sw = kernel_size[0];
|
||||
} else if (kernel_size.size() == 3) {
|
||||
*sd = kernel_size[0];
|
||||
*sh = kernel_size[1];
|
||||
*sw = kernel_size[2];
|
||||
} else if (kernel_size.size() == 5) {
|
||||
// NCDHW
|
||||
*sd = kernel_size[2];
|
||||
*sh = kernel_size[3];
|
||||
*sw = kernel_size[4];
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unknown strides size " << kernel_size.size();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void GetAttrs(const AnfNodePtr &node, std::vector<int64_t> *pad_list, bool *count_include_pad, bool *ceil_mode,
|
||||
int64_t *divisor_override) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!AnfAlgo::HasNodeAttr("pad_list", node->cast<CNodePtr>())) {
|
||||
MS_LOG(EXCEPTION) << "AvgPool3D should has attr pad_list";
|
||||
}
|
||||
*pad_list = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "pad_list");
|
||||
if (AnfAlgo::HasNodeAttr("count_include_pad", node->cast<CNodePtr>())) {
|
||||
*count_include_pad = AnfAlgo::GetNodeAttr<bool>(node, "count_include_pad");
|
||||
}
|
||||
if (AnfAlgo::HasNodeAttr("ceil_mode", node->cast<CNodePtr>())) {
|
||||
*ceil_mode = AnfAlgo::GetNodeAttr<bool>(node, "ceil_mode");
|
||||
}
|
||||
if (AnfAlgo::HasNodeAttr("divisor_override", node->cast<CNodePtr>())) {
|
||||
*divisor_override = AnfAlgo::GetNodeAttr<int64_t>(node, "divisor_override");
|
||||
}
|
||||
}
|
||||
|
||||
bool IsVectorImpl(int64_t fh, int64_t fw, int64_t kh, int64_t kw, const std::vector<int64_t> &pad_list) {
|
||||
if (std::any_of(pad_list.begin(), pad_list.end(), [](int64_t item) { return item != 0; })) {
|
||||
return false;
|
||||
}
|
||||
if (fh != kh || fw != kw) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsZeroPads(const std::vector<int64_t> &pad_list) {
|
||||
return std::all_of(pad_list.begin(), pad_list.end(), [](int64_t item) { return item == 0; });
|
||||
}
|
||||
|
||||
AnfNodePtr ConstructFilter(const FuncGraphPtr &func_graph, const std::vector<int64_t> &pad_list, int64_t fc, int64_t kd,
|
||||
int64_t kh, int64_t kw, bool ceil_mode, int64_t divisor_override) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
// assist tensor 1
|
||||
int64_t c1 = (fc + kC0 - 1) / kC0;
|
||||
std::vector<int64_t> assist_shape = {c1 * kd * kh * kw, 1, kC0, kC0}; // frac_z_3d
|
||||
auto infer_shape = {IntToSize(1), LongToSize(fc), LongToSize(kd), LongToSize(kh), LongToSize(kw)};
|
||||
float val = 1.0 / (kd * kh * kw);
|
||||
if (divisor_override) {
|
||||
val = 1.0 / divisor_override;
|
||||
} else if (!IsZeroPads(pad_list) || ceil_mode) {
|
||||
val = 1.0;
|
||||
}
|
||||
// create value node
|
||||
tensor::TensorPtr assist_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat16, assist_shape);
|
||||
TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat16);
|
||||
tensor::DeviceInfo device_info{kOpFormat_FRACTAL_Z_3D, tensor_type, kOpFormat_FRACTAL_Z_3D};
|
||||
assist_tensor->set_device_info(device_info);
|
||||
auto tensor_data = reinterpret_cast<float16 *>(assist_tensor->data_c());
|
||||
int64_t cnt = c1 * kd * kh * kw;
|
||||
for (int64_t i = 0; i < cnt; ++i) {
|
||||
for (size_t j = 0; j < kC0; ++j) {
|
||||
for (size_t k = 0; k < kC0; ++k) {
|
||||
float t = j == k ? val : 0;
|
||||
*tensor_data = float16(t);
|
||||
++tensor_data;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat16, assist_shape);
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
auto value_node = kernel_graph->NewValueNode(x_abstract, assist_tensor);
|
||||
kernel_graph->AddValueNodeToGraph(value_node);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {infer_shape}, value_node.get());
|
||||
return value_node;
|
||||
}
|
||||
|
||||
AnfNodePtr ConstructMultiplier(const FuncGraphPtr &func_graph, int64_t fn, int64_t fc, int64_t fd, int64_t fh,
|
||||
int64_t fw, int64_t dd, int64_t dh, int64_t dw, int64_t kd, int64_t kh, int64_t kw,
|
||||
int64_t sd, int64_t sh, int64_t sw, const std::vector<int64_t> &pad_list, bool ceil_mode,
|
||||
bool count_include_pad) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
// assist tensor 2
|
||||
std::vector<int64_t> assist_shape = {fn, fc, dd, dh, dw}; // NCDHW
|
||||
auto infer_shape = {LongToSize(fn), LongToSize(fc), LongToSize(dd), LongToSize(dh), LongToSize(dw)};
|
||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat16, assist_shape);
|
||||
auto tensor_data = reinterpret_cast<float16 *>(tensor->data_c());
|
||||
auto pad_d = pad_list[0] + pad_list[1];
|
||||
auto pad_h = pad_list[2] + pad_list[3];
|
||||
auto pad_w = pad_list[4] + pad_list[5];
|
||||
auto len_d = fd + pad_d;
|
||||
auto len_h = fh + pad_h;
|
||||
auto len_w = fw + pad_w;
|
||||
for (int64_t nn = 0; nn < fn; nn++) {
|
||||
for (int64_t cc = 0; cc < fc; cc++) {
|
||||
int64_t start_d = 0;
|
||||
for (int64_t di = 0; di < dd; di++) {
|
||||
auto v_kd = start_d + kd <= len_d ? kd : len_d - start_d;
|
||||
int64_t start_h = 0;
|
||||
for (int64_t hi = 0; hi < dh; hi++) {
|
||||
auto v_kh = start_h + kh <= len_h ? kh : len_h - start_h;
|
||||
int64_t start_w = 0;
|
||||
for (int64_t wi = 0; wi < dw; wi++) {
|
||||
auto v_kw = start_w + kw < len_w ? kw : len_w - start_w;
|
||||
auto vaild_d = GetInterSection(start_d, start_d + kd, pad_list[0], pad_list[0] + fd);
|
||||
auto vaild_h = GetInterSection(start_h, start_h + kh, pad_list[2], pad_list[2] + fh);
|
||||
auto vaild_w = GetInterSection(start_w, start_w + kw, pad_list[4], pad_list[4] + fw);
|
||||
auto vaild_data = vaild_d * vaild_h * vaild_w;
|
||||
auto vaild_kernel = v_kd * v_kh * v_kw;
|
||||
float val = count_include_pad ? 1.0 / vaild_kernel : 1.0 / vaild_data;
|
||||
*tensor_data = float16(val);
|
||||
++tensor_data;
|
||||
start_w += sw;
|
||||
}
|
||||
start_h += sh;
|
||||
}
|
||||
start_d += sd;
|
||||
}
|
||||
}
|
||||
}
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat16, assist_shape);
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
auto value_node = kernel_graph->NewValueNode(x_abstract, tensor);
|
||||
kernel_graph->AddValueNodeToGraph(value_node);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {infer_shape}, value_node.get());
|
||||
return value_node;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef AvgPool3DFusion::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({prim::kPrimAvgPool3D, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr AvgPool3DFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto avg_pool_3d_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(avg_pool_3d_node);
|
||||
auto dims_in = AnfAlgo::GetPrevNodeOutputInferShape(avg_pool_3d_node, 0);
|
||||
auto dims_out = AnfAlgo::GetOutputInferShape(avg_pool_3d_node, 0);
|
||||
if (dims_in.size() < k5DInferDims || dims_out.size() < k5DInferDims) {
|
||||
MS_LOG(EXCEPTION) << "AvgPool3D's in_out infer shape dims can not be less " << k5DInferDims;
|
||||
}
|
||||
auto fn = SizeToLong(dims_in[0]);
|
||||
auto fc = SizeToLong(dims_in[1]);
|
||||
auto fd = SizeToLong(dims_in[2]);
|
||||
auto fh = SizeToLong(dims_in[3]);
|
||||
auto fw = SizeToLong(dims_in[4]);
|
||||
auto dout = SizeToLong(dims_out[2]);
|
||||
auto dh = SizeToLong(dims_out[3]);
|
||||
auto dw = SizeToLong(dims_out[4]);
|
||||
// kernel size
|
||||
int64_t kd;
|
||||
int64_t kh;
|
||||
int64_t kw;
|
||||
if (!GetKernelSize(avg_pool_3d_node, &kd, &kh, &kw)) {
|
||||
MS_LOG(EXCEPTION) << "GetK kernel size failed";
|
||||
}
|
||||
// strides
|
||||
int64_t sd;
|
||||
int64_t sh;
|
||||
int64_t sw;
|
||||
if (!GetStrideSize(avg_pool_3d_node, &sd, &sh, &sw)) {
|
||||
MS_LOG(EXCEPTION) << "GetK stride size failed";
|
||||
}
|
||||
std::vector<int64_t> pad_list;
|
||||
bool count_include_pad = false;
|
||||
bool ceil_mode = false;
|
||||
int64_t divisor_override = 0;
|
||||
GetAttrs(avg_pool_3d_node, &pad_list, &count_include_pad, &ceil_mode, &divisor_override);
|
||||
if (IsVectorImpl(fh, fw, kh, kw, pad_list)) {
|
||||
MS_LOG(INFO) << "No need fusion";
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimAvgPool3D->name()))};
|
||||
(void)new_inputs.insert(new_inputs.end(), avg_pool_3d_node->inputs().begin() + 1, avg_pool_3d_node->inputs().end());
|
||||
// assist node 1
|
||||
auto filter_node = ConstructFilter(func_graph, pad_list, fc, kd, kh, kw, ceil_mode, divisor_override);
|
||||
new_inputs.push_back(filter_node);
|
||||
MS_EXCEPTION_IF_NULL(filter_node);
|
||||
// assist node 2
|
||||
if ((!IsZeroPads(pad_list) || ceil_mode) && !divisor_override) {
|
||||
auto multiplier = ConstructMultiplier(func_graph, fn, fc, fd, fh, fw, dout, dh, dw, kd, kh, kw, sd, sh, sw,
|
||||
pad_list, ceil_mode, count_include_pad);
|
||||
new_inputs.push_back(multiplier);
|
||||
}
|
||||
auto new_3d = func_graph->NewCNode(new_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_3d);
|
||||
new_3d->set_scope(avg_pool_3d_node->scope());
|
||||
new_3d->set_abstract(avg_pool_3d_node->abstract());
|
||||
AnfAlgo::CopyNodeAttrs(avg_pool_3d_node, new_3d);
|
||||
return new_3d;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_AVGPOOL_3D_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_AVGPOOL_3D_FUSION_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class AvgPool3DFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit AvgPool3DFusion(bool multigraph = true) : PatternProcessPass("avg_pool_3d_fusion", multigraph) {}
|
||||
~AvgPool3DFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_AVGPOOL_3D_FUSION_H_
|
|
@ -0,0 +1,254 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/ir_fusion/avgpool_3d_grad_fusion.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t k5DInferDims = 5;
|
||||
constexpr size_t kKernelDims = 3;
|
||||
constexpr size_t kStridesDims = 3;
|
||||
constexpr size_t kOrigShapeDims = 5;
|
||||
constexpr size_t kShapeDims = 6;
|
||||
constexpr size_t kPadDims = 6;
|
||||
constexpr size_t kC0 = 16;
|
||||
|
||||
void GetAttrs(const AnfNodePtr &node, std::vector<int64_t> *kernel_size, std::vector<int64_t> *strides,
|
||||
std::vector<int64_t> *pad_list, std::vector<int64_t> *origin_input_shape, bool *ceil_mode,
|
||||
bool *count_include_pad, int64_t *divisor_override, std::string *format_str) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// attr kernel size
|
||||
if (!AnfAlgo::HasNodeAttr("kernel_size", node->cast<CNodePtr>())) {
|
||||
MS_LOG(EXCEPTION) << "AvgPool3D should has attr kernel_size";
|
||||
}
|
||||
*kernel_size = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "kernel_size");
|
||||
// attr strides
|
||||
if (!AnfAlgo::HasNodeAttr("strides", node->cast<CNodePtr>())) {
|
||||
MS_LOG(EXCEPTION) << "AvgPool3D should has attr strides";
|
||||
}
|
||||
*strides = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "strides");
|
||||
// sttr pad_list
|
||||
if (!AnfAlgo::HasNodeAttr("pad_list", node->cast<CNodePtr>())) {
|
||||
MS_LOG(EXCEPTION) << "AvgPool3D should has attr pad_list";
|
||||
}
|
||||
*pad_list = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "pad_list");
|
||||
// attr origin input shape
|
||||
if (!AnfAlgo::HasNodeAttr("origin_input_shape", node->cast<CNodePtr>())) {
|
||||
MS_LOG(EXCEPTION) << "AvgPool3D should has attr origin_input_shape";
|
||||
}
|
||||
*origin_input_shape = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "origin_input_shape");
|
||||
// attr count include pad
|
||||
if (AnfAlgo::HasNodeAttr("count_include_pad", node->cast<CNodePtr>())) {
|
||||
*count_include_pad = AnfAlgo::GetNodeAttr<bool>(node, "count_include_pad");
|
||||
}
|
||||
// attr ceil mode
|
||||
if (AnfAlgo::HasNodeAttr("ceil_mode", node->cast<CNodePtr>())) {
|
||||
*ceil_mode = AnfAlgo::GetNodeAttr<bool>(node, "ceil_mode");
|
||||
}
|
||||
// attr divisor override
|
||||
if (AnfAlgo::HasNodeAttr("divisor_override", node->cast<CNodePtr>())) {
|
||||
*divisor_override = AnfAlgo::GetNodeAttr<int64_t>(node, "divisor_override");
|
||||
}
|
||||
if (AnfAlgo::HasNodeAttr("format", node->cast<CNodePtr>())) {
|
||||
*format_str = AnfAlgo::GetNodeAttr<std::string>(node, "format");
|
||||
}
|
||||
}
|
||||
|
||||
bool IsVectorImpl(const std::vector<int64_t> &fp_shape, const std::vector<int64_t> &k_size,
|
||||
const std::vector<int64_t> &pad_list) {
|
||||
// NCDHW
|
||||
auto fd = fp_shape[2];
|
||||
auto fh = fp_shape[3];
|
||||
auto fw = fp_shape[4];
|
||||
auto kd = k_size[0];
|
||||
auto kh = k_size[1];
|
||||
auto kw = k_size[2];
|
||||
bool flag1 = kd >= fd + pad_list[0] + pad_list[1];
|
||||
bool flag2 = kh >= fh + pad_list[2] + pad_list[3];
|
||||
bool flag3 = kw >= fw + pad_list[4] + pad_list[5];
|
||||
if (flag1 && flag2 && flag3) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsZeroPads(const std::vector<int64_t> &pad_list) {
|
||||
return std::all_of(pad_list.begin(), pad_list.end(), [](int64_t item) { return item == 0; });
|
||||
}
|
||||
|
||||
AnfNodePtr ConstructFilter(const FuncGraphPtr &func_graph, const std::vector<int64_t> &pad_list, int64_t fc, int64_t kd,
|
||||
int64_t kh, int64_t kw, int64_t divisor_override, bool ceil_mode) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
// assist tensor 1
|
||||
int64_t c1 = (fc + kC0 - 1) / kC0;
|
||||
std::vector<int64_t> assist_shape = {c1 * kd * kh * kw, 1, kC0, kC0}; // frac_z_3d
|
||||
auto infer_shape = {IntToSize(1), LongToSize(fc), LongToSize(kd), LongToSize(kh), LongToSize(kw)};
|
||||
float val = 1.0;
|
||||
if (divisor_override) {
|
||||
val = 1.0 / divisor_override;
|
||||
} else if (IsZeroPads(pad_list) && !ceil_mode) {
|
||||
val = 1.0 / (kd * kh * kw);
|
||||
}
|
||||
// create value node
|
||||
tensor::TensorPtr assist_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat16, assist_shape);
|
||||
TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat16);
|
||||
tensor::DeviceInfo device_info{kOpFormat_FRACTAL_Z_3D, tensor_type, kOpFormat_FRACTAL_Z_3D};
|
||||
assist_tensor->set_device_info(device_info);
|
||||
auto tensor_data = reinterpret_cast<float16 *>(assist_tensor->data_c());
|
||||
int64_t cnt = c1 * kd * kh * kw;
|
||||
for (int64_t i = 0; i < cnt; ++i) {
|
||||
for (size_t j = 0; j < kC0; ++j) {
|
||||
for (size_t k = 0; k < kC0; ++k) {
|
||||
float t = j == k ? val : 0;
|
||||
*tensor_data = float16(t);
|
||||
++tensor_data;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat16, assist_shape);
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
auto value_node = kernel_graph->NewValueNode(x_abstract, assist_tensor);
|
||||
kernel_graph->AddValueNodeToGraph(value_node);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {infer_shape}, value_node.get());
|
||||
return value_node;
|
||||
}
|
||||
|
||||
AnfNodePtr ConstructMultiplier(const FuncGraphPtr &func_graph, const std::vector<size_t> &ori_shape,
|
||||
const std::vector<int64_t> &ori_input_shape, const std::vector<int64_t> &kernel_size,
|
||||
const std::vector<int64_t> &strides, const std::vector<int64_t> &pad_list,
|
||||
bool ceil_mode, bool count_include_pad) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
// assist tensor 2
|
||||
std::vector<int64_t> grad_shape;
|
||||
(void)std::transform(ori_shape.begin(), ori_shape.end(), std::back_inserter(grad_shape), SizeToLong);
|
||||
std::vector<int64_t> assist_shape = grad_shape; // NCDHW
|
||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat16, assist_shape);
|
||||
auto tensor_data = reinterpret_cast<float16 *>(tensor->data_c());
|
||||
auto pad_d = pad_list[0] + pad_list[1];
|
||||
auto pad_h = pad_list[2] + pad_list[3];
|
||||
auto pad_w = pad_list[4] + pad_list[5];
|
||||
auto len_d = ori_input_shape[2] + pad_d;
|
||||
auto len_h = ori_input_shape[3] + pad_h;
|
||||
auto len_w = ori_input_shape[4] + pad_w;
|
||||
for (int64_t nn = 0; nn < grad_shape[0]; nn++) {
|
||||
for (int64_t cc = 0; cc < grad_shape[1]; cc++) {
|
||||
int64_t start_d = 0;
|
||||
for (int64_t di = 0; di < grad_shape[2]; di++) {
|
||||
int64_t start_h = 0;
|
||||
for (int64_t hi = 0; hi < grad_shape[3]; hi++) {
|
||||
int64_t start_w = 0;
|
||||
for (int64_t wi = 0; wi < grad_shape[4]; wi++) {
|
||||
int64_t vaild_d = 0;
|
||||
int64_t vaild_h = 0;
|
||||
int64_t vaild_w = 0;
|
||||
if (count_include_pad) {
|
||||
vaild_d = start_d + kernel_size[0] <= len_d ? kernel_size[0] : len_d - start_d;
|
||||
vaild_h = start_h + kernel_size[1] <= len_h ? kernel_size[1] : len_h - start_h;
|
||||
vaild_w = start_w + kernel_size[2] <= len_w ? kernel_size[2] : len_w - start_w;
|
||||
} else {
|
||||
vaild_d =
|
||||
std::min(start_d + kernel_size[0], pad_list[0] + ori_input_shape[2]) - std::max(pad_list[0], start_d);
|
||||
vaild_h =
|
||||
std::min(start_h + kernel_size[1], pad_list[2] + ori_input_shape[3]) - std::max(pad_list[2], start_h);
|
||||
vaild_w =
|
||||
std::min(start_w + kernel_size[2], pad_list[4] + ori_input_shape[4]) - std::max(pad_list[4], start_w);
|
||||
}
|
||||
auto vaild_data = vaild_d * vaild_h * vaild_w;
|
||||
float val = 1.0 / vaild_data;
|
||||
*tensor_data = float16(val);
|
||||
++tensor_data;
|
||||
start_w += strides[2];
|
||||
}
|
||||
start_h += strides[1];
|
||||
}
|
||||
start_d += strides[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat16, assist_shape);
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
auto value_node = kernel_graph->NewValueNode(x_abstract, tensor);
|
||||
kernel_graph->AddValueNodeToGraph(value_node);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {ori_shape}, value_node.get());
|
||||
return value_node;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef AvgPool3DGradFusion::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({prim::kPrimAvgPool3DGrad, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr AvgPool3DGradFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto avg_pool_3d_grad_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(avg_pool_3d_grad_node);
|
||||
std::vector<int64_t> kernel_size;
|
||||
std::vector<int64_t> strides;
|
||||
std::vector<int64_t> pad_list;
|
||||
std::vector<int64_t> origin_input_shape;
|
||||
bool ceil_mode = false;
|
||||
bool count_include_pad = true;
|
||||
int64_t divisor_override = 0;
|
||||
std::string format_str;
|
||||
GetAttrs(avg_pool_3d_grad_node, &kernel_size, &strides, &pad_list, &origin_input_shape, &ceil_mode,
|
||||
&count_include_pad, &divisor_override, &format_str);
|
||||
if (IsVectorImpl(origin_input_shape, kernel_size, pad_list)) {
|
||||
MS_LOG(INFO) << "No need fusion";
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimAvgPool3DGrad->name()))};
|
||||
(void)new_inputs.insert(new_inputs.end(), avg_pool_3d_grad_node->inputs().begin() + 1,
|
||||
avg_pool_3d_grad_node->inputs().end());
|
||||
// assist node 1
|
||||
auto kd = kernel_size[0];
|
||||
auto kh = kernel_size[1];
|
||||
auto kw = kernel_size[2];
|
||||
auto fc = origin_input_shape[1];
|
||||
auto filter_node = ConstructFilter(func_graph, pad_list, fc, kd, kh, kw, divisor_override, ceil_mode);
|
||||
new_inputs.push_back(filter_node);
|
||||
MS_EXCEPTION_IF_NULL(filter_node);
|
||||
|
||||
// after input to attr, the first input should be the 'grads', the index is 0;
|
||||
auto dims_in = AnfAlgo::GetPrevNodeOutputInferShape(avg_pool_3d_grad_node, 0);
|
||||
|
||||
// assist node 2
|
||||
if (divisor_override == 0 && (!IsZeroPads(pad_list) || ceil_mode)) {
|
||||
auto multiplier = ConstructMultiplier(func_graph, dims_in, origin_input_shape, kernel_size, strides, pad_list,
|
||||
ceil_mode, count_include_pad);
|
||||
new_inputs.push_back(multiplier);
|
||||
}
|
||||
auto new_3d_grad = func_graph->NewCNode(new_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_3d_grad);
|
||||
new_3d_grad->set_scope(avg_pool_3d_grad_node->scope());
|
||||
new_3d_grad->set_abstract(avg_pool_3d_grad_node->abstract());
|
||||
AnfAlgo::CopyNodeAttrs(avg_pool_3d_grad_node, new_3d_grad);
|
||||
return new_3d_grad;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_AVGPOOL_3D_GRAD_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_AVGPOOL_3D_GRAD_FUSION_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class AvgPool3DGradFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit AvgPool3DGradFusion(bool multigraph = true) : PatternProcessPass("avg_pool_3d_grad_fusion", multigraph) {}
|
||||
~AvgPool3DGradFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_AVGPOOL_3D_GRAD_FUSION_H_
|
|
@ -26,6 +26,7 @@ namespace opt {
|
|||
ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
|
||||
Register(prim::kPrimCast->name(), {1});
|
||||
Register(prim::kPrimAvgPoolGradVm->name(), {0});
|
||||
Register(prim::kPrimAvgPool3DGrad->name(), {0});
|
||||
Register(prim::kPrimConv2DBackpropInput->name(), {2});
|
||||
Register(prim::kPrimConv2DBackpropFilter->name(), {2});
|
||||
Register(prim::kPrimConv3DBackpropInput->name(), {2});
|
||||
|
|
|
@ -47,8 +47,8 @@ void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &k
|
|||
if (!AnfAlgo::IsParameterWeight(pk_node)) {
|
||||
tensor = inputs[no_weight_input++];
|
||||
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
tensor->data_c())) {
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c(),
|
||||
tensor->device_info().host_format_)) {
|
||||
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
|
||||
}
|
||||
}
|
||||
|
@ -76,8 +76,8 @@ GraphId AscendInferenceSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_grap
|
|||
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_value);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
tensor->data_c())) {
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c(),
|
||||
tensor->device_info().host_format_)) {
|
||||
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -429,8 +429,9 @@ void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_gra
|
|||
#endif
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if (size != 0 && !device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), size,
|
||||
tensor->data_type(), tensor->data_c())) {
|
||||
if (size != 0 &&
|
||||
!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), size, tensor->data_type(),
|
||||
tensor->data_c(), tensor->device_info().host_format_)) {
|
||||
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
|
||||
}
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode ||
|
||||
|
@ -1400,7 +1401,8 @@ void AscendSession::SyncInitialTenosrToDevice() {
|
|||
auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0);
|
||||
MS_EXCEPTION_IF_NULL(addr);
|
||||
if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size,
|
||||
front_tensor->data_type(), front_tensor->data_c())) {
|
||||
front_tensor->data_type(), front_tensor->data_c(),
|
||||
front_tensor->device_info().host_format_)) {
|
||||
MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -551,9 +551,10 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const ShapeVector &sh
|
|||
}
|
||||
|
||||
bool AscendDeviceAddress::SyncHostToDevice(const ShapeVector &shape, size_t size, mindspore::TypeId type,
|
||||
const void *host_ptr) const {
|
||||
const void *host_ptr, const std::string &format) const {
|
||||
MS_LOG(INFO) << "SyncHostToDevice, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_)
|
||||
<< ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")";
|
||||
<< ", size:" << size_ << "), Host(format:" << format << ", type_id:" << TypeIdLabel(type)
|
||||
<< ", size:" << size << ")";
|
||||
if (type_id_ > kMonadTypeBegin && type_id_ < kMonadTypeEnd) {
|
||||
return true;
|
||||
}
|
||||
|
@ -564,7 +565,7 @@ bool AscendDeviceAddress::SyncHostToDevice(const ShapeVector &shape, size_t size
|
|||
if (host_shape.empty()) {
|
||||
host_shape.emplace_back(1);
|
||||
}
|
||||
if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NCDHW) {
|
||||
if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NCDHW || format_ == format) {
|
||||
if (type_id_ == type) {
|
||||
SyncMemory(ptr_, host_ptr, size, RT_MEMCPY_HOST_TO_DEVICE);
|
||||
sync_ok = true;
|
||||
|
|
|
@ -42,7 +42,8 @@ class AscendDeviceAddress : public DeviceAddress {
|
|||
bool SyncDeviceToHost(size_t size, void *host_ptr) const override;
|
||||
bool SyncHostToDevice(size_t size, const void *host_ptr) const override;
|
||||
bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const override;
|
||||
bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr) const override;
|
||||
bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
|
||||
const std::string &format = "DefaultFormat") const override;
|
||||
void ClearDeviceMemory() override;
|
||||
DeviceAddressType DeviceType() const override { return DeviceAddressType::kAscend; }
|
||||
bool DumpMemToFile(const std::string &filepath, const std::string &host_fmt, const ShapeVector &host_shape,
|
||||
|
|
|
@ -74,8 +74,8 @@ bool CPUDeviceAddress::SyncDeviceToHost(const ShapeVector & /*shape*/, size_t si
|
|||
return true;
|
||||
}
|
||||
|
||||
bool CPUDeviceAddress::SyncHostToDevice(const ShapeVector & /*shape*/, size_t size, TypeId type,
|
||||
const void *host_ptr) const {
|
||||
bool CPUDeviceAddress::SyncHostToDevice(const ShapeVector & /*shape*/, size_t size, TypeId type, const void *host_ptr,
|
||||
const std::string &format) const {
|
||||
if (ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "The pointer ptr_ is null!";
|
||||
return false;
|
||||
|
|
|
@ -34,7 +34,8 @@ class CPUDeviceAddress : public DeviceAddress {
|
|||
~CPUDeviceAddress() override = default;
|
||||
|
||||
bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const override;
|
||||
bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr) const override;
|
||||
bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
|
||||
const std::string &format = "DefaultFormat") const override;
|
||||
bool DumpMemToFile(const std::string &filepath, const std::string &host_fmt, const ShapeVector &host_shape,
|
||||
TypeId host_type, bool trans_flag) const override;
|
||||
void ClearDeviceMemory() override {}
|
||||
|
|
|
@ -77,7 +77,8 @@ bool GPUDeviceAddress::SyncDeviceToHost(const ShapeVector &, size_t size, TypeId
|
|||
return SyncDeviceToHost(size, host_ptr);
|
||||
}
|
||||
|
||||
bool GPUDeviceAddress::SyncHostToDevice(const ShapeVector &, size_t size, TypeId, const void *host_ptr) const {
|
||||
bool GPUDeviceAddress::SyncHostToDevice(const ShapeVector &, size_t size, TypeId, const void *host_ptr,
|
||||
const std::string &format) const {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
bool execution_mode = ms_context->get_param<int>(MS_CTX_EXECUTION_MODE);
|
||||
|
|
|
@ -39,7 +39,8 @@ class GPUDeviceAddress : public DeviceAddress {
|
|||
bool SyncDeviceToHost(size_t size, void *host_ptr) const override;
|
||||
bool SyncHostToDevice(size_t size, const void *host_ptr) const override;
|
||||
bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const override;
|
||||
bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr) const override;
|
||||
bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
|
||||
const std::string &format = "DefaultFormat") const override;
|
||||
|
||||
void ClearDeviceMemory() override;
|
||||
void set_status(DeviceAddressStatus status) { status_ = status; }
|
||||
|
|
|
@ -734,8 +734,8 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph
|
|||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
tensor->set_device_address(device_address);
|
||||
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
tensor->data_c())) {
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c(),
|
||||
tensor->device_info().host_format_)) {
|
||||
MS_LOG(INFO) << "SyncHostToDevice failed.";
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -702,7 +702,7 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
|
|||
}
|
||||
AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
|
||||
if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(),
|
||||
tensor->data_c())) {
|
||||
tensor->data_c(), tensor->device_info().host_format_)) {
|
||||
MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString()
|
||||
<< "node format is" << AnfAlgo::GetOutputFormat(value_node, output_idx)
|
||||
<< "node dtype is " << AnfAlgo::GetOutputInferDataType(value_node, output_idx);
|
||||
|
|
|
@ -200,7 +200,7 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *cont
|
|||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
if (!device_tensor->SyncHostToDevice(trans::GetRuntimePaddingShape(data_nodes_[i], 0),
|
||||
LongToSize(host_tensor->data().nbytes()), host_tensor->data_type(),
|
||||
host_tensor->data_c())) {
|
||||
host_tensor->data_c(), host_tensor->device_info().host_format_)) {
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "SyncHostToDevice failed.");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -103,7 +103,7 @@ void PrepareDataForValueNodeTensor(const ValueNodePtr &node, const ValuePtr &nod
|
|||
|
||||
// Copy data from host tensor to device.
|
||||
if (!device_tensor->SyncHostToDevice(trans::GetRuntimePaddingShape(node, 0), LongToSize(tensor->data().nbytes()),
|
||||
tensor->data_type(), tensor->data_c())) {
|
||||
tensor->data_type(), tensor->data_c(), tensor->device_info().host_format_)) {
|
||||
MS_LOG(EXCEPTION) << "SyncHostToDevice failed, node name: " << node->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
|
@ -179,8 +179,8 @@ void PrepareDataForWeightNode(const AnfNodePtr &node, const TensorPtr &tensor, c
|
|||
MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, node name: " << node->fullname_with_scope();
|
||||
}
|
||||
if (!host_tensor_address->SyncHostToDevice(trans::GetRuntimePaddingShape(node, 0),
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
tensor->data_c())) {
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c(),
|
||||
tensor->device_info().host_format_)) {
|
||||
MS_LOG(EXCEPTION) << "SyncHostToDevice failed, node name: " << node->fullname_with_scope();
|
||||
}
|
||||
|
||||
|
|
|
@ -235,6 +235,7 @@ inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive
|
|||
inline const PrimitivePtr kPrimAvgPool = std::make_shared<Primitive>("AvgPool");
|
||||
inline const PrimitivePtr kPrimAvgPool3D = std::make_shared<Primitive>("AvgPool3D");
|
||||
inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad");
|
||||
inline const PrimitivePtr kPrimAvgPool3DGrad = std::make_shared<Primitive>("AvgPool3DGrad");
|
||||
inline const PrimitivePtr kPrimAvgPoolGradVm = std::make_shared<Primitive>("AvgPoolGradVm");
|
||||
inline const PrimitivePtr kPrimFusedSparseAdam = std::make_shared<Primitive>("FusedSparseAdam");
|
||||
inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm");
|
||||
|
|
|
@ -37,7 +37,8 @@ class DeviceSync {
|
|||
|
||||
// Used to sync data between host tensor and device address, additional need the data shape and data type.
|
||||
virtual bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const = 0;
|
||||
virtual bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr) const = 0;
|
||||
virtual bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
|
||||
const std::string &format = "DefaultFormat") const = 0;
|
||||
|
||||
virtual void *GetMutablePtr() const = 0;
|
||||
virtual void ClearDeviceMemory() = 0;
|
||||
|
|
|
@ -84,8 +84,8 @@ TypePtr MetaTensor::SetDtype(const TypePtr type_ptr) {
|
|||
return type_ptr;
|
||||
}
|
||||
|
||||
void MetaTensor::SetDeviceInfo(const std::string &format, const TypePtr &data_type) {
|
||||
DeviceInfo info(format, data_type);
|
||||
void MetaTensor::SetDeviceInfo(const std::string &format, const TypePtr &data_type, const std::string &host_format) {
|
||||
DeviceInfo info(format, data_type, host_format);
|
||||
set_device_info(info);
|
||||
}
|
||||
|
||||
|
|
|
@ -41,12 +41,14 @@ namespace mindspore {
|
|||
namespace tensor {
|
||||
// brief Device info of Tensor
|
||||
//
|
||||
// Includes the format and data type of a tensor.
|
||||
// Includes the format, data type and host format of a tensor.
|
||||
struct DeviceInfo {
|
||||
explicit DeviceInfo(std::string format = "DefaultFormat", TypePtr data_type = nullptr)
|
||||
: format_(std::move(format)), data_type_(std::move(data_type)) {}
|
||||
explicit DeviceInfo(std::string format = "DefaultFormat", TypePtr data_type = nullptr,
|
||||
std::string host_format = "DefaultFormat")
|
||||
: format_(std::move(format)), data_type_(std::move(data_type)), host_format_(std::move(host_format)) {}
|
||||
std::string format_ = "DefaultFormat";
|
||||
TypePtr data_type_ = nullptr;
|
||||
std::string host_format_ = "DefaultFormat";
|
||||
};
|
||||
|
||||
// brief Metadata of Tensor
|
||||
|
@ -138,7 +140,8 @@ class MetaTensor : public Value {
|
|||
// Set tensor's device info.
|
||||
void set_device_info(const DeviceInfo &device_info) { device_info_ = device_info; }
|
||||
|
||||
void SetDeviceInfo(const std::string &format, const TypePtr &data_type);
|
||||
void SetDeviceInfo(const std::string &format, const TypePtr &data_type,
|
||||
const std::string &host_format = "DefaultFormat");
|
||||
|
||||
// Get the size of a given dimension by its index number.
|
||||
int DimensionSize(size_t index) const;
|
||||
|
|
|
@ -0,0 +1,181 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/avg_pool_3d.h"
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr size_t k5DInputDims = 5;
|
||||
constexpr size_t kKernelDims = 3;
|
||||
constexpr size_t kStridesDims = 3;
|
||||
constexpr size_t kPadDims = 6;
|
||||
|
||||
void GetAttrs(const PrimitivePtr &primitive, std::vector<int64_t> *kernel_size, std::vector<int64_t> *strides,
|
||||
int64_t *pad_mode, std::vector<int64_t> *pad_list, bool *ceil_mode, bool *count_include_pad,
|
||||
int64_t *divisor_override) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
// attr kernel size
|
||||
*kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
|
||||
if (kernel_size->size() != kKernelDims) {
|
||||
MS_LOG(EXCEPTION) << "kernel_size of AvgPool3D must be 3.";
|
||||
}
|
||||
// attr strides
|
||||
*strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
|
||||
if (strides->size() != kStridesDims) {
|
||||
MS_LOG(EXCEPTION) << "strides of AvgPool3D must be 3.";
|
||||
}
|
||||
// sttr pad_list
|
||||
*pad_list = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPadList));
|
||||
// attr count include pad
|
||||
*count_include_pad = GetValue<bool>(primitive->GetAttr(kCountIncludePad));
|
||||
// attr pad_mode
|
||||
CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), pad_mode, true);
|
||||
// attr ceil mode
|
||||
*ceil_mode = GetValue<bool>(primitive->GetAttr(kCeilMode));
|
||||
// attr divisor override
|
||||
*divisor_override = GetValue<int64_t>(primitive->GetAttr(kDivisorOverride));
|
||||
}
|
||||
|
||||
std::vector<int64_t> GetOutputShape(const std::vector<int64_t> &in_shape, int64_t kernel_d, int64_t kernel_h,
|
||||
int64_t kernel_w, int64_t stride_d, int64_t stride_h, int64_t stride_w,
|
||||
const std::vector<int64_t> &pad_list, bool ceil_mode) {
|
||||
auto in_d = in_shape[2];
|
||||
auto in_h = in_shape[3];
|
||||
auto in_w = in_shape[4];
|
||||
int64_t out_d = 0;
|
||||
int64_t out_h = 0;
|
||||
int64_t out_w = 0;
|
||||
if (ceil_mode) {
|
||||
out_d = std::floor((in_d + pad_list[0] + pad_list[1] - kernel_d + stride_d - 1) / stride_d + 1);
|
||||
out_h = std::floor((in_h + pad_list[2] + pad_list[3] - kernel_h + stride_h - 1) / stride_h + 1);
|
||||
out_w = std::floor((in_w + pad_list[4] + pad_list[5] - kernel_w + stride_w - 1) / stride_w + 1);
|
||||
if ((out_d - 1) * stride_d >= in_d + pad_list[0]) {
|
||||
out_d--;
|
||||
}
|
||||
if ((out_h - 1) * stride_h >= in_h + pad_list[2]) {
|
||||
out_h--;
|
||||
}
|
||||
if ((out_w - 1) * stride_w >= in_w + pad_list[4]) {
|
||||
out_w--;
|
||||
}
|
||||
} else {
|
||||
out_d = std::floor((in_d + pad_list[0] + pad_list[1] - kernel_d) / stride_d + 1);
|
||||
out_h = std::floor((in_h + pad_list[2] + pad_list[3] - kernel_h) / stride_h + 1);
|
||||
out_w = std::floor((in_w + pad_list[4] + pad_list[5] - kernel_w) / stride_w + 1);
|
||||
}
|
||||
std::vector<int64_t> output_shape = {in_shape[0], in_shape[1], out_d, out_h, out_w};
|
||||
return output_shape;
|
||||
}
|
||||
|
||||
void GetPadsByPadding(int64_t in_d, int64_t in_h, int64_t in_w, int64_t kernel_d, int64_t kernel_h, int64_t kernel_w,
|
||||
int64_t stride_d, int64_t stride_h, int64_t stride_w, const int64_t &pad_mode,
|
||||
const std::vector<int64_t> &padding, std::vector<int64_t> *pad_list) {
|
||||
if (pad_mode == PadMode::VALID) {
|
||||
(void)pad_list->insert(pad_list->begin(), kPadDims, 0);
|
||||
} else if (pad_mode == PadMode::SAME) {
|
||||
int64_t tail_d = in_d % stride_d;
|
||||
int64_t tail_h = in_h % stride_h;
|
||||
int64_t tail_w = in_w % stride_w;
|
||||
int64_t pad_d = std::max((tail_d > 0 ? kernel_d - tail_d : kernel_d - stride_d), (int64_t)0);
|
||||
int64_t pad_h = std::max((tail_h > 0 ? kernel_h - tail_h : kernel_h - stride_h), (int64_t)0);
|
||||
int64_t pad_w = std::max((tail_w > 0 ? kernel_w - tail_w : kernel_w - stride_w), (int64_t)0);
|
||||
pad_list->push_back(static_cast<int64_t>(std::floor(pad_d / 2)));
|
||||
pad_list->push_back(pad_d - pad_list->at(0));
|
||||
pad_list->push_back(static_cast<int64_t>(std::floor(pad_h / 2)));
|
||||
pad_list->push_back(pad_h - pad_list->at(2));
|
||||
pad_list->push_back(static_cast<int64_t>(std::floor(pad_w / 2)));
|
||||
pad_list->push_back(pad_w - pad_list->at(4));
|
||||
} else if (pad_mode == PadMode::PAD) {
|
||||
(void)pad_list->assign(padding.begin(), padding.end());
|
||||
}
|
||||
}
|
||||
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kEqual, 1, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, k5DInputDims, op_name);
|
||||
|
||||
std::vector<int64_t> kernel_size;
|
||||
std::vector<int64_t> strides;
|
||||
std::vector<int64_t> pad_list;
|
||||
int64_t pad_mode = 0;
|
||||
bool ceil_mode = false;
|
||||
bool count_include_pad = true;
|
||||
int64_t divisor_override = 0;
|
||||
GetAttrs(primitive, &kernel_size, &strides, &pad_mode, &pad_list, &ceil_mode, &count_include_pad, &divisor_override);
|
||||
auto in_d = in_shape[2];
|
||||
auto in_h = in_shape[3];
|
||||
auto in_w = in_shape[4];
|
||||
auto kernel_d = kernel_size[0];
|
||||
auto kernel_h = kernel_size[1];
|
||||
auto kernel_w = kernel_size[2];
|
||||
auto stride_d = strides[0];
|
||||
auto stride_h = strides[1];
|
||||
auto stride_w = strides[2];
|
||||
std::vector<int64_t> new_pad_list;
|
||||
GetPadsByPadding(in_d, in_h, in_w, kernel_d, kernel_h, kernel_w, stride_d, stride_h, stride_w, pad_mode, pad_list,
|
||||
&new_pad_list);
|
||||
if (new_pad_list.size() != kPadDims) {
|
||||
MS_LOG(EXCEPTION) << "pad_list size must be 6.";
|
||||
}
|
||||
primitive->set_attr(kPadList, MakeValue(new_pad_list));
|
||||
if (pad_mode == PadMode::SAME) {
|
||||
primitive->set_attr(kCountIncludePad, MakeValue(false));
|
||||
}
|
||||
|
||||
std::vector<int64_t> out_shape =
|
||||
GetOutputShape(in_shape, kernel_d, kernel_h, kernel_w, stride_d, stride_h, stride_w, new_pad_list, ceil_mode);
|
||||
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
|
||||
MS_LOG(EXCEPTION) << "output size is not valid.";
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kEqual, 1, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto x_dtype = input_args[0]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("x", x_dtype, valid_types, op_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr AvgPool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool3D, prim::kPrimAvgPool3D, AvgPool3DInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_AVG_POOL_3D_H_
|
||||
#define MINDSPORE_CORE_OPS_AVG_POOL_3D_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
class AvgPool3D : public PrimitiveC {
|
||||
public:
|
||||
AvgPool3D() : PrimitiveC(prim::kPrimAvgPool3D->name()) { InitIOName({"input"}, {"output"}); }
|
||||
~AvgPool3D() = default;
|
||||
MS_DECLARE_PARENT(AvgPool3D, PrimitiveC);
|
||||
};
|
||||
|
||||
AbstractBasePtr AvgPool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimAvgPool3DPtr = std::shared_ptr<AvgPool3D>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_AVG_POOL_3D_H_
|
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/grad/avg_pool_3d_grad.h"
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr size_t k5DInputDims = 5;
|
||||
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kEqual, 2, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("grad_rank", grad_shape.size(), kEqual, k5DInputDims, op_name);
|
||||
std::vector<int64_t> origin_input_size;
|
||||
if (input_args[0]->isa<abstract::AbstractTuple>()) { // origin_size is tuple
|
||||
origin_input_size = GetValue<std::vector<int64_t>>(input_args[0]->BuildValue());
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "origin_input_size must be a tuple for" << op_name;
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(origin_input_size);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kEqual, 2, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto grad_dtype = input_args[1]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("grad", grad_dtype, valid_types, op_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr AvgPool3DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool3DGrad, prim::kPrimAvgPool3DGrad, AvgPool3DGradInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_AVG_POOL_3D_GRAD_H_
|
||||
#define MINDSPORE_CORE_OPS_AVG_POOL_3D_GRAD_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
class AvgPool3DGrad : public PrimitiveC {
|
||||
public:
|
||||
AvgPool3DGrad() : PrimitiveC(prim::kPrimAvgPool3DGrad->name()) {
|
||||
InitIOName({"origin_input_size", "grad"}, {"output"});
|
||||
}
|
||||
~AvgPool3DGrad() = default;
|
||||
MS_DECLARE_PARENT(AvgPool3DGrad, PrimitiveC);
|
||||
};
|
||||
|
||||
AbstractBasePtr AvgPool3DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimAvgPool3DGradPtr = std::shared_ptr<AvgPool3DGrad>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_AVG_POOL_3D_GRAD_H_
|
|
@ -152,6 +152,9 @@ constexpr auto kPadSize = "pad_size";
|
|||
constexpr auto kPooledH = "pooled_h";
|
||||
constexpr auto kPooledW = "pooled_w";
|
||||
constexpr auto kPoolMode = "pool_mode";
|
||||
constexpr auto kCeilMode = "ceil_mode";
|
||||
constexpr auto kCountIncludePad = "count_include_pad";
|
||||
constexpr auto kDivisorOverride = "divisor_override";
|
||||
constexpr auto kPostNmsTopn = "post_nms_topn";
|
||||
constexpr auto kPower = "power";
|
||||
constexpr auto kPreNmsTopn = "pre_nms_topn";
|
||||
|
|
|
@ -315,6 +315,27 @@ def get_bprop_avg_pool_grad(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.AvgPool3D)
|
||||
def get_bprop_avg_pool_3d_grad(self):
|
||||
"""Grad definition for `AvgPool3D` operation."""
|
||||
pad_list = self.get_attr_dict()['pad_list']
|
||||
count_include_pad = self.get_attr_dict()['count_include_pad']
|
||||
avgpool3d_grad = G.AvgPool3DGrad(kernel_size=self.kernel_size,
|
||||
strides=self.strides,
|
||||
pads=pad_list,
|
||||
ceil_mode=self.ceil_mode,
|
||||
count_include_pad=count_include_pad,
|
||||
divisor_override=self.divisor_override,
|
||||
data_format=self.data_format)
|
||||
|
||||
def bprop(x, out, dout):
|
||||
x_shape = F.shape(x)
|
||||
dx = avgpool3d_grad(x_shape, dout)
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.DropoutGenMask)
|
||||
def get_bprop_dropout_gen_mask(self):
|
||||
"""Grad definition for `DropoutGenMask` operation."""
|
||||
|
|
|
@ -223,6 +223,8 @@ from .scatter_nd_update import _scatter_nd_update_tbe
|
|||
from .avg_pool import _avg_pool_tbe
|
||||
from .avg_pool_grad import _avg_pool_grad_tbe
|
||||
from .avg_pool_grad_vm import _avg_pool_grad_vm_tbe
|
||||
from .avg_pool_3d import _avg_pool_3d_tbe
|
||||
from .avg_pool_3d_grad import _avg_pool_3d_grad_tbe
|
||||
from .ones_like import _ones_like_tbe
|
||||
from .ones_like_ds import _ones_like_ds_tbe
|
||||
from .batch_to_space import _batch_to_space_tbe
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""AvgPool3D op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
avg_pool_3d_op_info = TBERegOp("AvgPool3D") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("avg_pool3d_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("avg_pool3d_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("kernel_size", "required", "listInt", "all") \
|
||||
.attr("strides", "required", "listInt", "all") \
|
||||
.attr("pad_list", "required", "listInt", "all") \
|
||||
.attr("ceil_mode", "optional", "bool", "all") \
|
||||
.attr("count_include_pad", "optional", "bool", "all") \
|
||||
.attr("divisor_override", "optional", "int", "all", '0') \
|
||||
.attr("format", "optional", "str", "all", 'NCDHW') \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "filter", False, "optional", "all") \
|
||||
.input(2, "multiplier", False, "optional", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_FRACTAL_Z_3D, DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(avg_pool_3d_op_info)
|
||||
def _avg_pool_3d_tbe():
|
||||
"""AvgPool3D TBE register"""
|
||||
return
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""AvgPool3DGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
avg_pool_3d_grad_op_info = TBERegOp("AvgPool3DGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("avg_pool3d_grad_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("avg_pool3d_grad_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("origin_input_shape", "required", "listInt", "all") \
|
||||
.attr("kernel_size", "required", "listInt", "all") \
|
||||
.attr("strides", "required", "listInt", "all") \
|
||||
.attr("pad_list", "required", "listInt", "all") \
|
||||
.attr("ceil_mode", "optional", "bool", "all") \
|
||||
.attr("count_include_pad", "optional", "bool", "all") \
|
||||
.attr("divisor_override", "optional", "int", "all", '0') \
|
||||
.attr("format", "optional", "str", "all", 'NCDHW') \
|
||||
.input(0, "grads", False, "required", "all") \
|
||||
.input(1, "filter", False, "optional", "all") \
|
||||
.input(2, "multiplier", False, "optional", "all") \
|
||||
.output(0, "output", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_FRACTAL_Z_3D, DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(avg_pool_3d_grad_op_info)
|
||||
def _avg_pool_3d_grad_tbe():
|
||||
"""AvgPool3DGrad TBE register"""
|
||||
return
|
|
@ -70,7 +70,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam
|
|||
InstanceNorm, BNTrainingReduce, BNTrainingUpdate,
|
||||
GeLU, Gelu, FastGeLU, FastGelu, Elu,
|
||||
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder,
|
||||
LogSoftmax, MaxPool3D,
|
||||
LogSoftmax, MaxPool3D, AvgPool3D,
|
||||
MaxPool, DataFormatDimMap,
|
||||
AvgPool, Conv2DBackpropInput, ComputeAccidentalHits,
|
||||
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
|
||||
|
@ -119,6 +119,7 @@ __all__ = [
|
|||
'Argmax',
|
||||
'Argmin',
|
||||
'MaxPool3D',
|
||||
'AvgPool3D',
|
||||
'ArgMaxWithValue',
|
||||
'ArgMinWithValue',
|
||||
'AddN',
|
||||
|
|
|
@ -879,6 +879,30 @@ class AvgPoolGrad(_PoolGrad):
|
|||
return x1_dtype
|
||||
|
||||
|
||||
class AvgPool3DGrad(Primitive):
|
||||
"""Gradients of the avg pool3d operation."""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, kernel_size=1, strides=1, pads=0, ceil_mode=False,
|
||||
count_include_pad=True, divisor_override=0, data_format="NCDHW"):
|
||||
self.init_prim_io_names(inputs=['origin_input_shape', 'grads'], outputs=['output'])
|
||||
self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
|
||||
self.add_prim_attr('kernel_size', self.kernel_size)
|
||||
self.strides = _check_3d_int_or_tuple('strides', strides, self.name)
|
||||
self.add_prim_attr('strides', self.strides)
|
||||
validator.check_value_type('pads', pads, (int, tuple), self.name)
|
||||
if isinstance(pads, int):
|
||||
pads = (pads,) * 6
|
||||
validator.check_equal_int(len(pads), 6, 'pad size', self.name)
|
||||
for item in pads:
|
||||
validator.check_non_negative_int(item, 'pad item', self.name)
|
||||
self.add_prim_attr('pad_list', pads)
|
||||
self.ceil_mode = validator.check_value_type('ceil_mode', ceil_mode, bool, self.name)
|
||||
self.count_include_pad = validator.check_value_type('count_include_pad', count_include_pad, bool, self.name)
|
||||
self.divisor_override = validator.check_value_type('divisor_override', divisor_override, int, self.name)
|
||||
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
|
||||
|
||||
|
||||
class MaxPoolGrad(_PoolGrad):
|
||||
"""Performs gradients of the max pool operation."""
|
||||
|
||||
|
|
|
@ -7638,6 +7638,110 @@ class LRN(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
|
||||
class AvgPool3D(Primitive):
|
||||
r"""
|
||||
3D Average pooling operation.
|
||||
|
||||
Applies a 3D average pooling over an input Tensor which can be regarded as a composition of 3D input planes.
|
||||
Typically the input is of shape :math:`(N_{in}, C_{in}, D_{in}, H_{in}, W_{in})`, AvgPool3D outputs
|
||||
regional average in the :math:`(D_{in}, H_{in}, W_{in})`-dimension. Given kernel size
|
||||
:math:`ks = (d_{ker}, h_{ker}, w_{ker})` and stride :math:`s = (s_0, s_1, s_2)`, the operation is as follows.
|
||||
|
||||
.. math::
|
||||
\text{output}(N_i, C_j, d, h, w) =
|
||||
\frac{1}{d_{ker} * h_{ker} * w_{ker}} \sum_{l=0}^{d_{ker}-1} \sum_{m=0}^{h_{ker}-1} \sum_{n=0}^{w_{ker}-1}
|
||||
\text{input}(N_i, C_j, s_0 \times d + l, s_1 \times h + m, s_2 \times w + n)
|
||||
|
||||
Args:
|
||||
kernel_size (Union[int, tuple[int]]): The size of kernel used to take the average value,
|
||||
is an int number that represents depth, height and width are both kernel_size, or a tuple
|
||||
of three int numbers that represent depth, height and width respectively. Default: 1.
|
||||
strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
|
||||
the height and width of movement are both strides, or a tuple of two int numbers that
|
||||
represent height and width of movement respectively. Default: 1.
|
||||
pad_mode (str): The optional value for pad mode, is "same", "valid", "pad", not case sensitive.
|
||||
Default: "valid".
|
||||
|
||||
- same: Adopts the way of completion. The height and width of the output will be the same as
|
||||
the input. The total number of padding will be calculated in horizontal and vertical
|
||||
directions and evenly distributed to top and bottom, left and right if possible.
|
||||
Otherwise, the last extra padding will be done from the bottom and the right side.
|
||||
|
||||
- valid: Adopts the way of discarding. The possible largest height and width of output
|
||||
will be returned without padding. Extra pixels will be discarded.
|
||||
|
||||
- pad: Implicit paddings on both sides of the input in depth, height, width. The number of `pad` will
|
||||
be padded to the input Tensor borders. `pad` must be greater than or equal to 0.
|
||||
pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
|
||||
head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of six
|
||||
integers, the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2],
|
||||
pad[3], pad[4] and pad[5] correspondingly.
|
||||
ceil_mode (bool): If True, ceil instead of floor to compute the output shape. Default: False.
|
||||
count_include_pad (bool): If True, averaging calculation will include the zero-padding. Default: True.
|
||||
divisor_override (int): If specified, it will be used as divisor in the averaging calculation,
|
||||
otherwise kernel_size will be used. Default: 0.
|
||||
data_format (str) : The optional value for data format. Currently only support 'NCDHW'. Default: 'NCDHW'.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{out}, H_{in}, W_{in})`.
|
||||
Currently support float16 and float32 data type.
|
||||
|
||||
Outputs:
|
||||
Tensor, with shape :math:`(N, C, D_{out}, H_{out}, W_{out})`. Has the same data type with `input`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `kernel_size`, `strides` or `pad` is neither an int not a tuple.
|
||||
TypeError: If `ceil_mode` or `count_include_pad` is not a bool.
|
||||
TypeError: If `pad_mode` or `data_format` is not a string.
|
||||
TypeError: If `divisor_override` is not an int.
|
||||
ValueError: If numbers in `kernel_size` or `strides` are not positive.
|
||||
ValueError: If `kernel_size` or `strides` is a tuple whose length is not equal to 3.
|
||||
ValueError: If `pad_mode` is not one of 'same', 'valid' or 'pad'.
|
||||
ValueError: If `pad` is a tuple whose length is not equal to 6.
|
||||
ValueError: If element of `pad` is less than 0.
|
||||
ValueError: If `pad_mode` is not equal to 'pad' and `pad` is not equal to (0, 0, 0, 0, 0, 0).
|
||||
ValueError: If `data_format` is not 'NCDHW'.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> input = Tensor(np.arange(1 * 2 * 2 * 2 * 3).reshape((1, 2, 2, 2, 3)), mindspore.float16)
|
||||
>>> avg_pool3d = ops.AvgPool3D(kernel_size=2, strides=1, pad_mode="valid")
|
||||
>>> output = avg_pool3d(input)
|
||||
>>> print(output)
|
||||
[[[[[233.5 248.625]]]
|
||||
[[[233.5 238.125]]]]]
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, kernel_size=1, strides=1, pad_mode="valid", pad=0, ceil_mode=False,
|
||||
count_include_pad=True, divisor_override=0, data_format="NCDHW"):
|
||||
"""Initialize AvgPool3D"""
|
||||
self.init_prim_io_names(inputs=['input'], outputs=['output'])
|
||||
self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
|
||||
self.add_prim_attr('kernel_size', self.kernel_size)
|
||||
self.strides = _check_3d_int_or_tuple('strides', strides, self.name)
|
||||
validator.check_value_type('pad', pad, (int, tuple), self.name)
|
||||
self.add_prim_attr('strides', self.strides)
|
||||
if isinstance(pad, int):
|
||||
pad = (pad,) * 6
|
||||
validator.check_equal_int(len(pad), 6, 'pad size', self.name)
|
||||
self.pad_list = pad
|
||||
self.add_prim_attr('pad_list', self.pad_list)
|
||||
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME', 'PAD'], 'pad_mode', self.name)
|
||||
self.add_prim_attr('pad_mode', self.pad_mode)
|
||||
|
||||
if self.pad_mode != 'PAD' and pad != (0, 0, 0, 0, 0, 0):
|
||||
raise ValueError(f"For '{self.name}', when pad is not 0, pad_mode should be set as 'pad'.")
|
||||
if self.pad_mode == 'PAD':
|
||||
for item in pad:
|
||||
validator.check_non_negative_int(item, 'pad item', self.name)
|
||||
self.ceil_mode = validator.check_value_type('ceil_mode', ceil_mode, bool, self.name)
|
||||
self.count_include_pad = validator.check_value_type('count_include_pad', count_include_pad, bool, self.name)
|
||||
self.divisor_override = validator.check_non_negative_int(divisor_override, 'divisor_override', self.name)
|
||||
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
|
||||
|
||||
|
||||
class Conv3D(PrimitiveWithInfer):
|
||||
r"""
|
||||
3D convolution layer.
|
||||
|
|
|
@ -1709,6 +1709,14 @@ test_case_nn_ops = [
|
|||
'block': P.AvgPool(kernel_size=(2, 2), strides=(2, 2), pad_mode="VALID"),
|
||||
'desc_inputs': [[100, 3, 28, 28]],
|
||||
'desc_bprop': [[100, 3, 14, 14]]}),
|
||||
('AvgPool3D_1', {
|
||||
'block': P.AvgPool3D(kernel_size=2, strides=2, pad_mode="VALID"),
|
||||
'desc_inputs': [[10, 3, 28, 28, 28]],
|
||||
'desc_bprop': [[10, 3, 14, 14, 14]]}),
|
||||
('AvgPool3D_2', {
|
||||
'block': P.AvgPool3D(kernel_size=3, strides=2, pad_mode="PAD", pad=1),
|
||||
'desc_inputs': [[10, 3, 28, 31, 24]],
|
||||
'desc_bprop': [[10, 3, 14, 16, 12]]}),
|
||||
('MaxPoolWithArgmax', {
|
||||
'block': P.MaxPoolWithArgmax(kernel_size=2, strides=2),
|
||||
'desc_inputs': [[128, 32, 32, 64]],
|
||||
|
|
Loading…
Reference in New Issue