diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/adjust_hue_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/adjust_hue_cpu_kernel.cc new file mode 100644 index 00000000000..3186c01b062 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/adjust_hue_cpu_kernel.cc @@ -0,0 +1,310 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/adjust_hue_cpu_kernel.h" +#include +#include +#include "plugin/device/cpu/hal/device/cpu_device_address.h" +#include "utils/ms_utils.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kAdjustHueInputNum = 2; +constexpr size_t kAdjustHueOutputNum = 1; +const std::int64_t kAdjustHueParallelNum = 8 * 1024; +const std::int64_t kAdjustHueZero = 0; +const std::int64_t kAdjustHueOne = 1; +const std::int64_t kAdjustHueTwo = 2; +const std::int64_t kAdjustHueThree = 3; +const std::int64_t kAdjustHueFour = 4; +const std::int64_t kAdjustHueFive = 5; +} // namespace + +namespace detail { +static void rgb_to_hv_range(float r, float g, float b, float *h, float *v_min, float *v_max) { + float v_mid; + int h_category; + // According to the figures in: + // https://en.wikipedia.org/wiki/HSL_and_HSV#Hue_and_chroma + // For the conditions, we don't care about the case where two components are + // equal. It is okay to count it in either side in that case. + if (r < g) { + if (b < r) { + // b < r < g + *v_max = g; + v_mid = r; + *v_min = b; + h_category = kAdjustHueOne; + } else if (b > g) { + // r < g < b + *v_max = b; + v_mid = g; + *v_min = r; + h_category = kAdjustHueThree; + } else { + // r < b < g + *v_max = g; + v_mid = b; + *v_min = r; + h_category = kAdjustHueTwo; + } + } else { + // g < r + if (b < g) { + // b < g < r + *v_max = r; + v_mid = g; + *v_min = b; + h_category = kAdjustHueZero; + } else if (b > r) { + // g < r < b + *v_max = b; + v_mid = r; + *v_min = g; + h_category = kAdjustHueFour; + } else { + // g < b < r + *v_max = r; + v_mid = b; + *v_min = g; + h_category = kAdjustHueFive; + } + } + if (*v_max == *v_min) { + *h = 0; + return; + } + auto ratio = (v_mid - *v_min) / (*v_max - *v_min); + bool increase = ((h_category & 0x1) == 0); + *h = h_category + (increase ? ratio : (1 - ratio)); +} + +// Helper function to convert from H-and-V-range to RGB. +template +static void hv_range_to_rgb(float h, float v_min, float v_max, T *r, T *g, T *b) { + int h_category = static_cast(h); + float ratio = h - h_category; + bool increase = ((h_category & 0x1) == 0); + if (!increase) { + ratio = 1 - ratio; + } + float v_mid = v_min + ratio * (v_max - v_min); + // According to the figures in: + // https://en.wikipedia.org/wiki/HSL_and_HSV#Hue_and_chroma + switch (h_category) { + case kAdjustHueZero: + *r = static_cast(v_max); + *g = static_cast(v_mid); + *b = static_cast(v_min); + break; + case kAdjustHueOne: + *r = static_cast(v_mid); + *g = static_cast(v_max); + *b = static_cast(v_min); + break; + case kAdjustHueTwo: + *r = static_cast(v_min); + *g = static_cast(v_max); + *b = static_cast(v_mid); + break; + case kAdjustHueThree: + *r = static_cast(v_min); + *g = static_cast(v_mid); + *b = static_cast(v_max); + break; + case kAdjustHueFour: + *r = static_cast(v_mid); + *g = static_cast(v_min); + *b = static_cast(v_max); + break; + case kAdjustHueFive: + default: + *r = static_cast(v_max); + *g = static_cast(v_min); + *b = static_cast(v_mid); + } +} + +HsvTuple rgb2hsv(const float r, const float g, const float b) { + HsvTuple tuple; + const float M = fmaxf(r, fmaxf(g, b)); + const float m = fminf(r, fminf(g, b)); + const float chroma = M - m; + float h = 0.0f, s = 0.0f; + // hue + if (chroma > 0.0f) { + if (M == r) { + const float num = (g - b) / chroma; + const float sign = copysignf(1.0f, num); + h = ((sign < 0.0f) * 6.0f + sign * fmodf(sign * num, 6.0f)) / 6.0f; + } else if (M == g) { + h = ((b - r) / chroma + 2.0f) / 6.0f; + } else { + h = ((r - g) / chroma + 4.0f) / 6.0f; + } + } else { + h = 0.0f; + } + // saturation + if (M > 0) { + s = chroma / M; + } else { + s = 0.0f; + } + tuple.h = h; + tuple.s = s; + tuple.v = M; + return tuple; +} + +RgbTuple hsv2rgb(const float h, const float s, const float v) { + RgbTuple tuple; + const float new_h = h * 6.0f; + const float chroma = v * s; + const float x = chroma * (1.0f - fabsf(fmodf(new_h, 2.0f) - 1.0f)); + const float new_m = v - chroma; + const bool between_0_and_1 = new_h >= 0.0f && new_h < 1.0f; + const bool between_1_and_2 = new_h >= 1.0f && new_h < 2.0f; + const bool between_2_and_3 = new_h >= 2.0f && new_h < 3.0f; + const bool between_3_and_4 = new_h >= 3.0f && new_h < 4.0f; + const bool between_4_and_5 = new_h >= 4.0f && new_h < 5.0f; + const bool between_5_and_6 = new_h >= 5.0f && new_h < 6.0f; + tuple.r = chroma * (between_0_and_1 || between_5_and_6) + x * (between_1_and_2 || between_4_and_5) + new_m; + tuple.g = chroma * (between_1_and_2 || between_2_and_3) + x * (between_0_and_1 || between_3_and_4) + new_m; + tuple.b = chroma * (between_3_and_4 || between_4_and_5) + x * (between_2_and_3 || between_5_and_6) + new_m; + return tuple; +} + +template +bool LaunchAdjustHueKernel(const std::vector &inputs, + const std::vector &outputs) { + auto input_data = static_cast(inputs[0]->addr); + auto output_data = static_cast(outputs[0]->addr); + auto delta_h = static_cast(inputs[1]->addr)[0]; + std::int64_t num_elements = inputs[0]->size / sizeof(T); + constexpr int64_t kChannelSize = 3; + auto sharder_adjusthue = [input_data, delta_h, output_data, kChannelSize](int64_t start, int64_t end) { + for (int64_t i = start * kChannelSize; i < end * kChannelSize; i = i + kChannelSize) { + // CPU compute + float h, v_min, v_max; + rgb_to_hv_range(static_cast(*(input_data + i)), static_cast(*(input_data + i + 1)), + static_cast(*(input_data + i + 2)), &h, &v_min, &v_max); + + static const int kChannelRange = 6; + // Adjust the hue value. And adjust the hue back into the valid + // range of [0, 6). It is faster than a fmod by avoiding + // a float-point division since h is often very close to this + // range. + h += delta_h * kChannelRange; + while (h < 0) { + h += kChannelRange; + } + while (h >= kChannelRange) { + h -= kChannelRange; + } + + hv_range_to_rgb(h, v_min, v_max, &output_data[i], &output_data[i + 1], &output_data[i + 2]); + } + }; + std::int64_t total = num_elements / kChannelSize; + std::int64_t per_unit_size{total / std::min(kAdjustHueParallelNum - SizeToLong(kAdjustHueInputNum), total)}; + if (total > kAdjustHueParallelNum) { + CPUKernelUtils::ParallelFor(sharder_adjusthue, total, per_unit_size); + } else { + sharder_adjusthue(0, total); + } + return true; +} + +template +bool LaunchAdjustHueKernelHalf(const std::vector &inputs, + const std::vector &outputs) { + auto input_data = static_cast(inputs[0]->addr); + auto output_data = static_cast(outputs[0]->addr); + auto delta_h = static_cast(inputs[1]->addr)[0]; + std::int64_t num_elements = inputs[0]->size / sizeof(T); + constexpr int64_t kChannelSize = 3; + auto sharder_adjusthue = [input_data, delta_h, output_data, kChannelSize](int64_t start, int64_t end) { + for (int64_t i = start * kChannelSize; i < end * kChannelSize; i = i + kChannelSize) { + const HsvTuple hsv = rgb2hsv(static_cast(*(input_data + i)), static_cast(*(input_data + i + 1)), + static_cast(*(input_data + i + 2))); + float new_h = hsv.h; + float new_s = hsv.s; + float new_v = hsv.v; + // hue adjustment + new_h = fmodf(hsv.h + delta_h, 1.0f); + if (new_h < 0.0f) { + new_h = fmodf(1.0f + new_h, 1.0f); + } + const RgbTuple rgb = hsv2rgb(new_h, new_s, new_v); + output_data[i] = static_cast(rgb.r); + output_data[i + 1] = static_cast(rgb.g); + output_data[i + 2] = static_cast(rgb.b); + } + }; + std::int64_t total = num_elements / kChannelSize; + std::int64_t per_unit_size{total / std::min(kAdjustHueParallelNum - SizeToLong(kAdjustHueInputNum), total)}; + if (total > kAdjustHueParallelNum) { + CPUKernelUtils::ParallelFor(sharder_adjusthue, total, per_unit_size); + } else { + sharder_adjusthue(0, total); + } + return true; +} +} // namespace detail + +void AdjustHueCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector image_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + std::vector output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); + if (image_shape != output_shape) { + MS_LOG(EXCEPTION) << "For AdjustHue, the data type of the input " << image_shape + << "need be the same as the output " << output_shape << "."; + } + size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); + CHECK_KERNEL_INPUTS_NUM(input_num, kAdjustHueInputNum, kernel_name_); + size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node); + CHECK_KERNEL_OUTPUTS_NUM(output_num, kAdjustHueOutputNum, kernel_name_); +} + +bool AdjustHueCpuKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + switch (dtype_) { + case kNumberTypeFloat16: + detail::LaunchAdjustHueKernelHalf(inputs, outputs); + break; + case kNumberTypeFloat32: + detail::LaunchAdjustHueKernel(inputs, outputs); + break; + default: + MS_LOG(EXCEPTION) << "For AdjustHue, the type of 'image' should be float16, float32, but got " + << TypeIdLabel(dtype_) << "."; + return false; + } + return true; +} + +std::vector AdjustHueCpuKernelMod::GetOpSupport() { + static std::vector support_list = { + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}; + return support_list; +} +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, AdjustHue, AdjustHueCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/adjust_hue_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/adjust_hue_cpu_kernel.h new file mode 100644 index 00000000000..09b39adbfc3 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/adjust_hue_cpu_kernel.h @@ -0,0 +1,56 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADJUST_HUE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADJUST_HUE_CPU_KERNEL_H_ +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +struct HsvTuple { + float h; + float s; + float v; +}; +struct RgbTuple { + float r; + float g; + float b; +}; + +class AdjustHueCpuKernelMod : public DeprecatedNativeCpuKernelMod { + public: + AdjustHueCpuKernelMod() = default; + ~AdjustHueCpuKernelMod() override = default; + + void InitKernel(const CNodePtr &Kernel_node); + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + protected: + std::vector GetOpSupport() override; + + private: + TypeId dtype_{kTypeUnknown}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADJUST_HUE_CPU_KERNEL_H_ diff --git a/mindspore/core/ops/adjust_hue.cc b/mindspore/core/ops/adjust_hue.cc new file mode 100644 index 00000000000..4a54505d1b1 --- /dev/null +++ b/mindspore/core/ops/adjust_hue.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/adjust_hue.h" +#include +#include +#include +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::ShapePtr AdjustHueInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + auto prim_name = primitive->name(); + auto input_shape_images = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + const int64_t min_dim = 3; + (void)CheckAndConvertUtils::CheckInteger("dimension of image", SizeToLong(input_shape_images.size()), kGreaterEqual, + min_dim, prim_name); + (void)CheckAndConvertUtils::CheckInteger("last dimension of image", input_shape_images[input_shape_images.size() - 1], + kEqual, min_dim, prim_name); + auto input_shape_delta = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; + const int64_t delta_dim = 0; + (void)CheckAndConvertUtils::CheckInteger("dimension of delta", SizeToLong(input_shape_delta.size()), kEqual, + delta_dim, prim_name); + return std::make_shared(input_shape_images); +} + +TypePtr AdjustHueInferType(const PrimitivePtr &primitive, const std::vector &input_args) { + auto prim_name = primitive->name(); + auto input_type_images = input_args[0]->BuildType(); + MS_EXCEPTION_IF_NULL(input_type_images); + const std::set valid_types = {kFloat16, kFloat32}; + (void)CheckAndConvertUtils::CheckTensorTypeValid("images", input_type_images, valid_types, prim_name); + auto input_type_delta = input_args[1]->BuildType(); + MS_EXCEPTION_IF_NULL(input_type_delta); + (void)CheckAndConvertUtils::CheckTensorTypeValid("delta", input_type_delta, {kFloat32}, prim_name); + return input_type_images; +} +} // namespace + +MIND_API_BASE_IMPL(AdjustHue, PrimitiveC, BaseOperator); +AbstractBasePtr AdjustHueInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + const int64_t kInputsNum = 2; + (void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name()); + auto infer_type = AdjustHueInferType(primitive, input_args); + auto infer_shape = AdjustHueInferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} +REGISTER_PRIMITIVE_EVAL_IMPL(AdjustHue, prim::kPrimAdjustHue, AdjustHueInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/adjust_hue.h b/mindspore/core/ops/adjust_hue.h new file mode 100644 index 00000000000..1897a7a4813 --- /dev/null +++ b/mindspore/core/ops/adjust_hue.h @@ -0,0 +1,44 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_ADJUST_HUE_H_ +#define MINDSPORE_CORE_OPS_ADJUST_HUE_H_ +#include +#include +#include +#include "ops/primitive_c.h" +#include "ops/op_utils.h" +#include "ops/base_operator.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameAdjustHue = "AdjustHue"; +/// \brief Adjust hue of RGB images. +/// Refer to Python API @ref mindspore.ops.AdjustHue for more details. +class MIND_API AdjustHue : public BaseOperator { + public: + MIND_API_BASE_MEMBER(AdjustHue); + AdjustHue() : BaseOperator(kNameAdjustHue) { InitIOName({"images", "delta"}, {"y"}); } +}; + +abstract::AbstractBasePtr AdjustHueInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimAdjustHuePtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_ADJUST_HUE_H_ diff --git a/mindspore/core/ops/core_ops.h b/mindspore/core/ops/core_ops.h index 0c212865e6c..37788417477 100644 --- a/mindspore/core/ops/core_ops.h +++ b/mindspore/core/ops/core_ops.h @@ -32,6 +32,7 @@ GVAR_DEF(ValuePtr, kValueOne, std::make_shared(1)); GVAR_DEF(mindspore::HashMap, kSideEffectPropagate, {{mindspore::GRAPH_FLAG_SIDE_EFFECT_PROPAGATE COMMA kValueOne}}); #undef COMMA +constexpr auto kAdjustHue = "AdjustHue"; constexpr auto kGetNext = "GetNext"; constexpr auto kGather = "Gather"; constexpr auto kAddcdiv = "Addcdiv"; @@ -814,6 +815,7 @@ GVAR_DEF(PrimitivePtr, kPrimZeta, std::make_shared("Zeta")); // Image GVAR_DEF(PrimitivePtr, kPrimNonMaxSuppressionV3, std::make_shared("NonMaxSuppressionV3")); +GVAR_DEF(PrimitivePtr, kPrimAdjustHue, std::make_shared(kAdjustHue)); // Statements GVAR_DEF(PrimitivePtr, kPrimReturn, std::make_shared("Return")); diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py index 816f82f06ad..839b49f257a 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py @@ -164,3 +164,4 @@ from .transpose import _transpose_aicpu from .trace import _trace_aicpu from .tracegrad import _tracegrad_aicpu from .zeta import _zeta_aicpu +from .adjust_hue import _adjust_hue_aicpu diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/adjust_hue.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/adjust_hue.py new file mode 100644 index 00000000000..779884a35cb --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/adjust_hue.py @@ -0,0 +1,31 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""AdjustHue op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType +adjust_hue_op_info = AiCPURegOp("AdjustHue") \ + .fusion_type("OPAQUE") \ + .input(0, "images", "required") \ + .input(1, "delta", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.F16_Default, DataType.F32_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(adjust_hue_op_info) +def _adjust_hue_aicpu(): + """AdjustHue AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/operations/image_ops.py b/mindspore/python/mindspore/ops/operations/image_ops.py index e470ef4931c..77453545fab 100644 --- a/mindspore/python/mindspore/ops/operations/image_ops.py +++ b/mindspore/python/mindspore/ops/operations/image_ops.py @@ -1,4 +1,4 @@ -# Copyright 2020-2021 Huawei Technologies Co., Ltd +# Copyright 2020-2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,6 +21,63 @@ from ...common import dtype as mstype from ..primitive import PrimitiveWithInfer, prim_attr_register, Primitive +class AdjustHue(Primitive): + """ + Adjust hue of RGB images. + + Note: + This is a convenience method that converts an RGB image to float + representation, converts it to HSV, adds an offset to the + hue channel, converts back to RGB and then back to the original + data type. If several adjustments are chained it is advisable to minimize + the number of redundant conversions. + + Inputs: + - **image** (Tensor): RGB image or images. The size of the last dimension must be 3. + the dtype is float16 or float32. At least 3-D. + - **delta** (Tensor): How much to add to the hue channel, the dtype is float32. Must be 0-D. + + Output: + Adjusted image(s), same shape and dtype as `image`. + + Raises: + TypeError: If neither `image` nor `delta` is a tensor. + TypeError: If the dtype of image not float16 or float32. + TypeError: If the dtype of delta not float32. + ValueError: If image have at less than 3 dimensions. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> class AdjustHue(nn.Cell): + ... def __init__(self): + ... super(AdjustHue, self).__init__() + ... self.adjustHue = P.AdjustHue() + ... def construct(self, image, delta): + ... return self.adjustHue(image, delta) + ... + >>> image = np.array([[[1, 2, 3], [4, 5, 6]], + ... [[7, 8, 9], [10, 11, 12]], + ... [[13, 14, 15], [16, 17, 18]]]).astype(np.float32) + >>> delta = 0.2 + >>> adjust_hue = AdjustHue() + >>> output = adjust_hue(Tensor(image), Tensor(delta)) + >>> print("output", output) + output [[[ 2.3999996 1. 3. ] + [ 5.3999996 4. 6. ]] + [[ 8.4 7. 9. ] + [11.4 10. 12. ]] + [[14.4 13. 15. ] + [17.4 16. 18. ]]] + """ + + @prim_attr_register + def __init__(self): + """Initialize AdjustHue""" + self.init_prim_io_names(inputs=['images', 'delta'], outputs=['y']) + + class CropAndResize(PrimitiveWithInfer): """ Extracts crops from the input image tensor and resizes them. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 822de9eb466..ff52dd6d6c6 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -26,7 +26,7 @@ from mindspore import ms_function from mindspore.common import dtype as mstype from mindspore.ops import functional as F from mindspore.ops import operations as P -from mindspore.ops.operations.image_ops import CropAndResizeGradBoxes +from mindspore.ops.operations.image_ops import CropAndResizeGradBoxes, AdjustHue from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _inner_ops as inner from mindspore.ops.operations import _quant_ops as Q @@ -3071,6 +3071,16 @@ test_case_array_ops = [ ] test_case_image_ops = [ + ('AdjustHue', { + 'block': AdjustHue(), + 'desc_inputs': [Tensor(np.array([[[1, 2, 3], + [4, 5, 6]], + [[7, 8, 9], + [10, 11, 12]], + [[13, 14, 15], + [16, 17, 18]]]).astype(np.float32)), + Tensor(0.2, mstype.float32)], + 'skip': ['backward']}), ('NonMaxSuppressionV3', { 'block': P.NonMaxSuppressionV3(), 'desc_inputs': [Tensor(np.array([[20, 5, 200, 100],