forked from mindspore-Ecosystem/mindspore
!40401 fix TypeError in GetDynShape
Merge pull request !40401 from 王禹程/fix_shape
This commit is contained in:
commit
8558573afa
|
@ -464,6 +464,12 @@ void DynamicShape::CheckPreviousTopCellCanBeDynamicShape(const py::object &cell,
|
|||
|
||||
py::object DynamicShape::GetDynShape(const py::args &args) const {
|
||||
const auto &obj = args[0];
|
||||
// infer type
|
||||
const auto &v = PyNativeAlgo::DataConvert::PyObjToValue(obj);
|
||||
auto abs = v->ToAbstract();
|
||||
std::set<TypePtr> valid_params_types = {kTensorType};
|
||||
(void)CheckAndConvertUtils::CheckSubClass("shape type", abs->BuildType(), valid_params_types, "Shape");
|
||||
// infer shape
|
||||
const auto &base_shape_ptr = obj.cast<tensor::TensorPtr>()->base_shape_ptr();
|
||||
if (base_shape_ptr != nullptr) {
|
||||
auto value = MakeValue(base_shape_ptr->cast<abstract::ShapePtr>()->shape());
|
||||
|
|
|
@ -20,8 +20,10 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include "pipeline/pynative/pynative_utils.h"
|
||||
#include "pipeline/pynative/grad/top_cell.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace pynative {
|
||||
|
|
|
@ -61,3 +61,22 @@ def test_tile_eliminate():
|
|||
assert out.shape == (1, 448, 448)
|
||||
out = expand_tensor(tensor_, (1, 1, 1, 1))
|
||||
assert out.shape == (1, 1, 448, 448)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_shape_raise():
|
||||
"""
|
||||
Feature: shape raise.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
tensor0 = Tensor(np.ndarray([1, 448, 448]), dtype=dtype.float32)
|
||||
tensor1 = Tensor(np.ndarray([1, 448, 448]), dtype=dtype.float32)
|
||||
with pytest.raises(TypeError):
|
||||
ops.shape([tensor0, tensor1])
|
||||
|
|
Loading…
Reference in New Issue