From 6915a9853613512106ba4c4b84b27c6e1bcdcecc Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Wed, 18 Nov 2020 21:00:17 -0500 Subject: [PATCH] Add broadcast for a and x support for IGamma --- mindspore/nn/layer/math.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mindspore/nn/layer/math.py b/mindspore/nn/layer/math.py index 3543fa5bab4..2f2e830c7b3 100644 --- a/mindspore/nn/layer/math.py +++ b/mindspore/nn/layer/math.py @@ -646,10 +646,13 @@ class IGamma(Cell): x_dtype = self.dtype(x) _check_input_dtype("input_a", a_dtype, [mstype.float16, mstype.float32], self.cls_name) _check_input_dtype("input_x", x_dtype, a_dtype, self.cls_name) - x_is_zero = self.equal(x, 0) domain_error = self.logicalor(self.less(x, 0), self.less(a, 0)) use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a)) ax = a * self.log(x) - x - self.lgamma(a) + boradcastto = P.BroadcastTo(self.shape(ax)) + a = boradcastto(a) + x = boradcastto(x) + x_is_zero = self.equal(x, 0) if a_dtype == mstype.float16: log_maxfloat = self.log_maxfloat16 else: