!32951 [feat] [assistant] [I48O56] Add new aicpu operator AdjustSaturation

Merge pull request !32951 from 陈慧敏/AdjustSaturation
This commit is contained in:
i-robot 2022-05-18 07:10:53 +00:00 committed by Gitee
commit d21686c2fe
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 439 additions and 1 deletions

View File

@ -0,0 +1,181 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/adjust_saturation_cpu_kernel.h"
#include <Eigen/Dense>
#include <algorithm>
#include <iostream>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace kernel {
namespace {
const std::int64_t kAdjustSaturationParallelNum = 64 * 1024;
const std::int64_t kAdjustSaturationZero = 0;
const std::int64_t kAdjustSaturationOne = 1;
const std::int64_t kAdjustSaturationTwo = 2;
const std::int64_t kAdjustSaturationThree = 3;
const std::int64_t kAdjustSaturationFour = 4;
const std::int64_t kAdjustSaturationFive = 5;
const std::float_t kAdjustSaturationSix = 6;
} // namespace
namespace detail {
static void rgb_to_hsv(float r, float g, float b, float *h, float *s, float *v) {
float vv = std::max(r, std::max(g, b));
float range = vv - std::min(r, std::min(g, b));
if (vv > 0) {
*s = range / vv;
} else {
*s = 0;
}
float norm = kAdjustSaturationOne / (kAdjustSaturationSix * range);
float hh;
if (r == vv) {
hh = norm * (g - b);
} else if (g == vv) {
hh = norm * (b - r) + kAdjustSaturationTwo / kAdjustSaturationSix;
} else {
hh = norm * (r - g) + kAdjustSaturationFour / kAdjustSaturationSix;
}
if (range <= 0.0) {
hh = 0;
}
if (hh < 0.0) {
hh = hh + kAdjustSaturationOne;
}
*v = vv;
*h = hh;
}
template <typename T>
static void hsv_to_rgb(float h, float s, float v, T *r, T *g, T *b) {
float c = s * v;
float m = v - c;
float dh = h * kAdjustSaturationSix;
float rr, gg, bb;
int h_category = static_cast<int>(dh);
float fmodu = dh;
while (fmodu <= 0) {
fmodu += kAdjustSaturationTwo;
}
while (fmodu >= kAdjustSaturationTwo) {
fmodu -= kAdjustSaturationTwo;
}
float x = c * (1 - std::abs(fmodu - 1));
switch (h_category) {
case kAdjustSaturationZero:
rr = c;
gg = x;
bb = 0;
break;
case kAdjustSaturationOne:
rr = x;
gg = c;
bb = 0;
break;
case kAdjustSaturationTwo:
rr = 0;
gg = c;
bb = x;
break;
case kAdjustSaturationThree:
rr = 0;
gg = x;
bb = c;
break;
case kAdjustSaturationFour:
rr = x;
gg = 0;
bb = c;
break;
case kAdjustSaturationFive:
rr = c;
gg = 0;
bb = x;
break;
default:
rr = 0;
gg = 0;
bb = 0;
}
*r = static_cast<T>(rr + m);
*g = static_cast<T>(gg + m);
*b = static_cast<T>(bb + m);
}
template <typename T>
bool LaunchAdjustSaturationKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
auto input{static_cast<T *>(inputs[0]->addr)};
auto scale{static_cast<std::float_t *>(inputs[1]->addr)};
auto output{static_cast<T *>(outputs[0]->addr)};
constexpr int64_t kChannelSize = 3;
std::int64_t num_elements = inputs[0]->size / sizeof(T);
auto sharder_adjustsaturation = [input, scale, output, kChannelSize](int64_t start, int64_t end) {
for (int64_t i = start * kChannelSize; i < end * kChannelSize; i = i + kChannelSize) {
float h, s, v;
// Convert the RGB color to Hue/V-range.
rgb_to_hsv(static_cast<float>(*(input + i)), static_cast<float>(*(input + i + 1)),
static_cast<float>(*(input + i + 2)), &h, &s, &v);
s = std::min(1.0f, std::max(0.0f, s * scale[0]));
// Convert the hue and v-range back into RGB.
hsv_to_rgb<T>(h, s, v, &output[i], &output[i + 1], &output[i + 2]);
}
};
std::int64_t total = num_elements / kChannelSize;
if (total > kAdjustSaturationParallelNum) {
std::int64_t per_unit_size =
total / std::min(kAdjustSaturationParallelNum - SizeToLong(kAdjustSaturationTwo), total);
CPUKernelUtils::ParallelFor(sharder_adjustsaturation, total, per_unit_size);
} else {
sharder_adjustsaturation(0, total);
}
return true;
}
} // namespace detail
void AdjustSaturationCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
CHECK_KERNEL_INPUTS_NUM(input_num, kAdjustSaturationTwo, kernel_name_);
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
CHECK_KERNEL_OUTPUTS_NUM(output_num, kAdjustSaturationOne, kernel_name_);
input_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
}
bool AdjustSaturationCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) {
if (input_type_ == kNumberTypeFloat32) {
return detail::LaunchAdjustSaturationKernel<float>(inputs, outputs);
} else if (input_type_ == kNumberTypeFloat16) {
return detail::LaunchAdjustSaturationKernel<Eigen::half>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', unsupported input data type " << TypeIdLabel(input_type_);
return false;
}
return true;
}
std::vector<KernelAttr> AdjustSaturationCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)};
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, AdjustSaturation, AdjustSaturationCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,45 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADJUST_SATURATION_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADJUST_SATURATION_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class AdjustSaturationCpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
AdjustSaturationCpuKernelMod() = default;
~AdjustSaturationCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node);
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
private:
TypeId input_type_{kTypeUnknown};
protected:
std::vector<KernelAttr> GetOpSupport() override;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADJUST_SATURATION_CPU_KERNEL_H_

