clean code2

This commit is contained in:
VectorSL 2021-05-27 11:26:13 +08:00
parent 9e62c0ed08
commit 03210aee81
5 changed files with 17 additions and 16 deletions

View File

@ -958,14 +958,14 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr>
MS_EXCEPTION_IF_NULL(cnode);
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_device_format.push_back(kOpFormat_DEFAULT);
inputs_device_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
inputs_device_format.emplace_back(kOpFormat_DEFAULT);
inputs_device_type.emplace_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
}
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
outputs_device_format.push_back(kOpFormat_DEFAULT);
outputs_device_type.push_back(AnfAlgo::GetOutputInferDataType(cnode, output_index));
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
outputs_device_format.emplace_back(kOpFormat_DEFAULT);
outputs_device_type.emplace_back(AnfAlgo::GetOutputInferDataType(cnode, output_index));
outputs_shape.emplace_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
}
}
builder.SetInputsFormat(inputs_device_format);

View File

@ -46,6 +46,9 @@ constexpr size_t kAssignSubInputTensorNum = 2;
constexpr size_t kDropoutInputTensorNum = 1;
constexpr size_t kAssignInputTensorNum = 2;
constexpr size_t kGradIndex = 3;
constexpr size_t kAddNInputNum = 2;
constexpr size_t kConvBn1OutputNum = 3;
constexpr size_t kBn2ReluOutputNum = 4;

View File

@ -46,20 +46,20 @@ bool GetDealList(const std::vector<AnfNodePtr> &node_list, std::vector<std::vect
if (dst == kNumberTypeFloat16 && src == kNumberTypeFloat32) {
cast_32to16_list.push_back(cast_node);
if (IsPrimitiveCNode(input0, prim::kPrimLoad)) {
auto &monad = input0->cast<CNodePtr>()->inputs().at(second_input_index);
auto &monad_32to16 = input0->cast<CNodePtr>()->inputs().at(second_input_index);
if (cast_32to16_load_monad == nullptr) {
cast_32to16_load_monad = monad;
} else if (cast_32to16_load_monad != monad) {
cast_32to16_load_monad = monad_32to16;
} else if (cast_32to16_load_monad != monad_32to16) {
return false;
}
}
} else if (dst == kNumberTypeFloat32 && src == kNumberTypeFloat16) {
cast_16to32_list.push_back(cast_node);
if (IsPrimitiveCNode(input0, prim::kPrimLoad)) {
auto &monad = input0->cast<CNodePtr>()->inputs().at(second_input_index);
auto &monad_16to32 = input0->cast<CNodePtr>()->inputs().at(second_input_index);
if (cast_16to32_load_monad == nullptr) {
cast_16to32_load_monad = monad;
} else if (cast_16to32_load_monad != monad) {
cast_16to32_load_monad = monad_16to32;
} else if (cast_16to32_load_monad != monad_16to32) {
return false;
}
}

View File

@ -24,7 +24,6 @@
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#define ADD_NUM 2
namespace mindspore {
namespace opt {
const BaseRef ReplaceAddNFusion::DefinePattern() const {
@ -42,7 +41,7 @@ const AnfNodePtr ReplaceAddNFusion::Process(const FuncGraphPtr &graph, const Anf
MS_EXCEPTION_IF_NULL(A);
MS_EXCEPTION_IF_NULL(B);
int64_t num_input = AnfAlgo::GetNodeAttr<int64_t>(node, "n");
if (num_input == ADD_NUM) {
if (num_input == kAddNInputNum) {
auto prim = std::make_shared<Primitive>(prim::kPrimAdd->name());
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), A, B};

View File

@ -24,7 +24,6 @@
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#define GRAD_INDEX 3
namespace mindspore {
namespace opt {
const BaseRef ReplaceMomentumCastFusion::DefinePattern() const {
@ -39,7 +38,7 @@ const AnfNodePtr ReplaceMomentumCastFusion::Process(const FuncGraphPtr &graph, c
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
auto grad_cast = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), GRAD_INDEX);
auto grad_cast = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), kGradIndex);
MS_EXCEPTION_IF_NULL(grad_cast);
auto src = AnfAlgo::GetPrevNodeOutputInferDataType(grad_cast, 0);
// momentum only support fp32/fp16 by now, do nothing if not.
@ -58,7 +57,7 @@ const AnfNodePtr ReplaceMomentumCastFusion::Process(const FuncGraphPtr &graph, c
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, i));
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(node, i));
}
outputs_type[GRAD_INDEX] = AnfAlgo::GetPrevNodeOutputInferDataType(grad_cast, 0);
outputs_type[kGradIndex] = AnfAlgo::GetPrevNodeOutputInferDataType(grad_cast, 0);
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, node.get());