Merge pull request !24237 from kisnwang/clean-code
This commit is contained in:
i-robot 2021-09-28 06:32:33 +00:00 committed by Gitee
commit 9ef686f046
11 changed files with 49 additions and 30 deletions

View File

@ -14,7 +14,6 @@
* limitations under the License.
*/
#include <algorithm>
#include "backend/session/ascend_inference_session.h"
#include "ir/tensor.h"
#include "ir/anf.h"

View File

@ -233,7 +233,7 @@ size_t AscendMemoryManager::GetAvailableMemSize() {
return available_mem_size;
}
void AscendMemoryManager::SwapIn(void *host_ptr, void *device_ptr, size_t mem_size, void *stream) {
void AscendMemoryManager::SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) {
if (stream == nullptr) {
auto ret_rt_memcpy = rtMemcpy(device_ptr, mem_size, host_ptr, mem_size, RT_MEMCPY_HOST_TO_DEVICE);
if (ret_rt_memcpy != RT_ERROR_NONE) {
@ -250,7 +250,7 @@ void AscendMemoryManager::SwapIn(void *host_ptr, void *device_ptr, size_t mem_si
}
}
void AscendMemoryManager::SwapOut(void *device_ptr, void *host_ptr, size_t mem_size, void *stream) {
void AscendMemoryManager::SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) {
if (stream == nullptr) {
auto ret_rt_memcpy = rtMemcpy(host_ptr, mem_size, device_ptr, mem_size, RT_MEMCPY_DEVICE_TO_HOST);
if (ret_rt_memcpy != RT_ERROR_NONE) {

View File

@ -42,8 +42,8 @@ class AscendMemoryManager : public MemoryManager {
return AscendMemoryPool::GetInstance().AllocContinuousTensorMem(total_size, size_list);
}
void SwapIn(void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override;
void SwapOut(void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override;
void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override;
void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override;
size_t GetAvailableMemSize() override;
protected:

View File

@ -80,7 +80,7 @@ bool MPICollective::CreateCommGroup(const std::string &name, const std::vector<u
}
CHECK_RET(rtSetDevice(local_rank_id_), RT_ERROR_NONE, "Call rtSetDevice error.");
HcclRootInfo rootInfo;
if (static_cast<size_t>(rank_id_) == ranks[0]) {
if (static_cast<unsigned int>(rank_id_) == ranks[0]) {
CHECK_RET(HcclGetRootInfo(&rootInfo), ::HcclResult::HCCL_SUCCESS, "HcclGetRootInfo failed.");
}
MPI_Group mpi_group = MPI_GROUP_NULL;

View File

@ -16,7 +16,6 @@
#include "runtime/device/bucket.h"
#include <memory>
#include "runtime/device/kernel_runtime_manager.h"
#include "frontend/parallel/context.h"
#include "utils/profile.h"

View File

@ -38,6 +38,7 @@ namespace {
constexpr auto kGradients = "Gradients";
constexpr auto kSpecifyParameter = "accu_status";
size_t kNPUShape = 8;
size_t kLastHandleDiff = 2;
} // namespace
namespace mindspore {
namespace device {
@ -1093,7 +1094,7 @@ void KernelAdjust::InsertOverflowCheckOperations(const std::shared_ptr<session::
new_execution_order.push_back(npu_get_cnode);
new_execution_order.push_back(assign_add_cnode);
}
if (i == execution_order.size() - 2) {
if (i == execution_order.size() - kLastHandleDiff) {
new_execution_order.push_back(execution_order[i + 1]);
if (next_full_name.find(kGradients) != std::string::npos) {
auto npu_get_cnode = CreateNPUGetFloatStatus(kernel_graph_ptr, npu_alloc_cnode);

View File

@ -45,6 +45,7 @@ namespace device {
constexpr float kMaxMemReuseFactor = 0.8;
constexpr float kMinMemReuseFactor = 0.5;
constexpr float kRetryFactor = 0.1;
constexpr size_t kAtomicCleanInputSize = 2;
namespace {
std::vector<AnfNodePtr> GetGraphInputs(const session::KernelGraph &graph) {
auto graph_inputs = graph.inputs();
@ -68,11 +69,21 @@ std::vector<AnfNodePtr> GetGraphInputs(const session::KernelGraph &graph) {
}
} // namespace
constexpr size_t kMinInputSize = 2;
KernelRuntime::~KernelRuntime() {}
KernelRuntime::~KernelRuntime() {
stream_ = nullptr;
independent_stream_ = nullptr;
communication_stream_ = nullptr;
}
bool KernelRuntime::Load(const session::KernelGraph &graph, bool is_task_sink) { return true; }
bool KernelRuntime::Load(const session::KernelGraph &, bool) {
MS_LOG(INFO) << "Call default load.";
return true;
}
bool KernelRuntime::LoadData(const session::KernelGraph &) { return false; }
bool KernelRuntime::LoadData(const session::KernelGraph &) {
MS_LOG(INFO) << "Call default load data.";
return false;
}
bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) {
MS_EXCEPTION_IF_NULL(kernel);
@ -789,7 +800,10 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
output_ptr += align_size_list[j];
}
}
bool KernelRuntime::KernelMemNotReuse(const AnfNodePtr &node) { return false; }
bool KernelRuntime::KernelMemNotReuse(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
return false;
}
DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) const {
MS_EXCEPTION_IF_NULL(anf_node);
@ -1165,8 +1179,7 @@ void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList
const std::shared_ptr<MemScheduler> &mem_scheduler) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(kernel_inputs);
const size_t kNodeInputSize = 2;
if (cnode->inputs().size() != kNodeInputSize) {
if (cnode->inputs().size() != kAtomicCleanInputSize) {
MS_LOG(EXCEPTION) << "Atomic Addr clean Node Input nodes not equal 2.";
}
MS_EXCEPTION_IF_NULL(cnode->inputs()[1]);
@ -1341,7 +1354,9 @@ void KernelRuntime::SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &m
tensor->set_sync_status(kNeedSyncHostToDevice);
continue;
}
SyncStream();
if (!SyncStream()) {
MS_LOG(ERROR) << "SyncStream failed";
}
auto origin_ptr = device_address->ptr_;
if (origin_ptr == nullptr) {
device_address->ptr_ = mem_scheduler->GetOrMalloc(device_address.get(), device_address->size_);
@ -1372,7 +1387,6 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &m
if (!input_node->isa<Parameter>()) {
continue;
}
auto input_param = input_node->cast<ParameterPtr>();
if (AnfAlgo::OutputAddrExist(input_node, 0)) {
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
MS_EXCEPTION_IF_NULL(tensor);
@ -1382,7 +1396,8 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &m
tensor->data_sync(false);
priority = kMemPriorityLow;
}
mem_scheduler->Init(device_address.get(), tensor->data_c(), tensor->data().nbytes(), priority);
auto tensor_size = LongToSize(tensor->data().nbytes());
mem_scheduler->Init(device_address.get(), tensor->data_c(), tensor_size, priority);
}
}
}

View File

@ -32,7 +32,7 @@ size_t MemoryManager::GetCommonAlignSize(size_t input_size) {
}
size_t MemoryManager::GetCommunicationAlignSize(size_t input_size) {
return (input_size + kMemAlignSize - 1) / kMemAlignSize * kMemAlignSize + 2 * kMemAlignSize;
return (input_size + kMemAlignSize - 1) / kMemAlignSize * kMemAlignSize + kTwiceMemAlignSize;
}
void MemoryManager::MallocSomasDynamicMem(const session::KernelGraph &graph) {
@ -136,7 +136,10 @@ uint8_t *MemoryManager::MallocMem(MemType type, size_t size, const DeviceAddress
return ptr;
}
uint8_t *MemoryManager::MallocDynamicMem(size_t size, bool communication_mem) { return nullptr; }
uint8_t *MemoryManager::MallocDynamicMem(size_t size, bool communication_mem) {
MS_LOG(INFO) << "Call default dynamic malloc " << size << " v " << communication_mem;
return nullptr;
}
bool MemoryManager::MallocMemFromMemPool(const DeviceAddressPtr address, size_t size) {
MS_EXCEPTION_IF_NULL(address);

View File

@ -27,8 +27,9 @@
namespace mindspore {
namespace device {
enum MemType { kStaticMem, kDynamicMem, kSomasReuseDynamicMem };
const int kGetAllOuts = -1;
const uint64_t kMemAlignSize = 512;
constexpr int kGetAllOuts = -1;
constexpr uint64_t kMemAlignSize = 512;
constexpr uint64_t kTwiceMemAlignSize = kMemAlignSize << 1;
using SomasPtr = mindspore::somas::SomasPtr;
class MemoryManager : public MemHandler {
@ -95,8 +96,12 @@ class MemoryManager : public MemHandler {
auto mem_size = iter->second->size();
cached_host_mem_[mem_size].emplace(iter->first);
}
void SwapIn(void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override {}
void SwapOut(void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override {}
void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override {
MS_LOG(INFO) << "Call default swap in " << host_ptr << "," << device_ptr << "," << mem_size << "," << stream;
}
void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override {
MS_LOG(INFO) << "Call default swap out " << host_ptr << "," << device_ptr << "," << mem_size << "," << stream;
}
size_t GetAvailableMemSize() override {
MS_LOG(ERROR) << "Return default 0 mem size!";
return 0;

View File

@ -15,12 +15,9 @@
*/
#include "runtime/device/memory_scheduler.h"
#include <map>
#include <set>
#include <memory>
#include <utility>
#include <algorithm>
#include "utils/log_adapter.h"
namespace mindspore {
namespace device {
void MemScheduler::Clear() {
@ -277,7 +274,7 @@ void MemScheduler::GenEventSpan() {
MS_EXCEPTION_IF_NULL(event);
auto span = event->index - last_index;
if (span > 1) {
(void)event_span_.insert(std::pair<size_t, std::shared_ptr<Event>>(span, event));
(void)event_span_.emplace(std::pair<size_t, std::shared_ptr<Event>>(span, event));
}
last_index = event->index;
}

View File

@ -31,8 +31,8 @@ class MemHandler {
virtual void FreeDevice(void *ptr) = 0;
virtual void *MallocHost(size_t mem_size) = 0;
virtual void FreeHost(void *ptr) = 0;
virtual void SwapIn(void *host_ptr, void *device_ptr, size_t mem_size, void *stream) = 0;
virtual void SwapOut(void *device_ptr, void *host_ptr, size_t mem_size, void *stream) = 0;
virtual void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) = 0;
virtual void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) = 0;
};
enum MemPriority { kMemPriorityLow, kMemPriorityMedium, kMemPriorityHigh };