forked from mindspore-Ecosystem/mindspore
clean code2
This commit is contained in:
parent
9e62c0ed08
commit
03210aee81
|
@ -958,14 +958,14 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr>
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_device_format.push_back(kOpFormat_DEFAULT);
|
||||
inputs_device_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
|
||||
inputs_device_format.emplace_back(kOpFormat_DEFAULT);
|
||||
inputs_device_type.emplace_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
outputs_device_format.push_back(kOpFormat_DEFAULT);
|
||||
outputs_device_type.push_back(AnfAlgo::GetOutputInferDataType(cnode, output_index));
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
|
||||
outputs_device_format.emplace_back(kOpFormat_DEFAULT);
|
||||
outputs_device_type.emplace_back(AnfAlgo::GetOutputInferDataType(cnode, output_index));
|
||||
outputs_shape.emplace_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
|
||||
}
|
||||
}
|
||||
builder.SetInputsFormat(inputs_device_format);
|
||||
|
|
|
@ -46,6 +46,9 @@ constexpr size_t kAssignSubInputTensorNum = 2;
|
|||
constexpr size_t kDropoutInputTensorNum = 1;
|
||||
constexpr size_t kAssignInputTensorNum = 2;
|
||||
|
||||
constexpr size_t kGradIndex = 3;
|
||||
constexpr size_t kAddNInputNum = 2;
|
||||
|
||||
constexpr size_t kConvBn1OutputNum = 3;
|
||||
constexpr size_t kBn2ReluOutputNum = 4;
|
||||
|
||||
|
|
|
@ -46,20 +46,20 @@ bool GetDealList(const std::vector<AnfNodePtr> &node_list, std::vector<std::vect
|
|||
if (dst == kNumberTypeFloat16 && src == kNumberTypeFloat32) {
|
||||
cast_32to16_list.push_back(cast_node);
|
||||
if (IsPrimitiveCNode(input0, prim::kPrimLoad)) {
|
||||
auto &monad = input0->cast<CNodePtr>()->inputs().at(second_input_index);
|
||||
auto &monad_32to16 = input0->cast<CNodePtr>()->inputs().at(second_input_index);
|
||||
if (cast_32to16_load_monad == nullptr) {
|
||||
cast_32to16_load_monad = monad;
|
||||
} else if (cast_32to16_load_monad != monad) {
|
||||
cast_32to16_load_monad = monad_32to16;
|
||||
} else if (cast_32to16_load_monad != monad_32to16) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else if (dst == kNumberTypeFloat32 && src == kNumberTypeFloat16) {
|
||||
cast_16to32_list.push_back(cast_node);
|
||||
if (IsPrimitiveCNode(input0, prim::kPrimLoad)) {
|
||||
auto &monad = input0->cast<CNodePtr>()->inputs().at(second_input_index);
|
||||
auto &monad_16to32 = input0->cast<CNodePtr>()->inputs().at(second_input_index);
|
||||
if (cast_16to32_load_monad == nullptr) {
|
||||
cast_16to32_load_monad = monad;
|
||||
} else if (cast_16to32_load_monad != monad) {
|
||||
cast_16to32_load_monad = monad_16to32;
|
||||
} else if (cast_16to32_load_monad != monad_16to32) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,7 +24,6 @@
|
|||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
#define ADD_NUM 2
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef ReplaceAddNFusion::DefinePattern() const {
|
||||
|
@ -42,7 +41,7 @@ const AnfNodePtr ReplaceAddNFusion::Process(const FuncGraphPtr &graph, const Anf
|
|||
MS_EXCEPTION_IF_NULL(A);
|
||||
MS_EXCEPTION_IF_NULL(B);
|
||||
int64_t num_input = AnfAlgo::GetNodeAttr<int64_t>(node, "n");
|
||||
if (num_input == ADD_NUM) {
|
||||
if (num_input == kAddNInputNum) {
|
||||
auto prim = std::make_shared<Primitive>(prim::kPrimAdd->name());
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), A, B};
|
||||
|
|
|
@ -24,7 +24,6 @@
|
|||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
#define GRAD_INDEX 3
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef ReplaceMomentumCastFusion::DefinePattern() const {
|
||||
|
@ -39,7 +38,7 @@ const AnfNodePtr ReplaceMomentumCastFusion::Process(const FuncGraphPtr &graph, c
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
|
||||
auto grad_cast = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), GRAD_INDEX);
|
||||
auto grad_cast = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), kGradIndex);
|
||||
MS_EXCEPTION_IF_NULL(grad_cast);
|
||||
auto src = AnfAlgo::GetPrevNodeOutputInferDataType(grad_cast, 0);
|
||||
// momentum only support fp32/fp16 by now, do nothing if not.
|
||||
|
@ -58,7 +57,7 @@ const AnfNodePtr ReplaceMomentumCastFusion::Process(const FuncGraphPtr &graph, c
|
|||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, i));
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(node, i));
|
||||
}
|
||||
outputs_type[GRAD_INDEX] = AnfAlgo::GetPrevNodeOutputInferDataType(grad_cast, 0);
|
||||
outputs_type[kGradIndex] = AnfAlgo::GetPrevNodeOutputInferDataType(grad_cast, 0);
|
||||
|
||||
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, node.get());
|
||||
|
||||
|
|
Loading…
Reference in New Issue