!34252 [CPU operator] Add ReNorm CPU operator

Merge pull request !34252 from Xiaoda/127-add-renorm-op
This commit is contained in:
i-robot 2022-06-07 05:47:46 +00:00 committed by Gitee
commit ec65b96074
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 543 additions and 2 deletions

View File

@ -0,0 +1,174 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/renorm_cpu_kernel.h"
#include <algorithm>
#include <memory>
#include <utility>
#include <map>
#include <cmath>
#include "mindspore/core/ops/renorm.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "include/common/thread_pool.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kRenormInputsNum = 1;
constexpr size_t kRenormOutputsNum = 1;
} // namespace
bool RenormCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::Renorm>(base_operator);
if (!kernel_ptr) {
MS_LOG(ERROR) << "cast Renorm ops failed!";
return false;
}
kernel_name_ = kernel_ptr->name();
if (inputs.size() != kRenormInputsNum || outputs.size() != kRenormOutputsNum) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', input and output tensor number must be " << kRenormInputsNum
<< " and " << kRenormOutputsNum << ", but got " << inputs.size() << " and " << outputs.size();
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "Renorm does not support this kernel data type: " << kernel_attr;
return false;
}
base_operator_ = base_operator;
kernel_func_ = func_list_[index].second;
return true;
}
int RenormCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &others) {
int ret = 0;
if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs, others)) != 0) {
MS_LOG(WARNING) << kernel_name_ << " resize failed.";
return ret;
}
x_shape_ = inputs[kIndex0]->GetShapeVector();
axis_ = GetValue<int64_t>(base_operator_->GetAttr("dim"));
p_ = GetValue<float>(base_operator_->GetAttr("p"));
max_norm_ = GetValue<float>(base_operator->GetAttr("maxnorm"));
return 0;
}
void RenormCpuKernelMod::CheckAndInitParams() {
if (p_ <= 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the attribute norm 'p' must be positive, but got " << p_;
}
if (max_norm_ < 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the attribute 'maxnorm' must be non-negative, but got "
<< max_norm_;
}
auto x_rank = x_shape_.size();
if (axis_ < -SizeToLong(x_rank) || axis_ >= SizeToLong(x_rank)) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << ", the attribute 'axis' must be in range [" << -x_rank << ", "
<< x_rank << "), but got " << axis_;
}
if (axis_ < 0) {
axis_ += x_rank;
}
stride_size_ = 1;
inner_size_ = 1;
axis_size_ = 1;
total_size_ = 1;
for (size_t i = 0; i < x_rank; ++i) {
if (SizeToLong(i) == axis_) {
axis_size_ *= x_shape_[i];
} else if (SizeToLong(i) < axis_) {
stride_size_ *= x_shape_[i];
} else {
inner_size_ *= x_shape_[i];
}
total_size_ *= x_shape_[i];
}
}
template <typename T>
bool RenormCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kRenormInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kRenormOutputsNum, kernel_name_);
auto *x = reinterpret_cast<T *>(inputs[kIndex0]->addr);
auto *output = reinterpret_cast<T *>(outputs[kIndex0]->addr);
CheckAndInitParams();
auto axis_size = axis_size_; // maximum parallel number
auto inner_size = inner_size_; // continuous number
auto stride_size = stride_size_; // stride number
auto total_size = total_size_; // total number
auto p = static_cast<T>(p_);
auto maxnorm = static_cast<T>(max_norm_);
auto pnorm = std::make_unique<T[]>(axis_size);
auto task = [&](const size_t start, const size_t end) {
for (size_t ith = start; ith < end; ++ith) {
T single_sum = static_cast<T>(0.0);
size_t step_len = total_size / stride_size;
for (size_t pos_ith = ith * inner_size; pos_ith < total_size; pos_ith += step_len) {
for (size_t j = 0; j < inner_size; ++j) {
size_t index = pos_ith + j;
single_sum += pow(abs(x[index]), p);
}
}
pnorm[ith] = pow(single_sum, static_cast<T>(1) / p);
for (size_t pos_ith = ith * inner_size; pos_ith < total_size; pos_ith += step_len) {
for (size_t j = 0; j < inner_size; ++j) {
size_t index = pos_ith + j;
if (pnorm[ith] > maxnorm) {
output[index] = x[index] / pnorm[ith] * maxnorm;
} else {
output[index] = x[index];
}
}
}
}
};
ParallelLaunchAutoSearch(task, axis_size, this, &parallel_search_info_);
return true;
}
std::vector<std::pair<KernelAttr, RenormCpuKernelMod::RenormFunc>> RenormCpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&RenormCpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&RenormCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&RenormCpuKernelMod::LaunchKernel<float16>}};
std::vector<KernelAttr> RenormCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, RenormFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Renorm, RenormCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,81 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_RENORM_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_RENORM_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include <utility>
#include <map>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class RenormCpuKernelMod : public NativeCpuKernelMod {
public:
RenormCpuKernelMod() = default;
~RenormCpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &others = std::map<uint32_t, tensor::TensorPtr>()) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
MS_EXCEPTION_IF_NULL(kernel_func_);
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
void CheckAndInitParams();
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
using RenormFunc =
std::function<bool(RenormCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, RenormFunc>> func_list_;
RenormFunc kernel_func_{nullptr};
BaseOperatorPtr base_operator_;
// input shape
std::vector<int64_t> x_shape_;
// axis attribute of the primitive
int64_t axis_{0};
// the p norm attribute of the primitive. Now, only positive integer supported.
float p_{0};
// the maxnorm attribute of the primitive.
float max_norm_{0.0};
// sizes that are used in calculation
size_t stride_size_{1};
size_t inner_size_{1};
size_t axis_size_{1};
size_t total_size_{1};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_RENORM_CPU_KERNEL_H_

View File

@ -439,6 +439,46 @@ def get_lp_norm_vmap_rule(prim, axis_size):
return vmap_rule
@vmap_rules_getters.register(P.Renorm)
def get_renorm_rule(prim, axis_size):
"""VmapRule for Renorm"""
pnorm = prim.p
axis = prim.dim
maxnorm = prim.maxnorm
def vmap_rule(x_bdim):
is_all_none, result = vmap_general_preprocess(prim, x_bdim)
if is_all_none:
return result
x, batch_dim = x_bdim
batch_dim = batch_dim if batch_dim >= 0 else batch_dim + F.rank(x)
src_dim = batch_dim
origin_axis = axis if axis >= 0 else axis + F.rank(x) - 1
if batch_dim <= origin_axis:
actual_axis = origin_axis + 1
des_dim = actual_axis - 1
x = mnp.moveaxis(x, src_dim, des_dim)
from_shape = F.shape(x)
to_shape = from_shape[:actual_axis-1] + \
(from_shape[actual_axis-1]*from_shape[actual_axis],) + from_shape[actual_axis+1:]
else:
actual_axis = origin_axis
des_dim = actual_axis + 1
x = mnp.moveaxis(x, src_dim, des_dim)
from_shape = F.shape(x)
to_shape = from_shape[:actual_axis] + \
(from_shape[actual_axis]*from_shape[actual_axis+1],) + from_shape[actual_axis+2:]
x = F.reshape(x, to_shape)
op = P.Renorm(int(pnorm), origin_axis, maxnorm)
out = op(x)
out = F.reshape(out, from_shape)
out = mnp.moveaxis(out, des_dim, src_dim)
return (out, batch_dim)
return vmap_rule
get_assign_vmap_rule = vmap_rules_getters.register(P.AssignAdd)(get_assign_vmap_rule)
get_assign_vmap_rule = vmap_rules_getters.register(P.AssignSub)(get_assign_vmap_rule)
# Unary vmap

View File

@ -69,7 +69,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
Ceil, Acosh, Greater, GreaterEqual, Lerp, Less, LessEqual, Log, Log1p, LogicalAnd, Mod,
LogicalNot, LogicalOr, LogicalXor, LpNorm, MatMul, Maximum, MulNoNan,
MatrixDeterminant, LogMatrixDeterminant, Minimum, Mul, Neg, NMSWithMask, NotEqual,
NPUAllocFloatStatus, NPUClearFloatStatus, LinSpace, Einsum,
NPUAllocFloatStatus, NPUClearFloatStatus, LinSpace, Einsum, Renorm,
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
Sin, Sqrt, Rsqrt, BesselI0, BesselI1, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, Addcdiv,
@ -203,6 +203,7 @@ __all__ = [
'FusedWeightScaleApplyMomentum',
'ExpandDims',
'Einsum',
'Renorm',
'Cast',
'IsSubClass',
'IsInstance',

View File

@ -5875,7 +5875,7 @@ class Renorm(Primitive):
Refer to :func::`mindspore.ops.renorm` for more detail.
Supported Platforms:
``Ascend``
``Ascend`` ``CPU``
Example:
>>> x = Tensor(np.array([[1, 1, 1], [2, 2, 2], [3, 3, 3]]), mindspore.float32)

View File

@ -0,0 +1,245 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.functional import vmap
from mindspore.common.api import ms_function
class ReNormNet(nn.Cell):
def __init__(self, p=1, axis=0, maxnorm=10.0):
super(ReNormNet, self).__init__()
self.renorm = P.Renorm(p, axis, maxnorm)
def construct(self, input_x):
output = self.renorm(input_x)
return output
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
def test_renorm_op_float32(data_type=np.float32):
"""
Feature: test Renorm with using float32.
Description: inputs with batch.
Expectation: the result match with expect.
"""
error = 1e-6
context.set_context(mode=context.GRAPH_MODE)
input_x = np.array([[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
[[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]]).astype(data_type)
benchmark_output = np.array([[[0.27777779, 0.55555558, 0.83333337, 1.11111116],
[1.38888896, 1.66666675, 1.94444454, 2.22222233]],
[[0.90000004, 1.00000000, 1.10000002, 1.20000005],
[1.30000007, 1.39999998, 1.50000000, 1.60000002]]]).astype(data_type)
re_norm = ReNormNet()
output = re_norm(Tensor(input_x))
np.testing.assert_allclose(output.asnumpy(), benchmark_output, rtol=error)
context.set_context(mode=context.PYNATIVE_MODE)
output = re_norm(Tensor(input_x))
np.testing.assert_allclose(output.asnumpy(), benchmark_output, rtol=error)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
def test_renorm_op_float16(data_type=np.float16):
"""
Feature: test Renorm using float16.
Description: inputs with batch.
Expectation: the result match with expect.
"""
error = 1e-3
context.set_context(mode=context.GRAPH_MODE)
input_x = np.array([[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
[[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]]).astype(data_type)
benchmark_output = np.array([[[0.27783203, 0.55566406, 0.83349609, 1.11132812],
[1.38867188, 1.66699219, 1.94531250, 2.22265625]],
[[0.89990234, 1.00000000, 1.09960938, 1.19921875],
[1.29980469, 1.39941406, 1.50000000, 1.59960938]]]).astype(data_type)
re_norm = ReNormNet()
output = re_norm(Tensor(input_x))
np.testing.assert_allclose(output.asnumpy(), benchmark_output, rtol=error)
context.set_context(mode=context.PYNATIVE_MODE)
output = re_norm(Tensor(input_x))
np.testing.assert_allclose(output.asnumpy(), benchmark_output, rtol=error)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
def test_renorm_op1_float32(data_type=np.float32):
"""
Feature: test Renorm using float32.
Description: inputs with batch.
Expectation: the result match with expect.
"""
error = 1e-6
context.set_context(mode=context.GRAPH_MODE)
input_x = np.array([[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
[[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]]).astype(data_type)
benchmark_output = np.array([[[0.45834923, 0.91669846, 1.37504768, 1.83339691],
[1.56556070, 1.87867284, 2.19178486, 2.50489712]],
[[4.12514305, 4.58349228, 5.04184151, 5.50019073],
[4.07045794, 4.38356972, 4.69668198, 5.00979424]]]).astype(data_type)
re_norm = ReNormNet(p=2, axis=1, maxnorm=10.0)
output = re_norm(Tensor(input_x))
np.testing.assert_allclose(output.asnumpy(), benchmark_output, rtol=error)
context.set_context(mode=context.PYNATIVE_MODE)
output = re_norm(Tensor(input_x))
np.testing.assert_allclose(output.asnumpy(), benchmark_output, rtol=error)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
def test_renorm_op2_float16(data_type=np.float16):
"""
Feature: test Renorm using float16.
Description: inputs with batch.
Expectation: the result match with expect.
"""
error = 1e-3
context.set_context(mode=context.GRAPH_MODE)
input_x = np.array([[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
[[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]]).astype(data_type)
benchmark_output = np.array([[[0.60192931, 1.09108937, 1.49255586, 1.82574177],
[3.00964642, 3.27326822, 3.48263025, 3.65148354]],
[[5.41736364, 5.45544672, 5.47270441, 5.47722530],
[7.82508087, 7.63762569, 7.46277905, 7.30296707]]]).astype(data_type)
re_norm = ReNormNet(p=2, axis=2, maxnorm=10.0)
output = re_norm(Tensor(input_x))
np.testing.assert_allclose(output.asnumpy(), benchmark_output, rtol=error)
context.set_context(mode=context.PYNATIVE_MODE)
output = re_norm(Tensor(input_x))
np.testing.assert_allclose(output.asnumpy(), benchmark_output, rtol=error)
def vmap_case():
class Net(nn.Cell):
def __init__(self, p, axis, maxnorm):
super(Net, self).__init__()
self.renorm = P.Renorm(p, axis, maxnorm)
def construct(self, x):
return self.renorm(x)
class VmapNet(nn.Cell):
def __init__(self, net, in_axes, out_axes):
super(VmapNet, self).__init__()
self.net = net
self.in_axes = in_axes
self.out_axes = out_axes
def construct(self, x):
return vmap(self.net, self.in_axes, self.out_axes)(x)
@ms_function
def for_net(input_x, p, axis, maxnorm):
# split and concat along dimension 0
output = []
for i in range(x.shape[0]):
out = P.Renorm(p, axis, maxnorm)(input_x[i])
output.append(out)
return F.stack(output)
x = Tensor(np.array([[[[4, 3, 2, 1], [8, 7, 6, 5], [12, 11, 10, 9]],
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]],
[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
[[4, 3, 2, 1], [8, 7, 6, 5], [12, 11, 10, 9]]]], dtype=np.float32))
output = VmapNet(Net(1, 0, 10.0), 0, 0)(x)
fornet_output = for_net(x, 1, 0, 10.0)
np.testing.assert_allclose(output.asnumpy(), fornet_output.asnumpy(), rtol=1e-6)
def vmap_nested_case():
class Net(nn.Cell):
def __init__(self, p, axis, maxnorm):
super(Net, self).__init__()
self.renorm = P.Renorm(p, axis, maxnorm)
def construct(self, x):
return self.renorm(x)
class WrapNet(nn.Cell):
def __init__(self, net, inin_axes, inout_axes, outin_axes, outout_axes):
super(WrapNet, self).__init__()
self.net = net
self.ii = inin_axes
self.io = inout_axes
self.oi = outin_axes
self.oo = outout_axes
def construct(self, x):
return vmap(vmap(self.net, self.ii, self.io), self.oi, self.oo)(x)
@ms_function
def for_net(input_x, p, axis, maxnorm):
# split and concat along dimension 0 and 1
output = []
for i in range(x.shape[0]):
inner_output = []
for j in range(x.shape[1]):
out = P.Renorm(p, axis, maxnorm)(input_x[i][j])
inner_output.append(out)
output.append(F.stack(inner_output))
return F.stack(output)
x = Tensor(np.array([[[[4, 3, 2, 1], [8, 7, 6, 5], [12, 11, 10, 9]],
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]],
[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
[[4, 3, 2, 1], [8, 7, 6, 5], [12, 11, 10, 9]]]], dtype=np.float32))
output = WrapNet(Net(1, 0, 10.0), 0, 0, 1, 1)(x)
fornet_output = for_net(x, 1, 0, 10.0)
np.testing.assert_allclose(output.asnumpy(), fornet_output.asnumpy(), rtol=1e-6)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_renorm_vmap_cpu():
"""
Feature: test Renorm vmap on CPU.
Description: inputs with batch.
Expectation: the result match with expect.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
context.set_context(save_graphs=True, save_graphs_path="./rank0")
vmap_case()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_renorm_vmap_cpu_nested():
"""
Feature: test nested Renorm vmap on CPU.
Description: inputs with batch.
Expectation: the result match with expect.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
context.set_context(save_graphs=True, save_graphs_path="./rank0")
vmap_nested_case()