!33338 [assistant][ops][I48O5W] Add new AdjustHue operator

Merge pull request !33338 from 邹天宇/AdjustHue
This commit is contained in:
i-robot 2022-05-11 01:51:49 +00:00 committed by Gitee
commit 3aca62f878
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 582 additions and 2 deletions

View File

@ -0,0 +1,310 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/adjust_hue_cpu_kernel.h"
#include <Eigen/Dense>
#include <algorithm>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kAdjustHueInputNum = 2;
constexpr size_t kAdjustHueOutputNum = 1;
const std::int64_t kAdjustHueParallelNum = 8 * 1024;
const std::int64_t kAdjustHueZero = 0;
const std::int64_t kAdjustHueOne = 1;
const std::int64_t kAdjustHueTwo = 2;
const std::int64_t kAdjustHueThree = 3;
const std::int64_t kAdjustHueFour = 4;
const std::int64_t kAdjustHueFive = 5;
} // namespace
namespace detail {
static void rgb_to_hv_range(float r, float g, float b, float *h, float *v_min, float *v_max) {
float v_mid;
int h_category;
// According to the figures in:
// https://en.wikipedia.org/wiki/HSL_and_HSV#Hue_and_chroma
// For the conditions, we don't care about the case where two components are
// equal. It is okay to count it in either side in that case.
if (r < g) {
if (b < r) {
// b < r < g
*v_max = g;
v_mid = r;
*v_min = b;
h_category = kAdjustHueOne;
} else if (b > g) {
// r < g < b
*v_max = b;
v_mid = g;
*v_min = r;
h_category = kAdjustHueThree;
} else {
// r < b < g
*v_max = g;
v_mid = b;
*v_min = r;
h_category = kAdjustHueTwo;
}
} else {
// g < r
if (b < g) {
// b < g < r
*v_max = r;
v_mid = g;
*v_min = b;
h_category = kAdjustHueZero;
} else if (b > r) {
// g < r < b
*v_max = b;
v_mid = r;
*v_min = g;
h_category = kAdjustHueFour;
} else {
// g < b < r
*v_max = r;
v_mid = b;
*v_min = g;
h_category = kAdjustHueFive;
}
}
if (*v_max == *v_min) {
*h = 0;
return;
}
auto ratio = (v_mid - *v_min) / (*v_max - *v_min);
bool increase = ((h_category & 0x1) == 0);
*h = h_category + (increase ? ratio : (1 - ratio));
}
// Helper function to convert from H-and-V-range to RGB.
template <typename T>
static void hv_range_to_rgb(float h, float v_min, float v_max, T *r, T *g, T *b) {
int h_category = static_cast<int>(h);
float ratio = h - h_category;
bool increase = ((h_category & 0x1) == 0);
if (!increase) {
ratio = 1 - ratio;
}
float v_mid = v_min + ratio * (v_max - v_min);
// According to the figures in:
// https://en.wikipedia.org/wiki/HSL_and_HSV#Hue_and_chroma
switch (h_category) {
case kAdjustHueZero:
*r = static_cast<T>(v_max);
*g = static_cast<T>(v_mid);
*b = static_cast<T>(v_min);
break;
case kAdjustHueOne:
*r = static_cast<T>(v_mid);
*g = static_cast<T>(v_max);
*b = static_cast<T>(v_min);
break;
case kAdjustHueTwo:
*r = static_cast<T>(v_min);
*g = static_cast<T>(v_max);
*b = static_cast<T>(v_mid);
break;
case kAdjustHueThree:
*r = static_cast<T>(v_min);
*g = static_cast<T>(v_mid);
*b = static_cast<T>(v_max);
break;
case kAdjustHueFour:
*r = static_cast<T>(v_mid);
*g = static_cast<T>(v_min);
*b = static_cast<T>(v_max);
break;
case kAdjustHueFive:
default:
*r = static_cast<T>(v_max);
*g = static_cast<T>(v_min);
*b = static_cast<T>(v_mid);
}
}
HsvTuple rgb2hsv(const float r, const float g, const float b) {
HsvTuple tuple;
const float M = fmaxf(r, fmaxf(g, b));
const float m = fminf(r, fminf(g, b));
const float chroma = M - m;
float h = 0.0f, s = 0.0f;
// hue
if (chroma > 0.0f) {
if (M == r) {
const float num = (g - b) / chroma;
const float sign = copysignf(1.0f, num);
h = ((sign < 0.0f) * 6.0f + sign * fmodf(sign * num, 6.0f)) / 6.0f;
} else if (M == g) {
h = ((b - r) / chroma + 2.0f) / 6.0f;
} else {
h = ((r - g) / chroma + 4.0f) / 6.0f;
}
} else {
h = 0.0f;
}
// saturation
if (M > 0) {
s = chroma / M;
} else {
s = 0.0f;
}
tuple.h = h;
tuple.s = s;
tuple.v = M;
return tuple;
}
RgbTuple hsv2rgb(const float h, const float s, const float v) {
RgbTuple tuple;
const float new_h = h * 6.0f;
const float chroma = v * s;
const float x = chroma * (1.0f - fabsf(fmodf(new_h, 2.0f) - 1.0f));
const float new_m = v - chroma;
const bool between_0_and_1 = new_h >= 0.0f && new_h < 1.0f;
const bool between_1_and_2 = new_h >= 1.0f && new_h < 2.0f;
const bool between_2_and_3 = new_h >= 2.0f && new_h < 3.0f;
const bool between_3_and_4 = new_h >= 3.0f && new_h < 4.0f;
const bool between_4_and_5 = new_h >= 4.0f && new_h < 5.0f;
const bool between_5_and_6 = new_h >= 5.0f && new_h < 6.0f;
tuple.r = chroma * (between_0_and_1 || between_5_and_6) + x * (between_1_and_2 || between_4_and_5) + new_m;
tuple.g = chroma * (between_1_and_2 || between_2_and_3) + x * (between_0_and_1 || between_3_and_4) + new_m;
tuple.b = chroma * (between_3_and_4 || between_4_and_5) + x * (between_2_and_3 || between_5_and_6) + new_m;
return tuple;
}
template <typename T>
bool LaunchAdjustHueKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto input_data = static_cast<T *>(inputs[0]->addr);
auto output_data = static_cast<T *>(outputs[0]->addr);
auto delta_h = static_cast<std::float_t *>(inputs[1]->addr)[0];
std::int64_t num_elements = inputs[0]->size / sizeof(T);
constexpr int64_t kChannelSize = 3;
auto sharder_adjusthue = [input_data, delta_h, output_data, kChannelSize](int64_t start, int64_t end) {
for (int64_t i = start * kChannelSize; i < end * kChannelSize; i = i + kChannelSize) {
// CPU compute
float h, v_min, v_max;
rgb_to_hv_range(static_cast<float>(*(input_data + i)), static_cast<float>(*(input_data + i + 1)),
static_cast<float>(*(input_data + i + 2)), &h, &v_min, &v_max);
static const int kChannelRange = 6;
// Adjust the hue value. And adjust the hue back into the valid
// range of [0, 6). It is faster than a fmod by avoiding
// a float-point division since h is often very close to this
// range.
h += delta_h * kChannelRange;
while (h < 0) {
h += kChannelRange;
}
while (h >= kChannelRange) {
h -= kChannelRange;
}
hv_range_to_rgb<T>(h, v_min, v_max, &output_data[i], &output_data[i + 1], &output_data[i + 2]);
}
};
std::int64_t total = num_elements / kChannelSize;
std::int64_t per_unit_size{total / std::min(kAdjustHueParallelNum - SizeToLong(kAdjustHueInputNum), total)};
if (total > kAdjustHueParallelNum) {
CPUKernelUtils::ParallelFor(sharder_adjusthue, total, per_unit_size);
} else {
sharder_adjusthue(0, total);
}
return true;
}
template <typename T>
bool LaunchAdjustHueKernelHalf(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto input_data = static_cast<T *>(inputs[0]->addr);
auto output_data = static_cast<T *>(outputs[0]->addr);
auto delta_h = static_cast<std::float_t *>(inputs[1]->addr)[0];
std::int64_t num_elements = inputs[0]->size / sizeof(T);
constexpr int64_t kChannelSize = 3;
auto sharder_adjusthue = [input_data, delta_h, output_data, kChannelSize](int64_t start, int64_t end) {
for (int64_t i = start * kChannelSize; i < end * kChannelSize; i = i + kChannelSize) {
const HsvTuple hsv = rgb2hsv(static_cast<float>(*(input_data + i)), static_cast<float>(*(input_data + i + 1)),
static_cast<float>(*(input_data + i + 2)));
float new_h = hsv.h;
float new_s = hsv.s;
float new_v = hsv.v;
// hue adjustment
new_h = fmodf(hsv.h + delta_h, 1.0f);
if (new_h < 0.0f) {
new_h = fmodf(1.0f + new_h, 1.0f);
}
const RgbTuple rgb = hsv2rgb(new_h, new_s, new_v);
output_data[i] = static_cast<T>(rgb.r);
output_data[i + 1] = static_cast<T>(rgb.g);
output_data[i + 2] = static_cast<T>(rgb.b);
}
};
std::int64_t total = num_elements / kChannelSize;
std::int64_t per_unit_size{total / std::min(kAdjustHueParallelNum - SizeToLong(kAdjustHueInputNum), total)};
if (total > kAdjustHueParallelNum) {
CPUKernelUtils::ParallelFor(sharder_adjusthue, total, per_unit_size);
} else {
sharder_adjusthue(0, total);
}
return true;
}
} // namespace detail
void AdjustHueCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> image_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
if (image_shape != output_shape) {
MS_LOG(EXCEPTION) << "For AdjustHue, the data type of the input " << image_shape
<< "need be the same as the output " << output_shape << ".";
}
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
CHECK_KERNEL_INPUTS_NUM(input_num, kAdjustHueInputNum, kernel_name_);
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
CHECK_KERNEL_OUTPUTS_NUM(output_num, kAdjustHueOutputNum, kernel_name_);
}
bool AdjustHueCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) {
switch (dtype_) {
case kNumberTypeFloat16:
detail::LaunchAdjustHueKernelHalf<Eigen::half>(inputs, outputs);
break;
case kNumberTypeFloat32:
detail::LaunchAdjustHueKernel<float>(inputs, outputs);
break;
default:
MS_LOG(EXCEPTION) << "For AdjustHue, the type of 'image' should be float16, float32, but got "
<< TypeIdLabel(dtype_) << ".";
return false;
}
return true;
}
std::vector<KernelAttr> AdjustHueCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)};
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, AdjustHue, AdjustHueCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,56 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADJUST_HUE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADJUST_HUE_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
struct HsvTuple {
float h;
float s;
float v;
};
struct RgbTuple {
float r;
float g;
float b;
};
class AdjustHueCpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
AdjustHueCpuKernelMod() = default;
~AdjustHueCpuKernelMod() override = default;
void InitKernel(const CNodePtr &Kernel_node);
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
TypeId dtype_{kTypeUnknown};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADJUST_HUE_CPU_KERNEL_H_

