forked from mindspore-Ecosystem/mindspore
Support dynamic shape void* in message
This commit is contained in:
parent
6685a89540
commit
60f82161f6
|
@ -901,7 +901,8 @@ const std::set<std::string> kComputeDepend = {kUniqueOpName,
|
|||
kSegmentMeanOpName,
|
||||
kSegmentProdOpName,
|
||||
kNonZeroOpName,
|
||||
kSparseSparseMinimumOpName};
|
||||
kSparseSparseMinimumOpName,
|
||||
kRpcRecvOpName};
|
||||
|
||||
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
|
||||
kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC};
|
||||
|
|
|
@ -39,24 +39,25 @@ class RpcRecvKernelMod : public RpcKernelMod {
|
|||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(remote_input_);
|
||||
// If the string body is not empty, it means we need to copy data from 'body' instead of raw pointer 'data'.
|
||||
bool use_string_msg = !remote_input_->Body().empty();
|
||||
auto data_ptr = use_string_msg ? (remote_input_->Body().data()) : (static_cast<char *>(remote_input_->data));
|
||||
size_t data_size = use_string_msg ? (remote_input_->Body().size()) : (remote_input_->size);
|
||||
|
||||
if (is_dynamic_shape_) {
|
||||
if (real_data_offset_.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Dynamic shape data must have data offsets.";
|
||||
MS_LOG(EXCEPTION) << "Dynamic shape data must have data offsets to copy from source message.";
|
||||
}
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
MS_EXCEPTION_IF_NULL(inputs[i]->addr);
|
||||
int ret = memcpy_s(inputs[i]->addr, inputs[i]->size, remote_input_->Body().data() + real_data_offset_[i],
|
||||
inputs[i]->size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s for recv output failed, ret code: " << ret;
|
||||
MS_LOG(EXCEPTION) << "memcpy_s for recv output " << i << " failed, ret code: " << ret;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
size_t offset = 0;
|
||||
// If the string body is not empty, it means we need to copy data from 'body' instead of raw pointer 'data'.
|
||||
bool use_string_msg = !remote_input_->Body().empty();
|
||||
auto data_ptr = use_string_msg ? (remote_input_->Body().data()) : (static_cast<char *>(remote_input_->data));
|
||||
size_t data_size = use_string_msg ? (remote_input_->Body().size()) : (remote_input_->size);
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
MS_EXCEPTION_IF_NULL(inputs[i]->addr);
|
||||
int ret = memcpy_s(inputs[i]->addr, inputs[i]->size, data_ptr + offset, inputs[i]->size);
|
||||
|
@ -67,6 +68,7 @@ class RpcRecvKernelMod : public RpcKernelMod {
|
|||
offset += inputs[i]->size;
|
||||
// Maybe the size of data from remote is smaller than inputs size, need to break in advance to avoid illegal
|
||||
// memory access. For example, the 'umonad' inputs of RpcRecvKernel is not sent from remote.
|
||||
// This should be fixed in graph optimizing step.
|
||||
if (offset == data_size) {
|
||||
break;
|
||||
}
|
||||
|
@ -86,6 +88,8 @@ class RpcRecvKernelMod : public RpcKernelMod {
|
|||
recv_monad_ = true;
|
||||
}
|
||||
is_dynamic_shape_ = common::AnfAlgo::IsDynamicShape(kernel_node);
|
||||
// RpcRecv kernel is similar with Unique, the next op's infer op must be launched after RpcRecv kernel is done.
|
||||
is_need_retrieve_output_shape_ = true;
|
||||
}
|
||||
|
||||
int Resize(
|
||||
|
|
|
@ -15,15 +15,42 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/rpc/rpc_send_kernel.h"
|
||||
#include <string>
|
||||
#include "runtime/device/ms_device_shape_transfer.h"
|
||||
#include "proto/rpc.pb.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void RpcSendKernelMod::Init(const CNodePtr &kernel_node) {
|
||||
DeprecatedNativeCpuKernelMod::Init(kernel_node);
|
||||
// Assign workspace memory with the same size of inputs. It's the data which will be sent to remote.
|
||||
// Assign one piece of workspace memory with the same size of all inputs. It's the data which will be sent to remote.
|
||||
// Only allocate one piece of workspace memory to avoid extra memory copying and serialize inputs data to one message.
|
||||
size_t total_size = 0;
|
||||
total_size = std::accumulate(input_size_list_.begin(), input_size_list_.end(), total_size,
|
||||
[](size_t total_size, const auto &input_size) { return total_size + input_size; });
|
||||
if (common::AnfAlgo::IsDynamicShape(kernel_node)) {
|
||||
// In dynamic shape scenario, workspace size should be updated.
|
||||
size_t input_size = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t i = 0; i < input_size; i++) {
|
||||
auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(kernel_node, i, false);
|
||||
auto real_input = input_node_with_index.first;
|
||||
auto real_input_index = input_node_with_index.second;
|
||||
MS_EXCEPTION_IF_NULL(real_input);
|
||||
|
||||
auto shapes = trans::GetRuntimePaddingShape(real_input, real_input_index);
|
||||
TypeId data_type = common::AnfAlgo::GetOutputInferDataType(real_input, real_input_index);
|
||||
|
||||
runtime::rpc::DynamicShapeMessage pb_msg;
|
||||
pb_msg.set_type_id(static_cast<int>(data_type));
|
||||
*pb_msg.mutable_shape_vector() = {shapes.begin(), shapes.end()};
|
||||
std::string pb_msg_str = pb_msg.SerializeAsString();
|
||||
total_size += strlen(kRpcDynamicShapeData);
|
||||
total_size += sizeof(size_t);
|
||||
total_size += pb_msg_str.size();
|
||||
total_size += input_size_list_[i];
|
||||
}
|
||||
} else {
|
||||
total_size = std::accumulate(input_size_list_.begin(), input_size_list_.end(), total_size,
|
||||
[](size_t total_size, const auto &input_size) { return total_size + input_size; });
|
||||
}
|
||||
workspace_size_list_.push_back(total_size);
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr char kRpcDynamicShapeData[] = "RPC_DYNAMIC_SHAPE_DATA";
|
||||
// RpcSendKernel send data to another process across network communication.
|
||||
class RpcSendKernelMod : public RpcKernelMod {
|
||||
public:
|
||||
|
|
|
@ -55,13 +55,22 @@ MessageBase *MuxRecvActor::HandleMessage(MessageBase *const msg) {
|
|||
void MuxRecvActor::ParseFinalizeReqData(size_t data_len, const MessageBase *const msg, bool *need_finalize) {
|
||||
MS_EXCEPTION_IF_NULL(msg);
|
||||
MS_EXCEPTION_IF_NULL(need_finalize);
|
||||
const std::string &msg_body = msg->body;
|
||||
size_t msg_len = msg_body.length();
|
||||
if (data_len == msg_len) {
|
||||
|
||||
size_t req_data_size = 0;
|
||||
RpcDataPtr finaliz_req_data;
|
||||
if (common::GetEnv("use_void").empty()) {
|
||||
req_data_size = msg->body.size();
|
||||
finaliz_req_data = const_cast<RpcDataPtr>(msg->body.c_str());
|
||||
} else {
|
||||
MS_EXCEPTION_IF_NULL(msg->data);
|
||||
req_data_size = msg->size;
|
||||
finaliz_req_data = static_cast<RpcDataPtr>(msg->data);
|
||||
}
|
||||
if (data_len == req_data_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t remainder_len = msg_len - data_len;
|
||||
size_t remainder_len = req_data_size - data_len;
|
||||
size_t finalize_header_size = strlen(kFinalizeMuxRecvActor);
|
||||
if (remainder_len <= finalize_header_size) {
|
||||
MS_LOG(EXCEPTION) << "Not found msg header[" << kFinalizeMuxRecvActor << "] in received message";
|
||||
|
@ -71,7 +80,7 @@ void MuxRecvActor::ParseFinalizeReqData(size_t data_len, const MessageBase *cons
|
|||
MS_LOG(EXCEPTION) << "Invalid finalize request message";
|
||||
}
|
||||
|
||||
const void *need_finalize_actor_data = msg_body.c_str() + data_len + finalize_header_size;
|
||||
const void *need_finalize_actor_data = finaliz_req_data + data_len + finalize_header_size;
|
||||
MS_EXCEPTION_IF_NULL(need_finalize_actor_data);
|
||||
bool finalize_in_msg = *(static_cast<const bool *>(need_finalize_actor_data));
|
||||
MS_LOG(INFO) << "Received a message which contains finalize command: " << finalize_in_msg;
|
||||
|
|
|
@ -297,14 +297,84 @@ size_t RecvActor::ParseDynamicShapeData(const std::string &dynamic_shape_data, A
|
|||
return offset;
|
||||
}
|
||||
|
||||
size_t RecvActor::ParseDynamicShapeData(const RpcDataPtr &dynamic_shape_data, size_t data_size,
|
||||
AbstractBasePtrList *args_spec_list, size_t count) {
|
||||
// The data which could be parsed by offset in dynamic shape scenario.
|
||||
auto data_to_be_parsed = dynamic_shape_data;
|
||||
// The real data offsets which will be used by RpcRecvKernel.
|
||||
std::vector<size_t> real_data_offsets;
|
||||
|
||||
// Once the magic header is dynamic shape, each input of the Recv is dynamic shape.
|
||||
// So traverse each input and parse the dynamic shape data.
|
||||
size_t offset = 0;
|
||||
for (size_t i = 0; i < count; i++) {
|
||||
if (data_to_be_parsed >= dynamic_shape_data + data_size) {
|
||||
MS_LOG(EXCEPTION) << "The dynamic shape data size is invalid.";
|
||||
}
|
||||
// Step 1: parse the magic header which indicates the dynamic shape.
|
||||
std::string dynamic_shape_magic_header(data_to_be_parsed, strlen(kRpcDynamicShapeData));
|
||||
if (dynamic_shape_magic_header != kRpcDynamicShapeData) {
|
||||
MS_LOG(EXCEPTION) << "The dynamie shape data must have the magic header RPC_DYNAMIC_SHAPE_DATA";
|
||||
}
|
||||
|
||||
// Step 2: parse the size of serialized protobuf message.
|
||||
data_to_be_parsed += strlen(kRpcDynamicShapeData);
|
||||
size_t pb_msg_size = 0;
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(memcpy_s(&pb_msg_size, sizeof(pb_msg_size), data_to_be_parsed, sizeof(size_t)) == 0,
|
||||
"memcpy_s protobuf message size failed.");
|
||||
|
||||
// Step 3: deserialize the protobuf message.
|
||||
data_to_be_parsed += sizeof(pb_msg_size);
|
||||
rpc::DynamicShapeMessage pb_msg;
|
||||
(void)pb_msg.ParseFromArray(data_to_be_parsed, SizeToInt(pb_msg_size));
|
||||
|
||||
// Step 4: parse the data shape and
|
||||
ShapeVector shapes(pb_msg.shape_vector().begin(), pb_msg.shape_vector().end());
|
||||
TypeId data_type = static_cast<TypeId>(pb_msg.type_id());
|
||||
data_to_be_parsed += pb_msg_size;
|
||||
|
||||
// Step 5: get the size of real data as recv's input.
|
||||
int64_t real_data_size = 1;
|
||||
if (!kernel::GetShapeSize(shapes, TypeIdToType(data_type), &real_data_size)) {
|
||||
MS_LOG(EXCEPTION) << "Getting shape size for shape " << shapes << " failed.";
|
||||
}
|
||||
data_to_be_parsed += real_data_size;
|
||||
|
||||
// Step 6: update the abstract.
|
||||
AddArgSpecForInput(args_spec_list, shapes, data_type, i);
|
||||
|
||||
offset += strlen(kRpcDynamicShapeData) + sizeof(pb_msg_size) + pb_msg_size;
|
||||
real_data_offsets.push_back(offset);
|
||||
offset += LongToSize(real_data_size);
|
||||
}
|
||||
|
||||
auto recv_kernel_mod = dynamic_cast<kernel::RpcRecvKernelMod *>(kernel_info_->MutableKernelMod());
|
||||
MS_EXCEPTION_IF_NULL(recv_kernel_mod);
|
||||
recv_kernel_mod->set_real_data_offset(real_data_offsets);
|
||||
return offset;
|
||||
}
|
||||
|
||||
void RecvActor::PreprocessRemoteInput(MessageBase *const msg, bool *need_finalize) {
|
||||
MS_EXCEPTION_IF_NULL(msg);
|
||||
MS_EXCEPTION_IF_NULL(need_finalize);
|
||||
if (msg->body.size() <= strlen(kRpcDynamicShapeData)) {
|
||||
|
||||
size_t data_size = 0;
|
||||
std::string msg_magic_header;
|
||||
RpcDataPtr dynamic_shape_data;
|
||||
if (common::GetEnv("use_void").empty()) {
|
||||
data_size = msg->body.size();
|
||||
msg_magic_header = msg->body.substr(0, strlen(kRpcDynamicShapeData));
|
||||
dynamic_shape_data = const_cast<RpcDataPtr>(msg->body.c_str());
|
||||
} else {
|
||||
MS_EXCEPTION_IF_NULL(msg->data);
|
||||
data_size = msg->size;
|
||||
msg_magic_header = std::string(static_cast<RpcDataPtr>(msg->data), strlen(kRpcDynamicShapeData));
|
||||
dynamic_shape_data = static_cast<RpcDataPtr>(msg->data);
|
||||
}
|
||||
if (data_size <= strlen(kRpcDynamicShapeData)) {
|
||||
MS_LOG(DEBUG) << "This is not a dynamic shape data. No need to preprocess.";
|
||||
return;
|
||||
}
|
||||
std::string msg_magic_header = msg->body.substr(0, strlen(kRpcDynamicShapeData));
|
||||
if (msg_magic_header != kRpcDynamicShapeData) {
|
||||
MS_LOG(DEBUG) << "This is not a dynamic shape data. No need to preprocess.";
|
||||
return;
|
||||
|
@ -313,7 +383,7 @@ void RecvActor::PreprocessRemoteInput(MessageBase *const msg, bool *need_finaliz
|
|||
MS_LOG(INFO) << "Preprocess for dynamic shape data.";
|
||||
AbstractBasePtrList args_spec_list;
|
||||
size_t input_size = common::AnfAlgo::GetInputTensorNum(kernel_);
|
||||
size_t dynamic_shape_data_msg_len = ParseDynamicShapeData(msg->body, &args_spec_list, input_size);
|
||||
size_t dynamic_shape_data_msg_len = ParseDynamicShapeData(dynamic_shape_data, data_size, &args_spec_list, input_size);
|
||||
ParseFinalizeReqData(dynamic_shape_data_msg_len, msg, need_finalize);
|
||||
|
||||
// The args_spec_list is updated in ParseDynamicShapeData method. So do the Infer and Resize operation.
|
||||
|
|
|
@ -106,6 +106,9 @@ class RecvActor : public RpcActor {
|
|||
size_t ParseDynamicShapeData(const std::string &dynamic_shape_data, AbstractBasePtrList *args_spec_list,
|
||||
size_t count);
|
||||
|
||||
size_t ParseDynamicShapeData(const RpcDataPtr &dynamic_shape_data, size_t data_size,
|
||||
AbstractBasePtrList *args_spec_list, size_t count);
|
||||
|
||||
// After Recv actor receives data from a remote peer, the data could be with dynamic shape so we need to preprocess
|
||||
// it, e.g., infer shape for RpcRecv kernel and call Resize().
|
||||
void PreprocessRemoteInput(MessageBase *const msg, bool *need_finalize);
|
||||
|
|
|
@ -25,5 +25,18 @@ void RpcActor::set_actor_route_table_proxy(const ActorRouteTableProxyPtr &proxy)
|
|||
void RpcActor::set_inter_process_edge_names(const std::vector<std::string> &edge_names) {
|
||||
inter_process_edge_names_ = edge_names;
|
||||
}
|
||||
|
||||
bool RpcActor::CopyRpcDataWithOffset(RpcDataPtr *rpc_data, const void *src_data, size_t src_data_size) const {
|
||||
MS_EXCEPTION_IF_NULL(rpc_data);
|
||||
MS_EXCEPTION_IF_NULL(*rpc_data);
|
||||
|
||||
int ret = memcpy_s(*rpc_data, src_data_size, src_data, src_data_size);
|
||||
if (EOK != ret) {
|
||||
MS_LOG(ERROR) << "Failed to memcpy_s for rpc data. Error number: " << ret;
|
||||
return false;
|
||||
}
|
||||
*rpc_data += src_data_size;
|
||||
return true;
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -45,6 +45,9 @@ constexpr char kInterProcessEdgeMark[] = "->";
|
|||
// The magic header of the rpc data which indicates this message contains dynamic shape data.
|
||||
constexpr char kRpcDynamicShapeData[] = "RPC_DYNAMIC_SHAPE_DATA";
|
||||
|
||||
// RpcDataPtr will be used for serializing and deserializing rpc message raw pointer data.
|
||||
using RpcDataPtr = char *;
|
||||
|
||||
// RpcActor is used to do rpc with other processes in distributed execution.
|
||||
// Besides data arrows and controlling arrows, RpcActor also has inter-process arrows which is in charge of remote
|
||||
// communication with other processes. It supports both sync and async communication.
|
||||
|
@ -82,6 +85,15 @@ class RpcActor : public KernelActor {
|
|||
const std::string &dst_node_name) {}
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @description: Copy rpc data with size and update the input data's address with offset.
|
||||
* @param {RpcDataPtr} *rpc_data: Destination data address which will be updated in this method.
|
||||
* @param {void} *src_data: Source data address.
|
||||
* @param {size_t} src_data_size: Source data size and the offset.
|
||||
* @return {bool}: Whether data is successfully copied.
|
||||
*/
|
||||
bool CopyRpcDataWithOffset(RpcDataPtr *rpc_data, const void *src_data, size_t src_data_size) const;
|
||||
|
||||
// The op context to run rpc actor inter-process op. Set by method 'SetOpcontext'.
|
||||
OpContext<DeviceTensor> *op_context_;
|
||||
|
||||
|
|
|
@ -97,82 +97,25 @@ void SendActor::EraseInput(const OpContext<DeviceTensor> *context) {
|
|||
}
|
||||
}
|
||||
|
||||
void SendActor::SerializeDynamicShapeMessgae(std::string *msg_body, const ShapeVector &shape_vec,
|
||||
const TypeId &data_type, const kernel::AddressPtr &addr) const {
|
||||
MS_EXCEPTION_IF_NULL(msg_body);
|
||||
MS_EXCEPTION_IF_NULL(addr);
|
||||
|
||||
rpc::DynamicShapeMessage pb_msg;
|
||||
pb_msg.set_type_id(static_cast<int>(data_type));
|
||||
*pb_msg.mutable_shape_vector() = {shape_vec.begin(), shape_vec.end()};
|
||||
std::string pb_msg_str = pb_msg.SerializeAsString();
|
||||
|
||||
// 1. Magic header for dynamic shape.
|
||||
(void)msg_body->append(kRpcDynamicShapeData);
|
||||
// 2. The size of the protobuf message DynamicShapeMessage.
|
||||
size_t pb_msg_size = pb_msg_str.size();
|
||||
(void)msg_body->append(reinterpret_cast<char *>(&pb_msg_size), sizeof(pb_msg_size));
|
||||
// 3. Protobuf message DynamicShapeMessage.
|
||||
(void)msg_body->append(pb_msg_str);
|
||||
// 4. The real data buffer of the input.
|
||||
(void)msg_body->append(static_cast<char *>(addr->addr), addr->size);
|
||||
}
|
||||
|
||||
std::unique_ptr<MessageBase> SendActor::BuildRpcMessage(const kernel::AddressPtrList &data_list,
|
||||
const std::string &server_url) {
|
||||
std::unique_ptr<MessageBase> message = std::make_unique<MessageBase>();
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(message, nullptr);
|
||||
message->to = AID("", server_url);
|
||||
|
||||
// To reach optimal performance, we use workspace memory as the data sent to the remote. So the size must be
|
||||
// strictly checked to avoid illegal memory access.
|
||||
auto send_workspace = launch_info_.workspaces_;
|
||||
if (send_workspace.empty()) {
|
||||
MS_LOG(EXCEPTION) << "RpcSendKernel's workspace should not be empty.";
|
||||
}
|
||||
// Only use one piece of workspace memory to avoid extra memory copying and serialize inputs data to one message.
|
||||
auto workspace_addr = send_workspace[kIndex0];
|
||||
if (is_dynamic_shape_) {
|
||||
MS_LOG(INFO) << "This send actor builds message with dynamic shape.";
|
||||
size_t input_size = common::AnfAlgo::GetInputTensorNum(kernel_);
|
||||
for (size_t i = 0; i < input_size; i++) {
|
||||
auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(kernel_, i, false);
|
||||
auto real_input = input_node_with_index.first;
|
||||
auto real_input_index = input_node_with_index.second;
|
||||
MS_EXCEPTION_IF_NULL(real_input);
|
||||
|
||||
auto shapes = trans::GetRuntimePaddingShape(real_input, real_input_index);
|
||||
for (const auto &shape : shapes) {
|
||||
MS_LOG(INFO) << "Shape of input " << real_input->fullname_with_scope() << " is " << shape;
|
||||
}
|
||||
TypeId data_type = common::AnfAlgo::GetOutputInferDataType(real_input, real_input_index);
|
||||
|
||||
// Serialize the message body and append the data.
|
||||
SerializeDynamicShapeMessgae(&message->body, shapes, data_type, data_list[i]);
|
||||
}
|
||||
SerializeDynamicShapeMessage(message.get(), data_list, workspace_addr);
|
||||
} else {
|
||||
size_t total_size = 0;
|
||||
total_size =
|
||||
std::accumulate(data_list.begin(), data_list.end(), total_size,
|
||||
[](size_t total_size, const kernel::AddressPtr &output) { return total_size + output->size; });
|
||||
auto send_workspace = launch_info_.workspaces_;
|
||||
if (send_workspace.empty()) {
|
||||
MS_LOG(EXCEPTION) << "RpcSendKernel workspace should not be empty.";
|
||||
}
|
||||
if (send_workspace[0]->size != total_size) {
|
||||
MS_LOG(EXCEPTION) << "Workspace size should be the same as inputs size. But got " << send_workspace[0]->size
|
||||
<< " and " << total_size;
|
||||
}
|
||||
|
||||
if (common::GetEnv("use_void").empty()) {
|
||||
message->body.reserve(total_size);
|
||||
for (const auto &data : data_list) {
|
||||
(void)message->body.append(static_cast<char *>(data->addr), data->size);
|
||||
}
|
||||
} else {
|
||||
size_t offset = 0;
|
||||
for (const auto &data : data_list) {
|
||||
if (EOK !=
|
||||
memcpy_s(static_cast<char *>(send_workspace[0]->addr) + offset, data->size, data->addr, data->size)) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s for send data failed.";
|
||||
}
|
||||
offset += data->size;
|
||||
}
|
||||
message->data = send_workspace[0]->addr;
|
||||
message->size = total_size;
|
||||
}
|
||||
SerializeCommonMessage(message.get(), data_list, workspace_addr);
|
||||
}
|
||||
return message;
|
||||
}
|
||||
|
@ -195,5 +138,123 @@ std::vector<DeviceTensor *> SendActor::FindDeviceTensorNeedsFree(void *data) {
|
|||
}
|
||||
return free_list;
|
||||
}
|
||||
|
||||
void SendActor::SerializeDynamicShapeMessgae(std::string *msg_body, const ShapeVector &shape_vec,
|
||||
const TypeId &data_type, const kernel::AddressPtr &addr) const {
|
||||
MS_EXCEPTION_IF_NULL(msg_body);
|
||||
MS_EXCEPTION_IF_NULL(addr);
|
||||
|
||||
rpc::DynamicShapeMessage pb_msg;
|
||||
pb_msg.set_type_id(static_cast<int>(data_type));
|
||||
*pb_msg.mutable_shape_vector() = {shape_vec.begin(), shape_vec.end()};
|
||||
std::string pb_msg_str = pb_msg.SerializeAsString();
|
||||
|
||||
// 1. Magic header for dynamic shape.
|
||||
(void)msg_body->append(kRpcDynamicShapeData);
|
||||
// 2. The size of the protobuf message DynamicShapeMessage.
|
||||
size_t pb_msg_size = pb_msg_str.size();
|
||||
(void)msg_body->append(reinterpret_cast<RpcDataPtr>(&pb_msg_size), sizeof(pb_msg_size));
|
||||
// 3. Protobuf message DynamicShapeMessage.
|
||||
(void)msg_body->append(pb_msg_str);
|
||||
// 4. The real data buffer of the input.
|
||||
(void)msg_body->append(static_cast<RpcDataPtr>(addr->addr), addr->size);
|
||||
}
|
||||
|
||||
size_t SendActor::SerializeSingleDynamicShapeInput(RpcDataPtr rpc_data, const ShapeVector &shape_vec,
|
||||
const TypeId &data_type, const kernel::AddressPtr &addr) const {
|
||||
MS_EXCEPTION_IF_NULL(rpc_data);
|
||||
MS_EXCEPTION_IF_NULL(addr);
|
||||
|
||||
// The serialize data size needs to be computed.
|
||||
size_t serialized_data_size = 0;
|
||||
|
||||
// Serialize data's meta info to protobuffer.
|
||||
rpc::DynamicShapeMessage pb_msg;
|
||||
pb_msg.set_type_id(static_cast<int>(data_type));
|
||||
*pb_msg.mutable_shape_vector() = {shape_vec.begin(), shape_vec.end()};
|
||||
std::string pb_msg_str = pb_msg.SerializeAsString();
|
||||
|
||||
// Part 1. Magic header for dynamic shape.
|
||||
size_t header_size = strlen(kRpcDynamicShapeData);
|
||||
if (!CopyRpcDataWithOffset(&rpc_data, kRpcDynamicShapeData, header_size)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to copy data for kRpcDynamicShapeData.";
|
||||
}
|
||||
serialized_data_size += header_size;
|
||||
|
||||
// Part 2. The size of the protobuf message DynamicShapeMessage.
|
||||
size_t pb_msg_size = pb_msg_str.size();
|
||||
if (!CopyRpcDataWithOffset(&rpc_data, &pb_msg_size, sizeof(pb_msg_size))) {
|
||||
MS_LOG(EXCEPTION) << "Failed to copy data for protobuffer data's size.";
|
||||
}
|
||||
serialized_data_size += sizeof(pb_msg_size);
|
||||
|
||||
// Part 3. Protobuf message DynamicShapeMessage.
|
||||
if (!CopyRpcDataWithOffset(&rpc_data, pb_msg_str.c_str(), pb_msg_str.size())) {
|
||||
MS_LOG(EXCEPTION) << "Failed to copy data for protobuffer data.";
|
||||
}
|
||||
serialized_data_size += pb_msg_str.size();
|
||||
|
||||
// Part 4. The real data buffer of the input.
|
||||
if (!CopyRpcDataWithOffset(&rpc_data, addr->addr, addr->size)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to copy data for real input data.";
|
||||
}
|
||||
serialized_data_size += addr->size;
|
||||
|
||||
return serialized_data_size;
|
||||
}
|
||||
|
||||
void SendActor::SerializeDynamicShapeMessage(MessageBase *message, const kernel::AddressPtrList &data_list,
|
||||
const kernel::AddressPtr &workspace_addr) const {
|
||||
size_t offset = 0;
|
||||
RpcDataPtr rpc_data = static_cast<RpcDataPtr>(workspace_addr->addr);
|
||||
size_t input_size = common::AnfAlgo::GetInputTensorNum(kernel_);
|
||||
for (size_t i = 0; i < input_size; i++) {
|
||||
auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(kernel_, i, false);
|
||||
auto real_input = input_node_with_index.first;
|
||||
auto real_input_index = input_node_with_index.second;
|
||||
MS_EXCEPTION_IF_NULL(real_input);
|
||||
|
||||
auto shapes = trans::GetRuntimePaddingShape(real_input, real_input_index);
|
||||
TypeId data_type = common::AnfAlgo::GetOutputInferDataType(real_input, real_input_index);
|
||||
|
||||
if (common::GetEnv("use_void").empty()) {
|
||||
SerializeDynamicShapeMessgae(&message->body, shapes, data_type, data_list[i]);
|
||||
} else {
|
||||
size_t serialized_data_size =
|
||||
SerializeSingleDynamicShapeInput(rpc_data + offset, shapes, data_type, data_list[i]);
|
||||
offset += serialized_data_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SendActor::SerializeCommonMessage(MessageBase *message, const kernel::AddressPtrList &data_list,
|
||||
const kernel::AddressPtr &workspace_addr) const {
|
||||
size_t total_size = 0;
|
||||
total_size =
|
||||
std::accumulate(data_list.begin(), data_list.end(), total_size,
|
||||
[](size_t total_size, const kernel::AddressPtr &output) { return total_size + output->size; });
|
||||
|
||||
if (common::GetEnv("use_void").empty()) {
|
||||
message->body.reserve(total_size);
|
||||
for (const auto &data : data_list) {
|
||||
(void)message->body.append(static_cast<RpcDataPtr>(data->addr), data->size);
|
||||
}
|
||||
} else {
|
||||
if (workspace_addr->size != total_size) {
|
||||
MS_LOG(EXCEPTION) << "Workspace size should be the same as inputs size. But got " << workspace_addr->size
|
||||
<< " and " << total_size;
|
||||
}
|
||||
|
||||
RpcDataPtr rpc_data = static_cast<RpcDataPtr>(workspace_addr->addr);
|
||||
for (size_t i = 0; i < data_list.size(); i++) {
|
||||
if (!CopyRpcDataWithOffset(&rpc_data, data_list[i]->addr, data_list[i]->size)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to copy data for rpc send input " << i;
|
||||
}
|
||||
}
|
||||
message->data = workspace_addr->addr;
|
||||
message->size = workspace_addr->size;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -63,10 +63,9 @@ class SendActor : public RpcActor {
|
|||
*/
|
||||
virtual bool FreeMessage(void *data);
|
||||
|
||||
// The tcp client connection to multiple servers.
|
||||
std::unique_ptr<TCPClient> client_;
|
||||
|
||||
OpContext<DeviceTensor> *context_;
|
||||
|
||||
private:
|
||||
/**
|
||||
* @description: Find the memory list needs to be freed after the data is sent to remote. This should be called by
|
||||
|
@ -82,8 +81,46 @@ class SendActor : public RpcActor {
|
|||
void SerializeDynamicShapeMessgae(std::string *msg_body, const ShapeVector &shape_vec, const TypeId &data_type,
|
||||
const kernel::AddressPtr &addr) const;
|
||||
|
||||
/**
|
||||
* @description: Serialize one dynamic shape input data to a piece of memory and returns the serialized data
|
||||
* size for accessing memory by offset.
|
||||
* The format is shown below:
|
||||
* |--------22 bytes------|---4 bytes--|PB data size bytes| data size bytes |
|
||||
* |RPC_DYNAMIC_SHAPE_DATA|PB data size| PB data | real data |
|
||||
* @param {RpcDataPtr} &rpc_data: A piece of memory which is allocated by the caller for serialized data to copy to.
|
||||
* @param {ShapeVector} &shape_vec: Input data's shape vector.
|
||||
* @param {TypeId} &data_type: Input data's type.
|
||||
* @param {AddressPtr} &addr: Input data's address and size.
|
||||
* @return {size_t}: Size of the serialized data.
|
||||
*/
|
||||
size_t SerializeSingleDynamicShapeInput(RpcDataPtr rpc_data, const ShapeVector &shape_vec, const TypeId &data_type,
|
||||
const kernel::AddressPtr &addr) const;
|
||||
|
||||
/**
|
||||
* @description: Serialize message with dynamic shape data. For each input in dynamic shape scenario, extra meta info
|
||||
* like data shape, data type will be serialized as protobuffer and copied to message.
|
||||
* @param {MessageBase} *message: MessageBase object.
|
||||
* @param {AddressPtrList} &data_list: The inputs data of rpc send kernel.
|
||||
* @return {void}
|
||||
*/
|
||||
void SerializeDynamicShapeMessage(MessageBase *message, const kernel::AddressPtrList &data_list,
|
||||
const kernel::AddressPtr &workspace_addr) const;
|
||||
|
||||
/**
|
||||
* @description: Serialize common message without extra info, which means: the data of raw pointer will be directly
|
||||
* copied to the message.
|
||||
* @param {MessageBase} *message: MessageBase object.
|
||||
* @param {AddressPtrList} &data_list: The inputs data of rpc send kernel.
|
||||
* @return {void}
|
||||
*/
|
||||
void SerializeCommonMessage(MessageBase *message, const kernel::AddressPtrList &data_list,
|
||||
const kernel::AddressPtr &workspace_addr) const;
|
||||
|
||||
friend class GraphScheduler;
|
||||
|
||||
// OpC ontext passed by graph scheduler.
|
||||
OpContext<DeviceTensor> *context_;
|
||||
|
||||
// This send actor's destination peers' actor ids and route table.
|
||||
std::vector<std::string> peer_actor_ids_;
|
||||
mindspore::HashMap<std::string, std::string> peer_actor_urls_;
|
||||
|
|
Loading…
Reference in New Issue