!37801 [bugfix] get output addr failed of gather op

Merge pull request !37801 from zyli2020/embedding_cache_unify_runtime
This commit is contained in:
i-robot 2022-07-12 01:28:22 +00:00 committed by Gitee
commit 76f2309534
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 152 additions and 110 deletions

View File

@ -18,127 +18,134 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
MS_REG_GPU_KERNEL_THREE(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
EmbeddingLookupKernelMod, double, int)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookupKernelMod, double, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
EmbeddingLookupKernelMod, double, int64_t)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookupKernelMod, double, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
EmbeddingLookupKernelMod, float, int)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookupKernelMod, float, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
EmbeddingLookupKernelMod, float, int64_t)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookupKernelMod, float, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
EmbeddingLookupKernelMod, half, int)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookupKernelMod, half, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
EmbeddingLookupKernelMod, half, int64_t)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookupKernelMod, half, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
EmbeddingLookupKernelMod, int, int)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookupKernelMod, int, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
EmbeddingLookupKernelMod, int, int64_t)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookupKernelMod, int, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
EmbeddingLookupKernelMod, int16_t, int)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookupKernelMod, int16_t, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
EmbeddingLookupKernelMod, int16_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookupKernelMod, int16_t, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
EmbeddingLookupKernelMod, int8_t, int)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookupKernelMod, int8_t, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
EmbeddingLookupKernelMod, int8_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookupKernelMod, int8_t, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
EmbeddingLookupKernelMod, uint8_t, int)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookupKernelMod, uint8_t, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
EmbeddingLookupKernelMod, uint8_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookupKernelMod, uint8_t, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
EmbeddingLookupKernelMod, bool, int)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookupKernelMod, bool, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
EmbeddingLookupKernelMod, bool, int64_t)
EmbeddingLookupKernelMod, bool, int64_t, int64_t)
// dynamic shape
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
EmbeddingLookupKernelMod, double, int)
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
EmbeddingLookupKernelMod, double, int64_t)
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
EmbeddingLookupKernelMod, float, int)
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
EmbeddingLookupKernelMod, float, int64_t)
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
EmbeddingLookupKernelMod, half, int)
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
EmbeddingLookupKernelMod, half, int64_t)
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeBool),
EmbeddingLookupKernelMod, bool, int)
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeBool),
EmbeddingLookupKernelMod, bool, int64_t)
MS_REG_GPU_KERNEL_THREE(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
EmbeddingLookupKernelMod, double, int, int64_t)
MS_REG_GPU_KERNEL_THREE(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
EmbeddingLookupKernelMod, double, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
EmbeddingLookupKernelMod, float, int, int64_t)
MS_REG_GPU_KERNEL_THREE(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
EmbeddingLookupKernelMod, float, int, int)
MS_REG_GPU_KERNEL_THREE(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
EmbeddingLookupKernelMod, float, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
EmbeddingLookupKernelMod, half, int, int64_t)
MS_REG_GPU_KERNEL_THREE(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
EmbeddingLookupKernelMod, half, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeBool),
EmbeddingLookupKernelMod, bool, int, int64_t)
MS_REG_GPU_KERNEL_THREE(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeBool),
EmbeddingLookupKernelMod, bool, int64_t, int64_t)
// dynamic shape ends
} // namespace kernel
} // namespace mindspore

View File