View File

@ -0,0 +1,69 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/adjust_hue.h"
#include <algorithm>
#include <memory>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr AdjustHueInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto input_shape_images = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
const int64_t min_dim = 3;
(void)CheckAndConvertUtils::CheckInteger("dimension of image", SizeToLong(input_shape_images.size()), kGreaterEqual,
min_dim, prim_name);
(void)CheckAndConvertUtils::CheckInteger("last dimension of image", input_shape_images[input_shape_images.size() - 1],
kEqual, min_dim, prim_name);
auto input_shape_delta = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
const int64_t delta_dim = 0;
(void)CheckAndConvertUtils::CheckInteger("dimension of delta", SizeToLong(input_shape_delta.size()), kEqual,
delta_dim, prim_name);
return std::make_shared<abstract::Shape>(input_shape_images);
}
TypePtr AdjustHueInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto input_type_images = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(input_type_images);
const std::set valid_types = {kFloat16, kFloat32};
(void)CheckAndConvertUtils::CheckTensorTypeValid("images", input_type_images, valid_types, prim_name);
auto input_type_delta = input_args[1]->BuildType();
MS_EXCEPTION_IF_NULL(input_type_delta);
(void)CheckAndConvertUtils::CheckTensorTypeValid("delta", input_type_delta, {kFloat32}, prim_name);
return input_type_images;
}
} // namespace
MIND_API_BASE_IMPL(AdjustHue, PrimitiveC, BaseOperator);
AbstractBasePtr AdjustHueInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kInputsNum = 2;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
auto infer_type = AdjustHueInferType(primitive, input_args);
auto infer_shape = AdjustHueInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(AdjustHue, prim::kPrimAdjustHue, AdjustHueInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_ADJUST_HUE_H_
#define MINDSPORE_CORE_OPS_ADJUST_HUE_H_
#include <memory>
#include <vector>
#include <string>
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "ops/base_operator.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAdjustHue = "AdjustHue";
/// \brief Adjust hue of RGB images.
/// Refer to Python API @ref mindspore.ops.AdjustHue for more details.
class MIND_API AdjustHue : public BaseOperator {
public:
MIND_API_BASE_MEMBER(AdjustHue);
AdjustHue() : BaseOperator(kNameAdjustHue) { InitIOName({"images", "delta"}, {"y"}); }
};
abstract::AbstractBasePtr AdjustHueInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimAdjustHuePtr = std::shared_ptr<AdjustHue>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_ADJUST_HUE_H_

View File

@ -32,6 +32,7 @@ GVAR_DEF(ValuePtr, kValueOne, std::make_shared<Int64Imm>(1));
GVAR_DEF(mindspore::HashMap<std::string COMMA ValuePtr>, kSideEffectPropagate,
{{mindspore::GRAPH_FLAG_SIDE_EFFECT_PROPAGATE COMMA kValueOne}});
#undef COMMA
constexpr auto kAdjustHue = "AdjustHue";
constexpr auto kGetNext = "GetNext";
constexpr auto kGather = "Gather";
constexpr auto kAddcdiv = "Addcdiv";
@ -814,6 +815,7 @@ GVAR_DEF(PrimitivePtr, kPrimZeta, std::make_shared<Primitive>("Zeta"));
// Image
GVAR_DEF(PrimitivePtr, kPrimNonMaxSuppressionV3, std::make_shared<Primitive>("NonMaxSuppressionV3"));
GVAR_DEF(PrimitivePtr, kPrimAdjustHue, std::make_shared<Primitive>(kAdjustHue));
// Statements
GVAR_DEF(PrimitivePtr, kPrimReturn, std::make_shared<Primitive>("Return"));

View File

@ -164,3 +164,4 @@ from .transpose import _transpose_aicpu
from .trace import _trace_aicpu
from .tracegrad import _tracegrad_aicpu
from .zeta import _zeta_aicpu
from .adjust_hue import _adjust_hue_aicpu

View File

@ -0,0 +1,31 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""AdjustHue op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
adjust_hue_op_info = AiCPURegOp("AdjustHue") \
.fusion_type("OPAQUE") \
.input(0, "images", "required") \
.input(1, "delta", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F16_Default, DataType.F32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(adjust_hue_op_info)
def _adjust_hue_aicpu():
"""AdjustHue AiCPU register"""
return

View File

@ -1,4 +1,4 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -21,6 +21,63 @@ from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, prim_attr_register, Primitive
class AdjustHue(Primitive):
"""
Adjust hue of RGB images.
Note:
This is a convenience method that converts an RGB image to float
representation, converts it to HSV, adds an offset to the
hue channel, converts back to RGB and then back to the original
data type. If several adjustments are chained it is advisable to minimize
the number of redundant conversions.
Inputs:
- **image** (Tensor): RGB image or images. The size of the last dimension must be 3.
the dtype is float16 or float32. At least 3-D.
- **delta** (Tensor): How much to add to the hue channel, the dtype is float32. Must be 0-D.
Output:
Adjusted image(s), same shape and dtype as `image`.
Raises:
TypeError: If neither `image` nor `delta` is a tensor.
TypeError: If the dtype of image not float16 or float32.
TypeError: If the dtype of delta not float32.
ValueError: If image have at less than 3 dimensions.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> class AdjustHue(nn.Cell):
... def __init__(self):
... super(AdjustHue, self).__init__()
... self.adjustHue = P.AdjustHue()
... def construct(self, image, delta):
... return self.adjustHue(image, delta)
...
>>> image = np.array([[[1, 2, 3], [4, 5, 6]],
... [[7, 8, 9], [10, 11, 12]],
... [[13, 14, 15], [16, 17, 18]]]).astype(np.float32)
>>> delta = 0.2
>>> adjust_hue = AdjustHue()
>>> output = adjust_hue(Tensor(image), Tensor(delta))
>>> print("output", output)
output [[[ 2.3999996 1. 3. ]
[ 5.3999996 4. 6. ]]
[[ 8.4 7. 9. ]
[11.4 10. 12. ]]
[[14.4 13. 15. ]
[17.4 16. 18. ]]]
"""
@prim_attr_register
def __init__(self):
"""Initialize AdjustHue"""
self.init_prim_io_names(inputs=['images', 'delta'], outputs=['y'])
class CropAndResize(PrimitiveWithInfer):
"""
Extracts crops from the input image tensor and resizes them.

View File

@ -26,7 +26,7 @@ from mindspore import ms_function
from mindspore.common import dtype as mstype
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops.operations.image_ops import CropAndResizeGradBoxes
from mindspore.ops.operations.image_ops import CropAndResizeGradBoxes, AdjustHue
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.operations import _quant_ops as Q
@ -3071,6 +3071,16 @@ test_case_array_ops = [
]
test_case_image_ops = [
('AdjustHue', {
'block': AdjustHue(),
'desc_inputs': [Tensor(np.array([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]],
[[13, 14, 15],
[16, 17, 18]]]).astype(np.float32)),
Tensor(0.2, mstype.float32)],
'skip': ['backward']}),
('NonMaxSuppressionV3', {
'block': P.NonMaxSuppressionV3(),
'desc_inputs': [Tensor(np.array([[20, 5, 200, 100],