forked from OSSInnovation/mindspore
fix topk bug
This commit is contained in:
parent
9cbed69ee5
commit
1778ec0135
|
@ -35,7 +35,7 @@ tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
|
||||||
// 1 create tensor
|
// 1 create tensor
|
||||||
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
||||||
auto last_dim = shape[shape.size() - 1];
|
auto last_dim = shape[shape.size() - 1];
|
||||||
std::vector<int> indices_shape = {SizeToInt(last_dim)};
|
std::vector<int> indices_shape = {SizeToInt(last_dim * 2)};
|
||||||
TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat16);
|
TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat16);
|
||||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||||
tensor::DeviceInfo device_info{kOpFormat_DEFAULT, tensor_type};
|
tensor::DeviceInfo device_info{kOpFormat_DEFAULT, tensor_type};
|
||||||
|
@ -50,7 +50,11 @@ tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
|
||||||
for (size_t i = 0; i < last_dim; ++i) {
|
for (size_t i = 0; i < last_dim; ++i) {
|
||||||
half_data.emplace_back(Eigen::half(static_cast<float>(i)));
|
half_data.emplace_back(Eigen::half(static_cast<float>(i)));
|
||||||
}
|
}
|
||||||
auto elem_num = last_dim * kFloat16Len;
|
for (size_t i = 0; i < last_dim; ++i) {
|
||||||
|
auto gap = static_cast<int>(i) - static_cast<int>(Eigen::half(static_cast<float>(i)));
|
||||||
|
half_data.emplace_back(Eigen::half(static_cast<float>(gap)));
|
||||||
|
}
|
||||||
|
auto elem_num = last_dim * kFloat16Len * 2;
|
||||||
auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(indices_tensor->data().nbytes()), half_data.data(), elem_num);
|
auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(indices_tensor->data().nbytes()), half_data.data(), elem_num);
|
||||||
if (ret_code != 0) {
|
if (ret_code != 0) {
|
||||||
MS_LOG(ERROR) << "Failed to copy data into Tensor.";
|
MS_LOG(ERROR) << "Failed to copy data into Tensor.";
|
||||||
|
@ -108,6 +112,13 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
|
||||||
MS_LOG(INFO) << "The input k of topk has been converted to attr";
|
MS_LOG(INFO) << "The input k of topk has been converted to attr";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
||||||
|
auto last_dim = shape[shape.size() - 1];
|
||||||
|
const size_t kMaxFloat16 = 65500;
|
||||||
|
if (last_dim > kMaxFloat16) {
|
||||||
|
MS_LOG(INFO) << "The last dim is more than 65500, switch to aicpu ops.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
// Copy a new node to check supported.
|
// Copy a new node to check supported.
|
||||||
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kTopKOpName))};
|
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kTopKOpName))};
|
||||||
new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||||
|
|
|
@ -59,6 +59,7 @@ do
|
||||||
mkdir ./train_parallel$i
|
mkdir ./train_parallel$i
|
||||||
cp ../*.py ./train_parallel$i
|
cp ../*.py ./train_parallel$i
|
||||||
cp *.sh ./train_parallel$i
|
cp *.sh ./train_parallel$i
|
||||||
|
cp -r ../src ./train_parallel$i
|
||||||
cd ./train_parallel$i || exit
|
cd ./train_parallel$i || exit
|
||||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||||
env > env.log
|
env > env.log
|
||||||
|
|
|
@ -57,6 +57,7 @@ fi
|
||||||
mkdir ./eval
|
mkdir ./eval
|
||||||
cp ../*.py ./eval
|
cp ../*.py ./eval
|
||||||
cp *.sh ./eval
|
cp *.sh ./eval
|
||||||
|
cp -r ../src ./eval
|
||||||
cd ./eval || exit
|
cd ./eval || exit
|
||||||
env > env.log
|
env > env.log
|
||||||
echo "start eval for device $DEVICE_ID"
|
echo "start eval for device $DEVICE_ID"
|
||||||
|
|
|
@ -49,6 +49,7 @@ fi
|
||||||
mkdir ./train
|
mkdir ./train
|
||||||
cp ../*.py ./train
|
cp ../*.py ./train
|
||||||
cp *.sh ./train
|
cp *.sh ./train
|
||||||
|
cp -r ../src ./train
|
||||||
cd ./train || exit
|
cd ./train || exit
|
||||||
echo "start training for device $DEVICE_ID"
|
echo "start training for device $DEVICE_ID"
|
||||||
env > env.log
|
env > env.log
|
||||||
|
|
|
@ -134,7 +134,7 @@ config = ed({
|
||||||
"keep_checkpoint_max": 10,
|
"keep_checkpoint_max": 10,
|
||||||
"save_checkpoint_path": "./checkpoint",
|
"save_checkpoint_path": "./checkpoint",
|
||||||
|
|
||||||
"mindrecord_dir": "../MindRecoid_COCO_TRAIN",
|
"mindrecord_dir": "../MindRecord_COCO_TRAIN",
|
||||||
"coco_root": "./cocodataset/",
|
"coco_root": "./cocodataset/",
|
||||||
"train_data_type": "train2017",
|
"train_data_type": "train2017",
|
||||||
"val_data_type": "val2017",
|
"val_data_type": "val2017",
|
||||||
|
|
|
@ -24,7 +24,7 @@ import mmcv
|
||||||
import mindspore.dataset as de
|
import mindspore.dataset as de
|
||||||
import mindspore.dataset.transforms.vision.c_transforms as C
|
import mindspore.dataset.transforms.vision.c_transforms as C
|
||||||
from mindspore.mindrecord import FileWriter
|
from mindspore.mindrecord import FileWriter
|
||||||
from config import config
|
from src.config import config
|
||||||
|
|
||||||
|
|
||||||
def bbox_overlaps(bboxes1, bboxes2, mode='iou'):
|
def bbox_overlaps(bboxes1, bboxes2, mode='iou'):
|
||||||
|
|
|
@ -90,7 +90,7 @@ TEST_F(TestHWTopKSplit, test_topk_split) {
|
||||||
EXPECT_TRUE(value_node->value()->isa<tensor::Tensor>());
|
EXPECT_TRUE(value_node->value()->isa<tensor::Tensor>());
|
||||||
auto tensor = value_node->value()->cast<tensor::TensorPtr>();
|
auto tensor = value_node->value()->cast<tensor::TensorPtr>();
|
||||||
EXPECT_EQ(tensor->shape().size(), 1);
|
EXPECT_EQ(tensor->shape().size(), 1);
|
||||||
EXPECT_EQ(tensor->shape()[0], 4);
|
EXPECT_EQ(tensor->shape()[0], 8);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestHWTopKSplit, test_topk_no_split) {
|
TEST_F(TestHWTopKSplit, test_topk_no_split) {
|
||||||
|
|
Loading…
Reference in New Issue