SliceGrad unify mindir adpation for Cangjie frontend

This commit is contained in:
yuchaojie 2021-07-01 17:33:14 +08:00
parent 8195488630
commit 76075c0b36
1 changed files with 23 additions and 9 deletions

View File

@ -31,6 +31,7 @@ namespace mindspore {
namespace opt { namespace opt {
namespace { namespace {
constexpr size_t kSliceGradInputTensorNum = 4; constexpr size_t kSliceGradInputTensorNum = 4;
constexpr size_t kSliceGradCangjieInputTensorNum = 2;
std::vector<int64_t> GetInputXShape(const AnfNodePtr &node) { std::vector<int64_t> GetInputXShape(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
@ -50,11 +51,8 @@ std::vector<int64_t> GetTupleValue(const AnfNodePtr &node) {
} // namespace } // namespace
const BaseRef SliceGradUnifyMindIR::DefinePattern() const { const BaseRef SliceGradUnifyMindIR::DefinePattern() const {
VarPtr X1 = std::make_shared<Var>(); VarPtr Xs = std::make_shared<SeqVar>();
VarPtr X2 = std::make_shared<Var>(); VectorRef slice_grad({std::make_shared<Primitive>("SliceGrad"), Xs});
VarPtr X3 = std::make_shared<Var>();
VarPtr X4 = std::make_shared<Var>();
VectorRef slice_grad({std::make_shared<Primitive>("SliceGrad"), X1, X2, X3, X4});
return slice_grad; return slice_grad;
} }
@ -63,8 +61,16 @@ const AnfNodePtr SliceGradUnifyMindIR::Process(const FuncGraphPtr &graph, const
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto slice_grad = CheckAnfNodeIfCNodeAndInputSize(node, kSliceGradInputTensorNum); auto slice_grad = node->cast<CNodePtr>();
std::vector<AnfNodePtr> pad_inputs = {NewValueNode(std::make_shared<Primitive>(kPadOpName)), slice_grad->input(1)}; MS_EXCEPTION_IF_NULL(slice_grad);
auto input_num = AnfAlgo::GetInputTensorNum(slice_grad);
if (input_num != kSliceGradInputTensorNum && input_num != kSliceGradCangjieInputTensorNum) {
MS_LOG(EXCEPTION) << "The input tensor size[" << input_num
<< "] of node " + slice_grad->DebugString() + " is not equal to " << kSliceGradInputTensorNum
<< " or " << kSliceGradCangjieInputTensorNum;
}
std::vector<AnfNodePtr> pad_inputs = {NewValueNode(std::make_shared<Primitive>(kPadOpName)),
slice_grad->input(kIndex1)};
auto pad = graph->NewCNode(pad_inputs); auto pad = graph->NewCNode(pad_inputs);
MS_EXCEPTION_IF_NULL(pad); MS_EXCEPTION_IF_NULL(pad);
pad->set_scope(slice_grad->scope()); pad->set_scope(slice_grad->scope());
@ -72,8 +78,16 @@ const AnfNodePtr SliceGradUnifyMindIR::Process(const FuncGraphPtr &graph, const
// set attr paddings // set attr paddings
auto x_shape = GetInputXShape(slice_grad); auto x_shape = GetInputXShape(slice_grad);
auto begins = GetTupleValue(slice_grad->input(kIndex3)); std::vector<int64_t> begins;
auto sizes = GetTupleValue(slice_grad->input(kIndex4)); std::vector<int64_t> sizes;
if (input_num == kSliceGradInputTensorNum) {
begins = GetTupleValue(slice_grad->input(kIndex3));
sizes = GetTupleValue(slice_grad->input(kIndex4));
} else {
// if frontend is Cangjie and mode is pynative, input 2, 3 is already converted to attr
begins = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(slice_grad, kAttrBegin);
sizes = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(slice_grad, kAttrSize);
}
if (x_shape.size() != begins.size() || begins.size() != sizes.size()) { if (x_shape.size() != begins.size() || begins.size() != sizes.size()) {
MS_LOG(EXCEPTION) << "For SliceGrad, x's shape dim number should be equal to len(begin) and len(size)."; MS_LOG(EXCEPTION) << "For SliceGrad, x's shape dim number should be equal to len(begin) and len(size).";
} }