!49808 change the tensor value type from int32 to int64
Merge pull request !49808 from maoyaomin/mym_fix_step
This commit is contained in:
commit
fd3bc689af
|
@ -543,7 +543,7 @@ CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr<session::Kerne
|
|||
const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input,
|
||||
StreamSwitchKind kind) const {
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder(
|
||||
{kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32});
|
||||
{kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt64, TypeId::kNumberTypeInt64});
|
||||
auto typeNone_abstract = std::make_shared<abstract::AbstractNone>();
|
||||
auto stream_switch = std::make_shared<Primitive>(kStreamSwitchOpName);
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
|
@ -567,7 +567,7 @@ CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr<session::Kerne
|
|||
ValuePtr cond = MakeValue(condition);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrSwitchCondition, cond, stream_switch_app);
|
||||
// set attr:data_type
|
||||
int data_type = static_cast<int>(RT_SWITCH_INT32);
|
||||
int data_type = static_cast<int>(RT_SWITCH_INT64);
|
||||
ValuePtr dt = MakeValue(data_type);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrDataType, dt, stream_switch_app);
|
||||
// set distinction label and graph id
|
||||
|
@ -658,9 +658,9 @@ CNodePtr KernelAdjust::CreateStreamAssignAddnOP(const std::shared_ptr<session::K
|
|||
bool cur_loop) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder(
|
||||
{kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32});
|
||||
{kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt64, TypeId::kNumberTypeInt64});
|
||||
selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
selected_kernel_builder.SetOutputsDeviceType({kNumberTypeInt32});
|
||||
selected_kernel_builder.SetOutputsDeviceType({kNumberTypeInt64});
|
||||
// AssignAdd
|
||||
auto assign_add = std::make_shared<Primitive>(kAssignAddOpName);
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
|
@ -1081,11 +1081,11 @@ void KernelAdjust::InsertDynamicLossScaleCheckOperations(const std::shared_ptr<s
|
|||
}
|
||||
|
||||
// device loop control
|
||||
std::shared_ptr<Tensor> KernelAdjust::CreateTensor(int32_t initial_value) const {
|
||||
std::shared_ptr<Tensor> KernelAdjust::CreateTensor(int64_t initial_value) const {
|
||||
ShapeVector shp = {1};
|
||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
|
||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kInt64->type_id(), shp);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
auto val = static_cast<int32_t *>(tensor->data_c());
|
||||
auto val = static_cast<int64_t *>(tensor->data_c());
|
||||
MS_EXCEPTION_IF_NULL(val);
|
||||
*val = initial_value;
|
||||
return tensor;
|
||||
|
@ -1094,7 +1094,7 @@ std::shared_ptr<Tensor> KernelAdjust::CreateTensor(int32_t initial_value) const
|
|||
std::shared_ptr<Parameter> KernelAdjust::CreateParameter(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
||||
const string parameter_name) const {
|
||||
ShapeVector shp = {1};
|
||||
tensor::TensorPtr tensor_ptr = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
|
||||
tensor::TensorPtr tensor_ptr = std::make_shared<tensor::Tensor>(kInt64->type_id(), shp);
|
||||
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||
mindspore::abstract::AbstractBasePtr parameter_abstract_ptr = tensor_ptr->ToAbstract();
|
||||
if (parameter_abstract_ptr == nullptr) {
|
||||
|
@ -1131,7 +1131,7 @@ void KernelAdjust::InsertDeviceLoopCtrl(const std::shared_ptr<session::KernelGra
|
|||
device_loop_ctrl_params[kConstOneName] = CreateParameter(kernel_graph_ptr, kConstOneName);
|
||||
|
||||
// constant loop num in epoch tensor
|
||||
int32_t initial_value = 0;
|
||||
int64_t initial_value = 0;
|
||||
if (NeedLoopSink()) {
|
||||
// iter_num minus one because the device side counts from 0
|
||||
initial_value = SizeToInt(LongToSize(ConfigManager::GetInstance().iter_num() - 1));
|
||||
|
@ -1212,7 +1212,7 @@ void KernelAdjust::AssignLoopCtrlMemory(const session::KernelGraph &kernel_graph
|
|||
}
|
||||
|
||||
void KernelAdjust::SetDeviceLoopCtrlTensor(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
||||
const std::string name, int32_t value) const {
|
||||
const std::string name, int64_t value) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
||||
auto device_loop_control_tensors = kernel_graph_ptr->device_loop_control_tensors();
|
||||
if (device_loop_control_tensors.count(name) == 0) {
|
||||
|
@ -1221,7 +1221,7 @@ void KernelAdjust::SetDeviceLoopCtrlTensor(const std::shared_ptr<session::Kernel
|
|||
}
|
||||
auto tensor = device_loop_control_tensors.at(name);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
auto *cur_val = static_cast<int32_t *>(tensor->data_c());
|
||||
auto *cur_val = static_cast<int64_t *>(tensor->data_c());
|
||||
MS_EXCEPTION_IF_NULL(cur_val);
|
||||
*cur_val = value;
|
||||
tensor->set_sync_status(kNeedSyncHostToDevice);
|
||||
|
|
|
@ -74,7 +74,7 @@ class KernelAdjust {
|
|||
CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id) const;
|
||||
CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id) const;
|
||||
void SetDeviceLoopCtrlTensor(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, const string name,
|
||||
int32_t value) const;
|
||||
int64_t value) const;
|
||||
|
||||
private:
|
||||
KernelAdjust() = default;
|
||||
|
@ -170,7 +170,7 @@ class KernelAdjust {
|
|||
const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) const;
|
||||
void InsertDynamicLossScaleCheckOperations(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
||||
std::vector<AnfNodePtr> *dynamic_loss_scale_param_list) const;
|
||||
std::shared_ptr<Tensor> CreateTensor(int32_t initial_value) const;
|
||||
std::shared_ptr<Tensor> CreateTensor(int64_t initial_value) const;
|
||||
std::shared_ptr<Parameter> CreateParameter(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
||||
const string parameter_name) const;
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue