!9134 Equal op dynamic shape

From: @jonwe
Reviewed-by: @robingrosman,@tom__chen
Signed-off-by: @tom__chen
This commit is contained in:
mindspore-ci-bot 2020-12-01 00:55:27 +08:00 committed by Gitee
commit 5835095e07
2 changed files with 45 additions and 9 deletions

View File

@ -80,24 +80,31 @@ AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &pr
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(input_x);
MS_EXCEPTION_IF_NULL(input_x->shape());
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(x->shape());
ShapeVector x_shape = x->shape()->shape();
ShapeVector x_shape_min = x->shape()->min_shape().empty() ? x_shape : x->shape()->min_shape();
ShapeVector x_shape_max = x->shape()->max_shape().empty() ? x_shape : x->shape()->max_shape();
auto input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(input_y);
MS_EXCEPTION_IF_NULL(input_y->shape());
auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(y);
MS_EXCEPTION_IF_NULL(y->shape());
ShapeVector y_shape = y->shape()->shape();
ShapeVector y_shape_min = y->shape()->min_shape().empty() ? y_shape : y->shape()->min_shape();
ShapeVector y_shape_max = y->shape()->max_shape().empty() ? y_shape : y->shape()->max_shape();
auto x_shape = input_x->shape()->shape();
auto y_shape = input_y->shape()->shape();
auto out_shape = BroadcastShape(x_shape, y_shape);
if (out_shape.empty()) {
MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << ","
<< args_spec_list[1]->ToString();
}
auto out_shape_min = BroadcastShape(x_shape_min, y_shape_min);
auto out_shape_max = BroadcastShape(x_shape_max, y_shape_max);
auto output_type = std::make_shared<Bool>();
auto ret = std::make_shared<AbstractTensor>(output_type, out_shape);
auto ret =
std::make_shared<AbstractTensor>(output_type, std::make_shared<Shape>(out_shape, out_shape_min, out_shape_max));
return ret;
}

View File

@ -20,6 +20,7 @@ import mindspore.context as context
from mindspore.common.tensor import Tensor
from mindspore.nn import Cell
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
class NetEqual(Cell):
@ -30,6 +31,17 @@ class NetEqual(Cell):
def construct(self, x, y):
return self.Equal(x, y)
class NetEqualDynamic(Cell):
def __init__(self):
super(NetEqualDynamic, self).__init__()
self.conv = inner.GpuConvertToDynamicShape()
self.Equal = P.Equal()
def construct(self, x, y):
x_conv = self.conv(x)
y_conv = self.conv(y)
return self.Equal(x_conv, y_conv)
class NetNotEqual(Cell):
def __init__(self):
super(NetNotEqual, self).__init__()
@ -211,3 +223,20 @@ def test_greaterqual():
output2 = gequal(x2, y2)
assert np.all(output2.asnumpy() == expect2)
assert output2.shape == expect2.shape
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_equal_dynamic_shape():
x0_np = np.arange(24).reshape((4, 3, 2)).astype(np.float32)
x0 = Tensor(x0_np)
y0_np = np.arange(24).reshape((4, 3, 2)).astype(np.float32)
y0 = Tensor(y0_np)
expect0 = np.equal(x0_np, y0_np)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
equal = NetEqualDynamic()
output0 = equal(x0, y0)
assert np.all(output0.asnumpy() == expect0)
assert output0.shape == expect0.shape