!45695 add aicpu dropout gen mask
Merge pull request !45695 from 王禹程/add_drop
This commit is contained in:
commit
02bf0de98e
|
@ -7,6 +7,7 @@
|
|||
"mindspore/mindspore/core/abstract/ops/prim_nn.cc" "zerodivcond"
|
||||
"mindspore/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc" "zerodivcond"
|
||||
"mindspore/mindspore/ccsrc/pipeline/jit/pipeline_split.cc" "zerodivcond"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/drop_out_gen_mask_kernels.cc" "uninitMemberVar"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion/adaptive_max_pool2d_fusion.cc" "zerodivcond"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_stream_assign.cc" "useStlAlgorithm"
|
||||
"mindspore/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc" "uninitvar"
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
"mindspore/mindspore/ccsrc/runtime/hardware/device_context.h" "readability/braces"
|
||||
"mindspore/mindspore/ccsrc/transform/graph_ir/convert.h" "runtime/references"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/gather_grad_kernels.cc" "build/include"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/drop_out_gen_mask_kernels.cc" "build/include"
|
||||
"mindspore/mindspore/ccsrc/backend/common/optimizer/op_adaptation_info_factory.h" "runtime/explicit"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/concatv2_impl.cu" "runtime/int"
|
||||
"mindspore/mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.cc" "runtime/threadsafe_fn"
|
||||
|
|
|
@ -21,6 +21,7 @@ mindspore/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc:mindspore::abstra
|
|||
mindspore/mindspore/ccsrc/pybind_api/ir/log_adapter_py.cc:mindspore::PyExceptionInitializer::HandleExceptionPy
|
||||
mindspore/mindspore/ccsrc/plugin/device/gpu/kernel/math/unary_op_gpu_kernel.h:mindspore::kernel::UnaryOpGpuKernel::Launch
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/dynamic_rnn_grad_fission_v2.cc:mindspore::opt::AddLSTMInputGradNode
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/drop_out_gen_mask_kernels.cc:aicpu::ARMDropOutGenMaskKernel
|
||||
mindspore/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py:__init__
|
||||
mindspore/model_zoo/official/recommend/wide_and_deep_multitable/src/wide_and_deep.py:__init__
|
||||
mindspore/mindspore/ccsrc/pipeline/jit/resource.cc:mindspore::pipeline::GetMethodMap
|
||||
|
|
|
@ -39,6 +39,7 @@ if(EXISTS ${CMAKE_C_COMPILER} AND EXISTS ${CMAKE_CXX_COMPILER})
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/replay_buffer/priority_replay_buffer.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/replay_buffer/priority_replay_buffer_kernels.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/concat_offset_kernel.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/drop_out_gen_mask_kernels.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slice_grad_kernel.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random_shuffle_kernel.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/range_kernel.cc
|
||||
|
|
|
@ -0,0 +1,455 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "drop_out_gen_mask_kernels.h"
|
||||
#include <cfloat>
|
||||
#include <ctime>
|
||||
#include <random>
|
||||
#include <memory.h>
|
||||
|
||||
#include "aicpu_sharder/aicpu_sharder.h"
|
||||
#include "common/kernel_errcode.h"
|
||||
#include "common/kernel_log.h"
|
||||
|
||||
#include "Eigen/Core"
|
||||
|
||||
namespace aicpu {
|
||||
std::random_device e;
|
||||
size_t kIndexOutput = 4;
|
||||
|
||||
#if (defined __ARM_ARCH) || (defined PLATFORM_AARCH64) // compiled on arm arch
|
||||
#define CONFIG_ENABLE_PERIOD_64BIT
|
||||
static void OffsetAdd(uint64_t number, const uint64_t *baseOffset, uint64_t *offset) {
|
||||
uint64_t tmpBaseOffset0 = baseOffset[0];
|
||||
uint64_t tmpBaseOffset1 = baseOffset[1];
|
||||
offset[0] = tmpBaseOffset0 + number;
|
||||
offset[1] = tmpBaseOffset1;
|
||||
if (offset[0] < tmpBaseOffset0) {
|
||||
offset[1]++;
|
||||
}
|
||||
}
|
||||
|
||||
static void ARMDropOutGenMaskKernel(const uint64_t count, const float prob, const uint8_t *offset, const uint8_t *key,
|
||||
uint8_t *out) {
|
||||
const uint16_t threshold = static_cast<uint16_t>(UINT16_MAX * prob);
|
||||
const uint8_t in_offset[16] = {0x01, 0, 0, 0, 0, 0, 0, 0, 0x01};
|
||||
const uint8_t inc_step[16] = {0x02};
|
||||
|
||||
// a const key. reference paper: https://dl.acm.org/citation.cfm?id=206340
|
||||
const uint8_t key_const[16] = {0xBB, 0x67, 0xAE, 0x85, 0x84, 0xCA, 0xA7, 0x3B,
|
||||
0x9E, 0x37, 0x79, 0xB9, 0x7F, 0x4A, 0x7C, 0x15};
|
||||
|
||||
const uint8_t *key_const_ptr = &(key_const[0]);
|
||||
const uint8_t *inc_step_ptr = &(inc_step[0]);
|
||||
|
||||
// Each iteration generates 4-bit * 8 elements (in vector reg) * 4 (repeated code blocks)
|
||||
const uint64_t loop_time = count / 4 / 8 / 4;
|
||||
__asm volatile(
|
||||
".arch armv8-a+crypto \n"
|
||||
|
||||
"ldr x0, %[loop_time] \n"
|
||||
|
||||
"ldr x16, %[key_const_ptr] \n"
|
||||
"ld1 {v2.16b}, [x16] \n"
|
||||
|
||||
// generate in1
|
||||
"ldr x1, %[offset] \n"
|
||||
"ld1 {v0.16b}, [x1] \n" // tmp input
|
||||
|
||||
"ldr x2, %[key] \n"
|
||||
"ld1 {v1.16b}, [x2] \n" // first round key
|
||||
"add v5.2d, v1.2d, v2.2d \n" // second round key
|
||||
|
||||
// generate in2
|
||||
"ldp x10, x11, %[in_offset] \n"
|
||||
"ldp x12, x13, [x1] \n"
|
||||
"adds x14, x12, x10 \n"
|
||||
"adc x15, x13, x11 \n"
|
||||
"mov v10.d[0], x14 \n"
|
||||
"mov v10.d[1], x15 \n"
|
||||
|
||||
// generate in3 = in1
|
||||
"mov v3.16b, v0.16b \n"
|
||||
// generate in4 = in2
|
||||
"mov v13.16b, v10.16b \n"
|
||||
|
||||
// load input inc step
|
||||
#ifdef CONFIG_ENABLE_PERIOD_64BIT
|
||||
"ldr x17, %[inc_step_ptr] \n"
|
||||
"ld1 {v4.16b}, [x17] \n"
|
||||
#else
|
||||
"ldp x10, x11, %[inc_step] \n"
|
||||
#endif
|
||||
|
||||
"ldr w7, %[threshold] \n"
|
||||
"dup v20.8h, w7 \n"
|
||||
|
||||
// Generate 16 bitmasks to 16 regs
|
||||
"mov w7, #0x8000 \n"
|
||||
"dup v21.8h, w7 \n"
|
||||
"mov w7, #0x4000 \n"
|
||||
"dup v12.8h, w7 \n"
|
||||
"mov w7, #0x2000 \n"
|
||||
"dup v2.8h, w7 \n"
|
||||
"mov w7, #0x1000 \n"
|
||||
"dup v6.8h, w7 \n"
|
||||
|
||||
"mov w7, #0x800 \n"
|
||||
"dup v7.8h, w7 \n"
|
||||
"mov w7, #0x400 \n"
|
||||
"dup v8.8h, w7 \n"
|
||||
"mov w7, #0x200 \n"
|
||||
"dup v14.8h, w7 \n"
|
||||
"mov w7, #0x100 \n"
|
||||
"dup v15.8h, w7 \n"
|
||||
|
||||
"mov w7, #0x80 \n"
|
||||
"dup v26.8h, w7 \n"
|
||||
"mov w7, #0x40 \n"
|
||||
"dup v27.8h, w7 \n"
|
||||
"mov w7, #0x20 \n"
|
||||
"dup v9.8h, w7 \n"
|
||||
"mov w7, #0x10 \n"
|
||||
"dup v11.8h, w7 \n"
|
||||
|
||||
"mov w7, #0x8 \n"
|
||||
"dup v16.8h, w7 \n"
|
||||
"mov w7, #0x4 \n"
|
||||
"dup v17.8h, w7 \n"
|
||||
"mov w7, #0x2 \n"
|
||||
"dup v18.8h, w7 \n"
|
||||
"mov w7, #0x1 \n"
|
||||
"dup v19.8h, w7 \n"
|
||||
|
||||
// load out pointer addr to register
|
||||
"ldr x5, %[out] \n"
|
||||
|
||||
// Iteration begins
|
||||
".ARS: \n"
|
||||
|
||||
/* Mix v0 with v1 */
|
||||
"aese v0.16b, v1.16b \n"
|
||||
"aesmc v0.16b, v0.16b \n"
|
||||
|
||||
/* Mix v10 with v5 */
|
||||
"aese v10.16b, v5.16b \n"
|
||||
"aesmc v10.16b, v10.16b \n"
|
||||
|
||||
/* Compare the random number v0 against threshold */
|
||||
"cmhs v22.8h, v20.8h, v0.8h \n"
|
||||
/* Update the output register with v0 */
|
||||
"bit v29.16b, v22.16b, v21.16b \n"
|
||||
|
||||
/* Mix v13 with v1 */
|
||||
"aese v13.16b, v1.16b \n"
|
||||
"aesmc v13.16b, v13.16b \n"
|
||||
|
||||
/* Compare the random number v10 against threshold */
|
||||
"cmhs v23.8h, v20.8h, v10.8h \n"
|
||||
/* Update the output register with v10 */
|
||||
"bit v29.16b, v23.16b, v12.16b \n"
|
||||
|
||||
/* Mix v3 with v5 */
|
||||
"aese v3.16b, v5.16b \n"
|
||||
"aesmc v3.16b, v3.16b \n"
|
||||
|
||||
/* Compare the random number v13 against threshold */
|
||||
"cmhs v25.8h, v20.8h, v13.8h \n"
|
||||
/* Update the output register with v13 */
|
||||
"bit v29.16b, v25.16b, v6.16b \n"
|
||||
|
||||
/* Mix v0 with v1 */
|
||||
"aese v0.16b, v1.16b \n"
|
||||
"aesmc v0.16b, v0.16b \n"
|
||||
|
||||
/* Compare the random number v3 against threshold */
|
||||
"cmhs v24.8h, v20.8h, v3.8h \n"
|
||||
/* Update the output register with v3 */
|
||||
"bit v29.16b, v24.16b, v2.16b \n"
|
||||
|
||||
/* Mix v10 with v5 */
|
||||
"aese v10.16b, v5.16b \n"
|
||||
"aesmc v10.16b, v10.16b \n"
|
||||
|
||||
/* Compare the random number v0 against threshold */
|
||||
"cmhs v22.8h, v20.8h, v0.8h \n"
|
||||
/* Update the output register with v0 */
|
||||
"bit v29.16b, v22.16b, v7.16b \n"
|
||||
|
||||
/* Mix v13 with v1 */
|
||||
"aese v13.16b, v1.16b \n"
|
||||
"aesmc v13.16b, v13.16b \n"
|
||||
|
||||
/* Compare the random number v10 against threshold */
|
||||
"cmhs v23.8h, v20.8h, v10.8h \n"
|
||||
/* Update the output register with v10 */
|
||||
"bit v29.16b, v23.16b, v8.16b \n"
|
||||
|
||||
/* Mix v3 with v5 */
|
||||
"aese v3.16b, v5.16b \n"
|
||||
"aesmc v3.16b, v3.16b \n"
|
||||
|
||||
/* Compare the random number v13 against threshold */
|
||||
"cmhs v25.8h, v20.8h, v13.8h \n"
|
||||
/* Update the output register with v13 */
|
||||
"bit v29.16b, v25.16b, v14.16b \n"
|
||||
|
||||
/* Mix v0 with v1 */
|
||||
"aese v0.16b, v1.16b \n"
|
||||
"aesmc v0.16b, v0.16b \n"
|
||||
|
||||
/* Compare the random number v3 against threshold */
|
||||
"cmhs v24.8h, v20.8h, v3.8h \n"
|
||||
/* Update the output register with v3 */
|
||||
"bit v29.16b, v24.16b, v15.16b \n"
|
||||
|
||||
/* Mix v10 with v5 */
|
||||
"aese v10.16b, v5.16b \n"
|
||||
"aesmc v10.16b, v10.16b \n"
|
||||
|
||||
/* Compare the random number v0 against threshold */
|
||||
"cmhs v22.8h, v20.8h, v0.8h \n"
|
||||
/* Update the output register with v0 */
|
||||
"bit v29.16b, v22.16b, v26.16b \n"
|
||||
|
||||
/* Mix v13 with v1 */
|
||||
"aese v13.16b, v1.16b \n"
|
||||
"aesmc v13.16b, v13.16b \n"
|
||||
|
||||
/* Compare the random number v10 against threshold */
|
||||
"cmhs v23.8h, v20.8h, v10.8h \n"
|
||||
/* Update the output register with v10 */
|
||||
"bit v29.16b, v23.16b, v27.16b \n"
|
||||
|
||||
/* Mix v3 with v5 */
|
||||
"aese v3.16b, v5.16b \n"
|
||||
"aesmc v3.16b, v3.16b \n"
|
||||
|
||||
/* Compare the random number v13 against threshold */
|
||||
"cmhs v25.8h, v20.8h, v13.8h \n"
|
||||
/* Update the output register with v13 */
|
||||
"bit v29.16b, v25.16b, v9.16b \n"
|
||||
|
||||
/* Mix v0 with v1 */
|
||||
"aese v0.16b, v1.16b \n"
|
||||
"aesmc v0.16b, v0.16b \n"
|
||||
|
||||
/* Compare the random number v3 against threshold */
|
||||
"cmhs v24.8h, v20.8h, v3.8h \n"
|
||||
/* Update the output register with v3 */
|
||||
"bit v29.16b, v24.16b, v11.16b \n"
|
||||
|
||||
/* Mix v10 with v5 */
|
||||
"aese v10.16b, v5.16b \n"
|
||||
"aesmc v10.16b, v10.16b \n"
|
||||
|
||||
/* Compare the random number v0 against threshold */
|
||||
"cmhs v22.8h, v20.8h, v0.8h \n"
|
||||
/* Update the output register with v0 */
|
||||
"bit v29.16b, v22.16b, v16.16b \n"
|
||||
|
||||
/* Mix v13 with v1 */
|
||||
"aese v13.16b, v1.16b \n"
|
||||
"aesmc v13.16b, v13.16b \n"
|
||||
|
||||
/* Compare the random number v10 against threshold */
|
||||
"cmhs v23.8h, v20.8h, v10.8h \n"
|
||||
/* Update the output register with v10 */
|
||||
"bit v29.16b, v23.16b, v17.16b \n"
|
||||
|
||||
/* Mix v3 with v5 */
|
||||
"aese v3.16b, v5.16b \n"
|
||||
"aesmc v3.16b, v3.16b \n"
|
||||
|
||||
/* Compare the random number v13 against threshold */
|
||||
"cmhs v25.8h, v20.8h, v13.8h \n"
|
||||
/* Update the output register with v13 */
|
||||
"bit v29.16b, v25.16b, v18.16b \n"
|
||||
|
||||
// Update the key
|
||||
#ifdef CONFIG_ENABLE_PERIOD_64BIT
|
||||
"add v1.2d, v1.2d, v4.2d \n"
|
||||
"add v5.2d, v5.2d, v4.2d \n"
|
||||
#else
|
||||
"mov x12, v1.d[0] \n"
|
||||
"mov x13, v1.d[1] \n"
|
||||
"adds x14, x12, x10 \n"
|
||||
"adc x15, x13, x11 \n"
|
||||
"mov v1.d[0], x14 \n"
|
||||
"mov v1.d[1], x15 \n"
|
||||
|
||||
"mov x12, v5.d[0] \n"
|
||||
"mov x13, v5.d[1] \n"
|
||||
"adds x14, x12, x10 \n"
|
||||
"adc x15, x13, x11 \n"
|
||||
"mov v5.d[0], x14 \n"
|
||||
"mov v5.d[1], x15 \n"
|
||||
#endif
|
||||
|
||||
/* Compare the random number v3 against threshold */
|
||||
"cmhs v24.8h, v20.8h, v3.8h \n"
|
||||
/* Update the output register with v3 */
|
||||
"bit v29.16b, v24.16b, v19.16b \n"
|
||||
|
||||
// Store the output register to memory
|
||||
"st1 {v29.16b}, [x5] \n"
|
||||
"add x5, x5, 16 \n"
|
||||
|
||||
// Next iteration
|
||||
"subs x0, x0, 1 \n"
|
||||
"bne .ARS \n"
|
||||
:
|
||||
: [ offset ] "m"(offset), [ out ] "m"(out), [ in_offset ] "m"(in_offset), [ key ] "m"(key),
|
||||
[ key_const ] "m"(key_const), [ inc_step ] "m"(inc_step), [ loop_time ] "m"(loop_time),
|
||||
[ threshold ] "m"(threshold), [ key_const_ptr ] "m"(key_const_ptr), [ inc_step_ptr ] "m"(inc_step_ptr)
|
||||
: "x0", "x1", "x2", "x3", "x4", "x5", "x6", "w7", "x7", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17",
|
||||
"x18", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
|
||||
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v29");
|
||||
}
|
||||
|
||||
uint32_t DropOutGenMaskKernel::DoCompute() {
|
||||
float prob = keep_prob_;
|
||||
|
||||
uint64_t bit_count = static_cast<uint64_t>(count_);
|
||||
// align to 128 and around up
|
||||
bit_count = (bit_count + 127) & (~127);
|
||||
// transfer bit count to byte count
|
||||
uint64_t byte_count = bit_count >> 3;
|
||||
|
||||
// if prob is 0, set all bits to 0
|
||||
if (prob <= FLT_EPSILON) {
|
||||
memset_s(reinterpret_cast<void *>(io_addrs_[kIndexOutput]), byte_count, 0x00, byte_count);
|
||||
return kAicpuKernelStateSucess;
|
||||
}
|
||||
// if prob is 1, set all bits to 1
|
||||
if (abs(prob - 1.0f) <= FLT_EPSILON) {
|
||||
memset_s(reinterpret_cast<void *>(io_addrs_[kIndexOutput]), byte_count, 0xff, byte_count);
|
||||
return kAicpuKernelStateSucess;
|
||||
}
|
||||
|
||||
uint8_t *outBuff = reinterpret_cast<uint8_t *>(io_addrs_[kIndexOutput]);
|
||||
// cal actual bit count due to align to 128
|
||||
uint64_t key[2] = {g_key[0], g_key[1]};
|
||||
uint64_t baseOffset[2] = {g_offset[0], g_offset[1]};
|
||||
auto shards = [prob, baseOffset, key, outBuff](int64_t start, int64_t limit) {
|
||||
uint64_t bitCount = (static_cast<uint64_t>(limit - start)) << 7; // transfer 128bits to bit
|
||||
uint8_t *pOut = outBuff;
|
||||
pOut += ((static_cast<uint64_t>(start)) << 4); // calculate skip bytes
|
||||
uint64_t offset[2] = {0, 0};
|
||||
OffsetAdd(start << 7, baseOffset, offset);
|
||||
ARMDropOutGenMaskKernel(bitCount, prob, reinterpret_cast<const uint8_t *>(&offset),
|
||||
reinterpret_cast<const uint8_t *>(&key), pOut);
|
||||
};
|
||||
const int64_t total_unit = static_cast<int64_t>(byte_count >> 4);
|
||||
const int64_t perUnitSize = 1; // shard unit size
|
||||
aicpu::SharderNonBlock::GetInstance().ParallelFor(total_unit, perUnitSize, shards);
|
||||
const int64_t margin = 1021; // the margin of offset
|
||||
OffsetAdd(bit_count + margin, g_offset, g_offset);
|
||||
auto offset0 = reinterpret_cast<uint64_t *>(io_addrs_[2]);
|
||||
auto offset1 = reinterpret_cast<uint64_t *>(io_addrs_[3]);
|
||||
offset0[0] = g_offset[0];
|
||||
offset1[0] = g_offset[1];
|
||||
outBuff = nullptr;
|
||||
return kAicpuKernelStateSucess;
|
||||
}
|
||||
|
||||
#else // compiled on x86 arch
|
||||
|
||||
uint32_t DropOutGenMaskKernel::DoCompute() {
|
||||
std::default_random_engine te(time(0));
|
||||
std::bernoulli_distribution b(keep_prob_);
|
||||
const uint8_t mask[8] = {0x80, 0x40, 0x20, 0x10, 0x08, 0x04, 0x02, 0x01};
|
||||
uint64_t byteCount = count_ >> 3;
|
||||
out_ = reinterpret_cast<uint8_t *>(io_addrs_[kIndexOutput]);
|
||||
for (uint64_t i = 0; i < byteCount; ++i) {
|
||||
out_[i] = 0x00;
|
||||
for (const auto &m : mask) {
|
||||
if (b(te)) {
|
||||
out_[i] = out_[i] | m;
|
||||
}
|
||||
}
|
||||
}
|
||||
out_ = nullptr;
|
||||
return kAicpuKernelStateSucess;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
uint32_t DropOutGenMaskKernel::ParseKernelParam() {
|
||||
::google::protobuf::Map<::std::string, ::aicpuops::AttrValue> nodedef_map = node_def_.attrs();
|
||||
AICPU_LOGEVENT("InputNum=[%zu], OutputNum=[%zu], ioAddrNum=[%zu], seed exist: %d, seed2 exist: %d.",
|
||||
node_def_.inputs_size(), node_def_.outputs_size(), io_addrs_.size(), nodedef_map.contains("seed"),
|
||||
nodedef_map.contains("seed2"));
|
||||
|
||||
aicpuops::AttrValue seed0 = nodedef_map["seed"];
|
||||
aicpuops::AttrValue seed1 = nodedef_map["seed2"];
|
||||
seed0_ = seed0.i();
|
||||
seed1_ = seed1.i();
|
||||
if (seed0_ == 0 && seed1_ == 0) {
|
||||
seed0_ = e();
|
||||
seed1_ = e();
|
||||
}
|
||||
g_key[0] = static_cast<uint64_t>(seed1_);
|
||||
g_key[1] = static_cast<uint64_t>(seed0_);
|
||||
g_offset[0] = *reinterpret_cast<uint64_t *>(io_addrs_[2]);
|
||||
g_offset[1] = *reinterpret_cast<uint64_t *>(io_addrs_[3]);
|
||||
|
||||
uint64_t tmp_count = 1;
|
||||
aicpuops::Tensor shape_tensor = node_def_.inputs(0);
|
||||
aicpuops::TensorShape input_shape = shape_tensor.tensor_shape();
|
||||
aicpuops::DataType shape_dt = static_cast<::aicpuops::DataType>(shape_tensor.tensor_type());
|
||||
for (int j = 0; j < input_shape.dim_size(); j++) {
|
||||
tmp_count *= input_shape.dim(j).size();
|
||||
}
|
||||
if (shape_dt == aicpuops::MS_INT32) {
|
||||
auto input0 = reinterpret_cast<int32_t *>(io_addrs_[0]);
|
||||
count_ = 1;
|
||||
for (uint64_t index = 0; index < tmp_count; index++) {
|
||||
count_ *= input0[index];
|
||||
}
|
||||
} else {
|
||||
auto input0 = reinterpret_cast<int64_t *>(io_addrs_[0]);
|
||||
count_ = 1;
|
||||
for (uint64_t index = 0; index < tmp_count; index++) {
|
||||
count_ *= input0[index];
|
||||
}
|
||||
}
|
||||
|
||||
aicpuops::Tensor prob_tensor = node_def_.inputs(1);
|
||||
aicpuops::DataType dt = static_cast<::aicpuops::DataType>(prob_tensor.tensor_type());
|
||||
if (dt == aicpuops::MS_FLOAT16) {
|
||||
#if (defined __ARM_ARCH) || (defined PLATFORM_AARCH64) // compiled on arm arch
|
||||
keep_prob_ = *reinterpret_cast<float16_t *>(io_addrs_[1]);
|
||||
#else
|
||||
keep_prob_ = *reinterpret_cast<float *>(io_addrs_[1]);
|
||||
#endif
|
||||
} else {
|
||||
keep_prob_ = *reinterpret_cast<float *>(io_addrs_[1]);
|
||||
}
|
||||
if ((keep_prob_ < 0.0f) || (keep_prob_ > 1.0f)) {
|
||||
AICPU_LOGE("The prob must be in [0,1] range, got prob info %f.", keep_prob_);
|
||||
return kAicpuKernelStateInvalid;
|
||||
}
|
||||
AICPU_LOGI("DropoutGenMask mask count and pro: %lu %f", count_, keep_prob_);
|
||||
return kAicpuKernelStateSucess;
|
||||
}
|
||||
} // namespace aicpu
|
||||
|
||||
extern "C" {
|
||||
__attribute__((visibility("default"))) uint32_t DropoutGenMask(void *param) {
|
||||
aicpu::DropOutGenMaskKernel dropoutGenMaskKernel;
|
||||
return dropoutGenMaskKernel.Compute(param);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_OPS_DROP_GEN_MASK_KERNELS_H_
|
||||
#define AICPU_OPS_DROP_GEN_MASK_KERNELS_H_
|
||||
|
||||
#include "common/kernel_base.h"
|
||||
|
||||
namespace aicpu {
|
||||
class DropOutGenMaskKernel : public KernelBase {
|
||||
public:
|
||||
DropOutGenMaskKernel() : KernelBase("DropOutGenMask") {}
|
||||
|
||||
~DropOutGenMaskKernel() = default;
|
||||
|
||||
uint64_t seed0_;
|
||||
uint64_t seed1_;
|
||||
float keep_prob_;
|
||||
uint64_t count_;
|
||||
uint64_t g_key[2];
|
||||
uint64_t g_offset[2];
|
||||
|
||||
protected:
|
||||
uint32_t DoCompute() override;
|
||||
|
||||
uint32_t ParseKernelParam() override;
|
||||
|
||||
uint8_t *out_;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -30,6 +30,7 @@ constexpr auto kLibAicpuKernelSoName = "libaicpu_kernels.so";
|
|||
constexpr auto kLibCpuKernelSoName = "libcpu_kernels.so";
|
||||
constexpr auto kFormat = "format";
|
||||
constexpr auto kDataFormat = "data_format";
|
||||
constexpr auto kDropoutGenMaskOpName = "DropoutGenMask";
|
||||
constexpr auto kInitDataSetQueue = "InitDataSetQueue";
|
||||
constexpr auto kInitData = "InitData";
|
||||
constexpr auto kGetNext = "GetNext";
|
||||
|
@ -252,7 +253,8 @@ const std::set<std::string> kCpuKernelOps{kIdentity,
|
|||
kSign};
|
||||
const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter, kPadAndShift, kDropout3D,
|
||||
kDropout2D, kNonMaxSuppressionV3, kGetNext, kInitData, kPrint};
|
||||
const std::set<std::string> kCpuKernelBaseOps{kRandomChoiceWithMask,
|
||||
const std::set<std::string> kCpuKernelBaseOps{kDropoutGenMaskOpName,
|
||||
kRandomChoiceWithMask,
|
||||
kEnvironCreate,
|
||||
kEnvironSet,
|
||||
kEnvironGet,
|
||||
|
|
|
@ -362,7 +362,6 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
|
|||
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
|
||||
ir_fusion_pm->AddPass(std::make_shared<DropoutGenMaskFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<StatelessDropOutGenMaskReplace>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<SeedAdapter>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
|
||||
|
@ -658,7 +657,6 @@ void AscendUnifyMindIR(const std::shared_ptr<session::KernelGraph> &kernel_graph
|
|||
auto unify_mindir_pm = std::make_shared<opt::PassManager>("unify_mindir_pm");
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::SpaceToBatchNDAttrUpdate>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::BatchToSpaceNDAttrUpdate>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::AICpuLibSelectPass>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::MaxPool2MaxPoolWithArgmax>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::MaxPoolWithArgmaxUnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::MaxPoolGradWithArgmaxUnifyMindIR>());
|
||||
|
@ -694,6 +692,7 @@ void AscendUnifyMindIR(const std::shared_ptr<session::KernelGraph> &kernel_graph
|
|||
unify_mindir_pm->AddPass(std::make_shared<opt::NeighborExchangeV2GradUnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::AllToAllUnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::AscendVmOpAdapter>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::AICpuLibSelectPass>());
|
||||
|
||||
optimizer->AddPassManager(unify_mindir_pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
|
|
|
@ -31,25 +31,29 @@
|
|||
namespace mindspore::opt {
|
||||
namespace {
|
||||
const std::set<std::string> kNodeWithSeedOperators = {kGammaOpName, kPoissonOpName, kStandardLaplaceOpName,
|
||||
kStandardNormalOpName, kUniformIntOpName, kUniformRealOpName};
|
||||
tensor::TensorPtr CreateTensor(int64_t seed) {
|
||||
kStandardNormalOpName, kUniformIntOpName, kUniformRealOpName,
|
||||
kDropoutGenMaskOpName};
|
||||
template <typename T>
|
||||
tensor::TensorPtr CreateTensor(T seed) {
|
||||
// 1 create seed tensor
|
||||
std::vector<int64_t> indices_shape = {1};
|
||||
auto type = std::is_same<T, int64_t>::value ? kInt64 : kUInt64;
|
||||
TensorTypePtr tensor_type = std::make_shared<TensorType>(kInt64);
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
tensor::DeviceInfo device_info{kOpFormat_DEFAULT, tensor_type};
|
||||
tensor::TensorPtr indices_tensor = std::make_shared<tensor::Tensor>(kInt64->type_id(), indices_shape);
|
||||
tensor::TensorPtr indices_tensor = std::make_shared<tensor::Tensor>(type->type_id(), indices_shape);
|
||||
MS_EXCEPTION_IF_NULL(indices_tensor);
|
||||
indices_tensor->set_device_info(device_info);
|
||||
// 2 set value of tensor
|
||||
auto data_ptr = indices_tensor->data_c();
|
||||
MS_EXCEPTION_IF_NULL(data_ptr);
|
||||
auto ptr = static_cast<int64_t *>(data_ptr);
|
||||
auto ptr = static_cast<T *>(data_ptr);
|
||||
*ptr = seed;
|
||||
return indices_tensor;
|
||||
}
|
||||
|
||||
ValueNodePtr CreateValueNode(int64_t seed) {
|
||||
template <typename T>
|
||||
ValueNodePtr CreateValueNode(T seed) {
|
||||
tensor::TensorPtr tensor = CreateTensor(seed);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
auto value_node = std::make_shared<ValueNode>(tensor);
|
||||
|
@ -61,7 +65,11 @@ ValueNodePtr CreateValueNode(int64_t seed) {
|
|||
value_node->set_kernel_info(indices_kernel_info);
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
builder.SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
builder.SetOutputsDeviceType({kNumberTypeInt64});
|
||||
if (std::is_same<T, int64_t>::value) {
|
||||
builder.SetOutputsDeviceType({kNumberTypeInt64});
|
||||
} else {
|
||||
builder.SetOutputsDeviceType({kNumberTypeUInt64});
|
||||
}
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), value_node.get());
|
||||
return value_node;
|
||||
}
|
||||
|
@ -70,29 +78,41 @@ std::vector<ValueNodePtr> ConvertAttrToValueNode(const std::shared_ptr<kernel::O
|
|||
const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// get seed
|
||||
std::vector<ValueNodePtr> ret = {};
|
||||
auto attrs = op_info->attrs_ptr();
|
||||
if (attrs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Node(" << cnode->DebugString() << ") doesn't have any attrs."
|
||||
<< trace::DumpSourceLines(cnode);
|
||||
}
|
||||
for (const auto &attr : attrs) {
|
||||
if (!common::AnfAlgo::HasNodeAttr(attr->name(), cnode)) {
|
||||
MS_LOG(EXCEPTION) << "Node(" << cnode->DebugString() << ") doesn't have attr(" << attr->name() << ")."
|
||||
// DropoutGenMask only create offset
|
||||
if (op_info->op_name() == kDropoutGenMaskOpName) {
|
||||
uint64_t offset = 0;
|
||||
auto offset0 = CreateValueNode(offset);
|
||||
auto offset1 = CreateValueNode(offset);
|
||||
if (offset0 == nullptr || offset1 == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Create value node error, node: " << cnode->DebugString() << trace::DumpSourceLines(cnode);
|
||||
}
|
||||
(void)ret.emplace_back(offset0);
|
||||
(void)ret.emplace_back(offset1);
|
||||
} else {
|
||||
// Get seed to create value node
|
||||
auto attrs = op_info->attrs_ptr();
|
||||
if (attrs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Node(" << cnode->DebugString() << ") doesn't have any attrs."
|
||||
<< trace::DumpSourceLines(cnode);
|
||||
}
|
||||
auto attr_value = common::AnfAlgo::GetNodeAttr<int64_t>(cnode, attr->name());
|
||||
auto value_node = CreateValueNode(attr_value);
|
||||
if (value_node == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Create value node error, node: " << cnode->DebugString() << ", seed value: " << attr_value
|
||||
for (const auto &attr : attrs) {
|
||||
if (!common::AnfAlgo::HasNodeAttr(attr->name(), cnode)) {
|
||||
MS_LOG(EXCEPTION) << "Node(" << cnode->DebugString() << ") doesn't have attr(" << attr->name() << ")."
|
||||
<< trace::DumpSourceLines(cnode);
|
||||
}
|
||||
auto attr_value = common::AnfAlgo::GetNodeAttr<int64_t>(cnode, attr->name());
|
||||
auto value_node = CreateValueNode(attr_value);
|
||||
if (value_node == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Create value node error, node: " << cnode->DebugString() << ", seed value: " << attr_value
|
||||
<< trace::DumpSourceLines(cnode);
|
||||
}
|
||||
(void)ret.emplace_back(value_node);
|
||||
}
|
||||
if (ret.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Node(" << cnode->DebugString() << ") doesn't have any matched attrs."
|
||||
<< trace::DumpSourceLines(cnode);
|
||||
}
|
||||
(void)ret.emplace_back(value_node);
|
||||
}
|
||||
if (ret.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Node(" << cnode->DebugString() << ") doesn't have any matched attrs."
|
||||
<< trace::DumpSourceLines(cnode);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -29,7 +29,8 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
|
||||
static const std::set<std::string> kAICpuOpNames = {kEnvironCreateOpName,
|
||||
static const std::set<std::string> kAICpuOpNames = {kDropoutGenMaskOpName,
|
||||
kEnvironCreateOpName,
|
||||
kEnvironSetOpName,
|
||||
kEnvironGetOpName,
|
||||
kEnvironDestroyAllOpName,
|
||||
|
|
|
@ -20,10 +20,13 @@ dropout_genmask_op_info = AiCPURegOp("DropoutGenMask") \
|
|||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x1", "required") \
|
||||
.input(1, "x2", "required") \
|
||||
.input(2, "offset1", "required") \
|
||||
.input(3, "offset2", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.attr("Seed0", "int") \
|
||||
.attr("Seed1", "int") \
|
||||
.dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.U64_Default, DataType.U64_Default,
|
||||
DataType.U8_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(dropout_genmask_op_info)
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import random
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
@ -19,6 +20,7 @@ import mindspore.context as context
|
|||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import set_seed
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend")
|
||||
|
@ -49,3 +51,117 @@ def test_net():
|
|||
output = mask(tx, ty)
|
||||
print(output.asnumpy())
|
||||
assert ([255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255] == output.asnumpy()).all()
|
||||
|
||||
|
||||
class Drop(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Drop, self).__init__()
|
||||
self.drop = nn.Dropout(1.0 - 0.5)
|
||||
|
||||
def construct(self, out):
|
||||
out = self.drop(out)
|
||||
return out
|
||||
|
||||
|
||||
def train(net, data):
|
||||
net.set_train(True)
|
||||
res_list = []
|
||||
for _ in range(5):
|
||||
res = net(data)
|
||||
res_list.append(res.asnumpy())
|
||||
return res_list
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_drop():
|
||||
"""
|
||||
Feature: test dropout gen mask diff in diff step.
|
||||
Description: dropout gen mask.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
set_seed(1)
|
||||
np.random.seed(1)
|
||||
random.seed(1)
|
||||
data = Tensor(np.ones([1, 50]).astype(np.float32))
|
||||
|
||||
net = Drop()
|
||||
out_list = train(net, data)
|
||||
|
||||
for i in range(len(out_list)):
|
||||
for j in range(len(out_list)):
|
||||
if i == j:
|
||||
continue
|
||||
assert np.allclose(out_list[i], out_list[j], 0, 0) is False
|
||||
|
||||
|
||||
class Net0(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net0, self).__init__()
|
||||
self.mask_1 = P.DropoutGenMask(10, 10)
|
||||
self.mask_2 = P.DropoutGenMask(10, 10)
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, x_, y_):
|
||||
shape_x = self.shape(x_)
|
||||
out_1 = self.mask_1(shape_x, y_)
|
||||
out_2 = self.mask_2(shape_x, y_)
|
||||
return out_1, out_2
|
||||
|
||||
|
||||
class Net1(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net1, self).__init__()
|
||||
self.mask_1 = P.DropoutGenMask(20, 20)
|
||||
self.mask_2 = P.DropoutGenMask(10, 10)
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, x_, y_):
|
||||
shape_x = self.shape(x_)
|
||||
out_1 = self.mask_1(shape_x, y_)
|
||||
out_2 = self.mask_2(shape_x, y_)
|
||||
return out_1, out_2
|
||||
|
||||
|
||||
class Net2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.mask_1 = P.DropoutGenMask(10, 10)
|
||||
self.mask_2 = P.DropoutGenMask(20, 20)
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, x_, y_):
|
||||
shape_x = self.shape(x_)
|
||||
out_1 = self.mask_1(shape_x, y_)
|
||||
out_2 = self.mask_2(shape_x, y_)
|
||||
return out_1, out_2
|
||||
|
||||
|
||||
px = np.ones([2, 4, 2, 2]).astype(np.int32)
|
||||
py = np.array([0.5]).astype(np.float32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_diff_seed():
|
||||
"""
|
||||
Feature: test dropout gen mask diff by diff seed.
|
||||
Description: dropout gen mask.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
net_0 = Net0()
|
||||
net_1 = Net1()
|
||||
net_2 = Net2()
|
||||
|
||||
net0_out0, net0_out1 = net_0(Tensor(px), Tensor(py))
|
||||
net1_out0, net1_out1 = net_1(Tensor(px), Tensor(py))
|
||||
net2_out0, net2_out1 = net_2(Tensor(px), Tensor(py))
|
||||
|
||||
assert (np.allclose(net0_out0.asnumpy(), net1_out0.asnumpy(), 0, 0) is False) or \
|
||||
(np.allclose(net0_out1.asnumpy(), net1_out1.asnumpy(), 0, 0) is False)
|
||||
assert (np.allclose(net0_out0.asnumpy(), net2_out0.asnumpy(), 0, 0) is False) or \
|
||||
(np.allclose(net0_out1.asnumpy(), net2_out1.asnumpy(), 0, 0) is False)
|
||||
|
|
Loading…
Reference in New Issue