[assistant][Fmax] Add new operator Fmax

This commit is contained in:
mikkassa 2022-11-07 21:43:21 +08:00
parent a0beca7efe
commit a9978adad7
12 changed files with 674 additions and 0 deletions

View File

@ -179,6 +179,7 @@ constexpr auto kZerosLike = "ZerosLike";
constexpr auto kEqual = "Equal";
constexpr auto kOnesLike = "OnesLike";
constexpr auto kSign = "Sign";
constexpr auto kFmax = "Fmax";
constexpr auto kGLU = "GLU";
constexpr auto kFmin = "Fmin";
constexpr auto kArgmax = "Argmax";
@ -334,6 +335,7 @@ const std::map<std::string, std::string> kOpNameToAicpuOpNameMap{
{kACos, "Acos"},
{kHSigmoid, "HardSigmoid"},
{kFmin, "Minimum"},
{kFmax, "Maximum"},
{kHSigmoidGrad, "HardSigmoidGrad"},
{kArgmax, "ArgMax"},
{kArgmin, "ArgMin"},

View File

@ -0,0 +1,297 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/fmax_cpu_kernel.h"
#include <algorithm>
#include <utility>
#include "mindspore/core/ops/fmax.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "plugin/device/cpu/kernel/cpu_kernel.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr auto kShapeIndexZero = 0;
constexpr auto kShapeIndex1st = 1;
constexpr auto kShapeIndex2nd = 2;
constexpr auto kShapeIndex3rd = 3;
constexpr auto kShapeIndex4th = 4;
constexpr auto kShapeIndex5th = 5;
constexpr auto kShapeIndex6th = 6;
constexpr size_t kFmaxInputsNum = 2;
constexpr size_t kFmaxOutputsNum = 1;
} // namespace
bool FmaxCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::Fmax>(base_operator);
if (!kernel_ptr) {
MS_LOG(ERROR) << "cast Fmax ops failed!";
return false;
}
kernel_name_ = kernel_ptr->name();
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
return false;
}
return true;
}
int FmaxCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
int ret = 0;
if ((ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost)) != 0) {
return ret;
}
input_x_shape_ = inputs[0]->GetShapeVector();
input_y_shape_ = inputs[1]->GetShapeVector();
output_shape_ = outputs[0]->GetShapeVector();
TypeId input_x_dtype = inputs[0]->GetDtype();
TypeId input_y_dtype = inputs[1]->GetDtype();
size_t max_input_shape_size =
input_x_shape_.size() > input_y_shape_.size() ? input_x_shape_.size() : input_y_shape_.size();
for (size_t i = 0; i < output_shape_.size(); i++) {
output_num_ *= static_cast<size_t>(output_shape_[i]);
}
if ((input_x_shape_.size() == 0 && input_y_shape_.size() != 0) ||
(input_x_shape_.size() != 0 && input_y_shape_.size() == 0)) {
InitInputTensorAndScalar(max_input_shape_size);
} else if (max_input_shape_size == output_shape_.size() && output_shape_.size() != 0) {
InitInputTensors(input_x_dtype, input_y_dtype);
}
return 0;
}
void FmaxCpuKernelMod::InitInputTensorAndScalar(size_t max_input_shape_size) {
if (max_input_shape_size != output_shape_.size()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dimension of output tensor must be equal to the max "
"dimension of inputs, but got the dimension of output tensor: "
<< output_shape_.size() << " and the max dimension of inputs: " << max_input_shape_size;
}
need_broadcast_ = false;
}
void FmaxCpuKernelMod::InitInputTensors(TypeId input_x_dtype, TypeId input_y_dtype) {
if (input_x_dtype == kNumberTypeBool && input_y_dtype == kNumberTypeBool) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input tensor types can not be both bool.";
}
// Check if the shape needs to be broadcast
need_broadcast_ = IsBroadcast();
if (need_broadcast_) {
InitTensorBroadcastShape();
}
}
template <typename T>
bool FmaxCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) const {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kFmaxInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kFmaxOutputsNum, kernel_name_);
T *input_x_ = reinterpret_cast<T *>(inputs[0]->addr);
T *input_y_ = reinterpret_cast<T *>(inputs[1]->addr);
T *output_ = reinterpret_cast<T *>(outputs[0]->addr);
BroadcastArith(input_x_, input_y_, output_);
return true;
}
template <typename T>
void FmaxCpuKernelMod::BroadcastArith(const T *input_x, const T *input_y, T *output) const {
MS_EXCEPTION_IF_NULL(input_x);
MS_EXCEPTION_IF_NULL(input_y);
MS_EXCEPTION_IF_NULL(output);
// bool need_broadcast = false;
if (need_broadcast_) {
BroadcastArithKernel(broadcast_input_x_shape_[kShapeIndexZero], broadcast_input_x_shape_[kShapeIndex1st],
broadcast_input_x_shape_[kShapeIndex2nd], broadcast_input_x_shape_[kShapeIndex3rd],
broadcast_input_x_shape_[kShapeIndex4th], broadcast_input_x_shape_[kShapeIndex5th],
broadcast_input_x_shape_[kShapeIndex6th], broadcast_input_y_shape_[kShapeIndexZero],
broadcast_input_y_shape_[kShapeIndex1st], broadcast_input_y_shape_[kShapeIndex2nd],
broadcast_input_y_shape_[kShapeIndex3rd], broadcast_input_y_shape_[kShapeIndex4th],
broadcast_input_y_shape_[kShapeIndex5th], broadcast_input_y_shape_[kShapeIndex6th],
broadcast_output_shape_[kShapeIndexZero], broadcast_output_shape_[kShapeIndex1st],
broadcast_output_shape_[kShapeIndex2nd], broadcast_output_shape_[kShapeIndex3rd],
broadcast_output_shape_[kShapeIndex4th], broadcast_output_shape_[kShapeIndex5th],
broadcast_output_shape_[kShapeIndex6th], input_x, input_y, output);
} else {
if (input_x_shape_.size() == 0 || input_y_shape_.size() == 0) {
BroadcastArithOneScalarOneTensor(input_x, input_y, output);
} else {
BroadcastArithTensors(input_x, input_y, output);
}
}
}
bool FmaxCpuKernelMod::IsBroadcast() const {
if (input_x_shape_.size() != input_y_shape_.size()) {
return true;
}
for (size_t i = 0; i < input_x_shape_.size(); i++) {
if (input_x_shape_[i] != input_y_shape_[i]) {
return true;
}
}
return false;
}
void FmaxCpuKernelMod::InitTensorBroadcastShape() {
if (output_shape_.size() > max_dims_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dimension of output must be less than or "
"equal to 7, but got "
<< output_shape_.size() << ".";
}
broadcast_input_x_shape_.resize(max_dims_, 1);
broadcast_input_y_shape_.resize(max_dims_, 1);
broadcast_output_shape_.resize(max_dims_, 1);
for (size_t i = 0; i < output_shape_.size(); i++) {
broadcast_output_shape_[i] = static_cast<size_t>(output_shape_[i]);
}
int input_x_dim_offset = output_shape_.size() - input_x_shape_.size();
for (size_t j = 0; j < input_x_shape_.size(); j++) {
broadcast_input_x_shape_[j + IntToSize(input_x_dim_offset)] = static_cast<size_t>(input_x_shape_[j]);
input_x_num_ *= static_cast<size_t>(input_x_shape_[j]);
}
int input_y_dim_offset = output_shape_.size() - input_y_shape_.size();
for (size_t k = 0; k < input_y_shape_.size(); k++) {
if (need_broadcast_) {
broadcast_input_y_shape_[k + IntToSize(input_y_dim_offset)] = static_cast<size_t>(input_y_shape_[k]);
input_y_num_ *= static_cast<size_t>(input_y_shape_[k]);
}
}
}
// FmaxFunc
template <typename T>
T FmaxCpuKernelMod::FmaxFunc(const T &lhs, const T &rhs) const {
if constexpr (std::is_same_v<T, float16>) {
auto temp_lhs = static_cast<float>(lhs);
auto temp_rhs = static_cast<float>(rhs);
if (std::isnan(temp_lhs)) {
return rhs;
} else if (std::isnan(temp_rhs)) {
return lhs;
} else {
return lhs > rhs ? lhs : rhs;
}
} else if constexpr (((std::is_same_v<T, float>) || (std::is_same_v<T, double>))) { // NOLINT
if (std::isnan(lhs)) {
return rhs;
} else if (std::isnan(rhs)) {
return lhs;
} else {
return lhs > rhs ? lhs : rhs;
}
} else {
return lhs > rhs ? lhs : rhs;
}
return lhs > rhs ? lhs : rhs;
}
// Broadcast comparison
int64_t FmaxCpuKernelMod::Index(const int64_t &index, const int64_t &dim) const { return dim == 1 ? 0 : index; }
// Broadcast Arithmetic
template <typename T>
void FmaxCpuKernelMod::BroadcastArithKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3,
const size_t l4, const size_t l5, const size_t l6, const size_t r0,
const size_t r1, const size_t r2, const size_t r3, const size_t r4,
const size_t r5, const size_t r6, const size_t d0, const size_t d1,
const size_t d2, const size_t d3, const size_t d4, const size_t d5,
const size_t d6, const T *input_x, const T *input_y, T *output) const {
for (size_t pos = 0; pos < output_num_; pos++) {
auto pos_signed = SizeToLong(pos);
size_t i = pos_signed / (d1 * d2 * d3 * d4 * d5 * d6) % d0;
size_t j = pos_signed / (d2 * d3 * d4 * d5 * d6) % d1;
size_t k = pos_signed / (d3 * d4 * d5 * d6) % d2;
size_t l = pos_signed / (d4 * d5 * d6) % d3;
size_t m = pos_signed / (d5 * d6) % d4;
size_t n = pos_signed / d6 % d5;
size_t o = pos_signed % d6;
size_t l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6;
l_index += Index(j, l1) * l2 * l3 * l4 * l5 * l6;
l_index += Index(k, l2) * l3 * l4 * l5 * l6;
l_index += Index(l, l3) * l4 * l5 * l6;
l_index += Index(m, l4) * l5 * l6;
l_index += Index(n, l5) * l6;
l_index += Index(o, l6);
size_t r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6;
r_index += Index(j, r1) * r2 * r3 * r4 * r5 * r6;
r_index += Index(k, r2) * r3 * r4 * r5 * r6;
r_index += Index(l, r3) * r4 * r5 * r6;
r_index += Index(m, r4) * r5 * r6;
r_index += Index(n, r5) * r6;
r_index += Index(o, r6);
output[pos] = FmaxFunc(input_x[LongToSize(l_index)], input_y[LongToSize(r_index)]);
}
}
template <typename T>
void FmaxCpuKernelMod::BroadcastArithOneScalarOneTensor(const T *input_x, const T *input_y, T *output) const {
if (input_x_shape_.size() == 0) {
for (size_t i = 0; i < output_num_; ++i) {
output[i] = FmaxFunc(input_x[0], input_y[i]);
}
} else {
for (size_t i = 0; i < output_num_; ++i) {
output[i] = FmaxFunc(input_x[i], input_y[0]);
}
}
}
template <typename T>
void FmaxCpuKernelMod::BroadcastArithTensors(const T *input_x, const T *input_y, T *output) const {
for (size_t i = 0; i < output_num_; ++i) {
output[i] = FmaxFunc(input_x[i], input_y[i]);
}
}
const std::vector<std::pair<KernelAttr, FmaxCpuKernelMod::KernelRunFunc>> &FmaxCpuKernelMod::GetFuncList() const {
static const std::vector<std::pair<KernelAttr, FmaxCpuKernelMod::KernelRunFunc>> func_list = {
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&FmaxCpuKernelMod::LaunchKernel<float16>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&FmaxCpuKernelMod::LaunchKernel<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&FmaxCpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&FmaxCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&FmaxCpuKernelMod::LaunchKernel<double>},
};
return func_list;
}
std::vector<KernelAttr> FmaxCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> kernel_attr_list = {
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
};
return kernel_attr_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Fmax, FmaxCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,92 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FMAX_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FMAX_CPU_KERNEL_H_
#include <map>
#include <memory>
#include <utility>
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class FmaxCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<FmaxCpuKernelMod> {
public:
FmaxCpuKernelMod() = default;
~FmaxCpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
bool IsBroadcast() const;
int64_t Index(const int64_t &index, const int64_t &dim) const;
void InitTensorBroadcastShape();
void InitInputTensorAndScalar(size_t max_input_shape_size);
void InitInputTensors(TypeId input_x_dtype, TypeId input_y_dtype);
// Broadcast Arithmetic
template <typename T>
void BroadcastArithKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3, const size_t l4,
const size_t l5, const size_t l6, const size_t r0, const size_t r1, const size_t r2,
const size_t r3, const size_t r4, const size_t r5, const size_t r6, const size_t d0,
const size_t d1, const size_t d2, const size_t d3, const size_t d4, const size_t d5,
const size_t d6, const T *input_x, const T *input_y, T *output) const;
template <typename T>
T FmaxFunc(const T &lhs, const T &rhs) const;
template <typename T>
void BroadcastArithOneScalarOneTensor(const T *input_x, const T *input_y, T *output) const;
template <typename T>
void BroadcastArithTensors(const T *input_x, const T *input_y, T *output) const;
template <typename T>
void BroadcastArith(const T *input_x, const T *input_y, T *output) const;
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) const;
bool need_broadcast_{false};
size_t input_x_num_{1};
size_t input_y_num_{1};
size_t output_num_{1};
std::vector<int64_t> input_x_shape_;
std::vector<int64_t> input_y_shape_;
std::vector<int64_t> output_shape_;
std::vector<int64_t> broadcast_input_x_shape_;
std::vector<int64_t> broadcast_input_y_shape_;
std::vector<int64_t> broadcast_output_shape_;
const size_t max_dims_{7};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FMAX_CPU_KERNEL_H_

