forked from mindspore-Ecosystem/mindspore
actor runtime code review modify
This commit is contained in:
parent
fba1dd8f2f
commit
57a14f55e8
|
@ -69,11 +69,16 @@ class DeviceAddress : public mindspore::DeviceSync {
|
|||
virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; }
|
||||
virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; }
|
||||
void *GetMutablePtr() const override { return ptr_; }
|
||||
|
||||
// The related interface of reference count operation.
|
||||
void set_original_ref_count(size_t original_ref_count) { original_ref_count_ = original_ref_count; }
|
||||
size_t original_ref_count() const { return original_ref_count_; }
|
||||
void set_ref_count(size_t ref_count) { ref_count_ = ref_count; }
|
||||
void IncreaseRefCount() { ref_count_++; }
|
||||
void DecreaseRefCountUsed() { ref_count_dynamic_used_--; }
|
||||
void ResetRefCountUsed() { ref_count_dynamic_used_ = ref_count_; }
|
||||
size_t ref_count_dynamic_used() const { return ref_count_dynamic_used_; }
|
||||
size_t ref_count() const { return ref_count_; }
|
||||
void IncreaseOriginalRefCount() { original_ref_count_++; }
|
||||
void DecreaseRefCount() { ref_count_--; }
|
||||
void ResetRefCount() { ref_count_ = original_ref_count_; }
|
||||
|
||||
virtual bool DumpMemToFile(const std::string &filepath, const std::string &host_fmt, const ShapeVector &host_shape,
|
||||
TypeId host_type, bool trans_flag) const {
|
||||
return true;
|
||||
|
@ -91,9 +96,9 @@ class DeviceAddress : public mindspore::DeviceSync {
|
|||
void set_ptr(void *ptr) { ptr_ = ptr; }
|
||||
void *ptr_{nullptr};
|
||||
size_t size_{0};
|
||||
size_t original_ref_count_{1};
|
||||
// It will be decreased in the running, and reset by original_ref_count_ when it is zero.
|
||||
size_t ref_count_{1};
|
||||
// It will be decreased in the running, and reset by ref_count_ when it is zero.
|
||||
size_t ref_count_dynamic_used_{1};
|
||||
string format_{"DefaultFormat"};
|
||||
TypeId type_id_{kNumberTypeFloat16};
|
||||
bool from_mem_pool_{false};
|
||||
|
|
|
@ -36,7 +36,7 @@ using mindspore::device::DeviceContext;
|
|||
|
||||
// The data source actor is used to fetch data from data source and process them into device tensors,
|
||||
// and then send them to kernel actor. The processing flow is FetchData -> FillDataBuffer -> AllocateMemory
|
||||
// -> OnMemoryAllocFinish -> SendOutput -> FreeMemory.
|
||||
// -> OnMemoryAllocFinish -> FreeMemory -> SendOutput.
|
||||
class DataSourceActor : public MemoryInterfaceActor {
|
||||
public:
|
||||
DataSourceActor(std::string name, size_t buffer_capacity, const DeviceContext *device_context,
|
||||
|
|
|
@ -37,7 +37,7 @@ using mindspore::kernel::AddressPtr;
|
|||
|
||||
// The kernel actor is used to receive the device tensors and control info to luanch kernel.
|
||||
// The processing flow is RunOpData/RunOpControl -> CheckLaunchCondition -> AllocateMemory
|
||||
// -> OnMemoryAllocFinish -> LaunchKernel -> SendOutput -> FreeMemory.
|
||||
// -> OnMemoryAllocFinish -> LaunchKernel -> FreeMemory -> SendOutput.
|
||||
class KernelActor : public MemoryInterfaceActor {
|
||||
public:
|
||||
KernelActor(std::string name, CNodePtr kernel, const DeviceContext *device_context, const AID memory_manager_aid)
|
||||
|
|
|
@ -50,13 +50,13 @@ void MemoryManagerActor::FreeMemory(std::vector<DeviceTensor *> free_list, const
|
|||
for (auto &device_tensor : free_list) {
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
// The reference count is decremented to zero to free memory, and reset to the original count.
|
||||
device_tensor->DecreaseRefCountUsed();
|
||||
if (device_tensor->ref_count_dynamic_used() == 0) {
|
||||
device_tensor->DecreaseRefCount();
|
||||
if (device_tensor->ref_count() == 0) {
|
||||
// Free memory through the device context.
|
||||
if (device_tensor->GetPtr() != nullptr) {
|
||||
device_context->FreeMemory(device_tensor);
|
||||
}
|
||||
device_tensor->ResetRefCountUsed();
|
||||
device_tensor->ResetRefCount();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -78,10 +78,8 @@ void CreateParameterDeviceAddress(const DeviceContext *device_context, const Ker
|
|||
auto output_size = AnfAlgo::GetOutputTensorNum(item);
|
||||
for (size_t index = 0; index < output_size; index++) {
|
||||
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
|
||||
// if graph output is a weight and doesn't link to any cnode, it's data type will be unknown
|
||||
if (output_type_id == kTypeUnknown) {
|
||||
MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph";
|
||||
continue;
|
||||
output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
|
||||
}
|
||||
|
||||
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
|
||||
|
@ -212,13 +210,9 @@ GraphId GraphCompiler::CompileGraph(const AnfNodePtrList &nodes, const AnfNodePt
|
|||
GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(device_context_);
|
||||
// Optimization pass which is irrelevant to device type or format.
|
||||
device_context_->OptimizeGraphWithoutDeviceInfo(graph);
|
||||
|
||||
device_context_->SetOperatorInfo(graph->execution_order());
|
||||
|
||||
// Optimization pass which is relevant to device type or format.
|
||||
device_context_->OptimizeGraphWithDeviceInfo(graph);
|
||||
// Execute optimization pass.
|
||||
device_context_->OptimizeGraph(graph);
|
||||
|
||||
// Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
|
||||
// 'KernelMod' is real executive object of kernel.
|
||||
|
@ -248,9 +242,8 @@ GraphId GraphCompiler::CompileGraph(session::OpRunInfo *op_run_info, const Graph
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(device_context_);
|
||||
device_context_->SetOperatorInfo(graph->execution_order());
|
||||
|
||||
device_context_->OptimizeSingleOpGraph(graph);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
session_->RunOpHideNopNode(graph);
|
||||
session_->RunOpRemoveNopNode(graph);
|
||||
|
|
|
@ -99,8 +99,8 @@ void UpdateRefCount(const AnfNodePtr &node, size_t output_idx) {
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(node, output_idx);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
device_tensor->IncreaseRefCount();
|
||||
device_tensor->ResetRefCountUsed();
|
||||
device_tensor->IncreaseOriginalRefCount();
|
||||
device_tensor->ResetRefCount();
|
||||
}
|
||||
|
||||
// The branch processing of PrepareDataForValueNode that value type is tensor.
|
||||
|
@ -252,8 +252,8 @@ BaseRef CreateOutputTensor(const session::KernelWithIndex &node_output_pair, con
|
|||
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, output_index);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
tensor->set_device_address(device_tensor);
|
||||
device_tensor->set_ref_count(SIZE_MAX);
|
||||
device_tensor->ResetRefCountUsed();
|
||||
device_tensor->set_original_ref_count(SIZE_MAX);
|
||||
device_tensor->ResetRefCount();
|
||||
return tensor;
|
||||
}
|
||||
}
|
||||
|
@ -307,8 +307,8 @@ void AllocateContinuousMemoryForInput(const AnfNodePtr &kernel, const DeviceCont
|
|||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
// In the scene of communication op and computing op parallel multi stream, the input address of communication op
|
||||
// can't be reused, so set the max reference count.
|
||||
device_tensor->set_ref_count(SIZE_MAX);
|
||||
device_tensor->ResetRefCountUsed();
|
||||
device_tensor->set_original_ref_count(SIZE_MAX);
|
||||
device_tensor->ResetRefCount();
|
||||
|
||||
if (device_tensor->GetPtr() == nullptr) {
|
||||
is_need_alloc_memory = true;
|
||||
|
@ -341,8 +341,8 @@ void AllocateContinuousMemoryForOutput(const AnfNodePtr &kernel, const DeviceCon
|
|||
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
// One time application for continuous memory, so set the max reference count.
|
||||
device_tensor->set_ref_count(SIZE_MAX);
|
||||
device_tensor->ResetRefCountUsed();
|
||||
device_tensor->set_original_ref_count(SIZE_MAX);
|
||||
device_tensor->ResetRefCount();
|
||||
|
||||
if (device_tensor->GetPtr() == nullptr) {
|
||||
is_need_alloc_memory = true;
|
||||
|
@ -925,8 +925,8 @@ void GraphScheduler::PersistDeviceTensor(const KernelGraphPtr &graph) {
|
|||
}
|
||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0);
|
||||
DeviceTensorStore::GetInstance().Insert(value_node.get(), device_tensor);
|
||||
device_tensor->set_ref_count(SIZE_MAX);
|
||||
device_tensor->ResetRefCountUsed();
|
||||
device_tensor->set_original_ref_count(SIZE_MAX);
|
||||
device_tensor->ResetRefCount();
|
||||
}
|
||||
|
||||
for (auto &input_node : graph->input_nodes()) {
|
||||
|
@ -935,8 +935,8 @@ void GraphScheduler::PersistDeviceTensor(const KernelGraphPtr &graph) {
|
|||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
DeviceTensorStore::GetInstance().Insert(input_node.get(), device_tensor);
|
||||
device_tensor->set_ref_count(SIZE_MAX);
|
||||
device_tensor->ResetRefCountUsed();
|
||||
device_tensor->set_original_ref_count(SIZE_MAX);
|
||||
device_tensor->ResetRefCount();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1008,7 +1008,7 @@ void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &of
|
|||
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(data_kernel, i, false);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
ofs << "\t\t\toutput_index:" << i << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
|
||||
<< "\tref_count:" << device_tensor->ref_count_dynamic_used() << "\n ";
|
||||
<< "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n ";
|
||||
}
|
||||
} else if (actor_name.find("_HostQueueDataSourceActor") != string::npos) {
|
||||
// Dump the member info of host queue data source actor.
|
||||
|
@ -1021,7 +1021,7 @@ void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &of
|
|||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
ofs << "\t\t\tnode_order_number:" << i << "\tnode_name:" << data_node->fullname_with_scope()
|
||||
<< "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
|
||||
<< "\tref_count:" << device_tensor->ref_count_dynamic_used() << "\n ";
|
||||
<< "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n ";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1065,7 +1065,7 @@ void GraphScheduler::DumpKernelActor(const KernelActor *actor, std::ofstream &of
|
|||
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
ofs << "\t\t\toutput_index:" << i << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
|
||||
<< "\tref_count:" << device_tensor->ref_count_dynamic_used() << "\n ";
|
||||
<< "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n ";
|
||||
}
|
||||
|
||||
ofs << "\t\tdevice_tensor_stores:" << actor->device_tensor_store_keys_.size() << "\n ";
|
||||
|
|
|
@ -55,10 +55,11 @@ DeviceAddressPtr CPUDeviceContext::CreateDeviceAddress(void *device_ptr, size_t
|
|||
return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id);
|
||||
}
|
||||
|
||||
void CPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const {
|
||||
void CPUDeviceContext::OptimizeGraph(const KernelGraphPtr &graph) const {
|
||||
// Update Graph Dynamic Shape Attr.
|
||||
UpdateGraphDynamicShapeAttr(NOT_NULL(graph));
|
||||
|
||||
SetOperatorInfo(graph->execution_order());
|
||||
OptimizeGraphImpl(graph);
|
||||
|
||||
// Remove reorder after PS feature finish adapting push/pull in auto_monad.
|
||||
|
@ -67,7 +68,11 @@ void CPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &grap
|
|||
graph->set_execution_order(execution_order);
|
||||
}
|
||||
|
||||
void CPUDeviceContext::OptimizeSingleOpGraph(const KernelGraphPtr &graph) const { OptimizeGraphImpl(graph); }
|
||||
void CPUDeviceContext::OptimizeSingleOpGraph(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
SetOperatorInfo(graph->execution_order());
|
||||
OptimizeGraphImpl(graph);
|
||||
}
|
||||
|
||||
void CPUDeviceContext::OptimizeGraphImpl(const KernelGraphPtr &graph) const {
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
|
|
|
@ -41,7 +41,7 @@ class CPUDeviceContext : public DeviceContext {
|
|||
TypeId type_id) const override;
|
||||
DeviceAddressType GetDeviceAddressType() const override { return DeviceAddressType::kCPU; }
|
||||
|
||||
void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const override;
|
||||
void OptimizeGraph(const KernelGraphPtr &graph) const override;
|
||||
void OptimizeSingleOpGraph(const KernelGraphPtr &graph) const override;
|
||||
|
||||
void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const override;
|
||||
|
|
|
@ -70,11 +70,8 @@ class DeviceContext {
|
|||
// Get device address type according different device type, such GPU, Ascend.
|
||||
virtual DeviceAddressType GetDeviceAddressType() const = 0;
|
||||
|
||||
// The two functions below will be merged to one in the future.
|
||||
// General graph optimezer ignore device data type and format.
|
||||
virtual void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const {}
|
||||
// Optimize the kernel graph according to device data type and format.
|
||||
virtual void OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph) const {}
|
||||
// Optimize the kernel graph for graph mode.
|
||||
virtual void OptimizeGraph(const KernelGraphPtr &graph) const {}
|
||||
|
||||
// Optimize the single operator graph for PyNative mode.
|
||||
virtual void OptimizeSingleOpGraph(const KernelGraphPtr &graph) const {}
|
||||
|
|
|
@ -165,6 +165,17 @@ DeviceAddressPtr GPUDeviceContext::CreateDeviceAddress(void *device_ptr, size_t
|
|||
return std::make_shared<GPUDeviceAddress>(device_ptr, device_size, format, type_id);
|
||||
}
|
||||
|
||||
void GPUDeviceContext::OptimizeGraph(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// Optimization pass which is irrelevant to device type or format.
|
||||
OptimizeGraphWithoutDeviceInfo(graph);
|
||||
|
||||
SetOperatorInfo(graph->execution_order());
|
||||
|
||||
// Optimization pass which is relevant to device type or format.
|
||||
OptimizeGraphWithDeviceInfo(graph);
|
||||
}
|
||||
|
||||
void GPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// Operator fusion optimization.
|
||||
|
@ -240,6 +251,9 @@ void GPUDeviceContext::UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr>
|
|||
}
|
||||
|
||||
void GPUDeviceContext::OptimizeSingleOpGraph(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
SetOperatorInfo(graph->execution_order());
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::ReducePrecisionFusion>("reduce_precision"));
|
||||
|
|
|
@ -48,11 +48,8 @@ class GPUDeviceContext : public DeviceContext {
|
|||
TypeId type_id) const override;
|
||||
DeviceAddressType GetDeviceAddressType() const override { return DeviceAddressType::kGPU; }
|
||||
|
||||
// General graph optimezer ignore device data type and format.
|
||||
void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const override;
|
||||
// Optimize the kernel graph according to device type, such format transform.
|
||||
void OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph) const override;
|
||||
|
||||
// Optimize the kernel graph for graph mode.
|
||||
void OptimizeGraph(const KernelGraphPtr &graph) const override;
|
||||
// Optimize the single operator graph for PyNative mode.
|
||||
void OptimizeSingleOpGraph(const KernelGraphPtr &graph) const override;
|
||||
|
||||
|
@ -67,6 +64,11 @@ class GPUDeviceContext : public DeviceContext {
|
|||
DISABLE_COPY_AND_ASSIGN(GPUDeviceContext);
|
||||
bool InitDevice();
|
||||
|
||||
// General graph optimezer ignore device data type and format.
|
||||
void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const;
|
||||
// Optimize the kernel graph according to device type, such format transform.
|
||||
void OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph) const;
|
||||
|
||||
// Operator fusion optimization.
|
||||
void FuseOperators(const KernelGraphPtr &graph) const;
|
||||
|
||||
|
|
Loading…
Reference in New Issue