!49808 change the tensor value type from int32 to int64

Merge pull request !49808 from maoyaomin/mym_fix_step
This commit is contained in:
i-robot 2023-03-08 01:51:23 +00:00 committed by Gitee
commit fd3bc689af
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 13 additions and 13 deletions

View File

@ -543,7 +543,7 @@ CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr<session::Kerne
const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input,
StreamSwitchKind kind) const {
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder(
{kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32});
{kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt64, TypeId::kNumberTypeInt64});
auto typeNone_abstract = std::make_shared<abstract::AbstractNone>();
auto stream_switch = std::make_shared<Primitive>(kStreamSwitchOpName);
std::vector<AnfNodePtr> inputs;
@ -567,7 +567,7 @@ CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr<session::Kerne
ValuePtr cond = MakeValue(condition);
common::AnfAlgo::SetNodeAttr(kAttrSwitchCondition, cond, stream_switch_app);
// set attr:data_type
int data_type = static_cast<int>(RT_SWITCH_INT32);
int data_type = static_cast<int>(RT_SWITCH_INT64);
ValuePtr dt = MakeValue(data_type);
common::AnfAlgo::SetNodeAttr(kAttrDataType, dt, stream_switch_app);
// set distinction label and graph id
@ -658,9 +658,9 @@ CNodePtr KernelAdjust::CreateStreamAssignAddnOP(const std::shared_ptr<session::K
bool cur_loop) const {
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder(
{kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32});
{kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt64, TypeId::kNumberTypeInt64});
selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});
selected_kernel_builder.SetOutputsDeviceType({kNumberTypeInt32});
selected_kernel_builder.SetOutputsDeviceType({kNumberTypeInt64});
// AssignAdd
auto assign_add = std::make_shared<Primitive>(kAssignAddOpName);
std::vector<AnfNodePtr> inputs;
@ -1081,11 +1081,11 @@ void KernelAdjust::InsertDynamicLossScaleCheckOperations(const std::shared_ptr<s
}
// device loop control
std::shared_ptr<Tensor> KernelAdjust::CreateTensor(int32_t initial_value) const {
std::shared_ptr<Tensor> KernelAdjust::CreateTensor(int64_t initial_value) const {
ShapeVector shp = {1};
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kInt64->type_id(), shp);
MS_EXCEPTION_IF_NULL(tensor);
auto val = static_cast<int32_t *>(tensor->data_c());
auto val = static_cast<int64_t *>(tensor->data_c());
MS_EXCEPTION_IF_NULL(val);
*val = initial_value;
return tensor;
@ -1094,7 +1094,7 @@ std::shared_ptr<Tensor> KernelAdjust::CreateTensor(int32_t initial_value) const
std::shared_ptr<Parameter> KernelAdjust::CreateParameter(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
const string parameter_name) const {
ShapeVector shp = {1};
tensor::TensorPtr tensor_ptr = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
tensor::TensorPtr tensor_ptr = std::make_shared<tensor::Tensor>(kInt64->type_id(), shp);
MS_EXCEPTION_IF_NULL(tensor_ptr);
mindspore::abstract::AbstractBasePtr parameter_abstract_ptr = tensor_ptr->ToAbstract();
if (parameter_abstract_ptr == nullptr) {
@ -1131,7 +1131,7 @@ void KernelAdjust::InsertDeviceLoopCtrl(const std::shared_ptr<session::KernelGra
device_loop_ctrl_params[kConstOneName] = CreateParameter(kernel_graph_ptr, kConstOneName);
// constant loop num in epoch tensor
int32_t initial_value = 0;
int64_t initial_value = 0;
if (NeedLoopSink()) {
// iter_num minus one because the device side counts from 0
initial_value = SizeToInt(LongToSize(ConfigManager::GetInstance().iter_num() - 1));
@ -1212,7 +1212,7 @@ void KernelAdjust::AssignLoopCtrlMemory(const session::KernelGraph &kernel_graph
}
void KernelAdjust::SetDeviceLoopCtrlTensor(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
const std::string name, int32_t value) const {
const std::string name, int64_t value) const {
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
auto device_loop_control_tensors = kernel_graph_ptr->device_loop_control_tensors();
if (device_loop_control_tensors.count(name) == 0) {
@ -1221,7 +1221,7 @@ void KernelAdjust::SetDeviceLoopCtrlTensor(const std::shared_ptr<session::Kernel
}
auto tensor = device_loop_control_tensors.at(name);
MS_EXCEPTION_IF_NULL(tensor);
auto *cur_val = static_cast<int32_t *>(tensor->data_c());
auto *cur_val = static_cast<int64_t *>(tensor->data_c());
MS_EXCEPTION_IF_NULL(cur_val);
*cur_val = value;
tensor->set_sync_status(kNeedSyncHostToDevice);

View File

@ -74,7 +74,7 @@ class KernelAdjust {
CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id) const;
CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id) const;
void SetDeviceLoopCtrlTensor(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, const string name,
int32_t value) const;
int64_t value) const;
private:
KernelAdjust() = default;
@ -170,7 +170,7 @@ class KernelAdjust {
const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) const;
void InsertDynamicLossScaleCheckOperations(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
std::vector<AnfNodePtr> *dynamic_loss_scale_param_list) const;
std::shared_ptr<Tensor> CreateTensor(int32_t initial_value) const;
std::shared_ptr<Tensor> CreateTensor(int64_t initial_value) const;
std::shared_ptr<Parameter> CreateParameter(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
const string parameter_name) const;
};