Merge pull request !19875 from liubuyu/master
This commit is contained in:
i-robot 2021-07-27 03:23:52 +00:00 committed by Gitee
commit b86ce1e832
9 changed files with 99 additions and 85 deletions

View File

@ -107,7 +107,6 @@ def build_op(build_type, json_str, tune_mode=None):
op_module = __import__("impl.dynamic." + op_name, globals(), locals(), [op_name], 0) op_module = __import__("impl.dynamic." + op_name, globals(), locals(), [op_name], 0)
op_module_name = "impl.dynamic." + op_name op_module_name = "impl.dynamic." + op_name
else: else:
# op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
op_module_name = "impl." + op_name op_module_name = "impl." + op_name
# get function # get function
if build_type == op_build: if build_type == op_build:

View File

@ -143,7 +143,6 @@ def single_to_fusion(json_file, tune_mode):
"l1_size": -1, "l1_size": -1,
"op_list": ops "op_list": ops
} }
# op_info = {"fusion_op": end_file}
res = json.dumps(end_file, ensure_ascii=False) res = json.dumps(end_file, ensure_ascii=False)
return res return res

View File

@ -25,6 +25,7 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
constexpr size_t kInsertIdx = 3;
const BaseRef InsertPlaceholderForDynamicRNN::DefinePattern() const { const BaseRef InsertPlaceholderForDynamicRNN::DefinePattern() const {
std::shared_ptr<Var> V = std::make_shared<CondVar>(UnVisited); std::shared_ptr<Var> V = std::make_shared<CondVar>(UnVisited);
std::shared_ptr<Var> Xs = std::make_shared<SeqVar>(); std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
@ -52,7 +53,7 @@ const AnfNodePtr InsertPlaceholderForDynamicRNN::Process(const FuncGraphPtr &fun
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
for (size_t in_idx = 0; in_idx < input_num; in_idx++) { for (size_t in_idx = 0; in_idx < input_num; in_idx++) {
auto input_node = AnfAlgo::GetInputNode(cnode, in_idx); auto input_node = AnfAlgo::GetInputNode(cnode, in_idx);
if (in_idx == 3) { if (in_idx == kInsertIdx) {
auto value = std::make_shared<None>(); auto value = std::make_shared<None>();
auto value_node = NewValueNode(value); auto value_node = NewValueNode(value);
value_node->set_abstract(std::make_shared<abstract::AbstractNone>()); value_node->set_abstract(std::make_shared<abstract::AbstractNone>());

View File

@ -26,6 +26,7 @@ namespace opt {
namespace { namespace {
constexpr size_t kDynamicRNNGradInputNum = 16; constexpr size_t kDynamicRNNGradInputNum = 16;
constexpr size_t kSplitVOutputNum = 2; constexpr size_t kSplitVOutputNum = 2;
constexpr size_t kBasicCellOutputNum = 2;
void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode, void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
std::vector<std::vector<AnfNodePtr>> *result_nodes) { std::vector<std::vector<AnfNodePtr>> *result_nodes) {
@ -66,8 +67,8 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
// Create split // Create split
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))}; std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
auto split_v = func_graph->NewCNode(splitv_input); auto split_v = func_graph->NewCNode(splitv_input);
auto origin_output2_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 2); auto origin_output2_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, kIndex2);
auto origin_output3_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 3); auto origin_output3_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, kIndex3);
std::vector<size_t> split_v_output0_shape{IntToSize(1), origin_output2_shape[kDim1], origin_output2_shape[kDim2]}; std::vector<size_t> split_v_output0_shape{IntToSize(1), origin_output2_shape[kDim1], origin_output2_shape[kDim2]};
std::vector<size_t> split_v_output1_shape{IntToSize(1), origin_output3_shape[kDim0], origin_output3_shape[kDim1]}; std::vector<size_t> split_v_output1_shape{IntToSize(1), origin_output3_shape[kDim0], origin_output3_shape[kDim1]};
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32}, AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32},
@ -210,7 +211,8 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
AnfAlgo::CopyNodeAttrs(basic_lstm_cell_c_state_grad_nodes[i], basic_lstm_cell_c_state_grad); AnfAlgo::CopyNodeAttrs(basic_lstm_cell_c_state_grad_nodes[i], basic_lstm_cell_c_state_grad);
// Create outputs for current basic_lstm_cell_c_state_grad node // Create outputs for current basic_lstm_cell_c_state_grad node
std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_outputs; std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, basic_lstm_cell_c_state_grad, 2, &basic_lstm_cell_c_state_grad_outputs); CreateMultipleOutputsOfAnfNode(func_graph, basic_lstm_cell_c_state_grad, kBasicCellOutputNum,
&basic_lstm_cell_c_state_grad_outputs);
pre_basic_lstm_cell_c_state_grad_outputs = basic_lstm_cell_c_state_grad_outputs; pre_basic_lstm_cell_c_state_grad_outputs = basic_lstm_cell_c_state_grad_outputs;
// Create MatMul // Create MatMul
@ -232,7 +234,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
// Create outputs for current split node // Create outputs for current split node
std::vector<AnfNodePtr> split_outputs; std::vector<AnfNodePtr> split_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, split_v, 2, &split_outputs); CreateMultipleOutputsOfAnfNode(func_graph, split_v, kSplitVOutputNum, &split_outputs);
pre_split_outputs = split_outputs; pre_split_outputs = split_outputs;
lstm_x_concat_input[idx + 1] = split_outputs[0]; lstm_x_concat_input[idx + 1] = split_outputs[0];
@ -376,7 +378,8 @@ AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dy
auto origin_input4_shape = AnfAlgo::GetOutputInferShape(origin_input4, 0); auto origin_input4_shape = AnfAlgo::GetOutputInferShape(origin_input4, 0);
// Create reshape to change shape // Create reshape to change shape
std::vector<size_t> shape_tmp; std::vector<size_t> shape_tmp;
if (origin_input4_shape.size() == 3) { constexpr size_t kShapeDimNum = 3;
if (origin_input4_shape.size() == kShapeDimNum) {
shape_tmp = origin_input4_shape; shape_tmp = origin_input4_shape;
} else { } else {
shape_tmp = {1, origin_input4_shape[0], origin_input4_shape[1]}; shape_tmp = {1, origin_input4_shape[0], origin_input4_shape[1]};
@ -506,7 +509,7 @@ const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode); MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
if (dynamic_rnn_grad_cnode->inputs().size() < kDynamicRNNGradInputNum + 1) { if (dynamic_rnn_grad_cnode->inputs().size() < kDynamicRNNGradInputNum + 1) {
MS_LOG(INFO) << "The node " << dynamic_rnn_grad_cnode->DebugString() << " has less than " MS_LOG(INFO) << "The node " << dynamic_rnn_grad_cnode->DebugString() << " has less than "
<< kDynamicRNNGradInputNum + 1 << " inputs"; << (kDynamicRNNGradInputNum + 1) << " inputs";
return nullptr; return nullptr;
} }
std::vector<AnfNodePtr> new_outputs; std::vector<AnfNodePtr> new_outputs;

View File

@ -30,18 +30,22 @@
using mindspore::abstract::Shape; using mindspore::abstract::Shape;
namespace mindspore { namespace mindspore {
namespace trans { namespace trans {
const int b1 = 1;
const int b2 = 2;
const int b4 = 4;
const int b8 = 8;
inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) { inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) {
switch (size) { switch (size) {
case 1: case b1:
static_cast<uint8_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint8_t *>(args.data)[src_idx]; static_cast<uint8_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint8_t *>(args.data)[src_idx];
break; break;
case 2: case b2:
static_cast<uint16_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint16_t *>(args.data)[src_idx]; static_cast<uint16_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint16_t *>(args.data)[src_idx];
break; break;
case 4: case b4:
static_cast<uint32_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint32_t *>(args.data)[src_idx]; static_cast<uint32_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint32_t *>(args.data)[src_idx];
break; break;
case 8: case b8:
static_cast<uint64_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint64_t *>(args.data)[src_idx]; static_cast<uint64_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint64_t *>(args.data)[src_idx];
break; break;
default: default:
@ -357,47 +361,47 @@ std::vector<int64_t> Nc1hwc0DeviceDynamicShape(const std::vector<int64_t> &shape
std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape) { std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape) {
// NCDHW // NCDHW
if (shape.size() != 5) { if (shape.size() != kNcdhw) {
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size(); MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
} }
std::vector<size_t> device_shape; std::vector<size_t> device_shape;
const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize; const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
const size_t C0 = kCubeSize; const size_t C0 = kCubeSize;
device_shape.push_back(shape[0]); device_shape.push_back(shape[N_ncdhw]);
device_shape.push_back(shape[2]); device_shape.push_back(shape[D_ncdhw]);
device_shape.push_back(C1); device_shape.push_back(C1);
device_shape.push_back(shape[3]); device_shape.push_back(shape[H_ncdhw]);
device_shape.push_back(shape[4]); device_shape.push_back(shape[W_ncdhw]);
device_shape.push_back(C0); device_shape.push_back(C0);
return device_shape; return device_shape;
} }
std::vector<int64_t> Ndc1hwc0DeviceDynamicShape(const std::vector<int64_t> &shape) { std::vector<int64_t> Ndc1hwc0DeviceDynamicShape(const std::vector<int64_t> &shape) {
// NCDHW // NCDHW
if (shape.size() != 5) { if (shape.size() != kNcdhw) {
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size(); MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
} }
std::vector<int64_t> device_shape; std::vector<int64_t> device_shape;
const int64_t C1 = (shape[1] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[1] + kCubeSize - 1) / kCubeSize; const int64_t C1 = (shape[1] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[1] + kCubeSize - 1) / kCubeSize;
const int64_t C0 = kCubeSize; const int64_t C0 = kCubeSize;
device_shape.push_back(shape[0]); device_shape.push_back(shape[N_ncdhw]);
device_shape.push_back(shape[2]); device_shape.push_back(shape[D_ncdhw]);
device_shape.push_back(C1); device_shape.push_back(C1);
device_shape.push_back(shape[3]); device_shape.push_back(shape[H_ncdhw]);
device_shape.push_back(shape[4]); device_shape.push_back(shape[W_ncdhw]);
device_shape.push_back(C0); device_shape.push_back(C0);
return device_shape; return device_shape;
} }
std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) { std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) {
// NCDHW -> Frac_Z_3D // NCDHW -> Frac_Z_3D
if (shape.size() != 5) { if (shape.size() != kNcdhw) {
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size(); MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
} }
std::vector<size_t> device_shape; std::vector<size_t> device_shape;
const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize; const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
const size_t N1 = (shape[0] + kCubeSize - 1) / kCubeSize; const size_t N1 = (shape[0] + kCubeSize - 1) / kCubeSize;
device_shape.push_back(shape[2] * C1 * shape[3] * shape[4]); device_shape.push_back(shape[D_ncdhw] * C1 * shape[H_ncdhw] * shape[W_ncdhw]);
device_shape.push_back(N1); device_shape.push_back(N1);
device_shape.push_back(kCubeSize); device_shape.push_back(kCubeSize);
device_shape.push_back(kCubeSize); device_shape.push_back(kCubeSize);
@ -406,15 +410,15 @@ std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) {
std::vector<int64_t> Fracz3DDeviceDynamicShape(const std::vector<int64_t> &shape) { std::vector<int64_t> Fracz3DDeviceDynamicShape(const std::vector<int64_t> &shape) {
// NCDHW -> Frac_Z_3D // NCDHW -> Frac_Z_3D
if (shape.size() != 5) { if (shape.size() != kNcdhw) {
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size(); MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
} }
std::vector<int64_t> device_shape; std::vector<int64_t> device_shape;
if (HasShapeDynamic({shape[1], shape[2], shape[3], shape[4]})) { if (HasShapeDynamic({shape[C_ncdhw], shape[D_ncdhw], shape[H_ncdhw], shape[W_ncdhw]})) {
device_shape.push_back(Shape::SHP_ANY); device_shape.push_back(Shape::SHP_ANY);
} else { } else {
const int64_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize; const int64_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
device_shape.push_back(shape[2] * C1 * shape[3] * shape[4]); device_shape.push_back(shape[D_ncdhw] * C1 * shape[H_ncdhw] * shape[W_ncdhw]);
} }
const int64_t N1 = (shape[0] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[0] + kCubeSize - 1) / kCubeSize; const int64_t N1 = (shape[0] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[0] + kCubeSize - 1) / kCubeSize;
@ -1558,7 +1562,7 @@ bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result) {
MS_LOG(DEBUG) << "Trans from ndc1hwc0 to ncdhw"; MS_LOG(DEBUG) << "Trans from ndc1hwc0 to ncdhw";
MS_EXCEPTION_IF_NULL(result); MS_EXCEPTION_IF_NULL(result);
if (args.host_shape.size() != 5) { if (args.host_shape.size() != kNcdhw) {
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size(); MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
return false; return false;
} }
@ -1572,13 +1576,13 @@ bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
return false; return false;
} }
auto n = args.host_shape[0]; auto n = args.host_shape[N_ncdhw];
auto c = args.host_shape[1]; auto c = args.host_shape[C_ncdhw];
auto d = args.host_shape[2]; auto d = args.host_shape[D_ncdhw];
auto h = args.host_shape[3]; auto h = args.host_shape[H_ncdhw];
auto w = args.host_shape[4]; auto w = args.host_shape[W_ncdhw];
auto c1 = args.device_shape[2]; auto c1 = args.device_shape[C1_ndc1hwc0];
auto c0 = args.device_shape[5]; auto c0 = args.device_shape[C0_ndc1hwc0];
const size_t cdhw = c * d * h * w; const size_t cdhw = c * d * h * w;
const size_t dhw = d * h * w; const size_t dhw = d * h * w;
const size_t hw = h * w; const size_t hw = h * w;
@ -1613,7 +1617,7 @@ bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result) {
MS_LOG(DEBUG) << "Trans from ncdhw to ndc1hwc0"; MS_LOG(DEBUG) << "Trans from ncdhw to ndc1hwc0";
MS_EXCEPTION_IF_NULL(result); MS_EXCEPTION_IF_NULL(result);
if (args.host_shape.size() != 5) { if (args.host_shape.size() != kNcdhw) {
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size(); MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
return false; return false;
} }
@ -1628,11 +1632,11 @@ bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result) {
return false; return false;
} }
auto n = args.host_shape[0]; auto n = args.host_shape[N_ncdhw];
auto c = args.host_shape[1]; auto c = args.host_shape[C_ncdhw];
auto d = args.host_shape[2]; auto d = args.host_shape[D_ncdhw];
auto h = args.host_shape[3]; auto h = args.host_shape[H_ncdhw];
auto w = args.host_shape[4]; auto w = args.host_shape[W_ncdhw];
auto c0 = kCubeSize; auto c0 = kCubeSize;
auto c1 = DivCeil(c, c0); auto c1 = DivCeil(c, c0);
const size_t cdhw = c * d * h * w; const size_t cdhw = c * d * h * w;
@ -1672,7 +1676,7 @@ bool NcdhwToFracZ3D(const FormatArgs &args, void *result) {
MS_LOG(DEBUG) << "Trans from ncdhw to frac_z_3d"; MS_LOG(DEBUG) << "Trans from ncdhw to frac_z_3d";
MS_EXCEPTION_IF_NULL(result); MS_EXCEPTION_IF_NULL(result);
if (args.host_shape.size() != 5) { if (args.host_shape.size() != kNcdhw) {
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size(); MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
return false; return false;
} }
@ -1687,11 +1691,11 @@ bool NcdhwToFracZ3D(const FormatArgs &args, void *result) {
return false; return false;
} }
auto n = args.host_shape[0]; auto n = args.host_shape[N_ncdhw];
auto c = args.host_shape[1]; auto c = args.host_shape[C_ncdhw];
auto d = args.host_shape[2]; auto d = args.host_shape[D_ncdhw];
auto h = args.host_shape[3]; auto h = args.host_shape[H_ncdhw];
auto w = args.host_shape[4]; auto w = args.host_shape[W_ncdhw];
auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize; auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize;
const size_t c0 = 16; const size_t c0 = 16;
@ -1728,7 +1732,7 @@ bool FracZ3DToNcdhw(const FormatArgs &args, void *result) {
MS_LOG(DEBUG) << "Trans from frac_z_3d to ncdhw"; MS_LOG(DEBUG) << "Trans from frac_z_3d to ncdhw";
MS_EXCEPTION_IF_NULL(result); MS_EXCEPTION_IF_NULL(result);
if (args.host_shape.size() != 5) { if (args.host_shape.size() != kNcdhw) {
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size(); MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
return false; return false;
} }
@ -1742,12 +1746,13 @@ bool FracZ3DToNcdhw(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
return false; return false;
} }
auto n = args.host_shape[0]; auto n = args.host_shape[N_ncdhw];
auto c = args.host_shape[1]; auto c = args.host_shape[C_ncdhw];
auto d = args.host_shape[2]; auto d = args.host_shape[D_ncdhw];
auto h = args.host_shape[3]; auto h = args.host_shape[H_ncdhw];
auto w = args.host_shape[4]; auto w = args.host_shape[W_ncdhw];
auto c0 = args.device_shape[3]; const int kFZ3D_C0 = 3;
auto c0 = args.device_shape[kFZ3D_C0];
auto c1 = DivCeil(c, kCubeSize); auto c1 = DivCeil(c, kCubeSize);
auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize; auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize;
auto n1n0c0 = n1n0 * c0; auto n1n0c0 = n1n0 * c0;

View File

@ -31,14 +31,22 @@
namespace mindspore { namespace mindspore {
namespace trans { namespace trans {
enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNcdhw }; enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims };
enum Axis5D : int { enum Axis5D : int {
N_ncdhw = 0, N_ncdhw = 0,
C_ncdhw, C_ncdhw,
D_ncdhw, D_ncdhw,
H_ncdhw, H_ncdhw,
W_ncdhw, W_ncdhw,
kNcdhw,
N_ndc1hwc0 = 0,
D_ndc1hwc0,
C1_ndc1hwc0,
H_ndc1hwc0,
W_ndc1hwc0,
C0_ndc1hwc0
}; };
struct TypeIdArgs { struct TypeIdArgs {
const void *data; const void *data;
size_t host_shape_size; // Multiply each dimension elements. [a, b, c, d] => a*b*c*d size_t host_shape_size; // Multiply each dimension elements. [a, b, c, d] => a*b*c*d
@ -118,25 +126,25 @@ std::vector<T> PaddingShapeTo5dDefault(const std::vector<T> &shape) {
} }
std::vector<T> shape_5d(kNcdhw, 1); std::vector<T> shape_5d(kNcdhw, 1);
switch (shape.size()) { switch (shape.size()) {
case 0: case N_ncdhw:
return shape_5d; return shape_5d;
case 1: case C_ncdhw:
shape_5d[1] = shape[0]; shape_5d[C_ncdhw] = shape[N_ncdhw];
break; break;
case 2: case D_ncdhw:
shape_5d[1] = shape[0]; shape_5d[C_ncdhw] = shape[N_ncdhw];
shape_5d[2] = shape[1]; shape_5d[D_ncdhw] = shape[C_ncdhw];
break; break;
case 3: case H_ncdhw:
shape_5d[1] = shape[0]; shape_5d[C_ncdhw] = shape[N_ncdhw];
shape_5d[2] = shape[1]; shape_5d[D_ncdhw] = shape[C_ncdhw];
shape_5d[3] = shape[2]; shape_5d[H_ncdhw] = shape[D_ncdhw];
break; break;
case 4: case W_ncdhw:
shape_5d[1] = shape[0]; shape_5d[C_ncdhw] = shape[N_ncdhw];
shape_5d[2] = shape[1]; shape_5d[D_ncdhw] = shape[C_ncdhw];
shape_5d[3] = shape[2]; shape_5d[H_ncdhw] = shape[D_ncdhw];
shape_5d[4] = shape[3]; shape_5d[W_ncdhw] = shape[H_ncdhw];
break; break;
default: default:
MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size(); MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size();
@ -148,21 +156,21 @@ template <typename T>
std::vector<T> PaddingShapeTo4dDefault(const std::vector<T> &shape) { std::vector<T> PaddingShapeTo4dDefault(const std::vector<T> &shape) {
std::vector<T> shape_4d(kNchwDims, 1); std::vector<T> shape_4d(kNchwDims, 1);
switch (shape.size()) { switch (shape.size()) {
case 0: case kN:
return shape_4d; return shape_4d;
case 1: case kC:
shape_4d[kC] = shape[kN]; shape_4d[kC] = shape[kN];
break; break;
case 2: case kH:
shape_4d[kC] = shape[kN]; shape_4d[kC] = shape[kN];
shape_4d[kH] = shape[kC]; shape_4d[kH] = shape[kC];
break; break;
case 3: case kW:
shape_4d[kC] = shape[kN]; shape_4d[kC] = shape[kN];
shape_4d[kH] = shape[kC]; shape_4d[kH] = shape[kC];
shape_4d[kW] = shape[kH]; shape_4d[kW] = shape[kH];
break; break;
case 4: case kNchwDims:
std::copy(shape.begin(), shape.end(), shape_4d.begin()); std::copy(shape.begin(), shape.end(), shape_4d.begin());
break; break;
default: default:

View File

@ -30,13 +30,10 @@ std::optional<std::string> Common::GetRealPath(const std::string &input_path) {
if (input_path.length() >= PATH_MAX) { if (input_path.length() >= PATH_MAX) {
MS_LOG(EXCEPTION) << "The length of path: " << input_path << " exceeds limit: " << PATH_MAX; MS_LOG(EXCEPTION) << "The length of path: " << input_path << " exceeds limit: " << PATH_MAX;
} }
#if defined(SYSTEM_ENV_POSIX) auto path_split_pos = input_path.find_last_of('/');
size_t path_split_pos = input_path.find_last_of('/'); if (path_split_pos == std::string::npos) {
#elif defined(SYSTEM_ENV_WINDOWS) path_split_pos = input_path.find_last_of('\\');
size_t path_split_pos = input_path.find_last_of('\\'); }
#else
MS_LOG(EXCEPTION) << "Unsupported platform.";
#endif
// get real path // get real path
char real_path[PATH_MAX] = {0}; char real_path[PATH_MAX] = {0};
// input_path is dir + file_name // input_path is dir + file_name

View File

@ -29,6 +29,7 @@ namespace ops {
namespace { namespace {
constexpr size_t kLenLogProbs = 3; constexpr size_t kLenLogProbs = 3;
constexpr size_t kLenTarget = 2; constexpr size_t kLenTarget = 2;
constexpr int64_t kMulti = 2;
constexpr size_t kInputSize = 4; constexpr size_t kInputSize = 4;
abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive, abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
@ -54,7 +55,7 @@ abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
ShapeVector output_shape; ShapeVector output_shape;
std::vector<int64_t> out_dim0 = {N}; std::vector<int64_t> out_dim0 = {N};
std::vector<int64_t> out_dim1 = {N, T, 2 * S + 1}; std::vector<int64_t> out_dim1 = {N, T, kMulti * S + 1};
abstract::ShapePtr neg_log_shape = std::make_shared<abstract::Shape>(out_dim0); abstract::ShapePtr neg_log_shape = std::make_shared<abstract::Shape>(out_dim0);
abstract::ShapePtr log_alpha_shape = std::make_shared<abstract::Shape>(out_dim1); abstract::ShapePtr log_alpha_shape = std::make_shared<abstract::Shape>(out_dim1);
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{neg_log_shape, log_alpha_shape}); return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{neg_log_shape, log_alpha_shape});

View File

@ -28,6 +28,7 @@ namespace ops {
namespace { namespace {
constexpr size_t kLenLogProbs = 3; constexpr size_t kLenLogProbs = 3;
constexpr size_t kInputSize = 7; constexpr size_t kInputSize = 7;
constexpr size_t kIdx2 = 2;
abstract::ShapePtr CTCLossV2GradInferShape(const PrimitivePtr &primitive, abstract::ShapePtr CTCLossV2GradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
@ -43,7 +44,7 @@ abstract::ShapePtr CTCLossV2GradInferShape(const PrimitivePtr &primitive,
} }
int64_t T = log_probs_shape[0]; int64_t T = log_probs_shape[0];
int64_t N = log_probs_shape[1]; int64_t N = log_probs_shape[1];
int64_t C = log_probs_shape[2]; int64_t C = log_probs_shape[kIdx2];
ShapeVector output_shape = {N, T, C}; ShapeVector output_shape = {N, T, C};
return std::make_shared<abstract::Shape>(output_shape); return std::make_shared<abstract::Shape>(output_shape);
} }