MindRT Ascend Devcie Context

This commit is contained in:
hwjiaorui 2021-11-08 16:03:48 +08:00
parent 1509d3f848
commit d6b2a34a69
20 changed files with 177 additions and 28 deletions

View File

@ -2612,14 +2612,18 @@ std::set<FuncGraphPtr> AnfRuntimeAlgorithm::GetFuncGraphbyCallNode(const AnfNode
if (AnfAlgo::CheckPrimitiveType(call_input0, prim::kPrimSwitch)) {
// First input node of call is switch node.
const auto &switch_inputs = call_input0->cast<CNodePtr>()->inputs();
const auto &input_cnode = call_input0->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
const auto &switch_inputs = input_cnode->inputs();
for (size_t i = kSwitchTrueBranchPos; i < switch_inputs.size(); ++i) {
MS_EXCEPTION_IF_NULL(switch_inputs[i]);
(void)func_graphs.emplace(GetFuncGraphFromPartial(switch_inputs[i], call_depth));
}
} else if (AnfAlgo::CheckPrimitiveType(call_input0, prim::kPrimSwitchLayer)) {
// First input node of call is switch layer node.
const auto &tuple_node = call_input0->cast<CNodePtr>()->input(kSwitchLayerBranchPos);
const auto &input_cnode = call_input0->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
const auto &tuple_node = input_cnode->input(kSwitchLayerBranchPos);
if (!AnfAlgo::CheckPrimitiveType(tuple_node, prim::kPrimMakeTuple)) {
MS_LOG(EXCEPTION) << "Invalid input tuple node:" << tuple_node->DebugString()
<< " for switch layer node:" << cnode->DebugString();

View File

@ -236,9 +236,9 @@ bool TensorNeedSync(const std::shared_ptr<KernelGraph> &kernel_graph, const AnfN
auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
if (tensor_address != device_address) {
if (!kernel_graph->is_dynamic_shape() && EnableDeviceCopy() && NeedMemcpyInDevice(tensor_address, device_address)) {
auto status = device_address->SyncDeviceToDevice(trans::GetRuntimePaddingShape(parameter, 0),
tensor_address->GetSize(), tensor_address->type_id(),
tensor_address->GetPtr(), tensor_address->format());
auto status = device_address->AsyncDeviceToDevice(trans::GetRuntimePaddingShape(parameter, 0),
tensor_address->GetSize(), tensor_address->type_id(),
tensor_address->GetPtr(), tensor_address->format());
if (!status) {
MS_LOG(EXCEPTION) << "SyncDeviceToDevice failed.";
}
@ -1830,9 +1830,9 @@ void AscendSession::UpdateOutputTensors(const VectorRef *outputs,
if (EnableDeviceCopy() && tensor->NeedSyncDeviceToHostImmediately()) {
auto dst_device_address = AssignExtraMemForGraphOutput(tensor, node, output_index);
MS_EXCEPTION_IF_NULL(dst_device_address);
if (!dst_device_address->SyncDeviceToDevice(trans::GetRuntimePaddingShape(node, output_index),
address->GetSize(), address->type_id(), address->GetPtr(),
address->format())) {
if (!dst_device_address->AsyncDeviceToDevice(trans::GetRuntimePaddingShape(node, output_index),
address->GetSize(), address->type_id(), address->GetPtr(),
address->format())) {
MS_LOG(EXCEPTION) << "SyncDeviceToDevice failed!";
}
tensor->set_sync_status(kNoNeedSync);

View File

@ -1600,6 +1600,7 @@ void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> &para
// for example "def f(x,y,z) {return x + y}", parameter z in unused
auto new_parameter = CreateNewParameter(parameter, graph);
graph_inputs->push_back(new_parameter);
graph->FrontBackendlMapAdd(parameter, new_parameter);
MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
continue;
}

View File

@ -1461,9 +1461,9 @@ void ClearResAtexit() {
mindspore::RDR::ResetRecorder();
#endif
session::ExecutorManager::Instance().Clear();
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
runtime::GraphScheduler::GetInstance().Clear();
device::DeviceContextManager::GetInstance().ClearDeviceContexts();
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
ad::g_k_prims.clear();
ad::ClearKPynativeCellStaticRes();
ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear();

View File

@ -180,6 +180,8 @@ void AscendDeviceAddress::BindDevice() const {
if (!ascend_device_context->BindDeviceToCurrentThread()) {
MS_LOG(EXCEPTION) << "BindDeviceToCurrentThread failed.";
}
} else {
MS_LOG(WARNING) << "device name is null.";
}
}
@ -432,9 +434,32 @@ bool AscendDeviceAddress::SyncHostToDevice(const ShapeVector &shape, size_t size
return sync_ok;
}
bool AscendDeviceAddress::SyncDeviceToDevice(const ShapeVector &, size_t size, TypeId type, const void *src_ptr,
bool AscendDeviceAddress::SyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr,
const std::string &format) const {
MS_LOG(INFO) << "SyncDeviceToDevice, dst(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_)
if (type_id_ > kMonadTypeBegin && type_id_ < kMonadTypeEnd) {
return true;
}
BindDevice();
if (format_ != format || type_id_ != type) {
MS_LOG(ERROR) << "format or type is different, src(format:" << format << ", type_id:" << TypeIdLabel(type)
<< "), dst(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_);
return false;
}
if (size_ < size) {
MS_LOG(ERROR) << "src size is greater than det size, src size is: " << size << ", dst size is: " << size_;
return false;
}
auto ret_rt_memcpy = rtMemcpy(ptr_, size, src_ptr, size, RT_MEMCPY_DEVICE_TO_DEVICE);
if (ret_rt_memcpy != RT_ERROR_NONE) {
MS_LOG(ERROR) << "SyncDeviceToDevice failed, rtMemcpy mem size [" << size << "], ret [" << ret_rt_memcpy << "]";
return false;
}
return true;
}
bool AscendDeviceAddress::AsyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr,
const std::string &format) const {
MS_LOG(INFO) << "AsyncDeviceToDevice, dst(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_)
<< ", size:" << size_ << "), src(format:" << format << ", type_id:" << TypeIdLabel(type)
<< ", size:" << size << ")";
if (type_id_ > kMonadTypeBegin && type_id_ < kMonadTypeEnd) {

View File

@ -50,6 +50,8 @@ class AscendDeviceAddress : public DeviceAddress {
bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const override;
bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
const std::string &format = "DefaultFormat") const override;
bool AsyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr,
const std::string &format) const override;
bool SyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr,
const std::string &format) const override;
void ClearDeviceMemory() override;

View File

@ -301,6 +301,7 @@ void AscendKernelRuntime::ReleaseDeviceRes() {
!context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
HcclCollectiveGroup::instance().FinalizeCollective();
}
initialized_ = false;
MS_LOG(INFO) << "Ascend finalize end";
}
@ -407,7 +408,10 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id, kAscendDevice, device_id);
auto ascend_device_address_ptr =
std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id, kAscendDevice, device_id);
ascend_device_address_ptr->set_is_ptr_persisted(true);
return ascend_device_address_ptr;
}
DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
@ -415,8 +419,10 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id, node_index, kAscendDevice,
device_id);
auto ascend_device_address_ptr = std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id,
node_index, kAscendDevice, device_id);
ascend_device_address_ptr->set_is_ptr_persisted(true);
return ascend_device_address_ptr;
}
bool AscendKernelRuntime::Load(const session::KernelGraph &graph, bool is_task_sink) {

View File

@ -73,7 +73,6 @@ class AscendKernelRuntime : public KernelRuntime {
void *GetModelStream(uint32_t graph_id) const override;
// add for MindRT
void ReleaseDeviceRes() override;
void SetCurrentContext();
protected:
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
@ -90,6 +89,7 @@ class AscendKernelRuntime : public KernelRuntime {
static bool HcclInit();
static bool NeedDestroyHccl();
static bool DestroyHccl();
void SetCurrentContext();
void ClearGraphModelMap();
bool GraphWithEmptyTaskList(const session::KernelGraph &graph) const;

View File

@ -148,6 +148,7 @@ void AiCoreDynamicKernel::AllocateWorkspace() {
workspace_addr_.clear();
for (auto size : workspaces_size_) {
auto device_address_ptr = std::make_shared<AscendDeviceAddress>(nullptr, size, kAscendDevice, device_id);
device_address_ptr->set_is_ptr_persisted(true);
auto device_ptr = runtime_instance->MallocMem(MemType::kDynamicMem, size, device_address_ptr);
if (device_ptr == nullptr) {
MS_LOG(EXCEPTION) << "MallocMem from memory pool failed. Node info :" << cnode->fullname_with_scope();

View File

@ -98,6 +98,8 @@ class DeviceAddress : public mindspore::DeviceSync {
TypeId type_id() const { return type_id_; }
bool from_mem_pool() const { return from_mem_pool_; }
void set_from_mem_pool(bool from_mem_pool) { from_mem_pool_ = from_mem_pool; }
bool is_ptr_persisted() const { return is_ptr_persisted_; }
void set_is_ptr_persisted(bool is_ptr_persisted) { is_ptr_persisted_ = is_ptr_persisted; }
void set_host_shape(const ShapeVector &shape) { host_shape_ = shape; }
virtual void set_status(DeviceAddressStatus status) {}
virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; }
@ -134,6 +136,10 @@ class DeviceAddress : public mindspore::DeviceSync {
ShapeVector host_shape_{};
// {node, out_index}
std::pair<AnfNodeWeakPtr, size_t> node_index_{AnfNodePtr(nullptr), 0};
// The device address of the node that owns the device address cannot be updated and replaced.
// application scenario: set to true when the hardware execution mode requires that ptr cannot be changed during
// execution.
bool is_ptr_persisted_{false};
// The key of device context.
std::string device_name_{""};

View File

@ -997,6 +997,7 @@ void KernelAdjust::AssignLoopCtrlTensorMem(const session::KernelGraph &kernel_gr
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
device_address =
std::make_shared<device::ascend::AscendDeviceAddress>(nullptr, size, format, type_id, kAscendDevice, device_id);
device_address->set_is_ptr_persisted(true);
if (runtime_instance->MallocMem(kStaticMem, size, device_address) == nullptr) {
MS_LOG(EXCEPTION) << "Cannot alloc static memory for device loop control parameter " << name

View File

@ -159,6 +159,10 @@ bool Copy(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_
} else if (dst_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
// Other device tensor copy to CPU device tensor.
return src_device_tensor->SyncDeviceToHost(copy_size, dst_device_tensor->GetMutablePtr());
} else if (dst_device_tensor->DeviceType() == src_device_tensor->DeviceType()) {
return dst_device_tensor->SyncDeviceToDevice(ShapeVector(), src_device_tensor->GetSize(),
src_device_tensor->type_id(), src_device_tensor->GetPtr(),
src_device_tensor->format());
} else {
MS_LOG(ERROR) << "Invalid device type, src device type: " << src_device_tensor->DeviceType()
<< ", dst device type: " << dst_device_tensor->DeviceType();

View File

@ -239,7 +239,8 @@ void DataPrepareActor::PrepareDataForHostTensorQueue(const std::vector<std::vect
auto tensor_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address());
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
MS_EXCEPTION_IF_NULL(device_address);
if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType())) {
if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType()) &&
!device_address->is_ptr_persisted()) {
AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get());
}
}

View File

@ -348,6 +348,9 @@ GraphId GraphCompiler::CompileGraph(const FuncGraphPtr &func_graph, const Device
auto graph_id = CompileGraphImpl(root_graph, device_context);
// dump all graphs.
device_context->DumpAllGraphs(all_graphs);
// Cache the backend graph output nodes to front nodes with output index.
auto output = func_graph->output();
MS_EXCEPTION_IF_NULL(output);

View File

@ -19,7 +19,6 @@
#include <set>
#include "backend/optimizer/ascend/ascend_backend_optimization.h"
#include "backend/optimizer/graph_kernel/graph_kernel_optimization.h"
#include "backend/session/ascend_auto_monad.h"
#include "utils/context/graph_kernel_flags.h"
#include "runtime/device/ascend/kernel_select_ascend.h"
#include "runtime/device/kernel_adjust.h"
@ -34,6 +33,17 @@
#include "debug/dump_proto.h"
#include "debug/data_dump/e2e_dump.h"
#endif
#ifdef ENABLE_DEBUGGER
#include "debug/tensor_load.h"
#include "debug/debugger/proto_exporter.h"
#else
#include "debug/debugger/proto_exporter_stub.h"
#endif
#ifdef ENABLE_DUMP_IR
#include "debug/rdr/running_data_recorder.h"
#include "debug/rdr/recorder_manager.h"
#include "debug/rdr/graph_recorder.h"
#endif
namespace mindspore {
namespace device {
@ -69,11 +79,47 @@ void Dump(const KernelGraphPtr &graph, uint32_t rank_id) {
}
#endif
void AscendDeviceContext::DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs) const {
#ifdef ENABLE_DUMP_IR
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto &json_parser = DumpJsonParser::GetInstance();
json_parser.Parse();
if (!save_graphs && !json_parser.e2e_dump_enabled() && !json_parser.async_dump_enabled() &&
!mindspore::RecorderManager::Instance().RdrEnable()) {
return;
}
for (auto &graph : all_graphs) {
MS_EXCEPTION_IF_NULL(graph);
std::string name = "graph_build." + std::to_string(graph->graph_id());
DumpGraphParams dump_params = {true, static_cast<int>(kWholeStack)};
(void)mindspore::RDR::RecordAnfGraph(SUBMODULE_ID, name, graph, dump_params, ".ir;.pb");
if (save_graphs) {
std::string file_name = "graph_build_" + std::to_string(graph->graph_id()) + ".ir";
DumpIR(file_name, graph, true, kWholeStack);
DumpIRProto(graph, "vm_build_" + std::to_string(graph->graph_id()));
DumpIR("trace_code_graph", graph, true, kWholeStack);
}
std::string final_graph = "trace_code_graph_" + std::to_string(graph->graph_id());
if (json_parser.e2e_dump_enabled() || json_parser.async_dump_enabled()) {
std::string root_dir = json_parser.path() + "/rank_" + std::to_string(rank_id_);
std::string target_dir = root_dir + "/graphs";
std::string ir_file_path = target_dir + "/" + "ms_output_" + final_graph + ".ir";
DumpIRProtoWithSrcInfo(graph, final_graph, target_dir, kDebugWholeStack);
DumpIR("trace_code_graph", graph, true, kWholeStack, ir_file_path);
DumpGraphExeOrder("ms_execution_order_graph_" + std::to_string(graph->graph_id()) + ".csv", root_dir,
graph->execution_order());
}
}
#endif
}
void AscendDeviceContext::Initialize() {
MS_LOG(INFO) << "Status record: Enter Initialize...";
if (initialized_) {
MS_EXCEPTION_IF_NULL(runtime_instance_);
runtime_instance_->SetCurrentContext();
runtime_instance_->SetContext();
return;
}
@ -109,8 +155,7 @@ void AscendDeviceContext::Destroy() {
}
MS_LOG(INFO) << "Status record: Destroy start...";
rank_id_ = 0;
if (runtime_instance_ != nullptr) {
runtime_instance_->ReleaseDeviceRes();
if (runtime_instance_) {
runtime_instance_ = nullptr;
}
initialized_ = false;
@ -181,9 +226,44 @@ void AscendDeviceContext::PreprocessBeforeRunGraph(const KernelGraphPtr &graph)
MS_LOG(INFO) << "PreprocessBeforeRunGraph success.";
return;
}
AssignOutputNopNodeDeviceAddress(graph);
MS_LOG(INFO) << "PreprocessBeforeRunGraph success.";
}
void AscendDeviceContext::AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph) const {
MS_EXCEPTION_IF_NULL(graph);
auto outputs = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
for (auto output : outputs) {
if (!output->isa<CNode>() || !AnfAlgo::IsRealKernel(output)) {
continue;
}
if (!opt::IsNopNode(output)) {
continue;
}
size_t input_num = AnfAlgo::GetInputTensorNum(output);
if (input_num != 1) {
MS_LOG(WARNING) << "The input number of nop node :" << output->fullname_with_scope() << " is " << input_num
<< ", not equal 1";
continue;
}
auto real_input_index = AnfAlgo::GetRealInputIndex(output, 0);
auto pre_node_out_device_address = AnfAlgo::GetPrevNodeOutputAddr(output, real_input_index);
MS_EXCEPTION_IF_NULL(pre_node_out_device_address);
auto ptr = pre_node_out_device_address->GetPtr();
auto size = pre_node_out_device_address->GetSize();
std::string output_format = AnfAlgo::GetOutputFormat(output, 0);
auto output_type = AnfAlgo::GetOutputDeviceDataType(output, 0);
auto device_address = CreateDeviceAddress(const_cast<void *>(ptr), size, output_format, output_type);
AnfAlgo::SetOutputAddr(device_address, 0, output.get());
AnfAlgo::SetNodeAttr(kAttrSkipNopOpAddr, MakeValue(false), output);
MS_LOG(INFO) << "Assign device address to output nop node " << output->fullname_with_scope();
}
}
void AscendDeviceContext::AllocateGraphMemory(const NotNull<KernelGraphPtr> &root_graph) const {
MS_EXCEPTION_IF_NULL(runtime_instance_);
runtime_instance_->ClearGlobalIdleMem();
@ -225,7 +305,7 @@ void AscendDeviceContext::LoadModel(const NotNull<KernelGraphPtr> &root_graph) c
bool AscendDeviceContext::AllocateMemory(DeviceAddress *const &address, size_t size) const {
MS_EXCEPTION_IF_NULL(address);
MS_EXCEPTION_IF_NULL(runtime_instance_);
runtime_instance_->SetCurrentContext();
runtime_instance_->SetContext();
auto device_ptr = mem_manager_->MallocMemFromMemPool(size);
if (!device_ptr) {
return false;
@ -249,7 +329,7 @@ void AscendDeviceContext::FreeMemory(DeviceAddress *const &address) const {
bool AscendDeviceContext::AllocateContinuousMemory(const std::vector<DeviceAddressPtr> &addr_list, size_t total_size,
const std::vector<size_t> &size_list) const {
MS_EXCEPTION_IF_NULL(runtime_instance_);
runtime_instance_->SetCurrentContext();
runtime_instance_->SetContext();
return mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list);
}
@ -296,7 +376,7 @@ bool AscendDeviceContext::LaunchGraph(const KernelGraphPtr &graph) const {
MS_LOG(INFO) << "Status record: start launch graph. graph id: " << graph->graph_id();
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(runtime_instance_);
runtime_instance_->SetCurrentContext();
runtime_instance_->SetContext();
device::KernelAdjust::GetInstance().LoadDeviceLoopCtrlParameters(graph);
auto ret = ExecuteGraph(graph);
MS_LOG(INFO) << "Status record: end launch graph. graph id: " << graph->graph_id();
@ -336,7 +416,7 @@ bool AscendDeviceContext::LaunchKernel(const CNodePtr &kernel, const vector<Addr
}
bool AscendDeviceContext::BindDeviceToCurrentThread() const {
runtime_instance_->SetCurrentContext();
runtime_instance_->SetContext();
return true;
}

View File

@ -126,6 +126,9 @@ class AscendDeviceContext : public DeviceContext {
// set rt_context_ to this thread to control device
bool BindDeviceToCurrentThread() const;
// dump all graphs.
void DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs) const override;
private:
// Graph loader interface
void AllocateGraphMemory(const NotNull<KernelGraphPtr> &root_graph) const;
@ -150,6 +153,7 @@ class AscendDeviceContext : public DeviceContext {
mutable std::set<KernelGraphPtr> memo_;
// Using node to get it's atomics
mutable std::map<CNodePtr, std::vector<CNodePtr>> node_atomics_;
void AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph) const;
};
} // namespace ascend
} // namespace device

View File

@ -51,6 +51,8 @@ void AscendGraphOptimization::OptimizeGraph(const KernelGraphPtr &graph) {
OptimizeGraphWithDeviceInfo(graph);
OptimizeExecutionOrder(graph);
PostOptimization(graph);
// must clear memo_ which holds kernelgraph after using AscendGraphOptimization class.
memo_.clear();
MS_LOG(INFO) << "Status record: end optimize graph. graph id: " << graph->graph_id();
}
@ -71,6 +73,9 @@ void AscendGraphOptimization::OptimizeGraphWithDeviceInfo(const KernelGraphPtr &
MS_EXCEPTION_IF_NULL(graph);
memo_.clear();
HardWareOptimization(graph);
// copy child graph ref output map to father graph ref output map
memo_.clear();
UpdateRefOutputMap(graph);
AnfAlgo::InsertMakeTupleForOutput(NOT_NULL(graph));
}
@ -112,9 +117,6 @@ void AscendGraphOptimization::OptimizeExecutionOrder(const KernelGraphPtr &graph
void AscendGraphOptimization::PostOptimization(const KernelGraphPtr &graph) {
MS_LOG(INFO) << "Status record: start post optimization. graph id: " << graph->graph_id();
// copy child graph ref output map to father graph ref output map
memo_.clear();
UpdateRefOutputMap(graph);
graph->SetInputNodes();
graph->SetOptimizerFlag();
MS_LOG(INFO) << "Status record: end post optimization. graph id: " << graph->graph_id();

View File

@ -152,6 +152,10 @@ class DeviceContext {
// Return collective communication object for caller to access
CollectiveCommunicationLibPtr collective_comm_lib() const { return collective_comm_lib_; }
// TODO(jiaorui): will be delete
// Dump all graphs.
virtual void DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs) const {}
protected:
DeviceContextKey device_context_key_;
CollectiveCommunicationLibPtr collective_comm_lib_;

View File

@ -43,6 +43,11 @@ class DeviceSync {
const std::string &format) const {
return true;
}
virtual bool AsyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr,
const std::string &format) const {
return true;
}
virtual void *GetMutablePtr() const = 0;
virtual void ClearDeviceMemory() = 0;

View File

@ -95,7 +95,7 @@ fi
while read line; do
if [ -f "${line}" ]; then
${CLANG_FORMAT} -i "${line}"
"${CLANG_FORMAT}" -i "${line}"
fi
done < "${FMT_FILE_LIST}"