forked from OSSInnovation/mindspore
!414 add 6d format transfer
Merge pull request !414 from liubuyu/dev_lby
This commit is contained in:
commit
41c969ab00
|
@ -231,7 +231,98 @@ std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std
|
|||
return shape_4d;
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool CheckDims(const std::vector<size_t> &shape) {
|
||||
if (shape.size() != 4) {
|
||||
MS_LOG(ERROR) << "Host shape dims shoud be 4";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<size_t> NchwDeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Ccheck dims failed.";
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
device_shape.push_back(shape[0]);
|
||||
device_shape.push_back(shape[2]);
|
||||
device_shape.push_back(shape[3]);
|
||||
device_shape.push_back(shape[1]);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
device_shape.push_back(shape[2]);
|
||||
device_shape.push_back(shape[3]);
|
||||
device_shape.push_back(shape[1]);
|
||||
device_shape.push_back(shape[0]);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
size_t cout16 = ((shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize;
|
||||
size_t cin16 = ((shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize;
|
||||
device_shape.push_back(shape[2] * shape[3] * cin16 / kCubeSize);
|
||||
device_shape.push_back(cout16 / kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
|
||||
size_t C0 = kCubeSize;
|
||||
device_shape.push_back(shape[0]);
|
||||
device_shape.push_back(C1);
|
||||
device_shape.push_back(shape[2]);
|
||||
device_shape.push_back(shape[3]);
|
||||
device_shape.push_back(C0);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
device_shape.push_back((shape[1] - 1) / kCubeSize + 1);
|
||||
device_shape.push_back(shape[2]);
|
||||
device_shape.push_back(shape[3]);
|
||||
device_shape.push_back(shape[0]);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) {
|
||||
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
|
||||
const std::map<std::string, DeviceShapeTransfer> device_shape_map{
|
||||
{kOpFormat_NCHW, NchwDeviceShape}, {kOpFormat_NHWC, NhwcDeviceShape},
|
||||
{kOpFormat_HWCN, HwchDeviceShape}, {kOpFormat_FRAC_Z, FracZDeviceShape},
|
||||
{kOpFormat_NC1HWC0, Nc1hwc0DeviceShape}, {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape},
|
||||
};
|
||||
|
||||
if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
|
||||
return shape;
|
||||
}
|
||||
|
@ -255,37 +346,31 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
|||
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
|
||||
temp_shape = PaddingShapeTo4dByDefault(shape);
|
||||
}
|
||||
if (format == kOpFormat_NC1HWC0) {
|
||||
size_t C1 = (temp_shape[1] + kCubeSize - 1) / kCubeSize;
|
||||
size_t C0 = kCubeSize;
|
||||
device_shape.push_back(temp_shape[0]);
|
||||
device_shape.push_back(C1);
|
||||
device_shape.push_back(temp_shape[2]);
|
||||
device_shape.push_back(temp_shape[3]);
|
||||
device_shape.push_back(C0);
|
||||
return device_shape;
|
||||
} else if (format == kOpFormat_FRAC_Z) {
|
||||
size_t cout16 = ((temp_shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize;
|
||||
size_t cin16 = ((temp_shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize;
|
||||
device_shape.push_back(temp_shape[2] * temp_shape[3] * cin16 / kCubeSize);
|
||||
device_shape.push_back(cout16 / kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
} else if (format == kOpFormat_NHWC) {
|
||||
device_shape.push_back(temp_shape[0]);
|
||||
device_shape.push_back(temp_shape[2]);
|
||||
device_shape.push_back(temp_shape[3]);
|
||||
device_shape.push_back(temp_shape[1]);
|
||||
return device_shape;
|
||||
} else if (format == kOpFormat_HWCN) {
|
||||
return {temp_shape[2], temp_shape[3], temp_shape[1], temp_shape[0]};
|
||||
} else if (format == kOpFormat_NCHW) {
|
||||
return temp_shape;
|
||||
auto iter = device_shape_map.find(format);
|
||||
if (iter != device_shape_map.end()) {
|
||||
return iter->second(temp_shape);
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]";
|
||||
}
|
||||
|
||||
bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
|
||||
if (args.host_shape.size() != kNchwDims) {
|
||||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
*size = TypeIdSize(args.src_data_type);
|
||||
if (*size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
*total_size = ShapeSize(args.device_shape) * (*size);
|
||||
if (*total_size != args.device_size) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << *total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TransDataType(const TypeIdArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.host_data_type) << " to "
|
||||
<< TypeIdLabel(args.device_data_type);
|
||||
|
@ -320,13 +405,14 @@ bool TransFormat(const FormatArgs &args, void *result) {
|
|||
MS_LOG(ERROR) << "Invalid datatype..";
|
||||
return false;
|
||||
}
|
||||
if ((args.host_format == kOpFormat_NCHW || args.host_format == kOpFormat_ND) &&
|
||||
args.device_format == kOpFormat_FRAC_Z) {
|
||||
if (args.device_format == kOpFormat_FRAC_Z) {
|
||||
return NchwToFracZ(args, result);
|
||||
} else if (args.device_format == kOpFormat_FRAC_NZ) {
|
||||
return NchwToFracNz(args, result);
|
||||
} else if (args.device_format == kOpFormat_NC1HWC0) {
|
||||
return NchwToNc1hwc0(args, result);
|
||||
} else if (args.device_format == kOpFormat_C1HWNCoC0) {
|
||||
return NchwToC1hwncoc0(args, result);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -337,13 +423,14 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
|
|||
MS_LOG(ERROR) << "Invalid datatype..";
|
||||
return false;
|
||||
}
|
||||
if ((args.host_format == kOpFormat_NCHW || args.host_format == kOpFormat_ND) &&
|
||||
args.device_format == kOpFormat_FRAC_Z) {
|
||||
if (args.device_format == kOpFormat_FRAC_Z) {
|
||||
return FracZToNchw(args, result);
|
||||
} else if (args.device_format == kOpFormat_FRAC_NZ) {
|
||||
return FracNzToNchw(args, result);
|
||||
} else if (args.device_format == kOpFormat_NC1HWC0) {
|
||||
return Nc1hwc0ToNchw(args, result);
|
||||
} else if (args.device_format == kOpFormat_C1HWNCoC0) {
|
||||
return C1hwncoc0ToNchw(args, result);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -801,5 +888,99 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool NchwToC1hwncoc0(const FormatArgs &args, void *result) {
|
||||
// trans nchw to c1hwncoc0
|
||||
MS_LOG(DEBUG) << "Trans format from nchw to c1hwncoc0.";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
size_t size = 0;
|
||||
size_t total_size = 0;
|
||||
if (!CheckArgs(args, &size, &total_size)) {
|
||||
MS_LOG(ERROR) << "Check args failed.";
|
||||
return false;
|
||||
}
|
||||
auto n = args.host_shape[0];
|
||||
auto c = args.host_shape[1];
|
||||
auto h = args.host_shape[2];
|
||||
auto w = args.host_shape[3];
|
||||
auto c1 = args.device_shape[0];
|
||||
auto co = args.device_shape[4];
|
||||
auto c0 = args.device_shape[5];
|
||||
for (size_t c1_i = 0; c1_i < c1; c1_i++) {
|
||||
for (size_t h_i = 0; h_i < h; h_i++) {
|
||||
for (size_t w_i = 0; w_i < w; w_i++) {
|
||||
for (size_t n_i = 0; n_i < n; n_i++) {
|
||||
for (size_t co_i = 0; co_i < co; co_i++) {
|
||||
for (size_t c0_i = 0; c0_i < c0; c0_i++) {
|
||||
size_t dst_offset = (c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 +
|
||||
n_i * co * c0 + co_i * c0 + c0_i) *
|
||||
size;
|
||||
size_t protected_size = total_size - dst_offset < static_cast<size_t>(SECUREC_MEM_MAX_LEN)
|
||||
? total_size - dst_offset
|
||||
: static_cast<size_t>(SECUREC_MEM_MAX_LEN);
|
||||
size_t c_i = c0_i + c1_i * c0;
|
||||
size_t src_offset = (n_i * c * h * w + c_i * h * w + h_i * w + w_i) * size;
|
||||
error_t ret;
|
||||
if (c_i < c && c0_i == co_i) {
|
||||
ret = memcpy_s(static_cast<uint8_t *>(result) + dst_offset, protected_size,
|
||||
static_cast<uint8_t const *>(args.data) + src_offset, size);
|
||||
} else {
|
||||
ret = memset_s(static_cast<uint8_t *>(result) + dst_offset, protected_size, 0, size);
|
||||
}
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "Failed to operate the dst memory, error-code:" << ret;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) {
|
||||
// trans c1hwncoc0 to nchw
|
||||
MS_LOG(DEBUG) << "Trans format from c1hwncoc0 to nchw";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
size_t size = 0;
|
||||
size_t total_size = 0;
|
||||
if (!CheckArgs(args, &size, &total_size)) {
|
||||
MS_LOG(ERROR) << "Check args failed.";
|
||||
return false;
|
||||
}
|
||||
auto n = args.host_shape[0];
|
||||
auto c = args.host_shape[1];
|
||||
auto h = args.host_shape[2];
|
||||
auto w = args.host_shape[3];
|
||||
auto co = args.device_shape[4];
|
||||
auto c0 = args.device_shape[5];
|
||||
for (size_t n_i = 0; n_i < n; n_i++) {
|
||||
for (size_t c_i = 0; c_i < c; c_i++) {
|
||||
for (size_t h_i = 0; h_i < h; h_i++) {
|
||||
for (size_t w_i = 0; w_i < w; w_i++) {
|
||||
size_t dst_offset = (n_i * c * h * w + c_i * h * w + h_i * w + w_i) * size;
|
||||
size_t c1_i = c_i / kCubeSize;
|
||||
size_t c0_i = c_i % kCubeSize;
|
||||
size_t co_i = c0_i;
|
||||
size_t src_offset = (c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 +
|
||||
co_i * c0 + c0_i) *
|
||||
size;
|
||||
size_t protected_size = total_size - dst_offset < static_cast<size_t>(SECUREC_MEM_MAX_LEN)
|
||||
? total_size - dst_offset
|
||||
: static_cast<size_t>(SECUREC_MEM_MAX_LEN);
|
||||
auto ret = memcpy_s(static_cast<uint8_t *>(result) + dst_offset, protected_size,
|
||||
static_cast<uint8_t const *>(args.data) + src_offset, size);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "Failed to operate the dst memory, error-code:" << ret;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace trans
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -63,10 +63,12 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result);
|
|||
bool NchwToFracZ(const FormatArgs &args, void *result);
|
||||
bool NchwToFracNz(const FormatArgs &args, void *result);
|
||||
bool NchwToNc1hwc0(const FormatArgs &args, void *result);
|
||||
bool NchwToC1hwncoc0(const FormatArgs &args, void *result);
|
||||
// device to host
|
||||
bool FracZToNchw(const FormatArgs &args, void *result);
|
||||
bool FracNzToNchw(const FormatArgs &args, void *result);
|
||||
bool Nc1hwc0ToNchw(const FormatArgs &args, void *result);
|
||||
bool C1hwncoc0ToNchw(const FormatArgs &args, void *result);
|
||||
} // namespace trans
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -114,8 +114,11 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t
|
|||
return false;
|
||||
}
|
||||
}
|
||||
} else if (format_ == kOpFormat_NC1HWC0 || format_ == kOpFormat_FRAC_Z || format_ == kOpFormat_FRAC_NZ) {
|
||||
sync_ok = SyncDeviceToHostAndConvertFormat(shape, size, type, host_ptr);
|
||||
} else {
|
||||
auto iter = kNeedTransFormatSet.find(format_);
|
||||
if (iter != kNeedTransFormatSet.end()) {
|
||||
sync_ok = ConvertFormatAndSyncHostToDevice(shape, size, type, host_ptr);
|
||||
}
|
||||
}
|
||||
if (!sync_ok) {
|
||||
MS_LOG(ERROR) << "Not support to trans, dev_format:" << format_ << ", dev_type:" << TypeIdLabel(type_id_)
|
||||
|
@ -199,8 +202,11 @@ bool AscendDeviceAddress::SyncHostToDevice(const std::vector<int> &shape, size_t
|
|||
}
|
||||
SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE);
|
||||
}
|
||||
} else if (format_ == kOpFormat_NC1HWC0 || format_ == kOpFormat_FRAC_Z || format_ == kOpFormat_FRAC_NZ) {
|
||||
sync_ok = ConvertFormatAndSyncHostToDevice(shape, size, type, host_ptr);
|
||||
} else {
|
||||
auto iter = kNeedTransFormatSet.find(format_);
|
||||
if (iter != kNeedTransFormatSet.end()) {
|
||||
sync_ok = ConvertFormatAndSyncHostToDevice(shape, size, type, host_ptr);
|
||||
}
|
||||
}
|
||||
if (!sync_ok) {
|
||||
MS_LOG(ERROR) << "Not support to trans, dev_format:" << format_ << ", dev_type:" << TypeIdLabel(type_id_)
|
||||
|
|
|
@ -186,8 +186,10 @@ constexpr auto kOpFormat_FRAC_Z = "FracZ";
|
|||
constexpr auto kOpFormat_FRAC_NZ = "FRACTAL_NZ";
|
||||
constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0";
|
||||
constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04";
|
||||
const std::set<std::string> k1DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC,
|
||||
kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0};
|
||||
const std::set<std::string> k1DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC,
|
||||
kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
|
||||
kOpFormat_C1HWNCoC0};
|
||||
|
||||
const std::set<std::string> k2DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_FRAC_Z,
|
||||
kOpFormat_NC1KHKWHWC0};
|
||||
const std::set<std::string> k3DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0};
|
||||
|
|
Loading…
Reference in New Issue