forked from mindspore-Ecosystem/mindspore
assign label resource for new control sink
Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
parent
f23bfe0d71
commit
c99f904276
|
@ -0,0 +1,88 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include "device/ascend/ascend_label_assign.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
|
||||
static constexpr uint32_t kLabelGotoLabelId = 1;
|
||||
static constexpr uint32_t kLabelSwitchLabelId = 2;
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
|
||||
static void UpdateLabelGoto(NotNull<CNodePtr> node) {
|
||||
if (node->size() <= kLabelGotoLabelId) {
|
||||
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size();
|
||||
}
|
||||
auto label_set = AnfAlgo::GetCNodePrimitive(node->input(kLabelGotoLabelId));
|
||||
MS_EXCEPTION_IF_NULL(label_set);
|
||||
auto value = label_set->GetAttr(kAttrLabelIndex);
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
uint32_t goto_label_id = GetValue<uint32_t>(value);
|
||||
AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(goto_label_id), node.get());
|
||||
MS_LOG(INFO) << "Node " << node->DebugString() << " goto label id " << goto_label_id;
|
||||
}
|
||||
|
||||
static void UpdateLabelSwitch(NotNull<CNodePtr> node) {
|
||||
if (node->size() <= kLabelGotoLabelId) {
|
||||
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size();
|
||||
}
|
||||
std::vector<uint32_t> label_list;
|
||||
for (size_t i = kLabelSwitchLabelId; i < node->size(); ++i) {
|
||||
auto input = node->input(i);
|
||||
if (!input->isa<CNode>() || AnfAlgo::GetCNodeName(input) != kLabelSetOpName) {
|
||||
break;
|
||||
}
|
||||
|
||||
auto label_set = AnfAlgo::GetCNodePrimitive(input);
|
||||
MS_EXCEPTION_IF_NULL(label_set);
|
||||
auto value = label_set->GetAttr(kAttrLabelIndex);
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
uint32_t goto_label_id = GetValue<uint32_t>(value);
|
||||
label_list.push_back(goto_label_id);
|
||||
MS_LOG(INFO) << "Switch " << node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " << goto_label_id;
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue<std::vector<uint32_t>>(label_list), node.get());
|
||||
}
|
||||
|
||||
void AscendLabelAssign::AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &> graph) {
|
||||
auto cnode_list = graph->execution_order();
|
||||
// 1 assign label id to label_set
|
||||
uint32_t cur_label_id = 0;
|
||||
for (auto &node : cnode_list) {
|
||||
if (AnfAlgo::GetCNodeName(node) == kLabelSetOpName) {
|
||||
AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(cur_label_id), node);
|
||||
MS_LOG(INFO) << "Node " << node->DebugString() << " assign label id " << cur_label_id;
|
||||
++cur_label_id;
|
||||
}
|
||||
}
|
||||
// 2 update label_switch / label_goto
|
||||
for (auto &node : cnode_list) {
|
||||
if (AnfAlgo::GetCNodeName(node) == kLabelGotoOpName) {
|
||||
UpdateLabelGoto(NOT_NULL(node));
|
||||
}
|
||||
|
||||
if (AnfAlgo::GetCNodeName(node) == kLabelSwitchOpName) {
|
||||
UpdateLabelSwitch(NOT_NULL(node));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_
|
||||
#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_
|
||||
|
||||
#include <memory>
|
||||
#include "session/kernel_graph.h"
|
||||
#include "utils/contract.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
|
||||
class AscendLabelAssign {
|
||||
public:
|
||||
static AscendLabelAssign &GetInstance() {
|
||||
static AscendLabelAssign instance; // Guaranteed to be destroyed.
|
||||
return instance;
|
||||
}
|
||||
|
||||
AscendLabelAssign(const AscendLabelAssign &) = delete;
|
||||
AscendLabelAssign &operator=(const AscendLabelAssign &) = delete;
|
||||
|
||||
void AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &> graph);
|
||||
|
||||
private:
|
||||
AscendLabelAssign() = default;
|
||||
~AscendLabelAssign() = default;
|
||||
};
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_
|
|
@ -30,6 +30,7 @@
|
|||
#include "pre_activate/ascend/ascend_backend_optimization.h"
|
||||
#include "device/kernel_adjust.h"
|
||||
#include "device/ascend/ascend_stream_assign.h"
|
||||
#include "device/ascend/ascend_label_assign.h"
|
||||
#include "predict/predict.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "ir/scalar.h"
|
||||
|
@ -189,6 +190,8 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
|||
RootGraphExecutorValidate(graph.get());
|
||||
// assign stream
|
||||
AssignStream(graph);
|
||||
// assign label
|
||||
AssignLabel(NOT_NULL(graph));
|
||||
// build kernel if node is cnode
|
||||
BuildKernel(graph);
|
||||
// alloc mem
|
||||
|
@ -469,6 +472,12 @@ void AscendSession::AssignStream(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
||||
void AscendSession::AssignLabel(NotNull<const KernelGraphPtr &> kernel_graph) const {
|
||||
MS_LOG(INFO) << "Start!";
|
||||
device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kernel_graph);
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
||||
void AscendSession::BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
MS_LOG(INFO) << "Start!";
|
||||
struct timeval start_time, end_time;
|
||||
|
|
|
@ -74,6 +74,7 @@ class AscendSession : public SessionBasic {
|
|||
void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void AssignLabel(NotNull<const KernelGraphPtr &> kernel_graph) const;
|
||||
void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void MemoryAlloc(KernelGraph *kernel_graph) const;
|
||||
void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const;
|
||||
|
|
|
@ -14,12 +14,16 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "device/ascend/ascend_stream_assign.h"
|
||||
#include "device/ascend/ascend_label_assign.h"
|
||||
#include "device/ascend/tasksink/task_generator.h"
|
||||
#include "device/kernel_adjust.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
|
||||
void AscendLabelAssign::AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &>) {}
|
||||
|
||||
void AscendStreamAssign::AssignStreamNew(const KernelGraphPtr &graph) { return; }
|
||||
|
||||
uint32_t AscendStreamAssign::GetTotalStreamNum() const { return 1; }
|
||||
|
|
Loading…
Reference in New Issue