forked from mindspore-Ecosystem/mindspore
log1p & softplus infer
This commit is contained in:
parent
57ef41af0f
commit
91f35d925f
|
@ -271,6 +271,7 @@ inline const PrimitivePtr kPrimLrn = std::make_shared<Primitive>("LRN");
|
||||||
inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad");
|
inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad");
|
||||||
inline const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop");
|
inline const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop");
|
||||||
inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop");
|
inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop");
|
||||||
|
inline const PrimitivePtr kPrimLog1p = std::make_shared<Primitive>("Log1p");
|
||||||
inline const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>("DropoutGenMask");
|
inline const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>("DropoutGenMask");
|
||||||
inline const PrimitivePtr kPrimDropoutDoMask = std::make_shared<Primitive>("DropoutDoMask");
|
inline const PrimitivePtr kPrimDropoutDoMask = std::make_shared<Primitive>("DropoutDoMask");
|
||||||
inline const PrimitivePtr kPrimDropoutGrad = std::make_shared<Primitive>("DropoutGrad");
|
inline const PrimitivePtr kPrimDropoutGrad = std::make_shared<Primitive>("DropoutGrad");
|
||||||
|
@ -287,6 +288,7 @@ inline const PrimitivePtr kPrimElu = std::make_shared<Primitive>("Elu");
|
||||||
inline const PrimitivePtr kPrimRelu6 = std::make_shared<Primitive>("ReLU6");
|
inline const PrimitivePtr kPrimRelu6 = std::make_shared<Primitive>("ReLU6");
|
||||||
inline const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
|
inline const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
|
||||||
inline const PrimitivePtr kPrimPRelu = std::make_shared<Primitive>("PReLU");
|
inline const PrimitivePtr kPrimPRelu = std::make_shared<Primitive>("PReLU");
|
||||||
|
inline const PrimitivePtr kPrimSoftplus = std::make_shared<Primitive>("Softplus");
|
||||||
inline const PrimitivePtr kPrimZeros = std::make_shared<Primitive>("Zeros");
|
inline const PrimitivePtr kPrimZeros = std::make_shared<Primitive>("Zeros");
|
||||||
inline const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike");
|
inline const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike");
|
||||||
inline const PrimitivePtr kPrimOnesLike = std::make_shared<Primitive>("OnesLike");
|
inline const PrimitivePtr kPrimOnesLike = std::make_shared<Primitive>("OnesLike");
|
||||||
|
|
|
@ -23,13 +23,12 @@ namespace {
|
||||||
abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive,
|
abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto broad_cast_to = primitive->cast<PrimBroadcastToPtr>();
|
auto prim_name = primitive->name();
|
||||||
MS_EXCEPTION_IF_NULL(broad_cast_to);
|
|
||||||
auto prim_name = broad_cast_to->name();
|
|
||||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||||
auto input_x = broad_cast_to->get_shape();
|
auto value_ptr = primitive->GetAttr(kShape);
|
||||||
|
auto input_x = GetValue<std::vector<int64_t>>(value_ptr);
|
||||||
int64_t outer_dim_offset = input_x.size() - x_shape.size();
|
int64_t outer_dim_offset = input_x.size() - x_shape.size();
|
||||||
CheckAndConvertUtils::Check("x shape", x_shape, kLessEqual, "input_x", input_x, prim_name);
|
CheckAndConvertUtils::Check("x shape", x_shape.size(), kLessEqual, "input_x", input_x.size(), prim_name);
|
||||||
bool flag = true;
|
bool flag = true;
|
||||||
if (input_x.end() == find(input_x.begin(), input_x.end(), -1)) {
|
if (input_x.end() == find(input_x.begin(), input_x.end(), -1)) {
|
||||||
flag = false;
|
flag = false;
|
||||||
|
@ -49,7 +48,6 @@ abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::reverse(input_x.begin(), input_x.end());
|
|
||||||
return std::make_shared<abstract::Shape>(input_x);
|
return std::make_shared<abstract::Shape>(input_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,8 +76,8 @@ std::vector<int64_t> BroadcastTo::get_shape() const {
|
||||||
AbstractBasePtr BroadcastToInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr BroadcastToInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
return std::make_shared<abstract::AbstractTensor>(BroadcastToInferType(primitive, input_args),
|
return std::make_shared<abstract::AbstractTensor>(BroadcastToInferType(primitive, input_args),
|
||||||
BroadcastToInferShape(primitive, input_args)->shape());
|
BroadcastToInferShape(primitive, input_args));
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNameBroadcastTo, BroadcastTo);
|
REGISTER_PRIMITIVE_EVAL_IMPL(BroadcastTo, prim::kPrimBroadcastTo, BroadcastToInfer, nullptr, true);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -69,6 +69,6 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
||||||
GatherDInferShape(primitive, input_args));
|
GatherDInferShape(primitive, input_args));
|
||||||
return abs;
|
return abs;
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, false);
|
REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, true);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ops/log1p.h"
|
||||||
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "utils/tensor_construct_utils.h"
|
||||||
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
// log1p
|
||||||
|
namespace {
|
||||||
|
abstract::ShapePtr Log1pInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||||
|
auto in_shape = shape_map[kShape];
|
||||||
|
auto min_shape = shape_map[kMinShape];
|
||||||
|
auto max_shape = shape_map[kMaxShape];
|
||||||
|
return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr Log1pInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
auto prim_name = prim->name();
|
||||||
|
// check
|
||||||
|
std::set<TypePtr> valid_index_types = {kFloat16, kFloat32};
|
||||||
|
auto x_type =
|
||||||
|
CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_index_types, prim_name);
|
||||||
|
return x_type;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
AbstractBasePtr Log1pInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto abs = std::make_shared<abstract::AbstractTensor>(Log1pInferType(primitive, input_args),
|
||||||
|
Log1pInferShape(primitive, input_args));
|
||||||
|
return abs;
|
||||||
|
}
|
||||||
|
REGISTER_PRIMITIVE_EVAL_IMPL(Log1p, prim::kPrimLog1p, Log1pInfer, nullptr, true);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_LOG1P_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_LOG1P_H_
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include "ops/primitive_c.h"
|
||||||
|
#include "abstract/abstract_value.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
constexpr auto kNameLog1p = "Log1p";
|
||||||
|
class Log1p : public PrimitiveC {
|
||||||
|
public:
|
||||||
|
Log1p() : PrimitiveC(kNameLog1p) { InitIOName({"x"}, {"y"}); }
|
||||||
|
~Log1p() = default;
|
||||||
|
MS_DECLARE_PARENT(Log1p, PrimitiveC);
|
||||||
|
void Init() {}
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_LOG1P_H_
|
|
@ -0,0 +1,58 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ops/softplus.h"
|
||||||
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "utils/tensor_construct_utils.h"
|
||||||
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
// softplus
|
||||||
|
namespace {
|
||||||
|
abstract::ShapePtr SoftplusInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||||
|
auto in_shape = shape_map[kShape];
|
||||||
|
auto min_shape = shape_map[kMinShape];
|
||||||
|
auto max_shape = shape_map[kMaxShape];
|
||||||
|
return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr SoftplusInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
auto prim_name = prim->name();
|
||||||
|
// check
|
||||||
|
std::set<TypePtr> valid_index_types = {kFloat16, kFloat32, kFloat64};
|
||||||
|
auto x_type =
|
||||||
|
CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_index_types, prim_name);
|
||||||
|
return x_type;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
AbstractBasePtr SoftplusInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto abs = std::make_shared<abstract::AbstractTensor>(SoftplusInferType(primitive, input_args),
|
||||||
|
SoftplusInferShape(primitive, input_args));
|
||||||
|
return abs;
|
||||||
|
}
|
||||||
|
REGISTER_PRIMITIVE_EVAL_IMPL(Softplus, prim::kPrimSoftplus, SoftplusInfer, nullptr, true);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_SOFTPLUS_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_SOFTPLUS_H_
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include "ops/primitive_c.h"
|
||||||
|
#include "abstract/abstract_value.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
constexpr auto kNameSoftplus = "Softplus";
|
||||||
|
class Softplus : public PrimitiveC {
|
||||||
|
public:
|
||||||
|
Softplus() : PrimitiveC(kNameSoftplus) { InitIOName({"x"}, {"output"}); }
|
||||||
|
~Softplus() = default;
|
||||||
|
MS_DECLARE_PARENT(Softplus, PrimitiveC);
|
||||||
|
void Init() {}
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_SOFTPLUS_H_
|
|
@ -1306,7 +1306,7 @@ class Ones(PrimitiveWithInfer):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Zeros(PrimitiveWithInfer):
|
class Zeros(Primitive):
|
||||||
r"""
|
r"""
|
||||||
Creates a tensor filled with value zeros.
|
Creates a tensor filled with value zeros.
|
||||||
|
|
||||||
|
@ -4570,7 +4570,7 @@ class BatchToSpaceND(PrimitiveWithInfer):
|
||||||
return out_shape
|
return out_shape
|
||||||
|
|
||||||
|
|
||||||
class BroadcastTo(PrimitiveWithInfer):
|
class BroadcastTo(Primitive):
|
||||||
"""
|
"""
|
||||||
Broadcasts input tensor to a given shape.
|
Broadcasts input tensor to a given shape.
|
||||||
|
|
||||||
|
@ -4629,34 +4629,6 @@ class BroadcastTo(PrimitiveWithInfer):
|
||||||
validator.check("shape element", i, "shape element min limit", -1, Rel.GE, self.name)
|
validator.check("shape element", i, "shape element min limit", -1, Rel.GE, self.name)
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
|
|
||||||
def infer_shape(self, x_shape):
|
|
||||||
validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name)
|
|
||||||
|
|
||||||
reversed_x_shape = tuple(reversed(x_shape))
|
|
||||||
reversed_filtered_target = []
|
|
||||||
for i, v in enumerate(tuple(reversed(self.shape))):
|
|
||||||
if v == -1:
|
|
||||||
if i >= len(reversed_x_shape):
|
|
||||||
raise ValueError("-1 is not valid in a leading, non-existing dimension")
|
|
||||||
|
|
||||||
reversed_filtered_target.append(reversed_x_shape[i])
|
|
||||||
else:
|
|
||||||
reversed_filtered_target.append(v)
|
|
||||||
|
|
||||||
self.shape = tuple(reversed(reversed_filtered_target))
|
|
||||||
self.add_prim_attr('shape', self.shape)
|
|
||||||
|
|
||||||
for i, v in enumerate(reversed_x_shape):
|
|
||||||
if v not in (reversed_filtered_target[i], 1):
|
|
||||||
raise ValueError(f"Not supported shapes for broadcast, "
|
|
||||||
f"x_shape: {tuple(x_shape)}, target shape {self.shape}.")
|
|
||||||
|
|
||||||
return self.shape
|
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype):
|
|
||||||
validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
|
|
||||||
return x_dtype
|
|
||||||
|
|
||||||
|
|
||||||
class Meshgrid(PrimitiveWithInfer):
|
class Meshgrid(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
|
@ -5121,7 +5093,7 @@ class EmbeddingLookup(PrimitiveWithCheck):
|
||||||
raise ValueError("The dimension of 'params' in EmbeddingLookup must <= 2, but got %d." % len(params_shp))
|
raise ValueError("The dimension of 'params' in EmbeddingLookup must <= 2, but got %d." % len(params_shp))
|
||||||
|
|
||||||
|
|
||||||
class GatherD(PrimitiveWithInfer):
|
class GatherD(Primitive):
|
||||||
"""
|
"""
|
||||||
Gathers values along an axis specified by dim.
|
Gathers values along an axis specified by dim.
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ from mindspore import context
|
||||||
from ..._checkparam import Validator as validator
|
from ..._checkparam import Validator as validator
|
||||||
from ..._checkparam import Rel
|
from ..._checkparam import Rel
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from ..primitive import prim_attr_register, PrimitiveWithInfer
|
from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer
|
||||||
|
|
||||||
|
|
||||||
def _check_mode(class_name):
|
def _check_mode(class_name):
|
||||||
|
@ -50,7 +50,7 @@ def _check_summary_param(name, value, class_name):
|
||||||
SUMMARY_RETURN_VALUE = {'dtype': mstype.int32, 'shape': [1], 'value': None}
|
SUMMARY_RETURN_VALUE = {'dtype': mstype.int32, 'shape': [1], 'value': None}
|
||||||
|
|
||||||
|
|
||||||
class ScalarSummary(PrimitiveWithInfer):
|
class ScalarSummary(Primitive):
|
||||||
"""
|
"""
|
||||||
Outputs a scalar to a protocol buffer through a scalar summary operator.
|
Outputs a scalar to a protocol buffer through a scalar summary operator.
|
||||||
|
|
||||||
|
@ -141,7 +141,7 @@ class ImageSummary(PrimitiveWithInfer):
|
||||||
return SUMMARY_RETURN_VALUE
|
return SUMMARY_RETURN_VALUE
|
||||||
|
|
||||||
|
|
||||||
class TensorSummary(PrimitiveWithInfer):
|
class TensorSummary(Primitive):
|
||||||
"""
|
"""
|
||||||
Outputs a tensor to a protocol buffer through a tensor summary operator.
|
Outputs a tensor to a protocol buffer through a tensor summary operator.
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ from ...common import dtype as mstype
|
||||||
from ...common.tensor import Tensor
|
from ...common.tensor import Tensor
|
||||||
from ...common._decorator import deprecated
|
from ...common._decorator import deprecated
|
||||||
from .._utils import get_broadcast_shape
|
from .._utils import get_broadcast_shape
|
||||||
from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
|
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
|
||||||
|
|
||||||
|
|
||||||
def _infer_shape_reduce(x, axis, keep_dims, prim_name):
|
def _infer_shape_reduce(x, axis, keep_dims, prim_name):
|
||||||
|
@ -1873,7 +1873,7 @@ class Log(PrimitiveWithInfer):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class Log1p(PrimitiveWithInfer):
|
class Log1p(Primitive):
|
||||||
"""
|
"""
|
||||||
Returns the natural logarithm of one plus the input tensor element-wise.
|
Returns the natural logarithm of one plus the input tensor element-wise.
|
||||||
|
|
||||||
|
@ -1901,14 +1901,6 @@ class Log1p(PrimitiveWithInfer):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||||
|
|
||||||
def infer_shape(self, x_shape):
|
|
||||||
return x_shape
|
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype):
|
|
||||||
validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
|
|
||||||
validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name)
|
|
||||||
return x_dtype
|
|
||||||
|
|
||||||
|
|
||||||
class Erf(PrimitiveWithInfer):
|
class Erf(PrimitiveWithInfer):
|
||||||
r"""
|
r"""
|
||||||
|
|
|
@ -230,7 +230,7 @@ class LogSoftmax(PrimitiveWithInfer):
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
class Softplus(PrimitiveWithInfer):
|
class Softplus(Primitive):
|
||||||
r"""
|
r"""
|
||||||
Softplus activation function.
|
Softplus activation function.
|
||||||
|
|
||||||
|
@ -267,13 +267,6 @@ class Softplus(PrimitiveWithInfer):
|
||||||
"""Initialize Softplus"""
|
"""Initialize Softplus"""
|
||||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||||
|
|
||||||
def infer_shape(self, x_shape):
|
|
||||||
return x_shape
|
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype):
|
|
||||||
validator.check_tensor_dtype_valid('x', x_dtype, mstype.float_type, self.name)
|
|
||||||
return x_dtype
|
|
||||||
|
|
||||||
|
|
||||||
class Softsign(PrimitiveWithInfer):
|
class Softsign(PrimitiveWithInfer):
|
||||||
r"""
|
r"""
|
||||||
|
|
Loading…
Reference in New Issue