commit
c6127eed14
|
@ -13,8 +13,11 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <limits>
|
||||
#include "plugin/device/cpu/kernel/igammac_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -243,7 +246,7 @@ void IgammacCpuKernelMod::BcastCompute(const std::vector<kernel::AddressPtr> &in
|
|||
auto a_data_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto x_data_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto z_data_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
size_t data_num = get_element_num(z_shape_);
|
||||
size_t data_num = LongToSize(get_element_num(z_shape_));
|
||||
auto output_shape = CPUKernelUtils::GetBroadcastShape(a_shape_, x_shape_);
|
||||
BroadcastIterator iter(a_shape_, x_shape_, output_shape);
|
||||
if (data_num < kParallelDataNums) {
|
||||
|
@ -318,13 +321,13 @@ void IgammacCpuKernelMod::NoBcastCompute(const std::vector<kernel::AddressPtr> &
|
|||
auto in0 = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto in1 = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto out0 = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
size_t in0_elements_nums = get_element_num(a_shape_);
|
||||
size_t in1_elements_nums = get_element_num(x_shape_);
|
||||
size_t data_num = get_element_num(z_shape_);
|
||||
size_t in0_elements_nums = LongToSize(get_element_num(a_shape_));
|
||||
size_t in1_elements_nums = LongToSize(get_element_num(x_shape_));
|
||||
size_t data_num = LongToSize(get_element_num(z_shape_));
|
||||
int64_t type =
|
||||
in0_elements_nums == in1_elements_nums ? kSameShape : (in0_elements_nums == 1 ? kXOneElement : kYOneElement);
|
||||
if (data_num < kParallelDataNums) {
|
||||
SpecialCompute<T>(type, 0, data_num, in0, in1, out0);
|
||||
SpecialCompute<T>(type, 0, SizeToLong(data_num), in0, in1, out0);
|
||||
} else {
|
||||
auto shard_igammac = [type, in0, in1, out0, this](int64_t start, int64_t end) {
|
||||
SpecialCompute<T>(type, start, end, in0, in1, out0 + start);
|
||||
|
@ -359,8 +362,8 @@ void IgammacCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &in
|
|||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
|
||||
size_t in0_elements_nums = get_element_num(a_shape_);
|
||||
size_t in1_elements_nums = get_element_num(x_shape_);
|
||||
size_t in0_elements_nums = LongToSize(get_element_num(a_shape_));
|
||||
size_t in1_elements_nums = LongToSize(get_element_num(x_shape_));
|
||||
bool isNeedBcast = (a_shape_ == x_shape_) || (in0_elements_nums == 1) || (in1_elements_nums == 1);
|
||||
if (isNeedBcast) {
|
||||
NoBcastCompute<T>(inputs, outputs);
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
#include <vector>
|
||||
#include <array>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
|
Loading…
Reference in New Issue