View File

@ -0,0 +1,76 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/adjust_saturation.h"
#include <memory>
#include <vector>
#include <set>
#include <algorithm>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr AdjustSaturationInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto input_image_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto input_scale_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
const int64_t min_image_dim = 3;
const int64_t scale_dim = 0;
(void)CheckAndConvertUtils::CheckInteger("dimension of AdjustSaturation input image",
SizeToLong(input_image_shape.size()), kGreaterEqual, min_image_dim,
prim_name);
(void)CheckAndConvertUtils::CheckInteger("last dimension of AdjustSaturation input image",
input_image_shape[input_image_shape.size() - 1], kEqual, min_image_dim,
prim_name);
(void)CheckAndConvertUtils::CheckInteger("dimension of AdjustSaturation input scale",
SizeToLong(input_scale_shape.size()), kEqual, scale_dim, prim_name);
return std::make_shared<abstract::Shape>(input_image_shape);
}
TypePtr AdjustSaturationInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = prim->name();
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);
auto input_images_type = input_args[0]->BuildType();
auto input_scale_type = input_args[1]->BuildType();
MS_EXCEPTION_IF_NULL(input_images_type);
MS_EXCEPTION_IF_NULL(input_scale_type);
const std::set<TypePtr> valid_images_types = {kFloat16, kFloat32};
(void)CheckAndConvertUtils::CheckTensorTypeValid("image", input_images_type, valid_images_types, prim_name);
const std::set<TypePtr> valid_scale_types = {kFloat32};
(void)CheckAndConvertUtils::CheckTensorTypeValid("scale", input_scale_type, valid_scale_types, prim_name);
return input_images_type;
}
} // namespace
MIND_API_OPERATOR_IMPL(AdjustSaturation, BaseOperator);
AbstractBasePtr AdjustSaturationInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kInputsNum = 2;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
auto infer_type = AdjustSaturationInferType(primitive, input_args);
auto infer_shape = AdjustSaturationInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(AdjustSaturation, prim::kPrimAdjustSaturation, AdjustSaturationInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,45 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_ADJUST_SATURATION_H_
#define MINDSPORE_CORE_OPS_ADJUST_SATURATION_H_
#include <memory>
#include <vector>
#include <string>
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "ops/base_operator.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAdjustSaturation = "AdjustSaturation";
/// \brief Convert the images to HSV and multiply the saturation (S) channel by `scale` and clipping.
/// Refer to Python API @ref mindspore.ops.AdjustSaturation for more details.
class MIND_API AdjustSaturation : public BaseOperator {
public:
MIND_API_BASE_MEMBER(AdjustSaturation);
/// \brief Constructor.
AdjustSaturation() : BaseOperator(kNameAdjustSaturation) { InitIOName({"image", "scale"}, {"y"}); }
};
abstract::AbstractBasePtr AdjustSaturationInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimAdjustSaturationPtr = std::shared_ptr<AdjustSaturation>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_ADJUST_SATURATION_H_

View File

@ -34,6 +34,7 @@ GVAR_DEF(mindspore::HashMap<std::string COMMA ValuePtr>, kSideEffectPropagate,
#undef COMMA
constexpr auto kAdjustHue = "AdjustHue";
constexpr auto kAdjustContrastv2 = "AdjustContrastv2";
constexpr auto kAdjustSaturation = "AdjustSaturation";
constexpr auto kGetNext = "GetNext";
constexpr auto kGather = "Gather";
constexpr auto kAddcdiv = "Addcdiv";
@ -841,6 +842,7 @@ GVAR_DEF(PrimitivePtr, kPrimSvd, std::make_shared<Primitive>("Svd"));
GVAR_DEF(PrimitivePtr, kPrimNonMaxSuppressionV3, std::make_shared<Primitive>("NonMaxSuppressionV3"));
GVAR_DEF(PrimitivePtr, kPrimAdjustHue, std::make_shared<Primitive>(kAdjustHue));
GVAR_DEF(PrimitivePtr, kPrimAdjustContrastv2, std::make_shared<Primitive>(kAdjustContrastv2));
GVAR_DEF(PrimitivePtr, kPrimAdjustSaturation, std::make_shared<Primitive>(kAdjustSaturation));
// Statements
GVAR_DEF(PrimitivePtr, kPrimReturn, std::make_shared<Primitive>("Return"));

View File

@ -178,3 +178,4 @@ from .adjust_hue import _adjust_hue_aicpu
from .maxpool_v1 import _maxpool_v1_aicpu
from .maxpool_grad_v1 import _maxpool_grad_v1_aicpu
from .dense_to_csr_sparse_matrix import _dense_to_csr_sparse_matrix_aicpu
from .adjust_saturation import _adjust_saturation_aicpu

View File

@ -0,0 +1,32 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""AdjustSaturation op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
adjust_saturation_op_info = AiCPURegOp("AdjustSaturation") \
.fusion_type("OPAQUE") \
.input(0, "image", "required") \
.input(1, "scale", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F16_Default, DataType.F32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(adjust_saturation_op_info)
def _adjust_saturation_aicpu():
"""AdjustSaturation aicpu register"""
return

View File

@ -21,6 +21,53 @@ from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, prim_attr_register, Primitive
class AdjustSaturation(Primitive):
"""
Adjust saturation of RGB images.
Note:
This is a convenience method that converts RGB images to float representation, converts them to HSV,
adds an offset to the saturation channel, converts back to RGB and then back to the original data type.
If several adjustments are chained it is advisable to minimize the number of redundant conversions.
inputs:
- **image** (Tensor): Images to adjust. Must be one of the following types: float16, float32.
At least 3-D.The last dimension is interpreted as channels, and must be three.
- **scale** (Tensor): A float scale to add to the saturation. A Tensor of type float32. Must be 0-D.
Output:
Adjusted image(s), same shape and dtype as `image`.
Raises:
TypeError: If any iput is not Tensor.
TypeError: If the type of `image` is not one of the following dtype: float16, float32.
TypeError: If the type of `scale` is not float32.
ValueError: If the dimension of the 'image' is less than 3, or the last dimension of the 'image' is not 3.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> x = Tensor([[[1.0, 2.0, 3.0],
... [4.0, 5.0, 6.0]],
... [[7.0, 8.0, 9.0],
... [10.0, 11.0, 12.0]]])
>>> scale = Tensor(float(0.5))
>>> adjustsaturation = AdjustSaturation()
>>> output = adjustsaturation(x, scale)
>>> print(output)
[[[ 2. 2.4999998 3. ]
[ 5. 5.5 6. ]]
[[ 8. 8.5 9. ]
[11. 11.5 12. ]]]
"""
@prim_attr_register
def __init__(self):
"""Initialize AdjustSaturation"""
self.init_prim_io_names(inputs=['images', 'scale'], outputs=['y'])
class AdjustContrastv2(Primitive):
"""
Adjust contrastv2 of images.

View File

@ -26,7 +26,8 @@ from mindspore import ms_function
from mindspore.common import dtype as mstype
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops.operations.image_ops import CropAndResizeGradBoxes, AdjustHue, AdjustContrastv2
from mindspore.ops.operations.image_ops import CropAndResizeGradBoxes, AdjustHue, AdjustContrastv2, \
AdjustSaturation
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.operations import _quant_ops as Q
@ -3106,6 +3107,14 @@ test_case_image_ops = [
[10.0, 11.0, 12.0]]]),
Tensor(0.5, mstype.float32)],
'skip': ['backward']}),
('AdjustSaturation', {
'block': AdjustSaturation(),
'desc_inputs': [Tensor([[[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]],
[[7.0, 8.0, 9.0],
[10.0, 11.0, 12.0]]]),
Tensor(0.5, mstype.float32)],
'skip': ['backward']}),
('NonMaxSuppressionV3', {
'block': P.NonMaxSuppressionV3(),
'desc_inputs': [Tensor(np.array([[20, 5, 200, 100],