forked from mindspore-Ecosystem/mindspore
!13164 optimize GPU format transform
From: @limingqi107 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
24f1bc268d
|
@ -178,6 +178,21 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TransformFormatPosition(std::vector<size_t> *format_position, size_t position_num) {
|
||||||
|
MS_EXCEPTION_IF_NULL(format_position);
|
||||||
|
if (format_position->size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the inserted position is kAllPositions, then insert all the positions.
|
||||||
|
if ((*format_position)[0] == kAllPositions) {
|
||||||
|
format_position->clear();
|
||||||
|
for (size_t index = 0; index < position_num; index++) {
|
||||||
|
format_position->push_back(index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type) {
|
bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type) {
|
||||||
auto ms_context = MsContext::GetInstance();
|
auto ms_context = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(ms_context);
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
|
@ -198,20 +213,28 @@ bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<Type
|
||||||
if (inputs_type.size() == 0) {
|
if (inputs_type.size() == 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto inputs_format_position = iter->second.first;
|
auto inputs_format_position = iter->second.first;
|
||||||
// If input position is empty, then insert all the input positions, because the input numbers of this op are variable.
|
|
||||||
if (inputs_format_position.size() == 0) {
|
|
||||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
for (size_t input_index = 0; input_index < input_num; input_index++) {
|
TransformFormatPosition(&inputs_format_position, input_num);
|
||||||
inputs_format_position.push_back(input_index);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (const auto &input_format_position : inputs_format_position) {
|
for (const auto &input_format_position : inputs_format_position) {
|
||||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, input_format_position);
|
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, input_format_position);
|
||||||
|
// Only support the transformer between NCHW and NHWC, so need the shape is 4 dimension.
|
||||||
if (input_shape.size() != 4) {
|
if (input_shape.size() != 4) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto outputs_format_position = iter->second.second;
|
||||||
|
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||||
|
TransformFormatPosition(&outputs_format_position, output_num);
|
||||||
|
for (const auto &output_format_position : outputs_format_position) {
|
||||||
|
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, output_format_position);
|
||||||
|
// Only support the transformer between NCHW and NHWC, so need the shape is 4 dimension.
|
||||||
|
if (output_shape.size() != 4) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -226,13 +249,8 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeI
|
||||||
auto cal_format = (inputs_type[0] == kNumberTypeFloat16) ? kOpFormat_NHWC : kOpFormat_NCHW;
|
auto cal_format = (inputs_type[0] == kNumberTypeFloat16) ? kOpFormat_NHWC : kOpFormat_NCHW;
|
||||||
MS_LOG(DEBUG) << "Kernel node: " << kernel_node->fullname_with_scope() << ", format: " << cal_format;
|
MS_LOG(DEBUG) << "Kernel node: " << kernel_node->fullname_with_scope() << ", format: " << cal_format;
|
||||||
auto inputs_format_position = iter->second.first;
|
auto inputs_format_position = iter->second.first;
|
||||||
// If input position is empty, then insert all the input positions, because the input numbers of this op are variable.
|
|
||||||
if (inputs_format_position.size() == 0) {
|
|
||||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
for (size_t input_index = 0; input_index < input_num; input_index++) {
|
TransformFormatPosition(&inputs_format_position, input_num);
|
||||||
inputs_format_position.push_back(input_index);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (const auto &input_format_position : inputs_format_position) {
|
for (const auto &input_format_position : inputs_format_position) {
|
||||||
if (input_format_position >= inputs_format->size()) {
|
if (input_format_position >= inputs_format->size()) {
|
||||||
MS_LOG(EXCEPTION) << "The position [" << input_format_position << "] is out of range of the input size ["
|
MS_LOG(EXCEPTION) << "The position [" << input_format_position << "] is out of range of the input size ["
|
||||||
|
@ -240,7 +258,10 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeI
|
||||||
}
|
}
|
||||||
(*inputs_format)[input_format_position] = cal_format;
|
(*inputs_format)[input_format_position] = cal_format;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto outputs_format_position = iter->second.second;
|
auto outputs_format_position = iter->second.second;
|
||||||
|
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||||
|
TransformFormatPosition(&outputs_format_position, output_num);
|
||||||
for (const auto &output_format_position : outputs_format_position) {
|
for (const auto &output_format_position : outputs_format_position) {
|
||||||
if (output_format_position >= outputs_format->size()) {
|
if (output_format_position >= outputs_format->size()) {
|
||||||
MS_LOG(EXCEPTION) << "The position [" << output_format_position << "] is out of range of the output size ["
|
MS_LOG(EXCEPTION) << "The position [" << output_format_position << "] is out of range of the output size ["
|
||||||
|
|
|
@ -32,8 +32,11 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
// map<opName, (inputFormatPosition, outputFormatPosition)>, used for getting the insert position of format transform.
|
const size_t kAllPositions = SIZE_MAX;
|
||||||
// If input position is empty, then insert all the input positions, because the input numbers of this op are variable.
|
|
||||||
|
// Map<opName, (inputFormatPosition, outputFormatPosition)>, used for getting the inserted position of format transform.
|
||||||
|
// If the inserted position is kAllPositions, then insert all the positions, because the input or output numbers of
|
||||||
|
// this op are variable.
|
||||||
static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>> kKernelFormatPositionMap = {
|
static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>> kKernelFormatPositionMap = {
|
||||||
// Format sensitive.
|
// Format sensitive.
|
||||||
{prim::kPrimConv2D->name(), {{0, 1}, {0}}},
|
{prim::kPrimConv2D->name(), {{0, 1}, {0}}},
|
||||||
|
@ -58,8 +61,8 @@ static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>
|
||||||
{prim::kPrimRelu6Grad->name(), {{0, 1}, {0}}},
|
{prim::kPrimRelu6Grad->name(), {{0, 1}, {0}}},
|
||||||
{kSliceOpName, {{0}, {0}}},
|
{kSliceOpName, {{0}, {0}}},
|
||||||
{kTensorAddOpName, {{0, 1}, {0}}},
|
{kTensorAddOpName, {{0, 1}, {0}}},
|
||||||
{prim::kPrimConcat->name(), {{}, {0}}},
|
{prim::kPrimConcat->name(), {{kAllPositions}, {0}}},
|
||||||
{prim::kPrimAddN->name(), {{}, {0}}},
|
{prim::kPrimAddN->name(), {{kAllPositions}, {0}}},
|
||||||
};
|
};
|
||||||
|
|
||||||
void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
|
void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
|
||||||
|
|
Loading…
Reference in New Issue