forked from mindspore-Ecosystem/mindspore
length of kernel_size and strides should be five
This commit is contained in:
parent
a274d27a87
commit
e4ebd8f998
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -20,10 +20,6 @@
|
|||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/trace_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -287,7 +283,7 @@ const AnfNodePtr AvgPool3DFusion::Process(const FuncGraphPtr &func_graph, const
|
|||
MS_LOG(EXCEPTION) << "Get stride size failed" << trace::DumpSourceLines(node);
|
||||
}
|
||||
std::vector<int64_t> pad_list;
|
||||
bool count_include_pad = false;
|
||||
bool count_include_pad = true;
|
||||
bool ceil_mode = false;
|
||||
int64_t divisor_override = 0;
|
||||
GetAttrs(avg_pool_3d_node, &pad_list, &count_include_pad, &ceil_mode, &divisor_override);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -21,10 +21,6 @@
|
|||
#include <algorithm>
|
||||
#include "backend/optimizer/ascend/ir_fusion/avgpool_3d_fusion.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/trace_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -37,6 +33,8 @@ constexpr size_t kOrigShapeDims = 5;
|
|||
constexpr size_t kShapeDims = 6;
|
||||
constexpr size_t kPadDims = 6;
|
||||
constexpr int64_t kC0 = 16;
|
||||
constexpr auto kKernelSize = "kernel_size";
|
||||
constexpr auto kStrides = "strides";
|
||||
|
||||
void GetAttrs(const AnfNodePtr &node, std::vector<int64_t> *kernel_size, std::vector<int64_t> *strides,
|
||||
std::vector<int64_t> *pad_list, std::vector<int64_t> *origin_input_shape, bool *ceil_mode,
|
||||
|
@ -211,6 +209,14 @@ const AnfNodePtr AvgPool3DGradFusion::Process(const FuncGraphPtr &func_graph, co
|
|||
std::string format_str;
|
||||
GetAttrs(avg_pool_3d_grad_node, &kernel_size, &strides, &pad_list, &origin_input_shape, &ceil_mode,
|
||||
&count_include_pad, &divisor_override, &format_str);
|
||||
const int64_t dim_one = SizeToLong(1);
|
||||
AnfAlgo::SetNodeAttr(
|
||||
kKernelSize,
|
||||
MakeValue(std::vector<int64_t>{dim_one, dim_one, kernel_size[kDim0], kernel_size[kDim1], kernel_size[kDim2]}),
|
||||
avg_pool_3d_grad_node);
|
||||
AnfAlgo::SetNodeAttr(
|
||||
kStrides, MakeValue(std::vector<int64_t>{dim_one, dim_one, strides[kDim0], strides[kDim1], strides[kDim2]}),
|
||||
avg_pool_3d_grad_node);
|
||||
if (IsVectorImpl(origin_input_shape, kernel_size, pad_list)) {
|
||||
MS_LOG(INFO) << "No need fusion";
|
||||
return nullptr;
|
||||
|
@ -241,11 +247,6 @@ const AnfNodePtr AvgPool3DGradFusion::Process(const FuncGraphPtr &func_graph, co
|
|||
new_3d_grad->set_scope(avg_pool_3d_grad_node->scope());
|
||||
new_3d_grad->set_abstract(avg_pool_3d_grad_node->abstract());
|
||||
AnfAlgo::CopyNodeAttrs(avg_pool_3d_grad_node, new_3d_grad);
|
||||
const int64_t dim_one = SizeToLong(1);
|
||||
AnfAlgo::SetNodeAttr("kernel_size", MakeValue(std::vector<int64_t>{dim_one, dim_one, kd, kh, kw}), new_3d_grad);
|
||||
AnfAlgo::SetNodeAttr(
|
||||
"strides", MakeValue(std::vector<int64_t>{dim_one, dim_one, strides[kDim0], strides[kDim1], strides[kDim2]}),
|
||||
new_3d_grad);
|
||||
return new_3d_grad;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -19,8 +19,6 @@
|
|||
#include <set>
|
||||
#include <vector>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
|
Loading…
Reference in New Issue