!487 add dtype trans template

Merge pull request !487 from liubuyu/master
This commit is contained in:
mindspore-ci-bot 2020-04-22 14:29:36 +08:00 committed by Gitee
commit 715c0735a8
3 changed files with 56 additions and 29 deletions

View File

@ -103,17 +103,39 @@ const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{
template <typename SrcT, typename DstT>
void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size) {
auto src_id = TypeIdSize(args.src_type);
auto dst_id = TypeIdSize(args.dst_type);
if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) {
MS_LOG(EXCEPTION) << "Invalid src or dst data size.";
}
for (size_t idx = 0; idx != data_size; idx++) {
SrcT src_data = static_cast<const SrcT *>(args.data)[idx];
static_cast<DstT *>(dst)[idx] = static_cast<DstT>(src_data);
}
}
template <typename SrcT>
void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const size_t data_size) {
auto src_id = TypeIdSize(args.src_type);
auto dst_id = TypeIdSize(args.dst_type);
if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) {
MS_LOG(EXCEPTION) << "Invalid src or dst data size.";
}
auto src_data = static_cast<const SrcT *>(args.data);
auto half_data = static_cast<Eigen::half *>(dst);
for (size_t i = 0; i < data_size; i++) {
half_data[i] = Eigen::half(src_data[i]);
}
}
bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const DataTypeTransMode mode) {
switch (mode) {
case FROM_FLOAT_TO_FLOAT16:
device::FloatToHalf(dst, args.data, data_size);
break;
case FROM_INT32_TO_FLOAT16:
TransDataSrc2Fp16<int32_t>(args, dst, data_size);
break;
case FROM_FLOAT16_TO_FLOAT:
device::HalfToFloat(dst, args.data, data_size);
break;
@ -372,27 +394,27 @@ bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
}
bool TransDataType(const TypeIdArgs &args, void *result) {
MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.host_data_type) << " to "
<< TypeIdLabel(args.device_data_type);
MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.src_type) << " to " << TypeIdLabel(args.dst_type);
MS_EXCEPTION_IF_NULL(result);
std::pair<TypeId, TypeId> type_info(args.host_data_type, args.device_data_type);
std::pair<TypeId, TypeId> type_info(args.src_type, args.dst_type);
auto iter = mode_map.find(type_info);
if (iter == mode_map.end()) {
MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.host_data_type)
<< ", dst_type:" << TypeIdLabel(args.device_data_type);
MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.src_type)
<< ", dst_type:" << TypeIdLabel(args.dst_type);
return false;
}
auto trans_mode = iter->second;
auto type_size = TypeIdSize(args.device_data_type);
if (type_size < 1) {
MS_LOG(ERROR) << "Invalid host data type.";
auto src_id = TypeIdSize(args.src_type);
auto dst_id = TypeIdSize(args.dst_type);
if (src_id < 1 || dst_id < 1) {
MS_LOG(ERROR) << "Invalid src or dst data type.";
return false;
}
if (args.host_shape_size < 1) {
MS_LOG(ERROR) << "Invalid host data size.";
if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) {
MS_LOG(ERROR) << "Invalid src or dst data size.";
return false;
}
if (!CastKernel(args, result, args.host_shape_size, trans_mode)) {
if (!CastKernel(args, result, args.dst_shape_size, trans_mode)) {
MS_LOG(ERROR) << "Failed to trans datatype..";
return false;
}

View File

@ -31,9 +31,12 @@ namespace mindspore {
namespace trans {
struct TypeIdArgs {
const void *data;
size_t host_shape_size; // Multiply each dimension elements. [a, b, c, d] => a*b*c*d
TypeId host_data_type;
TypeId device_data_type;
size_t src_size;
size_t dst_size;
TypeId src_type;
TypeId dst_type;
size_t src_shape_size;
size_t dst_shape_size;
};
struct FormatArgs {

View File

@ -104,10 +104,10 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t
} else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) {
sync_ok = SyncDeviceToHostAndFloatToFloat64(host_ptr, size, ptr_, size_);
} else {
auto shape_size = trans::ShapeSize(host_shape);
auto host_size = trans::ShapeSize(host_shape);
auto host = std::vector<uint8_t>(size_);
SyncMemory(host.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST);
const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type};
const trans::TypeIdArgs type_args{host.data(), size_, size, type_id_, type, host_size, host_size};
sync_ok = trans::TransDataType(type_args, host_ptr);
if (!sync_ok) {
MS_LOG(ERROR) << "trans data type failed.";
@ -153,14 +153,15 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int
auto host = std::vector<uint8_t>(size_);
sync_ok = trans::TransFormatFromDeviceToHost(format_args, host.data());
if (!sync_ok) {
MS_LOG(ERROR) << "trans format failed.";
MS_LOG(ERROR) << "Trans format failed.";
return false;
}
auto shape_size = trans::ShapeSize(host_shape);
const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type};
auto host_size = trans::ShapeSize(host_shape);
auto device_size = trans::ShapeSize(device_shape);
const trans::TypeIdArgs type_args{host.data(), size_, size, type_id_, type, device_size, host_size};
sync_ok = trans::TransDataType(type_args, host_ptr);
if (!sync_ok) {
MS_LOG(ERROR) << "trans format failed.";
MS_LOG(ERROR) << "Trans format failed.";
return false;
}
} else {
@ -168,7 +169,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int
host_shape, device_shape, type_id_};
sync_ok = trans::TransFormatFromDeviceToHost(format_args, host_ptr);
if (!sync_ok) {
MS_LOG(ERROR) << "trans format failed.";
MS_LOG(ERROR) << "Trans format failed.";
return false;
}
}
@ -192,12 +193,12 @@ bool AscendDeviceAddress::SyncHostToDevice(const std::vector<int> &shape, size_t
} else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) {
sync_ok = Float64ToFloatAndSyncHostToDevice(ptr_, size_, host_ptr, size);
} else {
auto shape_size = trans::ShapeSize(host_shape);
const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_};
auto host_size = trans::ShapeSize(host_shape);
const trans::TypeIdArgs type_args{host_ptr, size, size_, type, type_id_, host_size, host_size};
auto host_tmp = std::vector<uint8_t>(size_);
sync_ok = trans::TransDataType(type_args, host_tmp.data());
if (!sync_ok) {
MS_LOG(ERROR) << "trans data type failed.";
MS_LOG(ERROR) << "Trans data type failed.";
return false;
}
SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE);
@ -234,12 +235,13 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector<int
device_shape = trans::TransShapeToDevice(host_shape, format_);
}
if (type_id_ != type) {
auto shape_size = trans::ShapeSize(host_shape);
const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_};
auto host_size = trans::ShapeSize(host_shape);
auto device_size = trans::ShapeSize(device_shape);
const trans::TypeIdArgs type_args{host_ptr, size, size_, type, type_id_, host_size, device_size};
auto host_tmp = std::vector<uint8_t>(size_);
sync_ok = trans::TransDataType(type_args, host_tmp.data());
if (!sync_ok) {
MS_LOG(ERROR) << "trans datatype failed.";
MS_LOG(ERROR) << "Trans datatype failed.";
return false;
}
const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_,
@ -247,7 +249,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector<int
auto dst_tmp = std::vector<uint8_t>(size_);
sync_ok = trans::TransFormat(format_args, dst_tmp.data());
if (!sync_ok) {
MS_LOG(ERROR) << "trans format failed.";
MS_LOG(ERROR) << "Trans format failed.";
return false;
}
SyncMemory(ptr_, dst_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE);
@ -256,7 +258,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector<int
auto host_tmp = std::vector<uint8_t>(size_);
sync_ok = trans::TransFormat(format_args, host_tmp.data());
if (!sync_ok) {
MS_LOG(ERROR) << "trans format failed.";
MS_LOG(ERROR) << "Trans format failed.";
return false;
}
SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE);