forked from mindspore-Ecosystem/mindspore
merge canndev code to mindspore
This commit is contained in:
parent
d94921848c
commit
6e430b8524
|
@ -104,6 +104,8 @@
|
|||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "zerodivcond"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "noConstructor"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "noExplicitConstructor"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "syntaxError"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "unsignedLessThanZero"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "identicalConditionAfterEarlyExit"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "uninitMemberVar"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "redundantInitialization"
|
||||
|
|
|
@ -0,0 +1,336 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "densetodense_set_operation.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "utils/allocator_utils.h"
|
||||
#include "utils/eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
#include "kernel_log.h"
|
||||
#include "status.h"
|
||||
|
||||
namespace {
|
||||
const uint32_t kOutputNum = 3;
|
||||
const uint32_t kInputNum = 2;
|
||||
const uint32_t kOutputIndex0 = 0;
|
||||
const uint32_t kOutputIndex1 = 1;
|
||||
const uint32_t kOutputIndex2 = 2;
|
||||
const char *kDenseToDenseSetOperation = "DenseToDenseSetOperation";
|
||||
const int64_t kParallelNum{512};
|
||||
|
||||
#define DTOD_SET_OPE_COMPUTE_CASE(DTYPE, TYPE, CTX) \
|
||||
case (DTYPE): { \
|
||||
uint32_t result = DoCompute<TYPE>(CTX); \
|
||||
if (result != KERNEL_STATUS_OK) { \
|
||||
KERNEL_LOG_ERROR("DenseToDenseSetOperation kernel compute failed."); \
|
||||
return result; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
|
||||
const std::vector<int64_t> Strides(const std::vector<int64_t> &shape) {
|
||||
std::vector<int64_t> result(shape.size());
|
||||
int64_t product = 1;
|
||||
for (int64_t i = static_cast<int64_t>(shape.size()) - 1; i >= 0; --i) {
|
||||
result[i] = product;
|
||||
product *= shape[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t GetNumElements(const std::vector<int64_t> &input_shape, int64_t &res) {
|
||||
int64_t result = 1;
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
KERNEL_CHECK_FALSE(MulWithoutOverflow(input_shape[i], result, result), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Overflow when calculate shape size");
|
||||
}
|
||||
res = result;
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t GroupShape(const std::vector<int64_t> &input_shape, std::vector<int64_t> &grouped_shape) {
|
||||
if (input_shape.size() < 2) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
// grouped_shape is input_shape[:-1]
|
||||
grouped_shape.assign(input_shape.begin(), input_shape.end() - 1);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t ChecksShapesMatch(const std::vector<int64_t> &shape1, const std::vector<int64_t> &shape2) {
|
||||
if (shape1.size() != shape2.size()) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
for (size_t i = 0; i < shape1.size(); i++) {
|
||||
if (shape1[i] != shape2[i]) return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t GroupShapeFromDenseInputs(const std::vector<int64_t> &shape1, const std::vector<int64_t> &shape2,
|
||||
std::vector<int64_t> &group_shape) {
|
||||
std::vector<int64_t> group_shape_1;
|
||||
KERNEL_HANDLE_ERROR(GroupShape(shape1, group_shape_1), "Shape rank is less than 2");
|
||||
std::vector<int64_t> group_shape_2;
|
||||
KERNEL_HANDLE_ERROR(GroupShape(shape2, group_shape_2), "Shape rank is less than 2");
|
||||
KERNEL_HANDLE_ERROR(ChecksShapesMatch(group_shape_1, group_shape_2), "Two shapes mismatch with each other.");
|
||||
group_shape.assign(group_shape_1.begin(), group_shape_1.end());
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
// Split `flat_group_index` into separate dimensions based on `group_shape`.
|
||||
void PopulatesGroupIndices(const int64_t flat_group_index, const std::vector<int64_t> &group_shape,
|
||||
std::vector<int64_t> &group_indices) {
|
||||
group_indices.clear();
|
||||
int64_t running_flat_group_index = flat_group_index;
|
||||
for (int64_t group_dim_index = static_cast<int64_t>(group_shape.size()) - 1; group_dim_index >= 0;
|
||||
--group_dim_index) {
|
||||
const auto group_dim = group_shape[group_dim_index];
|
||||
group_indices.insert(group_indices.begin(), running_flat_group_index % group_dim);
|
||||
running_flat_group_index /= group_dim;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t PopulateFromDenseGroup(Tensor *input_tensor, const std::vector<int64_t> &input_strides,
|
||||
const std::vector<int64_t> &group_indices, std::set<T> &result) {
|
||||
KERNEL_CHECK_FALSE(group_indices.size() == input_strides.size() - 1, KERNEL_STATUS_PARAM_INVALID,
|
||||
"group_indices size is not equal to input_strides.size-1 ")
|
||||
result.clear();
|
||||
EigenTensor input_tensor_eigen(input_tensor, input_tensor->GetData());
|
||||
auto input_flat = input_tensor_eigen.flat<T>();
|
||||
const auto start = std::inner_product(group_indices.begin(), group_indices.end(), input_strides.begin(), 0LL);
|
||||
auto input_shape = input_tensor->GetTensorShape();
|
||||
const auto end = start + input_shape->GetDimSize(input_shape->GetDims() - 1);
|
||||
for (int64_t i = start; i < end; ++i) {
|
||||
result.insert(input_flat(i));
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t DenseToDenseSetOperationCpuKernel::Check(const CpuKernelContext &ctx) {
|
||||
AttrValue *set_operation = ctx.GetAttr("set_operation");
|
||||
std::string set_operation_str;
|
||||
if (set_operation != nullptr) {
|
||||
set_operation_str = set_operation->GetString();
|
||||
}
|
||||
std::transform(set_operation_str.begin(), set_operation_str.end(), set_operation_str.begin(), ::tolower);
|
||||
if ("a-b" == set_operation_str) {
|
||||
set_operation_ = A_MINUS_B;
|
||||
} else if ("b-a" == set_operation_str) {
|
||||
set_operation_ = B_MINUS_A;
|
||||
} else if ("intersection" == set_operation_str) {
|
||||
set_operation_ = INTERSECTION;
|
||||
} else if ("union" == set_operation_str) {
|
||||
set_operation_ = UNION;
|
||||
} else {
|
||||
KERNEL_LOG_ERROR("Invalid set_operation");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
AttrValue *validate_indices = ctx.GetAttr("validate_indices");
|
||||
if (validate_indices != nullptr) {
|
||||
validate_indices_ = validate_indices->GetBool();
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t DenseToDenseSetOperationCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum),
|
||||
"DenseToDenseSetOperation check input and output number failed.");
|
||||
KERNEL_HANDLE_ERROR(Check(ctx), "DenseToDenseSetOperation check params failed.");
|
||||
auto data_type_x1 = ctx.Input(0)->GetDataType();
|
||||
auto data_type_x2 = ctx.Input(1)->GetDataType();
|
||||
KERNEL_CHECK_FALSE(data_type_x1 == data_type_x2, KERNEL_STATUS_PARAM_INVALID,
|
||||
"The type of x1 must be the same as x2");
|
||||
switch (data_type_x1) {
|
||||
DTOD_SET_OPE_COMPUTE_CASE(DT_INT8, int8_t, ctx)
|
||||
DTOD_SET_OPE_COMPUTE_CASE(DT_INT16, int16_t, ctx)
|
||||
DTOD_SET_OPE_COMPUTE_CASE(DT_INT32, int32_t, ctx)
|
||||
DTOD_SET_OPE_COMPUTE_CASE(DT_INT64, int64_t, ctx)
|
||||
DTOD_SET_OPE_COMPUTE_CASE(DT_UINT8, uint8_t, ctx)
|
||||
DTOD_SET_OPE_COMPUTE_CASE(DT_UINT16, uint16_t, ctx)
|
||||
default:
|
||||
KERNEL_LOG_ERROR("DenseToDenseSetOperation kernel data type [%s] not support.", DTypeStr(data_type_x1).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DenseToDenseSetOperationCpuKernel::ApplySetOperation(const std::set<T> &set1, const std::set<T> &set2,
|
||||
std::set<T> &result) {
|
||||
switch (set_operation_) {
|
||||
case A_MINUS_B:
|
||||
std::set_difference(set1.begin(), set1.end(), set2.begin(), set2.end(), std::inserter(result, result.begin()));
|
||||
break;
|
||||
case B_MINUS_A:
|
||||
std::set_difference(set2.begin(), set2.end(), set1.begin(), set1.end(), std::inserter(result, result.begin()));
|
||||
break;
|
||||
case INTERSECTION:
|
||||
std::set_intersection(set1.begin(), set1.end(), set2.begin(), set2.end(), std::inserter(result, result.begin()));
|
||||
break;
|
||||
case UNION:
|
||||
std::set_union(set1.begin(), set1.end(), set2.begin(), set2.end(), std::inserter(result, result.begin()));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t DenseToDenseSetOperationCpuKernel::OutputSparseTensor(
|
||||
CpuKernelContext &ctx, const std::vector<int64_t> &output_shape, const int64_t num_values,
|
||||
const std::map<std::vector<int64_t>, std::set<T>> &sets) {
|
||||
Tensor *out_indices_t, *out_values_t, *out_shape_t;
|
||||
|
||||
out_indices_t = ctx.Output(kOutputIndex0);
|
||||
out_values_t = ctx.Output(kOutputIndex1);
|
||||
out_shape_t = ctx.Output(kOutputIndex2);
|
||||
|
||||
auto out_indices_shape = out_indices_t->GetTensorShape();
|
||||
auto out_values_shape = out_values_t->GetTensorShape();
|
||||
auto out_shape_shape = out_shape_t->GetTensorShape();
|
||||
|
||||
int64_t output_shape_size = output_shape.size();
|
||||
|
||||
out_indices_shape->SetDimSizes({num_values, output_shape_size});
|
||||
out_values_shape->SetDimSizes({num_values});
|
||||
out_shape_shape->SetDimSizes({output_shape_size});
|
||||
|
||||
EigenTensor out_indices_tensor(out_indices_t, out_indices_t->GetData());
|
||||
EigenTensor out_values_tensor(out_values_t, out_values_t->GetData());
|
||||
EigenTensor out_shape_tensor(out_shape_t, out_shape_t->GetData());
|
||||
|
||||
auto out_indices_mat = out_indices_tensor.matrix<int64_t>();
|
||||
auto out_values_flat = out_values_tensor.flat<T>();
|
||||
auto out_shape_flat = out_shape_tensor.flat<int64_t>();
|
||||
|
||||
// For each set, write its indices and values to output tensors.
|
||||
int64_t value_index = 0;
|
||||
for (auto it = sets.begin(); it != sets.end(); ++it) {
|
||||
const auto &group_indices = it->first;
|
||||
KERNEL_CHECK_FALSE(group_indices.size() == output_shape.size() - 1, KERNEL_STATUS_PARAM_INVALID,
|
||||
"Invalid number of indices .")
|
||||
const auto &set = it->second;
|
||||
|
||||
int64_t group_value_index = 0;
|
||||
for (auto value = set.begin(); value != set.end(); ++value, ++value_index, ++group_value_index) {
|
||||
// First n-1 dimensions are the group, last dimension is the position in
|
||||
// the set.
|
||||
for (size_t i = 0; i < group_indices.size(); ++i) {
|
||||
out_indices_mat(value_index, i) = group_indices[i];
|
||||
}
|
||||
out_indices_mat(value_index, group_indices.size()) = group_value_index;
|
||||
|
||||
out_values_flat(value_index) = *value;
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < output_shape_size; ++i) {
|
||||
out_shape_flat(i) = output_shape[i];
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t DenseToDenseSetOperationCpuKernel::DoCompute(CpuKernelContext &ctx) {
|
||||
Tensor *set1_t = ctx.Input(0);
|
||||
Tensor *set2_t = ctx.Input(1);
|
||||
std::vector<int64_t> group_shape;
|
||||
const auto shape1 = set1_t->GetTensorShape()->GetDimSizes();
|
||||
const auto shape2 = set2_t->GetTensorShape()->GetDimSizes();
|
||||
KERNEL_HANDLE_ERROR(GroupShapeFromDenseInputs(shape1, shape2, group_shape), "Create group shape error.");
|
||||
|
||||
const auto set1_strides = Strides(shape1);
|
||||
const auto set2_strides = Strides(shape2);
|
||||
|
||||
std::map<std::vector<int64_t>, std::set<T>> group_sets;
|
||||
std::atomic<int64_t> num_result_values(0);
|
||||
std::atomic<int64_t> max_set_size(0);
|
||||
|
||||
int64_t num_elements;
|
||||
KERNEL_HANDLE_ERROR(GetNumElements(group_shape, num_elements), "Get numelements failed.");
|
||||
|
||||
if (num_elements <= kParallelNum) {
|
||||
std::set<T> set1_group_set;
|
||||
std::set<T> set2_group_set;
|
||||
std::vector<int64_t> group_indices;
|
||||
for (int64_t flat_group_index = 0; flat_group_index < num_elements; ++flat_group_index) {
|
||||
PopulatesGroupIndices(flat_group_index, group_shape, group_indices);
|
||||
KERNEL_HANDLE_ERROR(PopulateFromDenseGroup<T>(set1_t, set1_strides, group_indices, set1_group_set),
|
||||
"PopulateFromDenseGroup set1 compute failed");
|
||||
KERNEL_HANDLE_ERROR(PopulateFromDenseGroup<T>(set2_t, set2_strides, group_indices, set2_group_set),
|
||||
"PopulateFromDenseGroup set2 compute failed");
|
||||
|
||||
std::set<T> group_set;
|
||||
ApplySetOperation(set1_group_set, set2_group_set, group_set);
|
||||
if (!group_set.empty()) {
|
||||
group_sets[group_indices] = group_set;
|
||||
int64_t set_size = group_set.size();
|
||||
if (set_size > max_set_size) {
|
||||
max_set_size = set_size;
|
||||
}
|
||||
num_result_values += set_size;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::mutex mt;
|
||||
int64_t total = num_elements;
|
||||
uint32_t cores = CpuKernelUtils::GetCPUNum(ctx);
|
||||
int64_t per_unit_size = (total / std::min(std::max(1L, cores - 2L), total));
|
||||
uint32_t ret = CpuKernelUtils::ParallelFor(ctx, total, per_unit_size, [&](int64_t begin, int64_t end) {
|
||||
std::set<T> set1_group_set;
|
||||
std::set<T> set2_group_set;
|
||||
std::vector<int64_t> group_indices;
|
||||
for (int64_t flat_group_index = begin; flat_group_index < end; ++flat_group_index) {
|
||||
PopulatesGroupIndices(flat_group_index, group_shape, group_indices);
|
||||
KERNEL_HANDLE_ERROR(PopulateFromDenseGroup<T>(set1_t, set1_strides, group_indices, set1_group_set),
|
||||
"PopulateFromDenseGroup set1 compute failed");
|
||||
KERNEL_HANDLE_ERROR(PopulateFromDenseGroup<T>(set2_t, set2_strides, group_indices, set2_group_set),
|
||||
"PopulateFromDenseGroup set2 compute failed");
|
||||
std::set<T> group_set;
|
||||
ApplySetOperation(set1_group_set, set2_group_set, group_set);
|
||||
if (!group_set.empty()) {
|
||||
std::lock_guard<std::mutex> lck(mt);
|
||||
group_sets[group_indices] = group_set;
|
||||
int64_t set_size = group_set.size();
|
||||
if (set_size > max_set_size) {
|
||||
max_set_size = set_size;
|
||||
}
|
||||
num_result_values += set_size;
|
||||
}
|
||||
}
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_OK);
|
||||
});
|
||||
KERNEL_CHECK_FALSE((ret == KERNEL_STATUS_OK), KERNEL_STATUS_INNER_ERROR, "SparseSplit compute failed.");
|
||||
}
|
||||
|
||||
group_shape.push_back(max_set_size);
|
||||
return OutputSparseTensor<T>(ctx, group_shape, num_result_values, group_sets);
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(kDenseToDenseSetOperation, DenseToDenseSetOperationCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_DENSE_TO_DENSE_SET_OPERATION_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_DENSE_TO_DENSE_SET_OPERATION_H_
|
||||
|
||||
#include <set>
|
||||
#include "cpu_ops_kernel.h"
|
||||
namespace aicpu {
|
||||
enum SetOperation { A_MINUS_B = 0, B_MINUS_A = 1, INTERSECTION = 2, UNION = 3 };
|
||||
|
||||
class DenseToDenseSetOperationCpuKernel : public CpuKernel {
|
||||
public:
|
||||
~DenseToDenseSetOperationCpuKernel() = default;
|
||||
virtual uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t Check(const CpuKernelContext &ctx);
|
||||
template <typename T>
|
||||
uint32_t DoCompute(CpuKernelContext &ctx);
|
||||
template <typename T>
|
||||
uint32_t OutputSparseTensor(CpuKernelContext &ctx, const std::vector<int64_t> &output_shape, const int64_t num_values,
|
||||
const std::map<std::vector<int64_t>, std::set<T>> &sets);
|
||||
template <typename T>
|
||||
void ApplySetOperation(const std::set<T> &set1, const std::set<T> &set2, std::set<T> &result);
|
||||
|
||||
SetOperation set_operation_ = A_MINUS_B;
|
||||
bool validate_indices_ = true;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -50,7 +50,7 @@ const std::vector<int64_t> Strides(const std::vector<int64_t> &shape) {
|
|||
return result;
|
||||
}
|
||||
|
||||
uint32_t GroupShape(const std::vector<int64_t> input_shape, std::vector<int64_t> &grouped_shape) {
|
||||
uint32_t GroupsShape(const std::vector<int64_t> input_shape, std::vector<int64_t> &grouped_shape) {
|
||||
if (input_shape.size() < 2) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
@ -69,18 +69,18 @@ uint32_t CheckShapesMatch(const std::vector<int64_t> &shape1, const std::vector<
|
|||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t GroupShapeFromInputs(const std::vector<int64_t> &shape1, const std::vector<int64_t> &shape2,
|
||||
std::vector<int64_t> &group_shape) {
|
||||
uint32_t GroupsShapeFromInputs(const std::vector<int64_t> &shape1, const std::vector<int64_t> &shape2,
|
||||
std::vector<int64_t> &group_shape) {
|
||||
std::vector<int64_t> group_shape_1;
|
||||
KERNEL_HANDLE_ERROR(GroupShape(shape1, group_shape_1), "X1_Shape rank is less than 2.");
|
||||
KERNEL_HANDLE_ERROR(GroupsShape(shape1, group_shape_1), "X1_Shape rank is less than 2.");
|
||||
std::vector<int64_t> group_shape_2;
|
||||
KERNEL_HANDLE_ERROR(GroupShape(shape2, group_shape_2), "X2_Shape rank is less than 2.");
|
||||
KERNEL_HANDLE_ERROR(GroupsShape(shape2, group_shape_2), "X2_Shape rank is less than 2.");
|
||||
KERNEL_HANDLE_ERROR(CheckShapesMatch(group_shape_1, group_shape_2), "Two shapes mismatch with each other.");
|
||||
group_shape.assign(group_shape_1.begin(), group_shape_1.end());
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t GetNumElements(const std::vector<int64_t> input_shape, int64_t &res) {
|
||||
uint32_t GetsNumElements(const std::vector<int64_t> input_shape, int64_t &res) {
|
||||
int64_t result = 1;
|
||||
for (uint32_t i = 0; i < input_shape.size(); i++) {
|
||||
KERNEL_CHECK_FALSE(MulWithoutOverflow(input_shape[i], result, result), KERNEL_STATUS_PARAM_INVALID,
|
||||
|
@ -278,13 +278,13 @@ uint32_t DenseToSparseSetOperationCpuKernel::ComputeDenseToSparse(DataBank &data
|
|||
std::vector<int64_t> group_shape;
|
||||
const auto shape1 = set1_t->GetTensorShape()->GetDimSizes();
|
||||
|
||||
KERNEL_HANDLE_ERROR(GroupShapeFromInputs(shape1, shape2, group_shape), "GroupShapeFromInputs error.");
|
||||
KERNEL_HANDLE_ERROR(GroupsShapeFromInputs(shape1, shape2, group_shape), "GroupsShapeFromInputs error.");
|
||||
const std::vector<int64_t> set1_strides = Strides(shape1);
|
||||
std::map<std::vector<int64_t>, std::set<T>> group_sets;
|
||||
int64_t num_result_values = 0;
|
||||
int64_t max_set_size = 0;
|
||||
int64_t num_elements;
|
||||
KERNEL_HANDLE_ERROR(GetNumElements(group_shape, num_elements), "NumElements error.");
|
||||
KERNEL_HANDLE_ERROR(GetsNumElements(group_shape, num_elements), "NumElements error.");
|
||||
if (num_elements <= kParallelNum) {
|
||||
std::set<T> set1_group_set;
|
||||
std::set<T> set2_group_set;
|
||||
|
|
|
@ -0,0 +1,379 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "instance_norm_v2.h"
|
||||
|
||||
#include "Eigen/Dense"
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "cpu_types.h"
|
||||
#include "securec.h"
|
||||
#include "utils/kernel_util.h"
|
||||
#include "utils/eigen_tensor.h"
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
|
||||
namespace {
|
||||
const char *const kInstanceNormV2 = "InstanceNormV2";
|
||||
constexpr float float_init_zero = 0.0;
|
||||
constexpr float float_init_one = 1.0;
|
||||
constexpr float momentum_min = 0.0;
|
||||
constexpr float momentum_max = 1.0;
|
||||
constexpr double double_init_zero = 0.0;
|
||||
constexpr double double_init_one = 1.0;
|
||||
constexpr int32_t int32_init_one = 1;
|
||||
constexpr int64_t int64_init_one = 1;
|
||||
constexpr auto kInstanceNormV2InputsNum = 5;
|
||||
constexpr auto kInstanceNormV2OutputNum = 3;
|
||||
// GRAIN_SIZE for Parallel
|
||||
constexpr auto kGrainSize = 4 * 1024;
|
||||
constexpr auto kDim3 = 3;
|
||||
constexpr auto kDim4 = 4;
|
||||
constexpr auto kDim5 = 5;
|
||||
constexpr auto InstanceNormV2InXIndex = 0;
|
||||
constexpr auto InstanceNormV2InGammaIndex = 1;
|
||||
constexpr auto InstanceNormV2InBetaIndex = 2;
|
||||
constexpr auto InstanceNormV2InMeanIndex = 3;
|
||||
constexpr auto InstanceNormV2InVarianceIndex = 4;
|
||||
constexpr auto InstanceNormV2OutYIndex = 0;
|
||||
constexpr auto InstanceNormV2OutBatchMeanIndex = 1;
|
||||
constexpr auto InstanceNormV2OutBatchVarianceIndex = 2;
|
||||
|
||||
template <typename T>
|
||||
struct InvStd {
|
||||
T operator()(T var, double epsilon) const {
|
||||
T invstd = 0;
|
||||
if (var != static_cast<T>(0) || epsilon != static_cast<T>(0)) {
|
||||
invstd = static_cast<T>(int32_init_one) / std::sqrt(var + epsilon);
|
||||
}
|
||||
return invstd;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t InstanceNormV2CpuKernel::InstanceNormV2TypeCheck(const CpuKernelContext &ctx) {
|
||||
auto x_type = ctx.Input(InstanceNormV2InXIndex)->GetDataType();
|
||||
auto gamma_type = ctx.Input(InstanceNormV2InGammaIndex)->GetDataType();
|
||||
auto beta_type = ctx.Input(InstanceNormV2InBetaIndex)->GetDataType();
|
||||
auto mean_type = ctx.Input(InstanceNormV2InMeanIndex)->GetDataType();
|
||||
auto variance_type = ctx.Input(InstanceNormV2InVarianceIndex)->GetDataType();
|
||||
auto y_type = ctx.Output(InstanceNormV2OutYIndex)->GetDataType();
|
||||
auto batch_mean_type = ctx.Output(InstanceNormV2OutBatchMeanIndex)->GetDataType();
|
||||
auto batch_variance_type = ctx.Output(InstanceNormV2OutBatchVarianceIndex)->GetDataType();
|
||||
|
||||
if (x_type != y_type) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"For primitive[%s]'s input arguments x should have the same "
|
||||
"data type with output arguments y, but y type is [%s], x type is "
|
||||
"[%s].",
|
||||
kInstanceNormV2, DTypeStr(y_type).c_str(), DTypeStr(x_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
const std::map<std::string, DataType> types = {{"gamma", gamma_type},
|
||||
{"beta", beta_type},
|
||||
{"mean", mean_type},
|
||||
{"variance", variance_type},
|
||||
{"batch_mean", batch_mean_type},
|
||||
{"batch_variance", batch_variance_type}};
|
||||
return CheckTensorTypeSame(types, DT_FLOAT, kInstanceNormV2);
|
||||
}
|
||||
|
||||
uint32_t InstanceNormV2CpuKernel::InstanceNormV2ShapeCheck(const CpuKernelContext &ctx) {
|
||||
auto x_shape_ptr = ctx.Input(InstanceNormV2InXIndex)->GetTensorShape();
|
||||
auto gamma_shape_ptr = ctx.Input(InstanceNormV2InGammaIndex)->GetTensorShape();
|
||||
auto beta_shape_ptr = ctx.Input(InstanceNormV2InBetaIndex)->GetTensorShape();
|
||||
auto mean_shape_ptr = ctx.Input(InstanceNormV2InMeanIndex)->GetTensorShape();
|
||||
auto variance_shape_ptr = ctx.Input(InstanceNormV2InVarianceIndex)->GetTensorShape();
|
||||
auto y_shape_ptr = ctx.Output(InstanceNormV2OutYIndex)->GetTensorShape();
|
||||
auto batch_mean_shape_ptr = ctx.Output(InstanceNormV2OutBatchMeanIndex)->GetTensorShape();
|
||||
auto batch_variance_shape_ptr = ctx.Output(InstanceNormV2OutBatchVarianceIndex)->GetTensorShape();
|
||||
|
||||
auto y_shape = y_shape_ptr->GetDimSizes();
|
||||
auto x_shape = x_shape_ptr->GetDimSizes();
|
||||
auto res = CheckTensorShapeSame({{"x", x_shape_ptr}}, y_shape, kInstanceNormV2);
|
||||
if (res != KERNEL_STATUS_OK) {
|
||||
return res;
|
||||
};
|
||||
auto x_format = x_shape_ptr->GetFormat();
|
||||
std::vector<int64_t> check_shape;
|
||||
check_shape = y_shape;
|
||||
int64_t image_size = 0;
|
||||
if (x_shape.size() == kDim4) {
|
||||
if (x_format == FORMAT_NCHW || x_format == FORMAT_ND) {
|
||||
// consider (N, C, H, W) as (N*C, H, W, 1), similar to (N, H, W, C)
|
||||
check_shape[kFormatNCHWIndexH] = int64_init_one;
|
||||
check_shape[kFormatNCHWIndexW] = int64_init_one;
|
||||
image_size = x_shape[kFormatNCHWIndexH] * x_shape[kFormatNCHWIndexW];
|
||||
instance_num = x_shape[kFormatNCHWIndexN] * x_shape[kFormatNCHWIndexC];
|
||||
constexpr int64_t kNumberOne = 1;
|
||||
x_shape_4d_ = {x_shape[kFormatNCHWIndexN] * x_shape[kFormatNCHWIndexC], x_shape[kFormatNCHWIndexH],
|
||||
x_shape[kFormatNCHWIndexW], kNumberOne};
|
||||
batch_channels_2d_ = {x_shape[kFormatNCHWIndexN] * x_shape[kFormatNCHWIndexC], kNumberOne};
|
||||
} else if (x_format == FORMAT_NHWC) {
|
||||
check_shape[kFormatNHWCIndexH] = int64_init_one;
|
||||
check_shape[kFormatNHWCIndexW] = int64_init_one;
|
||||
image_size = x_shape[kFormatNHWCIndexH] * x_shape[kFormatNHWCIndexW];
|
||||
instance_num = x_shape[kFormatNHWCIndexN] * x_shape[kFormatNHWCIndexC];
|
||||
x_shape_4d_ = x_shape;
|
||||
batch_channels_2d_ = {x_shape[kFormatNHWCIndexN], x_shape[kFormatNHWCIndexC]};
|
||||
}
|
||||
} else if (x_format == FORMAT_NC1HWC0 || x_shape.size() == kDim5) {
|
||||
// consider (N, C1, H, W, C0) as (N*C1, H, W, C0), similar to (N, H, W, C)
|
||||
check_shape[kFormatNC1HWC0IndexH] = int64_init_one;
|
||||
check_shape[kFormatNC1HWC0IndexW] = int64_init_one;
|
||||
image_size = x_shape[kFormatNC1HWC0IndexH] * x_shape[kFormatNC1HWC0IndexW];
|
||||
instance_num = x_shape[kFormatNC1HWC0IndexN] * x_shape[kFormatNC1HWC0IndexC1] * x_shape[kFormatNC1HWC0IndexC0];
|
||||
x_shape_4d_ = {x_shape[kFormatNC1HWC0IndexN] * x_shape[kFormatNC1HWC0IndexC1], x_shape[kFormatNC1HWC0IndexH],
|
||||
x_shape[kFormatNC1HWC0IndexW], x_shape[kFormatNC1HWC0IndexC0]};
|
||||
batch_channels_2d_ = {x_shape[kFormatNC1HWC0IndexN] * x_shape[kFormatNC1HWC0IndexC1],
|
||||
x_shape[kFormatNC1HWC0IndexC0]};
|
||||
} else {
|
||||
KERNEL_LOG_ERROR(
|
||||
"For primitive[%s]'s input arguments x only "
|
||||
"support NHWC, NCHW and NC1HWC0, but get data format [%s]",
|
||||
kInstanceNormV2, FormatToSerialString(x_format).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
constexpr int64_t image_min = 1;
|
||||
if (image_size <= image_min) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"For primitive[%s], expected more than 1 value per instance, but get "
|
||||
"[%ld] value per instance.",
|
||||
kInstanceNormV2, image_size);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
const std::map<std::string, TensorShapePtr> shapes = {{"gamma", gamma_shape_ptr},
|
||||
{"beta", beta_shape_ptr},
|
||||
{"mean", mean_shape_ptr},
|
||||
{"variance", variance_shape_ptr},
|
||||
{"batch_mean", batch_mean_shape_ptr},
|
||||
{"batch_variance", batch_variance_shape_ptr}};
|
||||
return CheckTensorShapeSame(shapes, check_shape, kInstanceNormV2);
|
||||
}
|
||||
|
||||
uint32_t InstanceNormV2CpuKernel::InstanceNormV2AttrCheck(const CpuKernelContext &ctx) {
|
||||
constexpr float epsilon_min = 0.0;
|
||||
constexpr float epsilon_max = 1.0;
|
||||
auto epsilon_ptr = ctx.GetAttr("epsilon");
|
||||
if (epsilon_ptr) {
|
||||
epsilon_ = epsilon_ptr->GetFloat();
|
||||
}
|
||||
if (epsilon_ < epsilon_min || epsilon_ >= epsilon_max) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"For primitive[%s], attr epsilon value should be in [0, 1), but get "
|
||||
"[%f].",
|
||||
kInstanceNormV2, epsilon_);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
auto momentum_ptr = ctx.GetAttr("momentum");
|
||||
if (momentum_ptr) {
|
||||
momentum_ = momentum_ptr->GetFloat();
|
||||
}
|
||||
if (momentum_ < epsilon_min || momentum_ > epsilon_max) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"For primitive[%s], attr momentum value should be in [0, 1], but get "
|
||||
"[%f].",
|
||||
kInstanceNormV2, momentum_);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
auto is_training_ptr = ctx.GetAttr("is_training");
|
||||
if (is_training_ptr) {
|
||||
is_training_ = is_training_ptr->GetBool();
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t InstanceNormV2CpuKernel::InstanceNormV2ParamCheck(const CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(InstanceNormV2TypeCheck(ctx), "InstanceNormV2 check type failed.");
|
||||
KERNEL_HANDLE_ERROR(InstanceNormV2ShapeCheck(ctx), "InstanceNormV2 check shape failed.");
|
||||
KERNEL_HANDLE_ERROR(InstanceNormV2AttrCheck(ctx), "InstanceNormV2 check attr failed.");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t InstanceNormV2CpuKernel::CollectStatsKernel(const CpuKernelContext &ctx, float *_mean_, float *_var_sum) {
|
||||
const int64_t batch = x_shape_4d_[kFormatNHWCIndexN];
|
||||
const int64_t channel = x_shape_4d_[kFormatNHWCIndexC];
|
||||
const int64_t image_size = x_shape_4d_[kFormatNHWCIndexH] * x_shape_4d_[kFormatNHWCIndexW];
|
||||
KERNEL_CHECK_FALSE((channel != 0), KERNEL_STATUS_PARAM_INVALID, "Channel can not be zero!");
|
||||
KERNEL_CHECK_FALSE((image_size != 0), KERNEL_STATUS_PARAM_INVALID, "image_size can not be zero!");
|
||||
std::vector<int64_t> shape_3d = {batch, image_size, channel};
|
||||
auto x_3d = EigenTensor(shape_3d, ctx.Input(InstanceNormV2InXIndex)->GetData()).tensor<T, kDim3>();
|
||||
auto loop_batch = [&](int64_t begin, int64_t end) {
|
||||
for (int64_t batch_idx = begin; batch_idx < end; ++batch_idx) {
|
||||
for (int64_t channel_idx = 0; channel_idx < channel; ++channel_idx) {
|
||||
// compute mean per input
|
||||
double sum = double_init_zero;
|
||||
for (int64_t idx = 0; idx < image_size; ++idx) {
|
||||
sum += static_cast<double>(x_3d(batch_idx, idx, channel_idx));
|
||||
}
|
||||
double cur_mean = sum / static_cast<double>(image_size);
|
||||
_mean_[batch_idx * channel + channel_idx] = static_cast<float>(cur_mean);
|
||||
// compute variance per input
|
||||
double cur_var_sum = double_init_zero;
|
||||
for (int64_t idx = 0; idx < image_size; ++idx) {
|
||||
double cur_piexl = static_cast<double>(x_3d(batch_idx, idx, channel_idx));
|
||||
cur_var_sum += (cur_piexl - cur_mean) * (cur_piexl - cur_mean);
|
||||
}
|
||||
_var_sum[batch_idx * channel + channel_idx] = static_cast<float>(cur_var_sum);
|
||||
}
|
||||
}
|
||||
};
|
||||
int64_t block_size = std::max(int64_init_one, (kGrainSize / (channel * image_size)));
|
||||
|
||||
return CpuKernelUtils::ParallelFor(ctx, batch, block_size, loop_batch);
|
||||
}
|
||||
|
||||
template <typename T, template <typename S> class VarTransform>
|
||||
uint32_t InstanceNormV2CpuKernel::UpdateStatsTemplate(const CpuKernelContext &ctx) {
|
||||
std::vector<float> _var_sum(instance_num, float_init_zero);
|
||||
std::vector<float> _mean_(instance_num, float_init_zero);
|
||||
(void)CollectStatsKernel<T>(ctx, _mean_.data(), _var_sum.data());
|
||||
const int64_t image_size = x_shape_4d_[kFormatNHWCIndexH] * x_shape_4d_[kFormatNHWCIndexW];
|
||||
KERNEL_CHECK_FALSE((image_size != 0), KERNEL_STATUS_PARAM_INVALID, "image_size can not be zero!");
|
||||
std::vector<int64_t> batch_channels_1d_ = {batch_channels_2d_.front() * batch_channels_2d_.back()};
|
||||
auto running_mean_vec = EigenTensor(batch_channels_1d_, ctx.Input(InstanceNormV2InMeanIndex)->GetData()).vec<float>();
|
||||
auto running_var_vec =
|
||||
EigenTensor(batch_channels_1d_, ctx.Input(InstanceNormV2InVarianceIndex)->GetData()).vec<float>();
|
||||
auto save_mean_vec =
|
||||
EigenTensor(batch_channels_1d_, ctx.Output(InstanceNormV2OutBatchMeanIndex)->GetData()).vec<float>();
|
||||
auto save_var_vec =
|
||||
EigenTensor(batch_channels_1d_, ctx.Output(InstanceNormV2OutBatchVarianceIndex)->GetData()).vec<float>();
|
||||
auto loop_momentum = [&](int64_t begin, int64_t end) {
|
||||
for (int64_t idx = begin; idx < end; ++idx) {
|
||||
save_mean_vec(idx) = _mean_[idx];
|
||||
save_var_vec(idx) =
|
||||
VarTransform<double>{}(static_cast<double>(_var_sum[idx]) / static_cast<double>(image_size), epsilon_);
|
||||
running_mean_vec(idx) =
|
||||
static_cast<float>(momentum_ * static_cast<double>(_mean_[idx]) +
|
||||
(double_init_one - momentum_) * static_cast<double>(running_mean_vec(idx)));
|
||||
double unbiased_var = double_init_zero;
|
||||
if (image_size - int64_init_one == 0) {
|
||||
unbiased_var = static_cast<double>(_var_sum[idx]);
|
||||
} else {
|
||||
unbiased_var = static_cast<double>(_var_sum[idx]) / static_cast<double>(image_size - int64_init_one);
|
||||
}
|
||||
running_var_vec(idx) = static_cast<float>(momentum_ * unbiased_var + (double_init_one - momentum_) *
|
||||
static_cast<double>(running_var_vec(idx)));
|
||||
}
|
||||
};
|
||||
return CpuKernelUtils::ParallelFor(ctx, instance_num, kGrainSize, loop_momentum);
|
||||
}
|
||||
|
||||
uint32_t InstanceNormV2CpuKernel::CollectLinearAndConstant(
|
||||
const CpuKernelContext &ctx, const typename TTypes<float>::Vec &gamma, const typename TTypes<float>::Vec &beta,
|
||||
const typename TTypes<float>::Vec &running_mean, const typename TTypes<float>::Vec &running_var,
|
||||
const typename TTypes<float>::Vec &save_mean, const typename TTypes<float>::Vec &save_invstd, float *_alpha_,
|
||||
float *_beta_) {
|
||||
auto loop_instance = [&](int64_t begin, int64_t end) {
|
||||
for (int64_t idx = begin; idx < end; ++idx) {
|
||||
float mean = float_init_zero, invstd = float_init_zero;
|
||||
if (is_training_) {
|
||||
mean = save_mean(idx);
|
||||
invstd = save_invstd(idx);
|
||||
} else {
|
||||
mean = running_mean(idx);
|
||||
float _std_ = std::sqrt(running_var(idx) + static_cast<float>(epsilon_));
|
||||
KERNEL_CHECK_FALSE((_std_ != 0), KERNEL_STATUS_PARAM_INVALID, "_std_ can not be zero!");
|
||||
invstd = float_init_one / _std_;
|
||||
}
|
||||
_alpha_[idx] = invstd * gamma(idx);
|
||||
_beta_[idx] = beta(idx) - mean * _alpha_[idx];
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
};
|
||||
return CpuKernelUtils::ParallelFor(ctx, instance_num, kGrainSize, loop_instance);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t InstanceNormV2CpuKernel::TransformInput(const CpuKernelContext &ctx) {
|
||||
const int64_t batch = x_shape_4d_[kFormatNHWCIndexN];
|
||||
const int64_t channel = x_shape_4d_[kFormatNHWCIndexC];
|
||||
const int64_t image_size = x_shape_4d_[kFormatNHWCIndexH] * x_shape_4d_[kFormatNHWCIndexW];
|
||||
std::vector<float> _alpha_(instance_num, float_init_zero);
|
||||
std::vector<float> _beta_(instance_num, float_init_zero);
|
||||
std::vector<int64_t> batch_channels_1d_ = {batch_channels_2d_.front() * batch_channels_2d_.back()};
|
||||
auto gamma = EigenTensor(batch_channels_1d_, ctx.Input(InstanceNormV2InGammaIndex)->GetData()).vec<float>();
|
||||
auto beta = EigenTensor(batch_channels_1d_, ctx.Input(InstanceNormV2InBetaIndex)->GetData()).vec<float>();
|
||||
auto running_mean = EigenTensor(batch_channels_1d_, ctx.Input(InstanceNormV2InMeanIndex)->GetData()).vec<float>();
|
||||
auto running_var = EigenTensor(batch_channels_1d_, ctx.Input(InstanceNormV2InVarianceIndex)->GetData()).vec<float>();
|
||||
auto save_mean = EigenTensor(batch_channels_1d_, ctx.Output(InstanceNormV2OutBatchMeanIndex)->GetData()).vec<float>();
|
||||
auto save_invstd =
|
||||
EigenTensor(batch_channels_1d_, ctx.Output(InstanceNormV2OutBatchVarianceIndex)->GetData()).vec<float>();
|
||||
CollectLinearAndConstant(ctx, gamma, beta, running_mean, running_var, save_mean, save_invstd, _alpha_.data(),
|
||||
_beta_.data());
|
||||
// cast (B, H, W, C) to (B, H*W, C)
|
||||
std::vector<int64_t> shape_3d = {batch, image_size, channel};
|
||||
auto x_3d = EigenTensor(shape_3d, ctx.Input(InstanceNormV2InXIndex)->GetData()).tensor<T, kDim3>();
|
||||
auto y_3d = EigenTensor(shape_3d, ctx.Output(InstanceNormV2OutYIndex)->GetData()).tensor<T, kDim3>();
|
||||
// Apply the linear terms to the input,
|
||||
auto loop_transform = [&](int64_t begin, int64_t end) {
|
||||
for (int64_t batch_idx = begin; batch_idx < end; ++batch_idx) {
|
||||
for (int64_t idx = 0; idx < image_size; ++idx) {
|
||||
for (int64_t channel_idx = 0; channel_idx < channel; ++channel_idx) {
|
||||
float alpha = _alpha_[batch_idx * channel + channel_idx];
|
||||
float beta = _beta_[batch_idx * channel + channel_idx];
|
||||
y_3d(batch_idx, idx, channel_idx) =
|
||||
static_cast<T>(alpha * static_cast<float>(x_3d(batch_idx, idx, channel_idx)) + beta);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
int64_t block_size = std::max(int64_init_one, (kGrainSize / (channel * image_size)));
|
||||
return CpuKernelUtils::ParallelFor(ctx, batch, block_size, loop_transform);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t InstanceNormV2CpuKernel::DoCompute(CpuKernelContext &ctx) {
|
||||
auto batch_mean_ptr = static_cast<float *>(ctx.Output(InstanceNormV2OutBatchMeanIndex)->GetData());
|
||||
auto batch_var_ptr = static_cast<float *>(ctx.Output(InstanceNormV2OutBatchVarianceIndex)->GetData());
|
||||
(void)std::fill_n(batch_mean_ptr, instance_num, float_init_zero);
|
||||
(void)std::fill_n(batch_var_ptr, instance_num, float_init_zero);
|
||||
|
||||
if (is_training_) {
|
||||
// UpdateStatsTemplate to init save_mean and save_var
|
||||
(void)UpdateStatsTemplate<T, InvStd>(ctx);
|
||||
}
|
||||
return TransformInput<T>(ctx);
|
||||
}
|
||||
|
||||
uint32_t InstanceNormV2CpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
// check params
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInstanceNormV2InputsNum, kInstanceNormV2OutputNum),
|
||||
"InstanceNormV2 check input and output number failed.");
|
||||
KERNEL_HANDLE_ERROR(InstanceNormV2ParamCheck(ctx), "InstanceNormV2 check params failed.");
|
||||
auto data_type = ctx.Input(InstanceNormV2InXIndex)->GetDataType();
|
||||
uint32_t result;
|
||||
switch (data_type) {
|
||||
case (DT_FLOAT16):
|
||||
result = DoCompute<Eigen::half>(ctx);
|
||||
break;
|
||||
case (DT_FLOAT):
|
||||
result = DoCompute<float>(ctx);
|
||||
break;
|
||||
default:
|
||||
KERNEL_LOG_ERROR("InstanceNormV2 kernel data type [%s] not support.", DTypeStr(data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (result != KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_ERROR("InstanceNormV2 kernel compute failed.");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(kInstanceNormV2, InstanceNormV2CpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_INSTANCE_NORM_V2_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_INSTANCE_NORM_V2_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
#include <vector>
|
||||
#include "utils/eigen_tensor.h"
|
||||
|
||||
namespace aicpu {
|
||||
class InstanceNormV2CpuKernel : public CpuKernel {
|
||||
public:
|
||||
InstanceNormV2CpuKernel() = default;
|
||||
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t InstanceNormV2ParamCheck(const CpuKernelContext &ctx);
|
||||
uint32_t InstanceNormV2ShapeCheck(const CpuKernelContext &ctx);
|
||||
uint32_t InstanceNormV2TypeCheck(const CpuKernelContext &ctx);
|
||||
uint32_t InstanceNormV2AttrCheck(const CpuKernelContext &ctx);
|
||||
|
||||
template <typename T>
|
||||
uint32_t CollectStatsKernel(const CpuKernelContext &ctx, float *_mean_, float *_var_sum);
|
||||
|
||||
uint32_t CollectLinearAndConstant(const CpuKernelContext &ctx, const typename TTypes<float>::Vec &gamma,
|
||||
const typename TTypes<float>::Vec &beta,
|
||||
const typename TTypes<float>::Vec &running_mean,
|
||||
const typename TTypes<float>::Vec &running_var,
|
||||
const typename TTypes<float>::Vec &save_mean,
|
||||
const typename TTypes<float>::Vec &save_invstd, float *_alpha_, float *_beta_);
|
||||
|
||||
template <typename T>
|
||||
uint32_t TransformInput(const CpuKernelContext &ctx);
|
||||
|
||||
template <typename T, template <typename S> class VarTransform>
|
||||
uint32_t UpdateStatsTemplate(const CpuKernelContext &ctx);
|
||||
|
||||
template <typename T>
|
||||
uint32_t DoCompute(CpuKernelContext &ctx);
|
||||
|
||||
bool is_training_ = true;
|
||||
float momentum_ = 0.1;
|
||||
float epsilon_ = 0.00001;
|
||||
std::vector<int64_t> x_shape_4d_;
|
||||
std::vector<int64_t> batch_channels_2d_;
|
||||
int64_t instance_num = 0;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,161 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "list_diff.h"
|
||||
#include <unordered_set>
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "kernel_log.h"
|
||||
#include "securec.h"
|
||||
#include "status.h"
|
||||
#include "unsupported/Eigen/CXX11/Tensor"
|
||||
#include "utils/allocator_utils.h"
|
||||
#include "utils/eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
const char *kListDiff = "ListDiff";
|
||||
constexpr uint32_t kListDiffInputNum = 2;
|
||||
constexpr uint32_t kListDiffOutputNum = 2;
|
||||
|
||||
#define LIST_DIFF_COMPUTE_CASE(DTYPE, TYPE, OUT_IDX, CTX) \
|
||||
case (DTYPE): { \
|
||||
uint32_t result = KERNEL_STATUS_INNER_ERROR; \
|
||||
if ((OUT_IDX) == DT_INT32) { \
|
||||
result = DoCompute<TYPE, int32_t>(CTX); \
|
||||
} else { \
|
||||
result = DoCompute<TYPE, int64_t>(CTX); \
|
||||
} \
|
||||
if (result != KERNEL_STATUS_OK) { \
|
||||
KERNEL_LOG_ERROR("Less kernel compute failed."); \
|
||||
return result; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t ListDiffCpuKernel::ParamCheck(CpuKernelContext &ctx) {
|
||||
// check input number and output number
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kListDiffInputNum, kListDiffOutputNum), "[%s] check params failed.", kListDiff);
|
||||
// get all input and output
|
||||
const Tensor *x = ctx.Input(0);
|
||||
const Tensor *y = ctx.Input(1);
|
||||
const Tensor *out = ctx.Output(0);
|
||||
const Tensor *idx = ctx.Output(1);
|
||||
|
||||
// input tensor must be 1D vector
|
||||
KERNEL_CHECK_FALSE(IsVector(x->GetTensorShape()->GetDimSizes()), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Input Tensor x should be a 1D vector.");
|
||||
KERNEL_CHECK_FALSE(IsVector(y->GetTensorShape()->GetDimSizes()), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Input Tensor y should be a 1D vector.");
|
||||
// out_idx type check
|
||||
AttrValue *out_idx_att = ctx.GetAttr("out_idx");
|
||||
if (out_idx_att) {
|
||||
// private value out_idx store out_idx
|
||||
out_idx = out_idx_att->GetDataType();
|
||||
KERNEL_CHECK_FALSE((out_idx == DT_INT32 || out_idx == DT_INT64), KERNEL_STATUS_PARAM_INVALID,
|
||||
"attr value 'out_idx_att' should in (DT_INT32, DT_INT64)");
|
||||
}
|
||||
// datype check for x, y, out, idx
|
||||
KERNEL_CHECK_FALSE(x->GetDataType() == y->GetDataType() && y->GetDataType() == out->GetDataType(),
|
||||
KERNEL_STATUS_PARAM_INVALID, "The DataType of input x and y should be same");
|
||||
KERNEL_CHECK_FALSE(idx->GetDataType() == out_idx, KERNEL_STATUS_PARAM_INVALID,
|
||||
"The DataType of idx should be out_idx");
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T, typename Tidx>
|
||||
uint32_t ListDiffCpuKernel::DoCompute(CpuKernelContext &ctx) {
|
||||
Tensor *x = ctx.Input(0);
|
||||
Tensor *y = ctx.Input(1);
|
||||
Tensor *out = ctx.Output(0);
|
||||
Tensor *idx = ctx.Output(1);
|
||||
// construct EigenTensor
|
||||
EigenTensor x_et(x, x->GetData());
|
||||
EigenTensor y_et(y, y->GetData());
|
||||
|
||||
const auto x_vec = x_et.vec<T>();
|
||||
const size_t x_size = x_vec.size();
|
||||
const auto y_vec = y_et.vec<T>();
|
||||
const size_t y_size = y_vec.size();
|
||||
|
||||
std::unordered_set<T> y_set;
|
||||
y_set.reserve(y_size);
|
||||
for (size_t i = 0; i < y_size; ++i) {
|
||||
y_set.insert(y_vec(i));
|
||||
}
|
||||
|
||||
// Compute the size of the output.
|
||||
uint64_t out_size = 0;
|
||||
for (size_t i = 0; i < x_size; ++i) {
|
||||
if (0 == y_set.count(x_vec(i))) {
|
||||
++out_size;
|
||||
}
|
||||
}
|
||||
// allocate memory for out and idx
|
||||
DataType out_type = out->GetDataType();
|
||||
// this function just allocate memory. Must update TensorShape by hands
|
||||
uint32_t ret = CpuKernelAllocatorUtils::AllocateOutputTensorDataMemory({out_size}, out_type, out);
|
||||
KERNEL_CHECK_FALSE((ret == KERNEL_STATUS_OK), KERNEL_STATUS_INNER_ERROR, "Allocate memory for out tensor failed.");
|
||||
ret = CpuKernelAllocatorUtils::AllocateOutputTensorDataMemory({out_size}, out_idx, idx);
|
||||
KERNEL_CHECK_FALSE((ret == KERNEL_STATUS_OK), KERNEL_STATUS_INNER_ERROR, "Allocate memory for idx tensor failed.")
|
||||
// construct EigenTensor
|
||||
EigenTensor out_et(out, out->GetData());
|
||||
EigenTensor idx_et(idx, idx->GetData());
|
||||
auto out_vec = out_et.vec<T>();
|
||||
auto idx_vec = idx_et.vec<Tidx>();
|
||||
|
||||
// calculate results
|
||||
for (Tidx i = 0, p = 0; i < static_cast<Tidx>(x_size); ++i) {
|
||||
if (0 == y_set.count(x_vec(i))) {
|
||||
KERNEL_CHECK_FALSE(p < static_cast<Tidx>(out_size), KERNEL_STATUS_INNER_ERROR,
|
||||
"Tried to set output index failure for index out of out_size");
|
||||
out_vec(p) = x_vec(i);
|
||||
idx_vec(p) = i;
|
||||
p++;
|
||||
}
|
||||
}
|
||||
// update out tensor shape information required by mindspore
|
||||
std::vector<int64_t> shapes = {static_cast<int64_t>(out_size)};
|
||||
auto out_shape = out->GetTensorShape();
|
||||
out_shape->SetDimSizes(shapes);
|
||||
auto idx_shape = idx->GetTensorShape();
|
||||
idx_shape->SetDimSizes(shapes);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t ListDiffCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(ParamCheck(ctx), "ListDiffCpuKernel check params failed");
|
||||
auto data_type = ctx.Input(0)->GetDataType();
|
||||
switch (data_type) {
|
||||
LIST_DIFF_COMPUTE_CASE(DT_INT8, int8_t, out_idx, ctx)
|
||||
LIST_DIFF_COMPUTE_CASE(DT_INT16, int16_t, out_idx, ctx)
|
||||
LIST_DIFF_COMPUTE_CASE(DT_INT32, int32_t, out_idx, ctx)
|
||||
LIST_DIFF_COMPUTE_CASE(DT_INT64, int64_t, out_idx, ctx)
|
||||
LIST_DIFF_COMPUTE_CASE(DT_UINT8, uint8_t, out_idx, ctx)
|
||||
LIST_DIFF_COMPUTE_CASE(DT_UINT16, uint16_t, out_idx, ctx)
|
||||
LIST_DIFF_COMPUTE_CASE(DT_FLOAT16, Eigen::half, out_idx, ctx)
|
||||
LIST_DIFF_COMPUTE_CASE(DT_FLOAT, float, out_idx, ctx)
|
||||
LIST_DIFF_COMPUTE_CASE(DT_DOUBLE, double, out_idx, ctx)
|
||||
default:
|
||||
KERNEL_LOG_ERROR("ListDiff kernel data type [%s] not support.", DTypeStr(data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(kListDiff, ListDiffCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_LIST_DIFF_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_LIST_DIFF_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
#include "cpu_types.h"
|
||||
|
||||
namespace aicpu {
|
||||
class ListDiffCpuKernel : public CpuKernel {
|
||||
public:
|
||||
ListDiffCpuKernel() = default;
|
||||
~ListDiffCpuKernel() override = default;
|
||||
|
||||
protected:
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t ParamCheck(CpuKernelContext &ctx);
|
||||
|
||||
template <typename T, typename Tidx>
|
||||
uint32_t DoCompute(CpuKernelContext &ctx);
|
||||
// default DT_INT32
|
||||
DataType out_idx = DT_INT32;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,132 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "mvlgamma.h"
|
||||
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "igamma_utils.cc"
|
||||
#include "igamma_utils.h"
|
||||
#include "utils/eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
const char *kMvlgamma = "Mvlgamma";
|
||||
|
||||
#define MVLGAMMA_COMPUTE_CASE(DTYPE, TYPE, CTX) \
|
||||
case (DTYPE): { \
|
||||
uint32_t result = MvlgammaCompute<TYPE>(CTX); \
|
||||
if (result != KERNEL_STATUS_OK) { \
|
||||
KERNEL_LOG_ERROR("Mvlgamma kernel compute failed."); \
|
||||
return result; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
|
||||
constexpr double HALF = 0.5;
|
||||
constexpr double QUARTER = 0.25;
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t MvlgammaCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
// check params
|
||||
KERNEL_HANDLE_ERROR(MvlgammaCheck(ctx), "Mvlgamma check params failed.");
|
||||
|
||||
const Tensor *input_x = ctx.Input(0);
|
||||
auto data_type = input_x->GetDataType();
|
||||
|
||||
switch (data_type) {
|
||||
MVLGAMMA_COMPUTE_CASE(DT_FLOAT, float, ctx)
|
||||
MVLGAMMA_COMPUTE_CASE(DT_DOUBLE, double, ctx)
|
||||
default:
|
||||
KERNEL_LOG_ERROR("Mvlgamma kernel data type [%s] not support.", DTypeStr(data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t MvlgammaCpuKernel::MvlgammaCheck(CpuKernelContext &ctx) {
|
||||
// check input, output and attr not null
|
||||
KERNEL_CHECK_NULLPTR(ctx.Input(0)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input data failed.")
|
||||
KERNEL_CHECK_NULLPTR(ctx.Output(0)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output data failed")
|
||||
KERNEL_CHECK_NULLPTR(ctx.GetAttr("p"), KERNEL_STATUS_PARAM_INVALID, "Get attr failed.")
|
||||
NormalCheck(ctx, 1, 1, {"p"});
|
||||
|
||||
// check input and output datatype as the same
|
||||
DataType input_datatype = ctx.Input(0)->GetDataType();
|
||||
DataType output_datatype = ctx.Output(0)->GetDataType();
|
||||
KERNEL_CHECK_FALSE((input_datatype == output_datatype), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Input data type[%d] must be the same as Output data type[%d].", input_datatype, output_datatype)
|
||||
|
||||
auto attr_value = ctx.GetAttr("p")->GetInt();
|
||||
KERNEL_CHECK_FALSE((attr_value >= 1), KERNEL_STATUS_PARAM_INVALID, "p has to be greater than or equal to 1[%lld]",
|
||||
attr_value) // 已经用GetAttr获取
|
||||
|
||||
KERNEL_LOG_INFO("MvlgammaCpuKernel[%s], input: size[%llu], output: size[%llu].", ctx.GetOpType().c_str(),
|
||||
ctx.Input(0)->GetDataSize(), ctx.Output(0)->GetDataSize());
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T MvlgammaCpuKernel::MvlgammaSingle(T &x, const int &p, bool &error) {
|
||||
if (!(x > HALF * (p - 1))) {
|
||||
error = true;
|
||||
KERNEL_LOG_ERROR("All elements of `x` must be greater than (p-1)/2");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
const auto p2_sub_p = static_cast<T>(p * (p - 1));
|
||||
T output = p2_sub_p * std::log(M_PI) * QUARTER;
|
||||
for (int i = 0; i < p; i++) {
|
||||
output += Lgamma(x - HALF * i);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t MvlgammaCpuKernel::MvlgammaCompute(CpuKernelContext &ctx) {
|
||||
auto input_x = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
auto output_y = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
auto attr_p = ctx.GetAttr("p")->GetInt();
|
||||
|
||||
auto input0_shape = ctx.Input(0)->GetTensorShape();
|
||||
int64_t data_num = input0_shape->NumElements();
|
||||
uint32_t min_core_num = 1;
|
||||
int64_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
|
||||
if (max_core_num > data_num) {
|
||||
max_core_num = data_num;
|
||||
}
|
||||
|
||||
bool error = false;
|
||||
auto shard_mvlgamma = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
*(output_y + i) = MvlgammaSingle<T>(*(input_x + i), attr_p, error);
|
||||
}
|
||||
};
|
||||
|
||||
if (max_core_num == 0) {
|
||||
KERNEL_LOG_ERROR("max_core_num could not be 0,");
|
||||
}
|
||||
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, shard_mvlgamma),
|
||||
"Mvlgamma Compute failed.");
|
||||
if (error == true) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
} else {
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(kMvlgamma, MvlgammaCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_MVLGAMMA_UTILS_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_MVLGAMMA_UTILS_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
#include "utils/bcast.h"
|
||||
|
||||
namespace aicpu {
|
||||
class MvlgammaCpuKernel : public CpuKernel {
|
||||
public:
|
||||
MvlgammaCpuKernel() = default;
|
||||
~MvlgammaCpuKernel() override = default;
|
||||
|
||||
protected:
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
static uint32_t MvlgammaCheck(CpuKernelContext &ctx);
|
||||
|
||||
template <typename T>
|
||||
static T MvlgammaSingle(T &x, const int &p, bool &error);
|
||||
|
||||
template <typename T>
|
||||
static uint32_t MvlgammaCompute(CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,131 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "mvlgamma_grad.h"
|
||||
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "igamma_utils.cc"
|
||||
#include "igamma_utils.h"
|
||||
#include "utils/eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
const char *kMvlgammaGrad = "MvlgammaGrad";
|
||||
|
||||
#define MVLGAMMAGRAD_COMPUTE_CASE(DTYPE, TYPE, CTX) \
|
||||
case (DTYPE): { \
|
||||
uint32_t result = MvlgammaGradCompute<TYPE>(CTX); \
|
||||
if (result != KERNEL_STATUS_OK) { \
|
||||
KERNEL_LOG_ERROR("MvlgammaGrad kernel compute failed."); \
|
||||
return result; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
|
||||
constexpr double HALF = 0.5;
|
||||
constexpr double QUARTER = 0.25;
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t MvlgammaGradCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
// check params
|
||||
KERNEL_HANDLE_ERROR(MvlgammaGradCheck(ctx), "MvlgammaGrad check params failed.");
|
||||
|
||||
auto data_type = ctx.Input(0)->GetDataType();
|
||||
switch (data_type) {
|
||||
MVLGAMMAGRAD_COMPUTE_CASE(DT_FLOAT, float, ctx)
|
||||
MVLGAMMAGRAD_COMPUTE_CASE(DT_DOUBLE, double, ctx)
|
||||
default:
|
||||
KERNEL_LOG_ERROR("MvlgammaGrad kernel data type [%s] not support.", DTypeStr(data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t MvlgammaGradCpuKernel::MvlgammaGradCheck(CpuKernelContext &ctx) {
|
||||
// check input, output and attr not null
|
||||
KERNEL_CHECK_NULLPTR(ctx.Input(0)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 0 data failed.")
|
||||
KERNEL_CHECK_NULLPTR(ctx.Input(1)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 1 data failed.")
|
||||
KERNEL_CHECK_NULLPTR(ctx.Output(0)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output data failed")
|
||||
KERNEL_CHECK_NULLPTR(ctx.GetAttr("p"), KERNEL_STATUS_PARAM_INVALID, "Get attr failed.")
|
||||
NormalCheck(ctx, 2, 1, {"p"});
|
||||
|
||||
// check input and output datatype as the same
|
||||
DataType input0_type = ctx.Input(0)->GetDataType();
|
||||
DataType input1_type = ctx.Input(1)->GetDataType();
|
||||
DataType output_type = ctx.Output(0)->GetDataType();
|
||||
KERNEL_CHECK_FALSE((input0_type == input1_type), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The data type of input0 [%d] need be same with "
|
||||
"input1 [%d].",
|
||||
input0_type, input1_type)
|
||||
KERNEL_CHECK_FALSE((input0_type == output_type), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The data type of input0 [%d] need be same with "
|
||||
"output [%d].",
|
||||
input0_type, output_type)
|
||||
|
||||
auto attr_value = ctx.GetAttr("p")->GetInt();
|
||||
KERNEL_CHECK_FALSE((attr_value >= 1), KERNEL_STATUS_PARAM_INVALID, "p has to be greater than or equal to 1[%lld]",
|
||||
attr_value) // 已经用GetAttr获取
|
||||
|
||||
KERNEL_LOG_INFO(
|
||||
"MvlgammaGradCpuKernel[%s], input0: size[%llu];"
|
||||
"input1: size[%llu], output: size[%llu].",
|
||||
ctx.GetOpType().c_str(), ctx.Input(0)->GetDataSize(), ctx.Input(1)->GetDataSize(), ctx.Output(0)->GetDataSize());
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T MvlgammaGradCpuKernel::MvlgammaGradSingle(T &y_grad, T &x, const int &p) {
|
||||
T output = 0;
|
||||
for (int i = 0; i < p; i++) {
|
||||
output += Digamma(x - HALF * i);
|
||||
}
|
||||
output *= y_grad;
|
||||
return output;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t MvlgammaGradCpuKernel::MvlgammaGradCompute(CpuKernelContext &ctx) {
|
||||
auto input_y_grad = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
auto input_x = reinterpret_cast<T *>(ctx.Input(1)->GetData());
|
||||
auto output_x_grad = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
auto attr_p = ctx.GetAttr("p")->GetInt();
|
||||
|
||||
auto input0_shape = ctx.Input(0)->GetTensorShape();
|
||||
int64_t data_num = input0_shape->NumElements();
|
||||
uint32_t min_core_num = 1;
|
||||
int64_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
|
||||
if (max_core_num > data_num) {
|
||||
max_core_num = data_num;
|
||||
}
|
||||
|
||||
auto shard_mvlgammagrad = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
*(output_x_grad + i) = MvlgammaGradSingle<T>(*(input_y_grad + i), *(input_x + i), attr_p);
|
||||
}
|
||||
};
|
||||
|
||||
if (max_core_num == 0) {
|
||||
KERNEL_LOG_ERROR("max_core_num could not be 0,");
|
||||
}
|
||||
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, shard_mvlgammagrad),
|
||||
"MvlgammaGrad Compute failed.");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(kMvlgammaGrad, MvlgammaGradCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_MVLGAMMAGRAD_UTILS_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_MVLGAMMAGRAD_UTILS_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
#include "utils/bcast.h"
|
||||
|
||||
namespace aicpu {
|
||||
|
||||
class MvlgammaGradCpuKernel : public CpuKernel {
|
||||
public:
|
||||
MvlgammaGradCpuKernel() = default;
|
||||
~MvlgammaGradCpuKernel() override = default;
|
||||
|
||||
protected:
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
static uint32_t MvlgammaGradCheck(CpuKernelContext &ctx);
|
||||
|
||||
template <typename T>
|
||||
static T MvlgammaGradSingle(T &y_grad, T &x, const int &p);
|
||||
|
||||
template <typename T>
|
||||
static uint32_t MvlgammaGradCompute(CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,273 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "setsize.h"
|
||||
#include <securec.h>
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "status.h"
|
||||
#include "utils/kernel_util.h"
|
||||
using namespace std;
|
||||
|
||||
namespace {
|
||||
const uint32_t kOutputNum = 1;
|
||||
const uint32_t kInputNum = 3;
|
||||
const char *kSetSize = "SetSize";
|
||||
// when input data size is more than kParallelDataNum, use Parallel func
|
||||
const int64_t kParallelDataNumMid = 16 * 1024;
|
||||
const int64_t kParallelDataNumSameShape = 7 * 1024;
|
||||
|
||||
#define SETSIZE_COMPUTE_CASE(DTYPE, TYPE, CTX, ST) \
|
||||
case (DTYPE): { \
|
||||
uint32_t result; \
|
||||
result = SetSizeCompute<TYPE>(CTX, ST); \
|
||||
if (result != KERNEL_STATUS_OK) { \
|
||||
KERNEL_LOG_ERROR("SetSize kernel compute failed."); \
|
||||
return result; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t SetSizeCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
set_indices_ = ctx.Input(0);
|
||||
set_values_ = ctx.Input(1);
|
||||
int64_t input_index = 2;
|
||||
set_shape_ = ctx.Input(input_index);
|
||||
output_ = ctx.Output(0);
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "SetSize check input and output number failed.");
|
||||
// check dim
|
||||
KERNEL_CHECK_FALSE((set_indices_->GetTensorShape()->GetDims() == input_index), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Indices tensor dim size equal to 2, got size [%d].", set_indices_->GetTensorShape()->GetDims())
|
||||
KERNEL_CHECK_FALSE((set_values_->GetTensorShape()->GetDims() == 1), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Values tensor dim size equal to 1, got size [%d].", set_values_->GetTensorShape()->GetDims())
|
||||
KERNEL_CHECK_FALSE((set_shape_->GetTensorShape()->GetDims() == 1), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Shape tensor dim size equal to 1, got size [%d].", set_shape_->GetTensorShape()->GetDims())
|
||||
auto data_type_0 = set_indices_->GetDataType();
|
||||
auto data_type_1 = set_values_->GetDataType();
|
||||
auto data_type_2 = set_shape_->GetDataType();
|
||||
KERNEL_CHECK_FALSE((data_type_0 == DT_INT64), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The data type of input0 requested dtype int64 for Tensor "
|
||||
"with dtype [%s]",
|
||||
DTypeStr(data_type_0).c_str())
|
||||
KERNEL_CHECK_FALSE((data_type_2 == DT_INT64), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The data type of input2 requested dtype int64 for Tensor "
|
||||
"with dtype [%s]",
|
||||
DTypeStr(data_type_2).c_str())
|
||||
dims_ = set_indices_->GetTensorShape()->GetDimSize(1);
|
||||
AttrValue *validate_indices = ctx.GetAttr("validate_indices");
|
||||
validate_indices_ = (validate_indices == nullptr) ? true : (validate_indices->GetBool());
|
||||
SparseTensor st;
|
||||
if (SparseTensorFromContext(ctx, validate_indices->GetBool(), st) != KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_ERROR("Create sparse tensor failed.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
switch (data_type_1) {
|
||||
SETSIZE_COMPUTE_CASE(DT_INT8, int8_t, ctx, st)
|
||||
SETSIZE_COMPUTE_CASE(DT_INT16, int16_t, ctx, st)
|
||||
SETSIZE_COMPUTE_CASE(DT_INT32, int32_t, ctx, st)
|
||||
SETSIZE_COMPUTE_CASE(DT_INT64, int64_t, ctx, st)
|
||||
SETSIZE_COMPUTE_CASE(DT_UINT8, uint8_t, ctx, st)
|
||||
SETSIZE_COMPUTE_CASE(DT_UINT16, uint16_t, ctx, st)
|
||||
case DT_STRING:
|
||||
uint32_t result;
|
||||
result = SetSizeCompute_string(ctx, st);
|
||||
if (result != KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_ERROR("SetSize kernel compute failed.");
|
||||
return result;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
KERNEL_LOG_ERROR(
|
||||
"Value passed to parameter 'set_values_' has DataType [%s] not in "
|
||||
"list of allowed values: int8, int16, int32, int64, uint8, uint16, "
|
||||
"string.",
|
||||
DTypeStr(data_type_1).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t SetSizeCpuKernel::SparseTensorFromContext(CpuKernelContext &ctx, const bool validate_indices,
|
||||
SparseTensor &st) {
|
||||
auto sparse_shape = set_shape_->GetTensorShape();
|
||||
std::vector<int64_t> dense_shape;
|
||||
std::vector<int64_t> order;
|
||||
for (int32_t index = 0; index < sparse_shape->GetDimSize(0); ++index) {
|
||||
int64_t *temp_dim = reinterpret_cast<int64_t *>(set_shape_->GetData());
|
||||
dense_shape.emplace_back(temp_dim[index]);
|
||||
order.push_back(dense_shape[index]);
|
||||
}
|
||||
shape_.assign(dense_shape.begin(), dense_shape.end());
|
||||
order_.assign(order.begin(), order.end());
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
uint32_t result = st.CreateSparseTensor(set_indices_, set_values_, dense_shape, order);
|
||||
if (!validate_indices || result != KERNEL_STATUS_OK) {
|
||||
return result;
|
||||
}
|
||||
return IndicesValid(ctx, st);
|
||||
}
|
||||
|
||||
uint32_t SetSizeCpuKernel::IndicesValid(CpuKernelContext &ctx, SparseTensor &st) {
|
||||
int64_t dim_size =
|
||||
(set_indices_->GetTensorShape()->GetDims() == 0) ? 1 : set_indices_->GetTensorShape()->GetDimSize(0);
|
||||
if (dim_size >= kParallelDataNumSameShape) {
|
||||
uint32_t min_core_num = 1;
|
||||
uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
|
||||
if (dim_size <= kParallelDataNumMid) {
|
||||
max_core_num = std::min(max_core_num, 4U);
|
||||
}
|
||||
if (max_core_num > dim_size) {
|
||||
max_core_num = dim_size;
|
||||
}
|
||||
auto invalid_setsize = [&](int64_t start, int64_t end) {
|
||||
for (int64_t i = start; i < end; ++i) {
|
||||
if (st.EigenTensorIndicesValid<int64_t>(ctx) != KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_ERROR("Indices valid failed.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
};
|
||||
if (max_core_num == 0) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, dim_size, dim_size / max_core_num, invalid_setsize),
|
||||
"SetSize Compute failed.");
|
||||
} else {
|
||||
for (int64_t n = 0; n < dim_size; ++n) {
|
||||
if (st.EigenTensorIndicesValid<int64_t>(ctx) != KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_ERROR("Indices valid failed.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t SetSizeCpuKernel::CheckGroup(CpuKernelContext &ctx, const Group &group,
|
||||
const std::vector<int64_t> &sparse_tensor_shape) {
|
||||
const int64_t num_values = ctx.Input(0)->GetTensorShape()->GetDimSize(0);
|
||||
const auto indices_t = reinterpret_cast<int64_t *>(ctx.Input(0)->GetData());
|
||||
for (int32_t j = 0; j < dims_; ++j) {
|
||||
const auto dim_size = sparse_tensor_shape[j];
|
||||
KERNEL_CHECK_FALSE((dim_size > 0), KERNEL_STATUS_PARAM_INVALID, "Invalid dim_size [%d] = [%d].", j, dim_size)
|
||||
for (int64_t i = 0; i < num_values; ++i) {
|
||||
const auto index = *(indices_t + (i * dims_) + j);
|
||||
KERNEL_CHECK_FALSE((dim_size > 0), KERNEL_STATUS_PARAM_INVALID, "indices[%d, %d] expected < %d, got %d.", i, j,
|
||||
dim_size, index)
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t SetSizeCpuKernel::PopulateFromSparseGroup(CpuKernelContext &ctx, const Group &group,
|
||||
const std::vector<int64_t> &sparse_tensor_shape,
|
||||
std::unordered_set<T> *result) {
|
||||
if (validate_indices_ == false) CheckGroup<T>(ctx, group, sparse_tensor_shape);
|
||||
result->clear();
|
||||
const auto &group_values = group.values<T>();
|
||||
int64_t dim_size = group_values.size();
|
||||
if (dim_size >= kParallelDataNumSameShape) {
|
||||
uint32_t min_core_num = 1;
|
||||
uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
|
||||
if (dim_size <= kParallelDataNumMid) {
|
||||
max_core_num = std::min(max_core_num, 4U);
|
||||
}
|
||||
if (max_core_num > dim_size) {
|
||||
max_core_num = dim_size;
|
||||
}
|
||||
auto group_value_setsize = [&](int64_t start, int64_t end) {
|
||||
for (int64_t i = start; i < end; ++i) {
|
||||
result->insert(group_values(i));
|
||||
}
|
||||
};
|
||||
if (max_core_num == 0) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, dim_size, dim_size / max_core_num, group_value_setsize),
|
||||
"SetSize Compute failed.");
|
||||
} else {
|
||||
for (int64_t i = 0; i < group_values.size(); ++i) {
|
||||
result->insert(group_values(i));
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t SetSizeCpuKernel::SetSizeCompute(CpuKernelContext &ctx, SparseTensor &st) {
|
||||
auto output_t = reinterpret_cast<int32_t *>(ctx.Output(0)->GetData());
|
||||
std::vector<int64_t> group_ix(dims_ - 1);
|
||||
std::iota(group_ix.begin(), group_ix.end(), 0);
|
||||
std::vector<int64_t> strides(dims_);
|
||||
int64_t num2 = 2;
|
||||
auto shape_t = reinterpret_cast<int64_t *>(ctx.Input(num2)->GetData());
|
||||
if (dims_ > 1) {
|
||||
strides[dims_ - num2] = 1;
|
||||
}
|
||||
for (int32_t d = dims_ - 3; d >= 0; --d) {
|
||||
strides[d] = strides[d + 1] * shape_t[d + 1];
|
||||
}
|
||||
int64_t output_size = 1;
|
||||
for (int32_t d = 0; d < dims_ - 1; ++d) {
|
||||
output_size = output_size * shape_t[d];
|
||||
}
|
||||
memset_s(output_t, sizeof(int32_t) * output_size, 0, sizeof(int32_t) * output_size);
|
||||
std::unordered_set<T> group_set;
|
||||
for (const auto &group : st.group(group_ix)) {
|
||||
uint32_t result = PopulateFromSparseGroup<T>(ctx, group, shape_, &group_set);
|
||||
if (result != KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_ERROR("SetSize kernel compute failed.");
|
||||
return result;
|
||||
}
|
||||
const auto group_key = group.group();
|
||||
const auto output_index = std::inner_product(group_key.begin(), group_key.end(), strides.begin(), 0LL);
|
||||
*(output_t + output_index) = (int32_t)group_set.size();
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t SetSizeCpuKernel::SetSizeCompute_string(CpuKernelContext &ctx, SparseTensor &st) {
|
||||
auto output_t = reinterpret_cast<int32_t *>(ctx.Output(0)->GetData());
|
||||
std::vector<int64_t> group_ix(dims_ - 1);
|
||||
std::iota(group_ix.begin(), group_ix.end(), 0);
|
||||
std::vector<int64_t> strides(dims_);
|
||||
auto shape_t = reinterpret_cast<int64_t *>(ctx.Input(2)->GetData());
|
||||
int64_t num2 = 2;
|
||||
if (dims_ > 1) {
|
||||
strides[dims_ - num2] = 1;
|
||||
}
|
||||
for (int32_t d = dims_ - 3; d >= 0; --d) {
|
||||
strides[d] = strides[d + 1] * shape_t[d + 1];
|
||||
}
|
||||
int32_t output_size = 1;
|
||||
for (int32_t d = 0; d < dims_ - 1; ++d) {
|
||||
output_size = output_size * shape_t[d];
|
||||
}
|
||||
memset_s(output_t, sizeof(int32_t) * output_size, 0, sizeof(int32_t) * output_size);
|
||||
std::unordered_set<std::string> group_set;
|
||||
for (const auto &group : st.group(group_ix)) {
|
||||
PopulateFromSparseGroup<std::string>(ctx, group, shape_, &group_set);
|
||||
const auto group_key = group.group();
|
||||
const auto output_index = std::inner_product(group_key.begin(), group_key.end(), strides.begin(), 0LL);
|
||||
*(output_t + output_index) = (int32_t)group_set.size();
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
REGISTER_CPU_KERNEL(kSetSize, SetSizeCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_ROUND_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_ROUND_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
#include <unordered_set>
|
||||
#include <string>
|
||||
#include "utils/sparse_tensor.h"
|
||||
|
||||
namespace aicpu {
|
||||
class SetSizeCpuKernel : public CpuKernel {
|
||||
public:
|
||||
SetSizeCpuKernel() = default;
|
||||
~SetSizeCpuKernel() override = default;
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
uint32_t SetSizeCompute(CpuKernelContext &ctx, SparseTensor &st);
|
||||
uint32_t SetSizeCompute_string(CpuKernelContext &ctx, SparseTensor &st);
|
||||
uint32_t SparseTensorFromContext(CpuKernelContext &ctx, const bool validate_indices, SparseTensor &st);
|
||||
template <typename T>
|
||||
uint32_t PopulateFromSparseGroup(CpuKernelContext &ctx, const Group &group,
|
||||
const std::vector<int64_t> &sparse_tensor_shape, std::unordered_set<T> *result);
|
||||
template <typename T>
|
||||
uint32_t CheckGroup(CpuKernelContext &ctx, const Group &group, const std::vector<int64_t> &sparse_tensor_shape);
|
||||
bool validate_indices_ = true;
|
||||
uint32_t IndicesValid(CpuKernelContext &ctx, SparseTensor &st);
|
||||
|
||||
int32_t dims_;
|
||||
std::unordered_set<std::string> all_indices_;
|
||||
std::vector<int64_t> shape_;
|
||||
std::vector<int64_t> order_;
|
||||
Tensor *set_indices_ = nullptr;
|
||||
Tensor *set_values_ = nullptr;
|
||||
Tensor *set_shape_ = nullptr;
|
||||
Tensor *output_ = nullptr;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -152,4 +152,72 @@ uint32_t CpuKernelAllocatorUtils::DeleteOutputDataPtr(const uint64_t data_ptr) {
|
|||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t CpuKernelAllocatorUtils::AllocateOutputTensorDataMemory(const std::vector<uint64_t> &shape, DataType type,
|
||||
Tensor *&outputResultTensor) {
|
||||
KERNEL_CHECK_NULLPTR(outputResultTensor, KERNEL_STATUS_PARAM_INVALID, "outputResultTensor nullptr");
|
||||
KERNEL_LOG_INFO("AllocateOutputTensorDataMemory::START!!");
|
||||
if (shape.empty()) {
|
||||
KERNEL_LOG_ERROR("AllocateOutputTensorDataMemory shape size == 0.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
int64_t num_elements = 1;
|
||||
int64_t dim_size = 0;
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
dim_size = shape[i];
|
||||
KERNEL_CHECK_ASSIGN_64S_MULTI(num_elements, dim_size, num_elements, KERNEL_STATUS_PARAM_INVALID);
|
||||
}
|
||||
|
||||
uint64_t data_size = 0;
|
||||
int32_t element_size = GetSizeByDataType(type);
|
||||
KERNEL_CHECK_ASSIGN_64S_MULTI(num_elements, element_size, data_size, KERNEL_STATUS_PARAM_INVALID);
|
||||
if (data_size < 0) {
|
||||
KERNEL_LOG_ERROR("AllocateOutputTensorDataMemory data_size[%u].", data_size);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
uint64_t shape_buffer_size = 0;
|
||||
KERNEL_CHECK_ASSIGN_64S_MULTI(shape.size(), sizeof(int64_t), shape_buffer_size, KERNEL_STATUS_PARAM_INVALID);
|
||||
|
||||
void *output_shape_ptr = malloc(shape_buffer_size);
|
||||
KERNEL_CHECK_NULLPTR(output_shape_ptr, KERNEL_STATUS_PARAM_INVALID, "malloc error, size[%llu]!", shape_buffer_size);
|
||||
int32_t ret = memcpy_s(output_shape_ptr, shape_buffer_size, shape.data(), shape_buffer_size);
|
||||
if (ret != EOK) {
|
||||
free(output_shape_ptr);
|
||||
KERNEL_LOG_ERROR("memcpy error, size[%llu], ret[%d]!", shape_buffer_size, ret);
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
aicpu::FWKAdapter::ResultSummary *result_summary =
|
||||
reinterpret_cast<aicpu::FWKAdapter::ResultSummary *>(outputResultTensor->GetData());
|
||||
if (data_size == 0) {
|
||||
result_summary->raw_data_ptr = reinterpret_cast<uint64_t>(nullptr);
|
||||
result_summary->raw_data_size = 0;
|
||||
result_summary->shape_data_ptr = reinterpret_cast<uint64_t>(output_shape_ptr);
|
||||
result_summary->shape_data_size = shape_buffer_size;
|
||||
(void)g_allocated_ptr.insert(result_summary->shape_data_ptr);
|
||||
KERNEL_LOG_INFO("AllocateOutputTensorDataMemory:: empty tensor END!!");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
void *output_data_ptr = malloc(data_size);
|
||||
if (output_data_ptr == nullptr) {
|
||||
KERNEL_LOG_ERROR("malloc error, size[%lu]!", data_size);
|
||||
free(output_shape_ptr);
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
|
||||
result_summary->raw_data_size = data_size;
|
||||
result_summary->raw_data_ptr = reinterpret_cast<uint64_t>(output_data_ptr);
|
||||
result_summary->shape_data_size = shape_buffer_size;
|
||||
result_summary->shape_data_ptr = reinterpret_cast<uint64_t>(output_shape_ptr);
|
||||
|
||||
KERNEL_LOG_INFO("raw_data_ptr [%llu]", output_data_ptr);
|
||||
KERNEL_LOG_INFO("shape_data_ptr [%llu]", output_shape_ptr);
|
||||
|
||||
(void)g_allocated_ptr.insert(result_summary->raw_data_ptr);
|
||||
(void)g_allocated_ptr.insert(result_summary->shape_data_ptr);
|
||||
KERNEL_LOG_INFO("AllocateOutputTensorDataMemory :: END!!");
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
} // namespace aicpu
|
||||
|
|
|
@ -33,6 +33,8 @@ class AICPU_VISIBILITY CpuKernelAllocatorUtils {
|
|||
static uint32_t CheckOutputDataPtr(const uint64_t data_ptr);
|
||||
static uint32_t DeleteOutputDataPtr(const uint64_t data_ptr);
|
||||
static int64_t GetInputDataSize(const std::vector<int64_t> &dims, DataType type);
|
||||
static uint32_t AllocateOutputTensorDataMemory(const std::vector<uint64_t> &shape, DataType type,
|
||||
Tensor *&outputResultTensor);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_UTILS_ALLOCATOR_UTILS_H_
|
||||
|
|
|
@ -70,10 +70,17 @@ struct TTypes {
|
|||
|
||||
namespace aicpu {
|
||||
|
||||
namespace {
|
||||
using ShapeVector = std::vector<int64_t>;
|
||||
}
|
||||
|
||||
class EigenTensor {
|
||||
public:
|
||||
EigenTensor() = delete;
|
||||
EigenTensor(Tensor *tensor, void *data) : tensor_(tensor), tensor_data_(data) {}
|
||||
EigenTensor(Tensor *tensor, void *data) : tensor_(tensor), tensor_data_(data) {
|
||||
tensor_shape_ = tensor->GetTensorShape()->GetDimSizes();
|
||||
}
|
||||
EigenTensor(ShapeVector &shape, void *data_ptr) : tensor_shape_(shape), tensor_data_(data_ptr) {}
|
||||
~EigenTensor() = default;
|
||||
|
||||
/*
|
||||
|
@ -133,7 +140,9 @@ class EigenTensor {
|
|||
*/
|
||||
template <typename T>
|
||||
typename TTypes<T>::Flat flat() {
|
||||
return typename TTypes<T>::Flat(reinterpret_cast<T *>(tensor_data_), {tensor_->GetTensorShape()->NumElements()});
|
||||
return typename TTypes<T>::Flat(
|
||||
reinterpret_cast<T *>(tensor_data_),
|
||||
{std::accumulate(tensor_shape_.begin(), tensor_shape_.end(), 1, std::multiplies<int64_t>())});
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -143,10 +152,10 @@ class EigenTensor {
|
|||
template <int NDIMS, typename IndexType>
|
||||
Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesWithPadding() const {
|
||||
Eigen::DSizes<IndexType, NDIMS> dsizes;
|
||||
for (int d = 0; d < tensor_->GetTensorShape()->GetDims(); d++) {
|
||||
dsizes[d] = static_cast<IndexType>(tensor_->GetTensorShape()->GetDimSize(d));
|
||||
for (size_t d = 0; d < tensor_shape_.size(); d++) {
|
||||
dsizes[d] = static_cast<IndexType>(tensor_shape_[d]);
|
||||
}
|
||||
for (int d = tensor_->GetTensorShape()->GetDims(); d < NDIMS; d++) {
|
||||
for (size_t d = tensor_shape_.size(); d < NDIMS; d++) {
|
||||
dsizes[d] = 1;
|
||||
}
|
||||
return dsizes;
|
||||
|
@ -163,6 +172,7 @@ class EigenTensor {
|
|||
|
||||
private:
|
||||
Tensor *tensor_;
|
||||
ShapeVector tensor_shape_;
|
||||
void *tensor_data_;
|
||||
};
|
||||
} // namespace aicpu
|
||||
|
|
|
@ -0,0 +1,445 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "igamma_utils.h"
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
namespace {
|
||||
/**
|
||||
* Coefficients for the Lanczos approximation of the gamma function. The
|
||||
* coefficients are uniquely determined by the choice of g and n (kLanczosGamma
|
||||
* and kLanczosCoefficients.size() + 1). The coefficients below correspond to
|
||||
* [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and [7,
|
||||
* 9] seemed to be the least sensitive to the quality of the log function. In
|
||||
* particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
|
||||
* for a particularly inaccurate log function.
|
||||
* */
|
||||
static constexpr double kLanczosGamma = 7; // aka g
|
||||
static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
|
||||
static constexpr std::array<double, 8> kLanczosCoefficients = {
|
||||
676.520368121885098567009190444019, -1259.13921672240287047156078755283,
|
||||
771.3234287776530788486528258894, -176.61502916214059906584551354,
|
||||
12.507343278686904814458936853, -0.13857109526572011689554707,
|
||||
9.984369578019570859563e-6, 1.50563273514931155834e-7};
|
||||
double log_lanczos_gamma_plus_one_half = std::log(kLanczosGamma + 0.5);
|
||||
} // namespace
|
||||
|
||||
/** Compute the Lgamma function using Lanczos' approximation from "A Precision
|
||||
* Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
|
||||
* series B. Vol. 1:
|
||||
* lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z)
|
||||
* t(z) = z + kLanczosGamma + 1/2
|
||||
* A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
|
||||
*/
|
||||
template <typename T>
|
||||
T Lgamma(const T &input) {
|
||||
T log_pi = std::log(M_PI);
|
||||
T log_sqrt_two_pi = (std::log(2) + std::log(M_PI)) / 2;
|
||||
|
||||
/** If the input is less than 0.5 use Euler's reflection formula:
|
||||
* gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
|
||||
*/
|
||||
bool need_to_reflect = (input < 0.5);
|
||||
T input_after_reflect = need_to_reflect ? -input : input - 1; // aka z
|
||||
|
||||
T sum = kBaseLanczosCoeff; // aka x
|
||||
for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
|
||||
T lanczos_coefficient = kLanczosCoefficients[i];
|
||||
|
||||
sum += lanczos_coefficient / (input_after_reflect + i + 1);
|
||||
}
|
||||
|
||||
/** To improve accuracy on platforms with less-precise log implementations,
|
||||
* compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
|
||||
* the device.
|
||||
* log(t) = log(kLanczosGamma + 0.5 + z)
|
||||
* = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
|
||||
* */
|
||||
T gamma_plus_onehalf_plus_z = kLanczosGamma + 0.5 + input_after_reflect; // aka t
|
||||
|
||||
T log_t = log_lanczos_gamma_plus_one_half + std::log1pf(input_after_reflect / (kLanczosGamma + 0.5));
|
||||
|
||||
/** Compute the final result (modulo reflection). t(z) may be large, and we
|
||||
* need to be careful not to overflow to infinity in the first term of
|
||||
* (z + 1/2) * log(t(z)) - t(z).
|
||||
* Therefore we compute this as
|
||||
* (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
|
||||
*/
|
||||
T log_y = log_sqrt_two_pi + (input_after_reflect + 0.5 - gamma_plus_onehalf_plus_z / log_t) * log_t + std::log(sum);
|
||||
|
||||
/** Compute the reflected value, used when x < 0.5:
|
||||
*
|
||||
* lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
|
||||
*
|
||||
* (The abs is because lgamma is the log of the absolute value of the gamma
|
||||
* function.)
|
||||
*
|
||||
* We have to be careful when computing the final term above. gamma(x) goes
|
||||
* to +/-inf at every integer x < 0, and this is controlled by the
|
||||
* sin(pi * x) term. The slope is large, so precision is particularly
|
||||
* important.
|
||||
*
|
||||
* Because abs(sin(pi * x)) has period 1, we can equivalently use
|
||||
* abs(sin(pi * frac(x))), where frac(x) is the fractional part of x. This
|
||||
* is more numerically accurate: It doesn't overflow to inf like pi * x can,
|
||||
* and if x is an integer, it evaluates to 0 exactly, which is significant
|
||||
* because we then take the log of this value, and log(0) is inf.
|
||||
*
|
||||
* We don't have a frac(x) primitive in XLA and computing it is tricky, but
|
||||
* because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for
|
||||
* our purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
|
||||
*
|
||||
* Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
|
||||
* to 1. To remedy this, we can use the fact that sin(pi * x) in the domain
|
||||
* [0, 1] is symmetric across the line Y=0.5.
|
||||
*/
|
||||
T abs_input = std::abs(input);
|
||||
T abs_frac_input = abs_input - std::floor(abs_input);
|
||||
|
||||
/* Convert values of abs_frac_input > 0.5 to (1 - frac_input) to improve
|
||||
* precision of pi * abs_frac_input for values of abs_frac_input close to 1.
|
||||
*/
|
||||
T reduced_frac_input = (abs_frac_input > 0.5) ? 1 - abs_frac_input : abs_frac_input;
|
||||
T reflection_denom = std::log(std::sin(M_PI * reduced_frac_input));
|
||||
|
||||
/* Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf,
|
||||
* then it "wins" and the result is +/-inf.
|
||||
*/
|
||||
T reflection = std::isfinite(reflection_denom) ? log_pi - reflection_denom - log_y : -reflection_denom;
|
||||
|
||||
T result = need_to_reflect ? reflection : log_y;
|
||||
|
||||
return std::isinf(input) ? std::numeric_limits<T>::infinity() : result;
|
||||
};
|
||||
|
||||
/* Compute the Digamma function using Lanczos' approximation from "A Precision
|
||||
* Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
|
||||
* series B. Vol. 1:
|
||||
* digamma(z + 1) = log(t(z)) + A'(z) / A(z) - kLanczosGamma / t(z)
|
||||
* t(z) = z + kLanczosGamma + 1/2
|
||||
* A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
|
||||
* A'(z) = sigma(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
|
||||
*/
|
||||
template <typename T>
|
||||
T Digamma(const T &input) {
|
||||
/* If the input is less than 0.5 use Euler's reflection formula:
|
||||
* digamma(x) = digamma(1 - x) - pi * cot(pi * x)
|
||||
*/
|
||||
bool need_to_reflect = (input < 0.5);
|
||||
T reflected_input = need_to_reflect ? -input : input - 1; // aka z
|
||||
|
||||
T num = 0;
|
||||
T denom = kBaseLanczosCoeff;
|
||||
|
||||
for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
|
||||
T lanczos_coefficient = kLanczosCoefficients[i];
|
||||
num -= lanczos_coefficient / ((reflected_input + i + 1) * (reflected_input + i + 1));
|
||||
denom += lanczos_coefficient / (reflected_input + i + 1);
|
||||
}
|
||||
|
||||
/* To improve accuracy on platforms with less-precise log implementations,
|
||||
* compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
|
||||
* the device.
|
||||
* log(t) = log(kLanczosGamma + 0.5 + z)
|
||||
* = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
|
||||
*/
|
||||
|
||||
T gamma_plus_onehalf_plus_z = kLanczosGamma + 0.5 + reflected_input; // aka t
|
||||
T log_t = log_lanczos_gamma_plus_one_half + std::log1pf(reflected_input / (kLanczosGamma + 0.5));
|
||||
|
||||
T result = log_t + num / denom - kLanczosGamma / gamma_plus_onehalf_plus_z; // aka y
|
||||
|
||||
/* We need to be careful how we compute cot(pi * input) below: For
|
||||
* near-integral values of `input`, pi * input can lose precision.
|
||||
*
|
||||
* Input is already known to be less than 0.5 (otherwise we don't have to
|
||||
* reflect). We shift values smaller than -0.5 into the range [-.5, .5] to
|
||||
* increase precision of pi * input and the resulting cotangent.
|
||||
*/
|
||||
|
||||
T reduced_input = input + std::abs(std::floor(input + 0.5));
|
||||
T reflection = result - M_PI * std::cos(M_PI * reduced_input) / std::sin(M_PI * reduced_input);
|
||||
T real_result = need_to_reflect ? reflection : result;
|
||||
|
||||
// Digamma has poles at negative integers and zero; return nan for those.
|
||||
return (input < 0 && input == std::floor(input)) ? std::numeric_limits<T>::quiet_NaN() : real_result;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void IgammaSeriesLoop(std::vector<T> &vals, const int &mode) {
|
||||
while (vals[0]) {
|
||||
T enabled = vals[0];
|
||||
T r = vals[1];
|
||||
T c = vals[2];
|
||||
T ans = vals[3];
|
||||
T x = vals[4];
|
||||
T dc_da = vals[5];
|
||||
T dans_da = vals[6];
|
||||
|
||||
r += 1;
|
||||
dc_da = dc_da * (x / r) + (-1 * c * x) / (r * r);
|
||||
dans_da = dans_da + dc_da;
|
||||
c = c * (x / r);
|
||||
ans = ans + c;
|
||||
T conditional;
|
||||
if (mode == 1) {
|
||||
conditional = enabled && (c / ans > std::numeric_limits<T>::epsilon());
|
||||
} else {
|
||||
conditional = enabled && (std::abs(dc_da / dans_da) > std::numeric_limits<T>::epsilon());
|
||||
}
|
||||
|
||||
vals[0] = conditional;
|
||||
if (enabled) {
|
||||
vals = {conditional, r, c, ans, x, dc_da, dans_da};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function for computing Igamma using a power series.
|
||||
template <typename T>
|
||||
T IgammaSeries(const T &ax, const T &x, const T &a, const T &enabled, const int &mode) {
|
||||
/* vals: (enabled, r, c, ans, x)
|
||||
* 'enabled' is a predication mask that says for which elements we should
|
||||
* execute the loop body. Disabled elements have no effect in the loop body.
|
||||
*/
|
||||
std::vector<T> vals = {enabled, a, 1, 1, x, 0, 0};
|
||||
IgammaSeriesLoop<T>(vals, mode);
|
||||
|
||||
T ans = vals[3];
|
||||
T dans_da = vals[6];
|
||||
auto base_num = a == 0 ? 1 : a;
|
||||
if (mode == 1) {
|
||||
return (ans * ax) / base_num;
|
||||
}
|
||||
|
||||
T dlogax_da = std::log(x) - Digamma<T>(a + 1);
|
||||
switch (mode) {
|
||||
case 2:
|
||||
return ax * (ans * dlogax_da + dans_da) / base_num;
|
||||
default:
|
||||
return -(dans_da + ans * dlogax_da) * x / base_num;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void IgammacCFLoop(std::vector<T> &vals, const int &mode) {
|
||||
while (vals[0] && vals[5] < 2000) {
|
||||
T enabled = vals[0];
|
||||
T ans = vals[1];
|
||||
T tmp_var_t = vals[2];
|
||||
T tmp_var_y = vals[3];
|
||||
T tmp_var_z = vals[4];
|
||||
T tmp_var_c = vals[5];
|
||||
T pkm1 = vals[6];
|
||||
T qkm1 = vals[7];
|
||||
T pkm2 = vals[8];
|
||||
T qkm2 = vals[9];
|
||||
T dpkm2_da = vals[10];
|
||||
T dqkm2_da = vals[11];
|
||||
T dpkm1_da = vals[12];
|
||||
T dqkm1_da = vals[13];
|
||||
T dans_da = vals[14];
|
||||
|
||||
tmp_var_c += 1;
|
||||
tmp_var_y += 1;
|
||||
tmp_var_z += 2;
|
||||
|
||||
T yc = tmp_var_y * tmp_var_c;
|
||||
T pk = pkm1 * tmp_var_z - pkm2 * yc;
|
||||
T qk = qkm1 * tmp_var_z - qkm2 * yc;
|
||||
bool qk_is_nonzero = (qk != 0);
|
||||
T r = pk / qk;
|
||||
|
||||
T t = qk_is_nonzero ? std::abs((ans - r) / r) : 1;
|
||||
ans = qk_is_nonzero ? r : ans;
|
||||
|
||||
T dpk_da = dpkm1_da * tmp_var_z - pkm1 - dpkm2_da * yc + pkm2 * tmp_var_c;
|
||||
T dqk_da = dqkm1_da * tmp_var_z - qkm1 - dqkm2_da * yc + qkm2 * tmp_var_c;
|
||||
T dans_da_new = qk_is_nonzero ? (dpk_da - ans * dqk_da) / qk : dans_da;
|
||||
T grad_conditional = qk_is_nonzero ? std::abs(dans_da_new - dans_da) : 1;
|
||||
|
||||
pkm2 = pkm1;
|
||||
pkm1 = pk;
|
||||
qkm2 = qkm1;
|
||||
qkm1 = qk;
|
||||
|
||||
dpkm2_da = dpkm1_da;
|
||||
dqkm2_da = dqkm1_da;
|
||||
dpkm1_da = dpk_da;
|
||||
dqkm1_da = dqk_da;
|
||||
|
||||
bool rescale = std::abs(pk) > (1 / std::numeric_limits<T>::epsilon());
|
||||
|
||||
pkm2 = rescale ? pkm2 * std::numeric_limits<T>::epsilon() : pkm2;
|
||||
pkm1 = rescale ? pkm1 * std::numeric_limits<T>::epsilon() : pkm1;
|
||||
qkm2 = rescale ? qkm2 * std::numeric_limits<T>::epsilon() : qkm2;
|
||||
qkm1 = rescale ? qkm1 * std::numeric_limits<T>::epsilon() : qkm1;
|
||||
|
||||
dpkm2_da = rescale ? dpkm2_da * std::numeric_limits<T>::epsilon() : dpkm2_da;
|
||||
dqkm2_da = rescale ? dqkm2_da * std::numeric_limits<T>::epsilon() : dqkm2_da;
|
||||
dpkm1_da = rescale ? dpkm1_da * std::numeric_limits<T>::epsilon() : dpkm1_da;
|
||||
dqkm1_da = rescale ? dqkm1_da * std::numeric_limits<T>::epsilon() : dqkm1_da;
|
||||
|
||||
T conditional;
|
||||
|
||||
if (mode == 1) {
|
||||
conditional = enabled && (t > std::numeric_limits<T>::epsilon());
|
||||
} else {
|
||||
conditional = enabled && (grad_conditional > std::numeric_limits<T>::epsilon());
|
||||
}
|
||||
|
||||
vals[0] = conditional;
|
||||
vals[5] = tmp_var_c;
|
||||
if (enabled) {
|
||||
vals = {conditional, ans, tmp_var_t, tmp_var_y, tmp_var_z, tmp_var_c, pkm1, qkm1,
|
||||
pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da_new};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T IgammacContinuedFraction(const T &ax, const T &x, const T &a, const T &enabled, const int &mode) {
|
||||
// vals: enabled, ans, t, y, z, c, pkm1, qkm1, pkm2, qkm2
|
||||
T y = 1 - a;
|
||||
T z = x + y + 1;
|
||||
T c = 0;
|
||||
T pkm2 = 1;
|
||||
T qkm2 = x;
|
||||
T pkm1 = x + 1;
|
||||
T qkm1 = z * x;
|
||||
T ans = pkm1 / qkm1;
|
||||
T t = 1;
|
||||
T dpkm2_da = 0;
|
||||
T dqkm2_da = 0;
|
||||
T dpkm1_da = 0;
|
||||
T dqkm1_da = -x;
|
||||
T dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1;
|
||||
std::vector<T> vals = {enabled, ans, t, y, z, c, pkm1, qkm1,
|
||||
pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da};
|
||||
|
||||
IgammacCFLoop<T>(vals, mode);
|
||||
|
||||
ans = vals[1];
|
||||
if (mode == 1) {
|
||||
return ans * ax;
|
||||
}
|
||||
|
||||
dans_da = vals[14];
|
||||
T dlogax_da = std::log(x) - Digamma<T>(a);
|
||||
switch (mode) {
|
||||
case 2:
|
||||
return ax * (ans * dlogax_da + dans_da);
|
||||
default:
|
||||
return -(dans_da + ans * dlogax_da) * x;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T IgammaSingle(const T &a, const T &x) {
|
||||
if (!std::isinf(a) && (a > 0) && std::isinf(x) && x > 0) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
T is_nan = std::isnan(a) || std::isnan(x);
|
||||
T x_is_zero = (x == 0);
|
||||
T domain_error = (x < 0) || (a <= 0);
|
||||
T use_igammac = (x > 1) && (x > a);
|
||||
|
||||
T ax = a * std::log(x) - x - Lgamma<T>(a);
|
||||
|
||||
T underflow = (ax < -std::log(std::numeric_limits<T>::max()));
|
||||
|
||||
ax = std::exp(ax);
|
||||
T enabled = !(x_is_zero || domain_error || underflow || is_nan);
|
||||
|
||||
T output = use_igammac ? 1 - IgammacContinuedFraction<T>(ax, x, a, enabled && use_igammac, 1)
|
||||
: IgammaSeries<T>(ax, x, a, (enabled && !(use_igammac)), 1);
|
||||
|
||||
output = (domain_error || is_nan || std::isnan(output)) ? std::numeric_limits<double>::quiet_NaN() : output;
|
||||
|
||||
output = x_is_zero ? 0 : output;
|
||||
return output;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Igamma(T *a, T *x, T *output, int size) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
*(output + i) = IgammaSingle<T>(*(a + i), *(x + i));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T IgammacSingle(const T &a, const T &x) {
|
||||
T out_of_range = (x <= 0) || (a <= 0);
|
||||
T use_igamma = (x < 1) || (x < a);
|
||||
T ax = a * std::log(x) - x - Lgamma<double>(a);
|
||||
T underflow = (ax < -std::log(std::numeric_limits<T>::max()));
|
||||
|
||||
T enabled = !(out_of_range || underflow);
|
||||
|
||||
ax = std::exp(ax);
|
||||
T output = use_igamma ? 1 - IgammaSeries<T>(ax, x, a, (enabled && use_igamma), 1)
|
||||
: IgammacContinuedFraction<T>(ax, x, a, enabled && !use_igamma, 1);
|
||||
|
||||
output = out_of_range ? 1 : output;
|
||||
|
||||
output = x < 0 || a <= 0 || std::isnan(x) || (std::isinf(x) && (x > 0)) || std::isnan(a)
|
||||
? std::numeric_limits<T>::quiet_NaN()
|
||||
: output;
|
||||
output = std::isinf(x) && x > 0 && a > 0 ? 0 : output;
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Igammac(T *a, T *x, T *output, int size) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
*(output + i) = IgammacSingle<T>(*(a + i), *(x + 1));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void IgammaGradA(T *a, T *x, T *output, int size) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
*(output + i) = IgammaGradASingle<T>(*(a + i), *(x + i));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T IgammaGradASingle(const T &a, const T &x) {
|
||||
T is_nan = std::isnan(a) || std::isnan(x);
|
||||
T x_is_zero = (x == 0);
|
||||
T domain_error = (x < 0) || (a <= 0);
|
||||
T use_igammac = (x > 1) && (x > a);
|
||||
T ax = a * std::log(x) - x - Lgamma<T>(a);
|
||||
T underflow = (ax < -std::log(std::numeric_limits<T>::max()));
|
||||
ax = std::exp(ax);
|
||||
T enabled = !(x_is_zero || domain_error || underflow || is_nan);
|
||||
T output = use_igammac ? -IgammacContinuedFraction<T>(ax, x, a, enabled && use_igammac, 2)
|
||||
: IgammaSeries<T>(ax, x, a, (enabled && !(use_igammac)), 2);
|
||||
|
||||
output = (domain_error || is_nan || std::isnan(output)) ? std::numeric_limits<double>::quiet_NaN() : output;
|
||||
output = x_is_zero || (std::isinf(x) && !is_nan && !domain_error && !std::isinf(a)) ? 0 : output;
|
||||
|
||||
return output;
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_IGAMMA_UTILS_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_IGAMMA_UTILS_H_
|
||||
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
template <typename T>
|
||||
T IgammaSeries(const T &ax, const T &x, const T &a, const T &enabled, const int &mode);
|
||||
|
||||
template <typename T>
|
||||
T IgammacContinuedFraction(const T &ax, const T &x, const T &a, const T &enabled, const int &mode);
|
||||
|
||||
// Computes an approximation of the lgamma function.
|
||||
template <typename T>
|
||||
T Lgamma(const T &input);
|
||||
|
||||
// Computes an approximation of the digamma function.
|
||||
template <typename T>
|
||||
T Digamma(const T &input);
|
||||
|
||||
template <typename T>
|
||||
T IgammaSingle(const T &a, const T &x);
|
||||
|
||||
// Computes an approximation of the incomplete gamma function.
|
||||
template <typename T>
|
||||
void Igamma(T *a, T *x, T *output, int size);
|
||||
|
||||
template <typename T>
|
||||
T IgammaGradASingle(const T &a, const T &x);
|
||||
|
||||
/** an approximation of the derivative of the incomplete gamma function
|
||||
* with respect to a.
|
||||
*/
|
||||
template <typename T>
|
||||
void IgammaGradA(T *a, T *x, T *output, int size);
|
||||
|
||||
template <typename T>
|
||||
T IgammacSingle(const T &a, const T &x);
|
||||
|
||||
// Computes an approximation of the complementary incomplete gamma function.
|
||||
template <typename T>
|
||||
void Igammac(T *a, T *x, T *output, int size);
|
||||
|
||||
#endif
|
|
@ -235,4 +235,48 @@ std::string DTypeStr(DataType dtype) {
|
|||
return std::string("DT_UNDEFINED");
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t CheckTensorTypeSame(const std::map<std::string, DataType> &types, const DataType &check_type,
|
||||
const std::string &prim_name) {
|
||||
if (types.empty()) {
|
||||
KERNEL_LOG_ERROR("Trying to use the function to check a empty types map!");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
for (const auto type : types) {
|
||||
auto _type_ = type.second;
|
||||
if (_type_ != check_type) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"For primitive[%s]'s input arguments [%s] type should equal to [%s] , "
|
||||
"but get the real type [%s].",
|
||||
prim_name.c_str(), type.first.c_str(), DTypeStr(check_type).c_str(), DTypeStr(_type_).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t CheckTensorShapeSame(const std::map<std::string, TensorShapePtr> &shapes,
|
||||
const std::vector<int64_t> &check_shape, const std::string &prim_name) {
|
||||
if (shapes.empty()) {
|
||||
KERNEL_LOG_ERROR("Trying to use the function to check a empty types map!");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
for (const auto shape : shapes) {
|
||||
auto _shape_ptr_ = shape.second;
|
||||
KERNEL_CHECK_NULLPTR(_shape_ptr_, KERNEL_STATUS_PARAM_INVALID,
|
||||
"For primitive[%s]'s input arguments [%s] TensorShapePtr "
|
||||
"should not be nullptr.",
|
||||
prim_name.c_str(), shape.first.c_str());
|
||||
auto _shape_ = _shape_ptr_->GetDimSizes();
|
||||
if (!ShapeVectorIsSame(_shape_, check_shape)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"For primitive[%s]'s input arguments [%s] shape should equal to (%s) , "
|
||||
"but get the real shape (%s).",
|
||||
prim_name.c_str(), shape.first.c_str(), VectorToString(check_shape).c_str(), VectorToString(_shape_).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
} // namespace aicpu
|
||||
|
|
|
@ -44,6 +44,12 @@ constexpr uint64_t kFormatNCHWIndexC = 1;
|
|||
constexpr uint64_t kFormatNCHWIndexH = 2;
|
||||
constexpr uint64_t kFormatNCHWIndexW = 3;
|
||||
|
||||
constexpr uint64_t kFormatNC1HWC0IndexN = 0;
|
||||
constexpr uint64_t kFormatNC1HWC0IndexC1 = 1;
|
||||
constexpr uint64_t kFormatNC1HWC0IndexH = 2;
|
||||
constexpr uint64_t kFormatNC1HWC0IndexW = 3;
|
||||
constexpr uint64_t kFormatNC1HWC0IndexC0 = 4;
|
||||
|
||||
constexpr uint64_t kFormatCHWIndexC = 0;
|
||||
constexpr uint64_t kFormatCHWIndexH = 1;
|
||||
constexpr uint64_t kFormatCHWIndexW = 2;
|
||||
|
@ -68,6 +74,9 @@ const size_t INPUT_NUM7 = 7;
|
|||
const size_t INPUT_NUM8 = 8;
|
||||
const size_t INPUT_NUM9 = 9;
|
||||
const size_t INPUT_NUM32 = 32;
|
||||
|
||||
using TensorShapePtr = std::shared_ptr<TensorShape>;
|
||||
|
||||
/*
|
||||
* str cat util function
|
||||
* param[in] params need concat to string
|
||||
|
@ -193,6 +202,25 @@ inline bool AddWithoutOverflow(const int64_t x, const int64_t y, int64_t &sum) {
|
|||
return !(((x >= 0) == (y >= 0)) && ((sum >= 0) != (x >= 0)));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief check two shape vector are same
|
||||
* @param shape shape
|
||||
* @param check_shape check_shape
|
||||
* @return true: same, false: different
|
||||
*/
|
||||
inline bool ShapeVectorIsSame(const std::vector<int64_t> &shape, const std::vector<int64_t> &check_shape) {
|
||||
if (shape.size() != check_shape.size()) {
|
||||
return false;
|
||||
} else {
|
||||
for (size_t idx = 0; idx < shape.size(); ++idx) {
|
||||
if (shape[idx] != check_shape[idx]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief normal check for calculation
|
||||
* @param ctx context
|
||||
|
@ -250,5 +278,25 @@ DataType DType(std::string dtype_str);
|
|||
*/
|
||||
std::string DTypeStr(DataType dtype);
|
||||
|
||||
/**
|
||||
* @brief check tensor type is same
|
||||
* @param types a map with name and data type
|
||||
* @param check_type check_type
|
||||
* @param prim_name ops name
|
||||
* @return status code
|
||||
*/
|
||||
uint32_t CheckTensorTypeSame(const std::map<std::string, DataType> &types, const DataType &check_type,
|
||||
const std::string &prim_name);
|
||||
|
||||
/**
|
||||
* @brief check tensor type is same
|
||||
* @param shapes a map with name and shape
|
||||
* @param check_type check_shape
|
||||
* @param prim_name ops name
|
||||
* @return status code
|
||||
*/
|
||||
uint32_t CheckTensorShapeSame(const std::map<std::string, TensorShapePtr> &shapes,
|
||||
const std::vector<int64_t> &check_shape, const std::string &prim_name);
|
||||
|
||||
} // namespace aicpu
|
||||
#endif
|
||||
|
|
Loading…
Reference in New Issue