Add softmax infer

This commit is contained in:
l00591931 2021-04-01 09:52:10 +08:00
parent 3a761536f9
commit 85ac5c9c95
2 changed files with 15 additions and 16 deletions

View File

@ -46,15 +46,24 @@ void Softmax::Init(const int64_t axis) {
abstract::ShapePtr SoftMaxInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto Softmax_prim = primitive->cast<PrimSoftmaxPtr>();
MS_EXCEPTION_IF_NULL(Softmax_prim);
auto op_name = Softmax_prim->name();
auto axis = Softmax_prim->get_axis();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name);
auto op_name = primitive->name();
auto axis = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAxis));
(void)CheckAndConvertUtils::CheckValue<size_t>("length of axis", axis.size(), kGreaterEqual, 1, op_name);
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
if (shape_map.empty()) {
// Scalar input, has no shape
return std::make_shared<abstract::Shape>(std::vector<int64_t>());
}
auto in_shape = shape_map[kShape];
auto min_shape = shape_map[kMinShape];
auto max_shape = shape_map[kMaxShape];
auto rank = SizeToLong(in_shape.size());
for (auto &item : axis) {
CheckAndConvertUtils::CheckInRange<int64_t>("axis", item, kIncludeLeft, {-rank, rank}, op_name);
}
if (min_shape.size() != 0 && max_shape.size() != 0) {
return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape);
}
return std::make_shared<abstract::Shape>(in_shape);
}
@ -71,6 +80,7 @@ AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const Primitiv
return std::make_shared<abstract::AbstractTensor>(SoftMaxInferType(primitive, input_args),
SoftMaxInferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(Softmax, prim::kPrimSoftmax, SoftmaxInfer, nullptr, true);
REGISTER_PRIMITIVE_C(kNameSoftmax, Softmax);
} // namespace ops
} // namespace mindspore

View File

@ -177,17 +177,6 @@ class Softmax(PrimitiveWithInfer):
for item in self.axis:
validator.check_value_type("item of axis", item, [int], self.name)
def infer_shape(self, logits):
validator.check_int(len(self.axis), 1, Rel.GE, "length of axis", self.name)
rank = len(logits)
for axis_v in self.axis:
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
return logits
def infer_dtype(self, logits):
validator.check_tensor_dtype_valid("logits", logits, mstype.float_type, self.name)
return logits
class LogSoftmax(PrimitiveWithInfer):
r"""