diff --git a/mindspore/core/ops/rpc_recv.cc b/mindspore/core/ops/rpc_recv.cc index 58594d6f6a0..5e5da33260b 100644 --- a/mindspore/core/ops/rpc_recv.cc +++ b/mindspore/core/ops/rpc_recv.cc @@ -20,6 +20,7 @@ #include "abstract/ops/primitive_infer_map.h" #include "ir/anf.h" #include "ops/core_ops.h" +#include "ops/op_name.h" #include "ops/primitive_c.h" #include "utils/log_adapter.h" #include "mindapi/src/helper.h" @@ -29,9 +30,16 @@ namespace ops { MIND_API_OPERATOR_IMPL(RpcRecv, BaseOperator); AbstractBasePtr RpcRecvInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &, const std::vector &input_args) { - abstract::AbstractTuplePtr rpc_send_abs = std::make_shared(input_args); - MS_EXCEPTION_IF_NULL(rpc_send_abs); - return rpc_send_abs; + if (input_args.empty()) { + MS_LOG(EXCEPTION) << "The input size of RpcRecv is 0."; + } + if (input_args.size() == static_cast(kDim1)) { + return input_args[kInputIndex0]; + } else { + abstract::AbstractTuplePtr rpc_recv_abs = std::make_shared(input_args); + MS_EXCEPTION_IF_NULL(rpc_recv_abs); + return rpc_recv_abs; + } } REGISTER_PRIMITIVE_EVAL_IMPL(RpcRecv, prim::kPrimRpcRecv, RpcRecvInfer, nullptr, true); } // namespace ops