forked from mindspore-Ecosystem/mindspore
!3510 Cancel insert assign from condition to true branch of while
Merge pull request !3510 from chenfei_mindspore/avoid-insert-assign-from-condition-to-true-branch
This commit is contained in:
commit
d14a01f403
|
@ -22,6 +22,7 @@
|
|||
#include "backend/optimizer/pass/convert_const_input_to_tensor_input.h"
|
||||
#include "backend/optimizer/pass/convert_tuple_input_to_dynamic_input.h"
|
||||
#include "backend/optimizer/pass/const_to_attr_strided_slice_grad.h"
|
||||
#include "backend/optimizer/pass/convert_const_scalar_to_tensor.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
|
||||
|
@ -46,8 +47,9 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
|
|||
auto common_pm = std::make_shared<PassManager>("common_pm");
|
||||
common_pm->AddPass(std::make_shared<ConvertConstInputToAttr>());
|
||||
common_pm->AddPass(std::make_shared<ConstToAttrStridedSliceGradPass>());
|
||||
common_pm->AddPass(std::make_shared<ConvertTupleOutputToMaketuple>());
|
||||
common_pm->AddPass(std::make_shared<ConvertConstInputToTensorInput>());
|
||||
common_pm->AddPass(std::make_shared<ConvertTupleOutputToMaketuple>());
|
||||
common_pm->AddPass(std::make_shared<ConvertConstScalarToTensor>());
|
||||
common_pm->AddPass(std::make_shared<ConvertTupleInputToDynamicInput>());
|
||||
optimizer->AddPassManager(common_pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
|
|
|
@ -781,5 +781,27 @@ bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &suppor
|
|||
MS_LOG(DEBUG) << "Not supported data type. Node:" << node->DebugString();
|
||||
return false;
|
||||
}
|
||||
|
||||
ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
|
||||
new_value_node->set_abstract(value_node->abstract());
|
||||
// create kernel_info fo new value node
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
new_value_node->set_kernel_info(kernel_info);
|
||||
// create kernel_build_info for new value node
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
// set the format of value_node to DEFAULT_FORMAT
|
||||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
|
||||
// set value node initial device data type = infer data type
|
||||
std::vector<TypeId> types;
|
||||
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) {
|
||||
types.push_back(kTypeUnknown);
|
||||
}
|
||||
kernel_build_info_builder->SetOutputsDeviceType(types);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
|
||||
return new_value_node;
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -194,6 +194,9 @@ bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name);
|
|||
|
||||
// Check node's data type is in supported data type set
|
||||
bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &supported_data_type_set);
|
||||
|
||||
// Create a new value node of func graph,not kernel graph
|
||||
ValueNodePtr MakeValueNode(const ValueNodePtr &value_node);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_
|
||||
|
|
|
@ -29,28 +29,8 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
|
||||
new_value_node->set_abstract(value_node->abstract());
|
||||
// create kernel_info fo new value node
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
new_value_node->set_kernel_info(kernel_info);
|
||||
// create kernel_build_info for new value node
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
// set the format of value_node to DEFAULT_FORMAT
|
||||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
|
||||
// set value node initial device data type = infer data type
|
||||
std::vector<TypeId> types;
|
||||
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) {
|
||||
types.push_back(kTypeUnknown);
|
||||
}
|
||||
kernel_build_info_builder->SetOutputsDeviceType(types);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
|
||||
return new_value_node;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateTensorInput(const KernelGraphPtr &kernel_graph, const AnfNodePtr &input_node) {
|
||||
AnfNodePtr CreateTensorInput(const AnfNodePtr &node, const KernelGraphPtr &kernel_graph, const AnfNodePtr &input_node) {
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
auto value_node = input_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
|
@ -60,6 +40,9 @@ AnfNodePtr CreateTensorInput(const KernelGraphPtr &kernel_graph, const AnfNodePt
|
|||
if (value->isa<Scalar>()) {
|
||||
tensor_ptr = ScalarToTensor(value->cast<ScalarPtr>());
|
||||
} else if (value->isa<ValueTuple>()) {
|
||||
if (!AnfAlgo::IsRealCNodeKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
tensor_ptr = CreateTupleTensor(value->cast<ValueTuplePtr>());
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The value should be a scalar or value tuple";
|
||||
|
@ -93,7 +76,7 @@ AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePt
|
|||
for (size_t i = 0; i < inputs.size() - 1; ++i) {
|
||||
auto input_node = inputs[i + 1];
|
||||
if (IsValueNode<Scalar>(input_node) || IsValueNode<ValueTuple>(input_node)) {
|
||||
auto tensor_input = CreateTensorInput(kernel_graph, input_node);
|
||||
auto tensor_input = CreateTensorInput(cnode, kernel_graph, input_node);
|
||||
if (tensor_input == nullptr) {
|
||||
new_inputs.push_back(input_node);
|
||||
continue;
|
||||
|
@ -139,7 +122,7 @@ AnfNodePtr ProcessGraphKernelOp(const AnfNodePtr &node) {
|
|||
|
||||
const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (node == nullptr || func_graph == nullptr || AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
||||
if (node == nullptr || func_graph == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!node->isa<CNode>()) {
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/pass/convert_const_scalar_to_tensor.h"
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "utils/graph_utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
AnfNodePtr CreateTensorInput(const KernelGraphPtr &kernel_graph, const AnfNodePtr &input_node) {
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (!input_node->isa<ValueNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto value_node = input_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (!value->isa<Scalar>()) {
|
||||
return nullptr;
|
||||
}
|
||||
tensor::TensorPtr tensor_ptr = ScalarToTensor(value->cast<ScalarPtr>());
|
||||
if (tensor_ptr == nullptr) {
|
||||
MS_LOG(WARNING) << "Create tensor of" << input_node->DebugString() << "failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto tensor_input = std::make_shared<ValueNode>(tensor_ptr);
|
||||
MS_EXCEPTION_IF_NULL(tensor_input);
|
||||
tensor_input->set_abstract(tensor_ptr->ToAbstract());
|
||||
if (kernel_graph != nullptr) {
|
||||
tensor_input = kernel_graph->NewValueNode(tensor_input);
|
||||
kernel_graph->AddValueNodeToGraph(tensor_input);
|
||||
} else {
|
||||
tensor_input = MakeValueNode(tensor_input);
|
||||
}
|
||||
tensor_input->set_scope(input_node->scope());
|
||||
return tensor_input;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const AnfNodePtr ConvertConstScalarToTensor::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (node == nullptr || func_graph == nullptr || AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!node->isa<CNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
bool input_changed = false;
|
||||
for (size_t i = 0; i < cnode->inputs().size(); ++i) {
|
||||
auto new_input = CreateTensorInput(func_graph->cast<KernelGraphPtr>(), cnode->inputs()[i]);
|
||||
if (new_input != nullptr) {
|
||||
cnode->set_input(i, new_input);
|
||||
input_changed = true;
|
||||
}
|
||||
}
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
if (kernel_graph == nullptr || !input_changed) {
|
||||
return nullptr;
|
||||
}
|
||||
return kernel_graph->NewCNode(cnode);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONVERT_CONST_SCALAR_TO_TENSOR_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONVERT_CONST_SCALAR_TO_TENSOR_H_
|
||||
#include <string>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ConvertConstScalarToTensor : public PatternProcessPass {
|
||||
public:
|
||||
explicit ConvertConstScalarToTensor(bool multigraph = true)
|
||||
: PatternProcessPass("convert_const_scalar_to_tensor", multigraph) {}
|
||||
~ConvertConstScalarToTensor() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONVERT_CONST_SCALAR_TO_TENSOR_H_
|
|
@ -75,8 +75,10 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
|
|||
}
|
||||
}
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
return cnode_input_changed ? kernel_graph->NewCNode(cnode) : nullptr;
|
||||
if (kernel_graph == nullptr || !cnode_input_changed) {
|
||||
return nullptr;
|
||||
}
|
||||
return kernel_graph->NewCNode(cnode);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -881,22 +881,22 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu
|
|||
return;
|
||||
}
|
||||
memo->insert(graph.get());
|
||||
|
||||
graph->UpdateChildGraphOrder();
|
||||
for (auto &child_graph : graph->child_graph_order()) {
|
||||
CreateMultiBranchOutput(NOT_NULL(child_graph), memo);
|
||||
}
|
||||
|
||||
// If graph has no output, the graph is the true graph of while and will call condition graph, no need insert assign
|
||||
// from condition to true graph
|
||||
if (graph->get_output_null()) {
|
||||
return;
|
||||
}
|
||||
std::map<AnfNodePtr, AnfNodePtr> need_replace_list;
|
||||
auto node_list = GetCNodes(TopoSort(graph->get_return()));
|
||||
for (auto &node : node_list) {
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
|
||||
// create a parameter to store the output of multiple branch and set the parameter as the condition graph's output
|
||||
// auto multi_output_param = graph->NewParameter();
|
||||
auto origin_inputs = graph->inputs();
|
||||
auto output_param = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract()));
|
||||
MS_EXCEPTION_IF_NULL(graph->MutableInputs());
|
||||
graph->MutableInputs()->operator=(origin_inputs);
|
||||
graph->AddChildGraphResult(output_param);
|
||||
|
||||
std::vector<AnfNodePtr> depend_inputs = {
|
||||
|
|
|
@ -133,7 +133,6 @@ AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) {
|
|||
if (value_node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
|
||||
new_value_node->set_abstract(value_node->abstract());
|
||||
this->SetKernelInfoForNode(new_value_node);
|
||||
|
|
|
@ -99,18 +99,13 @@ TEST_F(TestHWConstInputToTensorInput, test_value_tuple_tensor_input) {
|
|||
EXPECT_NE(ret->input(1)->cast<CNodePtr>(), nullptr);
|
||||
auto cnode = ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>();
|
||||
EXPECT_EQ(AnfAlgo::GetCNodeName(cnode), prim::kPrimDropoutGenMask->name());
|
||||
std::vector<int> out;
|
||||
for (size_t i = 1; i <= 4; i++) {
|
||||
auto input = cnode->input(i);
|
||||
ASSERT_TRUE(input != nullptr);
|
||||
EXPECT_TRUE(IsValueNode<tensor::Tensor>(input));
|
||||
auto tensor = input->cast<ValueNodePtr>()->value()->cast<tensor::TensorPtr>();
|
||||
ASSERT_TRUE(tensor != nullptr);
|
||||
int *data = (int *)(tensor->data_c());
|
||||
ASSERT_TRUE(data != nullptr);
|
||||
out.push_back(*data);
|
||||
}
|
||||
EXPECT_EQ(out, std::vector<int>({2, 4, 2, 2}));
|
||||
auto input1 = cnode->input(1);
|
||||
ASSERT_TRUE(input1 != nullptr);
|
||||
EXPECT_TRUE(IsValueNode<tensor::Tensor>(input1));
|
||||
auto tensor = input1->cast<ValueNodePtr>()->value()->cast<tensor::TensorPtr>();
|
||||
ASSERT_TRUE(tensor != nullptr);
|
||||
auto data = tensor->data_c();
|
||||
EXPECT_EQ(std::vector<int>((int *)data, (int *)data + 4), std::vector<int>({2, 4, 2, 2}));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue