log1p & softplus infer

This commit is contained in:
simson 2021-04-12 10:22:03 +08:00
parent 57ef41af0f
commit 91f35d925f
11 changed files with 216 additions and 61 deletions

View File

@ -271,6 +271,7 @@ inline const PrimitivePtr kPrimLrn = std::make_shared<Primitive>("LRN");
inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad");
inline const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop");
inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop");
inline const PrimitivePtr kPrimLog1p = std::make_shared<Primitive>("Log1p");
inline const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>("DropoutGenMask");
inline const PrimitivePtr kPrimDropoutDoMask = std::make_shared<Primitive>("DropoutDoMask");
inline const PrimitivePtr kPrimDropoutGrad = std::make_shared<Primitive>("DropoutGrad");
@ -287,6 +288,7 @@ inline const PrimitivePtr kPrimElu = std::make_shared<Primitive>("Elu");
inline const PrimitivePtr kPrimRelu6 = std::make_shared<Primitive>("ReLU6");
inline const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
inline const PrimitivePtr kPrimPRelu = std::make_shared<Primitive>("PReLU");
inline const PrimitivePtr kPrimSoftplus = std::make_shared<Primitive>("Softplus");
inline const PrimitivePtr kPrimZeros = std::make_shared<Primitive>("Zeros");
inline const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike");
inline const PrimitivePtr kPrimOnesLike = std::make_shared<Primitive>("OnesLike");

View File

@ -23,13 +23,12 @@ namespace {
abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto broad_cast_to = primitive->cast<PrimBroadcastToPtr>();
MS_EXCEPTION_IF_NULL(broad_cast_to);
auto prim_name = broad_cast_to->name();
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto input_x = broad_cast_to->get_shape();
auto value_ptr = primitive->GetAttr(kShape);
auto input_x = GetValue<std::vector<int64_t>>(value_ptr);
int64_t outer_dim_offset = input_x.size() - x_shape.size();
CheckAndConvertUtils::Check("x shape", x_shape, kLessEqual, "input_x", input_x, prim_name);
CheckAndConvertUtils::Check("x shape", x_shape.size(), kLessEqual, "input_x", input_x.size(), prim_name);
bool flag = true;
if (input_x.end() == find(input_x.begin(), input_x.end(), -1)) {
flag = false;
@ -49,7 +48,6 @@ abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive,
}
}
}
std::reverse(input_x.begin(), input_x.end());
return std::make_shared<abstract::Shape>(input_x);
}
@ -78,8 +76,8 @@ std::vector<int64_t> BroadcastTo::get_shape() const {
AbstractBasePtr BroadcastToInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(BroadcastToInferType(primitive, input_args),
BroadcastToInferShape(primitive, input_args)->shape());
BroadcastToInferShape(primitive, input_args));
}
REGISTER_PRIMITIVE_C(kNameBroadcastTo, BroadcastTo);
REGISTER_PRIMITIVE_EVAL_IMPL(BroadcastTo, prim::kPrimBroadcastTo, BroadcastToInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -69,6 +69,6 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv
GatherDInferShape(primitive, input_args));
return abs;
}
REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/log1p.h"
#include <memory>
#include <set>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
// log1p
namespace {
abstract::ShapePtr Log1pInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto in_shape = shape_map[kShape];
auto min_shape = shape_map[kMinShape];
auto max_shape = shape_map[kMaxShape];
return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape);
}
TypePtr Log1pInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
// check
std::set<TypePtr> valid_index_types = {kFloat16, kFloat32};
auto x_type =
CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_index_types, prim_name);
return x_type;
}
} // namespace
AbstractBasePtr Log1pInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto abs = std::make_shared<abstract::AbstractTensor>(Log1pInferType(primitive, input_args),
Log1pInferShape(primitive, input_args));
return abs;
}
REGISTER_PRIMITIVE_EVAL_IMPL(Log1p, prim::kPrimLog1p, Log1pInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_LOG1P_H_
#define MINDSPORE_CORE_OPS_LOG1P_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameLog1p = "Log1p";
class Log1p : public PrimitiveC {
public:
Log1p() : PrimitiveC(kNameLog1p) { InitIOName({"x"}, {"y"}); }
~Log1p() = default;
MS_DECLARE_PARENT(Log1p, PrimitiveC);
void Init() {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_LOG1P_H_

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/softplus.h"
#include <memory>
#include <set>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
// softplus
namespace {
abstract::ShapePtr SoftplusInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto in_shape = shape_map[kShape];
auto min_shape = shape_map[kMinShape];
auto max_shape = shape_map[kMaxShape];
return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape);
}
TypePtr SoftplusInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
// check
std::set<TypePtr> valid_index_types = {kFloat16, kFloat32, kFloat64};
auto x_type =
CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_index_types, prim_name);
return x_type;
}
} // namespace
AbstractBasePtr SoftplusInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto abs = std::make_shared<abstract::AbstractTensor>(SoftplusInferType(primitive, input_args),
SoftplusInferShape(primitive, input_args));
return abs;
}
REGISTER_PRIMITIVE_EVAL_IMPL(Softplus, prim::kPrimSoftplus, SoftplusInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SOFTPLUS_H_
#define MINDSPORE_CORE_OPS_SOFTPLUS_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameSoftplus = "Softplus";
class Softplus : public PrimitiveC {
public:
Softplus() : PrimitiveC(kNameSoftplus) { InitIOName({"x"}, {"output"}); }
~Softplus() = default;
MS_DECLARE_PARENT(Softplus, PrimitiveC);
void Init() {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SOFTPLUS_H_

View File

@ -1306,7 +1306,7 @@ class Ones(PrimitiveWithInfer):
return out
class Zeros(PrimitiveWithInfer):
class Zeros(Primitive):
r"""
Creates a tensor filled with value zeros.
@ -4570,7 +4570,7 @@ class BatchToSpaceND(PrimitiveWithInfer):
return out_shape
class BroadcastTo(PrimitiveWithInfer):
class BroadcastTo(Primitive):
"""
Broadcasts input tensor to a given shape.
@ -4629,34 +4629,6 @@ class BroadcastTo(PrimitiveWithInfer):
validator.check("shape element", i, "shape element min limit", -1, Rel.GE, self.name)
self.shape = shape
def infer_shape(self, x_shape):
validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name)
reversed_x_shape = tuple(reversed(x_shape))
reversed_filtered_target = []
for i, v in enumerate(tuple(reversed(self.shape))):
if v == -1:
if i >= len(reversed_x_shape):
raise ValueError("-1 is not valid in a leading, non-existing dimension")
reversed_filtered_target.append(reversed_x_shape[i])
else:
reversed_filtered_target.append(v)
self.shape = tuple(reversed(reversed_filtered_target))
self.add_prim_attr('shape', self.shape)
for i, v in enumerate(reversed_x_shape):
if v not in (reversed_filtered_target[i], 1):
raise ValueError(f"Not supported shapes for broadcast, "
f"x_shape: {tuple(x_shape)}, target shape {self.shape}.")
return self.shape
def infer_dtype(self, x_dtype):
validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
return x_dtype
class Meshgrid(PrimitiveWithInfer):
"""
@ -5121,7 +5093,7 @@ class EmbeddingLookup(PrimitiveWithCheck):
raise ValueError("The dimension of 'params' in EmbeddingLookup must <= 2, but got %d." % len(params_shp))
class GatherD(PrimitiveWithInfer):
class GatherD(Primitive):
"""
Gathers values along an axis specified by dim.

View File

@ -20,7 +20,7 @@ from mindspore import context
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
from ..primitive import prim_attr_register, PrimitiveWithInfer
from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer
def _check_mode(class_name):
@ -50,7 +50,7 @@ def _check_summary_param(name, value, class_name):
SUMMARY_RETURN_VALUE = {'dtype': mstype.int32, 'shape': [1], 'value': None}
class ScalarSummary(PrimitiveWithInfer):
class ScalarSummary(Primitive):
"""
Outputs a scalar to a protocol buffer through a scalar summary operator.
@ -141,7 +141,7 @@ class ImageSummary(PrimitiveWithInfer):
return SUMMARY_RETURN_VALUE
class TensorSummary(PrimitiveWithInfer):
class TensorSummary(Primitive):
"""
Outputs a tensor to a protocol buffer through a tensor summary operator.

View File

@ -26,7 +26,7 @@ from ...common import dtype as mstype
from ...common.tensor import Tensor
from ...common._decorator import deprecated
from .._utils import get_broadcast_shape
from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
def _infer_shape_reduce(x, axis, keep_dims, prim_name):
@ -1873,7 +1873,7 @@ class Log(PrimitiveWithInfer):
return None
class Log1p(PrimitiveWithInfer):
class Log1p(Primitive):
"""
Returns the natural logarithm of one plus the input tensor element-wise.
@ -1901,14 +1901,6 @@ class Log1p(PrimitiveWithInfer):
def __init__(self):
self.init_prim_io_names(inputs=['x'], outputs=['y'])
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name)
return x_dtype
class Erf(PrimitiveWithInfer):
r"""

View File

@ -230,7 +230,7 @@ class LogSoftmax(PrimitiveWithInfer):
return logits
class Softplus(PrimitiveWithInfer):
class Softplus(Primitive):
r"""
Softplus activation function.
@ -267,13 +267,6 @@ class Softplus(PrimitiveWithInfer):
"""Initialize Softplus"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid('x', x_dtype, mstype.float_type, self.name)
return x_dtype
class Softsign(PrimitiveWithInfer):
r"""