forked from mindspore-Ecosystem/mindspore
support ncdhw <-> frac_z_3d data trans at host
This commit is contained in:
parent
7799c86797
commit
d4050cf98d
|
@ -200,7 +200,7 @@ size_t CubeSizeByType(const TypeId data_type) {
|
|||
namespace {
|
||||
bool CheckDims(const std::vector<size_t> &shape) {
|
||||
if (shape.size() != kNchwDims) {
|
||||
MS_LOG(ERROR) << "Host shape dims shoud be 4";
|
||||
MS_LOG(ERROR) << "Host shape dims should be 4";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
@ -370,7 +370,7 @@ std::vector<size_t> PaddingShapeTo4dByDefault(const std::vector<size_t> &shape)
|
|||
std::copy(shape.begin(), shape.end(), shape_4d.begin());
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Unexpect shape size = " << shape.size();
|
||||
MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size();
|
||||
}
|
||||
return shape_4d;
|
||||
}
|
||||
|
@ -545,7 +545,8 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
|
|||
const std::map<std::string, FormatTransfer> format_trans_map{
|
||||
{kOpFormat_FRAC_Z, FracZToNchw}, {kOpFormat_FRAC_NZ, FracNzToNchw},
|
||||
{kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw},
|
||||
{kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}, {kOpFormat_NDC1HWC0, Ndc1hwc0ToNcdhw}};
|
||||
{kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}, {kOpFormat_NDC1HWC0, Ndc1hwc0ToNcdhw},
|
||||
{kOpFormat_FRACTAL_Z_3D, FracZ3DToNcdhw}};
|
||||
|
||||
MS_LOG(DEBUG) << "Start trans format.";
|
||||
if (abstract::TypeIdSize(args.src_data_type) < 1) {
|
||||
|
@ -1248,5 +1249,119 @@ bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result) {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool NcdhwToFracZ3D(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "Trans from ncdhw to frac_z_3d";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
|
||||
if (args.host_shape.size() != 5) {
|
||||
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
|
||||
return false;
|
||||
}
|
||||
auto size = abstract::TypeIdSize(args.src_data_type);
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
auto total_size = abstract::ShapeSize(args.device_shape) * size;
|
||||
if (total_size != args.device_size) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto n = args.host_shape[0];
|
||||
auto c = args.host_shape[1];
|
||||
auto d = args.host_shape[2];
|
||||
auto h = args.host_shape[3];
|
||||
auto w = args.host_shape[4];
|
||||
|
||||
auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize;
|
||||
auto c0 = CubeSizeByType(args.src_data_type);
|
||||
auto c1 = DivCeil(c, c0);
|
||||
auto hw = h * w;
|
||||
auto dhw = d * hw;
|
||||
auto cdhw = c * dhw;
|
||||
auto n1n0c0 = n1n0 * c0;
|
||||
auto wn1n0c0 = w * n1n0c0;
|
||||
auto hwn1n0c0 = h * wn1n0c0;
|
||||
auto dhwn1n0c0 = d * hwn1n0c0;
|
||||
|
||||
for (size_t c1_i = 0; c1_i < c1; c1_i++) {
|
||||
for (size_t d_i = 0; d_i < d; d_i++) {
|
||||
for (size_t h_i = 0; h_i < h; h_i++) {
|
||||
for (size_t w_i = 0; w_i < w; w_i++) {
|
||||
for (size_t n1n0_i = 0; n1n0_i < n1n0; n1n0_i++) {
|
||||
for (size_t c0_i = 0; c0_i < c0; c0_i++) {
|
||||
size_t dst_i = c1_i * dhwn1n0c0 + d_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + n1n0_i * c0 + c0_i;
|
||||
// ncdhw
|
||||
size_t src_i = n1n0_i * cdhw + (c1_i * c0 + c0_i) * dhw + d_i * hw + h_i * w + w_i;
|
||||
auto pad_zero = ((c1_i * c0 + c0_i) >= c) || (n1n0_i >= n);
|
||||
SetData(size, pad_zero, src_i, dst_i, args, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FracZ3DToNcdhw(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "Trans from frac_z_3d to ncdhw";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
|
||||
if (args.host_shape.size() != 5) {
|
||||
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
|
||||
return false;
|
||||
}
|
||||
auto size = abstract::TypeIdSize(args.src_data_type);
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
auto total_size = abstract::ShapeSize(args.device_shape) * size;
|
||||
if (total_size != args.device_size) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
auto n = args.host_shape[0];
|
||||
auto c = args.host_shape[1];
|
||||
auto d = args.host_shape[2];
|
||||
auto h = args.host_shape[3];
|
||||
auto w = args.host_shape[4];
|
||||
auto n0 = args.device_shape[1];
|
||||
auto ni = args.device_shape[1];
|
||||
auto c0 = args.device_shape[3];
|
||||
auto hw = h * w;
|
||||
auto dhw = d * hw;
|
||||
auto cdhw = c * dhw;
|
||||
auto nc = ni * n0;
|
||||
auto ncc0 = nc * c0;
|
||||
auto wncc0 = w * ncc0;
|
||||
auto hwncc0 = h * wncc0;
|
||||
auto dhwncc0 = d * hwncc0;
|
||||
|
||||
for (size_t n_i = 0; n_i < n; n_i++) {
|
||||
size_t n_head = n_i * cdhw;
|
||||
for (size_t c_i = 0; c_i < c; c_i++) {
|
||||
size_t c_head = n_head + c_i * dhw;
|
||||
for (size_t d_i = 0; d_i < d; d_i++) {
|
||||
size_t d_head = c_head + d_i * hw;
|
||||
for (size_t h_i = 0; h_i < h; h_i++) {
|
||||
size_t h_head = d_head + h_i * w;
|
||||
for (size_t w_i = 0; w_i < w; w_i++) {
|
||||
size_t dst_i = h_head + w_i;
|
||||
size_t c1_i = c_i / c0;
|
||||
size_t c0_i = c_i % c0;
|
||||
size_t nc_i = n_i;
|
||||
size_t src_i = c1_i * dhwncc0 + d_i * hwncc0 + h_i * wncc0 + w_i * ncc0 + nc_i * c0 + c0_i;
|
||||
SetData(size, false, src_i, dst_i, args, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace trans
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -63,6 +63,7 @@ bool NchwTo4D(const FormatArgs &args, void *result);
|
|||
bool NchwToFracZ(const FormatArgs &args, void *result);
|
||||
bool NchwToFracNz(const FormatArgs &args, void *result);
|
||||
bool NchwToNc1hwc0(const FormatArgs &args, void *result);
|
||||
bool NcdhwToFracZ3D(const FormatArgs &args, void *result);
|
||||
bool NchwToFracZc04(const FormatArgs &args, void *result);
|
||||
bool NchwToNc1hwc04(const FormatArgs &args, void *result);
|
||||
bool NchwToC1hwncoc0(const FormatArgs &args, void *result);
|
||||
|
@ -74,6 +75,7 @@ bool FracZToNchw(const FormatArgs &args, void *result);
|
|||
bool FracNzToNchw(const FormatArgs &args, void *result);
|
||||
bool Nc1hwc0ToNchw(const FormatArgs &args, void *result);
|
||||
bool Nc1hwc04ToNchw(const FormatArgs &args, void *result);
|
||||
bool FracZ3DToNcdhw(const FormatArgs &args, void *result);
|
||||
bool C1hwncoc0ToNchw(const FormatArgs &args, void *result);
|
||||
bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result);
|
||||
using FormatTransfer = std::function<bool(const FormatArgs &, void *)>;
|
||||
|
@ -81,7 +83,7 @@ const std::map<std::string, FormatTransfer> kTransFormatMapOfHostToDevice{
|
|||
{kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz},
|
||||
{kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0},
|
||||
{kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04},
|
||||
{kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}};
|
||||
{kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}, {kOpFormat_FRACTAL_Z_3D, NcdhwToFracZ3D}};
|
||||
|
||||
} // namespace trans
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -93,9 +93,9 @@ namespace device {
|
|||
namespace ascend {
|
||||
const int FLOAT_LEN = sizeof(float);
|
||||
const int FLOAT16_LEN = 2; // sizeof(float16);
|
||||
const std::set<std::string> kOpNeedTransFormat = {kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_NC1HWC0,
|
||||
kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ,
|
||||
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDC1HWC0};
|
||||
const std::set<std::string> kOpNeedTransFormat = {
|
||||
kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0,
|
||||
kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D};
|
||||
|
||||
void SyncMemory(void *dst, const void *src, uint64_t size, rtMemcpyKind_t kind) {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
|
@ -575,7 +575,8 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh
|
|||
host_shape.emplace_back(1);
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW || format_ == kOpFormat_NDC1HWC0) {
|
||||
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW || format_ == kOpFormat_NDC1HWC0 ||
|
||||
format_ == kOpFormat_FRACTAL_Z_3D) {
|
||||
device_shape = trans::TransShapeToDevice(host_shape, format_);
|
||||
} else {
|
||||
host_shape = trans::PaddingShapeTo4d(host_shape);
|
||||
|
|
Loading…
Reference in New Issue