forked from mindspore-Ecosystem/mindspore
Support the values of the dict has nested sequences and None.
This commit is contained in:
parent
ed363cde09
commit
3e920a73fc
|
@ -16,7 +16,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/optimizer/clean.h"
|
||||
#include "frontend/optimizer/fallback_rewriter.h"
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
@ -98,6 +98,8 @@ class BaseRewriter : protected SimpleRewriter {
|
|||
|
||||
bool need_renormalized() const { return need_renormalized_; }
|
||||
|
||||
void set_need_renormalized(bool need_renormalized) { need_renormalized_ = need_renormalized; }
|
||||
|
||||
virtual bool Execute() {
|
||||
bool changed = Run();
|
||||
if (changed) {
|
||||
|
@ -941,26 +943,18 @@ class CleanAfterOptARewriter : public BaseRewriter {
|
|||
{prim::kPrimMakeDict, &ThisClass::ConvertMakeDict},
|
||||
};
|
||||
|
||||
// Convert ValueNode<None> to PyExecute("None", ("None"), ("None")).
|
||||
AnfNodePtr NoneConvertPyExecute(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto str_value = std::make_shared<StringImm>("None");
|
||||
auto script_node = NewValueNode(str_value);
|
||||
script_node->set_abstract(str_value->ToAbstract());
|
||||
|
||||
auto empty_tuple = std::vector<ValuePtr>();
|
||||
auto empty_tuple_value = std::make_shared<ValueTuple>(empty_tuple);
|
||||
|
||||
auto local_key_node = NewValueNode(empty_tuple_value);
|
||||
local_key_node->set_abstract(empty_tuple_value->ToAbstract());
|
||||
|
||||
auto local_value_node = NewValueNode(empty_tuple_value);
|
||||
local_value_node->set_abstract(empty_tuple_value->ToAbstract());
|
||||
std::vector<ValuePtr> none_value{str_value};
|
||||
const auto none_tuple = std::make_shared<ValueTuple>(none_value);
|
||||
auto none_tuple_node = NewValueNode(none_tuple);
|
||||
|
||||
AnfNodePtr none_execute_node =
|
||||
func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimPyExecute), script_node, local_key_node, local_value_node});
|
||||
ShapeVector shp{abstract::Shape::kShapeRankAny};
|
||||
auto abs = std::make_shared<abstract::AbstractTensor>(kFloat64, std::make_shared<abstract::Shape>(shp));
|
||||
none_execute_node->set_abstract(abs);
|
||||
func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimPyExecute), script_node, none_tuple_node, none_tuple_node});
|
||||
MS_LOG(DEBUG) << "none_execute_node:" << none_execute_node->DebugString();
|
||||
return none_execute_node;
|
||||
}
|
||||
|
@ -981,13 +975,9 @@ class CleanAfterOptARewriter : public BaseRewriter {
|
|||
if (!IsValueNode<None>(input)) {
|
||||
continue;
|
||||
}
|
||||
// Convert ValueNode<None> to PyExecute("None", (), ()).
|
||||
auto none_py_execute = NoneConvertPyExecute(cur_func);
|
||||
manager_->Replace(input, none_py_execute);
|
||||
}
|
||||
// If the cnode is depend node, need renew the abstract of the cnode.
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && IsPrimitiveCNode(cnode->input(1), prim::kPrimPyExecute)) {
|
||||
cnode->set_abstract(cnode->input(1)->abstract());
|
||||
set_need_renormalized(true);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1029,26 +1019,53 @@ class CleanAfterOptARewriter : public BaseRewriter {
|
|||
return value;
|
||||
}
|
||||
|
||||
AnfNodePtr ProcessValueSequence(const ValuePtr &value) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<ValueSequence>()) {
|
||||
auto value_seq = value->cast<ValueSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_seq);
|
||||
auto values = value_seq->value();
|
||||
std::vector<AnfNodePtr> value_seq_inputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
for (auto inner_value : values) {
|
||||
auto inner_value_seq = ProcessValueSequence(inner_value);
|
||||
(void)value_seq_inputs.emplace_back(inner_value_seq);
|
||||
}
|
||||
auto iter_value = root_graph_->NewCNode(value_seq_inputs);
|
||||
return iter_value;
|
||||
}
|
||||
if (value->isa<None>()) {
|
||||
return NoneConvertPyExecute(root_graph_);
|
||||
}
|
||||
return NewValueNode(value);
|
||||
}
|
||||
|
||||
AnfNodePtr PackDictValue(const ValueDictionaryPtr &dict) {
|
||||
const auto &keys_values = dict->value();
|
||||
std::vector<AnfNodePtr> value_list{NewValueNode(prim::kPrimMakeTuple)};
|
||||
for (const auto &key_value : keys_values) {
|
||||
auto iter_value = ProcessValueSequence(key_value.second);
|
||||
(void)value_list.emplace_back(iter_value);
|
||||
}
|
||||
auto value_tuple_node = root_graph_->NewCNode(value_list);
|
||||
return value_tuple_node;
|
||||
}
|
||||
|
||||
// dict(k0:v0, k1:v1, ...) --> PyExecute('dict(zip(keys, values))', ...)
|
||||
AnfNodePtr RebuildValueDict(const ValueNodePtr &value_node, const ValueDictionaryPtr &dict) {
|
||||
const auto &keys_values = dict->value();
|
||||
std::vector<ValuePtr> key_list;
|
||||
key_list.reserve(keys_values.size());
|
||||
std::vector<ValuePtr> value_list;
|
||||
value_list.reserve(keys_values.size());
|
||||
for (const auto &key_value : keys_values) {
|
||||
(void)key_list.emplace_back(key_value.first);
|
||||
(void)value_list.emplace_back(key_value.second);
|
||||
}
|
||||
|
||||
// Local parameters values.
|
||||
// Pack the key tuple.
|
||||
std::vector<ValuePtr> key_list;
|
||||
key_list.reserve(keys_values.size());
|
||||
for (const auto &key_value : keys_values) {
|
||||
(void)key_list.emplace_back(key_value.first);
|
||||
}
|
||||
const auto key_tuple = std::make_shared<ValueTuple>(key_list);
|
||||
auto key_tuple_node = NewValueNode(key_tuple);
|
||||
|
||||
// Pack the value
|
||||
const auto value_tuple = std::make_shared<ValueTuple>(value_list);
|
||||
auto value_tuple_node = NewValueNode(value_tuple);
|
||||
// Pack the value tuple.
|
||||
auto value_tuple_node = PackDictValue(dict);
|
||||
|
||||
// Generate Make Dict PyExecute Node value
|
||||
auto make_key_tuple_node = ConstructInternalTupleKeysNode(root_graph_, key_tuple_node);
|
||||
|
@ -1151,7 +1168,7 @@ bool CleanAfterOptA(const FuncGraphPtr &root, const pipeline::ResourcePtr &resou
|
|||
CleanAfterOptARewriter rewriter(root, manager);
|
||||
bool change = rewriter.Execute();
|
||||
// Renormalize for new PyExecute node.
|
||||
if (change && rewriter.need_renormalized()) {
|
||||
if (rewriter.need_renormalized()) {
|
||||
abstract::AbstractBasePtrList new_args_spec;
|
||||
std::transform(root->parameters().begin(), root->parameters().end(), std::back_inserter(new_args_spec),
|
||||
[](const AnfNodePtr ¶m) -> AbstractBasePtr { return param->abstract(); });
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,8 +16,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_CLEAN_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_CLEAN_H_
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_FALLBACK_REWRITRER_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_FALLBACK_REWRITRER_H_
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
|
@ -34,4 +34,4 @@ bool CleanAfterOptA(const FuncGraphPtr &root, const pipeline::ResourcePtr &resou
|
|||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_CLEAN_H_
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_FALLBACK_REWRITRER_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -19,7 +19,7 @@
|
|||
#include "frontend/optimizer/irpass/branch_culling.h"
|
||||
#include "frontend/optimizer/irpass/cast_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/get_grad_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/convert.h"
|
||||
#include "frontend/optimizer/irpass/print_converter.h"
|
||||
#include "frontend/optimizer/irpass/environ_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/meta_fg_var_prepare.h"
|
||||
#include "frontend/optimizer/irpass/inline.h"
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -1553,11 +1553,6 @@ void SetDynamicInputSizeAttr(const CNodePtr &cnode) {
|
|||
bool is_dyn_input = false;
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto input_node = common::AnfAlgo::GetInputNode(cnode, i);
|
||||
// remove monad input
|
||||
auto abs = input_node->abstract();
|
||||
if (abs->isa<abstract::AbstractMonad>()) {
|
||||
continue;
|
||||
}
|
||||
if (i < input_obj_types.size() && input_obj_types[i] == kernel::KernelObjectType::TUPLE_UNFOLD) {
|
||||
auto [input_structural, input_size, dyn_input] = CalOutputTupleSize(input_node);
|
||||
is_dyn_input |= dyn_input;
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
#include "frontend/optimizer/opt.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/cse_pass.h"
|
||||
#include "frontend/optimizer/clean.h"
|
||||
#include "frontend/optimizer/fallback_rewriter.h"
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/graph_transform.h"
|
||||
#include "frontend/optimizer/auto_monad_eliminate.h"
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind_api/pybind_patch.h"
|
||||
|
@ -36,9 +37,34 @@
|
|||
#include "mindspore/ccsrc/pybind_api/ir/tensor_py.h"
|
||||
#include "mindspore/ccsrc/plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.h"
|
||||
#include "mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
|
||||
#include "mindspore/ccsrc/pipeline/jit/parse/resolve.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
using PyObjectWrapperPtr = std::shared_ptr<parse::PyObjectWrapper>;
|
||||
namespace pyexecute_user_data_catcher {
|
||||
std::pair<bool, ValuePtr> PyExecuteUserDataCatcher(const AbstractBasePtr &element_abs) {
|
||||
MS_EXCEPTION_IF_NULL(element_abs);
|
||||
if (element_abs->has_user_data<kernel::PyExecuteOutputUserData>()) {
|
||||
const auto &data = element_abs->user_data<kernel::PyExecuteOutputUserData>();
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
auto python_obj = std::make_shared<parse::PyObjectWrapper>(data->obj, "graph python obj");
|
||||
return {true, python_obj};
|
||||
}
|
||||
return {false, nullptr};
|
||||
}
|
||||
|
||||
struct PyExecuteUserDataCatcherRegister {
|
||||
PyExecuteUserDataCatcherRegister() noexcept {
|
||||
abstract::AbstractBase::set_pyexecute_user_data_catcher(
|
||||
[](const AbstractBasePtr &element_abs) { return PyExecuteUserDataCatcher(element_abs); });
|
||||
}
|
||||
~PyExecuteUserDataCatcherRegister() {}
|
||||
} pyexecute_user_data_catcher_register;
|
||||
} // namespace pyexecute_user_data_catcher
|
||||
} // namespace abstract
|
||||
|
||||
static py::object CallPythonGetGlobalParams() {
|
||||
constexpr auto python_mod_parse = "mindspore._extends.parse"; // The same as PYTHON_MOD_PARSE_MODULE[]
|
||||
py::module mod = python_adapter::GetPyModule(python_mod_parse);
|
||||
|
@ -64,15 +90,6 @@ class PyExecuteInitializer {
|
|||
const auto &keys_tuple_abs = input_args[1];
|
||||
const auto &keys_tuple = keys_tuple_abs->BuildValue();
|
||||
const auto &keys = dyn_cast<ValueSequence>(keys_tuple);
|
||||
// Process PyExecute("None", (), (), io)
|
||||
// Since the backend converts the empty tuple into an empty tensor(not keep ValueSequence),
|
||||
// so special handling of None is required.
|
||||
if (script->ToString() == "None") {
|
||||
const auto &output = py::none();
|
||||
PushPyExecuteOutput(script, output);
|
||||
const auto &infer_shape = std::make_shared<abstract::Shape>(ShapeVector({1}));
|
||||
return abstract::MakeAbstract(infer_shape, kFloat64);
|
||||
}
|
||||
if (keys == nullptr) {
|
||||
MS_LOG(DEBUG) << "The keys is not tuple value, but got " << keys_tuple->ToString();
|
||||
const auto &infer_shape = std::make_shared<abstract::Shape>(ShapeVector({1}));
|
||||
|
@ -90,19 +107,24 @@ class PyExecuteInitializer {
|
|||
const auto &infer_shape = std::make_shared<abstract::Shape>(ShapeVector({1}));
|
||||
return abstract::MakeAbstract(infer_shape, kFloat64);
|
||||
}
|
||||
MS_LOG(DEBUG) << "script: " << script->ToString() << ", keys_tuple: " << keys_tuple->ToString()
|
||||
MS_LOG(DEBUG) << "The script is: " << script->ToString() << ", keys_tuple: " << keys_tuple->ToString()
|
||||
<< ", values_tuple: " << values_tuple->ToString();
|
||||
|
||||
if (keys->size() != values->size()) {
|
||||
MS_LOG(EXCEPTION) << "The length of keys(" << keys->size() << ") is not equal of the length of values("
|
||||
<< values->size() << ").";
|
||||
}
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
py::dict local_dict;
|
||||
|
||||
for (size_t i = 0; i < keys->size(); ++i) {
|
||||
const auto &key = (*keys)[i];
|
||||
const auto &key_str = dyn_cast<StringImm>(key);
|
||||
MS_EXCEPTION_IF_NULL(key_str);
|
||||
const auto &value = (*values)[i];
|
||||
MS_LOG(DEBUG) << "input[" << i << "], value : " << value;
|
||||
MS_LOG(DEBUG) << "input[" << i << "], value : " << value->ToString();
|
||||
const auto &tuple_abs = values_tuple_abs->cast<abstract::AbstractSequencePtr>();
|
||||
const auto &value_abs = (*tuple_abs)[i];
|
||||
|
||||
if (value_abs->has_user_data<kernel::PyExecuteOutputUserData>()) {
|
||||
const auto &output_data = value_abs->user_data<kernel::PyExecuteOutputUserData>();
|
||||
auto obj = output_data->obj;
|
||||
|
|
|
@ -337,6 +337,12 @@ static ValueNameToConverterVector value_name_to_converter = {
|
|||
auto interpreted_object = value->cast<parse::InterpretedObjectPtr>();
|
||||
return interpreted_object->obj();
|
||||
}},
|
||||
// parse::PyObjectWrapper
|
||||
{parse::PyObjectWrapper::kTypeId,
|
||||
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
|
||||
auto py_object = value->cast<parse::PyObjectWrapperPtr>();
|
||||
return py_object->obj();
|
||||
}},
|
||||
// None
|
||||
{None::kTypeId, [](const ValuePtr &, const AbstractBasePtr &) -> py::object { return py::none(); }},
|
||||
// AnyValue
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -751,12 +751,25 @@ AbstractBasePtrList AbstractSequence::ElementsPartialBroaden() const {
|
|||
return element_list;
|
||||
}
|
||||
|
||||
std::pair<bool, ValuePtr> GetValueFromUserData(const AbstractBasePtr &element_abs) {
|
||||
MS_EXCEPTION_IF_NULL(element_abs);
|
||||
if (abstract::AbstractBase::pyexecute_user_data_catcher()) {
|
||||
return abstract::AbstractBase::pyexecute_user_data_catcher()(element_abs);
|
||||
}
|
||||
return {false, nullptr};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ValuePtr AbstractSequence::ElementsBuildValue() const {
|
||||
std::vector<ValuePtr> element_value_list;
|
||||
for (const auto &element : elements_) {
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
ValuePtr element_value = element->BuildValue();
|
||||
auto [has_user_data, element_value] = GetValueFromUserData(element);
|
||||
if (has_user_data && element_value != nullptr) {
|
||||
element_value_list.push_back(element_value);
|
||||
continue;
|
||||
}
|
||||
element_value = element->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(element_value);
|
||||
if (element_value->isa<AnyValue>()) {
|
||||
return kAnyValue;
|
||||
|
|
|
@ -259,6 +259,14 @@ class MS_CORE_API AbstractBase : public Base {
|
|||
static void set_interpret_bool_checker(InterpretBoolChecker checker) { interpret_bool_checker_ = checker; }
|
||||
static inline InterpretBoolChecker interpret_bool_checker() { return interpret_bool_checker_; }
|
||||
|
||||
/// \brief Process the user date of abstract with PyExecute node.
|
||||
using PyExecuteUserDataCatcher = std::pair<bool, ValuePtr> (*)(const AbstractBasePtr &element_abs);
|
||||
static inline PyExecuteUserDataCatcher pyexecute_user_data_catcher_ = nullptr;
|
||||
static void set_pyexecute_user_data_catcher(PyExecuteUserDataCatcher catcher) {
|
||||
pyexecute_user_data_catcher_ = catcher;
|
||||
}
|
||||
static inline PyExecuteUserDataCatcher pyexecute_user_data_catcher() { return pyexecute_user_data_catcher_; }
|
||||
|
||||
std::string name() const { return name_; }
|
||||
|
||||
void set_name(const std::string &name) { name_ = name; }
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021-2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import pytest
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
|
@ -80,7 +79,7 @@ def test_auto_monad_layer():
|
|||
train_net.set_train()
|
||||
gen_samples = dict()
|
||||
num_epoch = 21
|
||||
for epoch in tqdm(range(num_epoch)):
|
||||
for epoch in range(num_epoch):
|
||||
loss = []
|
||||
for _, (batch,) in enumerate(dataloader):
|
||||
batch = Tensor(batch, dtype=ms.float32)
|
||||
|
|
|
@ -356,7 +356,6 @@ def test_none_is_output_of_function_with_side_effect():
|
|||
assert res == 4
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support None in dict return.")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
@ -373,11 +372,34 @@ def test_none_is_input_of_dict_return():
|
|||
def foo():
|
||||
x = {'a': 'a', 'b': 'b'}
|
||||
y = x.get('a')
|
||||
z = dict(y=y, v=False, w=None)
|
||||
z = dict(y=y, u=9, v=False, w=None)
|
||||
return z
|
||||
|
||||
out = foo()
|
||||
assert out == {'y': 'a', 'v': False, 'w': None}
|
||||
assert out == {'y': 'a', 'u': 9, 'v': False, 'w': None}
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_none_nested_input_of_dict_return():
|
||||
"""
|
||||
Feature: Support None.
|
||||
Description: Support None is input of dict, and the dict is return.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
x = {'a': 'a', 'b': 'b'}
|
||||
y = x.get('a')
|
||||
z = dict(y=y, u=9, v=False, w=(None, None), q=[1, (2, None), None])
|
||||
return z
|
||||
|
||||
out = foo()
|
||||
assert out == {'y': 'a', 'u': 9, 'v': False, 'w': (None, None), 'q': (1, (2, None), None)}
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
|
@ -194,29 +194,6 @@ def test_dict_return_2():
|
|||
assert out == {'a': ms.Tensor(np.array(1), ms.int64)}
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support None and Scalar in dict.")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dict_return_3():
|
||||
"""
|
||||
Feature: Return dict.
|
||||
Description: Support dict return.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms.jit
|
||||
def dict_net_3():
|
||||
x = {'a': 'a', 'b': 'b'}
|
||||
y = x.get('a')
|
||||
z = dict(y=y, u=9, v=False, w=None)
|
||||
return z
|
||||
|
||||
out = dict_net_3()
|
||||
assert out == {'y': 'a', 'u': 9, 'v': False, 'w': None}
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
|
Loading…
Reference in New Issue