forked from mindspore-Ecosystem/mindspore
log1p & softplus infer
This commit is contained in:
parent
57ef41af0f
commit
91f35d925f
|
@ -271,6 +271,7 @@ inline const PrimitivePtr kPrimLrn = std::make_shared<Primitive>("LRN");
|
|||
inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad");
|
||||
inline const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop");
|
||||
inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop");
|
||||
inline const PrimitivePtr kPrimLog1p = std::make_shared<Primitive>("Log1p");
|
||||
inline const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>("DropoutGenMask");
|
||||
inline const PrimitivePtr kPrimDropoutDoMask = std::make_shared<Primitive>("DropoutDoMask");
|
||||
inline const PrimitivePtr kPrimDropoutGrad = std::make_shared<Primitive>("DropoutGrad");
|
||||
|
@ -287,6 +288,7 @@ inline const PrimitivePtr kPrimElu = std::make_shared<Primitive>("Elu");
|
|||
inline const PrimitivePtr kPrimRelu6 = std::make_shared<Primitive>("ReLU6");
|
||||
inline const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
|
||||
inline const PrimitivePtr kPrimPRelu = std::make_shared<Primitive>("PReLU");
|
||||
inline const PrimitivePtr kPrimSoftplus = std::make_shared<Primitive>("Softplus");
|
||||
inline const PrimitivePtr kPrimZeros = std::make_shared<Primitive>("Zeros");
|
||||
inline const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike");
|
||||
inline const PrimitivePtr kPrimOnesLike = std::make_shared<Primitive>("OnesLike");
|
||||
|
|
|
@ -23,13 +23,12 @@ namespace {
|
|||
abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto broad_cast_to = primitive->cast<PrimBroadcastToPtr>();
|
||||
MS_EXCEPTION_IF_NULL(broad_cast_to);
|
||||
auto prim_name = broad_cast_to->name();
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto input_x = broad_cast_to->get_shape();
|
||||
auto value_ptr = primitive->GetAttr(kShape);
|
||||
auto input_x = GetValue<std::vector<int64_t>>(value_ptr);
|
||||
int64_t outer_dim_offset = input_x.size() - x_shape.size();
|
||||
CheckAndConvertUtils::Check("x shape", x_shape, kLessEqual, "input_x", input_x, prim_name);
|
||||
CheckAndConvertUtils::Check("x shape", x_shape.size(), kLessEqual, "input_x", input_x.size(), prim_name);
|
||||
bool flag = true;
|
||||
if (input_x.end() == find(input_x.begin(), input_x.end(), -1)) {
|
||||
flag = false;
|
||||
|
@ -49,7 +48,6 @@ abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive,
|
|||
}
|
||||
}
|
||||
}
|
||||
std::reverse(input_x.begin(), input_x.end());
|
||||
return std::make_shared<abstract::Shape>(input_x);
|
||||
}
|
||||
|
||||
|
@ -78,8 +76,8 @@ std::vector<int64_t> BroadcastTo::get_shape() const {
|
|||
AbstractBasePtr BroadcastToInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(BroadcastToInferType(primitive, input_args),
|
||||
BroadcastToInferShape(primitive, input_args)->shape());
|
||||
BroadcastToInferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameBroadcastTo, BroadcastTo);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(BroadcastTo, prim::kPrimBroadcastTo, BroadcastToInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -69,6 +69,6 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
GatherDInferShape(primitive, input_args));
|
||||
return abs;
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/log1p.h"
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
// log1p
|
||||
namespace {
|
||||
abstract::ShapePtr Log1pInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto in_shape = shape_map[kShape];
|
||||
auto min_shape = shape_map[kMinShape];
|
||||
auto max_shape = shape_map[kMaxShape];
|
||||
return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape);
|
||||
}
|
||||
|
||||
TypePtr Log1pInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
// check
|
||||
std::set<TypePtr> valid_index_types = {kFloat16, kFloat32};
|
||||
auto x_type =
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_index_types, prim_name);
|
||||
return x_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr Log1pInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto abs = std::make_shared<abstract::AbstractTensor>(Log1pInferType(primitive, input_args),
|
||||
Log1pInferShape(primitive, input_args));
|
||||
return abs;
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Log1p, prim::kPrimLog1p, Log1pInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_LOG1P_H_
|
||||
#define MINDSPORE_CORE_OPS_LOG1P_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameLog1p = "Log1p";
|
||||
class Log1p : public PrimitiveC {
|
||||
public:
|
||||
Log1p() : PrimitiveC(kNameLog1p) { InitIOName({"x"}, {"y"}); }
|
||||
~Log1p() = default;
|
||||
MS_DECLARE_PARENT(Log1p, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_LOG1P_H_
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/softplus.h"
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
// softplus
|
||||
namespace {
|
||||
abstract::ShapePtr SoftplusInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto in_shape = shape_map[kShape];
|
||||
auto min_shape = shape_map[kMinShape];
|
||||
auto max_shape = shape_map[kMaxShape];
|
||||
return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape);
|
||||
}
|
||||
|
||||
TypePtr SoftplusInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
// check
|
||||
std::set<TypePtr> valid_index_types = {kFloat16, kFloat32, kFloat64};
|
||||
auto x_type =
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_index_types, prim_name);
|
||||
return x_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr SoftplusInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto abs = std::make_shared<abstract::AbstractTensor>(SoftplusInferType(primitive, input_args),
|
||||
SoftplusInferShape(primitive, input_args));
|
||||
return abs;
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Softplus, prim::kPrimSoftplus, SoftplusInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SOFTPLUS_H_
|
||||
#define MINDSPORE_CORE_OPS_SOFTPLUS_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameSoftplus = "Softplus";
|
||||
class Softplus : public PrimitiveC {
|
||||
public:
|
||||
Softplus() : PrimitiveC(kNameSoftplus) { InitIOName({"x"}, {"output"}); }
|
||||
~Softplus() = default;
|
||||
MS_DECLARE_PARENT(Softplus, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SOFTPLUS_H_
|
|
@ -1306,7 +1306,7 @@ class Ones(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
|
||||
class Zeros(PrimitiveWithInfer):
|
||||
class Zeros(Primitive):
|
||||
r"""
|
||||
Creates a tensor filled with value zeros.
|
||||
|
||||
|
@ -4570,7 +4570,7 @@ class BatchToSpaceND(PrimitiveWithInfer):
|
|||
return out_shape
|
||||
|
||||
|
||||
class BroadcastTo(PrimitiveWithInfer):
|
||||
class BroadcastTo(Primitive):
|
||||
"""
|
||||
Broadcasts input tensor to a given shape.
|
||||
|
||||
|
@ -4629,34 +4629,6 @@ class BroadcastTo(PrimitiveWithInfer):
|
|||
validator.check("shape element", i, "shape element min limit", -1, Rel.GE, self.name)
|
||||
self.shape = shape
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name)
|
||||
|
||||
reversed_x_shape = tuple(reversed(x_shape))
|
||||
reversed_filtered_target = []
|
||||
for i, v in enumerate(tuple(reversed(self.shape))):
|
||||
if v == -1:
|
||||
if i >= len(reversed_x_shape):
|
||||
raise ValueError("-1 is not valid in a leading, non-existing dimension")
|
||||
|
||||
reversed_filtered_target.append(reversed_x_shape[i])
|
||||
else:
|
||||
reversed_filtered_target.append(v)
|
||||
|
||||
self.shape = tuple(reversed(reversed_filtered_target))
|
||||
self.add_prim_attr('shape', self.shape)
|
||||
|
||||
for i, v in enumerate(reversed_x_shape):
|
||||
if v not in (reversed_filtered_target[i], 1):
|
||||
raise ValueError(f"Not supported shapes for broadcast, "
|
||||
f"x_shape: {tuple(x_shape)}, target shape {self.shape}.")
|
||||
|
||||
return self.shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class Meshgrid(PrimitiveWithInfer):
|
||||
"""
|
||||
|
@ -5121,7 +5093,7 @@ class EmbeddingLookup(PrimitiveWithCheck):
|
|||
raise ValueError("The dimension of 'params' in EmbeddingLookup must <= 2, but got %d." % len(params_shp))
|
||||
|
||||
|
||||
class GatherD(PrimitiveWithInfer):
|
||||
class GatherD(Primitive):
|
||||
"""
|
||||
Gathers values along an axis specified by dim.
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from mindspore import context
|
|||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import prim_attr_register, PrimitiveWithInfer
|
||||
from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer
|
||||
|
||||
|
||||
def _check_mode(class_name):
|
||||
|
@ -50,7 +50,7 @@ def _check_summary_param(name, value, class_name):
|
|||
SUMMARY_RETURN_VALUE = {'dtype': mstype.int32, 'shape': [1], 'value': None}
|
||||
|
||||
|
||||
class ScalarSummary(PrimitiveWithInfer):
|
||||
class ScalarSummary(Primitive):
|
||||
"""
|
||||
Outputs a scalar to a protocol buffer through a scalar summary operator.
|
||||
|
||||
|
@ -141,7 +141,7 @@ class ImageSummary(PrimitiveWithInfer):
|
|||
return SUMMARY_RETURN_VALUE
|
||||
|
||||
|
||||
class TensorSummary(PrimitiveWithInfer):
|
||||
class TensorSummary(Primitive):
|
||||
"""
|
||||
Outputs a tensor to a protocol buffer through a tensor summary operator.
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ from ...common import dtype as mstype
|
|||
from ...common.tensor import Tensor
|
||||
from ...common._decorator import deprecated
|
||||
from .._utils import get_broadcast_shape
|
||||
from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
|
||||
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
|
||||
|
||||
|
||||
def _infer_shape_reduce(x, axis, keep_dims, prim_name):
|
||||
|
@ -1873,7 +1873,7 @@ class Log(PrimitiveWithInfer):
|
|||
return None
|
||||
|
||||
|
||||
class Log1p(PrimitiveWithInfer):
|
||||
class Log1p(Primitive):
|
||||
"""
|
||||
Returns the natural logarithm of one plus the input tensor element-wise.
|
||||
|
||||
|
@ -1901,14 +1901,6 @@ class Log1p(PrimitiveWithInfer):
|
|||
def __init__(self):
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
|
||||
validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class Erf(PrimitiveWithInfer):
|
||||
r"""
|
||||
|
|
|
@ -230,7 +230,7 @@ class LogSoftmax(PrimitiveWithInfer):
|
|||
return logits
|
||||
|
||||
|
||||
class Softplus(PrimitiveWithInfer):
|
||||
class Softplus(Primitive):
|
||||
r"""
|
||||
Softplus activation function.
|
||||
|
||||
|
@ -267,13 +267,6 @@ class Softplus(PrimitiveWithInfer):
|
|||
"""Initialize Softplus"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_dtype_valid('x', x_dtype, mstype.float_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class Softsign(PrimitiveWithInfer):
|
||||
r"""
|
||||
|
|
Loading…
Reference in New Issue