forked from mindspore-Ecosystem/mindspore
!24 Change strategy for structure output
Merge pull request !24 from 步学/structure-output
This commit is contained in:
commit
c1c8fef9ca
|
@ -725,23 +725,15 @@ py::object ExecutorPy::Run(const py::tuple& args, const py::object& phase) {
|
|||
return BaseRefToPyData(value);
|
||||
}
|
||||
|
||||
py::object StructureOutput(const AbstractBasePtr& output, const py::tuple& data, size_t* count) {
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
py::object ExtractGeneralCnodeRet(const AbstractBasePtr& cnode_data, const py::tuple& data, size_t* count) {
|
||||
MS_EXCEPTION_IF_NULL(cnode_data);
|
||||
if (*count >= data.size()) {
|
||||
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
|
||||
<< " less than the number of elements required. ";
|
||||
}
|
||||
|
||||
if (!output->isa<AbstractTuple>()) {
|
||||
ValuePtr value = output->BuildValue();
|
||||
if (value != kAnyValue) {
|
||||
return ValuePtrToPyData(value);
|
||||
}
|
||||
if (!output->isa<AbstractTensor>()) {
|
||||
MS_LOG(EXCEPTION) << "Output can only be tensor except for constants, but got "
|
||||
<< output->BuildValue()->ToString() << ".";
|
||||
}
|
||||
if (*count >= data.size()) {
|
||||
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
|
||||
<< " less than the number of elements required. ";
|
||||
}
|
||||
auto shape = output->BuildShape();
|
||||
if (cnode_data->isa<AbstractTensor>()) {
|
||||
BaseShapePtr shape = cnode_data->BuildShape();
|
||||
auto shape_act = shape->cast<abstract::ShapePtr>()->shape();
|
||||
Tensor tensor_exp = py::cast<Tensor>(data[*count]);
|
||||
if (shape_act != tensor_exp.shape()) {
|
||||
|
@ -751,16 +743,58 @@ py::object StructureOutput(const AbstractBasePtr& output, const py::tuple& data,
|
|||
return data[(*count)++];
|
||||
}
|
||||
|
||||
auto tuple_output = output->cast<AbstractTuplePtr>();
|
||||
AbstractBasePtrList elements = tuple_output->elements();
|
||||
size_t size = elements.size();
|
||||
if (!cnode_data->isa<AbstractTuple>()) {
|
||||
MS_LOG(EXCEPTION) << "The output of operator in the final anf graph could "
|
||||
<< "only be a tensor or a tuple of tensor, but got " << cnode_data->BuildValue()->ToString()
|
||||
<< ".";
|
||||
}
|
||||
auto data_tp = cnode_data->cast<AbstractTuplePtr>();
|
||||
auto elements = data_tp->elements();
|
||||
size_t size = data_tp->size();
|
||||
py::tuple tp = py::tuple(size);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
tp[i] = StructureOutput(elements[i], data, count);
|
||||
tp[i] = ExtractGeneralCnodeRet(elements[i], data, count);
|
||||
}
|
||||
return std::move(tp);
|
||||
}
|
||||
|
||||
py::object StructureOutput(const AnfNodePtr& output_node, const py::tuple& data, size_t* count) {
|
||||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
|
||||
if (output_node->isa<ValueNode>()) {
|
||||
return ValuePtrToPyData(GetValueNode(output_node));
|
||||
}
|
||||
|
||||
if (*count >= data.size()) {
|
||||
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
|
||||
<< " less than the number of elements required. ";
|
||||
}
|
||||
if (output_node->isa<Parameter>()) {
|
||||
return data[(*count)++];
|
||||
}
|
||||
|
||||
auto output_c = output_node->cast<CNodePtr>();
|
||||
if (output_c == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The final anf graph could only have constant, parameter, and operator, but got "
|
||||
<< output_node->ToString();
|
||||
}
|
||||
|
||||
if (output_c->IsApply(prim::kPrimMakeTuple)) {
|
||||
auto input_list = output_c->inputs();
|
||||
size_t size = input_list.size();
|
||||
py::tuple tp = py::tuple(size - 1);
|
||||
for (size_t i = 1; i < size; i++) {
|
||||
tp[i - 1] = StructureOutput(input_list[i], data, count);
|
||||
}
|
||||
return std::move(tp);
|
||||
}
|
||||
if (output_c->IsApply(prim::kPrimDepend)) {
|
||||
return StructureOutput(output_c->input(1), data, count);
|
||||
}
|
||||
|
||||
return ExtractGeneralCnodeRet(output_c->abstract(), data, count);
|
||||
}
|
||||
|
||||
std::shared_ptr<py::object> DoExecGraph(const FuncGraphPtr& graph, const std::vector<MeTensorPtr>& inputs,
|
||||
const std::string& phase) {
|
||||
std::vector<GeTensorPtr> ge_tensors = TransformUtil::ConvertInputTensors(inputs, kOpFormat_NCHW);
|
||||
|
@ -806,11 +840,10 @@ std::shared_ptr<py::object> DoExecGraph(const FuncGraphPtr& graph, const std::ve
|
|||
std::shared_ptr<py::object> ret = nullptr;
|
||||
|
||||
#ifdef ENABLE_GE
|
||||
AnfNodePtr root = graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
AbstractBasePtr output = root->abstract();
|
||||
AnfNodePtr output_node = graph->get_return()->input(1);
|
||||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
size_t count = 0;
|
||||
py::object oj = StructureOutput(output, outputs, &count);
|
||||
py::object oj = StructureOutput(output_node, outputs, &count);
|
||||
ret = std::make_shared<py::object>(oj);
|
||||
#else
|
||||
if (outputs.size() == 1) {
|
||||
|
|
|
@ -236,7 +236,7 @@ def test_soft():
|
|||
def __init__(self):
|
||||
super(SoftmaxCrossEntropyWithLogitsNet, self).__init__()
|
||||
self.soft = P.SoftmaxCrossEntropyWithLogits()
|
||||
self.value = (Tensor(np.zeros((2,)).astype(np.float32)), Tensor(np.ones((2,)).astype(np.float32)))
|
||||
self.value = (Tensor(np.zeros((2, 2)).astype(np.float32)), Tensor(np.ones((2, 2)).astype(np.float32)))
|
||||
|
||||
def construct(self, x, y, z):
|
||||
xx = x + y
|
||||
|
@ -246,8 +246,30 @@ def test_soft():
|
|||
ret = (ret, self.value)
|
||||
return ret
|
||||
|
||||
input1 = Tensor(np.zeros((2,)).astype(np.float32))
|
||||
input2 = Tensor(np.ones((2,)).astype(np.float32))
|
||||
input3 = Tensor((np.ones((2,)) + np.ones((2,))).astype(np.float32))
|
||||
input1 = Tensor(np.zeros((2, 2)).astype(np.float32))
|
||||
input2 = Tensor(np.ones((2, 2)).astype(np.float32))
|
||||
input3 = Tensor((np.ones((2, 2)) + np.ones((2, 2))).astype(np.float32))
|
||||
net = SoftmaxCrossEntropyWithLogitsNet()
|
||||
print(net(input1, input2, input3))
|
||||
net(input1, input2, input3)
|
||||
|
||||
|
||||
def test_const_depend():
|
||||
class ConstDepend(Cell):
|
||||
def __init__(self):
|
||||
super(ConstDepend, self).__init__()
|
||||
self.value = (Tensor(np.zeros((2, 3)).astype(np.float32)), Tensor(np.ones((2, 3)).astype(np.float32)))
|
||||
self.soft = P.SoftmaxCrossEntropyWithLogits()
|
||||
self.depend = depend
|
||||
|
||||
def construct(self, x, y, z):
|
||||
ret = x + y
|
||||
ret = ret * z
|
||||
ret = self.depend(self.value, ret)
|
||||
ret = (ret, self.soft(x, y))
|
||||
return ret
|
||||
|
||||
input1 = Tensor(np.zeros((2, 2)).astype(np.float32))
|
||||
input2 = Tensor(np.ones((2, 2)).astype(np.float32))
|
||||
input3 = Tensor((np.ones((2, 2)) + np.ones((2, 2))).astype(np.float32))
|
||||
net = ConstDepend()
|
||||
net(input1, input2, input3)
|
||||
|
|
Loading…
Reference in New Issue