Fixes RNG_seed bug in StandardNormal operator

From: @huangbo77
Reviewed-by: @wuxuejian,@liangchenghui
Signed-off-by: @wuxuejian
This commit is contained in:
mindspore-ci-bot 2021-05-08 09:33:15 +08:00 committed by Gitee
commit d632a34b16
1 changed files with 18 additions and 5 deletions
mindspore/ccsrc/backend/kernel_compiler/cpu

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,6 +15,7 @@
*/
#include <random>
#include <thread>
#include "common/thread_pool.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "backend/kernel_compiler/cpu/random_cpu_kernel.h"
@ -41,11 +42,23 @@ void LaunchStandardNormal(int seed, int seed2, const std::vector<AddressPtr> &ou
auto output = reinterpret_cast<float *>(outputs[0]->addr);
size_t lens = outputs[0]->size / sizeof(float);
std::normal_distribution<float> distribution;
auto task = [&](size_t start, size_t end) {
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
const float block_size = 128.0;
size_t thread_num = lens < block_size * max_thread_num ? std::ceil(lens / block_size) : max_thread_num;
std::vector<common::Task> tasks;
size_t start = 0;
size_t once_compute_size = (lens + thread_num - 1) / thread_num;
while (start < lens) {
size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size);
std::default_random_engine random_generator(++RNG_seed);
StandardNormal(output, distribution, random_generator, start, end);
};
CPUKernelUtils::ParallelFor(task, lens);
auto block = [&, start, end]() {
StandardNormal(output, distribution, random_generator, start, end);
return common::SUCCESS;
};
tasks.emplace_back(block);
start += once_compute_size;
}
common::ThreadPool::GetInstance().SyncRun(tasks);
}
void RandomCPUKernel::InitKernel(const CNodePtr &kernel_node) {