!29457 fix Conv2DUnifyMindIR when mode set to graph but run pynative
Merge pull request !29457 from yuchaojie/unify_ir
This commit is contained in:
commit
3e5eaa416f
|
@ -41,8 +41,6 @@ constexpr auto kAttrMode = "mode";
|
|||
constexpr auto kAttrChannelMultiplier = "channel_multiplier";
|
||||
constexpr auto kAttrInputSizes = "input_sizes";
|
||||
constexpr auto kAttrInputSize = "input_size";
|
||||
constexpr auto kIndex2 = 2;
|
||||
constexpr auto kIndex3 = 3;
|
||||
|
||||
bool NeedUpdate(const CNodePtr &conv2d, std::vector<size_t> in_shape, std::vector<size_t> out_shape) {
|
||||
MS_EXCEPTION_IF_NULL(conv2d);
|
||||
|
@ -72,6 +70,13 @@ bool NeedUpdate(const CNodePtr &conv2d, std::vector<size_t> in_shape, std::vecto
|
|||
return true;
|
||||
}
|
||||
|
||||
bool IsPynative() {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
return ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode ||
|
||||
ms_context->get_param<int>(MS_CTX_ENABLE_PYNATIVE_INFER);
|
||||
}
|
||||
|
||||
ValueNodePtr CreatePermValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &perm) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
|
@ -97,11 +102,9 @@ CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, cons
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(conv2d);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto perm = std::vector<int64_t>{1, 0, 2, 3};
|
||||
std::vector<AnfNodePtr> transpose_inputs;
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
if (IsPynative()) {
|
||||
transpose_inputs = {NewValueNode(std::make_shared<Primitive>(kTransposeOpName)), input_node};
|
||||
} else {
|
||||
transpose_inputs = {NewValueNode(std::make_shared<Primitive>(kTransposeOpName)), input_node,
|
||||
|
@ -129,7 +132,7 @@ CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, cons
|
|||
auto output_names = std::vector<std::string>{"output"};
|
||||
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), transpose);
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), transpose);
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
if (IsPynative()) {
|
||||
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(perm), transpose);
|
||||
}
|
||||
return transpose;
|
||||
|
|
|
@ -569,7 +569,6 @@ ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vec
|
|||
auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(shape_tensor->data().nbytes()), &shape[0], elem_num);
|
||||
if (ret_code != 0) {
|
||||
MS_LOG(EXCEPTION) << "Failed to copy data into tensor, memcpy_s errorno: " << ret_code;
|
||||
return nullptr;
|
||||
}
|
||||
shape_value = shape_tensor;
|
||||
abstract = std::make_shared<abstract::AbstractTensor>(kInt64, shape_vec_shape);
|
||||
|
|
Loading…
Reference in New Issue