View File

@ -148,6 +148,7 @@ constexpr auto kTridiagonalSolve = "TridiagonalSolve";
constexpr auto kFFTWithSize = "FFTWithSize";
constexpr auto kTriuIndices = "TriuIndices";
constexpr auto kTrilIndices = "TrilIndices";
constexpr auto kFmax = "Fmax";
constexpr auto kTrace = "Trace";
constexpr auto kTraceGrad = "TraceGrad";
constexpr auto kMatrixLogarithm = "MatrixLogarithm";
@ -1398,6 +1399,7 @@ GVAR_DEF(PrimitivePtr, kPrimTrace, std::make_shared<Primitive>("Trace"));
GVAR_DEF(PrimitivePtr, kPrimTraceGrad, std::make_shared<Primitive>("TraceGrad"));
GVAR_DEF(PrimitivePtr, kPrimTridiagonalMatMul, std::make_shared<Primitive>(kTridiagonalMatMul));
GVAR_DEF(PrimitivePtr, kPrimZeta, std::make_shared<Primitive>("Zeta"));
GVAR_DEF(PrimitivePtr, kPrimFmax, std::make_shared<Primitive>(kFmax));
GVAR_DEF(PrimitivePtr, kPrimIgamma, std::make_shared<Primitive>("Igamma"));
GVAR_DEF(PrimitivePtr, kPrimIgammac, std::make_shared<Primitive>("Igammac"));
GVAR_DEF(PrimitivePtr, kPrimIgammaGradA, std::make_shared<Primitive>("IgammaGradA"));

