!37801 [bugfix] get output addr failed of gather op
Merge pull request !37801 from zyli2020/embedding_cache_unify_runtime
This commit is contained in:
commit
76f2309534
|
@ -18,127 +18,134 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
EmbeddingLookup,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
EmbeddingLookupKernelMod, double, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
EmbeddingLookupKernelMod, double, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
EmbeddingLookup,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
EmbeddingLookupKernelMod, double, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
EmbeddingLookupKernelMod, double, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
EmbeddingLookup,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
EmbeddingLookupKernelMod, float, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
EmbeddingLookupKernelMod, float, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
EmbeddingLookup,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
EmbeddingLookupKernelMod, float, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
EmbeddingLookupKernelMod, float, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
EmbeddingLookup,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
EmbeddingLookupKernelMod, half, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
EmbeddingLookupKernelMod, half, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
EmbeddingLookup,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
EmbeddingLookupKernelMod, half, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
EmbeddingLookupKernelMod, half, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
EmbeddingLookup,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
EmbeddingLookupKernelMod, int, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
EmbeddingLookupKernelMod, int, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
EmbeddingLookup,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
EmbeddingLookupKernelMod, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
EmbeddingLookupKernelMod, int, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
EmbeddingLookup,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
EmbeddingLookupKernelMod, int16_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
EmbeddingLookupKernelMod, int16_t, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
EmbeddingLookup,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
EmbeddingLookupKernelMod, int16_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
EmbeddingLookupKernelMod, int16_t, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
EmbeddingLookup,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
EmbeddingLookupKernelMod, int8_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
EmbeddingLookupKernelMod, int8_t, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
EmbeddingLookup,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
EmbeddingLookupKernelMod, int8_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
EmbeddingLookupKernelMod, int8_t, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
EmbeddingLookup,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
EmbeddingLookupKernelMod, uint8_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
EmbeddingLookupKernelMod, uint8_t, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
EmbeddingLookup,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
EmbeddingLookupKernelMod, uint8_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
EmbeddingLookupKernelMod, uint8_t, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
EmbeddingLookup,
|
||||
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
EmbeddingLookupKernelMod, bool, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
EmbeddingLookupKernelMod, bool, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
EmbeddingLookup,
|
||||
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
EmbeddingLookupKernelMod, bool, int64_t)
|
||||
EmbeddingLookupKernelMod, bool, int64_t, int64_t)
|
||||
// dynamic shape
|
||||
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
EmbeddingLookupKernelMod, double, int)
|
||||
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
EmbeddingLookupKernelMod, double, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
EmbeddingLookupKernelMod, float, int)
|
||||
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
EmbeddingLookupKernelMod, float, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
EmbeddingLookupKernelMod, half, int)
|
||||
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
EmbeddingLookupKernelMod, half, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
EmbeddingLookupKernelMod, bool, int)
|
||||
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
EmbeddingLookupKernelMod, bool, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(EmbeddingLookup,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
EmbeddingLookupKernelMod, double, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(EmbeddingLookup,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
EmbeddingLookupKernelMod, double, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(EmbeddingLookup,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
EmbeddingLookupKernelMod, float, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(EmbeddingLookup,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
EmbeddingLookupKernelMod, float, int, int)
|
||||
MS_REG_GPU_KERNEL_THREE(EmbeddingLookup,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
EmbeddingLookupKernelMod, float, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(EmbeddingLookup,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
EmbeddingLookupKernelMod, half, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(EmbeddingLookup,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
EmbeddingLookupKernelMod, half, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(EmbeddingLookup,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
EmbeddingLookupKernelMod, bool, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(EmbeddingLookup,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
EmbeddingLookupKernelMod, bool, int64_t, int64_t)
|
||||
// dynamic shape ends
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S>
|
||||
template <typename T, typename S, typename G>
|
||||
class EmbeddingLookupKernelMod : public DeprecatedNativeGpuKernelMod {
|
||||
public:
|
||||
EmbeddingLookupKernelMod() { ResetResource(); }
|
||||
|
@ -41,17 +41,17 @@ class EmbeddingLookupKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
S *indices_addr = GetDeviceAddress<S>(inputs, 1);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
if (is_dynamic_shape_) {
|
||||
int64_t *offset_device_address = GetDeviceAddress<int64_t>(inputs, 2); // only get this if in dynamic mode
|
||||
G *offset_device_address = GetDeviceAddress<G>(inputs, 2); // only get this if in dynamic mode
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(&offset_, offset_device_address, sizeof(int64_t),
|
||||
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
cudaMemcpyAsync(&offset_, offset_device_address, sizeof(G), cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync offset_ failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(),
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaDeviceSyncFailed - EmbeddingLookup - in dynamic mode");
|
||||
}
|
||||
auto input_dim1 = input_shapes_[0];
|
||||
CalEmbeddingLookup(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], input_dim1, offset_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalEmbeddingLookup(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], input_dim1,
|
||||
static_cast<int64_t>(offset_), reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
|
@ -81,7 +81,7 @@ class EmbeddingLookupKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
<< input_shapes_.size();
|
||||
}
|
||||
if (!is_dynamic_shape_) {
|
||||
offset_ = GetAttr<int64_t>(kernel_node, "offset");
|
||||
offset_ = static_cast<G>(GetAttr<int64_t>(kernel_node, "offset"));
|
||||
}
|
||||
Reshape();
|
||||
InitSizeLists();
|
||||
|
@ -107,7 +107,7 @@ class EmbeddingLookupKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
size = GetSize(indices_shapes_);
|
||||
input_size_list_.push_back(size);
|
||||
if (is_dynamic_shape_) {
|
||||
input_size_list_.push_back(sizeof(int64_t));
|
||||
input_size_list_.push_back(sizeof(G));
|
||||
}
|
||||
size = GetSize(output_shapes_);
|
||||
output_size_list_.push_back(size);
|
||||
|
@ -148,7 +148,7 @@ class EmbeddingLookupKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
std::vector<int64_t> indices_shapes_;
|
||||
std::vector<int64_t> output_shapes_;
|
||||
size_t dims_[3] = {};
|
||||
int64_t offset_;
|
||||
G offset_;
|
||||
bool is_dynamic_shape_;
|
||||
bool is_null_input_;
|
||||
};
|
||||
|
|
|
@ -31,6 +31,10 @@ using mindspore::session::KernelGraph;
|
|||
const ShapeVector kOneDimensionalShape = {1};
|
||||
const ShapeVector kTwoDimensionalShape = {1, 1};
|
||||
|
||||
const size_t kInputIndexZero = 0;
|
||||
const size_t kInputIndexOne = 1;
|
||||
const size_t kInputIndexTwo = 2;
|
||||
|
||||
// Maximum number of threads for concurrent accelerated cache processing.
|
||||
constexpr size_t kMaxThreadNum = 16;
|
||||
// Maximum number of feature ids processed per thread.
|
||||
|
@ -60,11 +64,14 @@ ParameterPtr NewParameter(const KernelGraphPtr &graph, TypePtr type, const Shape
|
|||
return param;
|
||||
}
|
||||
|
||||
ValueNodePtr NewValueNode(int64_t value) {
|
||||
auto tensor = std::make_shared<tensor::Tensor>(static_cast<int64_t>(0), kInt32);
|
||||
ValueNodePtr NewValueNode(int64_t value, const DeviceContext *device_context, size_t stream_id) {
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
|
||||
auto tensor = std::make_shared<tensor::Tensor>(static_cast<int64_t>(value), kInt32);
|
||||
auto value_node = NewValueNode(tensor);
|
||||
value_node->set_abstract(tensor->ToAbstract());
|
||||
|
||||
// Create kernel build info.
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
std::vector<std::string> formats = {kOpFormat_DEFAULT};
|
||||
std::vector<TypeId> types = {kInt32->type_id()};
|
||||
|
@ -76,6 +83,28 @@ ValueNodePtr NewValueNode(int64_t value) {
|
|||
value_node->set_kernel_info(kernel_info);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), value_node.get());
|
||||
|
||||
// Create device address.
|
||||
size_t output_idx = 0;
|
||||
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
|
||||
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
|
||||
auto value_addr = device_context->device_res_manager_->AllocateMemory(tensor_size);
|
||||
MS_EXCEPTION_IF_NULL(value_addr);
|
||||
auto address = device_context->device_res_manager_->CreateDeviceAddress(
|
||||
value_addr, tensor_size, output_format, output_type_id, trans::GetRuntimePaddingShape(value_node, output_idx));
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
|
||||
// Sync tensor value.
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(address->AsyncHostToDevice({}, tensor_size, output_type_id, tensor->data_c(),
|
||||
device_context->device_res_manager_->GetStream(stream_id)),
|
||||
"Async memcpy host to device failed.");
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(device_context->device_res_manager_->SyncStream(stream_id), "Synchronize stream failed.");
|
||||
|
||||
address->set_from_persistent_mem(true);
|
||||
AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
|
||||
|
||||
return value_node;
|
||||
}
|
||||
|
||||
|
@ -261,16 +290,16 @@ void EmbeddingCachePrefetchActor::BuildEmbeddingCacheLookupKernel() {
|
|||
// 1. Create parameter nodes which are inputs of embedding cache look up kernel(operator name: 'Gather').
|
||||
ParameterPtr input_param = NewParameter(graph, kFloat32, kTwoDimensionalShape);
|
||||
ParameterPtr input_indices = NewParameter(graph, kInt32, kOneDimensionalShape);
|
||||
ValueNodePtr axis_value_node = NewValueNode(0);
|
||||
ValueNodePtr offset_value_node = NewValueNode(0, device_context_, stream_id_);
|
||||
|
||||
// 2. Create a CNode for operator Gather.
|
||||
PrimitivePtr emb_lookup_primitive = std::make_shared<Primitive>(kGatherV2OpName);
|
||||
// 2. Create a CNode for operator EmbeddingLookup.
|
||||
PrimitivePtr emb_lookup_primitive = std::make_shared<Primitive>(kEmbeddingLookupOpName);
|
||||
emb_lookup_primitive->set_attr(kAttrInputIsDynamicShape, MakeValue(true));
|
||||
emb_lookup_primitive->set_attr(kAttrOutputIsDynamicShape, MakeValue(true));
|
||||
emb_lookup_primitive->set_attr(kAttrStream, MakeValue(stream_id_));
|
||||
|
||||
std::vector<AnfNodePtr> emb_lookup_input_nodes{NewValueNode(emb_lookup_primitive), input_param, input_indices,
|
||||
axis_value_node};
|
||||
offset_value_node};
|
||||
embedding_cache_lookup_node_ = graph->NewCNode(emb_lookup_input_nodes);
|
||||
MS_EXCEPTION_IF_NULL(embedding_cache_lookup_node_);
|
||||
auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, kTwoDimensionalShape);
|
||||
|
@ -319,18 +348,23 @@ bool EmbeddingCachePrefetchActor::LookupDeviceCache(void *indices, void *embeddi
|
|||
MS_ERROR_IF_NULL(embedding_cache_lookup_node_);
|
||||
|
||||
// 1. Update parameter nodes' shape.
|
||||
auto input_param_node = common::AnfAlgo::GetInputNode(embedding_cache_lookup_node_, 0);
|
||||
auto input_param_node = common::AnfAlgo::GetInputNode(embedding_cache_lookup_node_, kInputIndexZero);
|
||||
MS_ERROR_IF_NULL(input_param_node);
|
||||
const ShapeVector input_param_shape = {SizeToLong(cache_size), SizeToLong(embedding_size)};
|
||||
auto input_param_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, input_param_shape);
|
||||
input_param_node->set_abstract(input_param_abstract);
|
||||
|
||||
auto input_indices_node = common::AnfAlgo::GetInputNode(embedding_cache_lookup_node_, 1);
|
||||
auto input_indices_node = common::AnfAlgo::GetInputNode(embedding_cache_lookup_node_, kInputIndexOne);
|
||||
MS_ERROR_IF_NULL(input_indices_node);
|
||||
const ShapeVector input_indices_shape = {SizeToLong(indices_num)};
|
||||
auto input_indices_abstract = std::make_shared<abstract::AbstractTensor>(kInt32, input_indices_shape);
|
||||
input_indices_node->set_abstract(input_indices_abstract);
|
||||
|
||||
auto input_offset_node = common::AnfAlgo::GetInputNode(embedding_cache_lookup_node_, kInputIndexTwo);
|
||||
MS_ERROR_IF_NULL(input_offset_node);
|
||||
auto offset_address = AnfAlgo::GetMutableOutputAddr(input_offset_node, 0);
|
||||
MS_ERROR_IF_NULL(offset_address);
|
||||
|
||||
// 2. Infer shape for embedding cache look up kernel(operator name: 'Gather') which is dynamic shape kernel.
|
||||
if (!InferOpShape(embedding_cache_lookup_node_)) {
|
||||
MS_LOG(ERROR) << "Infer operator shape failed, op name: " << embedding_cache_lookup_node_->fullname_with_scope();
|
||||
|
@ -340,7 +374,8 @@ bool EmbeddingCachePrefetchActor::LookupDeviceCache(void *indices, void *embeddi
|
|||
// 3. Do embedding cache look up on device.
|
||||
AddressPtrList kernel_inputs = {
|
||||
std::make_shared<Address>(embedding_cache, cache_size * embedding_size * sizeof(float)),
|
||||
std::make_shared<Address>(indices, indices_num * sizeof(int))};
|
||||
std::make_shared<Address>(indices, indices_num * sizeof(int)),
|
||||
std::make_shared<Address>(offset_address->GetMutablePtr(), offset_address->GetSize())};
|
||||
AddressPtrList kernel_outputs = {std::make_shared<Address>(outputs, indices_num * embedding_size * sizeof(float))};
|
||||
|
||||
MS_ERROR_IF_NULL(device_context_);
|
||||
|
@ -362,19 +397,19 @@ bool EmbeddingCachePrefetchActor::UpdateDeviceCache(void *indices, void *update_
|
|||
MS_ERROR_IF_NULL(embedding_cache_update_node_);
|
||||
|
||||
// 1. Update parameter nodes' shape.
|
||||
auto input_param_node = common::AnfAlgo::GetInputNode(embedding_cache_update_node_, 0);
|
||||
auto input_param_node = common::AnfAlgo::GetInputNode(embedding_cache_update_node_, kInputIndexZero);
|
||||
MS_ERROR_IF_NULL(input_param_node);
|
||||
const ShapeVector input_param_shape = {SizeToLong(cache_size), SizeToLong(embedding_size)};
|
||||
auto input_param_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, input_param_shape);
|
||||
input_param_node->set_abstract(input_param_abstract);
|
||||
|
||||
auto input_indices_node = common::AnfAlgo::GetInputNode(embedding_cache_update_node_, 1);
|
||||
auto input_indices_node = common::AnfAlgo::GetInputNode(embedding_cache_update_node_, kInputIndexOne);
|
||||
MS_ERROR_IF_NULL(input_indices_node);
|
||||
const ShapeVector input_indices_shape = {SizeToLong(indices_num)};
|
||||
auto input_indices_abstract = std::make_shared<abstract::AbstractTensor>(kInt32, input_indices_shape);
|
||||
input_indices_node->set_abstract(input_indices_abstract);
|
||||
|
||||
auto update_values_node = common::AnfAlgo::GetInputNode(embedding_cache_update_node_, 2);
|
||||
auto update_values_node = common::AnfAlgo::GetInputNode(embedding_cache_update_node_, kInputIndexTwo);
|
||||
MS_ERROR_IF_NULL(update_values_node);
|
||||
const ShapeVector update_values_shape = {SizeToLong(indices_num), SizeToLong(embedding_size)};
|
||||
auto update_values_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, update_values_shape);
|
||||
|
|
Loading…
Reference in New Issue