@ -25,7 +25,7 @@
namespace mindspore {
namespace kernel {
template <typename T, typename S>
template <typename T, typename S, typename G>
class EmbeddingLookupKernelMod : public DeprecatedNativeGpuKernelMod {
public:
EmbeddingLookupKernelMod() { ResetResource(); }
@ -41,17 +41,17 @@ class EmbeddingLookupKernelMod : public DeprecatedNativeGpuKernelMod {
S *indices_addr = GetDeviceAddress<S>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
if (is_dynamic_shape_) {
int64_t *offset_device_address = GetDeviceAddress<int64_t>(inputs, 2); // only get this if in dynamic mode
G *offset_device_address = GetDeviceAddress<G>(inputs, 2); // only get this if in dynamic mode
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&offset_, offset_device_address, sizeof(int64_t),
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
cudaMemcpyAsync(&offset_, offset_device_address, sizeof(G), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync offset_ failed");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(),
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaDeviceSyncFailed - EmbeddingLookup - in dynamic mode");
}
auto input_dim1 = input_shapes_[0];
CalEmbeddingLookup(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], input_dim1, offset_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalEmbeddingLookup(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], input_dim1,
static_cast<int64_t>(offset_), reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
@ -81,7 +81,7 @@ class EmbeddingLookupKernelMod : public DeprecatedNativeGpuKernelMod {
<< input_shapes_.size();
}
if (!is_dynamic_shape_) {
offset_ = GetAttr<int64_t>(kernel_node, "offset");
offset_ = static_cast<G>(GetAttr<int64_t>(kernel_node, "offset"));
}
Reshape();
InitSizeLists();
@ -107,7 +107,7 @@ class EmbeddingLookupKernelMod : public DeprecatedNativeGpuKernelMod {
size = GetSize(indices_shapes_);
input_size_list_.push_back(size);
if (is_dynamic_shape_) {
input_size_list_.push_back(sizeof(int64_t));
input_size_list_.push_back(sizeof(G));
}
size = GetSize(output_shapes_);
output_size_list_.push_back(size);
@ -148,7 +148,7 @@ class EmbeddingLookupKernelMod : public DeprecatedNativeGpuKernelMod {
std::vector<int64_t> indices_shapes_;
std::vector<int64_t> output_shapes_;
size_t dims_[3] = {};
int64_t offset_;
G offset_;
bool is_dynamic_shape_;
bool is_null_input_;
};

View File

@ -31,6 +31,10 @@ using mindspore::session::KernelGraph;
const ShapeVector kOneDimensionalShape = {1};
const ShapeVector kTwoDimensionalShape = {1, 1};
const size_t kInputIndexZero = 0;
const size_t kInputIndexOne = 1;
const size_t kInputIndexTwo = 2;
// Maximum number of threads for concurrent accelerated cache processing.
constexpr size_t kMaxThreadNum = 16;
// Maximum number of feature ids processed per thread.
@ -60,11 +64,14 @@ ParameterPtr NewParameter(const KernelGraphPtr &graph, TypePtr type, const Shape
return param;
}
ValueNodePtr NewValueNode(int64_t value) {
auto tensor = std::make_shared<tensor::Tensor>(static_cast<int64_t>(0), kInt32);
ValueNodePtr NewValueNode(int64_t value, const DeviceContext *device_context, size_t stream_id) {
MS_EXCEPTION_IF_NULL(device_context);
auto tensor = std::make_shared<tensor::Tensor>(static_cast<int64_t>(value), kInt32);
auto value_node = NewValueNode(tensor);
value_node->set_abstract(tensor->ToAbstract());
// Create kernel build info.
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
std::vector<std::string> formats = {kOpFormat_DEFAULT};
std::vector<TypeId> types = {kInt32->type_id()};
@ -76,6 +83,28 @@ ValueNodePtr NewValueNode(int64_t value) {
value_node->set_kernel_info(kernel_info);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), value_node.get());
// Create device address.
size_t output_idx = 0;
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
auto value_addr = device_context->device_res_manager_->AllocateMemory(tensor_size);
MS_EXCEPTION_IF_NULL(value_addr);
auto address = device_context->device_res_manager_->CreateDeviceAddress(
value_addr, tensor_size, output_format, output_type_id, trans::GetRuntimePaddingShape(value_node, output_idx));
MS_EXCEPTION_IF_NULL(address);
// Sync tensor value.
MS_EXCEPTION_IF_CHECK_FAIL(address->AsyncHostToDevice({}, tensor_size, output_type_id, tensor->data_c(),
device_context->device_res_manager_->GetStream(stream_id)),
"Async memcpy host to device failed.");
MS_EXCEPTION_IF_CHECK_FAIL(device_context->device_res_manager_->SyncStream(stream_id), "Synchronize stream failed.");
address->set_from_persistent_mem(true);
AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
return value_node;
}
@ -261,16 +290,16 @@ void EmbeddingCachePrefetchActor::BuildEmbeddingCacheLookupKernel() {
// 1. Create parameter nodes which are inputs of embedding cache look up kernel(operator name: 'Gather').
ParameterPtr input_param = NewParameter(graph, kFloat32, kTwoDimensionalShape);
ParameterPtr input_indices = NewParameter(graph, kInt32, kOneDimensionalShape);
ValueNodePtr axis_value_node = NewValueNode(0);
ValueNodePtr offset_value_node = NewValueNode(0, device_context_, stream_id_);
// 2. Create a CNode for operator Gather.
PrimitivePtr emb_lookup_primitive = std::make_shared<Primitive>(kGatherV2OpName);
// 2. Create a CNode for operator EmbeddingLookup.
PrimitivePtr emb_lookup_primitive = std::make_shared<Primitive>(kEmbeddingLookupOpName);
emb_lookup_primitive->set_attr(kAttrInputIsDynamicShape, MakeValue(true));
emb_lookup_primitive->set_attr(kAttrOutputIsDynamicShape, MakeValue(true));
emb_lookup_primitive->set_attr(kAttrStream, MakeValue(stream_id_));
std::vector<AnfNodePtr> emb_lookup_input_nodes{NewValueNode(emb_lookup_primitive), input_param, input_indices,
axis_value_node};
offset_value_node};
embedding_cache_lookup_node_ = graph->NewCNode(emb_lookup_input_nodes);
MS_EXCEPTION_IF_NULL(embedding_cache_lookup_node_);
auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, kTwoDimensionalShape);
@ -319,18 +348,23 @@ bool EmbeddingCachePrefetchActor::LookupDeviceCache(void *indices, void *embeddi
MS_ERROR_IF_NULL(embedding_cache_lookup_node_);
// 1. Update parameter nodes' shape.
auto input_param_node = common::AnfAlgo::GetInputNode(embedding_cache_lookup_node_, 0);
auto input_param_node = common::AnfAlgo::GetInputNode(embedding_cache_lookup_node_, kInputIndexZero);
MS_ERROR_IF_NULL(input_param_node);
const ShapeVector input_param_shape = {SizeToLong(cache_size), SizeToLong(embedding_size)};
auto input_param_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, input_param_shape);
input_param_node->set_abstract(input_param_abstract);
auto input_indices_node = common::AnfAlgo::GetInputNode(embedding_cache_lookup_node_, 1);
auto input_indices_node = common::AnfAlgo::GetInputNode(embedding_cache_lookup_node_, kInputIndexOne);
MS_ERROR_IF_NULL(input_indices_node);
const ShapeVector input_indices_shape = {SizeToLong(indices_num)};
auto input_indices_abstract = std::make_shared<abstract::AbstractTensor>(kInt32, input_indices_shape);
input_indices_node->set_abstract(input_indices_abstract);
auto input_offset_node = common::AnfAlgo::GetInputNode(embedding_cache_lookup_node_, kInputIndexTwo);
MS_ERROR_IF_NULL(input_offset_node);
auto offset_address = AnfAlgo::GetMutableOutputAddr(input_offset_node, 0);
MS_ERROR_IF_NULL(offset_address);
// 2. Infer shape for embedding cache look up kernel(operator name: 'Gather') which is dynamic shape kernel.
if (!InferOpShape(embedding_cache_lookup_node_)) {
MS_LOG(ERROR) << "Infer operator shape failed, op name: " << embedding_cache_lookup_node_->fullname_with_scope();
@ -340,7 +374,8 @@ bool EmbeddingCachePrefetchActor::LookupDeviceCache(void *indices, void *embeddi
// 3. Do embedding cache look up on device.
AddressPtrList kernel_inputs = {
std::make_shared<Address>(embedding_cache, cache_size * embedding_size * sizeof(float)),
std::make_shared<Address>(indices, indices_num * sizeof(int))};
std::make_shared<Address>(indices, indices_num * sizeof(int)),
std::make_shared<Address>(offset_address->GetMutablePtr(), offset_address->GetSize())};
AddressPtrList kernel_outputs = {std::make_shared<Address>(outputs, indices_num * embedding_size * sizeof(float))};
MS_ERROR_IF_NULL(device_context_);
@ -362,19 +397,19 @@ bool EmbeddingCachePrefetchActor::UpdateDeviceCache(void *indices, void *update_
MS_ERROR_IF_NULL(embedding_cache_update_node_);
// 1. Update parameter nodes' shape.
auto input_param_node = common::AnfAlgo::GetInputNode(embedding_cache_update_node_, 0);
auto input_param_node = common::AnfAlgo::GetInputNode(embedding_cache_update_node_, kInputIndexZero);
MS_ERROR_IF_NULL(input_param_node);
const ShapeVector input_param_shape = {SizeToLong(cache_size), SizeToLong(embedding_size)};
auto input_param_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, input_param_shape);
input_param_node->set_abstract(input_param_abstract);
auto input_indices_node = common::AnfAlgo::GetInputNode(embedding_cache_update_node_, 1);
auto input_indices_node = common::AnfAlgo::GetInputNode(embedding_cache_update_node_, kInputIndexOne);
MS_ERROR_IF_NULL(input_indices_node);
const ShapeVector input_indices_shape = {SizeToLong(indices_num)};
auto input_indices_abstract = std::make_shared<abstract::AbstractTensor>(kInt32, input_indices_shape);
input_indices_node->set_abstract(input_indices_abstract);
auto update_values_node = common::AnfAlgo::GetInputNode(embedding_cache_update_node_, 2);
auto update_values_node = common::AnfAlgo::GetInputNode(embedding_cache_update_node_, kInputIndexTwo);
MS_ERROR_IF_NULL(update_values_node);
const ShapeVector update_values_shape = {SizeToLong(indices_num), SizeToLong(embedding_size)};
auto update_values_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, update_values_shape);