forked from mindspore-Ecosystem/mindspore
!5704 gpu inceptionv3 optimize
Merge pull request !5704 from limingqi107/master
This commit is contained in:
commit
311493fe83
|
@ -96,11 +96,13 @@ class BiasAddGpuKernel : public GpuKernel {
|
|||
b_dims[i] = (i == pos) ? SizeToInt(x_shape[i]) : 1;
|
||||
}
|
||||
|
||||
auto input_device_format = AnfAlgo::GetInputFormat(kernel_node, 0);
|
||||
auto cudnn_cal_format = (input_device_format == kOpFormat_NHWC) ? CUDNN_TENSOR_NHWC : CUDNN_TENSOR_NCHW;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensorNdDescriptorEx(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), x_dims.get()),
|
||||
cudnnSetTensorNdDescriptorEx(x_desc_, cudnn_cal_format, cudnn_data_type_, SizeToInt(cudnn_dims), x_dims.get()),
|
||||
"cudnnSetTensorNdDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensorNdDescriptorEx(b_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), b_dims.get()),
|
||||
cudnnSetTensorNdDescriptorEx(b_desc_, cudnn_cal_format, cudnn_data_type_, SizeToInt(cudnn_dims), b_dims.get()),
|
||||
"cudnnSetTensorNdDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetOpTensorDescriptor(op_desc_, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN),
|
||||
|
|
|
@ -94,11 +94,13 @@ class BiasAddGradGpuKernel : public GpuKernel {
|
|||
}
|
||||
}
|
||||
|
||||
auto input_device_format = AnfAlgo::GetInputFormat(kernel_node, 0);
|
||||
auto cudnn_cal_format = (input_device_format == kOpFormat_NHWC) ? CUDNN_TENSOR_NHWC : CUDNN_TENSOR_NCHW;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensorNdDescriptorEx(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), dy_dims.get()),
|
||||
cudnnSetTensorNdDescriptorEx(dy_desc_, cudnn_cal_format, cudnn_data_type_, SizeToInt(cudnn_dims), dy_dims.get()),
|
||||
"cudnnSetTensorNdDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensorNdDescriptorEx(db_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), db_dims.get()),
|
||||
cudnnSetTensorNdDescriptorEx(db_desc_, cudnn_cal_format, cudnn_data_type_, SizeToInt(cudnn_dims), db_dims.get()),
|
||||
"cudnnSetTensorNdDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetReduceTensorDescriptor(op_desc_, CUDNN_REDUCE_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN,
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
|
@ -38,21 +39,22 @@ std::vector<int> TransposeAxis(const std::string &src_format, const std::string
|
|||
}
|
||||
|
||||
// Transpose can be replaceed by nop reshape in some situations.
|
||||
// 1. out_shape [x, 1, 1, y] with transpose perm {0, 2, 3, 1}
|
||||
// 2. out_shape [x, y, 1, 1] with transpose perm {0, 3, 1, 2}
|
||||
// 1. out_shape [x, 1, 1, y]
|
||||
// 2. out_shape [x, y, 1, 1]
|
||||
// 3. out_shape [x, 1, y, 1]
|
||||
bool IsFakeTranspose(const std::vector<size_t> &out_shape, const std::vector<int> &transpose_perm) {
|
||||
if (out_shape.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "Invalid data shape, 4-D data was needed, but get " << out_shape.size() << "-D.";
|
||||
}
|
||||
std::vector<int> perm1 = {0, 2, 3, 1};
|
||||
std::vector<int> perm2 = {0, 3, 1, 2};
|
||||
if (transpose_perm == perm1) {
|
||||
return (out_shape[1] == 1 && out_shape[2] == 1);
|
||||
} else if (transpose_perm == perm2) {
|
||||
return (out_shape[2] == 1 && out_shape[3] == 1);
|
||||
} else {
|
||||
return false;
|
||||
auto num = std::count(out_shape.begin(), out_shape.end(), 1);
|
||||
if ((transpose_perm == perm1) || (transpose_perm == perm2)) {
|
||||
if (num >= 2) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void SetTransposeOpBuildInfo(const std::string &input_format, const std::string &output_format,
|
||||
|
@ -73,6 +75,8 @@ void SetTransposeOpBuildInfo(const std::string &input_format, const std::string
|
|||
// Insert transpose op between node and used_node whose position is used_node_index.
|
||||
CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &used_node,
|
||||
int used_node_index, const std::vector<int> &transpose_perm) {
|
||||
MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope()
|
||||
<< ", index: " << used_node_index;
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// 0.Judge whether it is a fake transpose
|
||||
auto transed_shape = AnfAlgo::GetInputDeviceShape(used_node, used_node_index);
|
||||
|
@ -95,15 +99,10 @@ CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, co
|
|||
if (!is_fake) {
|
||||
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op);
|
||||
}
|
||||
// 4.Set the input of used_node.
|
||||
MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope()
|
||||
<< ", index: " << used_node_index;
|
||||
AnfAlgo::SetNodeInput(utils::cast<CNodePtr>(used_node), transpose_op, used_node_index);
|
||||
// 5. Update the manager info of transpose op.
|
||||
// 4. Set the new edge of transpose op.
|
||||
FuncGraphManagerPtr manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->Clear();
|
||||
manager->AddFuncGraph(graph);
|
||||
manager->SetEdge(used_node, used_node_index + 1, transpose_op);
|
||||
return transpose_op;
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -252,11 +252,11 @@ void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<s
|
|||
bn_cnt++;
|
||||
}
|
||||
}
|
||||
if (conv_cnt == kConv2dCount && bn_cnt == kFusedBatchNormCount) {
|
||||
format_transform_ = false;
|
||||
if (conv_cnt + bn_cnt > 1) {
|
||||
format_transform_ = true;
|
||||
return;
|
||||
}
|
||||
format_transform_ = true;
|
||||
format_transform_ = false;
|
||||
}
|
||||
|
||||
void SetKernelInfo(const CNodePtr &kernel_node) {
|
||||
|
|
|
@ -34,23 +34,27 @@ namespace gpu {
|
|||
// map<opName, (inputFormatPosition, outputFormatPosition)>, used for getting the insert position of format transform.
|
||||
// If input position is empty, then insert all the input positions, because the input numbers of this op are variable.
|
||||
static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>> kKernelFormatPositionMap = {
|
||||
// Format sensitive.
|
||||
{prim::kPrimConv2D->name(), {{0, 1}, {0}}},
|
||||
{prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {0}}},
|
||||
{prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {0}}},
|
||||
{prim::kPrimRelu->name(), {{0}, {0}}},
|
||||
{prim::kPrimReluGrad->name(), {{0, 1}, {0}}},
|
||||
{prim::kPrimMaxPool->name(), {{0}, {0}}},
|
||||
{prim::kPrimMaxPoolGrad->name(), {{0, 1, 2}, {0}}},
|
||||
{kSliceOpName, {{0}, {0}}},
|
||||
{kAvgPoolOpName, {{0}, {0}}},
|
||||
{kAvgPoolGradGpuOpName, {{0, 1, 2}, {0}}},
|
||||
{kTensorAddOpName, {{0, 1}, {0}}},
|
||||
{kFusedBatchNormEx, {{0}, {0}}},
|
||||
{kFusedBatchNormExWithActivation, {{0}, {0}}},
|
||||
{kFusedBatchNormExWithAddAndActivation, {{0, 5}, {0}}},
|
||||
{kFusedBatchNormGradEx, {{0, 1}, {0}}},
|
||||
{kFusedBatchNormGradExWithActivation, {{0, 1, 7}, {0}}},
|
||||
{kFusedBatchNormGradExWithAddAndActivation, {{0, 1, 7}, {0, 3}}},
|
||||
{kBiasAddOpName, {{0}, {0}}},
|
||||
{prim::kPrimBiasAddGrad->name(), {{0}, {}}},
|
||||
// Format insensitive.
|
||||
{prim::kPrimRelu->name(), {{0}, {0}}},
|
||||
{prim::kPrimReluGrad->name(), {{0, 1}, {0}}},
|
||||
{kSliceOpName, {{0}, {0}}},
|
||||
{kTensorAddOpName, {{0, 1}, {0}}},
|
||||
{prim::kPrimConcat->name(), {{}, {0}}},
|
||||
{prim::kPrimAddN->name(), {{}, {0}}},
|
||||
};
|
||||
|
@ -74,8 +78,6 @@ class FormatTransformChecker {
|
|||
FormatTransformChecker &operator=(const FormatTransformChecker &);
|
||||
|
||||
bool format_transform_{true};
|
||||
static constexpr size_t kConv2dCount = 96;
|
||||
static constexpr size_t kFusedBatchNormCount = 94;
|
||||
};
|
||||
|
||||
class KernelAttr {
|
||||
|
|
Loading…
Reference in New Issue