!40401 fix TypeError in GetDynShape

Merge pull request !40401 from 王禹程/fix_shape
This commit is contained in:
i-robot 2022-08-16 06:25:30 +00:00 committed by Gitee
commit 8558573afa
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 27 additions and 0 deletions

View File

@ -464,6 +464,12 @@ void DynamicShape::CheckPreviousTopCellCanBeDynamicShape(const py::object &cell,
py::object DynamicShape::GetDynShape(const py::args &args) const {
const auto &obj = args[0];
// infer type
const auto &v = PyNativeAlgo::DataConvert::PyObjToValue(obj);
auto abs = v->ToAbstract();
std::set<TypePtr> valid_params_types = {kTensorType};
(void)CheckAndConvertUtils::CheckSubClass("shape type", abs->BuildType(), valid_params_types, "Shape");
// infer shape
const auto &base_shape_ptr = obj.cast<tensor::TensorPtr>()->base_shape_ptr();
if (base_shape_ptr != nullptr) {
auto value = MakeValue(base_shape_ptr->cast<abstract::ShapePtr>()->shape());

View File

@ -20,8 +20,10 @@
#include <memory>
#include <string>
#include <vector>
#include <set>
#include "pipeline/pynative/pynative_utils.h"
#include "pipeline/pynative/grad/top_cell.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace pynative {

View File

@ -61,3 +61,22 @@ def test_tile_eliminate():
assert out.shape == (1, 448, 448)
out = expand_tensor(tensor_, (1, 1, 1, 1))
assert out.shape == (1, 1, 448, 448)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_shape_raise():
"""
Feature: shape raise.
Description: Test raise.
Expectation: No exception.
"""
context.set_context(mode=context.PYNATIVE_MODE)
tensor0 = Tensor(np.ndarray([1, 448, 448]), dtype=dtype.float32)
tensor1 = Tensor(np.ndarray([1, 448, 448]), dtype=dtype.float32)
with pytest.raises(TypeError):
ops.shape([tensor0, tensor1])