!24 Change strategy for structure output
Merge pull request !24 from 步学/structure-output
This commit is contained in:
commit
c1c8fef9ca
|
@ -725,23 +725,15 @@ py::object ExecutorPy::Run(const py::tuple& args, const py::object& phase) {
|
||||||
return BaseRefToPyData(value);
|
return BaseRefToPyData(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
py::object StructureOutput(const AbstractBasePtr& output, const py::tuple& data, size_t* count) {
|
py::object ExtractGeneralCnodeRet(const AbstractBasePtr& cnode_data, const py::tuple& data, size_t* count) {
|
||||||
MS_EXCEPTION_IF_NULL(output);
|
MS_EXCEPTION_IF_NULL(cnode_data);
|
||||||
|
if (*count >= data.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
|
||||||
|
<< " less than the number of elements required. ";
|
||||||
|
}
|
||||||
|
|
||||||
if (!output->isa<AbstractTuple>()) {
|
if (cnode_data->isa<AbstractTensor>()) {
|
||||||
ValuePtr value = output->BuildValue();
|
BaseShapePtr shape = cnode_data->BuildShape();
|
||||||
if (value != kAnyValue) {
|
|
||||||
return ValuePtrToPyData(value);
|
|
||||||
}
|
|
||||||
if (!output->isa<AbstractTensor>()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Output can only be tensor except for constants, but got "
|
|
||||||
<< output->BuildValue()->ToString() << ".";
|
|
||||||
}
|
|
||||||
if (*count >= data.size()) {
|
|
||||||
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
|
|
||||||
<< " less than the number of elements required. ";
|
|
||||||
}
|
|
||||||
auto shape = output->BuildShape();
|
|
||||||
auto shape_act = shape->cast<abstract::ShapePtr>()->shape();
|
auto shape_act = shape->cast<abstract::ShapePtr>()->shape();
|
||||||
Tensor tensor_exp = py::cast<Tensor>(data[*count]);
|
Tensor tensor_exp = py::cast<Tensor>(data[*count]);
|
||||||
if (shape_act != tensor_exp.shape()) {
|
if (shape_act != tensor_exp.shape()) {
|
||||||
|
@ -751,16 +743,58 @@ py::object StructureOutput(const AbstractBasePtr& output, const py::tuple& data,
|
||||||
return data[(*count)++];
|
return data[(*count)++];
|
||||||
}
|
}
|
||||||
|
|
||||||
auto tuple_output = output->cast<AbstractTuplePtr>();
|
if (!cnode_data->isa<AbstractTuple>()) {
|
||||||
AbstractBasePtrList elements = tuple_output->elements();
|
MS_LOG(EXCEPTION) << "The output of operator in the final anf graph could "
|
||||||
size_t size = elements.size();
|
<< "only be a tensor or a tuple of tensor, but got " << cnode_data->BuildValue()->ToString()
|
||||||
|
<< ".";
|
||||||
|
}
|
||||||
|
auto data_tp = cnode_data->cast<AbstractTuplePtr>();
|
||||||
|
auto elements = data_tp->elements();
|
||||||
|
size_t size = data_tp->size();
|
||||||
py::tuple tp = py::tuple(size);
|
py::tuple tp = py::tuple(size);
|
||||||
for (size_t i = 0; i < size; i++) {
|
for (size_t i = 0; i < size; i++) {
|
||||||
tp[i] = StructureOutput(elements[i], data, count);
|
tp[i] = ExtractGeneralCnodeRet(elements[i], data, count);
|
||||||
}
|
}
|
||||||
return std::move(tp);
|
return std::move(tp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
py::object StructureOutput(const AnfNodePtr& output_node, const py::tuple& data, size_t* count) {
|
||||||
|
MS_EXCEPTION_IF_NULL(output_node);
|
||||||
|
|
||||||
|
if (output_node->isa<ValueNode>()) {
|
||||||
|
return ValuePtrToPyData(GetValueNode(output_node));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (*count >= data.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
|
||||||
|
<< " less than the number of elements required. ";
|
||||||
|
}
|
||||||
|
if (output_node->isa<Parameter>()) {
|
||||||
|
return data[(*count)++];
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output_c = output_node->cast<CNodePtr>();
|
||||||
|
if (output_c == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "The final anf graph could only have constant, parameter, and operator, but got "
|
||||||
|
<< output_node->ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (output_c->IsApply(prim::kPrimMakeTuple)) {
|
||||||
|
auto input_list = output_c->inputs();
|
||||||
|
size_t size = input_list.size();
|
||||||
|
py::tuple tp = py::tuple(size - 1);
|
||||||
|
for (size_t i = 1; i < size; i++) {
|
||||||
|
tp[i - 1] = StructureOutput(input_list[i], data, count);
|
||||||
|
}
|
||||||
|
return std::move(tp);
|
||||||
|
}
|
||||||
|
if (output_c->IsApply(prim::kPrimDepend)) {
|
||||||
|
return StructureOutput(output_c->input(1), data, count);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ExtractGeneralCnodeRet(output_c->abstract(), data, count);
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<py::object> DoExecGraph(const FuncGraphPtr& graph, const std::vector<MeTensorPtr>& inputs,
|
std::shared_ptr<py::object> DoExecGraph(const FuncGraphPtr& graph, const std::vector<MeTensorPtr>& inputs,
|
||||||
const std::string& phase) {
|
const std::string& phase) {
|
||||||
std::vector<GeTensorPtr> ge_tensors = TransformUtil::ConvertInputTensors(inputs, kOpFormat_NCHW);
|
std::vector<GeTensorPtr> ge_tensors = TransformUtil::ConvertInputTensors(inputs, kOpFormat_NCHW);
|
||||||
|
@ -806,11 +840,10 @@ std::shared_ptr<py::object> DoExecGraph(const FuncGraphPtr& graph, const std::ve
|
||||||
std::shared_ptr<py::object> ret = nullptr;
|
std::shared_ptr<py::object> ret = nullptr;
|
||||||
|
|
||||||
#ifdef ENABLE_GE
|
#ifdef ENABLE_GE
|
||||||
AnfNodePtr root = graph->get_return();
|
AnfNodePtr output_node = graph->get_return()->input(1);
|
||||||
MS_EXCEPTION_IF_NULL(root);
|
MS_EXCEPTION_IF_NULL(output_node);
|
||||||
AbstractBasePtr output = root->abstract();
|
|
||||||
size_t count = 0;
|
size_t count = 0;
|
||||||
py::object oj = StructureOutput(output, outputs, &count);
|
py::object oj = StructureOutput(output_node, outputs, &count);
|
||||||
ret = std::make_shared<py::object>(oj);
|
ret = std::make_shared<py::object>(oj);
|
||||||
#else
|
#else
|
||||||
if (outputs.size() == 1) {
|
if (outputs.size() == 1) {
|
||||||
|
|
|
@ -236,7 +236,7 @@ def test_soft():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(SoftmaxCrossEntropyWithLogitsNet, self).__init__()
|
super(SoftmaxCrossEntropyWithLogitsNet, self).__init__()
|
||||||
self.soft = P.SoftmaxCrossEntropyWithLogits()
|
self.soft = P.SoftmaxCrossEntropyWithLogits()
|
||||||
self.value = (Tensor(np.zeros((2,)).astype(np.float32)), Tensor(np.ones((2,)).astype(np.float32)))
|
self.value = (Tensor(np.zeros((2, 2)).astype(np.float32)), Tensor(np.ones((2, 2)).astype(np.float32)))
|
||||||
|
|
||||||
def construct(self, x, y, z):
|
def construct(self, x, y, z):
|
||||||
xx = x + y
|
xx = x + y
|
||||||
|
@ -246,8 +246,30 @@ def test_soft():
|
||||||
ret = (ret, self.value)
|
ret = (ret, self.value)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
input1 = Tensor(np.zeros((2,)).astype(np.float32))
|
input1 = Tensor(np.zeros((2, 2)).astype(np.float32))
|
||||||
input2 = Tensor(np.ones((2,)).astype(np.float32))
|
input2 = Tensor(np.ones((2, 2)).astype(np.float32))
|
||||||
input3 = Tensor((np.ones((2,)) + np.ones((2,))).astype(np.float32))
|
input3 = Tensor((np.ones((2, 2)) + np.ones((2, 2))).astype(np.float32))
|
||||||
net = SoftmaxCrossEntropyWithLogitsNet()
|
net = SoftmaxCrossEntropyWithLogitsNet()
|
||||||
print(net(input1, input2, input3))
|
net(input1, input2, input3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_const_depend():
|
||||||
|
class ConstDepend(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(ConstDepend, self).__init__()
|
||||||
|
self.value = (Tensor(np.zeros((2, 3)).astype(np.float32)), Tensor(np.ones((2, 3)).astype(np.float32)))
|
||||||
|
self.soft = P.SoftmaxCrossEntropyWithLogits()
|
||||||
|
self.depend = depend
|
||||||
|
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
ret = x + y
|
||||||
|
ret = ret * z
|
||||||
|
ret = self.depend(self.value, ret)
|
||||||
|
ret = (ret, self.soft(x, y))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
input1 = Tensor(np.zeros((2, 2)).astype(np.float32))
|
||||||
|
input2 = Tensor(np.ones((2, 2)).astype(np.float32))
|
||||||
|
input3 = Tensor((np.ones((2, 2)) + np.ones((2, 2))).astype(np.float32))
|
||||||
|
net = ConstDepend()
|
||||||
|
net(input1, input2, input3)
|
||||||
|
|
Loading…
Reference in New Issue