SliceGrad unify mindir adpation for Cangjie frontend
This commit is contained in:
parent
8195488630
commit
76075c0b36
|
@ -31,6 +31,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kSliceGradInputTensorNum = 4;
|
||||
constexpr size_t kSliceGradCangjieInputTensorNum = 2;
|
||||
|
||||
std::vector<int64_t> GetInputXShape(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -50,11 +51,8 @@ std::vector<int64_t> GetTupleValue(const AnfNodePtr &node) {
|
|||
} // namespace
|
||||
|
||||
const BaseRef SliceGradUnifyMindIR::DefinePattern() const {
|
||||
VarPtr X1 = std::make_shared<Var>();
|
||||
VarPtr X2 = std::make_shared<Var>();
|
||||
VarPtr X3 = std::make_shared<Var>();
|
||||
VarPtr X4 = std::make_shared<Var>();
|
||||
VectorRef slice_grad({std::make_shared<Primitive>("SliceGrad"), X1, X2, X3, X4});
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
VectorRef slice_grad({std::make_shared<Primitive>("SliceGrad"), Xs});
|
||||
return slice_grad;
|
||||
}
|
||||
|
||||
|
@ -63,8 +61,16 @@ const AnfNodePtr SliceGradUnifyMindIR::Process(const FuncGraphPtr &graph, const
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
auto slice_grad = CheckAnfNodeIfCNodeAndInputSize(node, kSliceGradInputTensorNum);
|
||||
std::vector<AnfNodePtr> pad_inputs = {NewValueNode(std::make_shared<Primitive>(kPadOpName)), slice_grad->input(1)};
|
||||
auto slice_grad = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(slice_grad);
|
||||
auto input_num = AnfAlgo::GetInputTensorNum(slice_grad);
|
||||
if (input_num != kSliceGradInputTensorNum && input_num != kSliceGradCangjieInputTensorNum) {
|
||||
MS_LOG(EXCEPTION) << "The input tensor size[" << input_num
|
||||
<< "] of node " + slice_grad->DebugString() + " is not equal to " << kSliceGradInputTensorNum
|
||||
<< " or " << kSliceGradCangjieInputTensorNum;
|
||||
}
|
||||
std::vector<AnfNodePtr> pad_inputs = {NewValueNode(std::make_shared<Primitive>(kPadOpName)),
|
||||
slice_grad->input(kIndex1)};
|
||||
auto pad = graph->NewCNode(pad_inputs);
|
||||
MS_EXCEPTION_IF_NULL(pad);
|
||||
pad->set_scope(slice_grad->scope());
|
||||
|
@ -72,8 +78,16 @@ const AnfNodePtr SliceGradUnifyMindIR::Process(const FuncGraphPtr &graph, const
|
|||
|
||||
// set attr paddings
|
||||
auto x_shape = GetInputXShape(slice_grad);
|
||||
auto begins = GetTupleValue(slice_grad->input(kIndex3));
|
||||
auto sizes = GetTupleValue(slice_grad->input(kIndex4));
|
||||
std::vector<int64_t> begins;
|
||||
std::vector<int64_t> sizes;
|
||||
if (input_num == kSliceGradInputTensorNum) {
|
||||
begins = GetTupleValue(slice_grad->input(kIndex3));
|
||||
sizes = GetTupleValue(slice_grad->input(kIndex4));
|
||||
} else {
|
||||
// if frontend is Cangjie and mode is pynative, input 2, 3 is already converted to attr
|
||||
begins = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(slice_grad, kAttrBegin);
|
||||
sizes = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(slice_grad, kAttrSize);
|
||||
}
|
||||
if (x_shape.size() != begins.size() || begins.size() != sizes.size()) {
|
||||
MS_LOG(EXCEPTION) << "For SliceGrad, x's shape dim number should be equal to len(begin) and len(size).";
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue