forked from mindspore-Ecosystem/mindspore
Add softmax infer
This commit is contained in:
parent
3a761536f9
commit
85ac5c9c95
|
@ -46,15 +46,24 @@ void Softmax::Init(const int64_t axis) {
|
|||
|
||||
abstract::ShapePtr SoftMaxInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto Softmax_prim = primitive->cast<PrimSoftmaxPtr>();
|
||||
MS_EXCEPTION_IF_NULL(Softmax_prim);
|
||||
auto op_name = Softmax_prim->name();
|
||||
auto axis = Softmax_prim->get_axis();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name);
|
||||
auto op_name = primitive->name();
|
||||
auto axis = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAxis));
|
||||
(void)CheckAndConvertUtils::CheckValue<size_t>("length of axis", axis.size(), kGreaterEqual, 1, op_name);
|
||||
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
if (shape_map.empty()) {
|
||||
// Scalar input, has no shape
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>());
|
||||
}
|
||||
auto in_shape = shape_map[kShape];
|
||||
auto min_shape = shape_map[kMinShape];
|
||||
auto max_shape = shape_map[kMaxShape];
|
||||
auto rank = SizeToLong(in_shape.size());
|
||||
for (auto &item : axis) {
|
||||
CheckAndConvertUtils::CheckInRange<int64_t>("axis", item, kIncludeLeft, {-rank, rank}, op_name);
|
||||
}
|
||||
if (min_shape.size() != 0 && max_shape.size() != 0) {
|
||||
return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape);
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
||||
|
@ -71,6 +80,7 @@ AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
return std::make_shared<abstract::AbstractTensor>(SoftMaxInferType(primitive, input_args),
|
||||
SoftMaxInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Softmax, prim::kPrimSoftmax, SoftmaxInfer, nullptr, true);
|
||||
REGISTER_PRIMITIVE_C(kNameSoftmax, Softmax);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -177,17 +177,6 @@ class Softmax(PrimitiveWithInfer):
|
|||
for item in self.axis:
|
||||
validator.check_value_type("item of axis", item, [int], self.name)
|
||||
|
||||
def infer_shape(self, logits):
|
||||
validator.check_int(len(self.axis), 1, Rel.GE, "length of axis", self.name)
|
||||
rank = len(logits)
|
||||
for axis_v in self.axis:
|
||||
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
|
||||
return logits
|
||||
|
||||
def infer_dtype(self, logits):
|
||||
validator.check_tensor_dtype_valid("logits", logits, mstype.float_type, self.name)
|
||||
return logits
|
||||
|
||||
|
||||
class LogSoftmax(PrimitiveWithInfer):
|
||||
r"""
|
||||
|
|
Loading…
Reference in New Issue