!29457 fix Conv2DUnifyMindIR when mode set to graph but run pynative

Merge pull request !29457 from yuchaojie/unify_ir
This commit is contained in:
i-robot 2022-01-25 01:17:32 +00:00 committed by Gitee
commit 3e5eaa416f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 9 additions and 7 deletions

View File

@ -41,8 +41,6 @@ constexpr auto kAttrMode = "mode";
constexpr auto kAttrChannelMultiplier = "channel_multiplier";
constexpr auto kAttrInputSizes = "input_sizes";
constexpr auto kAttrInputSize = "input_size";
constexpr auto kIndex2 = 2;
constexpr auto kIndex3 = 3;
bool NeedUpdate(const CNodePtr &conv2d, std::vector<size_t> in_shape, std::vector<size_t> out_shape) {
MS_EXCEPTION_IF_NULL(conv2d);
@ -72,6 +70,13 @@ bool NeedUpdate(const CNodePtr &conv2d, std::vector<size_t> in_shape, std::vecto
return true;
}
bool IsPynative() {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
return ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode ||
ms_context->get_param<int>(MS_CTX_ENABLE_PYNATIVE_INFER);
}
ValueNodePtr CreatePermValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &perm) {
MS_EXCEPTION_IF_NULL(func_graph);
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
@ -97,11 +102,9 @@ CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, cons
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(conv2d);
MS_EXCEPTION_IF_NULL(input_node);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto perm = std::vector<int64_t>{1, 0, 2, 3};
std::vector<AnfNodePtr> transpose_inputs;
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
if (IsPynative()) {
transpose_inputs = {NewValueNode(std::make_shared<Primitive>(kTransposeOpName)), input_node};
} else {
transpose_inputs = {NewValueNode(std::make_shared<Primitive>(kTransposeOpName)), input_node,
@ -129,7 +132,7 @@ CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, cons
auto output_names = std::vector<std::string>{"output"};
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), transpose);
AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), transpose);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
if (IsPynative()) {
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(perm), transpose);
}
return transpose;

View File

@ -569,7 +569,6 @@ ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vec
auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(shape_tensor->data().nbytes()), &shape[0], elem_num);
if (ret_code != 0) {
MS_LOG(EXCEPTION) << "Failed to copy data into tensor, memcpy_s errorno: " << ret_code;
return nullptr;
}
shape_value = shape_tensor;
abstract = std::make_shared<abstract::AbstractTensor>(kInt64, shape_vec_shape);