View File

@ -0,0 +1,59 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/fmax.h"
#include <string>
#include <memory>
#include <set>
#include "abstract/ops/op_infer.h"
#include "abstract/dshape.h"
#include "ops/op_utils.h"
#include "mindapi/src/helper.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr FmaxInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
return BroadCastInferShape(prim_name, input_args);
}
TypePtr FmaxInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = primitive->name();
const int64_t kInputNum = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputNum, op_name);
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kInt32, kInt64};
auto x_type = input_args[0]->BuildType();
auto y_type = input_args[1]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x1", x_type, valid_types, op_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("x2", y_type, valid_types, op_name);
return x_type;
}
} // namespace
MIND_API_OPERATOR_IMPL(Fmax, BaseOperator);
AbstractBasePtr FmaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto infer_type = FmaxInferType(primitive, input_args);
auto infer_shape = FmaxInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(Fmax, prim::kPrimFmax, FmaxInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

41
mindspore/core/ops/fmax.h Normal file
View File

@ -0,0 +1,41 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_FMAX_H_
#define MINDSPORE_CORE_OPS_FMAX_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameFmax = "Fmax";
class MIND_API Fmax : public BaseOperator {
public:
Fmax() : BaseOperator(kNameFmax) { InitIOName({"x1", "x2"}, {"y"}); }
MIND_API_BASE_MEMBER(Fmax);
};
abstract::AbstractBasePtr FmaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimFmaxPtr = std::shared_ptr<Fmax>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_FMAX_H_

View File

@ -69,6 +69,7 @@ from mindspore.ops.operations.math_ops import TridiagonalSolve
from mindspore.ops.operations.math_ops import Logit
from mindspore.ops.operations.math_ops import Diagonal
from mindspore.ops.operations.array_ops import Transpose, MatrixSetDiagV3
from mindspore.ops.operations.math_ops import Fmax
from mindspore.ops.operations._inner_ops import DynamicBroadcastGradientArgs
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
from mindspore.ops.primitive import constexpr
@ -1376,6 +1377,61 @@ def get_bprop_fmin(self):
return bprop
@bprop_getters.register(Fmax)
def get_bprop_fmax(self):
"""Grad definition for 'Fmax' operation"""
shape_ = P.Shape()
masked_fill_op = P.MaskedFill()
logical_or_op = P.LogicalOr()
logical_not_op = P.LogicalNot()
logical_and_op = P.LogicalAnd()
mul_op = P.Mul()
is_nan_op = P.IsNan()
reshape_ = P.Reshape()
def bprop(x1, x2, out, dout):
x1_dtype = F.dtype(x1)
x2_dtype = F.dtype(x2)
if x1_dtype != mstype.float32:
x1 = F.cast(x1, mstype.float32)
dout = F.cast(dout, mstype.float32)
if x2_dtype != mstype.float32:
x2 = F.cast(x2, mstype.float32)
dout = F.cast(dout, mstype.float32)
b1 = logical_or_op(logical_and_op((x1 >= x2), logical_not_op(is_nan_op(x1))), is_nan_op(x2))
b2 = logical_or_op(logical_and_op(x2 > x1, logical_not_op(is_nan_op(x2))),
logical_and_op(is_nan_op(x1), logical_not_op(is_nan_op(x2))))
rx1 = masked_fill_op(x1, b1, 1.)
rx1 = masked_fill_op(rx1, logical_not_op(b1), 0.)
rx2 = masked_fill_op(x2, b2, 1.)
rx2 = masked_fill_op(rx2, logical_not_op(b2), 0.)
rrx1 = mul_op(rx1, dout)
rrx2 = mul_op(rx2, dout)
shape_of_x1 = shape_(x1)
shape_of_x2 = shape_(x2)
x1_dim = len(shape_of_x1)
x2_dim = len(shape_of_x2)
if x1_dim == 0 and x2_dim != 0:
sum_r1 = rrx1.sum()
sum_r2 = rrx2
elif x1_dim == 0 and x2_dim == 0:
sum_r1 = rrx1.sum()
sum_r2 = rrx2.sum()
elif x1_dim != 0 and x2_dim == 0:
sum_r2 = rrx2.sum()
sum_r1 = rrx1
else:
rx, ry = DynamicBroadcastGradientArgs()(shape_of_x1, shape_of_x2)
sum_r1 = sum_grad_reduce_axis(rrx1, rx)
sum_r2 = sum_grad_reduce_axis(rrx2, ry)
brrx1 = reshape_(sum_r1, shape_of_x1)
brrx2 = reshape_(sum_r2, shape_of_x2)
return brrx1, brrx2
return bprop
@bprop_getters.register(G.MinimumGrad)
def get_bprop_minimum_grad(self):
"""Grad definition for 'MinimumGrad' operation"""

View File

@ -0,0 +1,36 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Fmax op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
fmax_op_info = AiCPURegOp("Fmax") \
.fusion_type("OPAQUE") \
.input(0, "x1", "required") \
.input(1, "x2", "required") \
.output(0, "y", "required") \
.attr("ignore_nan", "bool") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(fmax_op_info)
def _fmax_aicpu():
""""Fmax AiCPU register"""
return

View File

@ -48,6 +48,7 @@ from mindspore.ops.operations.math_ops import (
MatrixExp,
MatrixSolve,
Median,
Fmax,
Orgqr,
Fmin,
Renorm,
@ -3472,6 +3473,47 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
return _nan_to_num(x)
def fmax(x1, x2):
r"""
Computes the maximum of input tensors element-wise.
Note:
- Inputs of `x1` and `x2` comply with the implicit type conversion rules to make the data types consistent.
- The inputs must be two tensors.
- Types of them are one of the following: float16, float32, float64, int32, int64.
- Shapes of them are supposed to be broadcast.
- If one of the elements to be compared is NaN, another element is returned.
.. math::
output_i = max(x1_i, x2_i)
Args:
x1 (Tensor): The first input is a tensor whose data type is number.
x2 (Tensor): The second input is is a tensor whose data type is number.
Returns:
A Tensor, the shape is the same as the one after broadcasting,
and the data type is the one with higher precision or higher digits among the two inputs.
Raises:
TypeError: If `x1` and `x2` is not Tensor.
TypeError: If dtype of `x1` and 'x2' is not one of: float16, float32, float64, int32, int64.
ValueError: If `x1` and `x2` are not the same shape after broadcast.
Supported Platforms:
``CPU``
Examples:
>>> x1 = Tensor(np.array([1.0, 5.0, 3.0]), mindspore.float32)
>>> x2 = Tensor(np.array([4.0, 2.0, 6.0]), mindspore.float32)
>>> output = ops.fmax(x1, x2)
>>> print(output)
[4. 5. 6.]
"""
fmax_ = Fmax()
return fmax_(x1, x2)
def maximum(x, y):
"""
Computes the maximum of input tensors element-wise.

View File

@ -7574,6 +7574,33 @@ class Fmin(Primitive):
self.init_prim_io_names(inputs=['x1, x2'], outputs=['y'])
class Fmax(Primitive):
"""
Computes the maximum of input tensors element-wise.
Refer to :func:`mindspore.ops.fmax` for more detail.
Supported Platforms:
``CPU``
Examples:
>>> x1 = Tensor(np.array([1.0, 5.0, 3.0]), mindspore.float32)
>>> x2 = Tensor(np.array([4.0, 2.0, 6.0]), mindspore.float32)
>>> fmax = ops.Fmax()
>>> output = fmax(x1, x2)
>>> print(output)
[4. 5. 6.]
"""
__mindspore_signature__ = (sig.sig_dtype.T, sig.sig_dtype.T)
@prim_attr_register
def __init__(self):
"""Initialize Fmax"""
self.add_prim_attr('ignore_nan', True)
self.init_prim_io_names(inputs=['x1, x2'], outputs=['y'])
class Eig(Primitive):
"""
Computes the eigenvalues and eigenvectors of a square matrix(batch square matrices).

View File

@ -321,6 +321,15 @@ class DiagNet(nn.Cell):
return x - self.diag(self.fill(mstype.float32, (3,), 1.0))
class FmaxFunc(nn.Cell):
def __init__(self):
super(FmaxFunc, self).__init__()
self.fmax_ = ops.function.math_func.fmax
def construct(self, x1, x2):
return self.fmax_(x1, x2)
class NetWithLossCumSum(nn.Cell):
""" NetWithLossCumSum definition """
@ -954,6 +963,11 @@ raise_set = [
'block': Zeta(),
'desc_inputs': [Tensor(np.array([1, 1, 1, 1], np.float32)),
Tensor([0.5, 0.5, 0.5, 0.5], mstype.float32)]}),
('Fmax', {
'block': FmaxFunc(),
'desc_inputs': [Tensor(np.array([1.0, 2.0, 3.0], np.float32)),
Tensor(np.array([2.0, 1.0, 4.0], np.float32))],
'desc_bprop': [Tensor(np.array([1.0, 2.0, 3.0], np.float32))]}),
('Lcm', {
'block': LcmFunc(),
'desc_inputs': [Tensor(np.array([2, 5, 8]).astype(np.int32)),

View File

@ -56,6 +56,7 @@ from mindspore.ops.operations.math_ops import MatrixLogarithm
from mindspore.ops.operations.math_ops import CholeskySolve
from mindspore.ops.operations.math_ops import InplaceIndexAdd
from mindspore.ops.operations.math_ops import NextAfter
from mindspore.ops.operations.math_ops import Fmax
from mindspore.ops.operations.math_ops import ComplexAbs
from mindspore.ops.operations.math_ops import Orgqr
from mindspore.ops.operations.math_ops import CompareAndBitpack
@ -2601,6 +2602,11 @@ test_case_math_ops = [
'desc_inputs': [Tensor(np.array([[2, 2, 3]]).astype(np.float32)),
Tensor(np.array([[2, 2, 3]]).astype(np.float32))],
'desc_bprop': [Tensor(np.array([[2, 2, 3]]).astype(np.float32))]}),
('Fmax', {
'block': Fmax(),
'desc_inputs': [Tensor(np.array([2, 2, 3]).astype(np.float32)),
Tensor(np.array([2, 2, 3]).astype(np.float32))],
'desc_bprop': [Tensor(np.array([2, 2, 3]).astype(np.float32))]}),
('Trace', {
'block': Trace(),
'desc_inputs': [Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).astype(np.float32))],