!32951 [feat] [assistant] [I48O56] Add new aicpu operator AdjustSaturation
Merge pull request !32951 from 陈慧敏/AdjustSaturation
This commit is contained in:
commit
d21686c2fe
|
@ -0,0 +1,181 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/adjust_saturation_cpu_kernel.h"
|
||||
#include <Eigen/Dense>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
const std::int64_t kAdjustSaturationParallelNum = 64 * 1024;
|
||||
const std::int64_t kAdjustSaturationZero = 0;
|
||||
const std::int64_t kAdjustSaturationOne = 1;
|
||||
const std::int64_t kAdjustSaturationTwo = 2;
|
||||
const std::int64_t kAdjustSaturationThree = 3;
|
||||
const std::int64_t kAdjustSaturationFour = 4;
|
||||
const std::int64_t kAdjustSaturationFive = 5;
|
||||
const std::float_t kAdjustSaturationSix = 6;
|
||||
} // namespace
|
||||
|
||||
namespace detail {
|
||||
static void rgb_to_hsv(float r, float g, float b, float *h, float *s, float *v) {
|
||||
float vv = std::max(r, std::max(g, b));
|
||||
float range = vv - std::min(r, std::min(g, b));
|
||||
if (vv > 0) {
|
||||
*s = range / vv;
|
||||
} else {
|
||||
*s = 0;
|
||||
}
|
||||
float norm = kAdjustSaturationOne / (kAdjustSaturationSix * range);
|
||||
float hh;
|
||||
if (r == vv) {
|
||||
hh = norm * (g - b);
|
||||
} else if (g == vv) {
|
||||
hh = norm * (b - r) + kAdjustSaturationTwo / kAdjustSaturationSix;
|
||||
} else {
|
||||
hh = norm * (r - g) + kAdjustSaturationFour / kAdjustSaturationSix;
|
||||
}
|
||||
if (range <= 0.0) {
|
||||
hh = 0;
|
||||
}
|
||||
if (hh < 0.0) {
|
||||
hh = hh + kAdjustSaturationOne;
|
||||
}
|
||||
*v = vv;
|
||||
*h = hh;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void hsv_to_rgb(float h, float s, float v, T *r, T *g, T *b) {
|
||||
float c = s * v;
|
||||
float m = v - c;
|
||||
float dh = h * kAdjustSaturationSix;
|
||||
float rr, gg, bb;
|
||||
int h_category = static_cast<int>(dh);
|
||||
float fmodu = dh;
|
||||
while (fmodu <= 0) {
|
||||
fmodu += kAdjustSaturationTwo;
|
||||
}
|
||||
while (fmodu >= kAdjustSaturationTwo) {
|
||||
fmodu -= kAdjustSaturationTwo;
|
||||
}
|
||||
float x = c * (1 - std::abs(fmodu - 1));
|
||||
switch (h_category) {
|
||||
case kAdjustSaturationZero:
|
||||
rr = c;
|
||||
gg = x;
|
||||
bb = 0;
|
||||
break;
|
||||
case kAdjustSaturationOne:
|
||||
rr = x;
|
||||
gg = c;
|
||||
bb = 0;
|
||||
break;
|
||||
case kAdjustSaturationTwo:
|
||||
rr = 0;
|
||||
gg = c;
|
||||
bb = x;
|
||||
break;
|
||||
case kAdjustSaturationThree:
|
||||
rr = 0;
|
||||
gg = x;
|
||||
bb = c;
|
||||
break;
|
||||
case kAdjustSaturationFour:
|
||||
rr = x;
|
||||
gg = 0;
|
||||
bb = c;
|
||||
break;
|
||||
case kAdjustSaturationFive:
|
||||
rr = c;
|
||||
gg = 0;
|
||||
bb = x;
|
||||
break;
|
||||
default:
|
||||
rr = 0;
|
||||
gg = 0;
|
||||
bb = 0;
|
||||
}
|
||||
*r = static_cast<T>(rr + m);
|
||||
*g = static_cast<T>(gg + m);
|
||||
*b = static_cast<T>(bb + m);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool LaunchAdjustSaturationKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
|
||||
auto input{static_cast<T *>(inputs[0]->addr)};
|
||||
auto scale{static_cast<std::float_t *>(inputs[1]->addr)};
|
||||
auto output{static_cast<T *>(outputs[0]->addr)};
|
||||
constexpr int64_t kChannelSize = 3;
|
||||
std::int64_t num_elements = inputs[0]->size / sizeof(T);
|
||||
auto sharder_adjustsaturation = [input, scale, output, kChannelSize](int64_t start, int64_t end) {
|
||||
for (int64_t i = start * kChannelSize; i < end * kChannelSize; i = i + kChannelSize) {
|
||||
float h, s, v;
|
||||
// Convert the RGB color to Hue/V-range.
|
||||
rgb_to_hsv(static_cast<float>(*(input + i)), static_cast<float>(*(input + i + 1)),
|
||||
static_cast<float>(*(input + i + 2)), &h, &s, &v);
|
||||
s = std::min(1.0f, std::max(0.0f, s * scale[0]));
|
||||
// Convert the hue and v-range back into RGB.
|
||||
hsv_to_rgb<T>(h, s, v, &output[i], &output[i + 1], &output[i + 2]);
|
||||
}
|
||||
};
|
||||
std::int64_t total = num_elements / kChannelSize;
|
||||
if (total > kAdjustSaturationParallelNum) {
|
||||
std::int64_t per_unit_size =
|
||||
total / std::min(kAdjustSaturationParallelNum - SizeToLong(kAdjustSaturationTwo), total);
|
||||
CPUKernelUtils::ParallelFor(sharder_adjustsaturation, total, per_unit_size);
|
||||
} else {
|
||||
sharder_adjustsaturation(0, total);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
void AdjustSaturationCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_INPUTS_NUM(input_num, kAdjustSaturationTwo, kernel_name_);
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(output_num, kAdjustSaturationOne, kernel_name_);
|
||||
input_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
}
|
||||
|
||||
bool AdjustSaturationCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (input_type_ == kNumberTypeFloat32) {
|
||||
return detail::LaunchAdjustSaturationKernel<float>(inputs, outputs);
|
||||
} else if (input_type_ == kNumberTypeFloat16) {
|
||||
return detail::LaunchAdjustSaturationKernel<Eigen::half>(inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', unsupported input data type " << TypeIdLabel(input_type_);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
std::vector<KernelAttr> AdjustSaturationCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)};
|
||||
return support_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, AdjustSaturation, AdjustSaturationCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADJUST_SATURATION_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADJUST_SATURATION_CPU_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class AdjustSaturationCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
AdjustSaturationCpuKernelMod() = default;
|
||||
~AdjustSaturationCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node);
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
|
||||
private:
|
||||
TypeId input_type_{kTypeUnknown};
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADJUST_SATURATION_CPU_KERNEL_H_
|
|
@ -0,0 +1,76 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/adjust_saturation.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <algorithm>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr AdjustSaturationInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto input_image_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto input_scale_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
const int64_t min_image_dim = 3;
|
||||
const int64_t scale_dim = 0;
|
||||
(void)CheckAndConvertUtils::CheckInteger("dimension of AdjustSaturation input image",
|
||||
SizeToLong(input_image_shape.size()), kGreaterEqual, min_image_dim,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("last dimension of AdjustSaturation input image",
|
||||
input_image_shape[input_image_shape.size() - 1], kEqual, min_image_dim,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("dimension of AdjustSaturation input scale",
|
||||
SizeToLong(input_scale_shape.size()), kEqual, scale_dim, prim_name);
|
||||
return std::make_shared<abstract::Shape>(input_image_shape);
|
||||
}
|
||||
|
||||
TypePtr AdjustSaturationInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);
|
||||
auto input_images_type = input_args[0]->BuildType();
|
||||
auto input_scale_type = input_args[1]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(input_images_type);
|
||||
MS_EXCEPTION_IF_NULL(input_scale_type);
|
||||
const std::set<TypePtr> valid_images_types = {kFloat16, kFloat32};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("image", input_images_type, valid_images_types, prim_name);
|
||||
const std::set<TypePtr> valid_scale_types = {kFloat32};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("scale", input_scale_type, valid_scale_types, prim_name);
|
||||
return input_images_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(AdjustSaturation, BaseOperator);
|
||||
AbstractBasePtr AdjustSaturationInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kInputsNum = 2;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
|
||||
auto infer_type = AdjustSaturationInferType(primitive, input_args);
|
||||
auto infer_shape = AdjustSaturationInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(AdjustSaturation, prim::kPrimAdjustSaturation, AdjustSaturationInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_ADJUST_SATURATION_H_
|
||||
#define MINDSPORE_CORE_OPS_ADJUST_SATURATION_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAdjustSaturation = "AdjustSaturation";
|
||||
/// \brief Convert the images to HSV and multiply the saturation (S) channel by `scale` and clipping.
|
||||
/// Refer to Python API @ref mindspore.ops.AdjustSaturation for more details.
|
||||
class MIND_API AdjustSaturation : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(AdjustSaturation);
|
||||
/// \brief Constructor.
|
||||
AdjustSaturation() : BaseOperator(kNameAdjustSaturation) { InitIOName({"image", "scale"}, {"y"}); }
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr AdjustSaturationInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimAdjustSaturationPtr = std::shared_ptr<AdjustSaturation>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_ADJUST_SATURATION_H_
|
|
@ -34,6 +34,7 @@ GVAR_DEF(mindspore::HashMap<std::string COMMA ValuePtr>, kSideEffectPropagate,
|
|||
#undef COMMA
|
||||
constexpr auto kAdjustHue = "AdjustHue";
|
||||
constexpr auto kAdjustContrastv2 = "AdjustContrastv2";
|
||||
constexpr auto kAdjustSaturation = "AdjustSaturation";
|
||||
constexpr auto kGetNext = "GetNext";
|
||||
constexpr auto kGather = "Gather";
|
||||
constexpr auto kAddcdiv = "Addcdiv";
|
||||
|
@ -841,6 +842,7 @@ GVAR_DEF(PrimitivePtr, kPrimSvd, std::make_shared<Primitive>("Svd"));
|
|||
GVAR_DEF(PrimitivePtr, kPrimNonMaxSuppressionV3, std::make_shared<Primitive>("NonMaxSuppressionV3"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimAdjustHue, std::make_shared<Primitive>(kAdjustHue));
|
||||
GVAR_DEF(PrimitivePtr, kPrimAdjustContrastv2, std::make_shared<Primitive>(kAdjustContrastv2));
|
||||
GVAR_DEF(PrimitivePtr, kPrimAdjustSaturation, std::make_shared<Primitive>(kAdjustSaturation));
|
||||
|
||||
// Statements
|
||||
GVAR_DEF(PrimitivePtr, kPrimReturn, std::make_shared<Primitive>("Return"));
|
||||
|
|
|
@ -178,3 +178,4 @@ from .adjust_hue import _adjust_hue_aicpu
|
|||
from .maxpool_v1 import _maxpool_v1_aicpu
|
||||
from .maxpool_grad_v1 import _maxpool_grad_v1_aicpu
|
||||
from .dense_to_csr_sparse_matrix import _dense_to_csr_sparse_matrix_aicpu
|
||||
from .adjust_saturation import _adjust_saturation_aicpu
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""AdjustSaturation op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
adjust_saturation_op_info = AiCPURegOp("AdjustSaturation") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "image", "required") \
|
||||
.input(1, "scale", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(adjust_saturation_op_info)
|
||||
def _adjust_saturation_aicpu():
|
||||
"""AdjustSaturation aicpu register"""
|
||||
return
|
|
@ -21,6 +21,53 @@ from ...common import dtype as mstype
|
|||
from ..primitive import PrimitiveWithInfer, prim_attr_register, Primitive
|
||||
|
||||
|
||||
class AdjustSaturation(Primitive):
|
||||
"""
|
||||
Adjust saturation of RGB images.
|
||||
|
||||
Note:
|
||||
This is a convenience method that converts RGB images to float representation, converts them to HSV,
|
||||
adds an offset to the saturation channel, converts back to RGB and then back to the original data type.
|
||||
If several adjustments are chained it is advisable to minimize the number of redundant conversions.
|
||||
|
||||
inputs:
|
||||
- **image** (Tensor): Images to adjust. Must be one of the following types: float16, float32.
|
||||
At least 3-D.The last dimension is interpreted as channels, and must be three.
|
||||
- **scale** (Tensor): A float scale to add to the saturation. A Tensor of type float32. Must be 0-D.
|
||||
|
||||
Output:
|
||||
Adjusted image(s), same shape and dtype as `image`.
|
||||
|
||||
Raises:
|
||||
TypeError: If any iput is not Tensor.
|
||||
TypeError: If the type of `image` is not one of the following dtype: float16, float32.
|
||||
TypeError: If the type of `scale` is not float32.
|
||||
ValueError: If the dimension of the 'image' is less than 3, or the last dimension of the 'image' is not 3.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor([[[1.0, 2.0, 3.0],
|
||||
... [4.0, 5.0, 6.0]],
|
||||
... [[7.0, 8.0, 9.0],
|
||||
... [10.0, 11.0, 12.0]]])
|
||||
>>> scale = Tensor(float(0.5))
|
||||
>>> adjustsaturation = AdjustSaturation()
|
||||
>>> output = adjustsaturation(x, scale)
|
||||
>>> print(output)
|
||||
[[[ 2. 2.4999998 3. ]
|
||||
[ 5. 5.5 6. ]]
|
||||
[[ 8. 8.5 9. ]
|
||||
[11. 11.5 12. ]]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize AdjustSaturation"""
|
||||
self.init_prim_io_names(inputs=['images', 'scale'], outputs=['y'])
|
||||
|
||||
|
||||
class AdjustContrastv2(Primitive):
|
||||
"""
|
||||
Adjust contrastv2 of images.
|
||||
|
|
|
@ -26,7 +26,8 @@ from mindspore import ms_function
|
|||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations.image_ops import CropAndResizeGradBoxes, AdjustHue, AdjustContrastv2
|
||||
from mindspore.ops.operations.image_ops import CropAndResizeGradBoxes, AdjustHue, AdjustContrastv2, \
|
||||
AdjustSaturation
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.ops.operations import _quant_ops as Q
|
||||
|
@ -3106,6 +3107,14 @@ test_case_image_ops = [
|
|||
[10.0, 11.0, 12.0]]]),
|
||||
Tensor(0.5, mstype.float32)],
|
||||
'skip': ['backward']}),
|
||||
('AdjustSaturation', {
|
||||
'block': AdjustSaturation(),
|
||||
'desc_inputs': [Tensor([[[1.0, 2.0, 3.0],
|
||||
[4.0, 5.0, 6.0]],
|
||||
[[7.0, 8.0, 9.0],
|
||||
[10.0, 11.0, 12.0]]]),
|
||||
Tensor(0.5, mstype.float32)],
|
||||
'skip': ['backward']}),
|
||||
('NonMaxSuppressionV3', {
|
||||
'block': P.NonMaxSuppressionV3(),
|
||||
'desc_inputs': [Tensor(np.array([[20, 5, 200, 100],
|
||||
|
|
Loading…
Reference in New Issue