diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc index 1c5802ebf46..40463f4d486 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc @@ -208,7 +208,8 @@ ReplaceGraphPtr OneHotInfo::replace_graph(const CNodePtr &cnode) { } std::vector OneHotInfo::GenerateOpStrategies(int64_t stage_id) { - Shapes splittable_inputs = {{1, 1}, {}, {}}; + Shape input0_split(outputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split, {}, {}}; std::vector sp_vector; if (inputs_shape_.size() != 3) { MS_LOG(EXCEPTION) << name_ << ": inputs_shape_ size must be 3, but is " << inputs_shape_.size();