SliceGrad unify mindir adpation for Cangjie frontend
This commit is contained in:
parent
8195488630
commit
76075c0b36
|
@ -31,6 +31,7 @@ namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr size_t kSliceGradInputTensorNum = 4;
|
constexpr size_t kSliceGradInputTensorNum = 4;
|
||||||
|
constexpr size_t kSliceGradCangjieInputTensorNum = 2;
|
||||||
|
|
||||||
std::vector<int64_t> GetInputXShape(const AnfNodePtr &node) {
|
std::vector<int64_t> GetInputXShape(const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
@ -50,11 +51,8 @@ std::vector<int64_t> GetTupleValue(const AnfNodePtr &node) {
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
const BaseRef SliceGradUnifyMindIR::DefinePattern() const {
|
const BaseRef SliceGradUnifyMindIR::DefinePattern() const {
|
||||||
VarPtr X1 = std::make_shared<Var>();
|
VarPtr Xs = std::make_shared<SeqVar>();
|
||||||
VarPtr X2 = std::make_shared<Var>();
|
VectorRef slice_grad({std::make_shared<Primitive>("SliceGrad"), Xs});
|
||||||
VarPtr X3 = std::make_shared<Var>();
|
|
||||||
VarPtr X4 = std::make_shared<Var>();
|
|
||||||
VectorRef slice_grad({std::make_shared<Primitive>("SliceGrad"), X1, X2, X3, X4});
|
|
||||||
return slice_grad;
|
return slice_grad;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,8 +61,16 @@ const AnfNodePtr SliceGradUnifyMindIR::Process(const FuncGraphPtr &graph, const
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
|
||||||
auto slice_grad = CheckAnfNodeIfCNodeAndInputSize(node, kSliceGradInputTensorNum);
|
auto slice_grad = node->cast<CNodePtr>();
|
||||||
std::vector<AnfNodePtr> pad_inputs = {NewValueNode(std::make_shared<Primitive>(kPadOpName)), slice_grad->input(1)};
|
MS_EXCEPTION_IF_NULL(slice_grad);
|
||||||
|
auto input_num = AnfAlgo::GetInputTensorNum(slice_grad);
|
||||||
|
if (input_num != kSliceGradInputTensorNum && input_num != kSliceGradCangjieInputTensorNum) {
|
||||||
|
MS_LOG(EXCEPTION) << "The input tensor size[" << input_num
|
||||||
|
<< "] of node " + slice_grad->DebugString() + " is not equal to " << kSliceGradInputTensorNum
|
||||||
|
<< " or " << kSliceGradCangjieInputTensorNum;
|
||||||
|
}
|
||||||
|
std::vector<AnfNodePtr> pad_inputs = {NewValueNode(std::make_shared<Primitive>(kPadOpName)),
|
||||||
|
slice_grad->input(kIndex1)};
|
||||||
auto pad = graph->NewCNode(pad_inputs);
|
auto pad = graph->NewCNode(pad_inputs);
|
||||||
MS_EXCEPTION_IF_NULL(pad);
|
MS_EXCEPTION_IF_NULL(pad);
|
||||||
pad->set_scope(slice_grad->scope());
|
pad->set_scope(slice_grad->scope());
|
||||||
|
@ -72,8 +78,16 @@ const AnfNodePtr SliceGradUnifyMindIR::Process(const FuncGraphPtr &graph, const
|
||||||
|
|
||||||
// set attr paddings
|
// set attr paddings
|
||||||
auto x_shape = GetInputXShape(slice_grad);
|
auto x_shape = GetInputXShape(slice_grad);
|
||||||
auto begins = GetTupleValue(slice_grad->input(kIndex3));
|
std::vector<int64_t> begins;
|
||||||
auto sizes = GetTupleValue(slice_grad->input(kIndex4));
|
std::vector<int64_t> sizes;
|
||||||
|
if (input_num == kSliceGradInputTensorNum) {
|
||||||
|
begins = GetTupleValue(slice_grad->input(kIndex3));
|
||||||
|
sizes = GetTupleValue(slice_grad->input(kIndex4));
|
||||||
|
} else {
|
||||||
|
// if frontend is Cangjie and mode is pynative, input 2, 3 is already converted to attr
|
||||||
|
begins = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(slice_grad, kAttrBegin);
|
||||||
|
sizes = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(slice_grad, kAttrSize);
|
||||||
|
}
|
||||||
if (x_shape.size() != begins.size() || begins.size() != sizes.size()) {
|
if (x_shape.size() != begins.size() || begins.size() != sizes.size()) {
|
||||||
MS_LOG(EXCEPTION) << "For SliceGrad, x's shape dim number should be equal to len(begin) and len(size).";
|
MS_LOG(EXCEPTION) << "For SliceGrad, x's shape dim number should be equal to len(begin) and len(size).";
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue