forked from mindspore-Ecosystem/mindspore
!9134 Equal op dynamic shape
From: @jonwe Reviewed-by: @robingrosman,@tom__chen Signed-off-by: @tom__chen
This commit is contained in:
commit
5835095e07
|
@ -80,24 +80,31 @@ AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &pr
|
|||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(input_x);
|
||||
MS_EXCEPTION_IF_NULL(input_x->shape());
|
||||
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
MS_EXCEPTION_IF_NULL(x->shape());
|
||||
ShapeVector x_shape = x->shape()->shape();
|
||||
ShapeVector x_shape_min = x->shape()->min_shape().empty() ? x_shape : x->shape()->min_shape();
|
||||
ShapeVector x_shape_max = x->shape()->max_shape().empty() ? x_shape : x->shape()->max_shape();
|
||||
|
||||
auto input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
MS_EXCEPTION_IF_NULL(input_y);
|
||||
MS_EXCEPTION_IF_NULL(input_y->shape());
|
||||
auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
MS_EXCEPTION_IF_NULL(y);
|
||||
MS_EXCEPTION_IF_NULL(y->shape());
|
||||
ShapeVector y_shape = y->shape()->shape();
|
||||
ShapeVector y_shape_min = y->shape()->min_shape().empty() ? y_shape : y->shape()->min_shape();
|
||||
ShapeVector y_shape_max = y->shape()->max_shape().empty() ? y_shape : y->shape()->max_shape();
|
||||
|
||||
auto x_shape = input_x->shape()->shape();
|
||||
auto y_shape = input_y->shape()->shape();
|
||||
auto out_shape = BroadcastShape(x_shape, y_shape);
|
||||
if (out_shape.empty()) {
|
||||
MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << ","
|
||||
<< args_spec_list[1]->ToString();
|
||||
}
|
||||
auto out_shape_min = BroadcastShape(x_shape_min, y_shape_min);
|
||||
auto out_shape_max = BroadcastShape(x_shape_max, y_shape_max);
|
||||
|
||||
auto output_type = std::make_shared<Bool>();
|
||||
auto ret = std::make_shared<AbstractTensor>(output_type, out_shape);
|
||||
auto ret =
|
||||
std::make_shared<AbstractTensor>(output_type, std::make_shared<Shape>(out_shape, out_shape_min, out_shape_max));
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ import mindspore.context as context
|
|||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
|
||||
|
||||
class NetEqual(Cell):
|
||||
|
@ -30,6 +31,17 @@ class NetEqual(Cell):
|
|||
def construct(self, x, y):
|
||||
return self.Equal(x, y)
|
||||
|
||||
class NetEqualDynamic(Cell):
|
||||
def __init__(self):
|
||||
super(NetEqualDynamic, self).__init__()
|
||||
self.conv = inner.GpuConvertToDynamicShape()
|
||||
self.Equal = P.Equal()
|
||||
|
||||
def construct(self, x, y):
|
||||
x_conv = self.conv(x)
|
||||
y_conv = self.conv(y)
|
||||
return self.Equal(x_conv, y_conv)
|
||||
|
||||
class NetNotEqual(Cell):
|
||||
def __init__(self):
|
||||
super(NetNotEqual, self).__init__()
|
||||
|
@ -211,3 +223,20 @@ def test_greaterqual():
|
|||
output2 = gequal(x2, y2)
|
||||
assert np.all(output2.asnumpy() == expect2)
|
||||
assert output2.shape == expect2.shape
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_equal_dynamic_shape():
|
||||
x0_np = np.arange(24).reshape((4, 3, 2)).astype(np.float32)
|
||||
x0 = Tensor(x0_np)
|
||||
y0_np = np.arange(24).reshape((4, 3, 2)).astype(np.float32)
|
||||
y0 = Tensor(y0_np)
|
||||
expect0 = np.equal(x0_np, y0_np)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
equal = NetEqualDynamic()
|
||||
output0 = equal(x0, y0)
|
||||
assert np.all(output0.asnumpy() == expect0)
|
||||
assert output0.shape == expect0.shape
|
||||
|
|
Loading…
Reference in New Issue