forked from mindspore-Ecosystem/mindspore
commit
b86ce1e832
|
@ -107,7 +107,6 @@ def build_op(build_type, json_str, tune_mode=None):
|
||||||
op_module = __import__("impl.dynamic." + op_name, globals(), locals(), [op_name], 0)
|
op_module = __import__("impl.dynamic." + op_name, globals(), locals(), [op_name], 0)
|
||||||
op_module_name = "impl.dynamic." + op_name
|
op_module_name = "impl.dynamic." + op_name
|
||||||
else:
|
else:
|
||||||
# op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
|
|
||||||
op_module_name = "impl." + op_name
|
op_module_name = "impl." + op_name
|
||||||
# get function
|
# get function
|
||||||
if build_type == op_build:
|
if build_type == op_build:
|
||||||
|
|
|
@ -143,7 +143,6 @@ def single_to_fusion(json_file, tune_mode):
|
||||||
"l1_size": -1,
|
"l1_size": -1,
|
||||||
"op_list": ops
|
"op_list": ops
|
||||||
}
|
}
|
||||||
# op_info = {"fusion_op": end_file}
|
|
||||||
res = json.dumps(end_file, ensure_ascii=False)
|
res = json.dumps(end_file, ensure_ascii=False)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
constexpr size_t kInsertIdx = 3;
|
||||||
const BaseRef InsertPlaceholderForDynamicRNN::DefinePattern() const {
|
const BaseRef InsertPlaceholderForDynamicRNN::DefinePattern() const {
|
||||||
std::shared_ptr<Var> V = std::make_shared<CondVar>(UnVisited);
|
std::shared_ptr<Var> V = std::make_shared<CondVar>(UnVisited);
|
||||||
std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
|
std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
|
||||||
|
@ -52,7 +53,7 @@ const AnfNodePtr InsertPlaceholderForDynamicRNN::Process(const FuncGraphPtr &fun
|
||||||
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
||||||
for (size_t in_idx = 0; in_idx < input_num; in_idx++) {
|
for (size_t in_idx = 0; in_idx < input_num; in_idx++) {
|
||||||
auto input_node = AnfAlgo::GetInputNode(cnode, in_idx);
|
auto input_node = AnfAlgo::GetInputNode(cnode, in_idx);
|
||||||
if (in_idx == 3) {
|
if (in_idx == kInsertIdx) {
|
||||||
auto value = std::make_shared<None>();
|
auto value = std::make_shared<None>();
|
||||||
auto value_node = NewValueNode(value);
|
auto value_node = NewValueNode(value);
|
||||||
value_node->set_abstract(std::make_shared<abstract::AbstractNone>());
|
value_node->set_abstract(std::make_shared<abstract::AbstractNone>());
|
||||||
|
|
|
@ -26,6 +26,7 @@ namespace opt {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr size_t kDynamicRNNGradInputNum = 16;
|
constexpr size_t kDynamicRNNGradInputNum = 16;
|
||||||
constexpr size_t kSplitVOutputNum = 2;
|
constexpr size_t kSplitVOutputNum = 2;
|
||||||
|
constexpr size_t kBasicCellOutputNum = 2;
|
||||||
|
|
||||||
void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
||||||
std::vector<std::vector<AnfNodePtr>> *result_nodes) {
|
std::vector<std::vector<AnfNodePtr>> *result_nodes) {
|
||||||
|
@ -66,8 +67,8 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
|
||||||
// Create split
|
// Create split
|
||||||
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
|
std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
|
||||||
auto split_v = func_graph->NewCNode(splitv_input);
|
auto split_v = func_graph->NewCNode(splitv_input);
|
||||||
auto origin_output2_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 2);
|
auto origin_output2_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, kIndex2);
|
||||||
auto origin_output3_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 3);
|
auto origin_output3_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, kIndex3);
|
||||||
std::vector<size_t> split_v_output0_shape{IntToSize(1), origin_output2_shape[kDim1], origin_output2_shape[kDim2]};
|
std::vector<size_t> split_v_output0_shape{IntToSize(1), origin_output2_shape[kDim1], origin_output2_shape[kDim2]};
|
||||||
std::vector<size_t> split_v_output1_shape{IntToSize(1), origin_output3_shape[kDim0], origin_output3_shape[kDim1]};
|
std::vector<size_t> split_v_output1_shape{IntToSize(1), origin_output3_shape[kDim0], origin_output3_shape[kDim1]};
|
||||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32},
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32},
|
||||||
|
@ -210,7 +211,8 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
||||||
AnfAlgo::CopyNodeAttrs(basic_lstm_cell_c_state_grad_nodes[i], basic_lstm_cell_c_state_grad);
|
AnfAlgo::CopyNodeAttrs(basic_lstm_cell_c_state_grad_nodes[i], basic_lstm_cell_c_state_grad);
|
||||||
// Create outputs for current basic_lstm_cell_c_state_grad node
|
// Create outputs for current basic_lstm_cell_c_state_grad node
|
||||||
std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_outputs;
|
std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_outputs;
|
||||||
CreateMultipleOutputsOfAnfNode(func_graph, basic_lstm_cell_c_state_grad, 2, &basic_lstm_cell_c_state_grad_outputs);
|
CreateMultipleOutputsOfAnfNode(func_graph, basic_lstm_cell_c_state_grad, kBasicCellOutputNum,
|
||||||
|
&basic_lstm_cell_c_state_grad_outputs);
|
||||||
pre_basic_lstm_cell_c_state_grad_outputs = basic_lstm_cell_c_state_grad_outputs;
|
pre_basic_lstm_cell_c_state_grad_outputs = basic_lstm_cell_c_state_grad_outputs;
|
||||||
|
|
||||||
// Create MatMul
|
// Create MatMul
|
||||||
|
@ -232,7 +234,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
||||||
|
|
||||||
// Create outputs for current split node
|
// Create outputs for current split node
|
||||||
std::vector<AnfNodePtr> split_outputs;
|
std::vector<AnfNodePtr> split_outputs;
|
||||||
CreateMultipleOutputsOfAnfNode(func_graph, split_v, 2, &split_outputs);
|
CreateMultipleOutputsOfAnfNode(func_graph, split_v, kSplitVOutputNum, &split_outputs);
|
||||||
pre_split_outputs = split_outputs;
|
pre_split_outputs = split_outputs;
|
||||||
|
|
||||||
lstm_x_concat_input[idx + 1] = split_outputs[0];
|
lstm_x_concat_input[idx + 1] = split_outputs[0];
|
||||||
|
@ -376,7 +378,8 @@ AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dy
|
||||||
auto origin_input4_shape = AnfAlgo::GetOutputInferShape(origin_input4, 0);
|
auto origin_input4_shape = AnfAlgo::GetOutputInferShape(origin_input4, 0);
|
||||||
// Create reshape to change shape
|
// Create reshape to change shape
|
||||||
std::vector<size_t> shape_tmp;
|
std::vector<size_t> shape_tmp;
|
||||||
if (origin_input4_shape.size() == 3) {
|
constexpr size_t kShapeDimNum = 3;
|
||||||
|
if (origin_input4_shape.size() == kShapeDimNum) {
|
||||||
shape_tmp = origin_input4_shape;
|
shape_tmp = origin_input4_shape;
|
||||||
} else {
|
} else {
|
||||||
shape_tmp = {1, origin_input4_shape[0], origin_input4_shape[1]};
|
shape_tmp = {1, origin_input4_shape[0], origin_input4_shape[1]};
|
||||||
|
@ -506,7 +509,7 @@ const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph
|
||||||
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
|
MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
|
||||||
if (dynamic_rnn_grad_cnode->inputs().size() < kDynamicRNNGradInputNum + 1) {
|
if (dynamic_rnn_grad_cnode->inputs().size() < kDynamicRNNGradInputNum + 1) {
|
||||||
MS_LOG(INFO) << "The node " << dynamic_rnn_grad_cnode->DebugString() << " has less than "
|
MS_LOG(INFO) << "The node " << dynamic_rnn_grad_cnode->DebugString() << " has less than "
|
||||||
<< kDynamicRNNGradInputNum + 1 << " inputs";
|
<< (kDynamicRNNGradInputNum + 1) << " inputs";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
std::vector<AnfNodePtr> new_outputs;
|
std::vector<AnfNodePtr> new_outputs;
|
||||||
|
|
|
@ -30,18 +30,22 @@
|
||||||
using mindspore::abstract::Shape;
|
using mindspore::abstract::Shape;
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace trans {
|
namespace trans {
|
||||||
|
const int b1 = 1;
|
||||||
|
const int b2 = 2;
|
||||||
|
const int b4 = 4;
|
||||||
|
const int b8 = 8;
|
||||||
inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) {
|
inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) {
|
||||||
switch (size) {
|
switch (size) {
|
||||||
case 1:
|
case b1:
|
||||||
static_cast<uint8_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint8_t *>(args.data)[src_idx];
|
static_cast<uint8_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint8_t *>(args.data)[src_idx];
|
||||||
break;
|
break;
|
||||||
case 2:
|
case b2:
|
||||||
static_cast<uint16_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint16_t *>(args.data)[src_idx];
|
static_cast<uint16_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint16_t *>(args.data)[src_idx];
|
||||||
break;
|
break;
|
||||||
case 4:
|
case b4:
|
||||||
static_cast<uint32_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint32_t *>(args.data)[src_idx];
|
static_cast<uint32_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint32_t *>(args.data)[src_idx];
|
||||||
break;
|
break;
|
||||||
case 8:
|
case b8:
|
||||||
static_cast<uint64_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint64_t *>(args.data)[src_idx];
|
static_cast<uint64_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint64_t *>(args.data)[src_idx];
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
@ -357,47 +361,47 @@ std::vector<int64_t> Nc1hwc0DeviceDynamicShape(const std::vector<int64_t> &shape
|
||||||
|
|
||||||
std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape) {
|
std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape) {
|
||||||
// NCDHW
|
// NCDHW
|
||||||
if (shape.size() != 5) {
|
if (shape.size() != kNcdhw) {
|
||||||
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
|
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
|
||||||
}
|
}
|
||||||
std::vector<size_t> device_shape;
|
std::vector<size_t> device_shape;
|
||||||
const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
|
const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
|
||||||
const size_t C0 = kCubeSize;
|
const size_t C0 = kCubeSize;
|
||||||
device_shape.push_back(shape[0]);
|
device_shape.push_back(shape[N_ncdhw]);
|
||||||
device_shape.push_back(shape[2]);
|
device_shape.push_back(shape[D_ncdhw]);
|
||||||
device_shape.push_back(C1);
|
device_shape.push_back(C1);
|
||||||
device_shape.push_back(shape[3]);
|
device_shape.push_back(shape[H_ncdhw]);
|
||||||
device_shape.push_back(shape[4]);
|
device_shape.push_back(shape[W_ncdhw]);
|
||||||
device_shape.push_back(C0);
|
device_shape.push_back(C0);
|
||||||
return device_shape;
|
return device_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> Ndc1hwc0DeviceDynamicShape(const std::vector<int64_t> &shape) {
|
std::vector<int64_t> Ndc1hwc0DeviceDynamicShape(const std::vector<int64_t> &shape) {
|
||||||
// NCDHW
|
// NCDHW
|
||||||
if (shape.size() != 5) {
|
if (shape.size() != kNcdhw) {
|
||||||
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
|
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
|
||||||
}
|
}
|
||||||
std::vector<int64_t> device_shape;
|
std::vector<int64_t> device_shape;
|
||||||
const int64_t C1 = (shape[1] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[1] + kCubeSize - 1) / kCubeSize;
|
const int64_t C1 = (shape[1] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[1] + kCubeSize - 1) / kCubeSize;
|
||||||
const int64_t C0 = kCubeSize;
|
const int64_t C0 = kCubeSize;
|
||||||
device_shape.push_back(shape[0]);
|
device_shape.push_back(shape[N_ncdhw]);
|
||||||
device_shape.push_back(shape[2]);
|
device_shape.push_back(shape[D_ncdhw]);
|
||||||
device_shape.push_back(C1);
|
device_shape.push_back(C1);
|
||||||
device_shape.push_back(shape[3]);
|
device_shape.push_back(shape[H_ncdhw]);
|
||||||
device_shape.push_back(shape[4]);
|
device_shape.push_back(shape[W_ncdhw]);
|
||||||
device_shape.push_back(C0);
|
device_shape.push_back(C0);
|
||||||
return device_shape;
|
return device_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) {
|
std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) {
|
||||||
// NCDHW -> Frac_Z_3D
|
// NCDHW -> Frac_Z_3D
|
||||||
if (shape.size() != 5) {
|
if (shape.size() != kNcdhw) {
|
||||||
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
|
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
|
||||||
}
|
}
|
||||||
std::vector<size_t> device_shape;
|
std::vector<size_t> device_shape;
|
||||||
const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
|
const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
|
||||||
const size_t N1 = (shape[0] + kCubeSize - 1) / kCubeSize;
|
const size_t N1 = (shape[0] + kCubeSize - 1) / kCubeSize;
|
||||||
device_shape.push_back(shape[2] * C1 * shape[3] * shape[4]);
|
device_shape.push_back(shape[D_ncdhw] * C1 * shape[H_ncdhw] * shape[W_ncdhw]);
|
||||||
device_shape.push_back(N1);
|
device_shape.push_back(N1);
|
||||||
device_shape.push_back(kCubeSize);
|
device_shape.push_back(kCubeSize);
|
||||||
device_shape.push_back(kCubeSize);
|
device_shape.push_back(kCubeSize);
|
||||||
|
@ -406,15 +410,15 @@ std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) {
|
||||||
|
|
||||||
std::vector<int64_t> Fracz3DDeviceDynamicShape(const std::vector<int64_t> &shape) {
|
std::vector<int64_t> Fracz3DDeviceDynamicShape(const std::vector<int64_t> &shape) {
|
||||||
// NCDHW -> Frac_Z_3D
|
// NCDHW -> Frac_Z_3D
|
||||||
if (shape.size() != 5) {
|
if (shape.size() != kNcdhw) {
|
||||||
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
|
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
|
||||||
}
|
}
|
||||||
std::vector<int64_t> device_shape;
|
std::vector<int64_t> device_shape;
|
||||||
if (HasShapeDynamic({shape[1], shape[2], shape[3], shape[4]})) {
|
if (HasShapeDynamic({shape[C_ncdhw], shape[D_ncdhw], shape[H_ncdhw], shape[W_ncdhw]})) {
|
||||||
device_shape.push_back(Shape::SHP_ANY);
|
device_shape.push_back(Shape::SHP_ANY);
|
||||||
} else {
|
} else {
|
||||||
const int64_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
|
const int64_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
|
||||||
device_shape.push_back(shape[2] * C1 * shape[3] * shape[4]);
|
device_shape.push_back(shape[D_ncdhw] * C1 * shape[H_ncdhw] * shape[W_ncdhw]);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t N1 = (shape[0] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[0] + kCubeSize - 1) / kCubeSize;
|
const int64_t N1 = (shape[0] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[0] + kCubeSize - 1) / kCubeSize;
|
||||||
|
@ -1558,7 +1562,7 @@ bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result) {
|
||||||
MS_LOG(DEBUG) << "Trans from ndc1hwc0 to ncdhw";
|
MS_LOG(DEBUG) << "Trans from ndc1hwc0 to ncdhw";
|
||||||
MS_EXCEPTION_IF_NULL(result);
|
MS_EXCEPTION_IF_NULL(result);
|
||||||
|
|
||||||
if (args.host_shape.size() != 5) {
|
if (args.host_shape.size() != kNcdhw) {
|
||||||
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
|
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -1572,13 +1576,13 @@ bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result) {
|
||||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto n = args.host_shape[0];
|
auto n = args.host_shape[N_ncdhw];
|
||||||
auto c = args.host_shape[1];
|
auto c = args.host_shape[C_ncdhw];
|
||||||
auto d = args.host_shape[2];
|
auto d = args.host_shape[D_ncdhw];
|
||||||
auto h = args.host_shape[3];
|
auto h = args.host_shape[H_ncdhw];
|
||||||
auto w = args.host_shape[4];
|
auto w = args.host_shape[W_ncdhw];
|
||||||
auto c1 = args.device_shape[2];
|
auto c1 = args.device_shape[C1_ndc1hwc0];
|
||||||
auto c0 = args.device_shape[5];
|
auto c0 = args.device_shape[C0_ndc1hwc0];
|
||||||
const size_t cdhw = c * d * h * w;
|
const size_t cdhw = c * d * h * w;
|
||||||
const size_t dhw = d * h * w;
|
const size_t dhw = d * h * w;
|
||||||
const size_t hw = h * w;
|
const size_t hw = h * w;
|
||||||
|
@ -1613,7 +1617,7 @@ bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result) {
|
||||||
MS_LOG(DEBUG) << "Trans from ncdhw to ndc1hwc0";
|
MS_LOG(DEBUG) << "Trans from ncdhw to ndc1hwc0";
|
||||||
MS_EXCEPTION_IF_NULL(result);
|
MS_EXCEPTION_IF_NULL(result);
|
||||||
|
|
||||||
if (args.host_shape.size() != 5) {
|
if (args.host_shape.size() != kNcdhw) {
|
||||||
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
|
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -1628,11 +1632,11 @@ bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto n = args.host_shape[0];
|
auto n = args.host_shape[N_ncdhw];
|
||||||
auto c = args.host_shape[1];
|
auto c = args.host_shape[C_ncdhw];
|
||||||
auto d = args.host_shape[2];
|
auto d = args.host_shape[D_ncdhw];
|
||||||
auto h = args.host_shape[3];
|
auto h = args.host_shape[H_ncdhw];
|
||||||
auto w = args.host_shape[4];
|
auto w = args.host_shape[W_ncdhw];
|
||||||
auto c0 = kCubeSize;
|
auto c0 = kCubeSize;
|
||||||
auto c1 = DivCeil(c, c0);
|
auto c1 = DivCeil(c, c0);
|
||||||
const size_t cdhw = c * d * h * w;
|
const size_t cdhw = c * d * h * w;
|
||||||
|
@ -1672,7 +1676,7 @@ bool NcdhwToFracZ3D(const FormatArgs &args, void *result) {
|
||||||
MS_LOG(DEBUG) << "Trans from ncdhw to frac_z_3d";
|
MS_LOG(DEBUG) << "Trans from ncdhw to frac_z_3d";
|
||||||
MS_EXCEPTION_IF_NULL(result);
|
MS_EXCEPTION_IF_NULL(result);
|
||||||
|
|
||||||
if (args.host_shape.size() != 5) {
|
if (args.host_shape.size() != kNcdhw) {
|
||||||
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
|
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -1687,11 +1691,11 @@ bool NcdhwToFracZ3D(const FormatArgs &args, void *result) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto n = args.host_shape[0];
|
auto n = args.host_shape[N_ncdhw];
|
||||||
auto c = args.host_shape[1];
|
auto c = args.host_shape[C_ncdhw];
|
||||||
auto d = args.host_shape[2];
|
auto d = args.host_shape[D_ncdhw];
|
||||||
auto h = args.host_shape[3];
|
auto h = args.host_shape[H_ncdhw];
|
||||||
auto w = args.host_shape[4];
|
auto w = args.host_shape[W_ncdhw];
|
||||||
|
|
||||||
auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize;
|
auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize;
|
||||||
const size_t c0 = 16;
|
const size_t c0 = 16;
|
||||||
|
@ -1728,7 +1732,7 @@ bool FracZ3DToNcdhw(const FormatArgs &args, void *result) {
|
||||||
MS_LOG(DEBUG) << "Trans from frac_z_3d to ncdhw";
|
MS_LOG(DEBUG) << "Trans from frac_z_3d to ncdhw";
|
||||||
MS_EXCEPTION_IF_NULL(result);
|
MS_EXCEPTION_IF_NULL(result);
|
||||||
|
|
||||||
if (args.host_shape.size() != 5) {
|
if (args.host_shape.size() != kNcdhw) {
|
||||||
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
|
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -1742,12 +1746,13 @@ bool FracZ3DToNcdhw(const FormatArgs &args, void *result) {
|
||||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto n = args.host_shape[0];
|
auto n = args.host_shape[N_ncdhw];
|
||||||
auto c = args.host_shape[1];
|
auto c = args.host_shape[C_ncdhw];
|
||||||
auto d = args.host_shape[2];
|
auto d = args.host_shape[D_ncdhw];
|
||||||
auto h = args.host_shape[3];
|
auto h = args.host_shape[H_ncdhw];
|
||||||
auto w = args.host_shape[4];
|
auto w = args.host_shape[W_ncdhw];
|
||||||
auto c0 = args.device_shape[3];
|
const int kFZ3D_C0 = 3;
|
||||||
|
auto c0 = args.device_shape[kFZ3D_C0];
|
||||||
auto c1 = DivCeil(c, kCubeSize);
|
auto c1 = DivCeil(c, kCubeSize);
|
||||||
auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize;
|
auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize;
|
||||||
auto n1n0c0 = n1n0 * c0;
|
auto n1n0c0 = n1n0 * c0;
|
||||||
|
|
|
@ -31,14 +31,22 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace trans {
|
namespace trans {
|
||||||
enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNcdhw };
|
enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims };
|
||||||
enum Axis5D : int {
|
enum Axis5D : int {
|
||||||
N_ncdhw = 0,
|
N_ncdhw = 0,
|
||||||
C_ncdhw,
|
C_ncdhw,
|
||||||
D_ncdhw,
|
D_ncdhw,
|
||||||
H_ncdhw,
|
H_ncdhw,
|
||||||
W_ncdhw,
|
W_ncdhw,
|
||||||
|
kNcdhw,
|
||||||
|
N_ndc1hwc0 = 0,
|
||||||
|
D_ndc1hwc0,
|
||||||
|
C1_ndc1hwc0,
|
||||||
|
H_ndc1hwc0,
|
||||||
|
W_ndc1hwc0,
|
||||||
|
C0_ndc1hwc0
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TypeIdArgs {
|
struct TypeIdArgs {
|
||||||
const void *data;
|
const void *data;
|
||||||
size_t host_shape_size; // Multiply each dimension elements. [a, b, c, d] => a*b*c*d
|
size_t host_shape_size; // Multiply each dimension elements. [a, b, c, d] => a*b*c*d
|
||||||
|
@ -118,25 +126,25 @@ std::vector<T> PaddingShapeTo5dDefault(const std::vector<T> &shape) {
|
||||||
}
|
}
|
||||||
std::vector<T> shape_5d(kNcdhw, 1);
|
std::vector<T> shape_5d(kNcdhw, 1);
|
||||||
switch (shape.size()) {
|
switch (shape.size()) {
|
||||||
case 0:
|
case N_ncdhw:
|
||||||
return shape_5d;
|
return shape_5d;
|
||||||
case 1:
|
case C_ncdhw:
|
||||||
shape_5d[1] = shape[0];
|
shape_5d[C_ncdhw] = shape[N_ncdhw];
|
||||||
break;
|
break;
|
||||||
case 2:
|
case D_ncdhw:
|
||||||
shape_5d[1] = shape[0];
|
shape_5d[C_ncdhw] = shape[N_ncdhw];
|
||||||
shape_5d[2] = shape[1];
|
shape_5d[D_ncdhw] = shape[C_ncdhw];
|
||||||
break;
|
break;
|
||||||
case 3:
|
case H_ncdhw:
|
||||||
shape_5d[1] = shape[0];
|
shape_5d[C_ncdhw] = shape[N_ncdhw];
|
||||||
shape_5d[2] = shape[1];
|
shape_5d[D_ncdhw] = shape[C_ncdhw];
|
||||||
shape_5d[3] = shape[2];
|
shape_5d[H_ncdhw] = shape[D_ncdhw];
|
||||||
break;
|
break;
|
||||||
case 4:
|
case W_ncdhw:
|
||||||
shape_5d[1] = shape[0];
|
shape_5d[C_ncdhw] = shape[N_ncdhw];
|
||||||
shape_5d[2] = shape[1];
|
shape_5d[D_ncdhw] = shape[C_ncdhw];
|
||||||
shape_5d[3] = shape[2];
|
shape_5d[H_ncdhw] = shape[D_ncdhw];
|
||||||
shape_5d[4] = shape[3];
|
shape_5d[W_ncdhw] = shape[H_ncdhw];
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size();
|
MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size();
|
||||||
|
@ -148,21 +156,21 @@ template <typename T>
|
||||||
std::vector<T> PaddingShapeTo4dDefault(const std::vector<T> &shape) {
|
std::vector<T> PaddingShapeTo4dDefault(const std::vector<T> &shape) {
|
||||||
std::vector<T> shape_4d(kNchwDims, 1);
|
std::vector<T> shape_4d(kNchwDims, 1);
|
||||||
switch (shape.size()) {
|
switch (shape.size()) {
|
||||||
case 0:
|
case kN:
|
||||||
return shape_4d;
|
return shape_4d;
|
||||||
case 1:
|
case kC:
|
||||||
shape_4d[kC] = shape[kN];
|
shape_4d[kC] = shape[kN];
|
||||||
break;
|
break;
|
||||||
case 2:
|
case kH:
|
||||||
shape_4d[kC] = shape[kN];
|
shape_4d[kC] = shape[kN];
|
||||||
shape_4d[kH] = shape[kC];
|
shape_4d[kH] = shape[kC];
|
||||||
break;
|
break;
|
||||||
case 3:
|
case kW:
|
||||||
shape_4d[kC] = shape[kN];
|
shape_4d[kC] = shape[kN];
|
||||||
shape_4d[kH] = shape[kC];
|
shape_4d[kH] = shape[kC];
|
||||||
shape_4d[kW] = shape[kH];
|
shape_4d[kW] = shape[kH];
|
||||||
break;
|
break;
|
||||||
case 4:
|
case kNchwDims:
|
||||||
std::copy(shape.begin(), shape.end(), shape_4d.begin());
|
std::copy(shape.begin(), shape.end(), shape_4d.begin());
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -30,13 +30,10 @@ std::optional<std::string> Common::GetRealPath(const std::string &input_path) {
|
||||||
if (input_path.length() >= PATH_MAX) {
|
if (input_path.length() >= PATH_MAX) {
|
||||||
MS_LOG(EXCEPTION) << "The length of path: " << input_path << " exceeds limit: " << PATH_MAX;
|
MS_LOG(EXCEPTION) << "The length of path: " << input_path << " exceeds limit: " << PATH_MAX;
|
||||||
}
|
}
|
||||||
#if defined(SYSTEM_ENV_POSIX)
|
auto path_split_pos = input_path.find_last_of('/');
|
||||||
size_t path_split_pos = input_path.find_last_of('/');
|
if (path_split_pos == std::string::npos) {
|
||||||
#elif defined(SYSTEM_ENV_WINDOWS)
|
path_split_pos = input_path.find_last_of('\\');
|
||||||
size_t path_split_pos = input_path.find_last_of('\\');
|
}
|
||||||
#else
|
|
||||||
MS_LOG(EXCEPTION) << "Unsupported platform.";
|
|
||||||
#endif
|
|
||||||
// get real path
|
// get real path
|
||||||
char real_path[PATH_MAX] = {0};
|
char real_path[PATH_MAX] = {0};
|
||||||
// input_path is dir + file_name
|
// input_path is dir + file_name
|
||||||
|
|
|
@ -29,6 +29,7 @@ namespace ops {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr size_t kLenLogProbs = 3;
|
constexpr size_t kLenLogProbs = 3;
|
||||||
constexpr size_t kLenTarget = 2;
|
constexpr size_t kLenTarget = 2;
|
||||||
|
constexpr int64_t kMulti = 2;
|
||||||
constexpr size_t kInputSize = 4;
|
constexpr size_t kInputSize = 4;
|
||||||
abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
|
abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
@ -54,7 +55,7 @@ abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
|
||||||
|
|
||||||
ShapeVector output_shape;
|
ShapeVector output_shape;
|
||||||
std::vector<int64_t> out_dim0 = {N};
|
std::vector<int64_t> out_dim0 = {N};
|
||||||
std::vector<int64_t> out_dim1 = {N, T, 2 * S + 1};
|
std::vector<int64_t> out_dim1 = {N, T, kMulti * S + 1};
|
||||||
abstract::ShapePtr neg_log_shape = std::make_shared<abstract::Shape>(out_dim0);
|
abstract::ShapePtr neg_log_shape = std::make_shared<abstract::Shape>(out_dim0);
|
||||||
abstract::ShapePtr log_alpha_shape = std::make_shared<abstract::Shape>(out_dim1);
|
abstract::ShapePtr log_alpha_shape = std::make_shared<abstract::Shape>(out_dim1);
|
||||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{neg_log_shape, log_alpha_shape});
|
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{neg_log_shape, log_alpha_shape});
|
||||||
|
|
|
@ -28,6 +28,7 @@ namespace ops {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr size_t kLenLogProbs = 3;
|
constexpr size_t kLenLogProbs = 3;
|
||||||
constexpr size_t kInputSize = 7;
|
constexpr size_t kInputSize = 7;
|
||||||
|
constexpr size_t kIdx2 = 2;
|
||||||
abstract::ShapePtr CTCLossV2GradInferShape(const PrimitivePtr &primitive,
|
abstract::ShapePtr CTCLossV2GradInferShape(const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
@ -43,7 +44,7 @@ abstract::ShapePtr CTCLossV2GradInferShape(const PrimitivePtr &primitive,
|
||||||
}
|
}
|
||||||
int64_t T = log_probs_shape[0];
|
int64_t T = log_probs_shape[0];
|
||||||
int64_t N = log_probs_shape[1];
|
int64_t N = log_probs_shape[1];
|
||||||
int64_t C = log_probs_shape[2];
|
int64_t C = log_probs_shape[kIdx2];
|
||||||
ShapeVector output_shape = {N, T, C};
|
ShapeVector output_shape = {N, T, C};
|
||||||
return std::make_shared<abstract::Shape>(output_shape);
|
return std::make_shared<abstract::Shape>(output_shape);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue