!1231 Assign label resource for new control sink

Merge pull request !1231 from zhoufeng/label-assign
This commit is contained in:
mindspore-ci-bot 2020-05-19 10:48:35 +08:00 committed by Gitee
commit c793540cc9
5 changed files with 150 additions and 0 deletions

View File

@ -0,0 +1,88 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "device/ascend/ascend_label_assign.h"
#include "session/anf_runtime_algorithm.h"
static constexpr uint32_t kLabelGotoLabelId = 1;
static constexpr uint32_t kLabelSwitchLabelId = 2;
namespace mindspore {
namespace device {
namespace ascend {
static void UpdateLabelGoto(NotNull<CNodePtr> node) {
if (node->size() <= kLabelGotoLabelId) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size();
}
auto label_set = AnfAlgo::GetCNodePrimitive(node->input(kLabelGotoLabelId));
MS_EXCEPTION_IF_NULL(label_set);
auto value = label_set->GetAttr(kAttrLabelIndex);
MS_EXCEPTION_IF_NULL(value);
uint32_t goto_label_id = GetValue<uint32_t>(value);
AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(goto_label_id), node.get());
MS_LOG(INFO) << "Node " << node->DebugString() << " goto label id " << goto_label_id;
}
static void UpdateLabelSwitch(NotNull<CNodePtr> node) {
if (node->size() <= kLabelGotoLabelId) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size();
}
std::vector<uint32_t> label_list;
for (size_t i = kLabelSwitchLabelId; i < node->size(); ++i) {
auto input = node->input(i);
if (!input->isa<CNode>() || AnfAlgo::GetCNodeName(input) != kLabelSetOpName) {
break;
}
auto label_set = AnfAlgo::GetCNodePrimitive(input);
MS_EXCEPTION_IF_NULL(label_set);
auto value = label_set->GetAttr(kAttrLabelIndex);
MS_EXCEPTION_IF_NULL(value);
uint32_t goto_label_id = GetValue<uint32_t>(value);
label_list.push_back(goto_label_id);
MS_LOG(INFO) << "Switch " << node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " << goto_label_id;
}
AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue<std::vector<uint32_t>>(label_list), node.get());
}
void AscendLabelAssign::AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &> graph) {
auto cnode_list = graph->execution_order();
// 1 assign label id to label_set
uint32_t cur_label_id = 0;
for (auto &node : cnode_list) {
if (AnfAlgo::GetCNodeName(node) == kLabelSetOpName) {
AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(cur_label_id), node);
MS_LOG(INFO) << "Node " << node->DebugString() << " assign label id " << cur_label_id;
++cur_label_id;
}
}
// 2 update label_switch / label_goto
for (auto &node : cnode_list) {
if (AnfAlgo::GetCNodeName(node) == kLabelGotoOpName) {
UpdateLabelGoto(NOT_NULL(node));
}
if (AnfAlgo::GetCNodeName(node) == kLabelSwitchOpName) {
UpdateLabelSwitch(NOT_NULL(node));
}
}
}
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,48 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_
#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_
#include <memory>
#include "session/kernel_graph.h"
#include "utils/contract.h"
namespace mindspore {
namespace device {
namespace ascend {
class AscendLabelAssign {
public:
static AscendLabelAssign &GetInstance() {
static AscendLabelAssign instance; // Guaranteed to be destroyed.
return instance;
}
AscendLabelAssign(const AscendLabelAssign &) = delete;
AscendLabelAssign &operator=(const AscendLabelAssign &) = delete;
void AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &> graph);
private:
AscendLabelAssign() = default;
~AscendLabelAssign() = default;
};
} // namespace ascend
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_

View File

@ -30,6 +30,7 @@
#include "pre_activate/ascend/ascend_backend_optimization.h"
#include "device/kernel_adjust.h"
#include "device/ascend/ascend_stream_assign.h"
#include "device/ascend/ascend_label_assign.h"
#include "predict/predict.h"
#include "session/anf_runtime_algorithm.h"
#include "ir/scalar.h"
@ -189,6 +190,8 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
RootGraphExecutorValidate(graph.get());
// assign stream
AssignStream(graph);
// assign label
AssignLabel(NOT_NULL(graph));
// build kernel if node is cnode
BuildKernel(graph);
// alloc mem
@ -469,6 +472,12 @@ void AscendSession::AssignStream(const std::shared_ptr<KernelGraph> &kernel_grap
MS_LOG(INFO) << "Finish!";
}
void AscendSession::AssignLabel(NotNull<const KernelGraphPtr &> kernel_graph) const {
MS_LOG(INFO) << "Start!";
device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kernel_graph);
MS_LOG(INFO) << "Finish!";
}
void AscendSession::BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_LOG(INFO) << "Start!";
struct timeval start_time, end_time;

View File

@ -74,6 +74,7 @@ class AscendSession : public SessionBasic {
void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void AssignLabel(NotNull<const KernelGraphPtr &> kernel_graph) const;
void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void MemoryAlloc(KernelGraph *kernel_graph) const;
void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const;

View File

@ -14,12 +14,16 @@
* limitations under the License.
*/
#include "device/ascend/ascend_stream_assign.h"
#include "device/ascend/ascend_label_assign.h"
#include "device/ascend/tasksink/task_generator.h"
#include "device/kernel_adjust.h"
namespace mindspore {
namespace device {
namespace ascend {
void AscendLabelAssign::AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &>) {}
void AscendStreamAssign::AssignStreamNew(const KernelGraphPtr &graph) { return; }
uint32_t AscendStreamAssign::GetTotalStreamNum() const { return 1; }