add: metaretriever
This commit is contained in:
parent
7c59c46682
commit
c2f193af79
|
@ -0,0 +1,18 @@
|
|||
# UIE
|
||||
/data
|
||||
/pretrain_data
|
||||
/hf_models
|
||||
/pd_models
|
||||
/runs
|
||||
/models
|
||||
*.lock
|
||||
|
||||
# mac
|
||||
.DS_Store
|
||||
|
||||
# env
|
||||
/.vscode
|
||||
/.idea
|
||||
**/__pycache__
|
||||
*.pyc
|
||||
.pytest_cache
|
|
@ -0,0 +1,90 @@
|
|||
# Universal Information Extraction with Meta-Pretrained Self-Retrieval
|
||||
|
||||
This code is for ACL 2023 Findings paper "Universal Information Extraction with Meta-Pretrained Self-Retrieval".
|
||||
|
||||
## Overview
|
||||
|
||||
![](img/MetaRetriever.png)
|
||||
|
||||
Universal Information Extraction (Universal IE) aims to solve different extraction tasks in a uniform text-to-structure generation manner. Such a generation procedure tends to struggle when there exist complex information structures to be extracted. Retrieving knowledge from external knowledge bases may help models to overcome this problem but it is impossible to construct a knowledge base suitable for various IE tasks. Inspired by the fact that large amount of knowledge are stored in the pretrained language models (PLM) and can be retrieved explicitly, in this paper, we propose MetaRetriever to retrieve task-specific knowledge from PLMs to enhance universal IE. As different IE tasks need different knowledge, we further propose a Meta-Pretraining Algorithm which allows MetaRetriever to quicktly achieve maximum task-specific retrieval performance when fine-tuning on downstream IE tasks. Experimental results show that MetaRetriever achieves the new state-of-the-art on 4 IE tasks, 12 datasets under fully-supervised, low-resource and few-shot scenarios.
|
||||
|
||||
## Requirements
|
||||
|
||||
General
|
||||
|
||||
- Python (verified on 3.8)
|
||||
- CUDA (verified on 10.2)
|
||||
|
||||
Python Packages
|
||||
``` bash
|
||||
conda create -n metaretriever python=3.8
|
||||
conda install -y pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.2 -c pytorch
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
**NOTE**: Different versions of packages (such as `pytorch`, `transformers`, etc.) may lead to different results from the paper. However, the trend should still hold no matter what versions of packages you use.
|
||||
|
||||
## Usage
|
||||
|
||||
### Data Preprocess
|
||||
|
||||
``` bash
|
||||
cd ./dataset_processing/ours
|
||||
bash download_and_preprocess_data_clean.sh > clean_log.txt
|
||||
```
|
||||
|
||||
### Model Preparation
|
||||
|
||||
Please refer to [UIE](https://github.com/universal-ie/UIE) to download UIE model checkpoint and put it under the `models` dir.
|
||||
|
||||
### Meta-Pretraining
|
||||
|
||||
``` bash
|
||||
|
||||
bash run_seq2seq_pretrain.bash -v -d 0,1,2,3,4,5,6,7 -b 64 -k 1 --lr 1e-4 --warmup_ratio 0.06 -i relation/ours_clean --spot_noise 0.0 --asoc_noise 0.0 -f spotasoc --map_config config/offset_map/closest_offset_en.yaml -m ./models/uie-base-en --random_prompt --epoch 4 --trainer_type meta_pretrain_v2 --use_prompt_tuning_model False --output_dir output/meta-pretrained-model
|
||||
```
|
||||
|
||||
### Meta-Finetuning
|
||||
|
||||
1. Full Supervision Scenario
|
||||
``` bash
|
||||
. config/exp_conf/large_model_conf.ini && trainer_type=meta_finetune_v2 model_name=meta-pretrained-model dataset_name=relation/conll04 selected_gpus=0,1,2,3,4,5,6,7 BATCH_SIZE=4 use_prompt_tuning_model=False run_time=1 bash scripts_exp/run_exp.bash
|
||||
```
|
||||
|
||||
2. Few-Shot Scenario
|
||||
``` bash
|
||||
. config/exp_conf/base_model_conf_sa_shot.ini && trainer_type=meta_finetune_v2 model_name=meta-pretrained-model dataset_name=relation/conll04 selected_gpus=0,1,2,3,4,5,6,7 BATCH_SIZE=16 use_prompt_tuning_model=False bash scripts_exp/run_exp_shot.bash
|
||||
```
|
||||
|
||||
3. Low-Resource Scenario
|
||||
``` bash
|
||||
. config/exp_conf/base_model_conf_sa_ratio.ini && trainer_type=meta_finetune_v2 model_name=meta-pretrained-model dataset_name=relation/conll04 selected_gpus=0,1,2,3,4,5,6,7 BATCH_SIZE=16 use_prompt_tuning_model=False bash scripts_exp/run_exp_ratio.bash
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
If this repository helps you, please cite this paper:
|
||||
```
|
||||
@inproceedings{cong-etal-2023-universal,
|
||||
title = "Universal Information Extraction with Meta-Pretrained Self-Retrieval",
|
||||
author = "Cong, Xin and
|
||||
Yu, Bowen and
|
||||
Fang, Mengcheng and
|
||||
Liu, Tingwen and
|
||||
Yu, Haiyang and
|
||||
Hu, Zhongkai and
|
||||
Huang, Fei and
|
||||
Li, Yongbin and
|
||||
Wang, Bin",
|
||||
editor = "Rogers, Anna and
|
||||
Boyd-Graber, Jordan and
|
||||
Okazaki, Naoaki",
|
||||
booktitle = "Findings of the Association for Computational Linguistics: ACL 2023",
|
||||
month = jul,
|
||||
year = "2023",
|
||||
address = "Toronto, Canada",
|
||||
publisher = "Association for Computational Linguistics",
|
||||
url = "https://aclanthology.org/2023.findings-acl.251",
|
||||
doi = "10.18653/v1/2023.findings-acl.251",
|
||||
}
|
||||
```
|
|
@ -0,0 +1,16 @@
|
|||
|
||||
export k8s_gpu_cards=1
|
||||
export gpu_node=${k8s_gpu_cards}
|
||||
|
||||
export eval_steps=0
|
||||
export epoch=50
|
||||
export run_time=5
|
||||
export max_source_length=384
|
||||
|
||||
export BATCH_SIZE="16"
|
||||
export LR_RATE="1e-4 3e-4 5e-4"
|
||||
export WARMUP_PROP="0.06"
|
||||
export LABEL_SMOOTHING="0"
|
||||
export NEGATIVE="-1"
|
||||
export NOISE="0.1"
|
||||
export map_config='config/offset_map/closest_offset_en.yaml'
|
|
@ -0,0 +1,21 @@
|
|||
export job_name=FT_Multi
|
||||
|
||||
export k8s_gpu_cards=4
|
||||
export gpu_node=${k8s_gpu_cards}
|
||||
|
||||
export eval_steps=0
|
||||
export epoch=50
|
||||
export run_time=3
|
||||
export max_source_length=384
|
||||
|
||||
export BATCH_SIZE="16"
|
||||
export LR_RATE="3e-4"
|
||||
export WARMUP_PROP="0.06"
|
||||
export LABEL_SMOOTHING="0"
|
||||
export NEGATIVE="-1"
|
||||
export NOISE="0.1"
|
||||
export map_config='config/offset_map/closest_offset_en.yaml'
|
||||
|
||||
export job_tags=""
|
||||
export job_remark="3e-4,0.1"
|
||||
export eval_match_mode="set"
|
|
@ -0,0 +1,23 @@
|
|||
|
||||
export k8s_gpu_cards=1
|
||||
export gpu_node=${k8s_gpu_cards}
|
||||
|
||||
export eval_steps=0
|
||||
export epoch=50
|
||||
export run_time=5
|
||||
export max_source_length=384
|
||||
|
||||
export job_tags="${dataset_name},${model_name}_rp"
|
||||
export job_remark="d${dataset_name},m${model_name}"
|
||||
|
||||
export BATCH_SIZE="16"
|
||||
export LR_RATE="5e-4"
|
||||
export WARMUP_PROP="0.06"
|
||||
export LABEL_SMOOTHING="0"
|
||||
export NEGATIVE="-1"
|
||||
export NOISE="0.1"
|
||||
export map_config='config/offset_map/closest_offset_en.yaml'
|
||||
|
||||
export job_tags="${dataset_name},${model_name}"
|
||||
export job_remark="5e-4,0.1"
|
||||
export start_eval_step=3000
|
|
@ -0,0 +1,22 @@
|
|||
|
||||
export k8s_gpu_cards=4
|
||||
export gpu_node=${k8s_gpu_cards}
|
||||
|
||||
export eval_steps=0
|
||||
export epoch=50
|
||||
export run_time=3
|
||||
export max_source_length=384
|
||||
|
||||
export job_tags="${dataset_name},${model_name}_rp"
|
||||
export job_remark="d${dataset_name},m${model_name}"
|
||||
|
||||
export BATCH_SIZE="8"
|
||||
export LR_RATE="1e-4"
|
||||
export WARMUP_PROP="0.06"
|
||||
export LABEL_SMOOTHING="0"
|
||||
export NEGATIVE="-1"
|
||||
export NOISE="0.1 0.2"
|
||||
export map_config='config/offset_map/closest_offset_en.yaml'
|
||||
|
||||
export job_tags="${dataset_name},${model_name}"
|
||||
export job_remark="1e-4,0.1,0.2"
|
|
@ -0,0 +1,22 @@
|
|||
|
||||
export k8s_gpu_cards=4
|
||||
export gpu_node=${k8s_gpu_cards}
|
||||
|
||||
export eval_steps=0
|
||||
export epoch=50
|
||||
export run_time=3
|
||||
export max_source_length=384
|
||||
|
||||
export job_tags="${dataset_name},${model_name}_rp"
|
||||
export job_remark="d${dataset_name},m${model_name}"
|
||||
|
||||
export BATCH_SIZE="8"
|
||||
export LR_RATE="1e-4"
|
||||
export WARMUP_PROP="0.06"
|
||||
export LABEL_SMOOTHING="0"
|
||||
export NEGATIVE="-1"
|
||||
export NOISE="0.1 0.2"
|
||||
export map_config='config/offset_map/closest_offset_en.yaml'
|
||||
|
||||
export job_tags="${dataset_name},${model_name}"
|
||||
export job_remark="1e-4,0.1,0.2"
|
|
@ -0,0 +1,22 @@
|
|||
export job_name=FT_spotasocname
|
||||
|
||||
export k8s_gpu_cards=4
|
||||
export gpu_node=${k8s_gpu_cards}
|
||||
|
||||
export eval_steps=2000
|
||||
export epoch=50
|
||||
export run_time=3
|
||||
export max_source_length=256
|
||||
|
||||
export job_tags="${dataset_name},${model_name}_rp"
|
||||
export job_remark="d${dataset_name},m${model_name}"
|
||||
|
||||
export BATCH_SIZE="8"
|
||||
export LR_RATE="1e-4"
|
||||
export WARMUP_PROP="0.06"
|
||||
export LABEL_SMOOTHING="0"
|
||||
export NEGATIVE="-1"
|
||||
export NOISE="0.1"
|
||||
export map_config='config/offset_map/closest_offset_en.yaml'
|
||||
export start_eval_step=15000
|
||||
export job_remark="1e-4,0.1"
|
|
@ -0,0 +1,22 @@
|
|||
|
||||
export k8s_gpu_cards=4
|
||||
export gpu_node=${k8s_gpu_cards}
|
||||
|
||||
export eval_steps=0
|
||||
export epoch=50
|
||||
export run_time=3
|
||||
export max_source_length=384
|
||||
|
||||
export job_tags="${dataset_name},${model_name}_rp"
|
||||
export job_remark="d${dataset_name},m${model_name}"
|
||||
|
||||
export BATCH_SIZE="8"
|
||||
export LR_RATE="1e-4 3e-4"
|
||||
export WARMUP_PROP="0.06"
|
||||
export LABEL_SMOOTHING="0"
|
||||
export NEGATIVE="-1"
|
||||
export NOISE="0.2"
|
||||
export map_config='config/offset_map/closest_offset_en.yaml'
|
||||
|
||||
export job_tags="${dataset_name},${model_name}"
|
||||
export job_remark="1e-4,3e-4,0.2"
|
|
@ -0,0 +1,22 @@
|
|||
|
||||
export k8s_gpu_cards=4
|
||||
export gpu_node=${k8s_gpu_cards}
|
||||
|
||||
export eval_steps=0
|
||||
export epoch=50
|
||||
export run_time=3
|
||||
export max_source_length=256
|
||||
|
||||
export job_tags="${dataset_name},${model_name}_rp"
|
||||
export job_remark="d${dataset_name},m${model_name}"
|
||||
|
||||
export BATCH_SIZE="8"
|
||||
export LR_RATE="3e-4"
|
||||
export WARMUP_PROP="0.06"
|
||||
export LABEL_SMOOTHING="0"
|
||||
export NEGATIVE="-1"
|
||||
export NOISE="0.2"
|
||||
export map_config='config/offset_map/closest_offset_en.yaml'
|
||||
|
||||
export job_tags="${dataset_name},${model_name}"
|
||||
export job_remark="3e-4,0.2"
|
|
@ -0,0 +1,22 @@
|
|||
|
||||
export k8s_gpu_cards=4
|
||||
export gpu_node=${k8s_gpu_cards}
|
||||
|
||||
export eval_steps=0
|
||||
export epoch=50
|
||||
export run_time=3
|
||||
export max_source_length=256
|
||||
|
||||
export job_tags="${dataset_name},${model_name}_rp"
|
||||
export job_remark="d${dataset_name},m${model_name}"
|
||||
|
||||
export BATCH_SIZE="8"
|
||||
export LR_RATE="1e-4"
|
||||
export WARMUP_PROP="0.06"
|
||||
export LABEL_SMOOTHING="0"
|
||||
export NEGATIVE="-1"
|
||||
export NOISE="0.1"
|
||||
export map_config='config/offset_map/first_offset_en.yaml'
|
||||
|
||||
export job_tags="${dataset_name},${model_name}"
|
||||
export job_remark="1e-4,0.1"
|
|
@ -0,0 +1,22 @@
|
|||
|
||||
export k8s_gpu_cards=1
|
||||
export gpu_node=${k8s_gpu_cards}
|
||||
|
||||
export eval_steps=0
|
||||
export epoch=50
|
||||
export run_time=3
|
||||
export max_source_length=256
|
||||
|
||||
export job_tags="${dataset_name},${model_name}_rp"
|
||||
export job_remark="d${dataset_name},m${model_name}"
|
||||
|
||||
export BATCH_SIZE="8"
|
||||
export LR_RATE="5e-5"
|
||||
export WARMUP_PROP="0.06"
|
||||
export LABEL_SMOOTHING="0"
|
||||
export NEGATIVE="-1"
|
||||
export NOISE="0.1"
|
||||
export map_config='config/offset_map/first_offset_en.yaml'
|
||||
|
||||
export job_tags="${dataset_name},${model_name}"
|
||||
export job_remark="5e-5,0.1"
|
|
@ -0,0 +1,22 @@
|
|||
|
||||
export k8s_gpu_cards=4
|
||||
export gpu_node=${k8s_gpu_cards}
|
||||
|
||||
export eval_steps=0
|
||||
export epoch=50
|
||||
export run_time=3
|
||||
export max_source_length=384
|
||||
|
||||
export job_tags="${dataset_name},${model_name}_rp"
|
||||
export job_remark="d${dataset_name},m${model_name}"
|
||||
|
||||
export BATCH_SIZE="8"
|
||||
export LR_RATE="3e-4"
|
||||
export WARMUP_PROP="0.06"
|
||||
export LABEL_SMOOTHING="0"
|
||||
export NEGATIVE="-1"
|
||||
export NOISE="0.2 0.1"
|
||||
export map_config='config/offset_map/closest_offset_en.yaml'
|
||||
|
||||
export job_tags="${dataset_name},${model_name}"
|
||||
export job_remark="3e-4,0.1,0.2"
|
|
@ -0,0 +1,16 @@
|
|||
|
||||
export k8s_gpu_cards=1
|
||||
export gpu_node=${k8s_gpu_cards}
|
||||
|
||||
export eval_steps=0
|
||||
export epoch=50
|
||||
export run_time=3
|
||||
export max_source_length=384
|
||||
|
||||
export BATCH_SIZE="8"
|
||||
export LR_RATE="1e-4 3e-5 5e-5"
|
||||
export WARMUP_PROP="0.06"
|
||||
export LABEL_SMOOTHING="0"
|
||||
export NEGATIVE="-1"
|
||||
export NOISE="0.1"
|
||||
export map_config='config/offset_map/closest_offset_en.yaml'
|
|
@ -0,0 +1,23 @@
|
|||
|
||||
export k8s_gpu_cards=4
|
||||
export gpu_node=${k8s_gpu_cards}
|
||||
|
||||
export eval_steps=0
|
||||
export epoch=50
|
||||
export run_time=3
|
||||
export max_source_length=384
|
||||
|
||||
export job_tags="${dataset_name},${model_name}_rp"
|
||||
export job_remark="d${dataset_name},m${model_name}"
|
||||
|
||||
export BATCH_SIZE="8"
|
||||
export LR_RATE="5e-5"
|
||||
export WARMUP_PROP="0.06"
|
||||
export LABEL_SMOOTHING="0"
|
||||
export NEGATIVE="-1"
|
||||
export NOISE="0.2"
|
||||
export map_config='config/offset_map/closest_offset_en.yaml'
|
||||
|
||||
export job_tags="${dataset_name},${model_name}"
|
||||
export job_remark="5e-5,0.2"
|
||||
export eval_match_mode="set"
|
|
@ -0,0 +1,22 @@
|
|||
|
||||
export k8s_gpu_cards=4
|
||||
export gpu_node=${k8s_gpu_cards}
|
||||
|
||||
export eval_steps=0
|
||||
export epoch=50
|
||||
export run_time=3
|
||||
export max_source_length=384
|
||||
|
||||
export job_tags="${dataset_name},${model_name}_rp"
|
||||
export job_remark="d${dataset_name},m${model_name}"
|
||||
|
||||
export BATCH_SIZE="8"
|
||||
export LR_RATE="3e-4"
|
||||
export WARMUP_PROP="0.06"
|
||||
export LABEL_SMOOTHING="0"
|
||||
export NEGATIVE="-1"
|
||||
export NOISE="0.2 0.1 0"
|
||||
export map_config='config/offset_map/closest_offset_en.yaml'
|
||||
|
||||
export job_tags="${dataset_name},${model_name}"
|
||||
export job_remark="3e-4,0,0.1,0.2"
|
|
@ -0,0 +1,22 @@
|
|||
description: uie
|
||||
|
||||
environment:
|
||||
image: docker.cipsup.cn/uie/uie:transformers4.6.2
|
||||
environment_variables:
|
||||
- DET_TASK_OWNER=luyaojie
|
||||
|
||||
resources:
|
||||
slots: 4
|
||||
|
||||
bind_mounts:
|
||||
# Data Folder Bind Mount
|
||||
- host_path: /shared_home/luyaojie/uie/data
|
||||
container_path: /run/determined/workdir/data
|
||||
|
||||
# Pre-trained Model Folder Bind Mount
|
||||
- host_path: /shared_home/luyaojie/uie/model
|
||||
container_path: /run/determined/workdir/hf_models
|
||||
|
||||
# Output Folder Bind Mount
|
||||
- host_path: /shared_home/luyaojie/uie/output
|
||||
container_path: /run/determined/workdir/output
|
|
@ -0,0 +1,22 @@
|
|||
|
||||
export k8s_gpu_cards=4
|
||||
export gpu_node=${k8s_gpu_cards}
|
||||
|
||||
export eval_steps=0
|
||||
export epoch=50
|
||||
export run_time=3
|
||||
export max_source_length=384
|
||||
|
||||
export job_tags="${dataset_name},${model_name}_rp"
|
||||
export job_remark="d${dataset_name},m${model_name}"
|
||||
|
||||
export BATCH_SIZE="16"
|
||||
export LR_RATE="1e-4 3e-4 5e-4"
|
||||
export WARMUP_PROP="0.06"
|
||||
export LABEL_SMOOTHING="0"
|
||||
export NEGATIVE="-1"
|
||||
export NOISE="0 0.1 0.2"
|
||||
export map_config='config/offset_map/closest_offset_en.yaml'
|
||||
|
||||
export job_tags="${dataset_name},${model_name}"
|
||||
export job_remark="1e-4,3e-4,5e-5,0,0.1,0.2"
|
|
@ -0,0 +1,21 @@
|
|||
|
||||
export gpu_node=1
|
||||
|
||||
export eval_steps=0
|
||||
export epoch=50
|
||||
export run_time=3
|
||||
export max_source_length=384
|
||||
|
||||
export job_tags="${dataset_name},${model_name}_rp"
|
||||
export job_remark="d${dataset_name},m${model_name}"
|
||||
|
||||
export BATCH_SIZE="16"
|
||||
export LR_RATE="1e-4 3e-4 5e-4"
|
||||
export WARMUP_PROP="0.06"
|
||||
export LABEL_SMOOTHING="0"
|
||||
export NEGATIVE="-1"
|
||||
export NOISE="0.1 0 0.2"
|
||||
export map_config='config/offset_map/closest_offset_en.yaml'
|
||||
|
||||
export job_tags="${dataset_name},${model_name}"
|
||||
export job_remark="1e-4,3e-4,5e-5,0,0.1,0.2"
|
|
@ -0,0 +1,15 @@
|
|||
export gpu_node=1
|
||||
|
||||
export eval_steps=0
|
||||
export epoch=200
|
||||
export run_time=10
|
||||
export max_source_length=256
|
||||
export decoding_format="spotasoc"
|
||||
|
||||
export BATCH_SIZE="16"
|
||||
export LR_RATE="1e-4"
|
||||
export WARMUP_PROP="0.06"
|
||||
export LABEL_SMOOTHING="0"
|
||||
export NEGATIVE="-1"
|
||||
export NOISE="0.1"
|
||||
export map_config='config/offset_map/closest_offset_en.yaml'
|
|
@ -0,0 +1,15 @@
|
|||
export gpu_node=1
|
||||
|
||||
export eval_steps=0
|
||||
export epoch=200
|
||||
export run_time=10
|
||||
export max_source_length=256
|
||||
export decoding_format="spotasoc"
|
||||
|
||||
export BATCH_SIZE="16"
|
||||
export LR_RATE="1e-4"
|
||||
export WARMUP_PROP="0.06"
|
||||
export LABEL_SMOOTHING="0"
|
||||
export NEGATIVE="-1"
|
||||
export NOISE="0.1"
|
||||
export map_config='config/offset_map/closest_offset_en.yaml'
|
|
@ -0,0 +1,22 @@
|
|||
|
||||
export k8s_gpu_cards=4
|
||||
export gpu_node=${k8s_gpu_cards}
|
||||
|
||||
export eval_steps=0
|
||||
export epoch=50
|
||||
export run_time=3
|
||||
export max_source_length=384
|
||||
|
||||
export job_tags="${dataset_name},${model_name}_rp"
|
||||
export job_remark="d${dataset_name},m${model_name}"
|
||||
|
||||
export BATCH_SIZE="8"
|
||||
export LR_RATE="5e-5 1e-4 3e-4"
|
||||
export WARMUP_PROP="0.06"
|
||||
export LABEL_SMOOTHING="0"
|
||||
export NEGATIVE="-1"
|
||||
export NOISE="0.2 0.1 0"
|
||||
export map_config='config/offset_map/closest_offset_en.yaml'
|
||||
|
||||
export job_tags="${dataset_name},${model_name}"
|
||||
export job_remark="5e-5,1e-4,3e-4,0,0.1,0.2"
|
|
@ -0,0 +1,22 @@
|
|||
|
||||
export k8s_gpu_cards=4
|
||||
export gpu_node=${k8s_gpu_cards}
|
||||
|
||||
export eval_steps=0
|
||||
export epoch=50
|
||||
export run_time=3
|
||||
export max_source_length=384
|
||||
|
||||
export job_tags="${dataset_name},${model_name}_rp"
|
||||
export job_remark="d${dataset_name},m${model_name}"
|
||||
|
||||
export BATCH_SIZE="8"
|
||||
export LR_RATE="5e-5 1e-4 3e-4"
|
||||
export WARMUP_PROP="0.06"
|
||||
export LABEL_SMOOTHING="0"
|
||||
export NEGATIVE="-1"
|
||||
export NOISE="0.2 0.1 0"
|
||||
export map_config='config/offset_map/closest_offset_en.yaml'
|
||||
|
||||
export job_tags="${dataset_name},${model_name}"
|
||||
export job_remark="5e-5,1e-4,3e-4,0,0.1,0.2"
|
|
@ -0,0 +1,3 @@
|
|||
map_strategy: "closest"
|
||||
de_duplicate: True
|
||||
span_to_token: "space"
|
|
@ -0,0 +1,3 @@
|
|||
map_strategy: "closest"
|
||||
de_duplicate: True
|
||||
span_to_token: "list"
|
|
@ -0,0 +1,3 @@
|
|||
map_strategy: "first"
|
||||
de_duplicate: True
|
||||
span_to_token: "space"
|
|
@ -0,0 +1,3 @@
|
|||
map_strategy: "first"
|
||||
de_duplicate: True
|
||||
span_to_token: "list"
|
|
@ -0,0 +1,3 @@
|
|||
map_strategy: "longer_first"
|
||||
de_duplicate: True
|
||||
span_to_token: "list"
|
|
@ -0,0 +1,25 @@
|
|||
/data
|
||||
/converted_data
|
||||
/lightning_logs
|
||||
/model
|
||||
/models
|
||||
/*log
|
||||
/thirdparty
|
||||
/tmp
|
||||
|
||||
.lock
|
||||
# mac
|
||||
.DS_Store
|
||||
|
||||
# env
|
||||
/.vscode
|
||||
/.idea
|
||||
**/__pycache__
|
||||
*.pyc
|
||||
.pytest_cache
|
||||
|
||||
# doc
|
||||
docs/build
|
||||
docs/.vscode
|
||||
docs/source/_build
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
# Universal IE Dataset Preparation
|
||||
|
||||
Please refer to [UIE](https://github.com/universal-ie/UIE).
|
|
@ -0,0 +1,15 @@
|
|||
name: 14lap
|
||||
path: data/absa/pengb/14lap
|
||||
data_class: ABSA
|
||||
split:
|
||||
train: train_convert.json
|
||||
val: dev_convert.json
|
||||
test: test_convert.json
|
||||
language: en
|
||||
|
||||
mapper:
|
||||
POS: positive
|
||||
NEG: negative
|
||||
NEU: neutral
|
||||
aspect: aspect
|
||||
opinion: opinion
|
|
@ -0,0 +1,15 @@
|
|||
name: 14res
|
||||
path: data/absa/pengb/14res
|
||||
data_class: ABSA
|
||||
split:
|
||||
train: train_convert.json
|
||||
val: dev_convert.json
|
||||
test: test_convert.json
|
||||
language: en
|
||||
|
||||
mapper:
|
||||
POS: positive
|
||||
NEG: negative
|
||||
NEU: neutral
|
||||
aspect: aspect
|
||||
opinion: opinion
|
|
@ -0,0 +1,15 @@
|
|||
name: 15res
|
||||
path: data/absa/pengb/15res
|
||||
data_class: ABSA
|
||||
split:
|
||||
train: train_convert.json
|
||||
val: dev_convert.json
|
||||
test: test_convert.json
|
||||
language: en
|
||||
|
||||
mapper:
|
||||
POS: positive
|
||||
NEG: negative
|
||||
NEU: neutral
|
||||
aspect: aspect
|
||||
opinion: opinion
|
|
@ -0,0 +1,15 @@
|
|||
name: 16res
|
||||
path: data/absa/pengb/16res
|
||||
data_class: ABSA
|
||||
split:
|
||||
train: train_convert.json
|
||||
val: dev_convert.json
|
||||
test: test_convert.json
|
||||
language: en
|
||||
|
||||
mapper:
|
||||
POS: positive
|
||||
NEG: negative
|
||||
NEU: neutral
|
||||
aspect: aspect
|
||||
opinion: opinion
|
|
@ -0,0 +1,13 @@
|
|||
name: conll03
|
||||
path: data/conll03/conll03
|
||||
data_class: CoNLL03
|
||||
split:
|
||||
train: eng.train
|
||||
val: eng.testa
|
||||
test: eng.testb
|
||||
language: en
|
||||
mapper:
|
||||
LOC: location
|
||||
ORG: organization
|
||||
PER: person
|
||||
MISC: miscellaneous
|
|
@ -0,0 +1,17 @@
|
|||
name: mrc_ace04
|
||||
path: data/mrc_ner/ace2004
|
||||
data_class: MRCNER
|
||||
split:
|
||||
train: mrc-ner.train
|
||||
val: mrc-ner.dev
|
||||
test: mrc-ner.test
|
||||
language: en
|
||||
|
||||
mapper:
|
||||
FAC: facility
|
||||
GPE: geographical social political
|
||||
LOC: location
|
||||
ORG: organization
|
||||
PER: person
|
||||
VEH: vehicle
|
||||
WEA: weapon
|
|
@ -0,0 +1,17 @@
|
|||
name: mrc_ace05
|
||||
path: data/mrc_ner/ace2005
|
||||
data_class: MRCNER
|
||||
split:
|
||||
train: mrc-ner.train
|
||||
val: mrc-ner.dev
|
||||
test: mrc-ner.test
|
||||
language: en
|
||||
|
||||
mapper:
|
||||
FAC: facility
|
||||
GPE: geographical social political
|
||||
LOC: location
|
||||
ORG: organization
|
||||
PER: person
|
||||
VEH: vehicle
|
||||
WEA: weapon
|
|
@ -0,0 +1,62 @@
|
|||
name: casie
|
||||
path: data/casie
|
||||
data_class: CASIE
|
||||
split:
|
||||
train: train.jsonlines
|
||||
val: dev.jsonlines
|
||||
test: test.jsonlines
|
||||
language: en
|
||||
# https://github.com/Ebiquity/CASIE
|
||||
# https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=9006444
|
||||
mapper:
|
||||
File: file
|
||||
System: system
|
||||
Person: person
|
||||
Phishing: phishing
|
||||
Data: data
|
||||
Purpose: purpose
|
||||
Website: website
|
||||
Organization: organization
|
||||
Capabilities: capabilities
|
||||
Malware: malware
|
||||
Software: software
|
||||
PII: personally identifiable information
|
||||
Databreach: databreach
|
||||
Time: time
|
||||
Number: number
|
||||
GPE: geopolitical entity
|
||||
Ransom: ransom
|
||||
Money: money
|
||||
Device: device
|
||||
Vulnerability: vulnerability
|
||||
DiscoverVulnerability: discover vulnerability
|
||||
Patch: patch
|
||||
PatchVulnerability: patch vulnerability
|
||||
Version: version
|
||||
PaymentMethod: payment method
|
||||
CVE: common vulnerabilities and exposures
|
||||
Issues-Addressed: issues addressed
|
||||
Vulnerable_System: vulnerable system
|
||||
Number-of-Data: number of data
|
||||
Capabilities: capabilities
|
||||
Patch: patch
|
||||
Time: time
|
||||
Vulnerable_System_Version: vulnerable system version
|
||||
Releaser: releaser
|
||||
Damage-Amount: damage amount
|
||||
Number-of-Victim: number of victim
|
||||
Tool: tool
|
||||
Attack-Pattern: attack pattern
|
||||
Compromised-Data: compromised data
|
||||
Attacker: attacker
|
||||
Price: price
|
||||
Discoverer: discoverer
|
||||
Patch-Number: patch number
|
||||
Payment-Method: payment method
|
||||
Supported_Platform: supported platform
|
||||
Vulnerability: vulnerability
|
||||
Place: place
|
||||
Vulnerable_System_Owner: vulnerable system owner
|
||||
Victim: victim
|
||||
Trusted-Entity: trusted entity
|
||||
Purpose: purpose
|
|
@ -0,0 +1,78 @@
|
|||
name: oneie_ace05_en_event
|
||||
path: data/oneie/ace05-EN
|
||||
data_class: OneIEEvent
|
||||
split:
|
||||
train: train.oneie.json
|
||||
val: dev.oneie.json
|
||||
test: test.oneie.json
|
||||
language: en
|
||||
|
||||
mapper:
|
||||
FAC: facility
|
||||
GPE: geographical social political
|
||||
LOC: location
|
||||
ORG: organization
|
||||
PER: person
|
||||
VEH: vehicle
|
||||
WEA: weapon
|
||||
ORG-AFF: organization affiliation
|
||||
GEN-AFF: general affiliation
|
||||
PHYS: physical
|
||||
PART-WHOLE: part whole
|
||||
PER-SOC: personal social
|
||||
ART: agent artifact
|
||||
Personnel:Elect: elect
|
||||
Life:Be-Born: born
|
||||
Movement:Transport: transport
|
||||
Contact:Phone-Write: phone write
|
||||
Life:Marry: marry
|
||||
Life:Die: die
|
||||
Personnel:Start-Position: start position
|
||||
Life:Injure: injure
|
||||
Transaction:Transfer-Ownership: transfer ownership
|
||||
Contact:Meet: meet
|
||||
Personnel:Nominate: nominate
|
||||
Conflict:Attack: attack
|
||||
Business:Start-Org: start organization
|
||||
Justice:Trial-Hearing: trial hearing
|
||||
Justice:Convict: convict
|
||||
Justice:Sentence: sentence
|
||||
Personnel:End-Position: end position
|
||||
Life:Divorce: divorce
|
||||
Justice:Acquit: acquit
|
||||
Justice:Charge-Indict: charge indict
|
||||
Transaction:Transfer-Money: transfer money
|
||||
Justice:Appeal: appeal
|
||||
Justice:Sue: sue
|
||||
Business:Merge-Org: merge organization
|
||||
Business:Declare-Bankruptcy: declare bankruptcy
|
||||
Justice:Execute: execute
|
||||
Justice:Arrest-Jail: arrest jail
|
||||
Justice:Extradite: extradite
|
||||
Conflict:Demonstrate: demonstrate
|
||||
Business:End-Org: end organization
|
||||
Justice:Release-Parole: release parole
|
||||
Justice:Fine: fine
|
||||
Justice:Pardon: pardon
|
||||
Defendant: defendant
|
||||
Prosecutor: prosecutor
|
||||
Person: person
|
||||
Origin: origin
|
||||
Buyer: buyer
|
||||
Plaintiff: plaintiff
|
||||
Victim: victim
|
||||
Org: organization
|
||||
Adjudicator: adjudicator
|
||||
Seller: seller
|
||||
Beneficiary: beneficiary
|
||||
Giver: giver
|
||||
Target: target
|
||||
Agent: agent
|
||||
Instrument: instrument
|
||||
Vehicle: vehicle
|
||||
Entity: entity
|
||||
Destination: destination
|
||||
Recipient: recipient
|
||||
Attacker: attacker
|
||||
Artifact: artifact
|
||||
Place: place
|
|
@ -0,0 +1,36 @@
|
|||
name: NYT
|
||||
path: data/NYT-multi
|
||||
data_class: JointER
|
||||
split:
|
||||
train: train.json
|
||||
val: dev.json
|
||||
test: test.json
|
||||
language: en
|
||||
mapper:
|
||||
ORGANIZATION: organization
|
||||
LOCATION: location
|
||||
PERSON: person
|
||||
/location/location/contains: contains
|
||||
/people/person/place_of_birth: place of birth
|
||||
/business/person/company: company
|
||||
/people/person/place_lived: place lived
|
||||
/location/administrative_division/country: country
|
||||
/location/country/administrative_divisions: administrative divisions
|
||||
/people/person/religion: religion
|
||||
/people/person/nationality: nationality
|
||||
/people/person/children: children
|
||||
/location/country/capital: capital
|
||||
/business/company/place_founded: place founded
|
||||
/people/deceased_person/place_of_death: place of death
|
||||
/business/company/founders: founders
|
||||
/location/neighborhood/neighborhood_of: neighborhood of
|
||||
/business/company/advisors: advisors
|
||||
/people/ethnicity/geographic_distribution: geographic distribution
|
||||
/sports/sports_team_location/teams: teams
|
||||
/sports/sports_team/location: location
|
||||
/business/company_shareholder/major_shareholder_of: major shareholder of
|
||||
/business/company/major_shareholders: major shareholders
|
||||
/people/person/ethnicity: ethnicity
|
||||
/people/ethnicity/people: people
|
||||
/people/person/profession: profession
|
||||
/business/company/industry: industry
|
|
@ -0,0 +1,23 @@
|
|||
name: ace05-rel
|
||||
path: data/spannet_data/relation/ace05
|
||||
data_class: Spannet
|
||||
split:
|
||||
train: train.jsonlines
|
||||
val: dev.jsonlines
|
||||
test: test.jsonlines
|
||||
language: en
|
||||
|
||||
mapper:
|
||||
FAC: facility
|
||||
GPE: geographical social political
|
||||
LOC: location
|
||||
ORG: organization
|
||||
PER: person
|
||||
VEH: vehicle
|
||||
WEA: weapon
|
||||
ORG-AFF: organization affiliation
|
||||
GEN-AFF: general affiliation
|
||||
PHYS: physical
|
||||
PART-WHOLE: part whole
|
||||
PER-SOC: personal social
|
||||
ART: agent artifact
|
|
@ -0,0 +1,19 @@
|
|||
name: conll04
|
||||
path: data/spannet_data/relation/conll04
|
||||
data_class: Spannet
|
||||
split:
|
||||
train: train.jsonlines
|
||||
val: dev.jsonlines
|
||||
test: test.jsonlines
|
||||
language: en
|
||||
|
||||
mapper:
|
||||
Loc: location
|
||||
Org: organization
|
||||
Other: other
|
||||
Peop: people
|
||||
OrgBased_In: organization in
|
||||
Work_For: work for
|
||||
Located_In: located in
|
||||
Live_In: live in
|
||||
Kill: kill
|
|
@ -0,0 +1,22 @@
|
|||
name: scierc
|
||||
path: data/spannet_data/relation/dygiepp/scierc
|
||||
data_class: Spannet
|
||||
split:
|
||||
train: train.jsonlines
|
||||
val: dev.jsonlines
|
||||
test: test.jsonlines
|
||||
language: en
|
||||
mapper:
|
||||
Method: method
|
||||
Generic: generic
|
||||
Material: material
|
||||
Task: task
|
||||
Metric: metric
|
||||
OtherScientificTerm: other scientific term
|
||||
USED-FOR: used for
|
||||
FEATURE-OF: feature of
|
||||
COMPARE: compare
|
||||
EVALUATE-FOR: evaluate for
|
||||
CONJUNCTION: conjunction
|
||||
HYPONYM-OF: hyponym of
|
||||
PART-OF: part of
|
|
@ -0,0 +1,10 @@
|
|||
# 数据统计脚本
|
||||
|
||||
``` bash
|
||||
python scripts/data_statistics.py \
|
||||
-data converted_data/text2spotasoc/
|
||||
-f csv
|
||||
```
|
||||
|
||||
- data: 目标文件夹,遍历文件夹下包含 record.schema 的子文件夹,跳过所有的命名中包含 shot 和 rario 的文件夹
|
||||
- f: 输出的表格形式,常见中 simple(默认),latex,html
|
|
@ -0,0 +1,54 @@
|
|||
# 低资源数据采样
|
||||
|
||||
详细脚本见 `run_sample.bash`, 自动生成所有数据
|
||||
|
||||
|
||||
## 低数据比例采样
|
||||
|
||||
``` text
|
||||
$ python scripts/sample_data_ratio.py -h
|
||||
usage: sample_data_ratio.py [-h] [-src SRC] [-tgt TGT] [-seed SEED]
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
-src SRC
|
||||
-tgt TGT
|
||||
-seed SEED
|
||||
```
|
||||
|
||||
样例:
|
||||
|
||||
``` bash
|
||||
python scripts/sample_data_ratio.py \
|
||||
-src converted_data/text2spotasoc/entity/mrc_conll03 \
|
||||
-tgt test_conll03_ratio
|
||||
```
|
||||
|
||||
对所有数据文件夹的train.json取指定 0.01 0.05 0.1 比例的数据
|
||||
|
||||
## N-shot 数据采样
|
||||
|
||||
``` text
|
||||
$ python scripts/sample_data_shot.py -h
|
||||
usage: sample_data_shot.py [-h] -src SRC -tgt TGT -task {entity,relation,event} [-seed SEED]
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
-src SRC Source Folder Name
|
||||
-tgt TGT Target Folder Name, n shot sampled
|
||||
-task {entity,relation,event}
|
||||
N-Shot Task name
|
||||
-seed SEED Default is None, no random
|
||||
```
|
||||
|
||||
样例:
|
||||
|
||||
``` bash
|
||||
python scripts/sample_data_shot.py \
|
||||
-src converted_data/text2spotasoc/entity/mrc_conll03 \
|
||||
-tgt test_conll03_shot \
|
||||
-task entity
|
||||
```
|
||||
|
||||
1. 读取数据文件夹的 `entity.schema`
|
||||
2. 根据每个类别采样 1 5 10 个样例合成最终数据
|
|
@ -0,0 +1,131 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
import argparse
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import pdb
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-d", "--data_dir", default="./final_data5/data_1", type=str)
|
||||
parser.add_argument("-o", "--output_dir", default="./output_fewshot", type=str)
|
||||
parser.add_argument("--entity_category_dir", default="entity_category", type=str)
|
||||
parser.add_argument("--relation_category_dir", default="relation_category", type=str)
|
||||
parser.add_argument("--step", default=1, type=int)
|
||||
opt = parser.parse_args()
|
||||
|
||||
data_dir = opt.data_dir
|
||||
output_dir = opt.output_dir
|
||||
entity_category_dir = opt.entity_category_dir
|
||||
relation_category_dir = opt.relation_category_dir
|
||||
step = opt.step
|
||||
|
||||
all_file = os.path.join(output_dir, "all.json")
|
||||
|
||||
entity_stat_file = os.path.join(output_dir, "entity_stat.json")
|
||||
relation_stat_file = os.path.join(output_dir, "relation_stat.json")
|
||||
|
||||
target_entity_category_dir = os.path.join(output_dir, entity_category_dir)
|
||||
target_relation_category_dir = os.path.join(output_dir, relation_category_dir)
|
||||
|
||||
if not os.path.exists(target_entity_category_dir):
|
||||
os.makedirs(target_entity_category_dir)
|
||||
if not os.path.exists(target_relation_category_dir):
|
||||
os.makedirs(target_relation_category_dir)
|
||||
|
||||
entity_instance_dict_file = os.path.join(output_dir, "entity_instance_dict.json")
|
||||
relation_instance_dict_file = os.path.join(output_dir, "relation_instance_dict.json")
|
||||
|
||||
metainfo_file = relation_instance_dict_file = os.path.join(output_dir, "metainfo.json")
|
||||
|
||||
# %% load all instance line
|
||||
|
||||
print("Reading all data...")
|
||||
instance_list = []
|
||||
with open(all_file) as all:
|
||||
for idx, line in tqdm(enumerate(all)):
|
||||
if len(line) == 0:
|
||||
continue
|
||||
instance_list.append(line)
|
||||
print("All data read.")
|
||||
|
||||
# %% rearrange instance by class
|
||||
|
||||
print("Stat entity type and relation type...")
|
||||
entity_type_instance_dict = {}
|
||||
relation_type_instance_dict = {}
|
||||
for line in tqdm(instance_list):
|
||||
if len(line) == 0:
|
||||
continue
|
||||
|
||||
record = json.loads(line)
|
||||
|
||||
entity_type_list = record["spot"]
|
||||
relation_type_list = record["asoc"]
|
||||
|
||||
for entity_type in entity_type_list:
|
||||
if entity_type not in entity_type_instance_dict:
|
||||
entity_type_instance_dict[entity_type] = {
|
||||
"type_id": len(entity_type_instance_dict),
|
||||
"instance_list": []
|
||||
}
|
||||
entity_type_instance_dict[entity_type]["instance_list"].append(line)
|
||||
|
||||
for relation_type in relation_type_list:
|
||||
if relation_type not in relation_type_instance_dict:
|
||||
relation_type_instance_dict[relation_type] = {
|
||||
"type_id": len(relation_type_instance_dict),
|
||||
"instance_list": []
|
||||
}
|
||||
relation_type_instance_dict[relation_type]["instance_list"].append(line)
|
||||
|
||||
print("Stat over.")
|
||||
|
||||
# %% save data by category
|
||||
|
||||
metainfo = {
|
||||
"entity": [],
|
||||
"relation": [],
|
||||
}
|
||||
|
||||
print("Saving entity by category...")
|
||||
for entity_type, data in tqdm(entity_type_instance_dict.items()):
|
||||
type_id = data["type_id"]
|
||||
instance_list = data["instance_list"]
|
||||
|
||||
current_metainfo = {
|
||||
"entity_type": entity_type,
|
||||
"type_id": type_id
|
||||
}
|
||||
metainfo["entity"].append(current_metainfo)
|
||||
|
||||
entity_type_file = os.path.join(target_entity_category_dir, str(type_id)+".json")
|
||||
with open(entity_type_file, "w") as f:
|
||||
for instance in instance_list:
|
||||
f.write(instance)
|
||||
print("Entity saved.")
|
||||
|
||||
print("Saving relation by category...")
|
||||
for relation_type, data in tqdm(relation_type_instance_dict.items()):
|
||||
type_id = data["type_id"]
|
||||
instance_list = data["instance_list"]
|
||||
|
||||
current_metainfo = {
|
||||
"relation_type": relation_type,
|
||||
"type_id": type_id
|
||||
}
|
||||
metainfo["relation"].append(current_metainfo)
|
||||
|
||||
relation_type_file = os.path.join(target_relation_category_dir, str(type_id)+".json")
|
||||
with open(relation_type_file, "w") as f:
|
||||
for instance in instance_list:
|
||||
f.write(instance)
|
||||
print("Relation saved.")
|
||||
|
||||
print("Saving metainfo...")
|
||||
with open(metainfo_file, "w") as f:
|
||||
json.dump(metainfo, f)
|
||||
print("Metainfo saved.")
|
|
@ -0,0 +1,38 @@
|
|||
import os
|
||||
import json
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
|
||||
in_files = [
|
||||
'original_train.json',
|
||||
]
|
||||
out_files = [
|
||||
'train.json',
|
||||
]
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-d", "--dir", default="output", type=str)
|
||||
opt = parser.parse_args()
|
||||
|
||||
dir_path = opt.dir
|
||||
|
||||
|
||||
for in_file, out_file in zip(in_files, out_files):
|
||||
in_file_path = os.path.join(dir_path, in_file)
|
||||
out_file_path = os.path.join(dir_path, out_file)
|
||||
|
||||
print(f"{in_file_path} -> {out_file_path}")
|
||||
|
||||
fin = open(in_file_path)
|
||||
fout = open(out_file_path,'w')
|
||||
for line in tqdm(fin):
|
||||
obj = json.loads(line)
|
||||
flag = 0
|
||||
tmp_relation = []
|
||||
for tmp_obj in obj['relation']:
|
||||
tmp = json.dumps(tmp_obj)
|
||||
tmp_relation.append(tmp)
|
||||
obj['relation'] = tmp_relation
|
||||
fout.write(json.dumps(obj)+"\n")
|
||||
fin.close()
|
||||
fout.close()
|
|
@ -0,0 +1,29 @@
|
|||
if [ ! -e out_clean.zip ];
|
||||
then
|
||||
echo "downloading out_clean ..."
|
||||
wget -c http://url/to/dataset/out_clean.zip
|
||||
else
|
||||
echo "out_clean has been downloaded."
|
||||
fi
|
||||
|
||||
if [ ! -d out_clean ];
|
||||
then
|
||||
echo "unziping out_clean"
|
||||
unzip out_clean.zip
|
||||
else
|
||||
echo "out_clean has been unzipped"
|
||||
fi
|
||||
|
||||
# preprocess
|
||||
python explore.py --data_dir ./out_clean --output_dir ./output --max_instance_num -1
|
||||
|
||||
# fewshot sampling
|
||||
python stat_category.py --source_dir ./output --output_dir ./output
|
||||
python partition.py --source_dir ./output --output_dir ./output
|
||||
python match.py --source_dir ./output --output_dir ./output --step 100
|
||||
python rearrange_dataset.py --source_dir ./output --output_dir ./output
|
||||
|
||||
# generate dataset
|
||||
python noise.py --output_dir ./output --all_file rearrange_all.json
|
||||
python change_data_format_for_relation.py -d ./output
|
||||
ln -s ../../../ours/output ../converted_data/text2spotasoc/relation/ours
|
|
@ -0,0 +1,323 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
import argparse
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from nltk.tokenize import WordPunctTokenizer
|
||||
word_tokenizer = WordPunctTokenizer()
|
||||
|
||||
import numpy as np
|
||||
np.set_printoptions(suppress=True)
|
||||
|
||||
import pdb
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-d", "--data_dir", default="./final_data5/data_1", type=str)
|
||||
parser.add_argument("-o", "--output_dir", default="./output", type=str)
|
||||
parser.add_argument("-n", "--max_instance_num", default=-1, type=int)
|
||||
opt = parser.parse_args()
|
||||
|
||||
data_dir = opt.data_dir
|
||||
output_dir = opt.output_dir
|
||||
max_instance_num = opt.max_instance_num
|
||||
|
||||
entity_schema_file = os.path.join(output_dir, "entity.schema")
|
||||
relation_schema_file = os.path.join(output_dir, "relation.schema")
|
||||
event_schema_file = os.path.join(output_dir, "event.schema")
|
||||
record_schema_file = os.path.join(output_dir, "record.schema")
|
||||
|
||||
all_file = os.path.join(output_dir, "all.json")
|
||||
train_file = os.path.join(output_dir, "original_train.json")
|
||||
dev_file = os.path.join(output_dir, "original_val.json")
|
||||
test_file = os.path.join(output_dir, "original_test.json")
|
||||
|
||||
ENTITY_SEARCH_RANGE = 0
|
||||
|
||||
ALL_ENTITY_CNT = 0
|
||||
NOMATCH_ENTITY_CNT = 0
|
||||
|
||||
NON_OFFSET_ENTITY_CNT = 0
|
||||
|
||||
def word_tokenize(text):
|
||||
return word_tokenizer.tokenize(text)
|
||||
|
||||
def record2instance(record):
|
||||
instance = {
|
||||
"text": None,
|
||||
"tokens": None,
|
||||
"record": None,
|
||||
"entity": None,
|
||||
"relation": None,
|
||||
"event": [],
|
||||
"spot": None,
|
||||
"asoc": None,
|
||||
"spot_asoc": None,
|
||||
}
|
||||
|
||||
# create text field
|
||||
text = record["sentence_value"]
|
||||
instance["text"] = text
|
||||
|
||||
# create tokens field
|
||||
tokens = word_tokenize(text)
|
||||
text_length_list.append(len(tokens))
|
||||
instance["tokens"] = tokens
|
||||
|
||||
# create entity field
|
||||
entities = record["sentence_entities"]
|
||||
instance_entity_list = []
|
||||
for entity in entities:
|
||||
entity_uri = entity["uri"]
|
||||
entity_mention = entity["surfaceform"]
|
||||
entity_type = entity["tag"]
|
||||
entity_offset = entity["boundaries_token"]
|
||||
|
||||
if entity_type == "#dateTime":
|
||||
entity_type = "date time"
|
||||
elif entity_type == "#decimal":
|
||||
entity_type = "decimal"
|
||||
elif entity_type == "":
|
||||
entity_type = "other"
|
||||
|
||||
if entity_mention == "":
|
||||
continue
|
||||
|
||||
try:
|
||||
start_index, end_index = entity_offset[0], entity_offset[-1]
|
||||
except:
|
||||
global NON_OFFSET_ENTITY_CNT
|
||||
NON_OFFSET_ENTITY_CNT += 1
|
||||
return None
|
||||
current_mention = " ".join(tokens[start_index:end_index+1])
|
||||
original_mention = " ".join(word_tokenize(entity_mention))
|
||||
if current_mention != original_mention:
|
||||
global NOMATCH_ENTITY_CNT
|
||||
NOMATCH_ENTITY_CNT += 1
|
||||
global ALL_ENTITY_CNT
|
||||
ALL_ENTITY_CNT += 1
|
||||
entity_offset = list(range(start_index, end_index+1))
|
||||
|
||||
instance_entity = {
|
||||
"type": entity_type,
|
||||
"offset": entity_offset,
|
||||
"text": entity_mention,
|
||||
"uri": entity_uri
|
||||
}
|
||||
instance_entity_list.append(instance_entity)
|
||||
instance["entity"] = instance_entity_list
|
||||
|
||||
# create spot field
|
||||
instance_entity_type_list = [i["type"] for i in instance_entity_list]
|
||||
instance["spot"] = list(set(instance_entity_type_list))
|
||||
entity_type_list.extend(instance_entity_type_list)
|
||||
|
||||
# create relation field
|
||||
triples = record["sentence_triples"]
|
||||
instance_relation_list = []
|
||||
for triple in triples:
|
||||
subj = triple["subject"]
|
||||
obj = triple["object"]
|
||||
predicate = triple["predicate"]
|
||||
relation_type = predicate["surfaceform"]
|
||||
|
||||
try:
|
||||
head_entity = [i for i in instance_entity_list if i["uri"] == subj["uri"]][0]
|
||||
except IndexError:
|
||||
continue
|
||||
|
||||
try:
|
||||
tail_entity = [i for i in instance_entity_list if i["uri"] == obj["uri"]][0]
|
||||
except IndexError:
|
||||
continue
|
||||
|
||||
head_entity_type = head_entity["type"]
|
||||
tail_entity_type = tail_entity["type"]
|
||||
|
||||
triple_type = (head_entity_type, relation_type, tail_entity_type)
|
||||
triple_type_list.append(triple_type)
|
||||
|
||||
instance_relation = {
|
||||
"type": relation_type,
|
||||
"args": [
|
||||
head_entity,
|
||||
tail_entity
|
||||
]
|
||||
}
|
||||
instance_relation_list.append(instance_relation)
|
||||
instance["relation"] = instance_relation_list
|
||||
|
||||
# create asoc field
|
||||
instance_asoc_list = [i["type"] for i in instance_relation_list]
|
||||
instance["asoc"] = list(set(instance_asoc_list))
|
||||
relation_list.extend(instance_asoc_list)
|
||||
|
||||
# create spot_asoc field
|
||||
instance_spot_asoc_list = []
|
||||
for entity in instance_entity_list:
|
||||
instance_spot_asoc = {
|
||||
"span": entity["text"],
|
||||
"label": entity["type"],
|
||||
"asoc": []
|
||||
}
|
||||
|
||||
for triple in instance_relation_list:
|
||||
if triple["args"][0]["uri"] == entity["uri"]:
|
||||
asoc_record = [triple["type"], triple["args"][1]["text"]]
|
||||
instance_spot_asoc["asoc"].append(asoc_record)
|
||||
|
||||
instance_spot_asoc_list.append(instance_spot_asoc)
|
||||
instance["spot_asoc"] = instance_spot_asoc_list
|
||||
|
||||
# create record field
|
||||
instance_record = "<extra_id_0> "
|
||||
for instance_spot_asoc in instance_spot_asoc_list:
|
||||
instance_record += "<extra_id_0> "
|
||||
|
||||
instance_record += instance_spot_asoc["label"] + " "
|
||||
instance_record += "<extra_id_5> "
|
||||
instance_record += instance_spot_asoc["span"] + " "
|
||||
|
||||
if len(instance_spot_asoc["asoc"]) != 0:
|
||||
for asoc in instance_spot_asoc["asoc"]:
|
||||
instance_record += "<extra_id_0> "
|
||||
|
||||
instance_record += asoc[0] + " "
|
||||
instance_record += "<extra_id_5> "
|
||||
instance_record += asoc[1] + " "
|
||||
|
||||
instance_record += "<extra_id_1> "
|
||||
|
||||
instance_record += "<extra_id_1> "
|
||||
instance_record += "<extra_id_1>"
|
||||
instance["record"] = instance_record
|
||||
|
||||
return instance
|
||||
|
||||
# %% read data
|
||||
|
||||
file_list = os.listdir(data_dir)
|
||||
|
||||
text_length_list = []
|
||||
record_cnt = 0
|
||||
|
||||
relation_list = []
|
||||
entity_type_list = []
|
||||
triple_type_list = []
|
||||
json_str_length_list = []
|
||||
instance_num = 0
|
||||
with open(all_file, "w") as all:
|
||||
for file_name in tqdm(file_list):
|
||||
file_path = os.path.join(data_dir, file_name)
|
||||
|
||||
with open(file_path) as f:
|
||||
for line in f:
|
||||
if len(line) == 0:
|
||||
continue
|
||||
|
||||
record = json.loads(line)
|
||||
record_cnt += 1
|
||||
|
||||
instance = record2instance(record)
|
||||
if instance is None:
|
||||
continue
|
||||
|
||||
json_str = json.dumps(instance)
|
||||
json_str_length_list.append(len(json_str))
|
||||
all.write(json_str + "\n")
|
||||
instance_num += 1
|
||||
|
||||
if max_instance_num != -1 and instance_num == max_instance_num:
|
||||
break
|
||||
|
||||
if max_instance_num != -1 and instance_num == max_instance_num:
|
||||
break
|
||||
|
||||
print(f"Total number of all entities: {ALL_ENTITY_CNT}")
|
||||
|
||||
print(f"Those entities non-match raw text: {NOMATCH_ENTITY_CNT}")
|
||||
print(f"Non-match rate: {NOMATCH_ENTITY_CNT / ALL_ENTITY_CNT}")
|
||||
|
||||
print(f"Total number of all non-offset entities: {NON_OFFSET_ENTITY_CNT}")
|
||||
print(f"Non-offset rate: {NON_OFFSET_ENTITY_CNT / ALL_ENTITY_CNT}")
|
||||
|
||||
print(f"Total record: {record_cnt}")
|
||||
print(f"Total instance: {instance_num}")
|
||||
|
||||
print()
|
||||
|
||||
# %% stat of text length
|
||||
max_len = max(text_length_list)
|
||||
min_len = min(text_length_list)
|
||||
|
||||
print(f"Max length: {max_len}, Min length: {min_len}")
|
||||
|
||||
bins = 20
|
||||
|
||||
hist, bin_edges = np.histogram(text_length_list, bins=bins, density=False)
|
||||
print("Hist:", hist)
|
||||
print("Edge:", bin_edges)
|
||||
|
||||
satisfied_length_cnt = len([i for i in text_length_list if i <= 512])
|
||||
print(f"Satisfied length cnt: {satisfied_length_cnt} ({satisfied_length_cnt/len(text_length_list)})")
|
||||
print()
|
||||
|
||||
# %% stat of json string length
|
||||
max_json_len = max(json_str_length_list)
|
||||
min_json_len = min(json_str_length_list)
|
||||
|
||||
print(f"Max json length: {max_json_len}, Min json length: {min_json_len}")
|
||||
|
||||
bins = 20
|
||||
|
||||
json_hist, json_bin_edges = np.histogram(json_str_length_list, bins=bins, density=False)
|
||||
print("Hist:", json_hist)
|
||||
print("Edge:", json_bin_edges)
|
||||
|
||||
satisfied_json_length_cnt = len([i for i in json_str_length_list if i <= 4096])
|
||||
print(f"Satisfied json length cnt: {satisfied_json_length_cnt} ({satisfied_json_length_cnt/len(json_str_length_list)})")
|
||||
|
||||
print()
|
||||
|
||||
# %% create schema
|
||||
|
||||
entity_type_list = list(set(entity_type_list))
|
||||
relation_list = list(set(relation_list))
|
||||
|
||||
print(f"Num of entity type: {len(entity_type_list)}")
|
||||
print(f"Num of relation type: {len(relation_list)}")
|
||||
|
||||
record_type_list = {}
|
||||
for head_entity_type, realtion_type, tail_entity_type in triple_type_list:
|
||||
if record_type_list.get(head_entity_type) is None:
|
||||
record_type_list[head_entity_type] = []
|
||||
record_type_list[head_entity_type].append(realtion_type)
|
||||
for head_entity_type, record_relation_list in record_type_list.items():
|
||||
record_type_list[head_entity_type] = list(set(record_relation_list))
|
||||
|
||||
with open(entity_schema_file, "w") as f:
|
||||
f.write(json.dumps(entity_type_list) + "\n")
|
||||
f.write(json.dumps([]) + "\n")
|
||||
f.write(json.dumps({}) + "\n")
|
||||
print("entity.schema saved")
|
||||
|
||||
with open(relation_schema_file, "w") as f:
|
||||
f.write(json.dumps(relation_list) + "\n")
|
||||
f.write(json.dumps(entity_type_list) + "\n")
|
||||
f.write(json.dumps({i: [] for i in relation_list}) + "\n")
|
||||
print("relation.schema saved")
|
||||
|
||||
with open(event_schema_file, "w") as f:
|
||||
f.write(json.dumps([]) + "\n")
|
||||
f.write(json.dumps([]) + "\n")
|
||||
f.write(json.dumps({}) + "\n")
|
||||
print("event.schema saved")
|
||||
|
||||
with open(record_schema_file, "w") as f:
|
||||
f.write(json.dumps(entity_type_list) + "\n")
|
||||
f.write(json.dumps(relation_list) + "\n")
|
||||
f.write(json.dumps(record_type_list) + "\n")
|
||||
print("record.schema saved")
|
||||
|
||||
print()
|
|
@ -0,0 +1,117 @@
|
|||
import os
|
||||
import json
|
||||
import math
|
||||
import time
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
import networkx as nx
|
||||
|
||||
import pdb
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-s", "--source_dir", default="./", type=str)
|
||||
parser.add_argument("-o", "--output_dir", default="./", type=str)
|
||||
parser.add_argument("--step", default=100, type=int)
|
||||
opt = parser.parse_args()
|
||||
|
||||
source_dir = opt.source_dir
|
||||
output_dir = opt.output_dir
|
||||
step = opt.step
|
||||
|
||||
instance_label_file = os.path.join(output_dir, "instance_label.json")
|
||||
partition_file = os.path.join(output_dir, "partition.json")
|
||||
match_group_file = os.path.join(output_dir, "match_group.json")
|
||||
|
||||
# %%
|
||||
|
||||
print("Loading partition...")
|
||||
partition = []
|
||||
with open(partition_file) as f:
|
||||
for line in f:
|
||||
partition.append(json.loads(line))
|
||||
|
||||
# %%
|
||||
|
||||
print("Loading instance label list...")
|
||||
instance_label_list = []
|
||||
with open(instance_label_file) as f:
|
||||
for line in tqdm(f):
|
||||
instance_label = json.loads(line)
|
||||
instance_label_list.append(instance_label)
|
||||
instance_label_dict = {i: j for i, j in instance_label_list}
|
||||
total = len(instance_label_dict)
|
||||
|
||||
# %%
|
||||
|
||||
def score(x_label, y_label, add_coef=True):
|
||||
x_label = set(x_label)
|
||||
y_label = set(y_label)
|
||||
|
||||
y2x_score = len(x_label & y_label) / len(x_label)
|
||||
if add_coef:
|
||||
y2x_score += 1 / len(y_label)
|
||||
x2y_score = len(x_label & y_label) / len(y_label)
|
||||
if add_coef:
|
||||
x2y_score += + 1 / len(x_label)
|
||||
|
||||
if x2y_score > y2x_score:
|
||||
final_score = x2y_score
|
||||
flag = True
|
||||
else:
|
||||
final_score = y2x_score
|
||||
flag = False
|
||||
|
||||
return final_score, flag
|
||||
|
||||
# %%
|
||||
|
||||
print("Matching...")
|
||||
match_group = []
|
||||
for curr_partition in tqdm(partition):
|
||||
type_name, category, instance_list = curr_partition
|
||||
|
||||
if len(instance_list) == 1:
|
||||
match_group.append((instance_list[0], instance_list[0], 1.0))
|
||||
else:
|
||||
# pdb.set_trace()
|
||||
total_epoch = math.ceil(len(instance_list) / step)
|
||||
|
||||
for epoch in tqdm(range(total_epoch), leave=False):
|
||||
batch = instance_list[epoch*step:(epoch+1)*step]
|
||||
|
||||
edges = []
|
||||
for i in range(len(batch)):
|
||||
for j in range(i+1, len(batch)):
|
||||
x_id, y_id = batch[i], batch[j]
|
||||
|
||||
x_label = instance_label_dict[x_id]
|
||||
y_label = instance_label_dict[y_id]
|
||||
|
||||
edge_weight, _ = score(x_label, y_label)
|
||||
|
||||
edges.append((x_id, y_id, edge_weight))
|
||||
|
||||
G = nx.Graph()
|
||||
G.add_weighted_edges_from(edges)
|
||||
|
||||
match_result = nx.max_weight_matching(G)
|
||||
|
||||
for edge in match_result:
|
||||
x_id, y_id = edge
|
||||
x_label = instance_label_dict[x_id]
|
||||
y_label = instance_label_dict[y_id]
|
||||
match_score, flag = score(x_label, y_label, add_coef=False)
|
||||
|
||||
if flag:
|
||||
match_group.append((x_id, y_id, match_score))
|
||||
else:
|
||||
match_group.append((y_id, x_id, match_score))
|
||||
|
||||
scores = [i[-1] for i in match_group]
|
||||
average_score = sum(scores) / len(scores)
|
||||
print(f"Average match score: {average_score}")
|
||||
|
||||
print("Saving match group...")
|
||||
with open(match_group_file, "w") as f:
|
||||
for record in match_group:
|
||||
f.write(json.dumps(record)+"\n")
|
|
@ -0,0 +1,242 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
from copy import deepcopy
|
||||
import numpy as np
|
||||
|
||||
import pdb
|
||||
|
||||
seed = 0
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-o", "--output_dir", default="./output", type=str)
|
||||
parser.add_argument("-a", "--all_file", default="all.json", type=str)
|
||||
parser.add_argument("-n", "--noise", default=4, type=int)
|
||||
opt = parser.parse_args()
|
||||
|
||||
output_dir = opt.output_dir
|
||||
all_file = opt.all_file
|
||||
noise = opt.noise
|
||||
|
||||
original_all_file = os.path.join(output_dir, all_file)
|
||||
noised_all_file = os.path.join(output_dir, "noised_all.json")
|
||||
|
||||
train_file = os.path.join(output_dir, "original_train.json")
|
||||
dev_file = os.path.join(output_dir, "original_val.json")
|
||||
test_file = os.path.join(output_dir, "original_test.json")
|
||||
|
||||
# %% noise function
|
||||
|
||||
NOISE_NUM = noise
|
||||
|
||||
THRESHOLD = 0.8
|
||||
TRIPLE_THRESHOLD = [0.6, 0.8]
|
||||
|
||||
DECAY_COEF = 0.8
|
||||
NOISE_OFFSET_THRESHOLD = 3
|
||||
NOISE_OFFSET_RANGE = list(range(NOISE_OFFSET_THRESHOLD))
|
||||
NOISE_OFFSET_WEIGHT = np.exp(- DECAY_COEF * np.array(NOISE_OFFSET_RANGE))
|
||||
NOISE_OFFSET_WEIGHT = NOISE_OFFSET_WEIGHT / NOISE_OFFSET_WEIGHT.sum()
|
||||
|
||||
def noise_entity_type(entity_list):
|
||||
entity_type_list = []
|
||||
for entity in entity_list:
|
||||
entity_type_list.append(entity["type"])
|
||||
entity_type_list = list(set(entity_type_list))
|
||||
|
||||
noised_entity_list = []
|
||||
for entity in entity_list:
|
||||
noised_entity = deepcopy(entity)
|
||||
if np.random.rand() > THRESHOLD:
|
||||
noised_entity_type = random.choice(entity_type_list)
|
||||
noised_entity["type"] = noised_entity_type
|
||||
noised_entity_list.append(noised_entity)
|
||||
return noised_entity_list
|
||||
|
||||
|
||||
def noise_entity_offset(entity_list, tokens):
|
||||
noised_entity_list = []
|
||||
for entity in entity_list:
|
||||
noised_entity = deepcopy(entity)
|
||||
|
||||
entity_offset = noised_entity["offset"]
|
||||
start_index, end_index = entity_offset[0], entity_offset[-1]
|
||||
|
||||
start_noise = np.random.choice(NOISE_OFFSET_RANGE, p=NOISE_OFFSET_WEIGHT)
|
||||
end_noise = np.random.choice(NOISE_OFFSET_RANGE, p=NOISE_OFFSET_WEIGHT)
|
||||
|
||||
noised_start_index = max(start_index-start_noise, 0)
|
||||
noised_end_index = min(end_index+end_noise, len(tokens)-1)
|
||||
noised_entity_offset = list(range(noised_start_index, noised_end_index+1))
|
||||
|
||||
noised_entity_mention = " ".join(tokens[noised_start_index:noised_end_index+1])
|
||||
|
||||
noised_entity["offset"] = noised_entity_offset
|
||||
noised_entity["text"] = noised_entity_mention
|
||||
|
||||
noised_entity_list.append(noised_entity)
|
||||
return noised_entity_list
|
||||
|
||||
def noise_entity_with_other_entity(entity_list):
|
||||
type_entity_mapping = {}
|
||||
for entity in entity_list:
|
||||
entity_type = entity["type"]
|
||||
if entity_type not in type_entity_mapping:
|
||||
type_entity_mapping[entity_type] = []
|
||||
type_entity_mapping[entity_type].append(entity)
|
||||
|
||||
noised_entity_list = []
|
||||
for entity in entity_list:
|
||||
noised_entity = deepcopy(entity)
|
||||
if np.random.rand() > THRESHOLD:
|
||||
entity_type = noised_entity["type"]
|
||||
other_entity = random.choice(type_entity_mapping[entity_type])
|
||||
noised_entity["text"] = other_entity["text"]
|
||||
noised_entity["offset"] = other_entity["offset"]
|
||||
noised_entity_list.append(noised_entity)
|
||||
return noised_entity_list
|
||||
|
||||
def noise_relation_type(triple_list):
|
||||
relation_type_list = []
|
||||
for triple in triple_list:
|
||||
relation_type_list.append(triple["type"])
|
||||
relation_type_list = list(set(relation_type_list))
|
||||
|
||||
noised_triple_list = []
|
||||
for triple in triple_list:
|
||||
noised_triple = deepcopy(triple)
|
||||
if np.random.rand() > THRESHOLD:
|
||||
noised_relation_type = random.choice(relation_type_list)
|
||||
noised_triple["type"] = noised_relation_type
|
||||
noised_triple_list.append(noised_triple)
|
||||
return noised_triple_list
|
||||
|
||||
def noise_triple_num(triple_list, entity_list):
|
||||
noised_triple_list = []
|
||||
for triple in triple_list:
|
||||
p = np.random.rand()
|
||||
if p < TRIPLE_THRESHOLD[0]: # do nothing
|
||||
noised_triple_list.append(triple)
|
||||
elif p < TRIPLE_THRESHOLD[1]: # add noised triple
|
||||
noised_triple_list.append(triple)
|
||||
|
||||
noised_triple = deepcopy(triple)
|
||||
replaced_tail = random.choice(entity_list)
|
||||
noised_triple["args"][1] = replaced_tail
|
||||
noised_triple_list.append(noised_triple)
|
||||
else: # remove triple
|
||||
pass
|
||||
|
||||
return noised_triple_list
|
||||
|
||||
# %% utils
|
||||
|
||||
def build_entity_dict(entity_list):
|
||||
entity_dict = {}
|
||||
for entity in entity_list:
|
||||
entity_uri = entity["uri"]
|
||||
entity_dict[entity_uri] = entity
|
||||
return entity_dict
|
||||
|
||||
def update_relation_triple_by_noised_entity(triple_list, noised_entity_dict):
|
||||
noised_triple_list = []
|
||||
for triple in triple_list:
|
||||
noised_triple = deepcopy(triple)
|
||||
head, tail = noised_triple["args"]
|
||||
noised_head, noised_tail = noised_entity_dict[head["uri"]], noised_entity_dict[tail["uri"]]
|
||||
noised_triple["args"] = [noised_head, noised_tail]
|
||||
noised_triple_list.append(noised_triple)
|
||||
return noised_triple_list
|
||||
|
||||
def create_spot_asoc_field(instance_entity_list, instance_triple_list):
|
||||
instance_spot_asoc_list = []
|
||||
for entity in instance_entity_list:
|
||||
instance_spot_asoc = {
|
||||
"span": entity["text"],
|
||||
"label": entity["type"],
|
||||
"asoc": []
|
||||
}
|
||||
|
||||
for triple in instance_triple_list:
|
||||
if triple["args"][0]["uri"] == entity["uri"]:
|
||||
asoc_record = [triple["type"], triple["args"][1]["text"]]
|
||||
instance_spot_asoc["asoc"].append(asoc_record)
|
||||
|
||||
instance_spot_asoc_list.append(instance_spot_asoc)
|
||||
return instance_spot_asoc_list
|
||||
|
||||
def create_record_field(instance_spot_asoc_list):
|
||||
instance_record = "<extra_id_0> "
|
||||
for instance_spot_asoc in instance_spot_asoc_list:
|
||||
instance_record += "<extra_id_0> "
|
||||
|
||||
instance_record += instance_spot_asoc["label"] + " "
|
||||
instance_record += "<extra_id_5> "
|
||||
instance_record += instance_spot_asoc["span"] + " "
|
||||
|
||||
if len(instance_spot_asoc["asoc"]) != 0:
|
||||
for asoc in instance_spot_asoc["asoc"]:
|
||||
instance_record += "<extra_id_0> "
|
||||
|
||||
instance_record += asoc[0] + " "
|
||||
instance_record += "<extra_id_5> "
|
||||
instance_record += asoc[1] + " "
|
||||
|
||||
instance_record += "<extra_id_1> "
|
||||
|
||||
instance_record += "<extra_id_1> "
|
||||
instance_record += "<extra_id_1>"
|
||||
|
||||
return instance_record
|
||||
|
||||
# %% create noised record for all
|
||||
|
||||
with open(original_all_file) as src, open(noised_all_file, "w") as tgt:
|
||||
for line in tqdm(src):
|
||||
instance = json.loads(line)
|
||||
|
||||
tokens = instance["tokens"]
|
||||
entity_list = instance["entity"]
|
||||
triple_list = instance["relation"]
|
||||
spot_asoc_list = instance["spot_asoc"]
|
||||
record = instance["record"]
|
||||
|
||||
noised_record_list = []
|
||||
for _ in range(NOISE_NUM):
|
||||
# noise entity
|
||||
noised_entity_list = noise_entity_offset(entity_list, tokens)
|
||||
noised_entity_list = noise_entity_with_other_entity(noised_entity_list)
|
||||
noised_entity_list = noise_entity_type(noised_entity_list)
|
||||
|
||||
noised_entity_dict = build_entity_dict(noised_entity_list)
|
||||
|
||||
# noise triple
|
||||
noised_triple_list = update_relation_triple_by_noised_entity(triple_list, noised_entity_dict)
|
||||
|
||||
noised_triple_list = noise_relation_type(noised_triple_list)
|
||||
noised_triple_list = noise_triple_num(noised_triple_list, noised_entity_list)
|
||||
|
||||
# create noised record
|
||||
noised_spot_asoc_list = create_spot_asoc_field(noised_entity_list, noised_triple_list)
|
||||
noised_record = create_record_field(noised_spot_asoc_list)
|
||||
noised_record_list.append(noised_record)
|
||||
|
||||
# remove uir field
|
||||
for entity in entity_list:
|
||||
del entity["uri"]
|
||||
|
||||
instance["noised_record"] = noised_record_list
|
||||
|
||||
json_str = json.dumps(instance)
|
||||
tgt.write(json_str + "\n")
|
||||
|
||||
# %% create train/dev/test data
|
||||
|
||||
with open(noised_all_file) as all, open(train_file, "w") as train, open(dev_file, "w") as dev, open(test_file, "w") as test:
|
||||
for i, line in tqdm(enumerate(all)):
|
||||
train.write(line)
|
||||
print("train/dev/test saved.")
|
|
@ -0,0 +1,96 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
import argparse
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import pdb
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-s", "--source_dir", default="./", type=str)
|
||||
parser.add_argument("-o", "--output_dir", default="./", type=str)
|
||||
opt = parser.parse_args()
|
||||
|
||||
source_dir = opt.source_dir
|
||||
output_dir = opt.output_dir
|
||||
|
||||
all_file = os.path.join(source_dir, "all.json")
|
||||
|
||||
entity_stat_file = os.path.join(output_dir, "entity_stat.json")
|
||||
relation_stat_file = os.path.join(output_dir, "relation_stat.json")
|
||||
|
||||
partition_file = os.path.join(output_dir, "partition.json")
|
||||
|
||||
entity_stat_list = []
|
||||
relation_stat_list = []
|
||||
with open(entity_stat_file) as f:
|
||||
for line in f:
|
||||
category = json.loads(line)
|
||||
category[1]["type"] = "entity"
|
||||
entity_stat_list.append(category)
|
||||
with open(relation_stat_file) as f:
|
||||
for line in f:
|
||||
category = json.loads(line)
|
||||
category[1]["type"] = "relation"
|
||||
relation_stat_list.append(category)
|
||||
|
||||
all_stat_list = entity_stat_list + relation_stat_list
|
||||
all_stat_list = sorted(all_stat_list, key=lambda x: len(x[1]["instance_id_list"]))
|
||||
|
||||
instance_type_dict = {}
|
||||
for curr_type, curr_record in tqdm(all_stat_list):
|
||||
instance_id_list = curr_record["instance_id_list"]
|
||||
for instance_id in instance_id_list:
|
||||
if instance_id not in instance_type_dict:
|
||||
instance_type_dict[instance_id] = set()
|
||||
instance_type_dict[instance_id].add(curr_type)
|
||||
|
||||
def get_visited_type(instance_id_list, instance_type_dict):
|
||||
visited_type = set()
|
||||
for i, instance_id in enumerate(instance_id_list):
|
||||
if i == 0:
|
||||
visited_type |= instance_type_dict[instance_id]
|
||||
else:
|
||||
visited_type &= instance_type_dict[instance_id]
|
||||
return visited_type
|
||||
|
||||
print("Begining partition...")
|
||||
visited_instance = set()
|
||||
visited_type = set()
|
||||
partition = []
|
||||
empty_set_cnt = 0
|
||||
duplicated_instance_cnt = 0
|
||||
for curr_type, curr_record in tqdm(all_stat_list):
|
||||
category_type = curr_record["type"]
|
||||
instance_id_list = curr_record["instance_id_list"]
|
||||
|
||||
instance_id_set = set(instance_id_list)
|
||||
instance_id_set = instance_id_set - visited_instance
|
||||
|
||||
curr_visited_type = get_visited_type(instance_id_list, instance_type_dict)
|
||||
|
||||
if len(instance_id_set) == 0:
|
||||
if curr_type in visited_type:
|
||||
continue
|
||||
else:
|
||||
non_visited_type = curr_visited_type - visited_type
|
||||
instance_id_set = set(instance_id_list)
|
||||
empty_set_cnt += 1
|
||||
duplicated_instance_cnt += len(instance_id_list)
|
||||
|
||||
curr_partition = [curr_type, category_type, list(instance_id_set)]
|
||||
partition.append(curr_partition)
|
||||
|
||||
visited_instance.update(instance_id_set)
|
||||
visited_type.update(curr_visited_type)
|
||||
|
||||
print(f"Empty set rate: {empty_set_cnt / len(all_stat_list)}")
|
||||
print(f"Duplication rate: {duplicated_instance_cnt / len(instance_type_dict)}")
|
||||
|
||||
print("Saving partition...")
|
||||
with open(partition_file, "w") as f:
|
||||
for record in partition:
|
||||
f.write(json.dumps(record)+"\n")
|
|
@ -0,0 +1,50 @@
|
|||
import os
|
||||
import json
|
||||
import math
|
||||
import time
|
||||
import random
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
|
||||
import pdb
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-s", "--source_dir", default="./", type=str)
|
||||
parser.add_argument("-o", "--output_dir", default="./", type=str)
|
||||
opt = parser.parse_args()
|
||||
|
||||
source_dir = opt.source_dir
|
||||
output_dir = opt.output_dir
|
||||
|
||||
all_file = os.path.join(source_dir, "all.json")
|
||||
match_group_file = os.path.join(output_dir, "match_group.json")
|
||||
rearrange_all_file = os.path.join(output_dir, "rearrange_all.json")
|
||||
|
||||
# %%
|
||||
|
||||
print("Loading match group...")
|
||||
match_group = []
|
||||
with open(match_group_file) as f:
|
||||
for line in tqdm(f):
|
||||
match_group.append(json.loads(line))
|
||||
|
||||
# %%
|
||||
|
||||
print("Loading instance...")
|
||||
instance_list = []
|
||||
with open(all_file) as f:
|
||||
for line in tqdm(f):
|
||||
instance_list.append(line)
|
||||
|
||||
# %%
|
||||
|
||||
print("Rearrange dataset...")
|
||||
with open(rearrange_all_file, "w") as f:
|
||||
for edge in tqdm(match_group):
|
||||
support_id, query_id, _ = edge
|
||||
|
||||
support = instance_list[support_id]
|
||||
query = instance_list[query_id]
|
||||
|
||||
f.write(support)
|
||||
f.write(query)
|
|
@ -0,0 +1,150 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
import argparse
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import pdb
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-d", "--data_dir", default="./final_data5/data_1", type=str)
|
||||
parser.add_argument("-o", "--output_dir", default="./output_fewshot", type=str)
|
||||
parser.add_argument("--entity_category_dir", default="entity_category", type=str)
|
||||
parser.add_argument("--relation_category_dir", default="relation_category", type=str)
|
||||
parser.add_argument("--task_num", default=10000, type=int)
|
||||
parser.add_argument("--N", default=5, type=int)
|
||||
parser.add_argument("--K", default=5, type=int)
|
||||
parser.add_argument("--Q", default=5, type=int)
|
||||
opt = parser.parse_args()
|
||||
|
||||
data_dir = opt.data_dir
|
||||
output_dir = opt.output_dir
|
||||
entity_category_dir = opt.entity_category_dir
|
||||
relation_category_dir = opt.relation_category_dir
|
||||
task_num = opt.task_num
|
||||
N = opt.N
|
||||
K = opt.K
|
||||
Q = opt.Q
|
||||
|
||||
target_entity_category_dir = os.path.join(output_dir, entity_category_dir)
|
||||
target_relation_category_dir = os.path.join(output_dir, relation_category_dir)
|
||||
|
||||
metainfo_file = relation_instance_dict_file = os.path.join(output_dir, "metainfo.json")
|
||||
|
||||
task_file = os.path.join(output_dir, "sampled_task.json")
|
||||
|
||||
# %% read instance dict
|
||||
|
||||
print("Reading metainfo...")
|
||||
with open(metainfo_file) as f:
|
||||
metainfo = json.load(f)
|
||||
print("Metainfo read.")
|
||||
|
||||
print("Loading entity instance dict...")
|
||||
entity_type_instance_dict = {}
|
||||
for current_metainfo in tqdm(metainfo["entity"]):
|
||||
entity_type = current_metainfo["entity_type"]
|
||||
type_id = current_metainfo["type_id"]
|
||||
|
||||
entity_type_file = os.path.join(target_entity_category_dir, str(type_id)+".json")
|
||||
instance_list = []
|
||||
with open(entity_type_file) as f:
|
||||
for line in f:
|
||||
instance_list.append(line)
|
||||
|
||||
entity_type_instance_dict[entity_type] = instance_list
|
||||
entity_type_list = list(entity_type_instance_dict.keys())
|
||||
print("Entity instance dict loaded")
|
||||
|
||||
print("Loading relation instance dict...")
|
||||
relation_type_instance_dict = {}
|
||||
for current_metainfo in tqdm(metainfo["relation"]):
|
||||
relation_type = current_metainfo["relation_type"]
|
||||
type_id = current_metainfo["type_id"]
|
||||
|
||||
relation_type_file = os.path.join(target_relation_category_dir, str(type_id)+".json")
|
||||
instance_list = []
|
||||
with open(relation_type_file) as f:
|
||||
for line in f:
|
||||
instance_list.append(line)
|
||||
|
||||
relation_type_instance_dict[relation_type] = instance_list
|
||||
relation_type_list = list(relation_type_instance_dict.keys())
|
||||
print("Relation instance dict loaded.")
|
||||
|
||||
# %% n-way-k-shot sampling
|
||||
|
||||
print("Sampling N-Way K-Shot task...")
|
||||
task_list = []
|
||||
for i in tqdm(range(task_num//2)):
|
||||
# sample entity task
|
||||
target_entity_type_list = random.sample(entity_type_list, N)
|
||||
|
||||
task = {
|
||||
"target_entity_type_list": target_entity_type_list,
|
||||
"target_relation_type_list": [],
|
||||
"N": N,
|
||||
"K": K,
|
||||
"Q": Q,
|
||||
"support": None,
|
||||
"query": None
|
||||
}
|
||||
|
||||
support = []
|
||||
query = []
|
||||
|
||||
for entity_type in target_entity_type_list:
|
||||
instance_candidates = entity_type_instance_dict[entity_type]
|
||||
|
||||
if len(instance_candidates) > K+Q:
|
||||
sampled_instance_list = random.sample(instance_candidates, K+Q)
|
||||
else:
|
||||
sampled_instance_list = random.choices(instance_candidates, k=K+Q)
|
||||
|
||||
support.extend(sampled_instance_list[:K])
|
||||
query.extend(sampled_instance_list[K:])
|
||||
|
||||
task["support"] = support
|
||||
task["query"] = query
|
||||
|
||||
task_list.append(task)
|
||||
|
||||
# sample relation task
|
||||
target_relation_type_list = random.sample(relation_type_list, N)
|
||||
|
||||
task = {
|
||||
"target_entity_type_list": [],
|
||||
"target_relation_type_list": target_relation_type_list,
|
||||
"N": N,
|
||||
"K": K,
|
||||
"Q": Q,
|
||||
"support": None,
|
||||
"query": None
|
||||
}
|
||||
|
||||
support = []
|
||||
query = []
|
||||
|
||||
for relation_type in target_relation_type_list:
|
||||
instance_candidates = relation_type_instance_dict[relation_type]
|
||||
|
||||
if len(instance_candidates) > K+Q:
|
||||
sampled_instance_list = random.sample(instance_candidates, K+Q)
|
||||
else:
|
||||
sampled_instance_list = random.choices(instance_candidates, k=K+Q)
|
||||
|
||||
support.extend(sampled_instance_list[:K])
|
||||
query.extend(sampled_instance_list[K:])
|
||||
|
||||
task["support"] = support
|
||||
task["query"] = query
|
||||
|
||||
task_list.append(task)
|
||||
print("Sampling over.")
|
||||
|
||||
print("Saving task...")
|
||||
with open(task_file, "w") as f:
|
||||
for task in tqdm(task_list):
|
||||
f.write(json.dumps(task) + "\n")
|
||||
print("Task saved.")
|
|
@ -0,0 +1,64 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
import argparse
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import pdb
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-d", "--data_dir", default="./final_data5/data_1", type=str)
|
||||
parser.add_argument("-s", "--source_dir", default="./output", type=str)
|
||||
parser.add_argument("-o", "--output_dir", default="./output_fewshot", type=str)
|
||||
opt = parser.parse_args()
|
||||
|
||||
data_dir = opt.data_dir
|
||||
source_dir = opt.source_dir
|
||||
output_dir = opt.output_dir
|
||||
|
||||
all_file = os.path.join(source_dir, "all.json")
|
||||
|
||||
entity_stat_file = os.path.join(output_dir, "entity_stat.json")
|
||||
relation_stat_file = os.path.join(output_dir, "relation_stat.json")
|
||||
|
||||
# %% read data and stat
|
||||
|
||||
entity_stat_dict = {}
|
||||
relation_stat_dict = {}
|
||||
|
||||
record_cnt = 0
|
||||
with open(all_file) as all:
|
||||
for line in tqdm(all):
|
||||
if len(line) == 0:
|
||||
continue
|
||||
|
||||
record = json.loads(line)
|
||||
|
||||
entity_type_list = record["spot"]
|
||||
relation_type_list = record["asoc"]
|
||||
|
||||
for entity_type in entity_type_list:
|
||||
if entity_type not in entity_stat_dict:
|
||||
entity_stat_dict[entity_type] = {
|
||||
"type_id": len(entity_stat_dict),
|
||||
"instance_id_list": []
|
||||
}
|
||||
entity_stat_dict[entity_type]["instance_id_list"].append(record_cnt)
|
||||
|
||||
for relation_type in relation_type_list:
|
||||
if relation_type not in relation_stat_dict:
|
||||
relation_stat_dict[relation_type] = {
|
||||
"type_id": len(relation_stat_dict),
|
||||
"instance_id_list": []
|
||||
}
|
||||
relation_stat_dict[relation_type]["instance_id_list"].append(record_cnt)
|
||||
|
||||
record_cnt += 1
|
||||
|
||||
with open(entity_stat_file, "w") as f:
|
||||
json.dump(entity_stat_dict, f)
|
||||
with open(relation_stat_file, "w") as f:
|
||||
json.dump(relation_stat_dict, f)
|
|
@ -0,0 +1,79 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
import argparse
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import pdb
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-s", "--source_dir", default="./", type=str)
|
||||
parser.add_argument("-o", "--output_dir", default="./", type=str)
|
||||
opt = parser.parse_args()
|
||||
|
||||
source_dir = opt.source_dir
|
||||
output_dir = opt.output_dir
|
||||
|
||||
all_file = os.path.join(source_dir, "all.json")
|
||||
|
||||
entity_stat_file = os.path.join(output_dir, "entity_stat.json")
|
||||
relation_stat_file = os.path.join(output_dir, "relation_stat.json")
|
||||
|
||||
instance_label_file = os.path.join(output_dir, "instance_label.json")
|
||||
|
||||
# %% read data and stat
|
||||
|
||||
instance_label_list = []
|
||||
entity_stat_dict = {}
|
||||
relation_stat_dict = {}
|
||||
|
||||
print("Stating label...")
|
||||
record_cnt = 0
|
||||
with open(all_file) as all:
|
||||
for line in tqdm(all):
|
||||
if len(line) == 0:
|
||||
continue
|
||||
|
||||
record = json.loads(line)
|
||||
|
||||
entity_type_list = record["spot"]
|
||||
relation_type_list = record["asoc"]
|
||||
|
||||
labels = entity_type_list + relation_type_list
|
||||
instance_label_list.append((record_cnt, labels))
|
||||
|
||||
for entity_type in entity_type_list:
|
||||
if entity_type not in entity_stat_dict:
|
||||
entity_stat_dict[entity_type] = {
|
||||
"type_id": len(entity_stat_dict),
|
||||
"instance_id_list": []
|
||||
}
|
||||
entity_stat_dict[entity_type]["instance_id_list"].append(record_cnt)
|
||||
|
||||
for relation_type in relation_type_list:
|
||||
if relation_type not in relation_stat_dict:
|
||||
relation_stat_dict[relation_type] = {
|
||||
"type_id": len(relation_stat_dict),
|
||||
"instance_id_list": []
|
||||
}
|
||||
relation_stat_dict[relation_type]["instance_id_list"].append(record_cnt)
|
||||
|
||||
record_cnt += 1
|
||||
|
||||
print("Saving entity stat...")
|
||||
with open(entity_stat_file, "w") as f:
|
||||
for key, value in tqdm(entity_stat_dict.items()):
|
||||
f.write(json.dumps([key, value])+"\n")
|
||||
print("Saving relation stat...")
|
||||
with open(relation_stat_file, "w") as f:
|
||||
for key, value in tqdm(relation_stat_dict.items()):
|
||||
f.write(json.dumps([key, value])+"\n")
|
||||
|
||||
print("Saving instance label stat...")
|
||||
instance_label_list = sorted(instance_label_list, key=lambda x: len(x[1]), reverse=True)
|
||||
with open(instance_label_file, "w") as f:
|
||||
for instance_label in tqdm(instance_label_list):
|
||||
f.write(json.dumps(instance_label)+"\n")
|
|
@ -0,0 +1,187 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
import argparse
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import pdb
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-d", "--data_dir", default="./final_data5/data_1", type=str)
|
||||
opt = parser.parse_args()
|
||||
|
||||
data_dir = opt.data_dir
|
||||
output_dir = opt.output_dir
|
||||
|
||||
task_file = os.path.join(output_dir, "sampled_task.json")
|
||||
|
||||
sampled_all_file = os.path.join(output_dir, "sampled_all.json")
|
||||
|
||||
# %% utils
|
||||
|
||||
def create_spot_asoc_field(instance_entity_list, instance_triple_list):
|
||||
instance_spot_asoc_list = []
|
||||
for entity in instance_entity_list:
|
||||
instance_spot_asoc = {
|
||||
"span": entity["text"],
|
||||
"label": entity["type"],
|
||||
"asoc": []
|
||||
}
|
||||
|
||||
for triple in instance_triple_list:
|
||||
if triple["args"][0]["uri"] == entity["uri"]:
|
||||
asoc_record = [triple["type"], triple["args"][1]["text"]]
|
||||
instance_spot_asoc["asoc"].append(asoc_record)
|
||||
|
||||
instance_spot_asoc_list.append(instance_spot_asoc)
|
||||
return instance_spot_asoc_list
|
||||
|
||||
def create_record_field(instance_spot_asoc_list):
|
||||
instance_record = "<extra_id_0> "
|
||||
for instance_spot_asoc in instance_spot_asoc_list:
|
||||
instance_record += "<extra_id_0> "
|
||||
|
||||
instance_record += instance_spot_asoc["label"] + " "
|
||||
instance_record += "<extra_id_5> "
|
||||
instance_record += instance_spot_asoc["span"] + " "
|
||||
|
||||
if len(instance_spot_asoc["asoc"]) != 0:
|
||||
for asoc in instance_spot_asoc["asoc"]:
|
||||
instance_record += "<extra_id_0> "
|
||||
|
||||
instance_record += asoc[0] + " "
|
||||
instance_record += "<extra_id_5> "
|
||||
instance_record += asoc[1] + " "
|
||||
|
||||
instance_record += "<extra_id_1> "
|
||||
|
||||
instance_record += "<extra_id_1> "
|
||||
instance_record += "<extra_id_1>"
|
||||
|
||||
return instance_record
|
||||
|
||||
def filter_entity_by_entity_type(entity_list, target_entity_type_list):
|
||||
'''
|
||||
{"type": "rocket stage", "offset": [11, 12, 13], "text": "S-II", "uri": "Q1093699"}
|
||||
'''
|
||||
filtered_entity_list = [entity for entity in entity_list if entity["type"] in target_entity_type_list]
|
||||
return filtered_entity_list
|
||||
|
||||
def filter_triple_by_entity_list(triple_list, filtered_entity_list):
|
||||
'''
|
||||
{"type": "part of", "args": [{"type": "rocket stage", "offset": [1, 2, 3], "text": "MS-II", "uri": "Q6717655"}, {"type": "rocket stage", "offset": [11, 12, 13], "text": "S-II", "uri": "Q1093699"}]}
|
||||
'''
|
||||
filtered_triple_list = []
|
||||
for triple in triple_list:
|
||||
head, tail = triple["args"]
|
||||
if head in filtered_entity_list and tail in filtered_entity_list:
|
||||
filtered_triple_list.append(triple)
|
||||
return filtered_triple_list
|
||||
|
||||
def build_target_relation_type_list(filtered_triple_list):
|
||||
target_relation_type_list = [triple["type"] for triple in filtered_triple_list]
|
||||
target_relation_type_list = list(set(target_relation_type_list))
|
||||
return target_relation_type_list
|
||||
|
||||
def filter_triple_by_relation_type(triple_list, target_relation_type_list):
|
||||
'''
|
||||
{"type": "part of", "args": [{"type": "rocket stage", "offset": [1, 2, 3], "text": "MS-II", "uri": "Q6717655"}, {"type": "rocket stage", "offset": [11, 12, 13], "text": "S-II", "uri": "Q1093699"}]}
|
||||
'''
|
||||
filtered_triple_list = [triple for triple in triple_list if triple["type"] in target_relation_type_list]
|
||||
return filtered_triple_list
|
||||
|
||||
def filter_entity_by_triple_list(entity_list, filtered_triple_list):
|
||||
filtered_entity_list = []
|
||||
for triple in filtered_triple_list:
|
||||
head, tail = triple["args"]
|
||||
filtered_entity_list.append(head)
|
||||
filtered_entity_list.append(tail)
|
||||
entity_uri_set = set()
|
||||
unique_filtered_entity_list = []
|
||||
for entity in filtered_entity_list:
|
||||
uri = entity["uri"]
|
||||
if uri not in entity_uri_set:
|
||||
entity_uri_set.add(uri)
|
||||
unique_filtered_entity_list.append(entity)
|
||||
return unique_filtered_entity_list
|
||||
|
||||
def build_target_entity_type_list(filtered_entity_list):
|
||||
target_entity_type_list = [entity["type"] for entity in filtered_entity_list]
|
||||
target_entity_type_list = list(set(target_entity_type_list))
|
||||
return target_entity_type_list
|
||||
|
||||
def create_instance(instance_line, target_entity_type_list, target_relation_type_list):
|
||||
instance = json.loads(instance_line)
|
||||
|
||||
entity_list = instance["entity"]
|
||||
triple_list = instance["relation"]
|
||||
spot_asoc_list = instance["spot_asoc"]
|
||||
record = instance["record"]
|
||||
|
||||
if len(target_relation_type_list) == 0:
|
||||
filtered_entity_list = filter_entity_by_entity_type(entity_list, target_entity_type_list)
|
||||
filtered_triple_list = filter_triple_by_entity_list(triple_list, filtered_entity_list)
|
||||
|
||||
current_target_entity_type_list = target_entity_type_list
|
||||
current_target_relation_type_list = build_target_relation_type_list(filtered_triple_list)
|
||||
else:
|
||||
filtered_triple_list = filter_triple_by_relation_type(triple_list, target_relation_type_list)
|
||||
filtered_entity_list = filter_entity_by_triple_list(entity_list, filtered_triple_list)
|
||||
|
||||
current_target_entity_type_list = build_target_entity_type_list(filtered_entity_list)
|
||||
current_target_relation_type_list = target_relation_type_list
|
||||
|
||||
filtered_spot_asoc_list = create_spot_asoc_field(filtered_entity_list, filtered_triple_list)
|
||||
filtered_record = create_record_field(filtered_spot_asoc_list)
|
||||
|
||||
instance["entity"] = filtered_entity_list
|
||||
instance["relation"] = filtered_triple_list
|
||||
instance["spot"] = current_target_entity_type_list
|
||||
instance["asoc"] = current_target_relation_type_list
|
||||
instance["spot_asoc"] = filtered_spot_asoc_list
|
||||
instance["record"] = filtered_record
|
||||
|
||||
return instance
|
||||
|
||||
# %% read task
|
||||
|
||||
print("Reading task...")
|
||||
task_list = []
|
||||
with open(task_file) as f:
|
||||
for line in tqdm(f):
|
||||
task_list.append(line)
|
||||
print("Task read.")
|
||||
|
||||
# %% write to sampled all
|
||||
|
||||
print("Changing task format...")
|
||||
with open(sampled_all_file, "w") as f:
|
||||
for task_line in tqdm(task_list):
|
||||
task = json.loads(task_line)
|
||||
|
||||
target_entity_type_list = task["target_entity_type_list"]
|
||||
target_relation_type_list = task["target_relation_type_list"]
|
||||
|
||||
support = task["support"]
|
||||
query = task["query"]
|
||||
|
||||
support_instance_list = []
|
||||
for instance_line in support:
|
||||
instance = create_instance(instance_line, target_entity_type_list, target_relation_type_list)
|
||||
|
||||
support_instance_list.append(instance)
|
||||
|
||||
query_instance_list = []
|
||||
for instance_line in query:
|
||||
instance = create_instance(instance_line, target_entity_type_list, target_relation_type_list)
|
||||
|
||||
query_instance_list.append(instance)
|
||||
|
||||
random.shuffle(support_instance_list)
|
||||
random.shuffle(query_instance_list)
|
||||
for instance in support_instance_list:
|
||||
f.write(json.dumps(instance) + "\n")
|
||||
for instance in query_instance_list:
|
||||
f.write(json.dumps(instance) + "\n")
|
||||
print("Task format changed.")
|
|
@ -0,0 +1,9 @@
|
|||
#!/usr/bin/env bash
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
for data_format in entity relation event absa
|
||||
do
|
||||
python uie_convert.py -format spotasoc -config data_config/${data_format} -output ${data_format}
|
||||
done
|
||||
|
||||
python scripts/data_statistics.py -data converted_data/text2spotasoc/
|
|
@ -0,0 +1,53 @@
|
|||
#!/usr/bin/env bash
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
export PYTHONPATH="${PYTHONPATH}:./"
|
||||
|
||||
for data_format in entity relation event absa
|
||||
do
|
||||
for dataset in $(ls converted_data/text2spotasoc/${data_format} | grep -v shot | grep -v ratio)
|
||||
do
|
||||
for seed in 1 2 3 4 5 6 7 8 9 10
|
||||
do
|
||||
rm -r converted_data/text2spotasoc/${data_format}/${dataset}_ratio/seed${seed}
|
||||
echo "Convert" converted_data/text2spotasoc/${data_format}/${dataset} "To" converted_data/text2spotasoc/${data_format}/${dataset}_ratio/seed${seed}
|
||||
python scripts/sample_data_ratio.py -seed ${seed} \
|
||||
-src converted_data/text2spotasoc/${data_format}/${dataset} \
|
||||
-tgt converted_data/text2spotasoc/${data_format}/${dataset}_ratio/seed${seed}
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
for data_format in entity relation event
|
||||
do
|
||||
for dataset in $(ls converted_data/text2spotasoc/${data_format} | grep -v shot | grep -v ratio)
|
||||
do
|
||||
for seed in 1 2 3 4 5 6 7 8 9 10
|
||||
do
|
||||
rm -r converted_data/text2spotasoc/${data_format}/${dataset}_shot/seed${seed}
|
||||
echo "Convert" converted_data/text2spotasoc/${data_format}/${dataset} "To" converted_data/text2spotasoc/${data_format}/${dataset}_shot/seed${seed}
|
||||
python scripts/sample_data_shot.py -seed ${seed} \
|
||||
-src converted_data/text2spotasoc/${data_format}/${dataset} \
|
||||
-tgt converted_data/text2spotasoc/${data_format}/${dataset}_shot/seed${seed} \
|
||||
-task ${data_format}
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
for data_format in absa
|
||||
do
|
||||
for dataset in $(ls converted_data/text2spotasoc/${data_format} | grep -v shot | grep -v ratio)
|
||||
do
|
||||
for seed in 1 2 3 4 5 6 7 8 9 10
|
||||
do
|
||||
rm -r converted_data/text2spotasoc/${data_format}/${dataset}_shot/seed${seed}
|
||||
echo "Convert" converted_data/text2spotasoc/${data_format}/${dataset} "To" converted_data/text2spotasoc/${data_format}/${dataset}_shot/seed${seed}
|
||||
python scripts/sample_data_shot.py -seed ${seed} \
|
||||
-src converted_data/text2spotasoc/${data_format}/${dataset} \
|
||||
-tgt converted_data/text2spotasoc/${data_format}/${dataset}_shot/seed${seed} \
|
||||
-task relation
|
||||
done
|
||||
done
|
||||
done
|
|
@ -0,0 +1,95 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from collections import Counter
|
||||
import tabulate
|
||||
|
||||
|
||||
def count_line_in_file(filename):
|
||||
return sum([1 for _ in open(filename)])
|
||||
|
||||
|
||||
def count_record_in_file(filename, key):
|
||||
counter = Counter()
|
||||
for line in open(filename):
|
||||
instance = json.loads(line)
|
||||
counter.update([key + ' entity'] * len(instance['entity']))
|
||||
counter.update([key + ' relation'] * len(instance['relation']))
|
||||
counter.update([key + ' event'] * len(instance['event']))
|
||||
for event in instance['event']:
|
||||
counter.update([key + ' role'] * len(event['args']))
|
||||
return counter
|
||||
|
||||
|
||||
def count_folder(folder_name):
|
||||
data_map = {
|
||||
'train': 'train.json',
|
||||
'val': 'val.json',
|
||||
'test': 'test.json',
|
||||
}
|
||||
intance_counter = {'name': folder_name}
|
||||
for key, name in data_map.items():
|
||||
filename = f"{folder_name}/{name}"
|
||||
if not os.path.exists(filename):
|
||||
sys.stderr.write(f'[warn] {filename} not exists.\n')
|
||||
continue
|
||||
intance_counter[key] = count_line_in_file(filename)
|
||||
intance_counter.update(count_record_in_file(filename, key))
|
||||
|
||||
for key in ['entity', 'relation', 'event']:
|
||||
filename = f"{folder_name}/{key}.schema"
|
||||
if not os.path.exists(filename):
|
||||
sys.stderr.write(f'[warn] {filename} not exists.\n')
|
||||
intance_counter[key] = 0
|
||||
continue
|
||||
intance_counter[key] = len(json.loads(open(filename).readline()))
|
||||
|
||||
return intance_counter
|
||||
|
||||
|
||||
def walk_dir(folder_name):
|
||||
|
||||
for root, dirs, files in os.walk(folder_name):
|
||||
for file in dirs:
|
||||
folder_name = os.path.join(root, file)
|
||||
if os.path.exists(f"{os.path.join(root, file)}/record.schema"):
|
||||
yield os.path.join(root, file)
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-data')
|
||||
parser.add_argument('-f', dest='format', default='simple')
|
||||
options = parser.parse_args()
|
||||
|
||||
folder_list = list()
|
||||
|
||||
for folder_name in walk_dir(options.data):
|
||||
if 'shot' in folder_name or 'ratio' in folder_name:
|
||||
continue
|
||||
folder_list += [count_folder(folder_name)]
|
||||
|
||||
col_name = ['name',
|
||||
'entity', 'relation', 'event',
|
||||
'train', 'val', 'test',
|
||||
'train entity', 'train relation', 'train event', 'train role',
|
||||
'val entity', 'val relation', 'val event', 'val role',
|
||||
'test entity', 'test relation', 'test event', 'test role',
|
||||
]
|
||||
table = []
|
||||
for data_info in folder_list:
|
||||
row = [data_info.get(col, 0) for col in col_name]
|
||||
table += [row]
|
||||
table.sort()
|
||||
print(
|
||||
tabulate.tabulate(
|
||||
tabular_data=table,
|
||||
headers=col_name,
|
||||
tablefmt=options.format,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,52 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
import math
|
||||
import shutil
|
||||
import random
|
||||
import argparse
|
||||
|
||||
|
||||
def split_ratio_file(in_filename, out_filename, ratio=0.1, seed=None):
|
||||
lines = open(in_filename).readlines()
|
||||
if seed:
|
||||
random.seed(seed)
|
||||
random.shuffle(lines)
|
||||
lines = lines[:math.ceil(len(lines) * ratio)]
|
||||
with open(out_filename, 'w') as output:
|
||||
for line in lines:
|
||||
output.write(line.strip() + '\n')
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-src')
|
||||
parser.add_argument('-tgt')
|
||||
parser.add_argument('-seed')
|
||||
options = parser.parse_args()
|
||||
|
||||
source_folder = options.src
|
||||
target_folder = options.tgt
|
||||
|
||||
os.makedirs(target_folder, exist_ok=True)
|
||||
|
||||
for ratio in [0.01, 0.05, 0.1]:
|
||||
ratio_folder = os.path.join(target_folder, "%s" % ratio)
|
||||
|
||||
os.makedirs(ratio_folder, exist_ok=True)
|
||||
split_ratio_file(
|
||||
in_filename=os.path.join(source_folder, 'train.json'),
|
||||
out_filename=os.path.join(ratio_folder, 'train.json'),
|
||||
ratio=ratio,
|
||||
seed=options.seed,
|
||||
)
|
||||
for filename in os.listdir(source_folder):
|
||||
if filename != 'train.json':
|
||||
shutil.copy(
|
||||
os.path.join(source_folder, filename),
|
||||
os.path.join(ratio_folder, filename),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,101 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
import shutil
|
||||
import random
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
import json
|
||||
import sys
|
||||
from universal_ie.record_schema import RecordSchema
|
||||
|
||||
|
||||
def n_shot_smaple(source_filename, target_filename, record_schema,
|
||||
spot_asoc_key='spot', num_shot=5, min_len=None, seed=None):
|
||||
|
||||
train_data = [json.loads(line.strip()) for line in open(source_filename)]
|
||||
|
||||
if seed:
|
||||
random.seed(seed)
|
||||
random.shuffle(train_data)
|
||||
|
||||
# 记录每一句的类别信息
|
||||
type_to_sentence_dict = defaultdict(list)
|
||||
for index, instance in enumerate(train_data):
|
||||
for spot in instance[spot_asoc_key]:
|
||||
if spot not in record_schema.type_list:
|
||||
continue
|
||||
if min_len is not None and len(instance['tokens']) < min_len:
|
||||
continue
|
||||
type_to_sentence_dict[spot] += [index]
|
||||
|
||||
sampled_data = list()
|
||||
for entity in type_to_sentence_dict:
|
||||
|
||||
if len(type_to_sentence_dict[entity]) < num_shot:
|
||||
sys.stderr.write(
|
||||
f'[WARN] {entity} in {source_filename} is less than shot num {num_shot}\n'
|
||||
)
|
||||
sampled = type_to_sentence_dict[entity]
|
||||
else:
|
||||
sampled = random.sample(type_to_sentence_dict[entity], num_shot)
|
||||
|
||||
sampled_data += [train_data[index] for index in sampled]
|
||||
|
||||
with open(target_filename, 'w') as output:
|
||||
for instance in sampled_data:
|
||||
output.write(json.dumps(instance) + '\n')
|
||||
|
||||
return sampled_data
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-src', help='Source Folder Name', required=True)
|
||||
parser.add_argument('-tgt', help='Target Folder Name, n shot sampled',
|
||||
required=True)
|
||||
parser.add_argument('-task', help='N-Shot Task name', required=True,
|
||||
choices=['entity', 'relation', 'event'])
|
||||
parser.add_argument('-seed', help='Default is None, no random')
|
||||
parser.add_argument('-min_len', dest='min_len', help='Default is None', type=int)
|
||||
options = parser.parse_args()
|
||||
|
||||
source_folder = options.src
|
||||
target_folder = options.tgt
|
||||
|
||||
task_name = options.task
|
||||
|
||||
if task_name in ['relation']:
|
||||
spot_asoc_key = 'asoc'
|
||||
else:
|
||||
spot_asoc_key = 'spot'
|
||||
|
||||
os.makedirs(target_folder, exist_ok=True)
|
||||
|
||||
for shot in [1, 5, 10]:
|
||||
shot_folder = os.path.join(target_folder, "%sshot" % shot)
|
||||
|
||||
os.makedirs(shot_folder, exist_ok=True)
|
||||
|
||||
n_shot_smaple(
|
||||
source_filename=os.path.join(source_folder, 'train.json'),
|
||||
target_filename=os.path.join(shot_folder, 'train.json'),
|
||||
record_schema=RecordSchema.read_from_file(
|
||||
os.path.join(source_folder, f'{task_name}.schema'),
|
||||
),
|
||||
spot_asoc_key=spot_asoc_key,
|
||||
num_shot=shot,
|
||||
seed=options.seed,
|
||||
min_len=options.min_len
|
||||
)
|
||||
|
||||
for filename in os.listdir(source_folder):
|
||||
if filename != 'train.json':
|
||||
shutil.copy(
|
||||
os.path.join(source_folder, filename),
|
||||
os.path.join(shot_folder, filename),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,86 @@
|
|||
from transformers import AutoTokenizer
|
||||
import json
|
||||
import argparse
|
||||
import tabulate
|
||||
from universal_ie.record_schema import RecordSchema
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-m', '--model', default='t5-base')
|
||||
parser.add_argument('-d', '--data', required=True)
|
||||
parser.add_argument('-s', '--schema', default='event')
|
||||
options = parser.parse_args()
|
||||
|
||||
if "chinese_t5_pegasus" in options.model:
|
||||
tokenizer = T5PegasusTokenizer.from_pretrained(options.model)
|
||||
tokenizer.bos_token = tokenizer.cls_token
|
||||
tokenizer.eos_token = tokenizer.sep_token
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
options.model,
|
||||
use_fast=False,
|
||||
mirror='tuna',
|
||||
)
|
||||
|
||||
tokenizer.add_special_tokens(
|
||||
{"additional_special_tokens": ["<extra_id_0>", "<extra_id_1>"]}
|
||||
)
|
||||
|
||||
folder_path = options.data
|
||||
|
||||
schema_file = f"{folder_path}/{options.schema}.schema"
|
||||
|
||||
event_schema = RecordSchema.read_from_file(schema_file)
|
||||
|
||||
table = list()
|
||||
for typename in event_schema.type_list:
|
||||
typename = typename.replace('_', ' ')
|
||||
after_tokenzied = tokenizer.encode(typename, add_special_tokens=False)
|
||||
table += [[typename,
|
||||
after_tokenzied,
|
||||
tokenizer.convert_ids_to_tokens(after_tokenzied)]]
|
||||
|
||||
print(tokenizer)
|
||||
print(type(tokenizer))
|
||||
|
||||
print("===============Event Schema=================")
|
||||
print(tabulate.tabulate(
|
||||
table,
|
||||
headers=['type', 'token id', 'tokenized'],
|
||||
tablefmt='grid',
|
||||
))
|
||||
|
||||
print("===============Instance=================")
|
||||
|
||||
table = list()
|
||||
for index, instance in enumerate(open(folder_path + "/val.json").readlines()[:10]):
|
||||
instance = json.loads(instance)
|
||||
table += [["Text %s" % index] + [instance['text']]]
|
||||
table += [["Token %s" % index] +
|
||||
['|'.join(tokenizer.tokenize(instance['text']))]]
|
||||
if 'entity' in instance:
|
||||
table += [["Entity %s" % index] +
|
||||
['|'.join(tokenizer.tokenize(instance['event']))]]
|
||||
if 'relation' in instance:
|
||||
table += [["Relation %s" % index] +
|
||||
['|'.join(tokenizer.tokenize(instance['relation']))]]
|
||||
if 'event' in instance:
|
||||
table += [["Event %s" % index] +
|
||||
['|'.join(tokenizer.tokenize(instance['event']))]]
|
||||
print(tabulate.tabulate(table, headers=['text', 'event'], tablefmt='grid'))
|
||||
|
||||
print("===============Specical Symbol=================")
|
||||
table = list()
|
||||
|
||||
for name in ['<extra_id_0>', '<extra_id_1>']:
|
||||
table += [[name, tokenizer.encode(name), tokenizer.tokenize(name)]]
|
||||
print(tabulate.tabulate(
|
||||
table,
|
||||
headers=['specical symbol', 'token id', 'tokenized'],
|
||||
tablefmt='grid'
|
||||
))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,216 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
from collections import Counter
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, List
|
||||
from tqdm import tqdm
|
||||
from universal_ie.generation_format.generation_format import GenerationFormat
|
||||
from universal_ie.generation_format import generation_format_dict
|
||||
from universal_ie.generation_format.structure_marker import BaseStructureMarker
|
||||
from universal_ie.dataset import Dataset
|
||||
from universal_ie.ie_format import Sentence
|
||||
|
||||
|
||||
def convert_graph(
|
||||
generation_class: GenerationFormat,
|
||||
output_folder: str,
|
||||
datasets: Dict[str, List[Sentence]],
|
||||
language: str = "en",
|
||||
label_mapper: Dict = None,
|
||||
):
|
||||
convertor = generation_class(
|
||||
structure_maker=BaseStructureMarker(),
|
||||
language=language,
|
||||
label_mapper=label_mapper,
|
||||
)
|
||||
|
||||
counter = Counter()
|
||||
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
schema_counter = {
|
||||
"entity": list(),
|
||||
"relation": list(),
|
||||
"event": list(),
|
||||
}
|
||||
for data_type, instance_list in datasets.items():
|
||||
with open(os.path.join(output_folder, f"{data_type}.json"), "w") as output:
|
||||
for instance in tqdm(instance_list):
|
||||
counter.update([f"{data_type} sent"])
|
||||
converted_graph = convertor.annonote_graph(
|
||||
tokens=instance.tokens,
|
||||
entities=instance.entities,
|
||||
relations=instance.relations,
|
||||
events=instance.events,
|
||||
)
|
||||
src, tgt, spot_labels, asoc_labels = converted_graph[:4]
|
||||
spot_asoc = converted_graph[4]
|
||||
|
||||
schema_counter["entity"] += instance.entities
|
||||
schema_counter["relation"] += instance.relations
|
||||
schema_counter["event"] += instance.events
|
||||
|
||||
output.write(
|
||||
"%s\n"
|
||||
% json.dumps(
|
||||
{
|
||||
"text": src,
|
||||
"tokens": instance.tokens,
|
||||
"record": tgt,
|
||||
"entity": [
|
||||
entity.to_offset(label_mapper)
|
||||
for entity in instance.entities
|
||||
],
|
||||
"relation": [
|
||||
relation.to_offset(
|
||||
ent_label_mapper=label_mapper,
|
||||
rel_label_mapper=label_mapper,
|
||||
)
|
||||
for relation in instance.relations
|
||||
],
|
||||
"event": [
|
||||
event.to_offset(evt_label_mapper=label_mapper)
|
||||
for event in instance.events
|
||||
],
|
||||
"spot": list(spot_labels),
|
||||
"asoc": list(asoc_labels),
|
||||
"spot_asoc": spot_asoc,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
convertor.output_schema(os.path.join(output_folder, "record.schema"))
|
||||
convertor.get_entity_schema(schema_counter["entity"]).write_to_file(
|
||||
os.path.join(output_folder, f"entity.schema")
|
||||
)
|
||||
convertor.get_relation_schema(schema_counter["relation"]).write_to_file(
|
||||
os.path.join(output_folder, f"relation.schema")
|
||||
)
|
||||
convertor.get_event_schema(schema_counter["event"]).write_to_file(
|
||||
os.path.join(output_folder, f"event.schema")
|
||||
)
|
||||
print(counter)
|
||||
print(output_folder)
|
||||
print("==========================")
|
||||
|
||||
|
||||
def convert_to_oneie(output_folder: str, datasets: Dict[str, List[Sentence]]):
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
counter = Counter()
|
||||
|
||||
for data_type, instance_list in datasets.items():
|
||||
with open(
|
||||
os.path.join(output_folder, f"{data_type}.oneie.json"), "w"
|
||||
) as output:
|
||||
for instance in tqdm(instance_list):
|
||||
counter.update([f"{data_type} sent"])
|
||||
entity_mentions = [
|
||||
{
|
||||
"id": entity.record_id,
|
||||
"entity_type": str(entity.label),
|
||||
"text": entity.span.text,
|
||||
"start": entity.span.indexes[0],
|
||||
"end": entity.span.indexes[-1] + 1,
|
||||
}
|
||||
for entity in instance.entities
|
||||
]
|
||||
relation_mentions = [
|
||||
{
|
||||
"id": relation.record_id,
|
||||
"relation_type": str(relation.label),
|
||||
"argument": [
|
||||
{
|
||||
"entity_id": relation.arg1.record_id,
|
||||
"text": relation.arg1.span.text,
|
||||
"role": "Arg-1",
|
||||
},
|
||||
{
|
||||
"entity_id": relation.arg2.record_id,
|
||||
"text": relation.arg2.span.text,
|
||||
"role": "Arg-2",
|
||||
},
|
||||
],
|
||||
}
|
||||
for relation in instance.relations
|
||||
]
|
||||
event_mentions = [
|
||||
{
|
||||
"id": event.record_id,
|
||||
"event_type": str(event.label),
|
||||
"trigger": {
|
||||
"text": event.span.text,
|
||||
"start": event.span.indexes[0],
|
||||
"end": event.span.indexes[-1] + 1,
|
||||
},
|
||||
"argument": [
|
||||
{
|
||||
"id": arg[1].record_id,
|
||||
"text": arg[1].span.text,
|
||||
"role": str(arg[0]),
|
||||
}
|
||||
for arg in event.args
|
||||
],
|
||||
}
|
||||
for event in instance.events
|
||||
]
|
||||
|
||||
instance_dict = {
|
||||
"tokens": instance.tokens,
|
||||
"sent_id": instance.text_id,
|
||||
"entity_mentions": entity_mentions,
|
||||
"relation_mentions": relation_mentions,
|
||||
"event_mentions": event_mentions,
|
||||
}
|
||||
instance_str = json.dumps(instance_dict, ensure_ascii=False)
|
||||
output.write(f"{instance_str}\n")
|
||||
|
||||
print(counter)
|
||||
print(output_folder)
|
||||
print("==========================")
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-format", dest="generation_format", default="spotasoc")
|
||||
parser.add_argument("-config", dest="config", default="data_config/relation")
|
||||
parser.add_argument("-output", dest="output", default="relation")
|
||||
options = parser.parse_args()
|
||||
|
||||
generation_class = generation_format_dict.get(options.generation_format)
|
||||
|
||||
if os.path.isfile(options.config):
|
||||
config_list = [options.config]
|
||||
else:
|
||||
config_list = [
|
||||
os.path.join(options.config, x) for x in os.listdir(options.config)
|
||||
]
|
||||
|
||||
for filename in config_list:
|
||||
dataset = Dataset.load_yaml_file(filename)
|
||||
|
||||
datasets = dataset.load_dataset()
|
||||
label_mapper = dataset.mapper
|
||||
print(label_mapper)
|
||||
|
||||
output_name = (
|
||||
f"converted_data/text2{options.generation_format}/{options.output}/"
|
||||
+ dataset.name
|
||||
)
|
||||
|
||||
if generation_class:
|
||||
convert_graph(
|
||||
generation_class,
|
||||
output_name,
|
||||
datasets=datasets,
|
||||
language=dataset.language,
|
||||
label_mapper=label_mapper,
|
||||
)
|
||||
elif options.generation_format == "oneie":
|
||||
convert_to_oneie(output_name, datasets=datasets)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,49 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
from universal_ie.utils import label_format
|
||||
import yaml
|
||||
import os
|
||||
from typing import Dict
|
||||
import universal_ie.task_format as task_format
|
||||
|
||||
|
||||
class Dataset:
|
||||
def __init__(self, name: str, path: str, data_class: task_format.TaskFormat, split_dict: Dict, language: str, mapper: Dict, other: Dict = None) -> None:
|
||||
self.name = name
|
||||
self.path = path
|
||||
self.data_class = data_class
|
||||
self.split_dict = split_dict
|
||||
self.language = language
|
||||
self.mapper = mapper
|
||||
self.other = other
|
||||
|
||||
def load_dataset(self):
|
||||
datasets = {}
|
||||
for split_name, filename in self.split_dict.items():
|
||||
datasets[split_name] = self.data_class.load_from_file(
|
||||
filename=os.path.join(self.path, filename),
|
||||
language=self.language,
|
||||
**self.other,
|
||||
)
|
||||
return datasets
|
||||
|
||||
@staticmethod
|
||||
def load_yaml_file(yaml_file):
|
||||
dataset_config = yaml.load(open(yaml_file), Loader=yaml.FullLoader)
|
||||
if 'mapper' in dataset_config:
|
||||
mapper = dataset_config['mapper']
|
||||
for key in mapper:
|
||||
mapper[key] = label_format(mapper[key])
|
||||
else:
|
||||
print(f"{dataset_config['name']} without label mapper.")
|
||||
mapper = None
|
||||
|
||||
return Dataset(
|
||||
name=dataset_config['name'], # 数据集名字 Name of Dataset
|
||||
path=dataset_config['path'], # 数据集路径 Path of Dataset
|
||||
data_class=getattr(task_format, dataset_config['data_class']), # 数据集对应的 Task Format 名字 Raw data loader
|
||||
split_dict=dataset_config['split'], # 数据集不同划分文件地址 Data Split Path
|
||||
language=dataset_config['language'], # 数据集语言 Dataset Language
|
||||
mapper=mapper,
|
||||
other=dataset_config.get('other', {}),
|
||||
)
|
|
@ -0,0 +1,8 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
from universal_ie.generation_format.text2spotasoc import Text2SpotAsoc
|
||||
|
||||
|
||||
generation_format_dict = {
|
||||
'spotasoc': Text2SpotAsoc
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
from typing import List, Dict, Union
|
||||
from collections import defaultdict
|
||||
from universal_ie.record_schema import RecordSchema
|
||||
from universal_ie.generation_format.structure_marker import StructureMarker
|
||||
from universal_ie.ie_format import Entity, Relation, Event, Label
|
||||
import abc
|
||||
|
||||
|
||||
class GenerationFormat:
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
def __init__(self,
|
||||
structure_maker: StructureMarker,
|
||||
label_mapper: Dict = None,
|
||||
language: str = 'en') -> None:
|
||||
self.structure_maker = structure_maker
|
||||
self.language = language
|
||||
self.label_mapper = {} if label_mapper is None else label_mapper
|
||||
|
||||
# 用于从数据中统计 Schema
|
||||
self.record_role_map = defaultdict(set)
|
||||
|
||||
def get_label_str(self, label: Label):
|
||||
return self.label_mapper.get(label.__repr__(), label.__repr__())
|
||||
|
||||
@abc.abstractmethod
|
||||
def annotate_entities(
|
||||
self, tokens: List[str], entities: List[Entity]): pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def annotate_given_entities(self, tokens: List[str], entities: Union[List[Entity], Entity]): pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def annotate_events(self, tokens: List[str], events: List[Event]): pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def annotate_event_given_predicate(self, tokens: List[str], event: Event): pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def annotate_relation_extraction(self, tokens: List[str],
|
||||
relations: List[Relation]): pass
|
||||
|
||||
def output_schema(self, filename: str):
|
||||
"""自动导出 Schema 文件
|
||||
每个 Schema 文件包含三行
|
||||
- 第一行为 Record 的类别名称列表
|
||||
- 第二行为 Role 的类别名称列表
|
||||
- 第三行为 Record-Role 映射关系字典
|
||||
Args:
|
||||
filename (str): [description]
|
||||
"""
|
||||
record_list = list(self.record_role_map.keys())
|
||||
role_set = set()
|
||||
for record in self.record_role_map:
|
||||
role_set.update(self.record_role_map[record])
|
||||
self.record_role_map[record] = list(self.record_role_map[record])
|
||||
role_list = list(role_set)
|
||||
|
||||
record_schema = RecordSchema(type_list=record_list,
|
||||
role_list=role_list,
|
||||
type_role_dict=self.record_role_map
|
||||
)
|
||||
record_schema.write_to_file(filename)
|
||||
|
||||
def get_entity_schema(self, entities: List[Entity]):
|
||||
schema_role_map = set()
|
||||
for entity in entities:
|
||||
schema_role_map.add(self.get_label_str(entity.label))
|
||||
return RecordSchema(
|
||||
type_list=list(schema_role_map),
|
||||
role_list=list(),
|
||||
type_role_dict=dict()
|
||||
)
|
||||
|
||||
def get_relation_schema(self, relations: List[Relation]):
|
||||
record_role_map = defaultdict(set)
|
||||
role_set = set()
|
||||
|
||||
for relation in relations:
|
||||
record_role_map[self.get_label_str(relation.label)].add(self.get_label_str(relation.arg1.label))
|
||||
record_role_map[self.get_label_str(relation.label)].add(self.get_label_str(relation.arg2.label))
|
||||
|
||||
for record in record_role_map:
|
||||
role_set.update(record_role_map[record])
|
||||
record_role_map[record] = list(self.record_role_map[record])
|
||||
|
||||
return RecordSchema(
|
||||
type_list=list(record_role_map.keys()),
|
||||
role_list=list(role_set),
|
||||
type_role_dict=record_role_map
|
||||
)
|
||||
|
||||
def get_event_schema(self, events: List[Event]):
|
||||
record_role_map = defaultdict(set)
|
||||
role_set = set()
|
||||
|
||||
for event in events:
|
||||
for role, _ in event.args:
|
||||
record_role_map[self.get_label_str(event.label)].add(self.get_label_str(role))
|
||||
|
||||
for record in record_role_map:
|
||||
role_set.update(record_role_map[record])
|
||||
record_role_map[record] = list(self.record_role_map[record])
|
||||
|
||||
return RecordSchema(
|
||||
type_list=list(record_role_map.keys()),
|
||||
role_list=list(role_set),
|
||||
type_role_dict=record_role_map
|
||||
)
|
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
# 结构标记符
|
||||
|
||||
|
||||
class StructureMarker:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class BaseStructureMarker(StructureMarker):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.sent_start = '<extra_id_0>'
|
||||
self.sent_end = '<extra_id_1>'
|
||||
self.record_start = '<extra_id_0>'
|
||||
self.record_end = '<extra_id_1>'
|
||||
self.span_start = '<extra_id_0>'
|
||||
self.span_end = '<extra_id_1>'
|
||||
self.sep_marker = '<extra_id_2>'
|
||||
self.source_span_start = '<extra_id_3>'
|
||||
self.source_span_end = '<extra_id_4>'
|
||||
self.target_span_start = '<extra_id_5>'
|
||||
|
||||
|
||||
class VisualStructureMarker(StructureMarker):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.sent_start = '{'
|
||||
self.sent_end = '}'
|
||||
self.record_start = '['
|
||||
self.record_end = ']'
|
||||
self.span_start = '('
|
||||
self.span_end = ')'
|
||||
self.source_span_start = '<'
|
||||
self.source_span_end = '>'
|
||||
self.target_span_start = ':'
|
||||
self.sep_marker = ':'
|
|
@ -0,0 +1,258 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
from collections import defaultdict
|
||||
from typing import List, Dict
|
||||
from universal_ie.utils import tokens_to_str
|
||||
from universal_ie.generation_format.generation_format import GenerationFormat, StructureMarker
|
||||
from universal_ie.ie_format import Entity, Event, Label, Relation, Span
|
||||
|
||||
|
||||
def convert_spot_asoc(spot_asoc_instance, structure_maker):
|
||||
spot_instance_str_rep_list = list()
|
||||
for spot in spot_asoc_instance:
|
||||
spot_str_rep = [
|
||||
spot['label'],
|
||||
structure_maker.target_span_start,
|
||||
spot['span'],
|
||||
]
|
||||
for asoc_label, asoc_span in spot.get('asoc', list()):
|
||||
asoc_str_rep = [
|
||||
structure_maker.span_start,
|
||||
asoc_label,
|
||||
structure_maker.target_span_start,
|
||||
asoc_span,
|
||||
structure_maker.span_end,
|
||||
]
|
||||
spot_str_rep += [' '.join(asoc_str_rep)]
|
||||
spot_instance_str_rep_list += [' '.join([
|
||||
structure_maker.record_start,
|
||||
' '.join(spot_str_rep),
|
||||
structure_maker.record_end,
|
||||
])]
|
||||
target_text = ' '.join([
|
||||
structure_maker.sent_start,
|
||||
' '.join(spot_instance_str_rep_list),
|
||||
structure_maker.sent_end,
|
||||
])
|
||||
return target_text
|
||||
|
||||
|
||||
class Text2SpotAsoc(GenerationFormat):
|
||||
def __init__(self, structure_maker: StructureMarker, label_mapper: Dict = None, language: str = 'en') -> None:
|
||||
super().__init__(
|
||||
structure_maker=structure_maker,
|
||||
label_mapper=label_mapper,
|
||||
language=language
|
||||
)
|
||||
|
||||
def annotate_entities(self, tokens: List[str], entities: List[Entity]):
|
||||
""" Convert Entities
|
||||
|
||||
Args:
|
||||
tokens (List[str]): ['Trump', 'visits', 'China', '.']
|
||||
entities (List[Entity]): [description]
|
||||
|
||||
Returns:
|
||||
source (str): Trump visits China.
|
||||
target (str): { [ Person : Trump ] [ Geo-political : China ] }
|
||||
"""
|
||||
return self.annonote_graph(tokens=tokens, entities=entities)[:2]
|
||||
|
||||
def augment_source_span(self, tokens: List[str], span: Span):
|
||||
"""[summary]
|
||||
|
||||
Args:
|
||||
tokens (List[str]):
|
||||
['Trump', 'visits', 'China', '.']
|
||||
span (Span):
|
||||
Trump
|
||||
|
||||
Returns:
|
||||
[type]:
|
||||
['(', 'Trump', ')', 'visits', 'China', '.']
|
||||
"""
|
||||
return tokens[:span.indexes[0]] \
|
||||
+ [self.structure_maker.source_span_start] \
|
||||
+ tokens[span.indexes[0]:span.indexes[-1] + 1] \
|
||||
+ [self.structure_maker.source_span_end] \
|
||||
+ tokens[span.indexes[-1] + 1:]
|
||||
|
||||
def annotate_given_entities(self, tokens: List[str], entities):
|
||||
"""
|
||||
entityies is List
|
||||
:param tokens:
|
||||
['Trump', 'visits', 'China', '.']
|
||||
:param entities:
|
||||
['Trump', 'China']
|
||||
:return:
|
||||
source (str): ( Trump ) ( China ) : Trump visits China .
|
||||
target (str): { [ Person : Trump ] [ Geo-political : China ] }
|
||||
|
||||
entityies is Entity
|
||||
:param tokens:
|
||||
['Trump', 'visits', 'China', '.']
|
||||
:param entities:
|
||||
'Trump'
|
||||
:return:
|
||||
source (str): < Trump > visits China .
|
||||
target (str): { [ Person : Trump ] }
|
||||
"""
|
||||
if isinstance(entities, list):
|
||||
entitytokens = []
|
||||
for entity in entities:
|
||||
entitytokens += [self.structure_maker.span_start]
|
||||
entitytokens += entity.span.tokens
|
||||
entitytokens += [self.structure_maker.span_end]
|
||||
source_text = tokens_to_str(
|
||||
entitytokens + [self.structure_maker.sep_marker] + tokens,
|
||||
language=self.language,
|
||||
)
|
||||
_, target_text = self.annonote_graph(tokens=tokens, entities=entities)[:2]
|
||||
|
||||
elif isinstance(entities, Entity):
|
||||
marked_tokens = self.augment_source_span(tokens=tokens, span=entities.span)
|
||||
source_text = tokens_to_str(marked_tokens, language=self.language)
|
||||
_, target_text = self.annonote_graph(tokens=tokens, entities=[entities])[:2]
|
||||
|
||||
return source_text, target_text
|
||||
|
||||
def annotate_events(self, tokens: List[str], events: List[Event]):
|
||||
"""
|
||||
:param tokens:
|
||||
['Trump', 'visits', 'China', '.']
|
||||
:param events:
|
||||
|
||||
:return:
|
||||
source (str): Trump visits China.
|
||||
target (str): { [ Visit : visits ( Person : Trump ) ( Location : China ) ] }
|
||||
"""
|
||||
return self.annonote_graph(tokens=tokens, events=events)[:2]
|
||||
|
||||
def annotate_event_given_predicate(self, tokens: List[str], event: Event):
|
||||
"""Annotate Event Given Predicate
|
||||
|
||||
Args:
|
||||
tokens (List[str]):
|
||||
['Trump', 'visits', 'China', '.']
|
||||
event (Event): Given Predicate
|
||||
|
||||
Returns:
|
||||
[type]: [description]
|
||||
"""
|
||||
marked_tokens = self.augment_source_span(tokens=tokens, span=event.span)
|
||||
source_text = tokens_to_str(marked_tokens, language=self.language)
|
||||
_, target_text = self.annonote_graph(tokens=tokens, events=[event])[:2]
|
||||
return source_text, target_text
|
||||
|
||||
def annotate_relation_extraction(self,
|
||||
tokens: List[str],
|
||||
relations: List[Relation]):
|
||||
"""
|
||||
:param tokens:
|
||||
['Trump', 'visits', 'China', '.']
|
||||
:param relations:
|
||||
|
||||
:return:
|
||||
source (str): Trump visits China.
|
||||
target (str): { [ Person : Trump ( Visit : China ) ] }
|
||||
"""
|
||||
return self.annonote_graph(tokens=tokens, relations=relations)[:2]
|
||||
|
||||
def annotate_entities_and_relation_extraction(self,
|
||||
tokens: List[str],
|
||||
entities: List[Entity],
|
||||
relations: List[Relation]):
|
||||
"""
|
||||
:param tokens:
|
||||
['Trump', 'visits', 'China', '.']
|
||||
:param relations:
|
||||
|
||||
:return:
|
||||
source (str): Trump visits China.
|
||||
target (str): { [ Person : Trump ( Visit : China ) ] [ Geo-political : China ] }
|
||||
"""
|
||||
return self.annonote_graph(tokens=tokens, entities=entities, relations=relations)[:2]
|
||||
|
||||
def annonote_graph(self,
|
||||
tokens: List[str],
|
||||
entities: List[Entity] = [],
|
||||
relations: List[Relation] = [],
|
||||
events: List[Event] = []):
|
||||
"""Convert Entity Relation Event to Spot-Assocation Graph
|
||||
|
||||
Args:
|
||||
tokens (List[str]): Token List
|
||||
entities (List[Entity], optional): Entity List. Defaults to [].
|
||||
relations (List[Relation], optional): Relation List. Defaults to [].
|
||||
events (List[Event], optional): Event List. Defaults to [].
|
||||
|
||||
Returns:
|
||||
str: [description]
|
||||
{
|
||||
[ Person : Trump ( Visit : China ) ]
|
||||
[ Visit : visits ( Person : Trump ) ( Location : China ) ]
|
||||
[ Geo-political : China ]
|
||||
}
|
||||
set: Set of Spot
|
||||
set: Set of Asoc
|
||||
"""
|
||||
spot_dict = dict()
|
||||
asoc_dict = defaultdict(list)
|
||||
spot_str_rep_list = list()
|
||||
|
||||
def add_spot(spot):
|
||||
spot_key = (tuple(spot.span.indexes), self.get_label_str(spot.label))
|
||||
spot_dict[spot_key] = spot
|
||||
|
||||
if self.get_label_str(spot.label) not in self.record_role_map:
|
||||
self.record_role_map[self.get_label_str(spot.label)] = set()
|
||||
|
||||
def add_asoc(spot, asoc: Label, tail):
|
||||
spot_key = (tuple(spot.span.indexes), self.get_label_str(spot.label))
|
||||
asoc_dict[spot_key] += [(tail.span.indexes, tail, self.get_label_str(asoc))]
|
||||
|
||||
self.record_role_map[self.get_label_str(spot.label)].add(self.get_label_str(asoc))
|
||||
|
||||
for entity in entities:
|
||||
add_spot(spot=entity)
|
||||
|
||||
for relation in relations:
|
||||
add_spot(spot=relation.arg1)
|
||||
add_asoc(spot=relation.arg1, asoc=relation.label, tail=relation.arg2)
|
||||
|
||||
for event in events:
|
||||
add_spot(spot=event)
|
||||
for arg_role, argument in event.args:
|
||||
add_asoc(spot=event, asoc=arg_role, tail=argument)
|
||||
|
||||
spot_asoc_instance = list()
|
||||
for spot_key in sorted(spot_dict.keys()):
|
||||
offset, label = spot_key
|
||||
|
||||
if spot_dict[spot_key].span.is_empty_span():
|
||||
continue
|
||||
|
||||
spot_instance = {'span': spot_dict[spot_key].span.text,
|
||||
'label': label,
|
||||
'asoc': list(),
|
||||
}
|
||||
for _, tail, asoc in sorted(asoc_dict.get(spot_key, [])):
|
||||
|
||||
if tail.span.is_empty_span():
|
||||
continue
|
||||
|
||||
spot_instance['asoc'] += [(asoc, tail.span.text)]
|
||||
spot_asoc_instance += [spot_instance]
|
||||
|
||||
target_text = convert_spot_asoc(
|
||||
spot_asoc_instance,
|
||||
structure_maker=self.structure_maker,
|
||||
)
|
||||
|
||||
source_text = tokens_to_str(tokens, language=self.language)
|
||||
spot_labels = set([label for _, label in spot_dict.keys()])
|
||||
asoc_labels = set()
|
||||
for _, asoc_list in asoc_dict.items():
|
||||
for _, _, asoc in asoc_list:
|
||||
asoc_labels.add(asoc)
|
||||
return source_text, target_text, spot_labels, asoc_labels, spot_asoc_instance
|
|
@ -0,0 +1,211 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from typing import List, Union, Tuple
|
||||
from universal_ie.utils import change_name_using_label_mapper
|
||||
|
||||
|
||||
# All Entity Relation Events are structured records.
|
||||
# They both have attributes text_id and record_id
|
||||
# 所有的 Entity Relation Event 都是结构化的记录表示 (Record)
|
||||
# 他们都有属性 text_id 和 record_id
|
||||
class Record:
|
||||
def __init__(self,
|
||||
text_id: Union[str, None] = None,
|
||||
record_id: Union[str, None] = None,
|
||||
) -> None:
|
||||
self.text_id = text_id
|
||||
self.record_id = record_id
|
||||
|
||||
@abstractmethod
|
||||
def to_offset(self):
|
||||
pass
|
||||
|
||||
|
||||
# Text span
|
||||
# 连续或者非连续的文本块
|
||||
class Span:
|
||||
def __init__(self,
|
||||
tokens: List[str],
|
||||
indexes: List[int],
|
||||
text: str,
|
||||
text_id: Union[str, None] = None,
|
||||
) -> None:
|
||||
self.tokens = tokens
|
||||
self.indexes = indexes
|
||||
self.text = text
|
||||
self.text_id = text_id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "[%s](%s)" % (self.text, self.indexes)
|
||||
|
||||
@staticmethod
|
||||
def get_empty_span(text_id: Union[str, None] = None,):
|
||||
return Span(
|
||||
tokens=list(),
|
||||
indexes=list(),
|
||||
text="",
|
||||
text_id=text_id
|
||||
)
|
||||
|
||||
def is_empty_span(self):
|
||||
"""Check is empty span.
|
||||
|
||||
Returns:
|
||||
bool: True, Empty Span; False Non-Empty Span
|
||||
"""
|
||||
return len(self.tokens) == 0 and len(self.indexes) == 0
|
||||
|
||||
|
||||
# Label Name
|
||||
class Label:
|
||||
def __init__(self, label_name: Union[str, List[str]]) -> None:
|
||||
self.label_name = label_name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.label_name
|
||||
|
||||
def __lt__(self, other):
|
||||
if not isinstance(other, Label):
|
||||
return NotImplemented
|
||||
return self.label_name < other.label_name
|
||||
|
||||
|
||||
# Entity, Span
|
||||
# 实体,以文本块为核心的一元结构
|
||||
class Entity(Record):
|
||||
def __init__(self,
|
||||
span: Span,
|
||||
label: Label,
|
||||
text_id: Union[str, None] = None,
|
||||
record_id: Union[str, None] = None,
|
||||
) -> None:
|
||||
super().__init__(text_id=text_id, record_id=record_id)
|
||||
self.span = span
|
||||
self.label = label
|
||||
|
||||
def __lt__(self, other):
|
||||
if not isinstance(other, Entity):
|
||||
return NotImplemented
|
||||
return self.span.indexes < other.span.indexes
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.span.__repr__() + self.label.__repr__()
|
||||
|
||||
def to_offset(self, ent_label_mapper=None):
|
||||
if self.span.is_empty_span():
|
||||
# If span is empty, skip entity
|
||||
return {}
|
||||
return {'type': change_name_using_label_mapper(self.label.label_name,
|
||||
ent_label_mapper),
|
||||
'offset': self.span.indexes,
|
||||
'text': self.span.text}
|
||||
|
||||
|
||||
# Relation Span Pair
|
||||
# 关系,以文本块对为核心的二元结构
|
||||
class Relation(Record):
|
||||
def __init__(self,
|
||||
arg1: Entity,
|
||||
arg2: Entity,
|
||||
label: Label,
|
||||
text_id: Union[str, None] = None,
|
||||
record_id: Union[str, None] = None,
|
||||
) -> None:
|
||||
super().__init__(text_id=text_id, record_id=record_id)
|
||||
self.arg1 = arg1
|
||||
self.arg2 = arg2
|
||||
self.label = label
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.arg1.__repr__() + self.label.__repr__() + self.arg2.__repr__()
|
||||
|
||||
def to_offset(self, rel_label_mapper=None, ent_label_mapper=None):
|
||||
if self.arg1.span.is_empty_span() or self.arg2.span.is_empty_span():
|
||||
# If span is empty, skip relation
|
||||
return {}
|
||||
return {'type': change_name_using_label_mapper(self.label.label_name,
|
||||
rel_label_mapper),
|
||||
'args': [self.arg1.to_offset(ent_label_mapper=ent_label_mapper),
|
||||
self.arg2.to_offset(ent_label_mapper=ent_label_mapper),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# Event, Trigger-Mult-Argument
|
||||
# 事件,以触发词为中心的多元(谓词论元)结构
|
||||
class Event(Record):
|
||||
def __init__(self,
|
||||
span: Span,
|
||||
label: Label,
|
||||
args: List[Tuple[Label, Entity]],
|
||||
text_id: Union[str, None] = None,
|
||||
record_id: Union[str, None] = None,
|
||||
) -> None:
|
||||
super().__init__(text_id=text_id, record_id=record_id)
|
||||
self.span = span
|
||||
self.label = label
|
||||
self.args = args
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.span.__repr__() + self.label.__repr__()
|
||||
|
||||
def to_offset(self, evt_label_mapper=None):
|
||||
|
||||
if self.span.is_empty_span():
|
||||
# If span is empty, skip relation
|
||||
return {}
|
||||
|
||||
args = list()
|
||||
for role, arg in self.args:
|
||||
if arg.span.is_empty_span():
|
||||
continue
|
||||
args += [{
|
||||
'type': change_name_using_label_mapper(
|
||||
role.label_name,
|
||||
evt_label_mapper,
|
||||
),
|
||||
'offset': arg.span.indexes,
|
||||
'text': arg.span.text
|
||||
}]
|
||||
|
||||
return {'type': change_name_using_label_mapper(self.label.label_name,
|
||||
evt_label_mapper),
|
||||
'offset': self.span.indexes,
|
||||
'text': self.span.text,
|
||||
'args': args}
|
||||
|
||||
|
||||
class Sentence:
|
||||
def __init__(self,
|
||||
tokens: List[str],
|
||||
entities: List[Entity] = None,
|
||||
relations: List[Relation] = None,
|
||||
events: List[Event] = None,
|
||||
text_id: Union[str, None] = None,
|
||||
) -> None:
|
||||
self.tokens = tokens
|
||||
self.entities = entities or list()
|
||||
self.relations = relations or list()
|
||||
self.events = events or list()
|
||||
self.text_id = text_id
|
||||
|
||||
def count_entity_without_relation(self):
|
||||
entity_set = set()
|
||||
entity_counter = defaultdict(int)
|
||||
for entity in self.entities:
|
||||
entity_set.add((tuple(entity.span.indexes), entity.label.label_name))
|
||||
|
||||
for relation in self.relations:
|
||||
entity1 = (tuple(relation.arg1.span.indexes), relation.arg1.label.label_name)
|
||||
entity2 = (tuple(relation.arg2.span.indexes), relation.arg2.label.label_name)
|
||||
entity_counter[entity1] += 1
|
||||
entity_counter[entity2] += 1
|
||||
entity_set.remove(entity1) if entity1 in entity_set else None
|
||||
entity_set.remove(entity2) if entity2 in entity_set else None
|
||||
overlap_entity = sum([1 if v > 1 else 0 for k, v in entity_counter.items()])
|
||||
return {'entity': len(self.entities),
|
||||
'entity_without_relation': len(entity_set),
|
||||
'overlap_entity': overlap_entity,
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
|
||||
|
||||
class RecordSchema:
|
||||
def __init__(self, type_list, role_list, type_role_dict):
|
||||
self.type_list = type_list
|
||||
self.role_list = role_list
|
||||
self.type_role_dict = type_role_dict
|
||||
|
||||
@staticmethod
|
||||
def read_from_file(filename):
|
||||
lines = open(filename).readlines()
|
||||
type_list = json.loads(lines[0])
|
||||
role_list = json.loads(lines[1])
|
||||
type_role_dict = json.loads(lines[2])
|
||||
return RecordSchema(type_list, role_list, type_role_dict)
|
||||
|
||||
def write_to_file(self, filename):
|
||||
with open(filename, 'w') as output:
|
||||
output.write(json.dumps(self.type_list, ensure_ascii=False) + '\n')
|
||||
output.write(json.dumps(self.role_list, ensure_ascii=False) + '\n')
|
||||
output.write(json.dumps(self.type_role_dict, ensure_ascii=False) + '\n')
|
||||
|
||||
|
||||
def merge_schema(schema_list: List[RecordSchema]):
|
||||
type_set = set()
|
||||
role_set = set()
|
||||
type_role_dict = defaultdict(list)
|
||||
|
||||
for schema in schema_list:
|
||||
|
||||
for type_name in schema.type_list:
|
||||
type_set.add(type_name)
|
||||
|
||||
for role_name in schema.role_list:
|
||||
role_set.add(role_name)
|
||||
|
||||
for type_name in schema.type_role_dict:
|
||||
type_role_dict[type_name] += schema.type_role_dict[type_name]
|
||||
|
||||
for type_name in type_role_dict:
|
||||
type_role_dict[type_name] = list(set(type_role_dict[type_name]))
|
||||
|
||||
return RecordSchema(type_list=list(type_set),
|
||||
role_list=list(role_set),
|
||||
type_role_dict=type_role_dict,
|
||||
)
|
|
@ -0,0 +1,16 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
from universal_ie.task_format.task_format import TaskFormat
|
||||
from universal_ie.task_format.oneie import OneIEEvent
|
||||
from universal_ie.task_format.jointer import JointER
|
||||
from universal_ie.task_format.mrc_ner import MRCNER
|
||||
from universal_ie.task_format.absa import ABSA
|
||||
from universal_ie.task_format.spannet import Spannet
|
||||
from universal_ie.task_format.casie import CASIE
|
||||
from universal_ie.task_format.cols import (
|
||||
TokenTagCols,
|
||||
I2b2Conll,
|
||||
TagTokenCols,
|
||||
TokenTagJson,
|
||||
CoNLL03,
|
||||
)
|
|
@ -0,0 +1,82 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
|
||||
import json
|
||||
from typing import List
|
||||
from universal_ie.utils import tokens_to_str, change_ptb_token_back
|
||||
from universal_ie.ie_format import Entity, Label, Relation, Sentence, Span
|
||||
from universal_ie.task_format.task_format import TaskFormat
|
||||
|
||||
|
||||
class ABSA(TaskFormat):
|
||||
""" Aspect-Based Sentiment Analysis Data format at https://github.com/yhcc/BARTABSA."""
|
||||
|
||||
def __init__(self, sentence_json, language='en'):
|
||||
super().__init__(
|
||||
language=language
|
||||
)
|
||||
self.tokens = sentence_json['words']
|
||||
for index in range(len(self.tokens)):
|
||||
self.tokens[index] = change_ptb_token_back(self.tokens[index])
|
||||
if self.tokens is None:
|
||||
print('[sentence without tokens]:', sentence_json)
|
||||
exit(1)
|
||||
self.aspects = sentence_json['aspects']
|
||||
self.opinions = sentence_json['opinions']
|
||||
|
||||
def generate_instance(self):
|
||||
entities = dict()
|
||||
relations = list()
|
||||
entity_map = dict()
|
||||
|
||||
for aspect, opinion in zip(self.aspects, self.opinions):
|
||||
aspect_span = (aspect['from'], aspect['to'])
|
||||
opinion_span = (opinion['from'], opinion['to'])
|
||||
|
||||
if aspect_span not in entity_map:
|
||||
tokens = self.tokens[aspect_span[0]:aspect_span[1]]
|
||||
entities[aspect_span] = Entity(
|
||||
span=Span(
|
||||
tokens=tokens,
|
||||
indexes=list(range(aspect_span[0], aspect_span[1])),
|
||||
text=tokens_to_str(tokens, language=self.language),
|
||||
),
|
||||
label=Label('aspect')
|
||||
)
|
||||
|
||||
if opinion_span not in entity_map:
|
||||
tokens = self.tokens[opinion_span[0]:opinion_span[1]]
|
||||
entities[opinion_span] = Entity(
|
||||
span=Span(
|
||||
tokens=tokens,
|
||||
indexes=list(range(opinion_span[0], opinion_span[1])),
|
||||
text=tokens_to_str(tokens, language=self.language),
|
||||
),
|
||||
label=Label('opinion')
|
||||
)
|
||||
|
||||
relations += [Relation(
|
||||
arg1=entities[aspect_span],
|
||||
arg2=entities[opinion_span],
|
||||
label=Label(aspect['polarity']),
|
||||
)]
|
||||
|
||||
return Sentence(
|
||||
tokens=self.tokens,
|
||||
entities=entities.values(),
|
||||
relations=relations,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_from_file(filename, language='en') -> List[Sentence]:
|
||||
sentence_list = list()
|
||||
raw_instance_list = json.load(open(filename))
|
||||
print(f"{filename}: {len(raw_instance_list)}")
|
||||
for instance in raw_instance_list:
|
||||
instance = ABSA(
|
||||
sentence_json=instance,
|
||||
language=language
|
||||
).generate_instance()
|
||||
sentence_list += [instance]
|
||||
return sentence_list
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,505 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
from collections import Counter
|
||||
import json
|
||||
from typing import List, Optional, Tuple, Set
|
||||
from tqdm import tqdm
|
||||
from universal_ie.task_format.task_format import TaskFormat
|
||||
from universal_ie.utils import tokens_to_str
|
||||
from universal_ie.ie_format import Entity, Label, Sentence, Span
|
||||
|
||||
|
||||
# https://github.com/allenai/allennlp/blob/main/allennlp/data/dataset_readers/dataset_utils/span_utils.py
|
||||
# ### Start Code
|
||||
def bio_tags_to_spans(
|
||||
tag_sequence: List[str], classes_to_ignore: List[str] = None
|
||||
) -> List[Tuple[str, Tuple[int, int]]]:
|
||||
"""
|
||||
Given a sequence corresponding to BIO tags, extracts spans.
|
||||
Spans are inclusive and can be of zero length, representing a single word span.
|
||||
Ill-formed spans are also included (i.e those which do not start with a "B-LABEL"),
|
||||
as otherwise it is possible to get a perfect precision score whilst still predicting
|
||||
ill-formed spans in addition to the correct spans. This function works properly when
|
||||
the spans are unlabeled (i.e., your labels are simply "B", "I", and "O").
|
||||
# Parameters
|
||||
tag_sequence : `List[str]`, required.
|
||||
The integer class labels for a sequence.
|
||||
classes_to_ignore : `List[str]`, optional (default = `None`).
|
||||
A list of string class labels `excluding` the bio tag
|
||||
which should be ignored when extracting spans.
|
||||
# Returns
|
||||
spans : `List[TypedStringSpan]`
|
||||
The typed, extracted spans from the sequence, in the format (label, (span_start, span_end)).
|
||||
Note that the label `does not` contain any BIO tag prefixes.
|
||||
"""
|
||||
classes_to_ignore = classes_to_ignore or []
|
||||
spans: Set[Tuple[str, Tuple[int, int]]] = set()
|
||||
span_start = 0
|
||||
span_end = 0
|
||||
active_conll_tag = None
|
||||
for index, string_tag in enumerate(tag_sequence):
|
||||
# Actual BIO tag.
|
||||
bio_tag = string_tag[0]
|
||||
if bio_tag not in ["B", "I", "O"]:
|
||||
raise RuntimeError('Invalid tag sequence %s' % tag_sequence)
|
||||
conll_tag = string_tag[2:]
|
||||
if bio_tag == "O" or conll_tag in classes_to_ignore:
|
||||
# The span has ended.
|
||||
if active_conll_tag is not None:
|
||||
spans.add((active_conll_tag, (span_start, span_end)))
|
||||
active_conll_tag = None
|
||||
# We don't care about tags we are
|
||||
# told to ignore, so we do nothing.
|
||||
continue
|
||||
elif bio_tag == "B":
|
||||
# We are entering a new span; reset indices
|
||||
# and active tag to new span.
|
||||
if active_conll_tag is not None:
|
||||
spans.add((active_conll_tag, (span_start, span_end)))
|
||||
active_conll_tag = conll_tag
|
||||
span_start = index
|
||||
span_end = index
|
||||
elif bio_tag == "I" and conll_tag == active_conll_tag:
|
||||
# We're inside a span.
|
||||
span_end += 1
|
||||
else:
|
||||
# This is the case the bio label is an "I", but either:
|
||||
# 1) the span hasn't started - i.e. an ill formed span.
|
||||
# 2) The span is an I tag for a different conll annotation.
|
||||
# We'll process the previous span if it exists, but also
|
||||
# include this span. This is important, because otherwise,
|
||||
# a model may get a perfect F1 score whilst still including
|
||||
# false positive ill-formed spans.
|
||||
if active_conll_tag is not None:
|
||||
spans.add((active_conll_tag, (span_start, span_end)))
|
||||
active_conll_tag = conll_tag
|
||||
span_start = index
|
||||
span_end = index
|
||||
# Last token might have been a part of a valid span.
|
||||
if active_conll_tag is not None:
|
||||
spans.add((active_conll_tag, (span_start, span_end)))
|
||||
return list(spans)
|
||||
|
||||
|
||||
def _iob1_start_of_chunk(
|
||||
prev_bio_tag: Optional[str],
|
||||
prev_conll_tag: Optional[str],
|
||||
curr_bio_tag: str,
|
||||
curr_conll_tag: str,
|
||||
) -> bool:
|
||||
if curr_bio_tag == "B":
|
||||
return True
|
||||
if curr_bio_tag == "I" and prev_bio_tag == "O":
|
||||
return True
|
||||
if curr_bio_tag != "O" and prev_conll_tag != curr_conll_tag:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def iob1_tags_to_spans(
|
||||
tag_sequence: List[str], classes_to_ignore: List[str] = None
|
||||
) -> List[Tuple[str, Tuple[int, int]]]:
|
||||
"""
|
||||
Given a sequence corresponding to IOB1 tags, extracts spans.
|
||||
Spans are inclusive and can be of zero length, representing a single word span.
|
||||
Ill-formed spans are also included (i.e., those where "B-LABEL" is not preceded
|
||||
by "I-LABEL" or "B-LABEL").
|
||||
# Parameters
|
||||
tag_sequence : `List[str]`, required.
|
||||
The integer class labels for a sequence.
|
||||
classes_to_ignore : `List[str]`, optional (default = `None`).
|
||||
A list of string class labels `excluding` the bio tag
|
||||
which should be ignored when extracting spans.
|
||||
# Returns
|
||||
spans : `List[TypedStringSpan]`
|
||||
The typed, extracted spans from the sequence, in the format (label, (span_start, span_end)).
|
||||
Note that the label `does not` contain any BIO tag prefixes.
|
||||
"""
|
||||
classes_to_ignore = classes_to_ignore or []
|
||||
spans: Set[Tuple[str, Tuple[int, int]]] = set()
|
||||
span_start = 0
|
||||
span_end = 0
|
||||
active_conll_tag = None
|
||||
prev_bio_tag = None
|
||||
prev_conll_tag = None
|
||||
for index, string_tag in enumerate(tag_sequence):
|
||||
curr_bio_tag = string_tag[0]
|
||||
curr_conll_tag = string_tag[2:]
|
||||
|
||||
if curr_bio_tag not in ["B", "I", "O"]:
|
||||
raise RuntimeError('Invalid tag sequence %s' % tag_sequence)
|
||||
if curr_bio_tag == "O" or curr_conll_tag in classes_to_ignore:
|
||||
# The span has ended.
|
||||
if active_conll_tag is not None:
|
||||
spans.add((active_conll_tag, (span_start, span_end)))
|
||||
active_conll_tag = None
|
||||
elif _iob1_start_of_chunk(prev_bio_tag, prev_conll_tag, curr_bio_tag, curr_conll_tag):
|
||||
# We are entering a new span; reset indices
|
||||
# and active tag to new span.
|
||||
if active_conll_tag is not None:
|
||||
spans.add((active_conll_tag, (span_start, span_end)))
|
||||
active_conll_tag = curr_conll_tag
|
||||
span_start = index
|
||||
span_end = index
|
||||
else:
|
||||
# bio_tag == "I" and curr_conll_tag == active_conll_tag
|
||||
# We're continuing a span.
|
||||
span_end += 1
|
||||
|
||||
prev_bio_tag = string_tag[0]
|
||||
prev_conll_tag = string_tag[2:]
|
||||
# Last token might have been a part of a valid span.
|
||||
if active_conll_tag is not None:
|
||||
spans.add((active_conll_tag, (span_start, span_end)))
|
||||
return list(spans)
|
||||
|
||||
|
||||
def bmes_tags_to_spans(
|
||||
tag_sequence: List[str], classes_to_ignore: List[str] = None
|
||||
) -> List[Tuple[str, Tuple[int, int]]]:
|
||||
"""
|
||||
Given a sequence corresponding to BMES tags, extracts spans.
|
||||
Spans are inclusive and can be of zero length, representing a single word span.
|
||||
Ill-formed spans are also included (i.e those which do not start with a "B-LABEL"),
|
||||
as otherwise it is possible to get a perfect precision score whilst still predicting
|
||||
ill-formed spans in addition to the correct spans.
|
||||
This function works properly when the spans are unlabeled (i.e., your labels are
|
||||
simply "B", "M", "E" and "S").
|
||||
# Parameters
|
||||
tag_sequence : `List[str]`, required.
|
||||
The integer class labels for a sequence.
|
||||
classes_to_ignore : `List[str]`, optional (default = `None`).
|
||||
A list of string class labels `excluding` the bio tag
|
||||
which should be ignored when extracting spans.
|
||||
# Returns
|
||||
spans : `List[TypedStringSpan]`
|
||||
The typed, extracted spans from the sequence, in the format (label, (span_start, span_end)).
|
||||
Note that the label `does not` contain any BIO tag prefixes.
|
||||
"""
|
||||
|
||||
def extract_bmes_tag_label(text):
|
||||
bmes_tag = text[0]
|
||||
label = text[2:]
|
||||
return bmes_tag, label
|
||||
|
||||
spans: List[Tuple[str, List[int]]] = []
|
||||
prev_bmes_tag: Optional[str] = None
|
||||
for index, tag in enumerate(tag_sequence):
|
||||
bmes_tag, label = extract_bmes_tag_label(tag)
|
||||
if bmes_tag in ("B", "S"):
|
||||
# Regardless of tag, we start a new span when reaching B & S.
|
||||
spans.append((label, [index, index]))
|
||||
elif bmes_tag in ("M", "E") and prev_bmes_tag in ("B", "M") and spans[-1][0] == label:
|
||||
# Only expand the span if
|
||||
# 1. Valid transition: B/M -> M/E.
|
||||
# 2. Matched label.
|
||||
spans[-1][1][1] = index
|
||||
else:
|
||||
# Best effort split for invalid span.
|
||||
spans.append((label, [index, index]))
|
||||
# update previous BMES tag.
|
||||
prev_bmes_tag = bmes_tag
|
||||
|
||||
classes_to_ignore = classes_to_ignore or []
|
||||
return [
|
||||
# to tuple.
|
||||
(span[0], (span[1][0], span[1][1]))
|
||||
for span in spans
|
||||
if span[0] not in classes_to_ignore
|
||||
]
|
||||
|
||||
|
||||
def bioul_tags_to_spans(
|
||||
tag_sequence: List[str], classes_to_ignore: List[str] = None
|
||||
) -> List[Tuple[str, Tuple[int, int]]]:
|
||||
"""
|
||||
Given a sequence corresponding to BIOUL tags, extracts spans.
|
||||
Spans are inclusive and can be of zero length, representing a single word span.
|
||||
Ill-formed spans are not allowed and will raise `InvalidTagSequence`.
|
||||
This function works properly when the spans are unlabeled (i.e., your labels are
|
||||
simply "B", "I", "O", "U", and "L").
|
||||
# Parameters
|
||||
tag_sequence : `List[str]`, required.
|
||||
The tag sequence encoded in BIOUL, e.g. ["B-PER", "L-PER", "O"].
|
||||
classes_to_ignore : `List[str]`, optional (default = `None`).
|
||||
A list of string class labels `excluding` the bio tag
|
||||
which should be ignored when extracting spans.
|
||||
# Returns
|
||||
spans : `List[TypedStringSpan]`
|
||||
The typed, extracted spans from the sequence, in the format (label, (span_start, span_end)).
|
||||
"""
|
||||
spans = []
|
||||
classes_to_ignore = classes_to_ignore or []
|
||||
index = 0
|
||||
while index < len(tag_sequence):
|
||||
label = tag_sequence[index]
|
||||
if label[0] == "U":
|
||||
spans.append((label.partition("-")[2], (index, index)))
|
||||
elif label[0] == "B":
|
||||
start = index
|
||||
while label[0] != "L":
|
||||
index += 1
|
||||
if index >= len(tag_sequence):
|
||||
raise RuntimeError('Invalid tag sequence %s' % tag_sequence)
|
||||
# raise InvalidTagSequence(tag_sequence)
|
||||
label = tag_sequence[index]
|
||||
if not (label[0] == "I" or label[0] == "L"):
|
||||
raise RuntimeError('Invalid tag sequence %s' % tag_sequence)
|
||||
# raise InvalidTagSequence(tag_sequence)
|
||||
spans.append((label.partition("-")[2], (start, index)))
|
||||
else:
|
||||
if label != "O":
|
||||
raise RuntimeError('Invalid tag sequence %s' % tag_sequence)
|
||||
# raise InvalidTagSequence(tag_sequence)
|
||||
index += 1
|
||||
return [span for span in spans if span[0] not in classes_to_ignore]
|
||||
|
||||
|
||||
def bmeso_tags_to_spans(
|
||||
tag_sequence: List[str], classes_to_ignore: List[str] = None
|
||||
) -> List[Tuple[str, Tuple[int, int]]]:
|
||||
"""
|
||||
bmeso -> bioul
|
||||
B = Beginning
|
||||
I/M = Inside / Middle
|
||||
L/E = Last / End
|
||||
O = Outside
|
||||
U/W/S = Unit-length / Whole / Singleton
|
||||
"""
|
||||
new_tag = list()
|
||||
for label in tag_sequence:
|
||||
if label[0] == 'M':
|
||||
new_tag += ['I-' + label.partition("-")[2]]
|
||||
elif label[0] == 'E':
|
||||
new_tag += ['L-' + label.partition("-")[2]]
|
||||
elif label[0] == 'S':
|
||||
new_tag += ['U-' + label.partition("-")[2]]
|
||||
else:
|
||||
new_tag += [label]
|
||||
|
||||
return bioul_tags_to_spans(tag_sequence=new_tag, classes_to_ignore=classes_to_ignore)
|
||||
|
||||
|
||||
def bieso_tags_to_spans(
|
||||
tag_sequence: List[str], classes_to_ignore: List[str] = None
|
||||
) -> List[Tuple[str, Tuple[int, int]]]:
|
||||
"""
|
||||
bmeso -> bioul
|
||||
B = Beginning
|
||||
I/M = Inside / Middle
|
||||
L/E = Last / End
|
||||
O = Outside
|
||||
U/W/S = Unit-length / Whole / Singleton
|
||||
"""
|
||||
new_tag = list()
|
||||
for label in tag_sequence:
|
||||
if label[0] == 'E':
|
||||
new_tag += ['L-' + label.partition("-")[2]]
|
||||
elif label[0] == 'S':
|
||||
new_tag += ['U-' + label.partition("-")[2]]
|
||||
else:
|
||||
new_tag += [label]
|
||||
|
||||
return bioul_tags_to_spans(tag_sequence=new_tag, classes_to_ignore=classes_to_ignore)
|
||||
# ### End Code
|
||||
|
||||
|
||||
_tagging_span_function = {
|
||||
'bioul': bioul_tags_to_spans,
|
||||
'bmes': bmes_tags_to_spans,
|
||||
'bio': bio_tags_to_spans,
|
||||
'iob1': iob1_tags_to_spans,
|
||||
'bmeso': bmeso_tags_to_spans,
|
||||
'bieso': bieso_tags_to_spans,
|
||||
}
|
||||
|
||||
|
||||
class Cols(TaskFormat):
|
||||
|
||||
def __init__(self, tokens: List[str], spans: List[Tuple[Tuple[int, int], str]], language='en', instance_id=None) -> None:
|
||||
super().__init__(
|
||||
language=language
|
||||
)
|
||||
self.instance_id = instance_id
|
||||
self.tokens = tokens
|
||||
self.spans = spans
|
||||
|
||||
def generate_instance(self):
|
||||
entities = list()
|
||||
for span_index, span in enumerate(self.spans):
|
||||
tokens = self.tokens[span['start']: span['end'] + 1]
|
||||
indexes = list(range(span['start'], span['end'] + 1))
|
||||
entities += [
|
||||
Entity(
|
||||
span=Span(
|
||||
tokens=tokens,
|
||||
indexes=indexes,
|
||||
text=tokens_to_str(tokens, language=self.language),
|
||||
text_id=self.instance_id
|
||||
),
|
||||
label=Label(span['type']),
|
||||
text_id=self.instance_id,
|
||||
record_id=self.instance_id + "#%s" % span_index if self.instance_id else None)
|
||||
]
|
||||
return Sentence(tokens=self.tokens,
|
||||
entities=entities,
|
||||
text_id=self.instance_id)
|
||||
|
||||
@staticmethod
|
||||
def generate_sentence(filename):
|
||||
sentence = list()
|
||||
with open(filename) as fin:
|
||||
for line in fin:
|
||||
if line.strip() == '':
|
||||
if len(sentence) != 0:
|
||||
yield sentence
|
||||
sentence = list()
|
||||
|
||||
else:
|
||||
sentence += [line.strip().split()]
|
||||
|
||||
if len(sentence) != 0:
|
||||
yield sentence
|
||||
|
||||
|
||||
class TokenTagCols(Cols):
|
||||
|
||||
@staticmethod
|
||||
def load_from_file(filename, language='en', tagging='bio') -> List[Sentence]:
|
||||
sentence_list = list()
|
||||
counter = Counter()
|
||||
for rows in tqdm(Cols.generate_sentence(filename)):
|
||||
tokens = [token[0] for token in rows]
|
||||
ner = [token[1] for token in rows]
|
||||
spans = _tagging_span_function[tagging](ner)
|
||||
spans = list(filter(lambda x: x[0] != "", spans))
|
||||
spans = [
|
||||
{'start': span[1][0], 'end': span[1][1], 'type': span[0]}
|
||||
for span in spans
|
||||
]
|
||||
sentence = Cols(
|
||||
tokens=tokens,
|
||||
spans=spans,
|
||||
language=language,
|
||||
)
|
||||
counter.update(['token'] * len(tokens))
|
||||
counter.update(['sentence'])
|
||||
counter.update(['span'] * len(spans))
|
||||
sentence_list += [sentence.generate_instance()]
|
||||
print(filename, counter)
|
||||
return sentence_list
|
||||
|
||||
|
||||
class TagTokenCols(Cols):
|
||||
|
||||
@staticmethod
|
||||
def load_from_file(filename, language='en', tagging='bio') -> List[Sentence]:
|
||||
sentence_list = list()
|
||||
counter = Counter()
|
||||
for rows in tqdm(Cols.generate_sentence(filename)):
|
||||
tokens = [token[1] for token in rows]
|
||||
ner = [token[0] for token in rows]
|
||||
spans = _tagging_span_function[tagging](ner)
|
||||
spans = [
|
||||
{'start': span[1][0], 'end': span[1][1], 'type': span[0]}
|
||||
for span in spans
|
||||
]
|
||||
sentence = Cols(
|
||||
tokens=tokens,
|
||||
spans=spans,
|
||||
language=language,
|
||||
)
|
||||
counter.update(['token'] * len(tokens))
|
||||
counter.update(['sentence'])
|
||||
counter.update(['span'] * len(spans))
|
||||
sentence_list += [sentence.generate_instance()]
|
||||
print(filename, counter)
|
||||
return sentence_list
|
||||
|
||||
|
||||
class TokenTagJson(Cols):
|
||||
@staticmethod
|
||||
def load_from_file(filename, language='en', tagging='bio') -> List[Sentence]:
|
||||
sentence_list = list()
|
||||
counter = Counter()
|
||||
for line in open(filename):
|
||||
instance = json.loads(line.strip())
|
||||
tokens = instance['tokens']
|
||||
ner = instance['ner_tags']
|
||||
spans = _tagging_span_function[tagging](ner)
|
||||
spans = list(filter(lambda x: x[0] != "", spans))
|
||||
spans = [
|
||||
{'start': span[1][0], 'end': span[1][1], 'type': span[0]}
|
||||
for span in spans
|
||||
]
|
||||
sentence = Cols(
|
||||
tokens=tokens,
|
||||
spans=spans,
|
||||
language=language,
|
||||
)
|
||||
counter.update(['token'] * len(tokens))
|
||||
counter.update(['sentence'])
|
||||
counter.update(['span'] * len(spans))
|
||||
sentence_list += [sentence.generate_instance()]
|
||||
print(filename, counter)
|
||||
return sentence_list
|
||||
|
||||
|
||||
class I2b2Conll(Cols):
|
||||
|
||||
@staticmethod
|
||||
def load_from_file(filename, language='en') -> List[Sentence]:
|
||||
sentence_list = list()
|
||||
counter = Counter()
|
||||
for rows in tqdm(Cols.generate_sentence(filename)):
|
||||
tokens = [token[0] for token in rows]
|
||||
ner = [token[4] for token in rows]
|
||||
spans = bio_tags_to_spans(ner)
|
||||
spans = [
|
||||
{'start': span[1][0], 'end': span[1][1], 'type': span[0]}
|
||||
for span in spans
|
||||
]
|
||||
sentence = Cols(
|
||||
tokens=tokens,
|
||||
spans=spans,
|
||||
language=language,
|
||||
)
|
||||
counter.update(['token'] * len(tokens))
|
||||
counter.update(['sentence'])
|
||||
counter.update(['span'] * len(spans))
|
||||
sentence_list += [sentence.generate_instance()]
|
||||
print(filename, counter)
|
||||
return sentence_list
|
||||
|
||||
|
||||
class CoNLL03(Cols):
|
||||
|
||||
@staticmethod
|
||||
def load_from_file(filename, language='en') -> List[Sentence]:
|
||||
sentence_list = list()
|
||||
counter = Counter()
|
||||
for rows in tqdm(Cols.generate_sentence(filename)):
|
||||
if rows[0][0] == '-DOCSTART-':
|
||||
continue
|
||||
tokens = [token[0] for token in rows]
|
||||
ner = [token[3] for token in rows]
|
||||
spans = iob1_tags_to_spans(ner)
|
||||
spans = [
|
||||
{'start': span[1][0], 'end': span[1][1], 'type': span[0]}
|
||||
for span in spans
|
||||
]
|
||||
sentence = Cols(
|
||||
tokens=tokens,
|
||||
spans=spans,
|
||||
language=language,
|
||||
)
|
||||
counter.update(['token'] * len(tokens))
|
||||
counter.update(['sentence'])
|
||||
counter.update(['span'] * len(spans))
|
||||
sentence_list += [sentence.generate_instance()]
|
||||
print(filename, counter)
|
||||
return sentence_list
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
|
@ -0,0 +1,84 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
|
||||
import json
|
||||
from typing import List
|
||||
from universal_ie.utils import tokens_to_str, change_ptb_token_back
|
||||
from universal_ie.ie_format import Entity, Label, Relation, Sentence, Span
|
||||
from universal_ie.task_format.task_format import TaskFormat
|
||||
|
||||
|
||||
class JointER(TaskFormat):
|
||||
""" Joint Entity Relation Data format at https://github.com/yubowen-ph/JointER"""
|
||||
|
||||
def __init__(self, sentence_json, language='en'):
|
||||
super().__init__(
|
||||
language=language
|
||||
)
|
||||
self.tokens = sentence_json['tokens']
|
||||
for index in range(len(self.tokens)):
|
||||
self.tokens[index] = change_ptb_token_back(self.tokens[index])
|
||||
if self.tokens is None:
|
||||
print('[sentence without tokens]:', sentence_json)
|
||||
exit(1)
|
||||
self.spo_list = sentence_json['spo_list']
|
||||
self.spo_details = sentence_json['spo_details']
|
||||
self.pos_tags = sentence_json['pos_tags']
|
||||
|
||||
def generate_instance(self):
|
||||
entities = dict()
|
||||
relations = dict()
|
||||
entity_map = dict()
|
||||
|
||||
for spo_index, spo in enumerate(self.spo_details):
|
||||
s_s, s_e, s_t = spo[0], spo[1], spo[2]
|
||||
tokens = self.tokens[s_s: s_e]
|
||||
indexes = list(range(s_s, s_e))
|
||||
if (s_s, s_e, s_t) not in entity_map:
|
||||
entities[(s_s, s_e, s_t)] = Entity(
|
||||
span=Span(
|
||||
tokens=tokens,
|
||||
indexes=indexes,
|
||||
text=tokens_to_str(tokens, language=self.language),
|
||||
),
|
||||
label=Label(s_t)
|
||||
)
|
||||
|
||||
o_s, o_e, o_t = spo[4], spo[5], spo[6]
|
||||
tokens = self.tokens[o_s: o_e]
|
||||
indexes = list(range(o_s, o_e))
|
||||
if (o_s, o_e, o_t) not in entity_map:
|
||||
entities[(o_s, o_e, o_t)] = Entity(
|
||||
span=Span(
|
||||
tokens=tokens,
|
||||
indexes=indexes,
|
||||
text=tokens_to_str(tokens, language=self.language),
|
||||
),
|
||||
label=Label(o_t)
|
||||
)
|
||||
|
||||
relations[spo_index] = Relation(
|
||||
arg1=entities[(s_s, s_e, s_t)],
|
||||
arg2=entities[(o_s, o_e, o_t)],
|
||||
label=Label(spo[3]),
|
||||
)
|
||||
|
||||
return Sentence(
|
||||
tokens=self.tokens,
|
||||
entities=entities.values(),
|
||||
relations=relations.values(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_from_file(filename, language='en') -> List[Sentence]:
|
||||
sentence_list = list()
|
||||
raw_instance_list = json.load(open(filename))
|
||||
print(f"{filename}: {len(raw_instance_list)}")
|
||||
for instance in raw_instance_list:
|
||||
instance = JointER(
|
||||
sentence_json=instance,
|
||||
language=language
|
||||
).generate_instance()
|
||||
sentence_list += [instance]
|
||||
return sentence_list
|
|
@ -0,0 +1,55 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Dict, List
|
||||
from universal_ie.task_format.spannet import Spannet
|
||||
from universal_ie.ie_format import Sentence
|
||||
|
||||
|
||||
class MRCNER(Spannet):
|
||||
""" MRC NER format at https://github.com/ShannonAI/mrc-for-flat-nested-ner"""
|
||||
id_template = "%s#%s"
|
||||
|
||||
def __init__(self, instance_json: Dict, language='en'):
|
||||
super().__init__(
|
||||
instance_json=instance_json,
|
||||
language=language
|
||||
)
|
||||
|
||||
@ staticmethod
|
||||
def load_from_file(filename, language='en') -> List[Sentence]:
|
||||
counter = Counter()
|
||||
dataset = defaultdict(dict)
|
||||
with open(filename) as fin:
|
||||
for instance in json.load(fin):
|
||||
counter.update(['label sentence'])
|
||||
key, _ = instance['qas_id'].split('.')
|
||||
dataset[key]['tokens'] = instance['context'].split()
|
||||
if 'spans' not in dataset[key]:
|
||||
dataset[key]['spans'] = list()
|
||||
for start, end in zip(instance['start_position'],
|
||||
instance['end_position']):
|
||||
dataset[key]['spans'] += [{
|
||||
'start': start,
|
||||
'end': end,
|
||||
'type': instance['entity_label']
|
||||
}]
|
||||
counter.update(['span'])
|
||||
|
||||
sentence_list = list()
|
||||
for sentence_id, sentence in dataset.items():
|
||||
counter.update(['sentence'])
|
||||
mrc_instance = MRCNER(
|
||||
instance_json={
|
||||
'tokens': sentence['tokens'],
|
||||
'span_list': sentence['spans'],
|
||||
'id': sentence_id
|
||||
},
|
||||
language=language
|
||||
)
|
||||
sentence_list += [mrc_instance.generate_instance()]
|
||||
|
||||
print(filename, counter)
|
||||
|
||||
return sentence_list
|
|
@ -0,0 +1,128 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
from typing import List
|
||||
from universal_ie.task_format.task_format import TaskFormat
|
||||
from universal_ie.utils import tokens_to_str
|
||||
from universal_ie.ie_format import Entity, Event, Label, Sentence, Span
|
||||
|
||||
|
||||
"""
|
||||
{
|
||||
"doc_id": "AFP_ENG_20030427.0118",
|
||||
"sent_id": "AFP_ENG_20030427.0118-1",
|
||||
"tokens": ["A", "Pakistani", "court", "in", "central", "Punjab", "province", "has", "sentenced", "a", "Christian", "man", "to", "life", "imprisonment", "for", "a", "blasphemy", "conviction", ",", "police", "said", "Sunday", "."], "pieces": ["A", "Pakistani", "court", "in", "central", "Punjab", "province", "has", "sentenced", "a", "Christian", "man", "to", "life", "imprisonment", "for", "a", "b", "##lasp", "##hem", "##y", "conviction", ",", "police", "said", "Sunday", "."],
|
||||
"token_lens": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1],
|
||||
"sentence": "A Pakistani court in central Punjab province has sentenced a Christian man to life imprisonment for a blasphemy conviction, police said Sunday.",
|
||||
"entity_mentions": [
|
||||
{"id": "AFP_ENG_20030427.0118-E15-53", "text": "Pakistani", "entity_type": "GPE", "mention_type": "NAM", "entity_subtype": "Nation", "start": 1, "end": 2},
|
||||
{"id": "AFP_ENG_20030427.0118-E35-52", "text": "court", "entity_type": "ORG", "mention_type": "NOM", "entity_subtype": "Government", "start": 2, "end": 3},
|
||||
{"id": "AFP_ENG_20030427.0118-E37-54", "text": "province", "entity_type": "LOC", "mention_type": "NOM", "entity_subtype": "Region-General", "start": 6, "end": 7},
|
||||
{"id": "AFP_ENG_20030427.0118-E27-48", "text": "Christian", "entity_type": "PER", "mention_type": "NOM", "entity_subtype": "Group", "start": 10, "end": 11},
|
||||
{"id": "AFP_ENG_20030427.0118-E38-55", "text": "man", "entity_type": "PER", "mention_type": "NOM", "entity_subtype": "Individual", "start": 11, "end": 12},
|
||||
{"id": "AFP_ENG_20030427.0118-E39-56", "text": "police", "entity_type": "PER", "mention_type": "NOM", "entity_subtype": "Group", "start": 20, "end": 21}],
|
||||
"relation_mentions": [
|
||||
{"id": "AFP_ENG_20030427.0118-R1-1", "relation_type": "GEN-AFF", "relation_subtype": "GEN-AFF:Citizen-Resident-Religion-Ethnicity",
|
||||
"arguments": [
|
||||
{"entity_id": "AFP_ENG_20030427.0118-E38-55", "text": "man", "role": "Arg-1"},
|
||||
{"entity_id": "AFP_ENG_20030427.0118-E27-48", "text": "Christian", "role": "Arg-2"}
|
||||
]
|
||||
},
|
||||
{"id": "AFP_ENG_20030427.0118-R3-1", "relation_type": "PART-WHOLE", "relation_subtype": "PART-WHOLE:Subsidiary",
|
||||
"arguments": [
|
||||
{"entity_id": "AFP_ENG_20030427.0118-E35-52", "text": "court", "role": "Arg-1"},
|
||||
{"entity_id": "AFP_ENG_20030427.0118-E15-53", "text": "Pakistani", "role": "Arg-2"}
|
||||
]
|
||||
},
|
||||
{"id": "AFP_ENG_20030427.0118-R4-1", "relation_type": "GEN-AFF", "relation_subtype": "GEN-AFF:Org-Location",
|
||||
"arguments": [
|
||||
{"entity_id": "AFP_ENG_20030427.0118-E35-52", "text": "court", "role": "Arg-1"},
|
||||
{"entity_id": "AFP_ENG_20030427.0118-E37-54", "text": "province", "role": "Arg-2"}
|
||||
]
|
||||
}
|
||||
],
|
||||
"event_mentions": [
|
||||
{"id": "AFP_ENG_20030427.0118-EV1-1", "event_type": "Justice:Sentence",
|
||||
"trigger": {"text": "sentenced", "start": 8, "end": 9},
|
||||
"arguments": [
|
||||
{"entity_id": "AFP_ENG_20030427.0118-E35-52", "text": "court", "role": "Adjudicator"},
|
||||
{"entity_id": "AFP_ENG_20030427.0118-E38-55", "text": "man", "role": "Defendant"},
|
||||
{"entity_id": "AFP_ENG_20030427.0118-E37-54", "text": "province", "role": "Place"}
|
||||
]},
|
||||
{"id": "AFP_ENG_20030427.0118-EV2-1", "event_type": "Justice:Convict",
|
||||
"trigger": {"text": "conviction", "start": 18, "end": 19},
|
||||
"arguments": [{"entity_id": "AFP_ENG_20030427.0118-E38-55", "text": "man", "role": "Defendant"}
|
||||
]}
|
||||
]}
|
||||
"""
|
||||
|
||||
|
||||
class OneIEEvent(TaskFormat):
|
||||
def __init__(self, doc_json, language='en'):
|
||||
super().__init__(
|
||||
language=language
|
||||
)
|
||||
self.doc_id = doc_json['doc_id']
|
||||
self.sent_id = doc_json['sent_id']
|
||||
self.tokens = doc_json['tokens']
|
||||
self.entities = doc_json['entity_mentions']
|
||||
self.relations = doc_json['relation_mentions']
|
||||
self.events = doc_json['event_mentions']
|
||||
|
||||
def generate_instance(self):
|
||||
events = dict()
|
||||
entities = dict()
|
||||
|
||||
for span_index, span in enumerate(self.entities):
|
||||
tokens = self.tokens[span['start']: span['end']]
|
||||
indexes = list(range(span['start'], span['end']))
|
||||
entities[span['id']] = Entity(
|
||||
span=Span(
|
||||
tokens=tokens,
|
||||
indexes=indexes,
|
||||
text=tokens_to_str(tokens, language=self.language),
|
||||
text_id=self.sent_id
|
||||
),
|
||||
label=Label(span['entity_type']),
|
||||
text_id=self.sent_id,
|
||||
record_id=span['id']
|
||||
)
|
||||
|
||||
for event_index, event in enumerate(self.events):
|
||||
start = event['trigger']['start']
|
||||
end = event['trigger']['end']
|
||||
tokens = self.tokens[start:end]
|
||||
indexes = list(range(start, end))
|
||||
events[event['id']] = Event(
|
||||
span=Span(
|
||||
tokens=tokens,
|
||||
indexes=indexes,
|
||||
text=tokens_to_str(tokens, language=self.language),
|
||||
text_id=self.sent_id
|
||||
),
|
||||
label=Label(event['event_type']),
|
||||
args=[(Label(x['role']), entities[x['entity_id']])
|
||||
for x in event['arguments']],
|
||||
text_id=self.sent_id,
|
||||
record_id=event['id']
|
||||
)
|
||||
|
||||
return Sentence(
|
||||
tokens=self.tokens,
|
||||
entities=list(),
|
||||
relations=list(),
|
||||
events=events.values(),
|
||||
text_id=self.sent_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_from_file(filename, language='en') -> List[Sentence]:
|
||||
sentence_list = list()
|
||||
with open(filename) as fin:
|
||||
for line in fin:
|
||||
instance = OneIEEvent(
|
||||
json.loads(line.strip()),
|
||||
language=language
|
||||
).generate_instance()
|
||||
sentence_list += [instance]
|
||||
return sentence_list
|
|
@ -0,0 +1,87 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
from collections import Counter
|
||||
import json
|
||||
from typing import List, Dict
|
||||
from universal_ie.task_format.task_format import TaskFormat
|
||||
from universal_ie.utils import change_ptb_token_back, tokens_to_str
|
||||
from universal_ie.ie_format import Entity, Label, Relation, Sentence, Span
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class Spannet(TaskFormat):
|
||||
"""
|
||||
{
|
||||
"tokens": ["An", "art", "exhibit", "at", "the", "Hakawati", "Theatre",
|
||||
"in", "Arab", "east", "Jerusalem", "was", "a", "series",
|
||||
"of", "portraits", "of", "Palestinians", "killed", "in",
|
||||
"the", "rebellion", "."],
|
||||
"span_pair_list": [
|
||||
{"type": "OrgBased_In", "head": 0, "tail": 2}
|
||||
],
|
||||
"span_list": [
|
||||
{"type": "Org", "start": 5, "end": 6},
|
||||
{"type": "Other", "start": 8, "end": 8},
|
||||
{"type": "Loc", "start": 10, "end": 10},
|
||||
{"type": "Other", "start": 17, "end": 17}
|
||||
]
|
||||
}
|
||||
"""
|
||||
def __init__(self, instance_json: Dict, language='en') -> None:
|
||||
super().__init__(
|
||||
language=language
|
||||
)
|
||||
self.tokens = change_ptb_token_back(instance_json['tokens'])
|
||||
self.span_list = instance_json.get('span_list', [])
|
||||
self.span_pair_list = instance_json.get('span_pair_list', [])
|
||||
self.instance_id = instance_json.get('id', None)
|
||||
|
||||
def generate_instance(self):
|
||||
entities = list()
|
||||
relations = list()
|
||||
for span_index, span in enumerate(self.span_list):
|
||||
tokens = self.tokens[span['start']: span['end'] + 1]
|
||||
indexes = list(range(span['start'], span['end'] + 1))
|
||||
entities += [
|
||||
Entity(
|
||||
span=Span(
|
||||
tokens=tokens,
|
||||
indexes=indexes,
|
||||
text=tokens_to_str(tokens, language=self.language),
|
||||
text_id=self.instance_id
|
||||
),
|
||||
label=Label(span['type']),
|
||||
text_id=self.instance_id,
|
||||
record_id=self.instance_id + "#%s" % span_index if self.instance_id else None)
|
||||
]
|
||||
for spanpair_index, span_pair in enumerate(self.span_pair_list):
|
||||
relations += [
|
||||
Relation(
|
||||
arg1=entities[span_pair['head']],
|
||||
arg2=entities[span_pair['tail']],
|
||||
label=Label(span_pair['type']),
|
||||
text_id=self.instance_id,
|
||||
record_id=self.instance_id + "##%s" % spanpair_index if self.instance_id else None
|
||||
)
|
||||
]
|
||||
return Sentence(tokens=self.tokens,
|
||||
entities=entities,
|
||||
relations=relations,
|
||||
text_id=self.instance_id)
|
||||
|
||||
@staticmethod
|
||||
def load_from_file(filename, language='en') -> List[Sentence]:
|
||||
sentence_list = list()
|
||||
counter = Counter()
|
||||
with open(filename) as fin:
|
||||
for line in tqdm(fin):
|
||||
spannet = Spannet(
|
||||
json.loads(line.strip()),
|
||||
language=language
|
||||
)
|
||||
instance = spannet.generate_instance()
|
||||
sentence_list += [instance]
|
||||
counter.update(['sentence'])
|
||||
counter.update(['span'] * len(spannet.span_list))
|
||||
print(filename, counter)
|
||||
return sentence_list
|
|
@ -0,0 +1,20 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
import abc
|
||||
|
||||
|
||||
class TaskFormat:
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(self, language='en'):
|
||||
self.language = language
|
||||
|
||||
@abc.abstractmethod
|
||||
def generate_instance(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def load_from_file(filename, language='en'):
|
||||
pass
|
|
@ -0,0 +1,83 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
from typing import List
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
global_mislabel_log = set()
|
||||
|
||||
|
||||
def tokens_to_str(tokens: List[str], language: str = 'en') -> str:
|
||||
if language == 'en':
|
||||
return ' '.join(tokens)
|
||||
elif language == 'zh':
|
||||
return ''.join(tokens)
|
||||
else:
|
||||
raise NotImplementedError('Language %s not supported' % language)
|
||||
|
||||
|
||||
def label_format(s):
|
||||
import re
|
||||
|
||||
def uncamelize(s):
|
||||
re_outer = re.compile(r'([^A-Z ])([A-Z])')
|
||||
re_inner = re.compile(r'\b[A-Z]+(?=[A-Z][a-z])')
|
||||
sub = re_inner.sub(r'\g<0> ', re_outer.sub(r'\1 \2', s)).lower()
|
||||
return sub
|
||||
|
||||
def remove(s):
|
||||
return s.replace("_", " ").replace("-", " ").replace(".", " ")
|
||||
|
||||
s = remove(uncamelize(s)).split()
|
||||
if len(s) > 1 and s[0] == s[1]:
|
||||
s = s[1:]
|
||||
return " ".join(s)
|
||||
|
||||
|
||||
def load_dict_ini_file(filename):
|
||||
print("Warning: `load_dict_ini_file` is deprecated.")
|
||||
if not os.path.exists(filename):
|
||||
sys.stderr.write(f'[warning] cannot load label mapper from {filename}\n')
|
||||
return {}
|
||||
mapper = dict()
|
||||
for line in open(filename):
|
||||
key, value = line.strip().split('=')
|
||||
mapper[key] = label_format(value)
|
||||
return mapper
|
||||
|
||||
|
||||
def change_ptb_token_back(token):
|
||||
"""将 PTBTokenized 的 Token 转换会原始字符串
|
||||
|
||||
Args:
|
||||
token (str): PTBTokenize 后的 Token 字符串
|
||||
|
||||
Returns:
|
||||
str: 原始 Token 字符串
|
||||
"""
|
||||
ptb_token_map = {
|
||||
'``': '"',
|
||||
"''": '"',
|
||||
'-LRB-': '(',
|
||||
'-RRB-': ')',
|
||||
'-LSB-': '[',
|
||||
'-RSB-': ']',
|
||||
'-LCB-': '{',
|
||||
'-RCB-': '}',
|
||||
}
|
||||
for ptb_token, raw_token in ptb_token_map.items():
|
||||
if token == ptb_token:
|
||||
return raw_token
|
||||
return token
|
||||
|
||||
|
||||
def change_name_using_label_mapper(label_name, label_mapper):
|
||||
if label_mapper is None or len(label_mapper) == 0:
|
||||
return label_name
|
||||
if label_name not in label_mapper:
|
||||
print(f"{label_name} not found in mapper")
|
||||
global global_mislabel_log
|
||||
if label_name not in global_mislabel_log:
|
||||
global_mislabel_log.add(label_name)
|
||||
return label_mapper.get(label_name, label_name)
|
|
@ -0,0 +1,29 @@
|
|||
FROM nvidia/cuda:11.0.3-cudnn8-devel-ubuntu18.04
|
||||
LABEL maintainer="Yaojie Lu"
|
||||
LABEL repository="uie"
|
||||
|
||||
RUN apt update && \
|
||||
apt install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
curl \
|
||||
ca-certificates \
|
||||
python3 \
|
||||
python3-pip && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
WORKDIR /pre_env
|
||||
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir mkl && \
|
||||
python3 -m pip install --no-cache-dir torch==1.7.1+cu110 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
|
||||
RUN git clone https://github.com/NVIDIA/apex
|
||||
RUN cd apex && \
|
||||
python3 setup.py install && \
|
||||
python3 -m pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
|
||||
|
||||
COPY ./requirements.txt .
|
||||
RUN python3 -m pip install -r ./requirements.txt
|
||||
|
||||
CMD ["/bin/bash"]
|
|
@ -0,0 +1,53 @@
|
|||
# 数据说明
|
||||
|
||||
``` json
|
||||
{
|
||||
"text": "MULTAN , Pakistan , April 27 ( AFP )",
|
||||
"tokens": ["MULTAN", ",", "Pakistan", ",", "April", "27", "(", "AFP", ")"],
|
||||
"record": "<extra_id_0> <extra_id_0> geographical social political <extra_id_5> MULTAN <extra_id_0> part whole <extra_id_5> Pakistan <extra_id_1> <extra_id_1> <extra_id_0> geographical social political <extra_id_5> Pakistan <extra_id_1> <extra_id_0> organization <extra_id_5> AFP <extra_id_1> <extra_id_1>",
|
||||
"entity": [
|
||||
{"type": "geographical social political", "offset": [0], "text": "MULTAN"},
|
||||
{"type": "geographical social political", "offset": [2], "text": "Pakistan"},
|
||||
{"type": "organization", "offset": [7], "text": "AFP"}
|
||||
],
|
||||
"relation": [
|
||||
{
|
||||
"type": "part whole",
|
||||
"args": [
|
||||
{"type": "geographical social political", "offset": [0], "text": "MULTAN"},
|
||||
{"type": "geographical social political", "offset": [2], "text": "Pakistan"}
|
||||
]
|
||||
}
|
||||
],
|
||||
"event": [],
|
||||
"spot": ["geographical social political", "organization"],
|
||||
"asoc": ["part whole"],
|
||||
"spot_asoc": [
|
||||
{
|
||||
"span": "MULTAN",
|
||||
"label": "geographical social political",
|
||||
"asoc": [["part whole", "Pakistan"]]
|
||||
},
|
||||
{
|
||||
"span": "Pakistan",
|
||||
"label": "geographical social political", "asoc": []
|
||||
},
|
||||
{
|
||||
"span": "AFP", "label": "organization", "asoc": []
|
||||
}
|
||||
],
|
||||
"task": 'record'
|
||||
}
|
||||
```
|
||||
|
||||
- task: `seq`, `record`, `t5mlm`
|
||||
- mlm 只要求有 Text
|
||||
- seq 只要求有 Record
|
||||
- record 要求有 Text-record 数据
|
||||
- 若无,默认为 Text-record 数据
|
||||
- spot、asoc
|
||||
- 文本中的正例类别
|
||||
- spot_asoc
|
||||
- record 结构表示
|
||||
- entity relation event
|
||||
- Offset 标准答案,用于模型验证。
|
|
@ -0,0 +1,41 @@
|
|||
# Tools for UIE
|
||||
|
||||
### Evaluate Model Performance
|
||||
验证模型性能 (eval_extraction.py)
|
||||
```text
|
||||
$ python scripts/eval_extraction.py -h
|
||||
usage: eval_extraction.py [-h] [-g GOLD_FOLDER] [-p PRED_FOLDER [PRED_FOLDER ...]] [-v] [-w] [-m] [-case]
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
-g GOLD_FOLDER Golden Dataset folder
|
||||
-p PRED_FOLDER [PRED_FOLDER ...]
|
||||
Predicted model folder
|
||||
-v Show more information during running
|
||||
-w Write evaluation results to predicted folder
|
||||
-m Match predicted result multiple times
|
||||
-case Show case study
|
||||
```
|
||||
|
||||
### Check Offset Mapping Performance
|
||||
验证回标的性能 (check_offset_map_gold_as_pred.bash)
|
||||
``` bash
|
||||
bash scripts/check_offset_map_gold_as_pred.bash <data-folder> <map-config>
|
||||
```
|
||||
|
||||
### Convert SEL to Record
|
||||
将结构化表达式转换成 Record 结构 (sel2record.py)
|
||||
``` text
|
||||
$ python scripts/sel2record.py -h
|
||||
usage: sel2record.py [-h] [-g GOLD_FOLDER] [-p PRED_FOLDER [PRED_FOLDER ...]] [-c MAP_CONFIG] [-d DECODING] [-v]
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
-g GOLD_FOLDER 标准答案(Gold)文件夹
|
||||
-p PRED_FOLDER [PRED_FOLDER ...]
|
||||
多个不同的预测(Pred)文件夹
|
||||
-c MAP_CONFIG, --config MAP_CONFIG
|
||||
Offset 匹配策略的配置文件
|
||||
-d DECODING 使用 SpotAsoc 结构的解析器进行结构表达式解析
|
||||
-v, --verbose 打印更详细的日志信息
|
||||
```
|
Binary file not shown.
Binary file not shown.
After Width: | Height: | Size: 23 KiB |
|
@ -0,0 +1,163 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import re
|
||||
from tqdm import tqdm
|
||||
import transformers as huggingface_transformers
|
||||
from uie.extraction.record_schema import RecordSchema
|
||||
from uie.sel2record.record import MapConfig
|
||||
from uie.extraction.scorer import *
|
||||
from uie.sel2record.sel2record import SEL2Record
|
||||
import math
|
||||
import os
|
||||
|
||||
|
||||
split_bracket = re.compile(r"\s*<extra_id_\d>\s*")
|
||||
special_to_remove = {'<pad>', '</s>'}
|
||||
|
||||
|
||||
def read_json_file(file_name):
|
||||
return [json.loads(line) for line in open(file_name)]
|
||||
|
||||
|
||||
def schema_to_ssi(schema: RecordSchema):
|
||||
ssi = "<spot> " + "<spot> ".join(sorted(schema.type_list))
|
||||
ssi += "<asoc> " + "<asoc> ".join(sorted(schema.role_list))
|
||||
ssi += "<extra_id_2> "
|
||||
return ssi
|
||||
|
||||
|
||||
def post_processing(x):
|
||||
for special in special_to_remove:
|
||||
x = x.replace(special, '')
|
||||
return x.strip()
|
||||
|
||||
|
||||
class HuggingfacePredictor:
|
||||
def __init__(self, model_path, schema_file, max_source_length=256, max_target_length=192) -> None:
|
||||
self._tokenizer = huggingface_transformers.T5TokenizerFast.from_pretrained(
|
||||
model_path)
|
||||
self._model = huggingface_transformers.T5ForConditionalGeneration.from_pretrained(
|
||||
model_path)
|
||||
self._model.cuda()
|
||||
self._schema = RecordSchema.read_from_file(schema_file)
|
||||
self._ssi = schema_to_ssi(self._schema)
|
||||
self._max_source_length = max_source_length
|
||||
self._max_target_length = max_target_length
|
||||
|
||||
def predict(self, text):
|
||||
text = [self._ssi + x for x in text]
|
||||
inputs = self._tokenizer(
|
||||
text, padding=True, return_tensors='pt').to(self._model.device)
|
||||
|
||||
inputs['input_ids'] = inputs['input_ids'][:, :self._max_source_length]
|
||||
inputs['attention_mask'] = inputs['attention_mask'][:,
|
||||
:self._max_source_length]
|
||||
|
||||
result = self._model.generate(
|
||||
input_ids=inputs['input_ids'],
|
||||
attention_mask=inputs['attention_mask'],
|
||||
max_length=self._max_target_length,
|
||||
)
|
||||
return self._tokenizer.batch_decode(result, skip_special_tokens=False, clean_up_tokenization_spaces=False)
|
||||
|
||||
|
||||
task_dict = {
|
||||
'entity': EntityScorer,
|
||||
'relation': RelationScorer,
|
||||
'event': EventScorer,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--data', '-d', default='data/text2spotasoc/absa/14lap')
|
||||
parser.add_argument(
|
||||
'--model', '-m', default='./models/uie_n10_21_50w_absa_14lap')
|
||||
parser.add_argument('--max_source_length', default=256, type=int)
|
||||
parser.add_argument('--max_target_length', default=192, type=int)
|
||||
parser.add_argument('--batch_size', default=16, type=int)
|
||||
parser.add_argument('-c', '--config', dest='map_config',
|
||||
help='Offset Re-mapping Config',
|
||||
default='config/offset_map/closest_offset_en.yaml')
|
||||
parser.add_argument('--decoding', default='spotasoc')
|
||||
parser.add_argument('--verbose', action='store_true')
|
||||
parser.add_argument('--match_mode', default='normal',
|
||||
choices=['set', 'normal', 'multimatch'])
|
||||
options = parser.parse_args()
|
||||
|
||||
data_folder = options.data
|
||||
model_path = options.model
|
||||
|
||||
predictor = HuggingfacePredictor(
|
||||
model_path=model_path,
|
||||
schema_file=f"{data_folder}/record.schema",
|
||||
max_source_length=options.max_source_length,
|
||||
max_target_length=options.max_target_length,
|
||||
)
|
||||
|
||||
map_config = MapConfig.load_from_yaml(options.map_config)
|
||||
schema_dict = SEL2Record.load_schema_dict(data_folder)
|
||||
sel2record = SEL2Record(
|
||||
schema_dict=schema_dict,
|
||||
decoding_schema=options.decoding,
|
||||
map_config=map_config,
|
||||
)
|
||||
|
||||
for split, split_name in [('val', 'eval'), ('test', 'test')]:
|
||||
gold_filename = f"{data_folder}/{split}.json"
|
||||
|
||||
text_list = [x['text'] for x in read_json_file(gold_filename)]
|
||||
token_list = [x['tokens'] for x in read_json_file(gold_filename)]
|
||||
|
||||
batch_num = math.ceil(len(text_list) / options.batch_size)
|
||||
|
||||
predict = list()
|
||||
for index in tqdm(range(batch_num)):
|
||||
start = index * options.batch_size
|
||||
end = index * options.batch_size + options.batch_size
|
||||
|
||||
pred_seq2seq = predictor.predict(text_list[start: end])
|
||||
pred_seq2seq = [post_processing(x) for x in pred_seq2seq]
|
||||
|
||||
predict += pred_seq2seq
|
||||
|
||||
records = list()
|
||||
for p, text, tokens in zip(predict, text_list, token_list):
|
||||
r = sel2record.sel2record(pred=p, text=text, tokens=tokens)
|
||||
records += [r]
|
||||
|
||||
results = dict()
|
||||
for task, scorer in task_dict.items():
|
||||
|
||||
gold_list = [x[task] for x in read_json_file(gold_filename)]
|
||||
pred_list = [x[task] for x in records]
|
||||
|
||||
gold_instance_list = scorer.load_gold_list(gold_list)
|
||||
pred_instance_list = scorer.load_pred_list(pred_list)
|
||||
|
||||
sub_results = scorer.eval_instance_list(
|
||||
gold_instance_list=gold_instance_list,
|
||||
pred_instance_list=pred_instance_list,
|
||||
verbose=options.verbose,
|
||||
match_mode=options.match_mode,
|
||||
)
|
||||
results.update(sub_results)
|
||||
|
||||
with open(os.path.join(options.model, f'{split_name}_preds_record.txt'), 'w') as output:
|
||||
for record in records:
|
||||
output.write(f'{json.dumps(record)}\n')
|
||||
|
||||
with open(os.path.join(options.model, f'{split_name}_preds_seq2seq.txt'), 'w') as output:
|
||||
for pred in predict:
|
||||
output.write(f'{pred}\n')
|
||||
|
||||
with open(os.path.join(options.model, f'{split_name}_results.txt'), 'w') as output:
|
||||
for key, value in results.items():
|
||||
output.write(f'{split_name}_{key}={value}\n')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,46 @@
|
|||
from tensorboard.backend.event_processing import event_accumulator
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def read_tensorboard_data(tensorboard_log_path, val_name):
|
||||
ea = event_accumulator.EventAccumulator(tensorboard_log_path)
|
||||
ea.Reload()
|
||||
|
||||
print("All scalers:")
|
||||
print(ea.scalars.Keys())
|
||||
|
||||
val = ea.scalars.Items(val_name)
|
||||
return val
|
||||
|
||||
def plot(vals, val_names, max_step=None):
|
||||
plt.figure()
|
||||
|
||||
for val, val_name in zip(vals, val_names):
|
||||
x = [i.step for i in val]
|
||||
y = [i.value for i in val]
|
||||
|
||||
if max_step is not None:
|
||||
x = [i for i in x if i < max_step]
|
||||
y = y[:len(x)]
|
||||
|
||||
plt.plot(x, y, label=val_name)
|
||||
|
||||
plt.xlabel("step")
|
||||
plt.ylabel("loss")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
if __name__ == "__main__":
|
||||
refine_uie_tensorboard_log_path = "tensorboard_logs/events.out.tfevents.1654419004.dsw32050-7df697f45c-6bwkm.44438.0"
|
||||
refine_t5_tensorboard_log_path = "tensorboard_logs/events.out.tfevents.1654361305.g64h07153.cloud.sqa.nt12.129194.0"
|
||||
uie_t5_tensorboard_log_path = "tensorboard_logs/events.out.tfevents.1654275965.eflops-common033255085104.NT12.106708.0"
|
||||
|
||||
val_name = "train/loss"
|
||||
|
||||
refine_uie_val = read_tensorboard_data(refine_uie_tensorboard_log_path, val_name)
|
||||
refine_t5_val = read_tensorboard_data(refine_t5_tensorboard_log_path, val_name)
|
||||
uie_t5_val = read_tensorboard_data(uie_t5_tensorboard_log_path, val_name)
|
||||
|
||||
vals = [refine_uie_val, refine_t5_val, uie_t5_val]
|
||||
val_names = ["refine_uie_loss", "refine_t5_loss", "uie_t5_loss"]
|
||||
max_step = 20000
|
||||
plot(vals, val_names, max_step=max_step)
|
|
@ -0,0 +1,173 @@
|
|||
absl-py==1.0.0
|
||||
altair==4.2.0
|
||||
anyio==3.5.0
|
||||
anytree==2.8.0
|
||||
argon2-cffi==21.3.0
|
||||
argon2-cffi-bindings==21.2.0
|
||||
asgiref==3.5.0
|
||||
astor==0.8.1
|
||||
asttokens==2.0.5
|
||||
attrs==21.4.0
|
||||
autopep8==1.6.0
|
||||
backcall==0.2.0
|
||||
backports.zoneinfo==0.2.1
|
||||
base58==2.1.1
|
||||
black==21.12b0
|
||||
bleach==4.1.0
|
||||
blinker==1.4
|
||||
cachetools==4.2.4
|
||||
certifi==2021.10.8
|
||||
cffi==1.15.0
|
||||
charset-normalizer==2.0.10
|
||||
click==8.0.3
|
||||
colorama==0.4.4
|
||||
colorlog==6.6.0
|
||||
commonmark==0.9.1
|
||||
conllu==4.4.1
|
||||
cycler==0.11.0
|
||||
dataclasses==0.6
|
||||
datasets==1.9.0
|
||||
debugpy==1.5.1
|
||||
decorator==5.1.1
|
||||
defusedxml==0.7.1
|
||||
dill==0.3.4
|
||||
elasticsearch==7.16.3
|
||||
entrypoints==0.3
|
||||
executing==0.8.2
|
||||
faiss-cpu==1.7.2
|
||||
fastapi==0.74.1
|
||||
filelock==3.0.12
|
||||
fire==0.4.0
|
||||
fonttools==4.28.5
|
||||
fsspec==2022.1.0
|
||||
future==0.18.2
|
||||
git-python==1.0.3
|
||||
gitdb==4.0.9
|
||||
GitPython==3.1.26
|
||||
google-auth==2.3.3
|
||||
google-auth-oauthlib==0.4.6
|
||||
googleapis-common-protos==1.54.0
|
||||
grpcio==1.43.0
|
||||
h11==0.13.0
|
||||
h5py==3.6.0
|
||||
huggingface-hub==0.0.8
|
||||
idna==3.3
|
||||
importlib-metadata==4.10.1
|
||||
importlib-resources==5.4.0
|
||||
iniconfig==1.1.1
|
||||
ipykernel==6.7.0
|
||||
ipython
|
||||
ipython-genutils==0.2.0
|
||||
ipywidgets==7.6.5
|
||||
jedi==0.18.1
|
||||
jieba==0.42.1
|
||||
Jinja2==3.0.3
|
||||
joblib==1.1.0
|
||||
jsonschema==4.4.0
|
||||
jupyter-client==7.1.1
|
||||
jupyter-core==4.9.1
|
||||
jupyterlab-pygments==0.1.2
|
||||
jupyterlab-widgets==1.0.2
|
||||
kiwisolver==1.3.2
|
||||
Markdown==3.3.6
|
||||
MarkupSafe==2.0.1
|
||||
matplotlib==3.5.1
|
||||
matplotlib-inline==0.1.3
|
||||
mistune==0.8.4
|
||||
multiprocess==0.70.12.2
|
||||
mypy-extensions==0.4.3
|
||||
nbclient==0.5.10
|
||||
nbconvert==6.4.0
|
||||
nbformat==5.1.3
|
||||
nest-asyncio==1.5.4
|
||||
nltk==3.6.7
|
||||
notebook==6.4.7
|
||||
numpy==1.19.5
|
||||
oauthlib==3.1.1
|
||||
packaging==21.3
|
||||
pandas==1.3.5
|
||||
pandocfilters==1.5.0
|
||||
parso==0.8.3
|
||||
pathspec==0.9.0
|
||||
pexpect==4.8.0
|
||||
pickleshare==0.7.5
|
||||
Pillow==9.0.0
|
||||
platformdirs==2.4.1
|
||||
pluggy==1.0.0
|
||||
portalocker==2.3.2
|
||||
prometheus-client==0.12.0
|
||||
promise==2.3
|
||||
prompt-toolkit==3.0.24
|
||||
protobuf==3.19.3
|
||||
psutil==5.9.0
|
||||
ptyprocess==0.7.0
|
||||
pure-eval==0.2.1
|
||||
py==1.11.0
|
||||
pyarrow==4.0.1
|
||||
pyasn1==0.4.8
|
||||
pyasn1-modules==0.2.8
|
||||
pycodestyle==2.8.0
|
||||
pycparser==2.21
|
||||
pydantic==1.9.0
|
||||
pydeck==0.7.1
|
||||
Pygments==2.11.2
|
||||
Pympler==1.0.1
|
||||
pyparsing==3.0.6
|
||||
pyrsistent==0.18.1
|
||||
pytest==6.2.5
|
||||
python-dateutil==2.8.2
|
||||
pytz==2021.3
|
||||
pytz-deprecation-shim==0.1.0.post0
|
||||
PyYAML==6.0
|
||||
pyzmq==22.3.0
|
||||
regex==2022.1.18
|
||||
requests==2.27.1
|
||||
requests-oauthlib==1.3.0
|
||||
rich==9.8.2
|
||||
rouge-score==0.0.4
|
||||
rsa==4.8
|
||||
sacrebleu==1.4.14
|
||||
sacremoses==0.0.47
|
||||
scikit-learn==1.0.2
|
||||
scipy==1.7.3
|
||||
Send2Trash==1.8.0
|
||||
sentencepiece==0.1.96
|
||||
seqeval==1.2.2
|
||||
six==1.16.0
|
||||
smmap==5.0.0
|
||||
sniffio==1.2.0
|
||||
stack-data==0.1.4
|
||||
starlette==0.17.1
|
||||
streamlit==1.4.0
|
||||
tabulate==0.8.9
|
||||
tensorboard==2.7.0
|
||||
tensorboard-data-server==0.6.1
|
||||
tensorboard-plugin-wit==1.8.1
|
||||
tensorflow-datasets==4.4.0
|
||||
tensorflow-metadata==1.6.0
|
||||
termcolor==1.1.0
|
||||
terminado==0.12.1
|
||||
testpath==0.5.0
|
||||
threadpoolctl==3.0.0
|
||||
tokenizers==0.10.3
|
||||
toml==0.10.2
|
||||
tomli==1.2.3
|
||||
toolz==0.11.2
|
||||
tornado==6.1
|
||||
tqdm==4.62.3
|
||||
traitlets==5.1.1
|
||||
transformers==4.6.1
|
||||
typing-extensions==3.10.0.2
|
||||
tzdata==2021.5
|
||||
tzlocal==4.1
|
||||
urllib3==1.26.8
|
||||
uvicorn==0.17.5
|
||||
validators==0.18.2
|
||||
watchdog==2.1.6
|
||||
wcwidth==0.2.5
|
||||
webencodings==0.5.1
|
||||
Werkzeug==2.0.2
|
||||
widgetsnbextension==3.5.2
|
||||
xxhash==2.0.2
|
||||
zipp==3.7.0
|
||||
learn2learn
|
|
@ -0,0 +1,87 @@
|
|||
device="0"
|
||||
model_path=""
|
||||
data_folder=data/text2spotasoc/absa/14lap
|
||||
task_name="meta"
|
||||
batch=16
|
||||
decoding_format='spotasoc'
|
||||
beam_size=1
|
||||
map_config=config/offset_map/closest_offset_en.yaml
|
||||
|
||||
export PYTHONPATH="${PYTHONPATH}:./"
|
||||
|
||||
OPTS=$(getopt -o b:d:m:i:t:co:f:e: --long batch:,device:,model:,data:,task:constraint_decoding,output:,format:,map_config:,extra_cmd:, -n 'parse-options' -- "$@")
|
||||
|
||||
if [ $? != 0 ]; then
|
||||
echo "Failed parsing options." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
eval set -- "$OPTS"
|
||||
|
||||
while true; do
|
||||
case "$1" in
|
||||
-b | --batch) batch="$2"
|
||||
shift 2 ;;
|
||||
-d | --device) device="$2"
|
||||
shift 2 ;;
|
||||
-m | --model) model_path="$2"
|
||||
shift 2 ;;
|
||||
-i | --data) data_folder="$2"
|
||||
shift 2 ;;
|
||||
-t | --task) task_name="$2"
|
||||
shift 2 ;;
|
||||
-c | --constraint_decoding) constraint_decoding="--constraint_decoding"
|
||||
shift ;;
|
||||
-o | --output) output_dir="$2"
|
||||
shift 2 ;;
|
||||
-f | --format) decoding_format="$2"
|
||||
shift 2 ;;
|
||||
-e | --extra_cmd) extra_cmd="$2"
|
||||
shift 2 ;;
|
||||
--beam) beam_size="$2"
|
||||
shift 2 ;;
|
||||
--map_config) map_config="$2"
|
||||
shift 2 ;;
|
||||
--)
|
||||
shift
|
||||
break
|
||||
;;
|
||||
*)
|
||||
echo "$1" not recognize.
|
||||
exit
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
echo "Extra CMD: " "${extra_cmd}"
|
||||
|
||||
if [[ ${output_dir} == "" ]]
|
||||
then
|
||||
output_dir=${model_path}_eval
|
||||
if [[ ${constraint_decoding} != "" ]]
|
||||
then
|
||||
output_dir=${output_dir}_CD
|
||||
fi
|
||||
fi
|
||||
|
||||
CUDA_VISIBLE_DEVICES=${device} python3 run_seq2seq.py \
|
||||
--use_fast_tokenizer=True \
|
||||
--max_source_length=${max_source_length:-"256"} \
|
||||
--max_target_length=${max_target_length:-"192"} \
|
||||
--do_eval --do_predict --task=record --predict_with_generate \
|
||||
--validation_file=${data_folder}/val.json \
|
||||
--test_file=${data_folder}/test.json \
|
||||
--record_schema=${data_folder}/record.schema \
|
||||
--model_name_or_path=${model_path} \
|
||||
--output_dir=${output_dir} \
|
||||
--source_prefix="${task_name}: " \
|
||||
--no_remove_unused_columns \
|
||||
--num_beams=${beam_size} \
|
||||
${constraint_decoding} ${extra_cmd} \
|
||||
--per_device_eval_batch_size=${batch} \
|
||||
--decoding_format ${decoding_format}
|
||||
|
||||
python3 scripts/sel2record.py -p ${output_dir} -g ${data_folder} -v -d ${decoding_format} -c ${map_config}
|
||||
python3 scripts/eval_extraction.py -p ${output_dir} -g ${data_folder} -w -m ${eval_match_mode:-"normal"}
|
||||
|
||||
echo "Output Dir:" ${output_dir}
|
|
@ -0,0 +1,779 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Fine-tuning the library models for sequence to sequence.
|
||||
"""
|
||||
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
DataCollatorForSeq2Seq,
|
||||
HfArgumentParser,
|
||||
default_data_collator,
|
||||
set_seed
|
||||
)
|
||||
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
||||
|
||||
from uie.extraction import constants
|
||||
from uie.extraction.record_schema import RecordSchema
|
||||
from uie.extraction.predict_parser import decoding_format_dict
|
||||
from uie.extraction.extraction_metrics import get_extract_metrics
|
||||
from uie.extraction.noiser.spot_asoc_noiser import SpotAsocNoiser
|
||||
from uie.extraction.dataset_processer import PrefixGenerator
|
||||
from uie.seq2seq.constrained_seq2seq import (
|
||||
ConstraintSeq2SeqTrainingArguments,
|
||||
ConstraintSeq2SeqTrainer,
|
||||
OriginalConstraintSeq2SeqTrainer,
|
||||
UIEPretrainConstraintSeq2SeqTrainer,
|
||||
UIEFinetuneConstraintSeq2SeqTrainer,
|
||||
MetaPretrainConstraintSeq2SeqTrainer,
|
||||
MetaFinetuneConstraintSeq2SeqTrainer,
|
||||
)
|
||||
from uie.seq2seq.data_collator import (
|
||||
DataCollatorForMetaSeq2Seq,
|
||||
DynamicSSIGenerator,
|
||||
)
|
||||
from uie.seq2seq.features import RecordFeature
|
||||
from uie.seq2seq.model import PromptSeq2SeqTransformer
|
||||
from uie.seq2seq.noise_record import create_noised_record
|
||||
|
||||
import pdb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
from_checkpoint: bool = field(
|
||||
default=False, metadata={"help": "Whether load from checkpoint to continue learning"}
|
||||
)
|
||||
load_config_only: bool = field(
|
||||
default=False, metadata={"help": "Whether load model config only from checkpoint"}
|
||||
)
|
||||
use_prompt_tuning_model: bool = field(
|
||||
default=False, metadata={"help": "Whether use prompt tuning model"}
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
task: str = field(
|
||||
default="summarization",
|
||||
metadata={
|
||||
"help": "The name of the task, should be summarization (or summarization_{dataset} for evaluating "
|
||||
"pegasus) or translation (or translation_{xx}_to_{yy})."
|
||||
},
|
||||
)
|
||||
dataset_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
text_column: Optional[str] = field(
|
||||
default='text',
|
||||
metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
|
||||
)
|
||||
record_column: Optional[str] = field(
|
||||
default='record',
|
||||
metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
|
||||
)
|
||||
train_file: Optional[str] = field(
|
||||
default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
|
||||
)
|
||||
validation_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input evaluation data file to evaluate the metrics (rouge/sacreblue) on "
|
||||
"(a jsonlines or csv file)."
|
||||
},
|
||||
)
|
||||
test_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input test data file to evaluate the metrics (rouge/sacreblue) on "
|
||||
"(a jsonlines or csv file)."
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
max_source_length: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
max_target_length: Optional[int] = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
max_prefix_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The maximum prefix length."
|
||||
},
|
||||
)
|
||||
val_max_target_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
||||
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
||||
"during ``evaluate`` and ``predict``."
|
||||
},
|
||||
)
|
||||
pad_to_max_length: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to pad all samples to model maximum sentence length. "
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
||||
"efficient on GPU but very bad for TPU."
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_test_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
num_beams: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
||||
"which is used during ``evaluate`` and ``predict``."
|
||||
},
|
||||
)
|
||||
ignore_pad_token_for_loss: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
|
||||
},
|
||||
)
|
||||
source_prefix: Optional[str] = field(
|
||||
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
||||
)
|
||||
meta_negative: int = field(
|
||||
default=-1, metadata={"help": "Negative Schema Number in Training."}
|
||||
)
|
||||
ordered_prompt: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to sort the spot prompt and asoc prompt or not."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
else:
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||
if self.val_max_target_length is None:
|
||||
self.val_max_target_length = self.max_target_length
|
||||
|
||||
decoding_format: str = field(
|
||||
default='tree',
|
||||
metadata={"help": "Decoding Format, valid in %s" % decoding_format_dict.keys()}
|
||||
)
|
||||
record_schema: str = field(
|
||||
default=None, metadata={"help": "The input event schema file."}
|
||||
)
|
||||
spot_noise: float = field(
|
||||
default=0., metadata={"help": "The noise rate of null spot."}
|
||||
)
|
||||
asoc_noise: float = field(
|
||||
default=0., metadata={"help": "The noise rate of null asoc."}
|
||||
)
|
||||
meta_positive_rate: float = field(
|
||||
default=1., metadata={"help": "The keep rate of positive spot."}
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ConstraintSeq2SeqTrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
# Detecting last checkpoint.
|
||||
last_checkpoint = None
|
||||
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
elif last_checkpoint is not None:
|
||||
logger.info(
|
||||
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
||||
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
|
||||
|
||||
logger.info("Options:")
|
||||
logger.info(model_args)
|
||||
logger.info(data_args)
|
||||
logger.info(training_args)
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
if is_main_process(training_args.local_rank):
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
logger.info("Training/evaluation parameters %s", training_args)
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args.seed)
|
||||
|
||||
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
|
||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||
#
|
||||
# For CSV/JSON files in the summarization task, this script will use the first column for the full texts and the
|
||||
# second column for the summaries (unless you specify column names for this with the `text_column` and
|
||||
# `record_column` arguments).
|
||||
# For translation, only JSON files are supported, with one field named "translation" containing two keys for the
|
||||
# source and target languages (unless you adapt what follows).
|
||||
#
|
||||
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
||||
# download the dataset.
|
||||
if data_args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
|
||||
else:
|
||||
data_files = {}
|
||||
if data_args.train_file is not None:
|
||||
data_files["train"] = data_args.train_file
|
||||
extension = data_args.train_file.split(".")[-1]
|
||||
if training_args.do_eval and data_args.validation_file is not None:
|
||||
data_files["validation"] = data_args.validation_file
|
||||
extension = data_args.validation_file.split(".")[-1]
|
||||
if training_args.do_predict and data_args.test_file is not None:
|
||||
data_files["test"] = data_args.test_file
|
||||
extension = data_args.test_file.split(".")[-1]
|
||||
logger.info(data_files)
|
||||
|
||||
datasets = load_dataset("uie_json.py", data_files=data_files, block_size=(10<<22))
|
||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||
logger.info(datasets)
|
||||
# Load pretrained model and tokenizer
|
||||
#
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
logger.info("Load Config: %s" % model_args.config_name if model_args.config_name else model_args.model_name_or_path)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
config.max_length = data_args.max_target_length
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
to_remove_token_list = list()
|
||||
if tokenizer.bos_token:
|
||||
to_remove_token_list += [tokenizer.bos_token]
|
||||
if tokenizer.eos_token:
|
||||
to_remove_token_list += [tokenizer.eos_token]
|
||||
if tokenizer.pad_token:
|
||||
to_remove_token_list += [tokenizer.pad_token]
|
||||
|
||||
if model_args.use_prompt_tuning_model:
|
||||
MODEL = PromptSeq2SeqTransformer
|
||||
else:
|
||||
MODEL = AutoModelForSeq2SeqLM
|
||||
|
||||
if model_args.load_config_only:
|
||||
model = MODEL.from_config(config)
|
||||
else:
|
||||
model = MODEL.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
mirror='tuna',
|
||||
)
|
||||
|
||||
if training_args.do_train:
|
||||
to_add_special_token = list()
|
||||
for special_token in [constants.type_start, constants.type_end, constants.text_start, constants.span_start, constants.spot_prompt, constants.asoc_prompt]:
|
||||
if special_token not in tokenizer.get_vocab():
|
||||
to_add_special_token += [special_token]
|
||||
|
||||
tokenizer.add_special_tokens(
|
||||
{"additional_special_tokens": tokenizer.special_tokens_map_extended['additional_special_tokens'] + to_add_special_token}
|
||||
)
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
logger.info(tokenizer)
|
||||
|
||||
# Set decoder_start_token_id
|
||||
if model.config.decoder_start_token_id is None:
|
||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||
|
||||
if data_args.record_schema and os.path.exists(data_args.record_schema):
|
||||
record_schema = RecordSchema.read_from_file(data_args.record_schema)
|
||||
else:
|
||||
record_schema = None
|
||||
|
||||
if data_args.source_prefix is not None:
|
||||
if data_args.source_prefix == 'schema':
|
||||
prefix = PrefixGenerator.get_schema_prefix(schema=record_schema)
|
||||
elif data_args.source_prefix.startswith('meta'):
|
||||
prefix = ""
|
||||
else:
|
||||
prefix = data_args.source_prefix
|
||||
else:
|
||||
prefix = ""
|
||||
logger.info(f"Prefix: {prefix}")
|
||||
logger.info(f"Prefix Length: {len(tokenizer.tokenize(prefix))}")
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# We need to tokenize inputs and targets.
|
||||
if training_args.do_train:
|
||||
column_names = datasets["train"].column_names
|
||||
elif training_args.do_eval:
|
||||
column_names = datasets["validation"].column_names
|
||||
elif training_args.do_predict:
|
||||
column_names = datasets["test"].column_names
|
||||
else:
|
||||
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
|
||||
return
|
||||
|
||||
# To serialize preprocess_function below, each of those four variables needs to be defined (even if we won't use
|
||||
# them all).
|
||||
|
||||
text_column = data_args.text_column
|
||||
record_column = data_args.record_column
|
||||
logger.info('Using src: %s and tgt: %s' % (text_column, record_column))
|
||||
|
||||
# Temporarily set max_target_length for training.
|
||||
max_target_length = data_args.max_target_length
|
||||
padding = "max_length" if data_args.pad_to_max_length else False
|
||||
|
||||
if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
||||
logger.error(
|
||||
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
|
||||
f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
|
||||
)
|
||||
|
||||
def preprocess_function(examples):
|
||||
inputs = examples[text_column]
|
||||
targets = examples[record_column]
|
||||
inputs = [prefix + inp for inp in inputs]
|
||||
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
||||
|
||||
model_inputs["text"] = inputs
|
||||
|
||||
# Setup the tokenizer for targets
|
||||
with tokenizer.as_target_tokenizer():
|
||||
labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
|
||||
|
||||
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
||||
# padding in the loss.
|
||||
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
|
||||
labels["input_ids"] = [
|
||||
[(_label if _label != tokenizer.pad_token_id else -100) for _label in label] for label in labels["input_ids"]
|
||||
]
|
||||
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
# set noised record inputs
|
||||
noised_record_list = []
|
||||
for idx, noised_record in enumerate(examples["noised_record"]):
|
||||
if noised_record is None:
|
||||
tokens = examples["tokens"][idx]
|
||||
entity_list = examples["entity"][idx]
|
||||
triple_list = examples["relation"][idx]
|
||||
event_list = examples["event"][idx]
|
||||
|
||||
noised_record = create_noised_record(tokens, entity_list, triple_list, event_list)
|
||||
noised_record_list.append(noised_record)
|
||||
model_inputs["noised_record"] = noised_record_list
|
||||
# model_inputs["noised_record"] = examples["noised_record"]
|
||||
|
||||
# others
|
||||
model_inputs['sample_prompt'] = [False] * len(model_inputs['input_ids'])
|
||||
if data_args.source_prefix is not None and data_args.source_prefix.startswith('meta'):
|
||||
model_inputs['spots'] = examples['spot']
|
||||
model_inputs['asocs'] = examples['asoc']
|
||||
model_inputs['spot_asoc'] = examples['spot_asoc']
|
||||
# sample_prompt=True for Finetune and Pretrain
|
||||
model_inputs['sample_prompt'] = [True] * len(model_inputs['input_ids'])
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_function_eval(examples):
|
||||
model_inputs = preprocess_function(examples)
|
||||
# sample_prompt=False for evaluation
|
||||
model_inputs['sample_prompt'] = [False] * len(model_inputs['input_ids'])
|
||||
return model_inputs
|
||||
|
||||
def postprocess_text(x_str):
|
||||
# Clean `bos` `eos` `pad` for cleaned text
|
||||
for to_remove_token in to_remove_token_list:
|
||||
x_str = x_str.replace(to_remove_token, '')
|
||||
|
||||
return x_str.strip()
|
||||
|
||||
logger.info("Start Data Preprocessing ...")
|
||||
|
||||
if training_args.do_train:
|
||||
train_dataset = datasets["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
train_dataset = train_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
features=RecordFeature,
|
||||
)
|
||||
|
||||
if training_args.do_eval:
|
||||
max_target_length = data_args.val_max_target_length
|
||||
eval_dataset = datasets["validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
eval_dataset = eval_dataset.map(
|
||||
preprocess_function_eval,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
features=RecordFeature,
|
||||
)
|
||||
|
||||
if training_args.do_predict:
|
||||
max_target_length = data_args.val_max_target_length
|
||||
test_dataset = datasets["test"]
|
||||
if data_args.max_test_samples is not None:
|
||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||
test_dataset = test_dataset.map(
|
||||
preprocess_function_eval,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
features=RecordFeature,
|
||||
)
|
||||
|
||||
logger.info("End Data Preprocessing ...")
|
||||
|
||||
# Data collator
|
||||
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
if data_args.pad_to_max_length:
|
||||
data_collator = default_data_collator
|
||||
elif data_args.source_prefix.startswith('meta'):
|
||||
|
||||
if data_args.spot_noise > 0 or data_args.asoc_noise > 0:
|
||||
if data_args.decoding_format == 'spotasoc':
|
||||
spot_asoc_nosier = SpotAsocNoiser(
|
||||
spot_noise_ratio=data_args.spot_noise,
|
||||
asoc_noise_ratio=data_args.asoc_noise,
|
||||
null_span=constants.null_span,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"decoding_format `spotasoc` is not implemented."
|
||||
)
|
||||
else:
|
||||
spot_asoc_nosier = None
|
||||
|
||||
data_collator = DataCollatorForMetaSeq2Seq(
|
||||
tokenizer,
|
||||
model=model,
|
||||
label_pad_token_id=label_pad_token_id,
|
||||
pad_to_multiple_of=8 if training_args.fp16 else None,
|
||||
max_length=data_args.max_source_length,
|
||||
max_prefix_length=data_args.max_prefix_length,
|
||||
max_target_length=data_args.max_target_length,
|
||||
negative_sampler=DynamicSSIGenerator(
|
||||
tokenizer=tokenizer,
|
||||
schema=record_schema,
|
||||
positive_rate=data_args.meta_positive_rate,
|
||||
negative=data_args.meta_negative,
|
||||
ordered_prompt=data_args.ordered_prompt,
|
||||
),
|
||||
spot_asoc_nosier=spot_asoc_nosier,
|
||||
decoding_format=data_args.decoding_format,
|
||||
)
|
||||
else:
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
model=model,
|
||||
label_pad_token_id=label_pad_token_id,
|
||||
pad_to_multiple_of=8 if training_args.fp16 else None,
|
||||
)
|
||||
|
||||
def compute_metrics(eval_preds):
|
||||
preds, labels = eval_preds
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=False, clean_up_tokenization_spaces=False)
|
||||
if data_args.ignore_pad_token_for_loss:
|
||||
# Replace -100 in the labels as we can't decode them.
|
||||
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
||||
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=False, clean_up_tokenization_spaces=False)
|
||||
|
||||
decoded_preds = [postprocess_text(x) for x in decoded_preds]
|
||||
decoded_labels = [postprocess_text(x) for x in decoded_labels]
|
||||
|
||||
result = get_extract_metrics(
|
||||
pred_lns=decoded_preds,
|
||||
tgt_lns=decoded_labels,
|
||||
label_constraint=record_schema,
|
||||
decoding_format=data_args.decoding_format,
|
||||
)
|
||||
|
||||
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
|
||||
result["gen_len"] = np.mean(prediction_lens)
|
||||
result = {k: round(v, 4) for k, v in result.items()}
|
||||
return result
|
||||
|
||||
# Initialize our Trainer
|
||||
if training_args.trainer_type == "uie_pretrain":
|
||||
TRAINER = UIEPretrainConstraintSeq2SeqTrainer
|
||||
elif training_args.trainer_type == "uie_finetune":
|
||||
TRAINER = UIEFinetuneConstraintSeq2SeqTrainer
|
||||
elif training_args.trainer_type == "meta_pretrain":
|
||||
TRAINER = MetaPretrainConstraintSeq2SeqTrainer
|
||||
elif training_args.trainer_type == "meta_finetune":
|
||||
TRAINER = MetaFinetuneConstraintSeq2SeqTrainer
|
||||
else:
|
||||
TRAINER = OriginalConstraintSeq2SeqTrainer
|
||||
|
||||
trainer = TRAINER(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset if training_args.do_train else None,
|
||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
decoding_type_schema=record_schema,
|
||||
decoding_format=data_args.decoding_format,
|
||||
source_prefix=prefix,
|
||||
task=data_args.task,
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
if model_args.from_checkpoint:
|
||||
if last_checkpoint is not None:
|
||||
checkpoint = last_checkpoint
|
||||
elif os.path.isdir(model_args.model_name_or_path):
|
||||
checkpoint = model_args.model_name_or_path
|
||||
else:
|
||||
checkpoint = None
|
||||
else:
|
||||
checkpoint = None
|
||||
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_train_file, "w") as writer:
|
||||
logger.info("***** Train results *****")
|
||||
for key, value in sorted(train_result.metrics.items()):
|
||||
logger.info(f" {key} = {value}")
|
||||
writer.write(f"{key} = {value}\n")
|
||||
|
||||
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
|
||||
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
results = trainer.evaluate(max_length=data_args.val_max_target_length, num_beams=data_args.num_beams)
|
||||
results = {k: round(v, 4) for k, v in results.items()}
|
||||
|
||||
eval_results = trainer.predict(
|
||||
eval_dataset,
|
||||
metric_key_prefix="eval",
|
||||
max_length=data_args.val_max_target_length,
|
||||
num_beams=data_args.num_beams,
|
||||
)
|
||||
|
||||
output_eval_file = os.path.join(training_args.output_dir, "eval_results_seq2seq.txt")
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results *****")
|
||||
for key, value in sorted(results.items()):
|
||||
logger.info(f" {key} = {value}")
|
||||
writer.write(f"{key} = {value}\n")
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
eval_preds = tokenizer.batch_decode(
|
||||
eval_results.predictions, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
||||
)
|
||||
eval_preds = [postprocess_text(pred) for pred in eval_preds]
|
||||
output_test_preds_file = os.path.join(training_args.output_dir, "eval_preds_seq2seq.txt")
|
||||
with open(output_test_preds_file, "w") as writer:
|
||||
writer.write("\n".join(eval_preds))
|
||||
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Test ***")
|
||||
|
||||
test_results = trainer.predict(
|
||||
test_dataset,
|
||||
metric_key_prefix="test",
|
||||
max_length=data_args.val_max_target_length,
|
||||
num_beams=data_args.num_beams,
|
||||
)
|
||||
test_metrics = test_results.metrics
|
||||
test_metrics["test_loss"] = round(test_metrics["test_loss"], 4)
|
||||
|
||||
output_test_result_file = os.path.join(training_args.output_dir, "test_results_seq2seq.txt")
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_test_result_file, "w") as writer:
|
||||
logger.info("***** Test results *****")
|
||||
for key, value in sorted(test_metrics.items()):
|
||||
logger.info(f" {key} = {value}")
|
||||
writer.write(f"{key} = {value}\n")
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
test_preds = tokenizer.batch_decode(
|
||||
test_results.predictions, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
||||
)
|
||||
test_preds = [postprocess_text(pred) for pred in test_preds]
|
||||
output_test_preds_file = os.path.join(training_args.output_dir, "test_preds_seq2seq.txt")
|
||||
with open(output_test_preds_file, "w") as writer:
|
||||
writer.write("\n".join(test_preds))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,104 @@
|
|||
#!/usr/bin/env bash
|
||||
# -*- coding:utf-8 -*-
|
||||
export batch_size="16"
|
||||
export model_name=uie-base-en
|
||||
export data_name=absa/14lap
|
||||
export task_name="meta"
|
||||
export decoding_format='spotasoc'
|
||||
|
||||
source scripts/function_code.bash
|
||||
|
||||
for index in $(seq 1 ${run_time}); do
|
||||
|
||||
if [[ ! ${output_dir} ]]
|
||||
then
|
||||
output_dir=${model_folder}_run${index}
|
||||
echo "output_dir is not provided so create it automatically: ${output_dir}"
|
||||
else
|
||||
echo "output_dir is provided: ${output_dir}"
|
||||
fi
|
||||
|
||||
if [[ ${verbose} == true ]]
|
||||
then
|
||||
stdout_file=/dev/stdout
|
||||
stderr_file=/dev/stderr
|
||||
disable_tqdm=False
|
||||
else
|
||||
stdout_file=${output_dir}.log
|
||||
stderr_file=${output_dir}.err
|
||||
disable_tqdm=True
|
||||
fi
|
||||
|
||||
# CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} gdb --args ${run_command} run_seq2seq.py \
|
||||
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} ${run_command} run_seq2seq.py \
|
||||
--do_train ${constraint_decoding} ${fp16} \
|
||||
--trainer_type=${trainer_type} \
|
||||
--load_config_only=False \
|
||||
--use_fast_tokenizer=True \
|
||||
--ddp_find_unused_parameters=False \
|
||||
--predict_with_generate \
|
||||
--evaluation_strategy="no" \
|
||||
--metric_for_best_model eval_overall-F1 \
|
||||
--save_strategy="steps" \
|
||||
--save_steps=10000 \
|
||||
--save_total_limit 9999999 \
|
||||
--load_best_model_at_end=False \
|
||||
--max_source_length="128" \
|
||||
--max_prefix_length="-1" \
|
||||
--max_target_length="128" \
|
||||
--num_train_epochs=${epoch} \
|
||||
--task=${task_name} \
|
||||
--train_file=${data_folder}/train.json \
|
||||
--validation_file=${data_folder}/val.json \
|
||||
--test_file=${data_folder}/test.json \
|
||||
--record_schema=${data_folder}/record.schema \
|
||||
--per_device_train_batch_size=${batch_size} \
|
||||
--per_device_eval_batch_size=$((batch_size * 4)) \
|
||||
--output_dir=${output_dir} \
|
||||
--from_checkpoint=True \
|
||||
--logging_dir=${output_dir}_log \
|
||||
--logging_strategy="steps" \
|
||||
--logging_first_step=True \
|
||||
--logging_steps=100 \
|
||||
--model_name_or_path=${model_name} \
|
||||
--learning_rate=${lr} \
|
||||
--source_prefix="${task_name}: " \
|
||||
--lr_scheduler_type=${lr_scheduler} \
|
||||
--label_smoothing_factor=${label_smoothing} \
|
||||
--eval_steps ${eval_steps} \
|
||||
--decoding_format ${decoding_format} \
|
||||
--warmup_ratio ${warmup_ratio} \
|
||||
--preprocessing_num_workers=32 \
|
||||
--dataloader_num_workers=32 \
|
||||
--meta_negative=10 \
|
||||
--meta_positive_rate=${positive} \
|
||||
--skip_memory_metrics \
|
||||
--no_remove_unused_columns \
|
||||
--ordered_prompt=${ordered_prompt} \
|
||||
--save_better_checkpoint=False \
|
||||
--start_eval_step=${start_eval_step:-"0"} \
|
||||
--spot_noise=${spot_noise} \
|
||||
--asoc_noise=${asoc_noise} \
|
||||
--seed=${seed}${index} --disable_tqdm=${disable_tqdm} >${stdout_file} 2>${stderr_file}
|
||||
echo "exit code:" $?
|
||||
|
||||
# --max_source_length=${max_source_length:-"128"} \
|
||||
# --max_prefix_length=${max_prefix_length:-"-1"} \
|
||||
# --max_target_length=${max_target_length:-"128"} \
|
||||
# --save_strategy=${evaluation_strategy} \
|
||||
# --save_total_limit 1 \
|
||||
# --load_best_model_at_end \
|
||||
|
||||
if [[ ${verbose} != true ]]
|
||||
then
|
||||
tail -n 200 ${stderr_file}
|
||||
fi
|
||||
|
||||
# echo "Map Config" ${map_config}
|
||||
# python3 scripts/sel2record.py -p ${output_dir} -g ${data_folder} -v -d ${decoding_format} -c ${map_config}
|
||||
# python3 scripts/eval_extraction.py -p ${output_dir} -g ${data_folder} -w -m ${eval_match_mode:-"normal"}
|
||||
|
||||
# delete all optimizer.pt for saving disk
|
||||
find ${output_dir}/ | grep -P "optimizer.pt" | xargs rm -rf
|
||||
|
||||
done
|
|
@ -0,0 +1,107 @@
|
|||
#!/usr/bin/env bash
|
||||
# -*- coding:utf-8 -*-
|
||||
export batch_size="16"
|
||||
export model_name=uie-base-en
|
||||
export data_name=absa/14lap
|
||||
export task_name="meta"
|
||||
export decoding_format='spotasoc'
|
||||
|
||||
source scripts/function_code.bash
|
||||
|
||||
for index in $(seq 1 ${run_time}); do
|
||||
|
||||
if [[ ! ${output_dir} ]]
|
||||
then
|
||||
output_dir=${model_folder}_run${index}
|
||||
echo "output_dir is not provided so create it automatically: ${output_dir}"
|
||||
else
|
||||
echo "output_dir is provided: ${output_dir}"
|
||||
fi
|
||||
|
||||
# output_dir=${model_folder}_run${index}
|
||||
|
||||
if [[ ${verbose} == true ]]
|
||||
then
|
||||
stdout_file=/dev/stdout
|
||||
stderr_file=/dev/stderr
|
||||
disable_tqdm=False
|
||||
else
|
||||
stdout_file=${output_dir}.log
|
||||
stderr_file=${output_dir}.err
|
||||
disable_tqdm=True
|
||||
fi
|
||||
|
||||
# CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} gdb --args ${run_command} run_seq2seq.py \
|
||||
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} ${run_command} run_seq2seq.py \
|
||||
--do_train --do_eval --do_predict ${constraint_decoding} ${fp16} \
|
||||
--use_prompt_tuning_model=${use_prompt_tuning_model} \
|
||||
--trainer_type=${trainer_type} \
|
||||
--load_config_only=False \
|
||||
--use_fast_tokenizer=True \
|
||||
--ddp_find_unused_parameters=False \
|
||||
--predict_with_generate \
|
||||
--evaluation_strategy=${evaluation_strategy} \
|
||||
--save_strategy=${evaluation_strategy} \
|
||||
--metric_for_best_model eval_overall-F1 \
|
||||
--save_total_limit 1 \
|
||||
--load_best_model_at_end \
|
||||
--max_source_length=${max_source_length:-"256"} \
|
||||
--max_prefix_length=${max_prefix_length:-"-1"} \
|
||||
--max_target_length=${max_target_length:-"192"} \
|
||||
--num_train_epochs=${epoch} \
|
||||
--task=${task_name} \
|
||||
--train_file=${data_folder}/train.json \
|
||||
--validation_file=${data_folder}/val.json \
|
||||
--test_file=${data_folder}/test.json \
|
||||
--record_schema=${data_folder}/record.schema \
|
||||
--per_device_train_batch_size=${batch_size} \
|
||||
--per_device_eval_batch_size=$((batch_size * 4)) \
|
||||
--output_dir=${output_dir} \
|
||||
--logging_dir=${output_dir}_log \
|
||||
--logging_strategy="steps" \
|
||||
--logging_first_step=True \
|
||||
--logging_steps=100 \
|
||||
--model_name_or_path=${model_name} \
|
||||
--learning_rate=${lr} \
|
||||
--source_prefix="${task_name}: " \
|
||||
--lr_scheduler_type=${lr_scheduler} \
|
||||
--label_smoothing_factor=${label_smoothing} \
|
||||
--eval_steps ${eval_steps} \
|
||||
--decoding_format ${decoding_format} \
|
||||
--warmup_ratio ${warmup_ratio} \
|
||||
--preprocessing_num_workers=32 \
|
||||
--dataloader_num_workers=32 \
|
||||
--meta_negative=${negative} \
|
||||
--meta_positive_rate=${positive} \
|
||||
--skip_memory_metrics \
|
||||
--no_remove_unused_columns \
|
||||
--ordered_prompt=${ordered_prompt} \
|
||||
--save_better_checkpoint=False \
|
||||
--start_eval_step=${start_eval_step:-"0"} \
|
||||
--spot_noise=${spot_noise} \
|
||||
--asoc_noise=${asoc_noise} \
|
||||
--seed=${seed}${index} --disable_tqdm=${disable_tqdm} >${stdout_file} 2>${stderr_file}
|
||||
echo "exit code:" $?
|
||||
|
||||
# --save_strategy=${evaluation_strategy} \
|
||||
# --save_total_limit 1 \
|
||||
# --load_best_model_at_end \
|
||||
|
||||
# --save_strategy="steps" \
|
||||
# --save_steps=5000 \
|
||||
# --save_total_limit 9999999 \
|
||||
# --load_best_model_at_end=True \
|
||||
|
||||
if [[ ${verbose} != true ]]
|
||||
then
|
||||
tail -n 200 ${stderr_file}
|
||||
fi
|
||||
|
||||
echo "Map Config" ${map_config}
|
||||
python3 scripts/sel2record.py -p ${output_dir} -g ${data_folder} -v -d ${decoding_format} -c ${map_config}
|
||||
python3 scripts/eval_extraction.py -p ${output_dir} -g ${data_folder} -w -m ${eval_match_mode:-"normal"}
|
||||
|
||||
# delete all optimizer.pt for saving disk
|
||||
# find ${output_dir}/ | grep -P "optimizer.pt" | xargs rm -rf
|
||||
|
||||
done
|
|
@ -0,0 +1,96 @@
|
|||
#!/usr/bin/env bash
|
||||
# -*- coding:utf-8 -*-
|
||||
export batch_size="16"
|
||||
export model_name=uie-base-en
|
||||
export data_name=absa/14lap
|
||||
export task_name="meta"
|
||||
export decoding_format='spotasoc'
|
||||
|
||||
source scripts/function_code.bash
|
||||
|
||||
for index in $(seq 1 ${run_time}); do
|
||||
output_dir=${model_folder}_run${index}
|
||||
|
||||
if [[ ${verbose} == true ]]
|
||||
then
|
||||
stdout_file=/dev/stdout
|
||||
stderr_file=/dev/stderr
|
||||
disable_tqdm=False
|
||||
else
|
||||
stdout_file=${output_dir}.log
|
||||
stderr_file=${output_dir}.err
|
||||
disable_tqdm=True
|
||||
fi
|
||||
|
||||
ratio_data_folder=${data_folder}_ratio/seed${index}
|
||||
|
||||
for ratio in $(ls ${ratio_data_folder})
|
||||
do
|
||||
run_data_folder=${ratio_data_folder}/${ratio}
|
||||
run_output_folder=${output_dir}_${ratio}
|
||||
|
||||
if [[ ${max_prefix_length} == 0 ]]
|
||||
then
|
||||
run_output_folder=${run_output_folder}_noprefix
|
||||
fi
|
||||
|
||||
eval_steps=$(python scripts/get_eval_batch_num.py ${run_data_folder}/train.json ${batch_size} 20)
|
||||
echo Eval each ${eval_steps} batch
|
||||
|
||||
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} ${run_command} run_seq2seq.py \
|
||||
--do_train --do_eval --do_predict ${constraint_decoding} ${fp16} \
|
||||
--trainer_type=${trainer_type} \
|
||||
--load_config_only=False \
|
||||
--use_fast_tokenizer=True \
|
||||
--ddp_find_unused_parameters=False \
|
||||
--predict_with_generate \
|
||||
--evaluation_strategy=steps \
|
||||
--save_strategy=steps \
|
||||
--load_best_model_at_end \
|
||||
--metric_for_best_model eval_overall-F1 \
|
||||
--save_total_limit 1 \
|
||||
--max_source_length=${max_source_length:-"256"} \
|
||||
--max_prefix_length=${max_prefix_length:-"-1"} \
|
||||
--max_target_length=${max_target_length:-"192"} \
|
||||
--num_train_epochs=${epoch} \
|
||||
--task=${task_name} \
|
||||
--train_file=${run_data_folder}/train.json \
|
||||
--validation_file=${run_data_folder}/val.json \
|
||||
--test_file=${run_data_folder}/test.json \
|
||||
--record_schema=${run_data_folder}/record.schema \
|
||||
--per_device_train_batch_size=${batch_size} \
|
||||
--per_device_eval_batch_size=$((batch_size * 4)) \
|
||||
--output_dir=${run_output_folder} \
|
||||
--logging_dir=${run_output_folder}_log \
|
||||
--model_name_or_path=${model_name} \
|
||||
--learning_rate=${lr} \
|
||||
--source_prefix="${task_name}: " \
|
||||
--lr_scheduler_type=${lr_scheduler} \
|
||||
--label_smoothing_factor=${label_smoothing} \
|
||||
--eval_steps ${eval_steps} \
|
||||
--decoding_format ${decoding_format} \
|
||||
--warmup_ratio ${warmup_ratio} \
|
||||
--preprocessing_num_workers=4 \
|
||||
--dataloader_num_workers=0 \
|
||||
--meta_negative=${negative} \
|
||||
--meta_positive_rate=${positive} \
|
||||
--skip_memory_metrics \
|
||||
--no_remove_unused_columns \
|
||||
--ordered_prompt=${ordered_prompt} \
|
||||
--save_better_checkpoint=True \
|
||||
--spot_noise=${spot_noise} \
|
||||
--asoc_noise=${asoc_noise} \
|
||||
--seed=${seed} --disable_tqdm=${disable_tqdm} >${stdout_file} 2>${stderr_file}
|
||||
|
||||
echo "Map Config" ${map_config}
|
||||
python3 scripts/sel2record.py -p ${run_output_folder} -g ${run_data_folder} -v -d ${decoding_format} -c ${map_config}
|
||||
python3 scripts/eval_extraction.py -p ${run_output_folder} -g ${run_data_folder} -w -m ${eval_match_mode:-"normal"}
|
||||
|
||||
# delete all pytorch_model.bin of checkpoints in low-resource exps for saving disk
|
||||
# find ${run_output_folder}/ | grep -P "checkpoint-\d+/pytorch_model.bin" | xargs rm -rf
|
||||
# delete all optimizer.pt in low-resource exps for saving disk
|
||||
# find ${run_output_folder}/ | grep -P "optimizer.pt" | xargs rm -rf
|
||||
|
||||
done
|
||||
|
||||
done
|
|
@ -0,0 +1,99 @@
|
|||
#!/usr/bin/env bash
|
||||
# -*- coding:utf-8 -*-
|
||||
export batch_size="16"
|
||||
export model_name=uie-base-en
|
||||
export data_name=absa/14lap
|
||||
export task_name="meta"
|
||||
export decoding_format='spotasoc'
|
||||
|
||||
source scripts/function_code.bash
|
||||
|
||||
for index in $(seq 1 ${run_time}); do
|
||||
output_dir=${model_folder}_run${index}
|
||||
|
||||
if [[ ${verbose} == true ]]
|
||||
then
|
||||
stdout_file=/dev/stdout
|
||||
stderr_file=/dev/stderr
|
||||
disable_tqdm=False
|
||||
else
|
||||
stdout_file=${output_dir}.log
|
||||
stderr_file=${output_dir}.err
|
||||
disable_tqdm=True
|
||||
fi
|
||||
|
||||
shot_data_folder=${data_folder}_shot/seed${index}
|
||||
|
||||
for shot in $(ls ${shot_data_folder})
|
||||
do
|
||||
|
||||
run_data_folder=${shot_data_folder}/${shot}
|
||||
run_output_folder=${output_dir}_${shot}
|
||||
|
||||
if [[ ${max_prefix_length} == 0 ]]
|
||||
then
|
||||
run_output_folder=${run_output_folder}_noprefix
|
||||
fi
|
||||
|
||||
echo ${run_data_folder}
|
||||
|
||||
eval_steps=$(python scripts/get_eval_batch_num.py ${run_data_folder}/train.json ${batch_size} 20)
|
||||
echo Eval each ${eval_steps} batch
|
||||
|
||||
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} ${run_command} run_seq2seq.py \
|
||||
--do_train --do_eval --do_predict ${constraint_decoding} ${fp16} \
|
||||
--trainer_type=${trainer_type} \
|
||||
--load_config_only=False \
|
||||
--use_fast_tokenizer=True \
|
||||
--ddp_find_unused_parameters=False \
|
||||
--predict_with_generate \
|
||||
--evaluation_strategy=steps \
|
||||
--save_strategy=steps \
|
||||
--load_best_model_at_end \
|
||||
--metric_for_best_model eval_overall-F1 \
|
||||
--save_total_limit 1 \
|
||||
--max_source_length=${max_source_length:-"256"} \
|
||||
--max_prefix_length=${max_prefix_length:-"-1"} \
|
||||
--max_target_length=${max_target_length:-"192"} \
|
||||
--num_train_epochs=${epoch} \
|
||||
--task=${task_name} \
|
||||
--train_file=${run_data_folder}/train.json \
|
||||
--validation_file=${run_data_folder}/val.json \
|
||||
--test_file=${run_data_folder}/test.json \
|
||||
--record_schema=${run_data_folder}/record.schema \
|
||||
--per_device_train_batch_size=${batch_size} \
|
||||
--per_device_eval_batch_size=$((batch_size * 4)) \
|
||||
--output_dir=${run_output_folder} \
|
||||
--logging_dir=${run_output_folder}_log \
|
||||
--model_name_or_path=${model_name} \
|
||||
--learning_rate=${lr} \
|
||||
--source_prefix="${task_name}: " \
|
||||
--lr_scheduler_type=${lr_scheduler} \
|
||||
--label_smoothing_factor=${label_smoothing} \
|
||||
--eval_steps ${eval_steps} \
|
||||
--decoding_format ${decoding_format} \
|
||||
--warmup_ratio ${warmup_ratio} \
|
||||
--preprocessing_num_workers=4 \
|
||||
--dataloader_num_workers=0 \
|
||||
--meta_negative=${negative} \
|
||||
--meta_positive_rate=${positive} \
|
||||
--skip_memory_metrics \
|
||||
--no_remove_unused_columns \
|
||||
--ordered_prompt=${ordered_prompt} \
|
||||
--save_better_checkpoint=True \
|
||||
--spot_noise=${spot_noise} \
|
||||
--asoc_noise=${asoc_noise} \
|
||||
--seed=${seed} --disable_tqdm=${disable_tqdm} >${stdout_file} 2>${stderr_file}
|
||||
|
||||
echo "Map Config" ${map_config}
|
||||
python3 scripts/sel2record.py -p ${run_output_folder} -g ${run_data_folder} -v -d ${decoding_format} -c ${map_config}
|
||||
python3 scripts/eval_extraction.py -p ${run_output_folder} -g ${run_data_folder} -w -m ${eval_match_mode:-"normal"}
|
||||
|
||||
# delete all pytorch_model.bin of checkpoints in low-resource exps for saving disk
|
||||
# find ${run_output_folder}/ | grep -P "checkpoint-\d+/pytorch_model.bin" | xargs rm -rf
|
||||
# delete all optimizer.pt in low-resource exps for saving disk
|
||||
# find ${run_output_folder}/ | grep -P "optimizer.pt" | xargs rm -rf
|
||||
|
||||
done
|
||||
|
||||
done
|
|
@ -0,0 +1,25 @@
|
|||
#!/usr/bin/env bash
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
# Check Offset Mapping Performance
|
||||
# 用于验证不同 SEL2Record 回标策略的准确值
|
||||
# bash scripts/check_offset_map_gold_as_pred.bash data/text2spotasocname/absa/14lap config/offset_map/closest_offset_en.yaml spotasocname
|
||||
|
||||
folder_name=$1
|
||||
config_name=$2
|
||||
parser_format=$3
|
||||
|
||||
cat ${folder_name}/val.json | python -c "import json, sys
|
||||
for line in sys.stdin:
|
||||
print(json.loads(line.strip())['record'])
|
||||
" > ${folder_name}/eval_preds_seq2seq.txt
|
||||
|
||||
python scripts/sel2record.py \
|
||||
-c ${config_name} \
|
||||
-g ${folder_name} \
|
||||
-p ${folder_name} \
|
||||
-d ${parser_format}
|
||||
|
||||
python scripts/eval_extraction.py \
|
||||
-g ${folder_name} \
|
||||
-p ${folder_name} -w
|
|
@ -0,0 +1,135 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
from pprint import pprint
|
||||
from uie.extraction.scorer import EntityScorer, RelationScorer, EventScorer
|
||||
|
||||
|
||||
def read_file(file_name):
|
||||
return [line for line in open(file_name).readlines()]
|
||||
|
||||
|
||||
def write_to_file(result, output_filename, prefix=None):
|
||||
with open(output_filename, 'w') as output:
|
||||
for key, value in result.items():
|
||||
if prefix:
|
||||
key = '%s_%s' % (prefix, key)
|
||||
output.write("%s=%s\n" % (key, value))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-g', dest='gold_folder', help="Golden Dataset folder")
|
||||
parser.add_argument('-p', dest='pred_folder', nargs='+', help="Predicted model folder")
|
||||
parser.add_argument('-v', dest='verbose', action='store_true', help='Show more information during running')
|
||||
parser.add_argument('-w', dest='write_to_file', action='store_true', help="Write evaluation results to predicted folder")
|
||||
parser.add_argument('-m', dest='match_mode', default='normal', choices=['set', 'normal', 'multimatch'])
|
||||
parser.add_argument('-case', dest='case', action='store_true', help='Show case study')
|
||||
options = parser.parse_args()
|
||||
|
||||
data_dict = {
|
||||
'eval': ['eval_preds_record.txt', 'val.json'],
|
||||
'test': ['test_preds_record.txt', 'test.json'],
|
||||
}
|
||||
|
||||
task_dict = {
|
||||
'entity': EntityScorer,
|
||||
'relation': RelationScorer,
|
||||
'event': EventScorer,
|
||||
}
|
||||
|
||||
result_list = {'eval': list(), 'test': list()}
|
||||
for pred_folder in options.pred_folder:
|
||||
gold_folder = options.gold_folder
|
||||
|
||||
for data_key, (generation, gold_file) in data_dict.items():
|
||||
|
||||
gold_filename = os.path.join(gold_folder, gold_file)
|
||||
pred_filename = os.path.join(pred_folder, generation)
|
||||
|
||||
if not os.path.exists(pred_filename):
|
||||
sys.stderr.write("%s not found.\n" % pred_filename)
|
||||
continue
|
||||
|
||||
print("pred:", pred_filename)
|
||||
print("gold:", gold_filename)
|
||||
|
||||
if options.case:
|
||||
for pred_line, gold_line in zip(read_file(pred_filename), read_file(gold_filename)):
|
||||
gold_instance = json.loads(gold_line)
|
||||
pred_instance = json.loads(pred_line)
|
||||
print('=========================')
|
||||
print(gold_instance['text'])
|
||||
for task in task_dict:
|
||||
scorer = task_dict[task]
|
||||
gold = scorer.load_gold_list([gold_instance[task]])[0]
|
||||
pred = scorer.load_pred_list([pred_instance[task]])[0]
|
||||
min_length = max(
|
||||
len(gold['string']),
|
||||
len(pred['string']),
|
||||
len(gold.get('string_trigger', [])),
|
||||
len(pred.get('string_trigger', [])),
|
||||
len(gold.get('string_role', [])),
|
||||
len(pred.get('string_role', [])),
|
||||
)
|
||||
if min_length == 0:
|
||||
continue
|
||||
if task == 'entity':
|
||||
print("Entity Gold:", sorted(gold['string']))
|
||||
print("Entity Pred:", sorted(pred['string']))
|
||||
if task == 'relation':
|
||||
print("Relation Gold:", sorted(gold['string']))
|
||||
print("Relation Pred:", sorted(pred['string']))
|
||||
if task == 'event':
|
||||
print("Event Gold Trigger:", sorted(gold['string_trigger']))
|
||||
print("Event Pred Trigger:", sorted(pred['string_trigger']))
|
||||
print("Event Gold Role :", sorted(gold['string_role']))
|
||||
print("Event Pred Role :", sorted(pred['string_role']))
|
||||
|
||||
results = dict()
|
||||
for task in task_dict:
|
||||
if task not in json.loads(read_file(pred_filename)[0]):
|
||||
continue
|
||||
scorer = task_dict[task]
|
||||
gold_list = [json.loads(line)[task] for line in read_file(gold_filename)]
|
||||
pred_list = [json.loads(line)[task] for line in read_file(pred_filename)]
|
||||
|
||||
assert len(pred_list) == len(gold_list)
|
||||
gold_instance_list = scorer.load_gold_list(gold_list)
|
||||
pred_instance_list = scorer.load_pred_list(pred_list)
|
||||
assert len(pred_instance_list) == len(gold_instance_list)
|
||||
sub_results = scorer.eval_instance_list(
|
||||
gold_instance_list=gold_instance_list,
|
||||
pred_instance_list=pred_instance_list,
|
||||
verbose=options.verbose,
|
||||
match_mode=options.match_mode,
|
||||
)
|
||||
results.update(sub_results)
|
||||
|
||||
pprint(results)
|
||||
result_list[data_key] += [results]
|
||||
|
||||
if options.write_to_file:
|
||||
output_filename = "%s/%s_results.txt" % (pred_folder, data_key)
|
||||
write_to_file(
|
||||
result=results,
|
||||
output_filename=output_filename,
|
||||
prefix=data_key,
|
||||
)
|
||||
|
||||
print("===========> AVG <===========")
|
||||
|
||||
for data_key in data_dict:
|
||||
if len(result_list[data_key]) < 1:
|
||||
continue
|
||||
for key in result_list[data_key][0]:
|
||||
ave = np.mean([result[key] for result in result_list[data_key]])
|
||||
print(data_key, key, ave)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue