forked from mindspore-Ecosystem/mindspore
fix review code
This commit is contained in:
parent
8867c67d61
commit
2a8f0a75be
|
@ -52,6 +52,38 @@ namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
namespace ascend {
|
namespace ascend {
|
||||||
static const size_t PRAMATER_OUTPUT_INDEX = 0;
|
static const size_t PRAMATER_OUTPUT_INDEX = 0;
|
||||||
|
namespace {
|
||||||
|
std::string GetRankId() {
|
||||||
|
std::string rank_id_str;
|
||||||
|
#ifdef ENABLE_MPI
|
||||||
|
auto mpi_config_ptr = MpiConfig::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(mpi_config_ptr);
|
||||||
|
if (mpi_config_ptr->enable_mpi()) {
|
||||||
|
int rank_id = device::cpu::MPIAdapter::Instance().GetRankId();
|
||||||
|
const char *offset = std::getenv("RANK_OFFSET");
|
||||||
|
if (offset != nullptr) {
|
||||||
|
try {
|
||||||
|
int rank_offset = std::stoi(offset);
|
||||||
|
rank_id += rank_offset;
|
||||||
|
} catch (std::invalid_argument) {
|
||||||
|
MS_LOG(EXCEPTION) << "stoi invalid argument:" << offset;
|
||||||
|
} catch (std::out_of_range) {
|
||||||
|
MS_LOG(EXCEPTION) << "stoi out_of_range:" << offset;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rank_id_str = std::to_string(rank_id);
|
||||||
|
} else {
|
||||||
|
rank_id_str = std::getenv("RANK_ID");
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
rank_id_str = std::getenv("RANK_ID");
|
||||||
|
#endif
|
||||||
|
if (rank_id_str.empty()) {
|
||||||
|
MS_LOG(ERROR) << "get hccl rankid failed, please set env RANK_ID";
|
||||||
|
}
|
||||||
|
return rank_id_str;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
AscendKernelRuntime::~AscendKernelRuntime() { graph_model_map_.clear(); }
|
AscendKernelRuntime::~AscendKernelRuntime() { graph_model_map_.clear(); }
|
||||||
|
|
||||||
|
@ -497,7 +529,6 @@ bool AscendKernelRuntime::HcclInit() {
|
||||||
if (!context_ptr->IsTsdOpened()) {
|
if (!context_ptr->IsTsdOpened()) {
|
||||||
MS_LOG(EXCEPTION) << "Hccl dependent tsd is not open";
|
MS_LOG(EXCEPTION) << "Hccl dependent tsd is not open";
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(INFO) << "do hcom init";
|
MS_LOG(INFO) << "do hcom init";
|
||||||
auto config_path_str = std::getenv("MINDSPORE_HCCL_CONFIG_PATH");
|
auto config_path_str = std::getenv("MINDSPORE_HCCL_CONFIG_PATH");
|
||||||
if (config_path_str == nullptr) {
|
if (config_path_str == nullptr) {
|
||||||
|
@ -507,44 +538,14 @@ bool AscendKernelRuntime::HcclInit() {
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
std::string rank_id_str = GetRankId();
|
||||||
auto full_path = realpath(config_path_str, nullptr);
|
auto full_path = realpath(config_path_str, nullptr);
|
||||||
if (full_path == nullptr) {
|
if (full_path == nullptr) {
|
||||||
MS_LOG(ERROR) << "file path " << config_path_str << " does not exist";
|
MS_LOG(ERROR) << "file path " << config_path_str << " does not exist";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const char *identify = nullptr;
|
MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << rank_id_str;
|
||||||
#ifdef ENABLE_MPI
|
hcclResult_t res = hcom_init(full_path, rank_id_str.c_str());
|
||||||
std::string rank_id_tmp;
|
|
||||||
auto mpi_config_ptr = MpiConfig::GetInstance();
|
|
||||||
MS_EXCEPTION_IF_NULL(mpi_config_ptr);
|
|
||||||
if (mpi_config_ptr->enable_mpi()) {
|
|
||||||
int rank_id = device::cpu::MPIAdapter::Instance().GetRankId();
|
|
||||||
const char *offset = std::getenv("RANK_OFFSET");
|
|
||||||
if (offset != nullptr) {
|
|
||||||
try {
|
|
||||||
int rank_offset = std::stoi(offset);
|
|
||||||
rank_id += rank_offset;
|
|
||||||
} catch (std::invalid_argument) {
|
|
||||||
MS_LOG(EXCEPTION) << "stoi invalid argument:" << offset;
|
|
||||||
} catch (std::out_of_range) {
|
|
||||||
MS_LOG(EXCEPTION) << "stoi out_of_range:" << offset;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rank_id_tmp = std::to_string(rank_id);
|
|
||||||
identify = rank_id_tmp.c_str();
|
|
||||||
} else {
|
|
||||||
identify = std::getenv("RANK_ID");
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
identify = std::getenv("RANK_ID");
|
|
||||||
#endif
|
|
||||||
if (identify == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "get hccl rankid failed, please set env RANK_ID";
|
|
||||||
free(full_path);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << identify;
|
|
||||||
hcclResult_t res = hcom_init(full_path, identify);
|
|
||||||
free(full_path);
|
free(full_path);
|
||||||
if (res != HCCL_SUCCESS) {
|
if (res != HCCL_SUCCESS) {
|
||||||
MS_LOG(ERROR) << "hcom init failed, res is " << static_cast<int>(res);
|
MS_LOG(ERROR) << "hcom init failed, res is " << static_cast<int>(res);
|
||||||
|
|
|
@ -303,15 +303,12 @@ void AscendStreamAssign::InsertSendRecvForDiffHcom(const shared_ptr<mindspore::s
|
||||||
fusion_hcom_index.emplace_back(i);
|
fusion_hcom_index.emplace_back(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (fusion_hcom_index.size() < 2) {
|
if (fusion_hcom_index.size() < 2) {
|
||||||
MS_LOG(INFO) << "fusion hcom size is less than 2, no need insert event between them";
|
MS_LOG(INFO) << "fusion hcom size is less than 2, no need insert event between them";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t first_index = fusion_hcom_index[0];
|
uint32_t first_index = fusion_hcom_index[0];
|
||||||
uint32_t last_index = fusion_hcom_index[fusion_hcom_index.size() - 1];
|
uint32_t last_index = fusion_hcom_index[fusion_hcom_index.size() - 1];
|
||||||
|
|
||||||
uint32_t cur_event_id = total_event_num_;
|
uint32_t cur_event_id = total_event_num_;
|
||||||
uint32_t pre_hcom_stream_id = UINT32_MAX;
|
uint32_t pre_hcom_stream_id = UINT32_MAX;
|
||||||
std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_index, std::back_inserter(orders));
|
std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_index, std::back_inserter(orders));
|
||||||
|
@ -322,13 +319,11 @@ void AscendStreamAssign::InsertSendRecvForDiffHcom(const shared_ptr<mindspore::s
|
||||||
orders.emplace_back(cur_cnode);
|
orders.emplace_back(cur_cnode);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto cur_hcom_stream_id = AnfAlgo::GetStreamId(cur_cnode);
|
auto cur_hcom_stream_id = AnfAlgo::GetStreamId(cur_cnode);
|
||||||
if (cur_hcom_stream_id == pre_hcom_stream_id) {
|
if (cur_hcom_stream_id == pre_hcom_stream_id) {
|
||||||
orders.emplace_back(cur_cnode);
|
orders.emplace_back(cur_cnode);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (i == first_index) {
|
if (i == first_index) {
|
||||||
// first fusion hcom
|
// first fusion hcom
|
||||||
orders.emplace_back(cur_cnode);
|
orders.emplace_back(cur_cnode);
|
||||||
|
@ -348,15 +343,12 @@ void AscendStreamAssign::InsertSendRecvForDiffHcom(const shared_ptr<mindspore::s
|
||||||
auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
|
auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
|
||||||
orders.emplace_back(send);
|
orders.emplace_back(send);
|
||||||
}
|
}
|
||||||
|
|
||||||
pre_hcom_stream_id = cur_hcom_stream_id;
|
pre_hcom_stream_id = cur_hcom_stream_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::copy(cnode_ptr_list.begin() + last_index + 1, cnode_ptr_list.end(), std::back_inserter(orders));
|
std::copy(cnode_ptr_list.begin() + last_index + 1, cnode_ptr_list.end(), std::back_inserter(orders));
|
||||||
graph_ptr->set_execution_order(orders);
|
graph_ptr->set_execution_order(orders);
|
||||||
total_event_num_ = cur_event_id;
|
total_event_num_ = cur_event_id;
|
||||||
MS_LOG(INFO) << "after indsert between allreduce, total event nums[" << total_event_num_ << "]";
|
MS_LOG(INFO) << "after indsert between allreduce, total event nums[" << total_event_num_ << "]\n end";
|
||||||
MS_LOG(INFO) << "end";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) {
|
void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) {
|
||||||
|
@ -826,7 +818,6 @@ void AscendStreamAssign::ReorderIndependentOrders(const shared_ptr<mindspore::se
|
||||||
std::vector<CNodePtr> exe_orders;
|
std::vector<CNodePtr> exe_orders;
|
||||||
std::vector<CNodePtr> independents;
|
std::vector<CNodePtr> independents;
|
||||||
std::vector<CNodePtr> others;
|
std::vector<CNodePtr> others;
|
||||||
|
|
||||||
auto cnode_ptr_list = graph_ptr->execution_order();
|
auto cnode_ptr_list = graph_ptr->execution_order();
|
||||||
MS_LOG(INFO) << "before reorder, graph orders size:" << cnode_ptr_list.size();
|
MS_LOG(INFO) << "before reorder, graph orders size:" << cnode_ptr_list.size();
|
||||||
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
||||||
|
@ -838,19 +829,16 @@ void AscendStreamAssign::ReorderIndependentOrders(const shared_ptr<mindspore::se
|
||||||
others.emplace_back(cur_cnode_ptr);
|
others.emplace_back(cur_cnode_ptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (others.empty()) {
|
if (others.empty()) {
|
||||||
std::copy(independents.begin(), independents.end(), std::back_inserter(exe_orders));
|
std::copy(independents.begin(), independents.end(), std::back_inserter(exe_orders));
|
||||||
graph_ptr->set_execution_order(exe_orders);
|
graph_ptr->set_execution_order(exe_orders);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (independents.empty()) {
|
if (independents.empty()) {
|
||||||
std::copy(others.begin(), others.end(), std::back_inserter(exe_orders));
|
std::copy(others.begin(), others.end(), std::back_inserter(exe_orders));
|
||||||
graph_ptr->set_execution_order(exe_orders);
|
graph_ptr->set_execution_order(exe_orders);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<CNodePtr> processed;
|
std::vector<CNodePtr> processed;
|
||||||
for (size_t i = 0; i < others.size(); i++) {
|
for (size_t i = 0; i < others.size(); i++) {
|
||||||
auto begin = others.begin() + i;
|
auto begin = others.begin() + i;
|
||||||
|
@ -862,7 +850,6 @@ void AscendStreamAssign::ReorderIndependentOrders(const shared_ptr<mindspore::se
|
||||||
if (it != processed.end()) {
|
if (it != processed.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto res = FindTargetOp(begin, end, cur_independent);
|
auto res = FindTargetOp(begin, end, cur_independent);
|
||||||
if (res != end) {
|
if (res != end) {
|
||||||
flag = true;
|
flag = true;
|
||||||
|
@ -872,12 +859,10 @@ void AscendStreamAssign::ReorderIndependentOrders(const shared_ptr<mindspore::se
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!flag) {
|
if (!flag) {
|
||||||
exe_orders.emplace_back(*begin);
|
exe_orders.emplace_back(*begin);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(INFO) << "after reorder, graph orders size:" << exe_orders.size();
|
MS_LOG(INFO) << "after reorder, graph orders size:" << exe_orders.size();
|
||||||
graph_ptr->set_execution_order(exe_orders);
|
graph_ptr->set_execution_order(exe_orders);
|
||||||
}
|
}
|
||||||
|
|
|
@ -121,7 +121,6 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) {
|
||||||
MS_LOG(ERROR) << "Register profiling Engine failed.";
|
MS_LOG(ERROR) << "Register profiling Engine failed.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto context = MsContext::GetInstance();
|
auto context = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
const string prof_options_str = context->profiling_options();
|
const string prof_options_str = context->profiling_options();
|
||||||
|
@ -130,7 +129,6 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) {
|
||||||
MS_LOG(WARNING) << "Profiling is enabled, but profiling option is not set!";
|
MS_LOG(WARNING) << "Profiling is enabled, but profiling option is not set!";
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// current one docker only use one device`
|
// current one docker only use one device`
|
||||||
Json p_device;
|
Json p_device;
|
||||||
// JOBID
|
// JOBID
|
||||||
|
@ -149,7 +147,6 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) {
|
||||||
// only one device, but sProfMgrStartUp API require for device list
|
// only one device, but sProfMgrStartUp API require for device list
|
||||||
Json devices;
|
Json devices;
|
||||||
devices[0] = p_device;
|
devices[0] = p_device;
|
||||||
|
|
||||||
Json startCfg;
|
Json startCfg;
|
||||||
startCfg["startCfg"] = devices;
|
startCfg["startCfg"] = devices;
|
||||||
|
|
||||||
|
@ -157,9 +154,7 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << startCfg;
|
ss << startCfg;
|
||||||
std::string cfg = ss.str();
|
std::string cfg = ss.str();
|
||||||
|
|
||||||
MS_LOG(INFO) << "profiling config " << cfg;
|
MS_LOG(INFO) << "profiling config " << cfg;
|
||||||
|
|
||||||
auto ret = rtProfilerStart();
|
auto ret = rtProfilerStart();
|
||||||
if (ret != RT_ERROR_NONE) {
|
if (ret != RT_ERROR_NONE) {
|
||||||
MS_LOG(INFO) << "Call rtProfilerStart failed, ret:" << ret;
|
MS_LOG(INFO) << "Call rtProfilerStart failed, ret:" << ret;
|
||||||
|
|
|
@ -33,7 +33,7 @@ bool CPUDeviceAddress::SyncDeviceToHost(const std::vector<int> & /*shape*/, size
|
||||||
}
|
}
|
||||||
|
|
||||||
if (type == type_id_) {
|
if (type == type_id_) {
|
||||||
auto ret_code = memcpy_s(host_ptr, size, ptr_, size);
|
auto ret_code = memcpy_s(host_ptr, size, ptr_, size_);
|
||||||
if (ret_code != EOK) {
|
if (ret_code != EOK) {
|
||||||
MS_LOG(ERROR) << "Failed to copy tensor!";
|
MS_LOG(ERROR) << "Failed to copy tensor!";
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -34,6 +34,23 @@ namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
namespace cpu {
|
namespace cpu {
|
||||||
const size_t INIT_NODE_REF = 1;
|
const size_t INIT_NODE_REF = 1;
|
||||||
|
namespace {
|
||||||
|
TypeId GetCPUSupportOutputTypeId(const TypeId type_id) {
|
||||||
|
TypeId support_type_id = type_id;
|
||||||
|
if (type_id == kNumberTypeUInt32) {
|
||||||
|
support_type_id = kNumberTypeInt32;
|
||||||
|
}
|
||||||
|
if (type_id == kNumberTypeFloat || type_id == kNumberTypeFloat16 || type_id == kNumberTypeFloat32 ||
|
||||||
|
type_id == kNumberTypeFloat64) {
|
||||||
|
support_type_id = kNumberTypeFloat32;
|
||||||
|
}
|
||||||
|
if (support_type_id != kNumberTypeInt32 && support_type_id != kNumberTypeFloat32) {
|
||||||
|
MS_LOG(EXCEPTION) << "Check output type failed.";
|
||||||
|
}
|
||||||
|
return support_type_id;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void CPUKernelRuntime::AssignKernelAddress(session::KernelGraph *kernel_graph) {
|
void CPUKernelRuntime::AssignKernelAddress(session::KernelGraph *kernel_graph) {
|
||||||
AssignValueNodeAddress(kernel_graph);
|
AssignValueNodeAddress(kernel_graph);
|
||||||
AssignInputNodeAddress(kernel_graph);
|
AssignInputNodeAddress(kernel_graph);
|
||||||
|
@ -149,16 +166,7 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(const AnfNodePtr &input_node, siz
|
||||||
std::vector<int> temp_shape;
|
std::vector<int> temp_shape;
|
||||||
(void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end());
|
(void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end());
|
||||||
TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index);
|
TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index);
|
||||||
if (type_id == kNumberTypeUInt32) {
|
type_id = GetCPUSupportOutputTypeId(type_id);
|
||||||
type_id = kNumberTypeInt32;
|
|
||||||
}
|
|
||||||
if (type_id == kNumberTypeFloat || type_id == kNumberTypeFloat16 || type_id == kNumberTypeFloat32 ||
|
|
||||||
type_id == kNumberTypeFloat64) {
|
|
||||||
type_id = kNumberTypeFloat32;
|
|
||||||
}
|
|
||||||
if (type_id != kNumberTypeInt32 && type_id != kNumberTypeFloat32) {
|
|
||||||
MS_LOG(EXCEPTION) << "Check output type failed.";
|
|
||||||
}
|
|
||||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
|
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
|
||||||
MS_EXCEPTION_IF_NULL(tensor);
|
MS_EXCEPTION_IF_NULL(tensor);
|
||||||
if (address->ref_count_ > 0 && address->ptr_ != nullptr) {
|
if (address->ref_count_ > 0 && address->ptr_ != nullptr) {
|
||||||
|
|
|
@ -54,8 +54,9 @@ KernelRuntime *KernelRuntimeManager::GetSingleKernelRuntime(const std::string &d
|
||||||
return runtime_iter->second.get();
|
return runtime_iter->second.get();
|
||||||
} else if (runtime_map_.size() > 0) {
|
} else if (runtime_map_.size() > 0) {
|
||||||
auto cur_runtime_key = runtime_map_.begin()->first;
|
auto cur_runtime_key = runtime_map_.begin()->first;
|
||||||
if (cur_runtime_key.rfind('_') != std::string::npos) {
|
auto find_pos = cur_runtime_key.rfind('_');
|
||||||
auto cur_device_id = cur_runtime_key.substr(cur_runtime_key.rfind('_') + 1);
|
if (find_pos != std::string::npos) {
|
||||||
|
auto cur_device_id = cur_runtime_key.substr(find_pos + 1);
|
||||||
MS_LOG(EXCEPTION) << "Can't change device id in runtime, already set device id: " << cur_device_id
|
MS_LOG(EXCEPTION) << "Can't change device id in runtime, already set device id: " << cur_device_id
|
||||||
<< ", set device id: " << device_id << " failed";
|
<< ", set device id: " << device_id << " failed";
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,50 +24,32 @@ namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
CheckParam(kernel_node);
|
CheckParam(kernel_node);
|
||||||
|
|
||||||
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||||
input_lens_ = 1;
|
input_lens_ = 1;
|
||||||
for (auto shape : input_shape_) {
|
for (auto shape : input_shape_) {
|
||||||
MS_LOG(INFO) << "input shape: " << shape;
|
|
||||||
input_lens_ = input_lens_ * shape;
|
input_lens_ = input_lens_ * shape;
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "input lens: " << input_lens_;
|
|
||||||
|
|
||||||
indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||||
indices_lens_ = 1;
|
indices_lens_ = 1;
|
||||||
for (auto shape : indices_shape_) {
|
for (auto shape : indices_shape_) {
|
||||||
MS_LOG(INFO) << "indice shape: " << shape;
|
|
||||||
indices_lens_ = indices_lens_ * shape;
|
indices_lens_ = indices_lens_ * shape;
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "indice lens: " << indices_lens_;
|
|
||||||
|
|
||||||
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||||
for (auto shape : output_shape_) {
|
|
||||||
MS_LOG(INFO) << "output shape: " << shape;
|
|
||||||
}
|
|
||||||
auto output_type = AnfAlgo::GetOutputInferDataType(kernel_node, 0);
|
|
||||||
MS_LOG(INFO) << "output type: " << output_type;
|
|
||||||
|
|
||||||
axis_ = 4 - input_shape_.size();
|
axis_ = 4 - input_shape_.size();
|
||||||
MS_LOG(INFO) << "axis_: " << axis_;
|
|
||||||
reduce_scatter_flag_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "reduce_scatter_flag");
|
reduce_scatter_flag_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "reduce_scatter_flag");
|
||||||
MS_LOG(INFO) << "reduce_scatter_flag: " << reduce_scatter_flag_;
|
|
||||||
#ifdef ENABLE_MPI
|
#ifdef ENABLE_MPI
|
||||||
if (reduce_scatter_flag_) {
|
if (reduce_scatter_flag_) {
|
||||||
size_t gatherv2_out_lens = 1;
|
size_t gatherv2_out_lens = 1;
|
||||||
for (int i = 0; i < SizeToInt(input_shape_.size()); i++) {
|
for (int i = 0; i < SizeToInt(input_shape_.size()); i++) {
|
||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
for (int j = 0; j < SizeToInt(indices_shape_.size()); j++) {
|
for (int j = 0; j < SizeToInt(indices_shape_.size()); j++) {
|
||||||
MS_LOG(DEBUG) << "gatherv2 out shape: " << indices_shape_[j];
|
|
||||||
gatherv2_out_lens = gatherv2_out_lens * indices_shape_[j];
|
gatherv2_out_lens = gatherv2_out_lens * indices_shape_[j];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(DEBUG) << "gatherv2 out shape: " << input_shape_[i];
|
|
||||||
gatherv2_out_lens = gatherv2_out_lens * input_shape_[i];
|
gatherv2_out_lens = gatherv2_out_lens * input_shape_[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
gatherv2_out_lens_ = gatherv2_out_lens * sizeof(float);
|
gatherv2_out_lens_ = gatherv2_out_lens * sizeof(float);
|
||||||
MS_LOG(INFO) << "gatherv2 out lens: " << gatherv2_out_lens_;
|
|
||||||
gather_v2_out_ = malloc(gatherv2_out_lens_);
|
gather_v2_out_ = malloc(gatherv2_out_lens_);
|
||||||
if (gather_v2_out_ == nullptr) {
|
if (gather_v2_out_ == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel malloc failed, malloc lens: " << gatherv2_out_lens_;
|
MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel malloc failed, malloc lens: " << gatherv2_out_lens_;
|
||||||
|
@ -76,9 +58,7 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel memset gatherv2 out buff failed";
|
MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel memset gatherv2 out buff failed";
|
||||||
}
|
}
|
||||||
|
|
||||||
split_num_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "split_num");
|
split_num_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "split_num");
|
||||||
MS_LOG(INFO) << "split_num: " << split_num_;
|
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
if (reduce_scatter_flag_) {
|
if (reduce_scatter_flag_) {
|
||||||
|
@ -86,7 +66,6 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
offset_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "offset");
|
offset_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "offset");
|
||||||
MS_LOG(INFO) << "offset: " << offset_;
|
|
||||||
CPUKernelUtils::ExpandDimsTo4(&input_shape_);
|
CPUKernelUtils::ExpandDimsTo4(&input_shape_);
|
||||||
CPUKernelUtils::ExpandDimsTo4(&output_shape_);
|
CPUKernelUtils::ExpandDimsTo4(&output_shape_);
|
||||||
}
|
}
|
||||||
|
@ -94,21 +73,11 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||||
const std::vector<kernel::AddressPtr> &outputs) {
|
const std::vector<kernel::AddressPtr> &outputs) {
|
||||||
#if defined(_WIN32) || defined(_WIN64)
|
|
||||||
auto start_time = std::chrono::steady_clock::now();
|
|
||||||
#else
|
|
||||||
struct timeval start_time, end_time;
|
|
||||||
(void)gettimeofday(&start_time, nullptr);
|
|
||||||
#endif
|
|
||||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||||
MS_LOG(DEBUG) << "output addr: " << output_addr << "output size: " << outputs[0]->size;
|
|
||||||
float *gather_out_addr = reduce_scatter_flag_ ? reinterpret_cast<float *>(gather_v2_out_) : output_addr;
|
float *gather_out_addr = reduce_scatter_flag_ ? reinterpret_cast<float *>(gather_v2_out_) : output_addr;
|
||||||
MS_LOG(DEBUG) << "gatherv2 out addr: " << gather_out_addr;
|
|
||||||
|
|
||||||
size_t dim0 = input_shape_[0];
|
size_t dim0 = input_shape_[0];
|
||||||
size_t dim1 = input_shape_[1];
|
size_t dim1 = input_shape_[1];
|
||||||
size_t dim2 = input_shape_[2];
|
size_t dim2 = input_shape_[2];
|
||||||
|
|
||||||
if (axis_ == 3) {
|
if (axis_ == 3) {
|
||||||
for (size_t i = 0; i < dim0; ++i) {
|
for (size_t i = 0; i < dim0; ++i) {
|
||||||
for (size_t j = 0; j < dim1; ++j) {
|
for (size_t j = 0; j < dim1; ++j) {
|
||||||
|
@ -130,7 +99,6 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
||||||
} else if (axis_ == 0) {
|
} else if (axis_ == 0) {
|
||||||
LookUpTable(inputs, 0, 0, 0, &gather_out_addr);
|
LookUpTable(inputs, 0, 0, 0, &gather_out_addr);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef ENABLE_MPI
|
#ifdef ENABLE_MPI
|
||||||
if (reduce_scatter_flag_) {
|
if (reduce_scatter_flag_) {
|
||||||
size_t one_split_lens = gatherv2_out_lens_ / split_num_ / sizeof(float);
|
size_t one_split_lens = gatherv2_out_lens_ / split_num_ / sizeof(float);
|
||||||
|
@ -143,21 +111,10 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(_WIN32) || defined(_WIN64)
|
|
||||||
auto end_time = std::chrono::steady_clock::now();
|
|
||||||
std::chrono::duration<double, std::ratio<1, 1000000>> cost = end_time - start_time;
|
|
||||||
MS_LOG(INFO) << "EmbeddingLookUpCPUKernel, used time: " << cost.count() << " us";
|
|
||||||
#else
|
|
||||||
(void)gettimeofday(&end_time, nullptr);
|
|
||||||
uint64_t time = 1000000 * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
|
|
||||||
time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
|
|
||||||
MS_LOG(INFO) << "EmbeddingLookUpCPUKernel, used time: " << time << " us";
|
|
||||||
#endif
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void LookUpTable_task(float *input_addr, float *output_addr, int *indices_addr, size_t indices_lens, size_t num,
|
void LookUpTable_task(const float *input_addr, float *output_addr, int *indices_addr, size_t indices_lens, size_t num,
|
||||||
size_t dim0, size_t dim1, size_t dim2, int offset, size_t axis, std::vector<size_t> input_shape,
|
size_t dim0, size_t dim1, size_t dim2, int offset, size_t axis, std::vector<size_t> input_shape,
|
||||||
size_t input_lens) {
|
size_t input_lens) {
|
||||||
size_t lens = num * sizeof(float);
|
size_t lens = num * sizeof(float);
|
||||||
|
@ -182,7 +139,6 @@ void LookUpTable_task(float *input_addr, float *output_addr, int *indices_addr,
|
||||||
if (ret != EOK) {
|
if (ret != EOK) {
|
||||||
MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed.";
|
MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed.";
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
auto ret = memset_s(output_addr, lens, 0, lens);
|
auto ret = memset_s(output_addr, lens, 0, lens);
|
||||||
if (ret != EOK) {
|
if (ret != EOK) {
|
||||||
|
@ -204,6 +160,7 @@ void LookUpTable_task(float *input_addr, float *output_addr, int *indices_addr,
|
||||||
output_addr += num;
|
output_addr += num;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void EmbeddingLookUpCPUKernel::LookUpTable(const std::vector<kernel::AddressPtr> &inputs, size_t dim0, size_t dim1,
|
void EmbeddingLookUpCPUKernel::LookUpTable(const std::vector<kernel::AddressPtr> &inputs, size_t dim0, size_t dim1,
|
||||||
size_t dim2, float **output_addr) {
|
size_t dim2, float **output_addr) {
|
||||||
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
||||||
|
|
|
@ -30,7 +30,7 @@ void SubCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void sub_task(int *in_addr, int *out_addr, size_t lens, int offset) {
|
void sub_task(const int *in_addr, int *out_addr, size_t lens, int offset) {
|
||||||
for (size_t i = 0; i < lens; i++) {
|
for (size_t i = 0; i < lens; i++) {
|
||||||
out_addr[i] = in_addr[i] - offset;
|
out_addr[i] = in_addr[i] - offset;
|
||||||
}
|
}
|
||||||
|
@ -55,7 +55,7 @@ bool SubCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
output_addr[i] = input_addr[i] - offset_;
|
output_addr[i] = input_addr[i] - offset_;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
size_t thread_num = 4;
|
const size_t thread_num = 4;
|
||||||
std::thread threads[4];
|
std::thread threads[4];
|
||||||
size_t process_lens = (lens + thread_num - 1) / thread_num;
|
size_t process_lens = (lens + thread_num - 1) / thread_num;
|
||||||
size_t process_offset = 0;
|
size_t process_offset = 0;
|
||||||
|
|
Loading…
Reference in New Issue