forked from mindspore-Ecosystem/mindspore
!45394 Add insert_default_value attribute for the MapParameter.get() api.
Merge pull request !45394 from Margaret_wangrui/insert_attr
This commit is contained in:
commit
fa6f403612
|
@ -10,6 +10,7 @@ mindspore/mindspore/lite/tools/converter/graphdef_transform.cc:mindspore::lite::
|
|||
mindspore/mindspore/lite/providers/nnie_proposal/src/proposal.cc:mindspore::proposal::Rpn
|
||||
mindspore/mindspore/core/abstract/ops/primitive_infer_map.cc:mindspore::abstract::GetPrimitiveToEvalImplMap
|
||||
mindspore/mindspore/core/abstract/ops/primitive_infer_map.cc:mindspore::abstract::GetHostDependsMap
|
||||
mindspore/mindspore/core/ir/tensor.cc:mindspore::tensor::MakeTensorData
|
||||
mindspore/mindspore/ccsrc/frontend/optimizer/irpass.cc:mindspore::opt::irpass::OptimizeIRPassLib::OptimizeIRPassLib
|
||||
mindspore/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc:mindspore::parallel::GatherV2PInfo::CheckStrategy
|
||||
mindspore/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_kernel_runtime.cc:mindspore::device::gpu::GPUKernelRuntime::LaunchKernelDynamic
|
||||
|
|
|
@ -675,6 +675,7 @@ constexpr auto kAttrUseEmbeddingStore = "UseEmbeddingStore";
|
|||
constexpr auto kAttrParameterKey = "ParameterKey";
|
||||
constexpr auto kAttrMsFunctionControl = "ms_function_control";
|
||||
constexpr auto kAttrFuncGraphCellId = "func_graph_cell_id";
|
||||
constexpr auto kAttrInsertDefaultValue = "insert_default_value";
|
||||
|
||||
// FuncGraph Flags
|
||||
constexpr auto kFlagsIsCutGraph = "is_cut_graph";
|
||||
|
|
|
@ -170,16 +170,17 @@ void GPUHashTable<Key, Value, Allocator>::FreeMemory(void *ptr) {
|
|||
}
|
||||
|
||||
template <typename Key, typename Value, typename Allocator>
|
||||
bool GPUHashTable<Key, Value, Allocator>::Find(const Key *keys, size_t key_num, Value *outputs, void *stream) {
|
||||
bool GPUHashTable<Key, Value, Allocator>::Find(const Key *keys, size_t key_num, bool insert_default_value,
|
||||
Value *outputs, void *stream) {
|
||||
if (!initializer_.empty()) {
|
||||
return Find(keys, key_num, initializer_, outputs, stream);
|
||||
return Find(keys, key_num, insert_default_value, initializer_, outputs, stream);
|
||||
}
|
||||
return Find(keys, key_num, default_value_, outputs, stream);
|
||||
return Find(keys, key_num, insert_default_value, default_value_, outputs, stream);
|
||||
}
|
||||
|
||||
template <typename Key, typename Value, typename Allocator>
|
||||
bool GPUHashTable<Key, Value, Allocator>::Find(const Key *keys, size_t key_num, const std::string &initializer,
|
||||
Value *outputs, void *stream) {
|
||||
bool GPUHashTable<Key, Value, Allocator>::Find(const Key *keys, size_t key_num, bool insert_default_value,
|
||||
const std::string &initializer, Value *outputs, void *stream) {
|
||||
MS_ERROR_IF_NULL(keys);
|
||||
MS_ERROR_IF_NULL(outputs);
|
||||
MS_ERROR_IF_NULL(stream);
|
||||
|
@ -192,9 +193,11 @@ bool GPUHashTable<Key, Value, Allocator>::Find(const Key *keys, size_t key_num,
|
|||
|
||||
// 1. Get all indices in blocks according to the keys.
|
||||
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
|
||||
RETURN_IF_FALSE_WITH_LOG(GetIndicesByKeys(keys, key_num, true, indices, cuda_stream), "Get indices by keys failed.");
|
||||
RETURN_IF_FALSE_WITH_LOG(GetIndicesByKeys(keys, key_num, insert_default_value, indices, cuda_stream),
|
||||
"Get indices by keys failed.");
|
||||
|
||||
RETURN_IF_FALSE_WITH_LOG(UpdateSize(key_num, indices, cuda_stream), "Update hash table size failed.");
|
||||
RETURN_IF_FALSE_WITH_LOG(UpdateSize(key_num, indices, cuda_stream, insert_default_value),
|
||||
"Update hash table size failed.");
|
||||
|
||||
// 2. Insert default value according to initializer, initializer can be 'normal', 'zeros' or 'ones'.
|
||||
RETURN_IF_FALSE_WITH_LOG(InsertDefaultValueByInitializer(key_num, initializer, indices, cuda_stream),
|
||||
|
@ -209,8 +212,8 @@ bool GPUHashTable<Key, Value, Allocator>::Find(const Key *keys, size_t key_num,
|
|||
}
|
||||
|
||||
template <typename Key, typename Value, typename Allocator>
|
||||
bool GPUHashTable<Key, Value, Allocator>::Find(const Key *keys, size_t key_num, const Value &default_value,
|
||||
Value *outputs, void *stream) {
|
||||
bool GPUHashTable<Key, Value, Allocator>::Find(const Key *keys, size_t key_num, bool insert_default_value,
|
||||
const Value &default_value, Value *outputs, void *stream) {
|
||||
MS_ERROR_IF_NULL(keys);
|
||||
MS_ERROR_IF_NULL(outputs);
|
||||
MS_ERROR_IF_NULL(stream);
|
||||
|
@ -223,9 +226,11 @@ bool GPUHashTable<Key, Value, Allocator>::Find(const Key *keys, size_t key_num,
|
|||
|
||||
// 1. Get all indices in blocks according to the keys.
|
||||
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
|
||||
RETURN_IF_FALSE_WITH_LOG(GetIndicesByKeys(keys, key_num, true, indices, cuda_stream), "Get indices by keys failed.");
|
||||
RETURN_IF_FALSE_WITH_LOG(GetIndicesByKeys(keys, key_num, insert_default_value, indices, cuda_stream),
|
||||
"Get indices by keys failed.");
|
||||
|
||||
RETURN_IF_FALSE_WITH_LOG(UpdateSize(key_num, indices, cuda_stream), "Update hash table size failed.");
|
||||
RETURN_IF_FALSE_WITH_LOG(UpdateSize(key_num, indices, cuda_stream, insert_default_value),
|
||||
"Update hash table size failed.");
|
||||
|
||||
// 2. Insert default value into map by specific value.
|
||||
InsertDefaultValue<<<GET_BLOCKS(key_num), GET_THREADS, 0, cuda_stream>>>(
|
||||
|
|
|
@ -67,7 +67,7 @@ class GPUHashTable : public HashTable<Key, Value> {
|
|||
// Find elements with specific keys, if a key does not exist, initialize the value for the key based on the
|
||||
// initialzer and insert the key-value pair into map. The initializer can be 'normal', 'zero' or 'one', and also
|
||||
// could be a specific 'Value' type scalar.
|
||||
bool Find(const Key *keys, size_t key_num, Value *outputs, void *stream) override;
|
||||
bool Find(const Key *keys, size_t key_num, bool insert_default_value, Value *outputs, void *stream) override;
|
||||
|
||||
// Insert elements with specific keys. If key exists, update the value of the key.
|
||||
bool Insert(const Key *keys, size_t key_num, const Value *value, void *stream) override;
|
||||
|
@ -102,11 +102,13 @@ class GPUHashTable : public HashTable<Key, Value> {
|
|||
private:
|
||||
// Find elements with specific keys, if the key does not exist, initialize the value for the key based on the
|
||||
// initialzer and insert the key-value pair into map.The initializer can be 'normal', 'zeros' or 'ones'.
|
||||
bool Find(const Key *keys, size_t key_num, const std::string &initializer, Value *outputs, void *stream);
|
||||
bool Find(const Key *keys, size_t key_num, bool insert_default_value, const std::string &initializer, Value *outputs,
|
||||
void *stream);
|
||||
|
||||
// Find elements with specific keys, if the key does not exist, initialize the value for the key by 'default_value'
|
||||
// and insert the key-value pair into map.
|
||||
bool Find(const Key *keys, size_t key_num, const Value &default_value, Value *outputs, void *stream);
|
||||
bool Find(const Key *keys, size_t key_num, bool insert_default_value, const Value &default_value, Value *outputs,
|
||||
void *stream);
|
||||
|
||||
// Get all indices in blocks according to the key.
|
||||
bool GetIndicesByKeys(const Key *key, size_t key_num, bool insert_miss_key, int32_t *indices, cudaStream_t stream);
|
||||
|
|
|
@ -29,12 +29,12 @@ std::vector<std::pair<KernelAttr, MapTensorGetDataGpuKernelMod::MapTensorGetData
|
|||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeMapTensorType)
|
||||
.AddOutputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat),
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&MapTensorGetDataGpuKernelMod::LaunchKernel<int32_t, float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeMapTensorType)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat),
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&MapTensorGetDataGpuKernelMod::LaunchKernel<int64_t, float>}};
|
||||
|
||||
std::vector<KernelAttr> MapTensorGetDataGpuKernelMod::GetOpSupport() {
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <string>
|
||||
#include "mindspore/core/abstract/utils.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -46,8 +47,10 @@ std::vector<KernelAttr> MapTensorGetGpuKernelMod::GetOpSupport() {
|
|||
bool MapTensorGetGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
MS_EXCEPTION_IF_NULL(base_operator->GetPrim());
|
||||
kernel_name_ = base_operator->GetPrim()->name();
|
||||
auto prim = base_operator->GetPrim();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
kernel_name_ = prim->name();
|
||||
insert_default_value_ = GetValue<bool>(prim->GetAttr(kAttrInsertDefaultValue));
|
||||
// Check the inputs and outputs num.
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMapTensorGetInputNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kMapTensorGetOutputNum, kernel_name_);
|
||||
|
@ -92,7 +95,7 @@ bool MapTensorGetGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &input
|
|||
auto hash_table_ptr = user_data->get<GPUHashTable<KeyType, ValueType>>(kUserDataData);
|
||||
MS_EXCEPTION_IF_NULL(hash_table_ptr);
|
||||
return hash_table_ptr->Find(static_cast<KeyType *>(inputs[kIndex1]->addr), inputs[kIndex1]->size / sizeof(KeyType),
|
||||
static_cast<ValueType *>(outputs[kIndex0]->addr), stream_ptr);
|
||||
insert_default_value_, static_cast<ValueType *>(outputs[kIndex0]->addr), stream_ptr);
|
||||
}
|
||||
|
||||
bool MapTensorGetGpuKernelMod::InitSize(const BaseOperatorPtr &, const std::vector<KernelTensorPtr> &inputs,
|
||||
|
|
|
@ -61,6 +61,7 @@ class MapTensorGetGpuKernelMod : public MapTensorGpuKernelMod {
|
|||
const std::vector<AddressPtr> &, void *)>;
|
||||
static std::vector<std::pair<KernelAttr, MapTensorGetLaunchFunc>> map_tensor_get_func_list_;
|
||||
MapTensorGetLaunchFunc kernel_launch_func_;
|
||||
bool insert_default_value_{true};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -88,9 +88,10 @@ std::tuple<py::array, py::array, py::array> MapTensorPy::ExportAsNumpy(const Map
|
|||
}
|
||||
|
||||
// Python wrapper for MapTensor::Get.
|
||||
static tensor::TensorPtr PyMapTensorGet(const MapTensorPtr &map_tensor, const tensor::TensorPtr &key_tensor) {
|
||||
static tensor::TensorPtr PyMapTensorGet(const MapTensorPtr &map_tensor, const tensor::TensorPtr &key_tensor,
|
||||
bool insert_default_value) {
|
||||
MS_EXCEPTION_IF_NULL(map_tensor);
|
||||
return map_tensor->Get(key_tensor);
|
||||
return map_tensor->Get(key_tensor, insert_default_value);
|
||||
}
|
||||
|
||||
static tensor::TensorPtr PyMapTensorGetKeys(const MapTensorPtr &map_tensor) {
|
||||
|
|
|
@ -38,7 +38,7 @@ class HashTable {
|
|||
// Find elements with specific keys, if a key does not exist, initialize the value for the key based on the
|
||||
// initialzer and insert the key-value pair into map. The initializer can be 'normal', 'zero' or 'one', and also
|
||||
// could be a specific Value type scalar.
|
||||
virtual bool Find(const Key *keys, size_t key_num, Value *outputs, void *stream) = 0;
|
||||
virtual bool Find(const Key *keys, size_t key_num, bool insert_default_value, Value *outputs, void *stream) = 0;
|
||||
|
||||
// Insert elements with specific keys. If key exists, update the value of the key.
|
||||
virtual bool Insert(const Key *keys, size_t key_num, const Value *value, void *stream) = 0;
|
||||
|
|
|
@ -61,7 +61,7 @@ std::string MapTensor::ToString() const {
|
|||
", default_value=" + (default_value_ == nullptr ? "<null>" : default_value_->ToString()) + ")";
|
||||
}
|
||||
|
||||
TensorPtr MapTensor::Get(const TensorPtr &key_tensor) {
|
||||
TensorPtr MapTensor::Get(const TensorPtr &key_tensor, bool insert_default_value) {
|
||||
MS_EXCEPTION_IF_NULL(key_tensor);
|
||||
// Check input.
|
||||
if (key_tensor->shape().size() != 1) {
|
||||
|
|
|
@ -145,8 +145,9 @@ class MS_CORE_API MapTensor final : public Tensor {
|
|||
/// \brief Get or create values.
|
||||
///
|
||||
/// \param[in] key_tensor [Tensor] The key tensor.
|
||||
/// \param[in] insert_default_value [bool] The flag of insert default_value.
|
||||
/// \return The value tensor according the key tensor, return default_value if key_tensor is not exist.
|
||||
TensorPtr Get(const TensorPtr &key_tensor);
|
||||
TensorPtr Get(const TensorPtr &key_tensor, bool insert_default_value);
|
||||
|
||||
/// \brief Put or insert key value pairs.
|
||||
///
|
||||
|
|
|
@ -590,6 +590,7 @@ TensorDataPtr MakeTensorData(TypeId data_type, Args &&... args) {
|
|||
case kObjectTypeString:
|
||||
return std::make_shared<ImplClass<uint8_t>>(std::forward<Args>(args)...);
|
||||
case kObjectTypeTensorType:
|
||||
case kObjectTypeMapTensorType:
|
||||
return std::make_shared<ImplClass<int>>(std::forward<Args>(args)...);
|
||||
default:
|
||||
break;
|
||||
|
|
|
@ -3542,11 +3542,11 @@ def cholesky_inverse(input_x, upper=False):
|
|||
return F.cholesky_inverse(input_x, upper=upper)
|
||||
|
||||
|
||||
def map_tensor_get(map_tensor, key_tensor):
|
||||
def map_tensor_get(map_tensor, key_tensor, insert_default_value=True):
|
||||
r"""
|
||||
Get or create value according the key tensor from a map tensor.
|
||||
"""
|
||||
return _map_tensor_ops.get(map_tensor, key_tensor)
|
||||
return _map_tensor_ops.MapTensorGet(insert_default_value)(map_tensor, key_tensor)
|
||||
|
||||
|
||||
def map_tensor_put(map_tensor, key_tensor, value_tensor):
|
||||
|
|
|
@ -4592,18 +4592,19 @@ class MapTensor(MapTensor_):
|
|||
"""
|
||||
return self.value_tensor.shape
|
||||
|
||||
def get(self, key_tensor):
|
||||
def get(self, key_tensor, insert_default_value=True):
|
||||
"""
|
||||
Get value tensor according the key tensor, fill and return the default value in map parameter if key is not
|
||||
existed.
|
||||
|
||||
Args:
|
||||
key_tensor (Tensor): The key tensor.
|
||||
insert_default_value (bool): The flag of insert default_value.
|
||||
|
||||
Returns:
|
||||
Tensor, the value tensor for the key tensor.
|
||||
"""
|
||||
result_tensor = self.get(key_tensor)
|
||||
result_tensor = self.get(key_tensor, insert_default_value)
|
||||
return Tensor(result_tensor, internal=True)
|
||||
|
||||
def get_keys(self):
|
||||
|
|
|
@ -94,7 +94,7 @@ class MapParameter(Parameter):
|
|||
self.permit_filter_value, self.evict_filter_value)
|
||||
|
||||
def __getitem__(self, key_tensor):
|
||||
return self.get(key_tensor)
|
||||
return self.get(key_tensor, True)
|
||||
|
||||
def __setitem__(self, key_tensor, value_tensor):
|
||||
return self.put(key_tensor, value_tensor)
|
||||
|
@ -135,18 +135,19 @@ class MapParameter(Parameter):
|
|||
x.evict_filter_value)
|
||||
return x
|
||||
|
||||
def get(self, key_tensor):
|
||||
def get(self, key_tensor, insert_default_value=True):
|
||||
"""
|
||||
Get value tensor according the key tensor, fill and return the default value in map parameter if key is not
|
||||
existed.
|
||||
|
||||
Args:
|
||||
key_tensor (Tensor): The key tensor.
|
||||
insert_default_value (bool): The flag of insert default_value.
|
||||
|
||||
Returns:
|
||||
Tensor, the value tensor for the key tensor.
|
||||
"""
|
||||
result_tensor = self._map_tensor.get(key_tensor)
|
||||
result_tensor = self._map_tensor.get(key_tensor, insert_default_value)
|
||||
return Tensor(result_tensor, internal=True)
|
||||
|
||||
def get_keys(self):
|
||||
|
|
|
@ -323,4 +323,4 @@ def _map_tensor_getitem(map_tensor, key_tensor):
|
|||
Outputs:
|
||||
Tensor, value tensor according the key tensor.
|
||||
"""
|
||||
return _map_tensor_ops.get(map_tensor, key_tensor)
|
||||
return _map_tensor_ops.MapTensorGet(True)(map_tensor, key_tensor)
|
||||
|
|
|
@ -28,10 +28,11 @@ class MapTensorGet(Primitive):
|
|||
sig.make_sig('key_tensor'))
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
def __init__(self, insert_default_value):
|
||||
"""Initialize MapTensorGet"""
|
||||
self.init_prim_io_names(inputs=['map_tensor', 'key_tensor'], outputs=['output'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
self.insert_default_value = insert_default_value
|
||||
|
||||
|
||||
class MapTensorPut(Primitive):
|
||||
|
@ -104,7 +105,6 @@ class MapTensorGetData(Primitive):
|
|||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
|
||||
get = MapTensorGet()
|
||||
put = MapTensorPut()
|
||||
erase = MapTensorErase()
|
||||
get_keys = MapTensorGetKeys()
|
||||
|
|
|
@ -12,12 +12,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from packaging import version
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.experimental import MapParameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.run_check._check_version import GPUEnvChecker
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -41,8 +43,51 @@ def test_simple_graph_compile_export():
|
|||
return self.p, self.m
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = MyNet()
|
||||
out = net()
|
||||
print("out:", out)
|
||||
data = net.m.export()
|
||||
print("data:", data)
|
||||
env_checker = GPUEnvChecker()
|
||||
v = env_checker.get_cudart_version()
|
||||
env_version = version.parse(v)
|
||||
if env_version.major >= 11:
|
||||
net = MyNet()
|
||||
out = net()
|
||||
print("out:", out)
|
||||
data = net.m.export()
|
||||
print("data:", data)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_maptensor_put_get_export():
|
||||
"""
|
||||
Feature: MapParameter
|
||||
Description: Test IR graph compiled with MapParameter, test put, get and export api.
|
||||
Expectation: IR graph with MapParameter created without exceptions.
|
||||
"""
|
||||
class MyNet(nn.Cell):
|
||||
def __init__(self):
|
||||
nn.Cell.__init__(self)
|
||||
self.p = Parameter(initializer('ones', (2, 3), ms.float32))
|
||||
self.m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
|
||||
self.keys = Tensor([1, 2], dtype=ms.int32)
|
||||
self.values = Tensor([[11, 11, 11], [22, 22, 22]], dtype=ms.float32)
|
||||
|
||||
def construct(self):
|
||||
self.m.put(self.keys, self.values)
|
||||
key1 = Tensor([3], dtype=ms.int32)
|
||||
value1 = self.m.get(key1, True)
|
||||
key2 = Tensor([4], dtype=ms.int32)
|
||||
value2 = self.m.get(key2, True)
|
||||
return value1, value2, self.m
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
env_checker = GPUEnvChecker()
|
||||
v = env_checker.get_cudart_version()
|
||||
env_version = version.parse(v)
|
||||
if env_version.major >= 11:
|
||||
net = MyNet()
|
||||
out1, out2, out3 = net()
|
||||
print("out1:", out1)
|
||||
print("out2:", out2)
|
||||
print("out3:", out3)
|
||||
data = net.m.export()
|
||||
print("data:", data)
|
||||
|
|
Loading…
Reference in New Issue