forked from mindspore-Ecosystem/mindspore
fix codedex warning
This commit is contained in:
parent
507b63ea20
commit
0b6b5e5123
|
@ -101,13 +101,20 @@ const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{
|
|||
{std::pair<TypeId, TypeId>(kNumberTypeInt64, kNumberTypeInt32), FROM_INT64_TO_INT32},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeUInt16, kNumberTypeInt32), FROM_UINT16_TO_INT32}};
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size) {
|
||||
auto src_id = TypeIdSize(args.src_type);
|
||||
auto dst_id = TypeIdSize(args.dst_type);
|
||||
if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) {
|
||||
void CheckMemSize(const TypeIdArgs &args) {
|
||||
auto src_type_size = TypeIdSize(args.host_data_type);
|
||||
auto dst_type_size = TypeIdSize(args.device_data_type);
|
||||
if (src_type_size < 1 || dst_type_size < 1) {
|
||||
MS_LOG(EXCEPTION) << "Invalid src or dst data type.";
|
||||
}
|
||||
if (args.data_size / src_type_size != args.host_shape_size) {
|
||||
MS_LOG(EXCEPTION) << "Invalid src or dst data size.";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size) {
|
||||
CheckMemSize(args);
|
||||
for (size_t idx = 0; idx != data_size; idx++) {
|
||||
SrcT src_data = static_cast<const SrcT *>(args.data)[idx];
|
||||
static_cast<DstT *>(dst)[idx] = static_cast<DstT>(src_data);
|
||||
|
@ -116,11 +123,7 @@ void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size)
|
|||
|
||||
template <typename SrcT>
|
||||
void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const size_t data_size) {
|
||||
auto src_id = TypeIdSize(args.src_type);
|
||||
auto dst_id = TypeIdSize(args.dst_type);
|
||||
if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) {
|
||||
MS_LOG(EXCEPTION) << "Invalid src or dst data size.";
|
||||
}
|
||||
CheckMemSize(args);
|
||||
auto src_data = static_cast<const SrcT *>(args.data);
|
||||
auto half_data = static_cast<Eigen::half *>(dst);
|
||||
for (size_t i = 0; i < data_size; i++) {
|
||||
|
@ -394,27 +397,18 @@ bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
|
|||
}
|
||||
|
||||
bool TransDataType(const TypeIdArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.src_type) << " to " << TypeIdLabel(args.dst_type);
|
||||
MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.host_data_type) << " to "
|
||||
<< TypeIdLabel(args.device_data_type);
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
std::pair<TypeId, TypeId> type_info(args.src_type, args.dst_type);
|
||||
std::pair<TypeId, TypeId> type_info(args.host_data_type, args.device_data_type);
|
||||
auto iter = mode_map.find(type_info);
|
||||
if (iter == mode_map.end()) {
|
||||
MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.src_type)
|
||||
<< ", dst_type:" << TypeIdLabel(args.dst_type);
|
||||
MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.host_data_type)
|
||||
<< ", dst_type:" << TypeIdLabel(args.device_data_type);
|
||||
return false;
|
||||
}
|
||||
auto trans_mode = iter->second;
|
||||
auto src_id = TypeIdSize(args.src_type);
|
||||
auto dst_id = TypeIdSize(args.dst_type);
|
||||
if (src_id < 1 || dst_id < 1) {
|
||||
MS_LOG(ERROR) << "Invalid src or dst data type.";
|
||||
return false;
|
||||
}
|
||||
if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) {
|
||||
MS_LOG(ERROR) << "Invalid src or dst data size.";
|
||||
return false;
|
||||
}
|
||||
if (!CastKernel(args, result, args.dst_shape_size, trans_mode)) {
|
||||
if (!CastKernel(args, result, args.host_shape_size, trans_mode)) {
|
||||
MS_LOG(ERROR) << "Failed to trans datatype..";
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -31,12 +31,10 @@ namespace mindspore {
|
|||
namespace trans {
|
||||
struct TypeIdArgs {
|
||||
const void *data;
|
||||
size_t src_size;
|
||||
size_t dst_size;
|
||||
TypeId src_type;
|
||||
TypeId dst_type;
|
||||
size_t src_shape_size;
|
||||
size_t dst_shape_size;
|
||||
size_t host_shape_size; // Multiply each dimension elements. [a, b, c, d] => a*b*c*d
|
||||
TypeId host_data_type;
|
||||
TypeId device_data_type;
|
||||
size_t data_size;
|
||||
};
|
||||
|
||||
struct FormatArgs {
|
||||
|
|
|
@ -104,10 +104,10 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t
|
|||
} else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) {
|
||||
sync_ok = SyncDeviceToHostAndFloatToFloat64(host_ptr, size, ptr_, size_);
|
||||
} else {
|
||||
auto host_size = trans::ShapeSize(host_shape);
|
||||
auto shape_size = trans::ShapeSize(host_shape);
|
||||
auto host = std::vector<uint8_t>(size_);
|
||||
SyncMemory(host.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST);
|
||||
const trans::TypeIdArgs type_args{host.data(), size_, size, type_id_, type, host_size, host_size};
|
||||
const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, size};
|
||||
sync_ok = trans::TransDataType(type_args, host_ptr);
|
||||
if (!sync_ok) {
|
||||
MS_LOG(ERROR) << "trans data type failed.";
|
||||
|
@ -156,9 +156,8 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int
|
|||
MS_LOG(ERROR) << "Trans format failed.";
|
||||
return false;
|
||||
}
|
||||
auto host_size = trans::ShapeSize(host_shape);
|
||||
auto device_size = trans::ShapeSize(device_shape);
|
||||
const trans::TypeIdArgs type_args{host.data(), size_, size, type_id_, type, device_size, host_size};
|
||||
auto shape_size = trans::ShapeSize(host_shape);
|
||||
const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, size};
|
||||
sync_ok = trans::TransDataType(type_args, host_ptr);
|
||||
if (!sync_ok) {
|
||||
MS_LOG(ERROR) << "Trans format failed.";
|
||||
|
@ -193,8 +192,8 @@ bool AscendDeviceAddress::SyncHostToDevice(const std::vector<int> &shape, size_t
|
|||
} else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) {
|
||||
sync_ok = Float64ToFloatAndSyncHostToDevice(ptr_, size_, host_ptr, size);
|
||||
} else {
|
||||
auto host_size = trans::ShapeSize(host_shape);
|
||||
const trans::TypeIdArgs type_args{host_ptr, size, size_, type, type_id_, host_size, host_size};
|
||||
auto shape_size = trans::ShapeSize(host_shape);
|
||||
const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_, size};
|
||||
auto host_tmp = std::vector<uint8_t>(size_);
|
||||
sync_ok = trans::TransDataType(type_args, host_tmp.data());
|
||||
if (!sync_ok) {
|
||||
|
@ -235,9 +234,8 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector<int
|
|||
device_shape = trans::TransShapeToDevice(host_shape, format_);
|
||||
}
|
||||
if (type_id_ != type) {
|
||||
auto host_size = trans::ShapeSize(host_shape);
|
||||
auto device_size = trans::ShapeSize(device_shape);
|
||||
const trans::TypeIdArgs type_args{host_ptr, size, size_, type, type_id_, host_size, device_size};
|
||||
auto shape_size = trans::ShapeSize(host_shape);
|
||||
const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_, size};
|
||||
auto host_tmp = std::vector<uint8_t>(size_);
|
||||
sync_ok = trans::TransDataType(type_args, host_tmp.data());
|
||||
if (!sync_ok) {
|
||||
|
|
Loading…
Reference in New Issue