forked from mindspore-Ecosystem/mindspore
commit
b86ce1e832
|
@ -107,7 +107,6 @@ def build_op(build_type, json_str, tune_mode=None):
|
|||
op_module = __import__("impl.dynamic." + op_name, globals(), locals(), [op_name], 0)
|
||||
op_module_name = "impl.dynamic." + op_name
|
||||
else:
|
||||
# op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
|
||||
op_module_name = "impl." + op_name
|
||||
# get function
|
||||
if build_type == op_build:
|
||||
|
|
|
@ -143,7 +143,6 @@ def single_to_fusion(json_file, tune_mode):
|
|||
"l1_size": -1,
|
||||
"op_list": ops
|
||||
}
|
||||
# op_info = {"fusion_op": end_file}
|
||||
res = json.dumps(end_file, ensure_ascii=False)
|
||||
return res
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
constexpr size_t kInsertIdx = 3;
|
||||
const BaseRef InsertPlaceholderForDynamicRNN::DefinePattern() const {
|
||||
std::shared_ptr<Var> V = std::make_shared<CondVar>(UnVisited);
|
||||
std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
|
||||
|
@ -52,7 +53,7 @@ const AnfNodePtr InsertPlaceholderForDynamicRNN::Process(const FuncGraphPtr &fun
|
|||
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
||||
for (size_t in_idx = 0; in_idx < input_num; in_idx++) {
|
||||
auto input_node = AnfAlgo::GetInputNode(cnode, in_idx);
|
||||
if (in_idx == 3) {
|
||||
if (in_idx == kInsertIdx) {
|
||||
auto value = std::make_shared<None>();
|
||||
auto value_node = NewValueNode(value);
|
||||
value_node->set_abstract(std::make_shared<abstract::AbstractNone>());
|
||||
|
|
|
@ -26,6 +26,7 @@ namespace opt {
|
|||
namespace {
|
||||
constexpr size_t kDynamicRNNGradInputNum = 16;
|
||||
constexpr size_t kSplitVOutputNum = 2;
|
||||
constexpr size_t kBasicCellOutputNum = 2;
|
||||
|
||||
void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||
std::vector<std::vector<AnfNodePtr>> *result_nodes) {
|
||||
|
@ -66,8 +67,8 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
|
|||
// Create split
|
||||
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
|
||||
auto split_v = func_graph->NewCNode(splitv_input);
|
||||
auto origin_output2_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 2);
|
||||
auto origin_output3_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 3);
|
||||
auto origin_output2_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, kIndex2);
|
||||
auto origin_output3_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, kIndex3);
|
||||
std::vector<size_t> split_v_output0_shape{IntToSize(1), origin_output2_shape[kDim1], origin_output2_shape[kDim2]};
|
||||
std::vector<size_t> split_v_output1_shape{IntToSize(1), origin_output3_shape[kDim0], origin_output3_shape[kDim1]};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32},
|
||||
|
@ -210,7 +211,8 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
|||
AnfAlgo::CopyNodeAttrs(basic_lstm_cell_c_state_grad_nodes[i], basic_lstm_cell_c_state_grad);
|
||||
// Create outputs for current basic_lstm_cell_c_state_grad node
|
||||
std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, basic_lstm_cell_c_state_grad, 2, &basic_lstm_cell_c_state_grad_outputs);
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, basic_lstm_cell_c_state_grad, kBasicCellOutputNum,
|
||||
&basic_lstm_cell_c_state_grad_outputs);
|
||||
pre_basic_lstm_cell_c_state_grad_outputs = basic_lstm_cell_c_state_grad_outputs;
|
||||
|
||||
// Create MatMul
|
||||
|
@ -232,7 +234,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
|||
|
||||
// Create outputs for current split node
|
||||
std::vector<AnfNodePtr> split_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, split_v, 2, &split_outputs);
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, split_v, kSplitVOutputNum, &split_outputs);
|
||||
pre_split_outputs = split_outputs;
|
||||
|
||||
lstm_x_concat_input[idx + 1] = split_outputs[0];
|
||||
|
@ -376,7 +378,8 @@ AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dy
|
|||
auto origin_input4_shape = AnfAlgo::GetOutputInferShape(origin_input4, 0);
|
||||
// Create reshape to change shape
|
||||
std::vector<size_t> shape_tmp;
|
||||
if (origin_input4_shape.size() == 3) {
|
||||
constexpr size_t kShapeDimNum = 3;
|
||||
if (origin_input4_shape.size() == kShapeDimNum) {
|
||||
shape_tmp = origin_input4_shape;
|
||||
} else {
|
||||
shape_tmp = {1, origin_input4_shape[0], origin_input4_shape[1]};
|
||||
|
@ -506,7 +509,7 @@ const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph
|
|||
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
|
||||
if (dynamic_rnn_grad_cnode->inputs().size() < kDynamicRNNGradInputNum + 1) {
|
||||
MS_LOG(INFO) << "The node " << dynamic_rnn_grad_cnode->DebugString() << " has less than "
|
||||
<< kDynamicRNNGradInputNum + 1 << " inputs";
|
||||
<< (kDynamicRNNGradInputNum + 1) << " inputs";
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> new_outputs;
|
||||
|
|
|
@ -30,18 +30,22 @@
|
|||
using mindspore::abstract::Shape;
|
||||
namespace mindspore {
|
||||
namespace trans {
|
||||
const int b1 = 1;
|
||||
const int b2 = 2;
|
||||
const int b4 = 4;
|
||||
const int b8 = 8;
|
||||
inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) {
|
||||
switch (size) {
|
||||
case 1:
|
||||
case b1:
|
||||
static_cast<uint8_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint8_t *>(args.data)[src_idx];
|
||||
break;
|
||||
case 2:
|
||||
case b2:
|
||||
static_cast<uint16_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint16_t *>(args.data)[src_idx];
|
||||
break;
|
||||
case 4:
|
||||
case b4:
|
||||
static_cast<uint32_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint32_t *>(args.data)[src_idx];
|
||||
break;
|
||||
case 8:
|
||||
case b8:
|
||||
static_cast<uint64_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint64_t *>(args.data)[src_idx];
|
||||
break;
|
||||
default:
|
||||
|
@ -357,47 +361,47 @@ std::vector<int64_t> Nc1hwc0DeviceDynamicShape(const std::vector<int64_t> &shape
|
|||
|
||||
std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape) {
|
||||
// NCDHW
|
||||
if (shape.size() != 5) {
|
||||
if (shape.size() != kNcdhw) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
|
||||
const size_t C0 = kCubeSize;
|
||||
device_shape.push_back(shape[0]);
|
||||
device_shape.push_back(shape[2]);
|
||||
device_shape.push_back(shape[N_ncdhw]);
|
||||
device_shape.push_back(shape[D_ncdhw]);
|
||||
device_shape.push_back(C1);
|
||||
device_shape.push_back(shape[3]);
|
||||
device_shape.push_back(shape[4]);
|
||||
device_shape.push_back(shape[H_ncdhw]);
|
||||
device_shape.push_back(shape[W_ncdhw]);
|
||||
device_shape.push_back(C0);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<int64_t> Ndc1hwc0DeviceDynamicShape(const std::vector<int64_t> &shape) {
|
||||
// NCDHW
|
||||
if (shape.size() != 5) {
|
||||
if (shape.size() != kNcdhw) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
|
||||
}
|
||||
std::vector<int64_t> device_shape;
|
||||
const int64_t C1 = (shape[1] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[1] + kCubeSize - 1) / kCubeSize;
|
||||
const int64_t C0 = kCubeSize;
|
||||
device_shape.push_back(shape[0]);
|
||||
device_shape.push_back(shape[2]);
|
||||
device_shape.push_back(shape[N_ncdhw]);
|
||||
device_shape.push_back(shape[D_ncdhw]);
|
||||
device_shape.push_back(C1);
|
||||
device_shape.push_back(shape[3]);
|
||||
device_shape.push_back(shape[4]);
|
||||
device_shape.push_back(shape[H_ncdhw]);
|
||||
device_shape.push_back(shape[W_ncdhw]);
|
||||
device_shape.push_back(C0);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) {
|
||||
// NCDHW -> Frac_Z_3D
|
||||
if (shape.size() != 5) {
|
||||
if (shape.size() != kNcdhw) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
|
||||
const size_t N1 = (shape[0] + kCubeSize - 1) / kCubeSize;
|
||||
device_shape.push_back(shape[2] * C1 * shape[3] * shape[4]);
|
||||
device_shape.push_back(shape[D_ncdhw] * C1 * shape[H_ncdhw] * shape[W_ncdhw]);
|
||||
device_shape.push_back(N1);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
|
@ -406,15 +410,15 @@ std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) {
|
|||
|
||||
std::vector<int64_t> Fracz3DDeviceDynamicShape(const std::vector<int64_t> &shape) {
|
||||
// NCDHW -> Frac_Z_3D
|
||||
if (shape.size() != 5) {
|
||||
if (shape.size() != kNcdhw) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
|
||||
}
|
||||
std::vector<int64_t> device_shape;
|
||||
if (HasShapeDynamic({shape[1], shape[2], shape[3], shape[4]})) {
|
||||
if (HasShapeDynamic({shape[C_ncdhw], shape[D_ncdhw], shape[H_ncdhw], shape[W_ncdhw]})) {
|
||||
device_shape.push_back(Shape::SHP_ANY);
|
||||
} else {
|
||||
const int64_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
|
||||
device_shape.push_back(shape[2] * C1 * shape[3] * shape[4]);
|
||||
device_shape.push_back(shape[D_ncdhw] * C1 * shape[H_ncdhw] * shape[W_ncdhw]);
|
||||
}
|
||||
|
||||
const int64_t N1 = (shape[0] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[0] + kCubeSize - 1) / kCubeSize;
|
||||
|
@ -1558,7 +1562,7 @@ bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result) {
|
|||
MS_LOG(DEBUG) << "Trans from ndc1hwc0 to ncdhw";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
|
||||
if (args.host_shape.size() != 5) {
|
||||
if (args.host_shape.size() != kNcdhw) {
|
||||
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
|
||||
return false;
|
||||
}
|
||||
|
@ -1572,13 +1576,13 @@ bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result) {
|
|||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
auto n = args.host_shape[0];
|
||||
auto c = args.host_shape[1];
|
||||
auto d = args.host_shape[2];
|
||||
auto h = args.host_shape[3];
|
||||
auto w = args.host_shape[4];
|
||||
auto c1 = args.device_shape[2];
|
||||
auto c0 = args.device_shape[5];
|
||||
auto n = args.host_shape[N_ncdhw];
|
||||
auto c = args.host_shape[C_ncdhw];
|
||||
auto d = args.host_shape[D_ncdhw];
|
||||
auto h = args.host_shape[H_ncdhw];
|
||||
auto w = args.host_shape[W_ncdhw];
|
||||
auto c1 = args.device_shape[C1_ndc1hwc0];
|
||||
auto c0 = args.device_shape[C0_ndc1hwc0];
|
||||
const size_t cdhw = c * d * h * w;
|
||||
const size_t dhw = d * h * w;
|
||||
const size_t hw = h * w;
|
||||
|
@ -1613,7 +1617,7 @@ bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result) {
|
|||
MS_LOG(DEBUG) << "Trans from ncdhw to ndc1hwc0";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
|
||||
if (args.host_shape.size() != 5) {
|
||||
if (args.host_shape.size() != kNcdhw) {
|
||||
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
|
||||
return false;
|
||||
}
|
||||
|
@ -1628,11 +1632,11 @@ bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result) {
|
|||
return false;
|
||||
}
|
||||
|
||||
auto n = args.host_shape[0];
|
||||
auto c = args.host_shape[1];
|
||||
auto d = args.host_shape[2];
|
||||
auto h = args.host_shape[3];
|
||||
auto w = args.host_shape[4];
|
||||
auto n = args.host_shape[N_ncdhw];
|
||||
auto c = args.host_shape[C_ncdhw];
|
||||
auto d = args.host_shape[D_ncdhw];
|
||||
auto h = args.host_shape[H_ncdhw];
|
||||
auto w = args.host_shape[W_ncdhw];
|
||||
auto c0 = kCubeSize;
|
||||
auto c1 = DivCeil(c, c0);
|
||||
const size_t cdhw = c * d * h * w;
|
||||
|
@ -1672,7 +1676,7 @@ bool NcdhwToFracZ3D(const FormatArgs &args, void *result) {
|
|||
MS_LOG(DEBUG) << "Trans from ncdhw to frac_z_3d";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
|
||||
if (args.host_shape.size() != 5) {
|
||||
if (args.host_shape.size() != kNcdhw) {
|
||||
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
|
||||
return false;
|
||||
}
|
||||
|
@ -1687,11 +1691,11 @@ bool NcdhwToFracZ3D(const FormatArgs &args, void *result) {
|
|||
return false;
|
||||
}
|
||||
|
||||
auto n = args.host_shape[0];
|
||||
auto c = args.host_shape[1];
|
||||
auto d = args.host_shape[2];
|
||||
auto h = args.host_shape[3];
|
||||
auto w = args.host_shape[4];
|
||||
auto n = args.host_shape[N_ncdhw];
|
||||
auto c = args.host_shape[C_ncdhw];
|
||||
auto d = args.host_shape[D_ncdhw];
|
||||
auto h = args.host_shape[H_ncdhw];
|
||||
auto w = args.host_shape[W_ncdhw];
|
||||
|
||||
auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize;
|
||||
const size_t c0 = 16;
|
||||
|
@ -1728,7 +1732,7 @@ bool FracZ3DToNcdhw(const FormatArgs &args, void *result) {
|
|||
MS_LOG(DEBUG) << "Trans from frac_z_3d to ncdhw";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
|
||||
if (args.host_shape.size() != 5) {
|
||||
if (args.host_shape.size() != kNcdhw) {
|
||||
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
|
||||
return false;
|
||||
}
|
||||
|
@ -1742,12 +1746,13 @@ bool FracZ3DToNcdhw(const FormatArgs &args, void *result) {
|
|||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
auto n = args.host_shape[0];
|
||||
auto c = args.host_shape[1];
|
||||
auto d = args.host_shape[2];
|
||||
auto h = args.host_shape[3];
|
||||
auto w = args.host_shape[4];
|
||||
auto c0 = args.device_shape[3];
|
||||
auto n = args.host_shape[N_ncdhw];
|
||||
auto c = args.host_shape[C_ncdhw];
|
||||
auto d = args.host_shape[D_ncdhw];
|
||||
auto h = args.host_shape[H_ncdhw];
|
||||
auto w = args.host_shape[W_ncdhw];
|
||||
const int kFZ3D_C0 = 3;
|
||||
auto c0 = args.device_shape[kFZ3D_C0];
|
||||
auto c1 = DivCeil(c, kCubeSize);
|
||||
auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize;
|
||||
auto n1n0c0 = n1n0 * c0;
|
||||
|
|
|
@ -31,14 +31,22 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace trans {
|
||||
enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNcdhw };
|
||||
enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims };
|
||||
enum Axis5D : int {
|
||||
N_ncdhw = 0,
|
||||
C_ncdhw,
|
||||
D_ncdhw,
|
||||
H_ncdhw,
|
||||
W_ncdhw,
|
||||
kNcdhw,
|
||||
N_ndc1hwc0 = 0,
|
||||
D_ndc1hwc0,
|
||||
C1_ndc1hwc0,
|
||||
H_ndc1hwc0,
|
||||
W_ndc1hwc0,
|
||||
C0_ndc1hwc0
|
||||
};
|
||||
|
||||
struct TypeIdArgs {
|
||||
const void *data;
|
||||
size_t host_shape_size; // Multiply each dimension elements. [a, b, c, d] => a*b*c*d
|
||||
|
@ -118,25 +126,25 @@ std::vector<T> PaddingShapeTo5dDefault(const std::vector<T> &shape) {
|
|||
}
|
||||
std::vector<T> shape_5d(kNcdhw, 1);
|
||||
switch (shape.size()) {
|
||||
case 0:
|
||||
case N_ncdhw:
|
||||
return shape_5d;
|
||||
case 1:
|
||||
shape_5d[1] = shape[0];
|
||||
case C_ncdhw:
|
||||
shape_5d[C_ncdhw] = shape[N_ncdhw];
|
||||
break;
|
||||
case 2:
|
||||
shape_5d[1] = shape[0];
|
||||
shape_5d[2] = shape[1];
|
||||
case D_ncdhw:
|
||||
shape_5d[C_ncdhw] = shape[N_ncdhw];
|
||||
shape_5d[D_ncdhw] = shape[C_ncdhw];
|
||||
break;
|
||||
case 3:
|
||||
shape_5d[1] = shape[0];
|
||||
shape_5d[2] = shape[1];
|
||||
shape_5d[3] = shape[2];
|
||||
case H_ncdhw:
|
||||
shape_5d[C_ncdhw] = shape[N_ncdhw];
|
||||
shape_5d[D_ncdhw] = shape[C_ncdhw];
|
||||
shape_5d[H_ncdhw] = shape[D_ncdhw];
|
||||
break;
|
||||
case 4:
|
||||
shape_5d[1] = shape[0];
|
||||
shape_5d[2] = shape[1];
|
||||
shape_5d[3] = shape[2];
|
||||
shape_5d[4] = shape[3];
|
||||
case W_ncdhw:
|
||||
shape_5d[C_ncdhw] = shape[N_ncdhw];
|
||||
shape_5d[D_ncdhw] = shape[C_ncdhw];
|
||||
shape_5d[H_ncdhw] = shape[D_ncdhw];
|
||||
shape_5d[W_ncdhw] = shape[H_ncdhw];
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size();
|
||||
|
@ -148,21 +156,21 @@ template <typename T>
|
|||
std::vector<T> PaddingShapeTo4dDefault(const std::vector<T> &shape) {
|
||||
std::vector<T> shape_4d(kNchwDims, 1);
|
||||
switch (shape.size()) {
|
||||
case 0:
|
||||
case kN:
|
||||
return shape_4d;
|
||||
case 1:
|
||||
case kC:
|
||||
shape_4d[kC] = shape[kN];
|
||||
break;
|
||||
case 2:
|
||||
case kH:
|
||||
shape_4d[kC] = shape[kN];
|
||||
shape_4d[kH] = shape[kC];
|
||||
break;
|
||||
case 3:
|
||||
case kW:
|
||||
shape_4d[kC] = shape[kN];
|
||||
shape_4d[kH] = shape[kC];
|
||||
shape_4d[kW] = shape[kH];
|
||||
break;
|
||||
case 4:
|
||||
case kNchwDims:
|
||||
std::copy(shape.begin(), shape.end(), shape_4d.begin());
|
||||
break;
|
||||
default:
|
||||
|
|
|
@ -30,13 +30,10 @@ std::optional<std::string> Common::GetRealPath(const std::string &input_path) {
|
|||
if (input_path.length() >= PATH_MAX) {
|
||||
MS_LOG(EXCEPTION) << "The length of path: " << input_path << " exceeds limit: " << PATH_MAX;
|
||||
}
|
||||
#if defined(SYSTEM_ENV_POSIX)
|
||||
size_t path_split_pos = input_path.find_last_of('/');
|
||||
#elif defined(SYSTEM_ENV_WINDOWS)
|
||||
size_t path_split_pos = input_path.find_last_of('\\');
|
||||
#else
|
||||
MS_LOG(EXCEPTION) << "Unsupported platform.";
|
||||
#endif
|
||||
auto path_split_pos = input_path.find_last_of('/');
|
||||
if (path_split_pos == std::string::npos) {
|
||||
path_split_pos = input_path.find_last_of('\\');
|
||||
}
|
||||
// get real path
|
||||
char real_path[PATH_MAX] = {0};
|
||||
// input_path is dir + file_name
|
||||
|
|
|
@ -29,6 +29,7 @@ namespace ops {
|
|||
namespace {
|
||||
constexpr size_t kLenLogProbs = 3;
|
||||
constexpr size_t kLenTarget = 2;
|
||||
constexpr int64_t kMulti = 2;
|
||||
constexpr size_t kInputSize = 4;
|
||||
abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
@ -54,7 +55,7 @@ abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
|
|||
|
||||
ShapeVector output_shape;
|
||||
std::vector<int64_t> out_dim0 = {N};
|
||||
std::vector<int64_t> out_dim1 = {N, T, 2 * S + 1};
|
||||
std::vector<int64_t> out_dim1 = {N, T, kMulti * S + 1};
|
||||
abstract::ShapePtr neg_log_shape = std::make_shared<abstract::Shape>(out_dim0);
|
||||
abstract::ShapePtr log_alpha_shape = std::make_shared<abstract::Shape>(out_dim1);
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{neg_log_shape, log_alpha_shape});
|
||||
|
|
|
@ -28,6 +28,7 @@ namespace ops {
|
|||
namespace {
|
||||
constexpr size_t kLenLogProbs = 3;
|
||||
constexpr size_t kInputSize = 7;
|
||||
constexpr size_t kIdx2 = 2;
|
||||
abstract::ShapePtr CTCLossV2GradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
@ -43,7 +44,7 @@ abstract::ShapePtr CTCLossV2GradInferShape(const PrimitivePtr &primitive,
|
|||
}
|
||||
int64_t T = log_probs_shape[0];
|
||||
int64_t N = log_probs_shape[1];
|
||||
int64_t C = log_probs_shape[2];
|
||||
int64_t C = log_probs_shape[kIdx2];
|
||||
ShapeVector output_shape = {N, T, C};
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue